use bitvec::vec::BitVec;
use bitvec::{bitvec, order::Lsb0};
use clap::Parser;
use rand::Rng;
use rand_core::OsRng;
use seec_channel::sub_channel;
use std::time::{Duration, Instant};
use tracing_subscriber::EnvFilter;
use zappot::base_ot;
use zappot::ot_ext::{Receiver, Sender};
use zappot::traits::{ExtROTReceiver, ExtROTSender};
use zappot::util::Block;
#[derive(Parser, Debug, Clone)]
struct Args {
    #[clap(short, long, default_value_t = 1000)]
    num_ots: usize,
    #[clap(short, long, default_value_t = 8066)]
    port: u16,
}
async fn sender(args: Args) -> (Vec<[Block; 2]>, usize, usize) {
    let mut rng = OsRng;
    let mut sender = Sender::new(base_ot::Receiver);
    let (mut base_sender, send_cnt, mut base_receiver, recv_cnt) =
        seec_channel::tcp::listen::<seec_channel::Sender<_>>(("127.0.0.1", args.port))
            .await
            .expect("Error listening for channel connection");
    let (ch_sender, mut ch_receiver) = sub_channel(&mut base_sender, &mut base_receiver, 128)
        .await
        .expect("Establishing sub channel");
    let ots = sender
        .send_random(args.num_ots, &mut rng, &ch_sender, &mut ch_receiver)
        .await
        .expect("Failed to generate ROTs");
    (ots, send_cnt.get(), recv_cnt.get())
}
async fn receiver(args: Args) -> (Vec<Block>, BitVec) {
    let mut rng = OsRng;
    let mut receiver = Receiver::new(base_ot::Sender);
    let (mut base_sender, _, mut base_receiver, _) =
        seec_channel::tcp::connect::<seec_channel::Sender<_>>(("127.0.0.1", args.port))
            .await
            .expect("Error listening for channel connection");
    let (ch_sender, mut ch_receiver) = sub_channel(&mut base_sender, &mut base_receiver, 128)
        .await
        .expect("Establishing sub channel");
    let choices: BitVec = {
        let mut bv = bitvec![usize, Lsb0; 0; args.num_ots];
        rng.fill(bv.as_raw_mut_slice());
        bv
    };
    let ots = receiver
        .receive_random(&choices, &mut rng, &ch_sender, &mut ch_receiver)
        .await
        .expect("Failed to generate ROTs");
    (ots, choices)
}
#[tokio::main(flavor = "multi_thread", worker_threads = 4)]
async fn main() {
    tracing_subscriber::fmt()
        .with_env_filter(EnvFilter::from_default_env())
        .init();
    let args: Args = Args::parse();
    let now = Instant::now();
    let sender_fut = tokio::spawn(sender(args.clone()));
    tokio::time::sleep(Duration::from_millis(50)).await;
    let (receiver_ots, choices) = tokio::spawn(receiver(args.clone()))
        .await
        .expect("Error await receiver");
    let (sender_ots, send_cnt, recv_cnt) = sender_fut.await.expect("Error awaiting sender");
    println!(
        "Executed {} ots in {} ms. Sent bytes: {}, Recv bytes: {}",
        args.num_ots,
        now.elapsed().as_millis(),
        send_cnt,
        recv_cnt
    );
    for ((recv, choice), [send1, send2]) in receiver_ots.into_iter().zip(choices).zip(sender_ots) {
        let [chosen, not_chosen] = if choice {
            [send2, send1]
        } else {
            [send1, send2]
        };
        assert_eq!(recv, chosen);
        assert_ne!(recv, not_chosen);
    }
}