use std::cmp::{max, min};
use std::fmt::Debug;
use bitpolymul::{DecodeCache, FftPoly};
use bitvec::order::Lsb0;
use bitvec::slice::BitSlice;
use bytemuck::{cast_slice, cast_slice_mut};
use ndarray::Array2;
use num_integer::Integer;
use num_prime::nt_funcs::next_prime;
use rand::Rng;
use rand_core::SeedableRng;
use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
use rayon::slice::ParallelSliceMut;
use seec_bitmatrix::BitMatrixView;
use std::time::Instant;
use crate::silent_ot::get_reg_noise_weight;
use crate::silent_ot::pprf::PprfConfig;
use crate::util::aes_rng::AesRng;
use crate::util::Block;
#[derive(Debug, Clone)]
pub struct QuasiCyclicEncoder {
pub(crate) conf: QuasiCyclicConf,
a_polynomials: Vec<FftPoly>,
}
impl QuasiCyclicEncoder {
pub(crate) fn new(conf: QuasiCyclicConf) -> Self {
let a = init_a_polynomials(conf);
Self {
conf,
a_polynomials: a,
}
}
pub(crate) fn dual_encode(&self, rT: Array2<Block>) -> Vec<Block> {
let conf = self.conf;
let mut c_mod_p1: Array2<Block> = Array2::zeros((QuasiCyclicConf::ROWS, conf.n_blocks()));
let mut B = vec![Block::zero(); conf.N2];
let reducer = MultAddReducer::new(conf, &self.a_polynomials);
c_mod_p1
.outer_iter_mut()
.into_par_iter()
.zip(rT.outer_iter())
.for_each_init(
|| reducer.clone(),
|reducer, (mut cmod_row, rt_row)| {
let cmod_row = cmod_row.as_slice_mut().unwrap();
let rt_row = rt_row.as_slice().unwrap();
reducer.reduce(cmod_row, rt_row);
},
);
let num_blocks = Integer::next_multiple_of(&conf.requested_num_ots, &128);
copy_out(&mut B[..num_blocks], &c_mod_p1);
B.truncate(self.conf.requested_num_ots);
B
}
pub(crate) fn dual_encode_choice(&self, sb: &[Block]) -> Vec<u8> {
let mut c128 = vec![Block::zero(); self.conf.n_blocks()];
let mut reducer = MultAddReducer::new(self.conf, &self.a_polynomials);
reducer.reduce(&mut c128, sb);
let mut C = vec![0; self.conf.requested_num_ots];
let c128_bits: &BitSlice<usize, Lsb0> = BitSlice::from_slice(cast_slice(&c128));
C.iter_mut()
.zip(c128_bits.iter().by_vals())
.for_each(|(c, bit)| {
*c = bit as u8;
});
C
}
}
fn init_a_polynomials(conf: QuasiCyclicConf) -> Vec<FftPoly> {
let mut temp = vec![0_u64; 2 * conf.n_blocks()];
(0..conf.scaler - 1)
.map(|s| {
let mut fft_poly = FftPoly::new();
let mut pub_rng = AesRng::from_seed((s + 1).into());
pub_rng.fill(&mut temp[..]);
fft_poly.encode(&temp);
fft_poly
})
.collect()
}
fn copy_out(dest: &mut [Block], c_mod_p1: &Array2<Block>) {
assert_eq!(dest.len() % 128, 0, "Dest must have a length of 128");
dest.par_chunks_exact_mut(128)
.enumerate()
.for_each(|(i, chunk)| {
chunk
.iter_mut()
.zip(c_mod_p1.column(i))
.for_each(|(block, cmod)| *block = *cmod);
let transposed = BitMatrixView::from_slice(chunk, 128, 128)
.fast_transpose()
.into_vec();
chunk.copy_from_slice(&transposed);
});
}
#[derive(Copy, Clone, Debug)]
pub struct QuasiCyclicConf {
pub(crate) P: usize,
pub(crate) requested_num_ots: usize,
pub(crate) N: usize,
pub(crate) N2: usize,
pub(crate) scaler: usize,
pub(crate) size_per: usize,
pub(crate) num_partitions: usize,
}
impl QuasiCyclicConf {
pub const ROWS: usize = 128;
pub fn configure(num_ots: usize, scaler: usize, sec_param: usize) -> Self {
let P = next_prime(&max(num_ots, 128 * 128), None).unwrap();
let num_partitions = get_reg_noise_weight(0.2, sec_param) as usize;
let ss = (P * scaler + num_partitions - 1) / num_partitions;
let size_per = Integer::next_multiple_of(&ss, &8);
let N2 = size_per * num_partitions;
let N = N2 / scaler;
Self {
P,
num_partitions,
size_per,
N2,
N,
scaler,
requested_num_ots: num_ots,
}
}
pub fn n_blocks(&self) -> usize {
self.N / Self::ROWS
}
pub fn n2_blocks(&self) -> usize {
self.N2 / Self::ROWS
}
pub fn n64(self) -> usize {
self.n_blocks() * 2
}
pub fn P(&self) -> usize {
self.P
}
pub fn requested_num_ots(&self) -> usize {
self.requested_num_ots
}
pub fn N(&self) -> usize {
self.N
}
pub fn N2(&self) -> usize {
self.N2
}
pub fn scaler(&self) -> usize {
self.scaler
}
pub fn size_per(&self) -> usize {
self.size_per
}
pub fn num_partitions(&self) -> usize {
self.num_partitions
}
pub fn base_ot_count(&self) -> usize {
let pprf_conf = PprfConfig::from(*self);
pprf_conf.base_ot_count()
}
}
impl From<QuasiCyclicConf> for PprfConfig {
fn from(conf: QuasiCyclicConf) -> Self {
PprfConfig::new(conf.size_per, conf.num_partitions)
}
}
#[derive(Clone)]
pub struct MultAddReducer<'a> {
a_polynomials: &'a [FftPoly],
conf: QuasiCyclicConf,
b_poly: FftPoly,
temp128: Vec<Block>,
cache: DecodeCache,
}
impl<'a> MultAddReducer<'a> {
pub(crate) fn new(conf: QuasiCyclicConf, a_polynomials: &'a [FftPoly]) -> Self {
Self {
a_polynomials,
conf,
b_poly: FftPoly::new(),
temp128: vec![Block::zero(); 2 * conf.n_blocks()],
cache: DecodeCache::default(),
}
}
pub(crate) fn reduce(&mut self, dest: &mut [Block], b128: &[Block]) {
let n64 = self.conf.n64();
let mut c_poly = FftPoly::new();
for s in 1..self.conf.scaler {
let a_poly = &self.a_polynomials[s - 1];
let b64 = &cast_slice(b128)[s * n64..(s + 1) * n64];
let _now = Instant::now();
self.b_poly.encode(b64);
if s == 1 {
c_poly.mult(a_poly, &self.b_poly);
} else {
self.b_poly.mult_eq(a_poly);
c_poly.add_eq(&self.b_poly);
}
}
c_poly.decode_with_cache(&mut self.cache, cast_slice_mut(&mut self.temp128));
self.temp128
.iter_mut()
.zip(b128)
.take(self.conf.n_blocks())
.for_each(|(t, b)| *t ^= *b);
modp(dest, &self.temp128, self.conf.P);
}
}
pub fn modp(dest: &mut [Block], inp: &[Block], prime: usize) {
let p: usize = prime;
let p_blocks = (p + 127) / 128;
let p_bytes = (p + 7) / 8;
let dest_len = dest.len();
assert!(dest_len >= p_blocks);
assert!(inp.len() >= p_blocks);
let count = (inp.len() * 128 + p - 1) / p;
{
let dest_bytes = cast_slice_mut::<_, u8>(dest);
let inp_bytes = cast_slice::<_, u8>(inp);
dest_bytes[..p_bytes].copy_from_slice(&inp_bytes[..p_bytes]);
}
for i in 1..count {
let begin = i * p;
let begin_block = begin / 128;
let end_block = min(i * p + p, inp.len() * 128);
let end_block = (end_block + 127) / 128;
assert!(end_block <= inp.len());
let in_i = &inp[begin_block..end_block];
let shift = begin & 127;
bit_shift_xor(dest, in_i, shift as u8);
}
let dest_bytes = cast_slice_mut::<_, u8>(dest);
let offset = p & 7;
if offset != 0 {
let mask = ((1 << offset) - 1) as u8;
let idx = p / 8;
dest_bytes[idx] &= mask;
}
let rem = dest_len * 16 - p_bytes;
if rem != 0 {
dest_bytes[p_bytes..p_bytes + rem].fill(0);
}
}
pub fn bit_shift_xor(dest: &mut [Block], inp: &[Block], bit_shift: u8) {
assert!(bit_shift <= 127, "bit_shift must be less than 127");
dest.iter_mut()
.zip(inp)
.zip(&inp[1..])
.for_each(|((d, inp), inp_off)| {
let mut shifted = *inp >> bit_shift;
shifted |= *inp_off << (128 - bit_shift);
*d ^= shifted;
});
if dest.len() >= inp.len() {
let inp_last = *inp.last().expect("empty input");
dest[inp.len() - 1] ^= inp_last >> bit_shift;
}
}
#[cfg(test)]
mod tests {
use crate::silent_ot::quasi_cyclic_encode::{bit_shift_xor, modp};
use crate::util::Block;
use bitvec::order::Lsb0;
use bitvec::prelude::{BitSlice, BitVec};
use std::cmp::min;
#[test]
fn basic_bit_shift_xor() {
let dest = &mut [Block::zero(), Block::zero()];
let inp = &[Block::all_ones(), Block::all_ones()];
let bit_shift = 10;
bit_shift_xor(dest, inp, bit_shift);
assert_eq!(Block::all_ones(), dest[0]);
let exp = Block::from(u128::MAX >> bit_shift);
assert_eq!(exp, dest[1]);
}
#[test]
fn basic_modp() {
let i_bits = 1026;
let n_bits = 223;
let n = (n_bits + 127) / 128;
let c = (i_bits + n_bits - 1) / n_bits;
let mut dest = vec![Block::zero(); n];
let mut inp = vec![Block::all_ones(); (i_bits + 127) / 128];
let p = n_bits;
let inp_bits: &mut BitSlice<usize, Lsb0> =
BitSlice::from_slice_mut(bytemuck::cast_slice_mut(&mut inp));
inp_bits[i_bits..].fill(false);
let mut dv: BitVec<usize, Lsb0> = BitVec::repeat(true, p);
let mut iv: BitVec<usize, Lsb0> = BitVec::new();
for j in 1..c {
let rem = min(p, i_bits - j * p);
iv.clear();
let inp = &inp_bits[j * p..(j * p) + rem];
iv.extend_from_bitslice(inp);
iv.resize(p, false);
dv ^= &iv;
}
modp(&mut dest, &inp, p);
let dest_bits: &BitSlice<usize, Lsb0> = BitSlice::from_slice(bytemuck::cast_slice(&dest));
let dv2 = &dest_bits[..p];
assert_eq!(dv, dv2);
}
}