1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
//! This example shows how to load a Bristol circuit (sha256) and execute it, optionally generating
//! the multiplication triples via a third trusted party (see the `trusted_party_mts.rs` example).
//!
//! It also demonstrates how to use the [`Statistics`] API to track the communication of
//! different phases and write it to a file.

use std::fs::File;
use std::io::{stdout, BufWriter, Write};
use std::net::SocketAddr;
use std::path::PathBuf;

use anyhow::Result;
use clap::Parser;
use tracing_subscriber::EnvFilter;

use seec::circuit::base_circuit::Load;
use seec::circuit::{BaseCircuit, ExecutableCircuit};
use seec::common::BitVec;
use seec::executor::{Executor, Input, Message};
use seec::mul_triple::boolean::insecure_provider::InsecureMTProvider;
use seec::mul_triple::boolean::trusted_seed_provider::TrustedMTProviderClient;
use seec::protocols::boolean_gmw::BooleanGmw;
use seec::BooleanGate;
use seec_channel::sub_channels_for;
use seec_channel::util::{Phase, Statistics};

#[derive(Parser, Debug)]
struct Args {
    /// Id of this party
    #[clap(long)]
    id: usize,

    /// Address of server to bind or connect to
    #[clap(long)]
    server: SocketAddr,

    /// Optional address of trusted server providing MTs
    #[clap(long)]
    mt_provider: Option<SocketAddr>,

    /// Sha256 as a bristol circuit
    #[clap(
        long,
        default_value = "test_resources/bristol-circuits/sha-256-low_depth.txt"
    )]
    circuit: PathBuf,
    /// File path for the communication statistics. Will overwrite existing files.
    #[clap(long)]
    stats: Option<PathBuf>,
}

#[tokio::main]
async fn main() -> Result<()> {
    let _guard = init_tracing()?;
    let args = Args::parse();
    let circuit: ExecutableCircuit<bool, BooleanGate, u32> = ExecutableCircuit::DynLayers(
        BaseCircuit::load_bristol(args.circuit, Load::Circuit)?.into(),
    );

    let (mut sender, bytes_written, mut receiver, bytes_read) = match args.id {
        0 => seec_channel::tcp::listen(args.server).await?,
        1 => seec_channel::tcp::connect(args.server).await?,
        illegal => anyhow::bail!("Illegal party id {illegal}. Must be 0 or 1."),
    };

    // Initialize the communication statistics tracker with the counters for the main channel
    let mut comm_stats = Statistics::new(bytes_written, bytes_read).without_unaccounted(true);

    let (mut sender, mut receiver) =
        sub_channels_for!(&mut sender, &mut receiver, 8, Message<BooleanGmw>).await?;

    let mut executor: Executor<BooleanGmw, _> = if let Some(addr) = args.mt_provider {
        let (mt_sender, bytes_written, mt_receiver, bytes_read) =
            seec_channel::tcp::connect(addr).await?;
        // Set the counters for the helper channel
        comm_stats.set_helper(bytes_written, bytes_read);
        let mt_provider = TrustedMTProviderClient::new("unique-id".into(), mt_sender, mt_receiver);
        // As the MTs are generated when the Executor is created, we record the communication
        // with the `record_helper` method and a custom category
        comm_stats
            .record_helper(
                Phase::Custom("Helper-Mts".into()),
                Executor::new(&circuit, args.id, mt_provider),
            )
            .await?
    } else {
        let mt_provider = InsecureMTProvider::default();
        comm_stats
            .record(
                Phase::FunctionDependentSetup,
                Executor::new(&circuit, args.id, mt_provider),
            )
            .await?
    };
    let input = BitVec::repeat(false, 768);
    let _out = comm_stats
        .record(
            Phase::Online,
            executor.execute(Input::Scalar(input), &mut sender, &mut receiver),
        )
        .await?;

    // Depending on whether a --stats file is set, create a file writer or stdout
    let mut writer: Box<dyn Write> = match args.stats {
        Some(path) => {
            let file = File::create(path)?;
            Box::new(file)
        }
        None => Box::new(stdout()),
    };
    // serde_json is used to write the statistics in json format. `.csv` is currently not
    // supported.
    let mut res = comm_stats.into_run_result();
    res.add_metadata("circuit", "sha256.rs");
    serde_json::to_writer_pretty(&mut writer, &res)?;
    writeln!(writer)?;

    Ok(())
}

pub fn init_tracing() -> Result<tracing_appender::non_blocking::WorkerGuard> {
    let log_writer = BufWriter::new(File::create("sha256.log")?);
    let (non_blocking, appender_guard) = tracing_appender::non_blocking(log_writer);
    tracing_subscriber::fmt()
        .json()
        .with_env_filter(EnvFilter::from_default_env())
        .with_writer(non_blocking)
        .init();
    Ok(appender_guard)
}