use crate::circuit::base_circuit::Load;
use crate::circuit::circuit_connections::CrossCircuitConnections;
use crate::circuit::{base_circuit, BaseCircuit, CircuitId, DefaultIdx, GateIdx, LayerIterable};
use crate::errors::CircuitError;
use crate::gate::base::BaseGate;
use crate::protocols::{Gate, Plain, Wire};
use crate::{bristol, BooleanGate, SharedCircuit, SubCircuitGate};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt::Debug;
use std::hash::Hash;
use std::marker::PhantomData;
use std::mem;
use std::num::NonZeroUsize;
use std::path::Path;
use std::sync::Arc;
use tracing::trace;
#[derive(Debug, Serialize, Deserialize)]
#[serde(bound = "\
G: serde::Serialize + serde::de::DeserializeOwned,\
Idx: GateIdx + Ord + Eq + Hash + serde::Serialize + serde::de::DeserializeOwned,\
W: serde::Serialize + serde::de::DeserializeOwned")]
pub struct Circuit<P = bool, G = BooleanGate, Idx = DefaultIdx, W = ()> {
pub(crate) circuits: Vec<BaseCircuit<P, G, Idx, W>>,
pub(crate) circ_map: HashMap<CircuitId, usize>,
pub(crate) connections: CrossCircuitConnections<Idx>,
}
impl<P, G, Idx, W> Circuit<P, G, Idx, W> {
pub fn get_circ(&self, id: CircuitId) -> &BaseCircuit<P, G, Idx, W> {
&self.circuits[self.circ_map[&id]]
}
pub fn iter_circs(&self) -> impl Iterator<Item = &BaseCircuit<P, G, Idx, W>> + '_ {
(0..self.circ_map.len()).map(|circ_id| self.get_circ(circ_id as CircuitId))
}
}
impl<P: Plain, G: Gate<P>, Idx: GateIdx, W: Wire> Circuit<P, G, Idx, W> {
pub fn new() -> Self {
Self::default()
}
pub fn get_gate(&self, id: SubCircuitGate<Idx>) -> G {
self.get_circ(id.circuit_id).get_gate(id.gate_id)
}
pub fn parent_gates(
&self,
id: SubCircuitGate<Idx>,
) -> impl Iterator<Item = SubCircuitGate<Idx>> + '_ {
let same_circuit = self
.get_circ(id.circuit_id)
.parent_gates(id.gate_id)
.map(move |parent_gate| SubCircuitGate::new(id.circuit_id, parent_gate));
same_circuit.chain(self.connections.parent_gates(id))
}
pub fn gate_count(&self) -> usize {
self.iter_circs().map(|circ| circ.gate_count()).sum()
}
pub fn iter(&self) -> impl Iterator<Item = (G, SubCircuitGate<Idx>)> + Clone + '_ {
let layer_iter = CircuitLayerIter::new(self);
layer_iter.flat_map(|layer| {
layer
.sc_layers
.into_iter()
.flat_map(|(sc_id, _, base_layer)| base_layer.into_sc_iter(sc_id))
})
}
pub fn interactive_iter(&self) -> impl Iterator<Item = (G, SubCircuitGate<Idx>)> + Clone + '_ {
self.iter().filter(|(gate, _)| gate.is_interactive())
}
}
impl<P, G, Idx, W> Circuit<P, G, Idx, W> {
pub fn interactive_count(&self) -> usize {
self.iter_circs().map(|circ| circ.interactive_count()).sum()
}
pub fn interactive_count_times_simd(&self) -> usize {
self.iter_circs()
.map(|circ| {
circ.interactive_count() * circ.simd_size().map(NonZeroUsize::get).unwrap_or(1)
})
.sum()
}
pub fn input_count(&self) -> usize {
self.get_circ(0).input_count()
}
pub fn output_count(&self) -> usize {
self.get_circ(0).output_count()
}
}
impl<P, G> Circuit<P, G, u32>
where
P: Clone,
G: Gate<P> + From<BaseGate<P>> + for<'a> From<&'a bristol::Gate>,
{
pub fn load_bristol(path: impl AsRef<Path>) -> Result<Self, CircuitError> {
BaseCircuit::load_bristol(path, Load::Circuit).map(Into::into)
}
}
impl<P: Clone, G: Clone, Idx: GateIdx, W: Clone> Clone for Circuit<P, G, Idx, W> {
fn clone(&self) -> Self {
Self {
circuits: self.circuits.clone(),
circ_map: self.circ_map.clone(),
connections: self.connections.clone(),
}
}
}
#[derive(Clone)]
pub struct CircuitLayerIter<'a, P, G, Idx: GateIdx, W> {
circuit: &'a Circuit<P, G, Idx, W>,
layer_iters: HashMap<CircuitId, base_circuit::BaseLayerIter<'a, P, G, Idx, W>>,
}
impl<'a, P, G: Gate<P>, Idx: GateIdx, W: Wire> CircuitLayerIter<'a, P, G, Idx, W> {
pub fn new(circuit: &'a Circuit<P, G, Idx, W>) -> Self {
let first_iter = circuit.get_circ(0).layer_iter();
Self {
circuit,
layer_iters: [(0, first_iter)].into(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CircuitLayer<P, G, Idx: Hash + PartialEq + Eq> {
pub(crate) sc_layers: Vec<(
CircuitId,
Option<NonZeroUsize>,
base_circuit::CircuitLayer<G, Idx>,
)>,
_plain: PhantomData<P>,
}
impl<'a, P: Debug, G: Gate<P>, Idx: GateIdx, W: Wire> Iterator
for CircuitLayerIter<'a, P, G, Idx, W>
{
type Item = CircuitLayer<P, G, Idx>;
fn next(&mut self) -> Option<Self::Item> {
trace!("layer_iters: {:#?}", &self.layer_iters);
let mut sc_layers: Vec<_> = self
.layer_iters
.iter_mut()
.filter_map(|(&sc_id, iter)| {
iter.next().map(|layer| {
let simd_size = self.circuit.get_circ(sc_id).simd_size();
(sc_id, simd_size, layer)
})
})
.collect();
self.layer_iters.retain(|_sc_id, iter| !iter.is_exhausted());
sc_layers.sort_unstable_by_key(|sc_id| sc_id.0);
for (sc_id, _simd_size, layer) in &sc_layers {
for potential_out in layer.iter_ids() {
let from = SubCircuitGate::new(*sc_id, potential_out);
let outgoing = self.circuit.connections.outgoing_gates(from);
for sc_gate in outgoing {
let to_layer_iter =
self.layer_iters
.entry(sc_gate.circuit_id)
.or_insert_with(|| {
base_circuit::BaseLayerIter::new_uninit(
self.circuit.get_circ(sc_gate.circuit_id),
)
});
to_layer_iter.add_to_next_layer(sc_gate.gate_id.into());
}
}
}
if sc_layers.is_empty() {
None
} else {
Some(CircuitLayer {
sc_layers,
_plain: PhantomData,
})
}
}
}
impl<P, G: Clone, Idx: GateIdx + Hash + PartialEq + Eq + Copy> CircuitLayer<P, G, Idx> {
pub(crate) fn interactive_count_times_simd(&self) -> usize {
self.sc_layers
.iter()
.map(|(_, simd, layer)| {
let simd = simd.map(|v| v.get()).unwrap_or(1);
simd * layer.interactive_len()
})
.sum()
}
pub(crate) fn split_simd(mut self) -> (Self, Self) {
let mut simd = vec![];
self.sc_layers.retain_mut(|(sc_id, simd_size, layer)| {
if simd_size.is_some() {
simd.push((*sc_id, *simd_size, mem::take(layer)));
false
} else {
true
}
});
(
self,
Self {
sc_layers: simd,
_plain: PhantomData,
},
)
}
pub(crate) fn non_interactive_iter(
&self,
) -> impl Iterator<Item = (G, SubCircuitGate<Idx>)> + Clone + '_ {
self.sc_layers.iter().flat_map(|(sc_id, _, layer)| {
layer
.non_interactive_iter()
.map(|(gate, gate_idx)| (gate, SubCircuitGate::new(*sc_id, gate_idx)))
})
}
pub(crate) fn interactive_iter(
&self,
) -> impl Iterator<Item = (G, SubCircuitGate<Idx>)> + Clone + '_ {
self.sc_layers.iter().flat_map(|(sc_id, _, layer)| {
layer
.interactive_iter()
.map(|(gate, gate_idx)| (gate, SubCircuitGate::new(*sc_id, gate_idx)))
})
}
pub(crate) fn freeable_simd_gates(&self) -> impl Iterator<Item = SubCircuitGate<Idx>> + '_ {
self.sc_layers.iter().flat_map(|(sc_id, _, layer)| {
layer
.freeable_gates
.iter()
.map(|gate_idx| SubCircuitGate::new(*sc_id, *gate_idx))
})
}
}
impl<P, G, Idx: GateIdx, W> Default for Circuit<P, G, Idx, W> {
fn default() -> Self {
Self {
circuits: vec![],
circ_map: Default::default(),
connections: Default::default(),
}
}
}
impl<P, G, Idx: GateIdx + Default, W> From<BaseCircuit<P, G, Idx, W>> for Circuit<P, G, Idx, W> {
fn from(bc: BaseCircuit<P, G, Idx, W>) -> Self {
Self {
circuits: vec![bc],
circ_map: [(0, 0)].into_iter().collect(),
connections: Default::default(),
}
}
}
impl<P, G, Idx: GateIdx, W> TryFrom<SharedCircuit<P, G, Idx, W>> for Circuit<P, G, Idx, W> {
type Error = SharedCircuit<P, G, Idx, W>;
fn try_from(circuit: SharedCircuit<P, G, Idx, W>) -> Result<Self, Self::Error> {
Arc::try_unwrap(circuit).map(|mutex| mutex.into_inner().into())
}
}