use crate::common::BitVec;
use crate::mul_triple::boolean::MulTriples;
use crate::mul_triple::MTProvider;
use crate::protocols::SetupStorage;
use async_trait::async_trait;
use rand::rngs::OsRng;
use rand::SeedableRng;
use rand_chacha::rand_core::{CryptoRng, RngCore};
use rand_chacha::ChaChaRng;
use std::thread::available_parallelism;
use zappot::silent_ot;
use zappot::silent_ot::MultType;
pub type Msg = silent_ot::Msg;
pub struct SilentMtProvider<Rng> {
    rng: Rng,
    configured_ots: usize,
    stored_mts: Option<MulTriples>,
    silent_sender: Option<silent_ot::Sender>,
    silent_receiver: Option<silent_ot::Receiver>,
    ch1: Option<seec_channel::Channel<silent_ot::Msg>>,
    ch2: Option<seec_channel::Channel<silent_ot::Msg>>,
}
impl<Rng: RngCore + CryptoRng + Send> SilentMtProvider<Rng> {
    pub async fn new(
        num_ots: usize,
        rng: Rng,
        ch1: seec_channel::Channel<silent_ot::Msg>,
        ch2: seec_channel::Channel<silent_ot::Msg>,
    ) -> Self {
        #[cfg(feature = "silent-ot")]
        {
            Self::new_with_mult_type(num_ots, MultType::ExConv7x24, rng, ch1, ch2).await
        }
        #[cfg(not(feature = "silent-ot"))]
        {
            Self::new_with_mult_type(num_ots, MultType::QuasiCyclic { scaler: 2 }, rng, ch1, ch2)
                .await
        }
    }
    pub async fn new_with_mult_type(
        num_ots: usize,
        mul_type: MultType,
        mut rng: Rng,
        mut ch1: seec_channel::Channel<silent_ot::Msg>,
        mut ch2: seec_channel::Channel<silent_ot::Msg>,
    ) -> Self {
        let mut rng1 = ChaChaRng::from_rng(&mut rng).expect("Seeding Rng in SilentMtProvider::new");
        let mut rng2 = ChaChaRng::from_rng(&mut rng).expect("Seeding Rng in SilentMtProvider::new");
        let threads_per_ot = available_parallelism().map(usize::from).unwrap_or(2) / 2;
        let (silent_sender, silent_receiver) = tokio::join!(
            silent_ot::Sender::new_with_base_ot_sender(
                zappot::base_ot::Sender::new(),
                &mut rng1,
                num_ots,
                mul_type,
                threads_per_ot,
                &mut ch1.0,
                &mut ch1.1
            ),
            silent_ot::Receiver::new_with_base_ot_receiver(
                zappot::base_ot::Receiver::new(),
                &mut rng2,
                num_ots,
                mul_type,
                threads_per_ot,
                &mut ch2.0,
                &mut ch2.1
            ),
        );
        Self {
            rng,
            configured_ots: num_ots,
            stored_mts: None,
            silent_sender: Some(silent_sender),
            silent_receiver: Some(silent_receiver),
            ch1: Some(ch1),
            ch2: Some(ch2),
        }
    }
    pub async fn precompute_mts(&mut self) {
        let silent_sender = self
            .silent_sender
            .take()
            .expect("precompute_mts can only be called once");
        let silent_receiver = self.silent_receiver.take().unwrap();
        let ch1 = self.ch1.take().unwrap();
        let ch2 = self.ch2.take().unwrap();
        let send = silent_sender.random_silent_send(&mut self.rng, ch1.0, ch1.1);
        let receive = silent_receiver.random_silent_receive(ch2.0, ch2.1);
        let (send_ots, (recv_ots, a_i)) = tokio::join!(send, receive);
        eprintln!("{send_ots:?}");
        eprintln!("{recv_ots:?}");
        eprintln!("{a_i:?}");
        let mut b_i = BitVec::with_capacity(self.configured_ots);
        let mut v_i: BitVec<usize> = BitVec::with_capacity(self.configured_ots);
        send_ots
            .into_iter()
            .map(|arr| arr.map(|b| b.lsb()))
            .for_each(|[m0, m1]| {
                b_i.push(m0 ^ m1);
                v_i.push(m0);
            });
        let u_i = recv_ots.into_iter().map(|b| b.lsb());
        let c_i = a_i
            .iter()
            .by_vals()
            .zip(b_i.iter().by_vals())
            .zip(u_i)
            .zip(v_i)
            .map(|(((a, b), u), v)| a & b ^ u ^ v)
            .collect();
        self.stored_mts = Some(MulTriples::from_raw(a_i, b_i, c_i));
    }
    pub fn mts_available(&self) -> usize {
        self.stored_mts.as_ref().map(|mts| mts.len()).unwrap_or(0)
    }
}
impl SilentMtProvider<OsRng> {
    pub fn from_raw_mts(mts: MulTriples) -> Self {
        Self {
            rng: OsRng,
            configured_ots: mts.len(),
            stored_mts: Some(mts),
            silent_sender: None,
            silent_receiver: None,
            ch1: None,
            ch2: None,
        }
    }
}
#[async_trait]
impl<Rng: RngCore + CryptoRng + Send> MTProvider for SilentMtProvider<Rng> {
    type Output = MulTriples;
    type Error = ();
    async fn precompute_mts(&mut self, amount: usize) -> Result<(), Self::Error> {
        assert_eq!(
            amount, self.configured_ots,
            "Requested OTs must be equal to configured OTs of SilentMtProvider"
        );
        self.precompute_mts().await;
        Ok(())
    }
    async fn request_mts(&mut self, amount: usize) -> Result<Self::Output, Self::Error> {
        if let Some(stored_mts) = &mut self.stored_mts {
            return Ok(stored_mts.split_off_last(amount));
        }
        self.precompute_mts().await;
        self.request_mts(amount).await
    }
}
#[cfg(test)]
mod tests {
    use super::*;
    #[tokio::test]
    async fn silent_mts() {
        let (ch11, ch21) = seec_channel::in_memory::new_pair(128);
        let (ch12, ch22) = seec_channel::in_memory::new_pair(128);
        let (mut mtp1, mut mtp2) = tokio::join!(
            SilentMtProvider::new(100, OsRng, ch11, ch22),
            SilentMtProvider::new(100, OsRng, ch12, ch21)
        );
        let (mts1, mts2) = tokio::try_join!(mtp1.request_mts(100), mtp2.request_mts(100),).unwrap();
        let left = mts1.c ^ mts2.c;
        let right = (mts1.a ^ mts2.a) & (mts1.b ^ mts2.b);
        assert_eq!(left, right);
    }
}