use crate::util::{Counter, TrackingReadWrite};
use crate::{BaseReceiver, BaseSender, TrackingChannel};
use remoc::{ConnectError, RemoteSend};
use rustls::pki_types::{CertificateDer, InvalidDnsNameError, PrivateKeyDer, ServerName};
use rustls::version::TLS13;
use rustls_native_certs::load_native_certs;
use rustls_pemfile::{certs, private_key};
use std::fmt::Debug;
use std::fs::File;
use std::io;
use std::io::BufReader;
use std::path::Path;
use std::sync::Arc;
use tokio::io::{split, AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
use tokio_rustls::{TlsAcceptor, TlsConnector};
use tracing::info;
#[derive(thiserror::Error, Debug)]
pub enum Error {
    #[error("Encountered io error when establishing TLS connection")]
    Io(#[from] io::Error),
    #[error("TLS error")]
    TlsError(#[from] rustls::Error),
    #[error("Invalid DNS name")]
    InvalidDnsNameError(#[from] InvalidDnsNameError),
    #[error("Missing private key file")]
    MissingKey,
    #[error("Error in establishing remoc connection")]
    RemocConnect(#[from] ConnectError<io::Error, io::Error>),
}
fn load_certs(path: &Path) -> Result<Vec<CertificateDer<'static>>, io::Error> {
    certs(&mut BufReader::new(File::open(path)?)).collect()
}
fn load_key(path: &Path) -> Result<PrivateKeyDer<'static>, Error> {
    private_key(&mut BufReader::new(File::open(path)?))?.ok_or(Error::MissingKey)
}
#[tracing::instrument(err)]
pub async fn listen<T: RemoteSend>(
    addr: impl ToSocketAddrs + Debug,
    private_key_file: impl AsRef<Path> + Debug,
    certificate_chain_file: impl AsRef<Path> + Debug,
) -> Result<TrackingChannel<T>, Error> {
    info!("Listening for connections");
    let listener = TcpListener::bind(addr).await?;
    let (stream, remote_addr) = listener.accept().await?;
    info!(?remote_addr, "Accepted TCP connection to remote");
    let (tracking_stream, write_counter, read_counter) = tracking_stream(stream)?;
    let (sender, receiver) =
        tls_accept(tracking_stream, private_key_file, certificate_chain_file).await?;
    Ok((sender, write_counter, receiver, read_counter))
}
#[tracing::instrument(err)]
pub async fn connect<T: RemoteSend>(
    domain: &str,
    remote_addr: impl ToSocketAddrs + Debug,
) -> Result<TrackingChannel<T>, Error> {
    info!("Connecting to remote");
    let stream = TcpStream::connect(remote_addr).await?;
    info!("Established TCP connection to server");
    let (tracking_stream, write_counter, read_counter) = tracking_stream(stream)?;
    let (sender, receiver) = tls_connect(domain, tracking_stream).await?;
    Ok((sender, write_counter, receiver, read_counter))
}
fn tracking_stream(
    tcp_stream: TcpStream,
) -> Result<
    (
        impl AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
        Counter,
        Counter,
    ),
    Error,
> {
    tcp_stream.set_nodelay(true)?;
    let (socket_read, socket_write) = tcp_stream.into_split();
    let tracking_channel = TrackingReadWrite::new(socket_read, socket_write);
    let write_counter = tracking_channel.bytes_written();
    let read_counter = tracking_channel.bytes_read();
    Ok((tracking_channel, write_counter, read_counter))
}
async fn tls_accept<T, IO>(
    tcp_stream: IO,
    private_key_file: impl AsRef<Path> + Debug,
    certificate_chain_file: impl AsRef<Path> + Debug,
) -> Result<(BaseSender<T>, BaseReceiver<T>), Error>
where
    T: RemoteSend,
    IO: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
{
    let certs = load_certs(certificate_chain_file.as_ref())?;
    let key = load_key(private_key_file.as_ref())?;
    let config = rustls::ServerConfig::builder_with_protocol_versions(&[&TLS13])
        .with_no_client_auth()
        .with_single_cert(certs, key)?;
    let acceptor = TlsAcceptor::from(Arc::new(config));
    let tls_stream = acceptor.accept(tcp_stream).await?;
    info!("Established TLS connection to remote");
    let (tls_reader, tls_writer) = split(tls_stream);
    let (sender, _, receiver, _) =
        super::establish_remoc_connection(tls_reader, tls_writer).await?;
    Ok((sender, receiver))
}
async fn tls_connect<T, IO>(
    domain: &str,
    tcp_stream: IO,
) -> Result<(BaseSender<T>, BaseReceiver<T>), Error>
where
    T: RemoteSend,
    IO: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
{
    let domain = ServerName::try_from(domain.to_string())?;
    let mut root_cert_store = rustls::RootCertStore::empty();
    let (added, ignored) = root_cert_store.add_parsable_certificates(load_native_certs()?);
    info!("Added {added} certificates to store. Ignored {ignored}");
    let config = rustls::ClientConfig::builder()
        .with_root_certificates(root_cert_store)
        .with_no_client_auth();
    let connector = TlsConnector::from(Arc::new(config));
    let tls_stream = connector.connect(domain, tcp_stream).await?;
    info!("Established TLS connection to server");
    let (tls_reader, tls_writer) = split(tls_stream);
    let (sender, _, receiver, _) =
        super::establish_remoc_connection(tls_reader, tls_writer).await?;
    Ok((sender, receiver))
}