// SPDX-License-Identifier: AGPL-3.0-or-later
#![feature(impl_trait_in_assoc_type)]

use btmsg::*;

use btlib::{
    crypto::{ConcreteCreds, Creds, CredsPriv},
    BlockPath, Epoch, Principal, Principaled,
};
use core::future::{ready, Future, Ready};
use ctor::ctor;
use futures::join;
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
use std::{
    io::Write,
    net::{IpAddr, Ipv6Addr},
    sync::{Arc, Mutex as SyncMutex},
    time::Duration,
};
use tokio::sync::mpsc::{self, Sender};

#[ctor]
fn setup_logging() {
    use env_logger::Env;
    let env = Env::default().default_filter_or("ERROR");
    env_logger::builder()
        .format(|fmt, record| {
            writeln!(
                fmt,
                "[{} {} {}:{}] {}",
                chrono::Utc::now().to_rfc3339(),
                record.level(),
                record.file().unwrap_or("(unknown)"),
                record.line().unwrap_or(u32::MAX),
                record.args(),
            )
        })
        .parse_env(env)
        .init();
}

lazy_static! {
    static ref ROOT_CREDS: ConcreteCreds = ConcreteCreds::generate().unwrap();
    static ref NODE_CREDS: ConcreteCreds = {
        let mut creds = ConcreteCreds::generate().unwrap();
        let root_creds = &ROOT_CREDS;
        let writecap = root_creds
            .issue_writecap(
                creds.principal(),
                &mut std::iter::empty(),
                Epoch::now() + Duration::from_secs(3600),
            )
            .unwrap();
        creds.set_writecap(writecap).unwrap();
        creds
    };
    static ref ROOT_PRINCIPAL: Principal = ROOT_CREDS.principal();
}

#[derive(Debug, Serialize, Deserialize)]
enum Reply {
    Success,
    Fail,
    ReadReply { offset: u64, buf: Vec<u8> },
}

#[derive(Serialize, Deserialize)]
enum Msg<'a> {
    Ping,
    Success,
    Fail,
    Read { offset: u64, size: u64 },
    Write { offset: u64, buf: &'a [u8] },
}

impl<'a> CallMsg<'a> for Msg<'a> {
    type Reply<'b> = Reply;
}

impl<'a> SendMsg<'a> for Msg<'a> {}

trait TestFunc<S: 'static + Send, Fut: Send + Future>:
    Send + Sync + Fn(MsgReceived<Msg<'_>>, Sender<S>) -> Fut
{
}

impl<
        S: 'static + Send,
        Fut: Send + Future,
        T: Send + Sync + Fn(MsgReceived<Msg<'_>>, Sender<S>) -> Fut,
    > TestFunc<S, Fut> for T
{
}

struct Delegate<S, Fut> {
    func: Arc<dyn TestFunc<S, Fut>>,
    sender: Sender<S>,
}

impl<S, Fut> Clone for Delegate<S, Fut> {
    fn clone(&self) -> Self {
        Self {
            func: self.func.clone(),
            sender: self.sender.clone(),
        }
    }
}

impl<S: 'static + Send, Fut: Send + Future> Delegate<S, Fut> {
    fn new<F: 'static + TestFunc<S, Fut>>(sender: Sender<S>, func: F) -> Self {
        Self {
            func: Arc::new(func),
            sender,
        }
    }
}

impl<S: 'static + Send, Fut: Send + Future<Output = btlib::Result<()>>> MsgCallback
    for Delegate<S, Fut>
{
    type Arg<'de> = Msg<'de> where Self: 'de;
    type CallFut<'s> = Fut where Fut: 's;
    fn call<'de>(&'de self, arg: MsgReceived<Self::Arg<'de>>) -> Self::CallFut<'de> {
        (self.func)(arg, self.sender.clone())
    }
}

fn proc_creds() -> impl Creds {
    let mut creds = ConcreteCreds::generate().unwrap();
    let writecap = NODE_CREDS
        .issue_writecap(
            creds.principal(),
            &mut std::iter::empty(),
            Epoch::now() + Duration::from_secs(3600),
        )
        .unwrap();
    creds.set_writecap(writecap).unwrap();
    creds
}

fn proc_rx<F: 'static + MsgCallback>(callback: F) -> (Receiver, Arc<BlockAddr>) {
    let ip_addr = IpAddr::V6(Ipv6Addr::LOCALHOST);
    let creds = proc_creds();
    let writecap = creds.writecap().unwrap();
    let addr = Arc::new(BlockAddr::new(ip_addr, Arc::new(writecap.bind_path())));
    (
        Receiver::new(ip_addr, Arc::new(creds), callback).unwrap(),
        addr,
    )
}

async fn proc_tx_rx<F: 'static + MsgCallback>(func: F) -> (Transmitter, Receiver) {
    let (receiver, addr) = proc_rx(func);
    let sender = receiver.transmitter(addr).await.unwrap();
    (sender, receiver)
}

async fn file_server() -> (Transmitter, Receiver) {
    let (sender, _) = mpsc::channel::<()>(1);
    let file = Arc::new(SyncMutex::new([0, 1, 2, 3, 4, 5, 6, 7]));
    proc_tx_rx(Delegate::new(
        sender,
        move |mut received: MsgReceived<Msg<'_>>, _| {
            let mut guard = file.lock().unwrap();
            let reply_body = match received.body() {
                Msg::Read { offset, size } => {
                    let offset: usize = (*offset).try_into().unwrap();
                    let size: usize = (*size).try_into().unwrap();
                    let end: usize = offset + size;
                    let mut buf = Vec::with_capacity(end - offset);
                    buf.extend_from_slice(&guard[offset..end]);
                    Reply::ReadReply {
                        offset: offset as u64,
                        buf,
                    }
                }
                Msg::Write { offset, ref buf } => {
                    let offset: usize = (*offset).try_into().unwrap();
                    let end: usize = offset + buf.len();
                    (&mut guard[offset..end]).copy_from_slice(buf);
                    Reply::Success
                }
                _ => Reply::Fail,
            };
            let mut replier = received.take_replier().unwrap();
            async move { replier.reply(reply_body).await }
        },
    ))
    .await
}

async fn timeout<F: Future>(future: F) -> F::Output {
    tokio::time::timeout(Duration::from_millis(1000), future)
        .await
        .unwrap()
}

macro_rules! recv {
    ($rx:expr) => {
        timeout($rx.recv()).await.unwrap()
    };
}

#[tokio::test]
async fn message_received_is_message_sent() {
    let (sender, mut passed) = mpsc::channel(1);
    let (sender, _receiver) = proc_tx_rx(Delegate::new(
        sender,
        |msg: MsgReceived<Msg<'_>>, sender: Sender<bool>| {
            let passed = if let Msg::Ping = msg.body() {
                true
            } else {
                false
            };
            let sender = sender.clone();
            async move {
                sender.send(passed).await.unwrap();
                Ok(())
            }
        },
    ))
    .await;

    sender.send(Msg::Ping).await.unwrap();

    assert!(recv!(passed));
}

#[tokio::test]
async fn message_received_from_path_is_correct() {
    let (sender, mut path) = mpsc::channel(1);
    let (sender, receiver) = proc_tx_rx(Delegate::new(
        sender,
        |msg: MsgReceived<Msg<'_>>, sender: Sender<Arc<BlockPath>>| {
            let path = msg.from().clone();
            let sender = sender.clone();
            async move {
                sender.send(path).await.unwrap();
                Ok(())
            }
        },
    ))
    .await;

    sender.send(Msg::Ping).await.unwrap();

    assert_eq!(receiver.addr().path().as_ref(), recv!(path).as_ref());
}

#[tokio::test]
async fn reply_to_read() {
    let (sender, _receiver) = file_server().await;
    let reply = sender
        .call_through::<Msg>(Msg::Read { offset: 2, size: 2 })
        .await
        .unwrap();
    if let Reply::ReadReply { offset, buf } = reply {
        assert_eq!(2, offset);
        assert_eq!([2, 3].as_slice(), buf.as_slice());
    } else {
        panic!("reply was not the right type");
    };
}

#[tokio::test]
async fn call_twice() {
    let (sender, _receiver) = file_server().await;

    let reply = sender
        .call_through::<Msg>(Msg::Write {
            offset: 1,
            buf: &[1, 1],
        })
        .await
        .unwrap();
    if let Reply::Success = reply {
        ()
    } else {
        panic!("reply was not the right type");
    };
    let reply = sender
        .call_through::<Msg>(Msg::Read { offset: 1, size: 2 })
        .await
        .unwrap();
    if let Reply::ReadReply { offset, buf } = reply {
        assert_eq!(1, offset);
        assert_eq!([1, 1].as_slice(), buf.as_slice());
    } else {
        panic!("second reply was not the right type");
    }
}

#[tokio::test]
async fn separate_transmitter() {
    let (_sender, receiver) = file_server().await;
    let creds = proc_creds();
    let transmitter = Transmitter::new(receiver.addr().clone(), Arc::new(creds))
        .await
        .unwrap();

    let reply = transmitter
        .call_through::<Msg>(Msg::Write {
            offset: 5,
            buf: &[7, 7, 7],
        })
        .await
        .unwrap();
    let matched = if let Reply::Success = reply {
        true
    } else {
        false
    };
    assert!(matched);
}

#[derive(Serialize, Deserialize)]
struct Read {
    offset: usize,
    size: usize,
}

#[derive(Serialize, Deserialize)]
struct ReadReply<'a> {
    buf: &'a [u8],
}

impl<'a> CallMsg<'a> for Read {
    type Reply<'b> = ReadReply<'b>;
}

#[derive(Clone)]
struct ReadChecker<'a> {
    expected: &'a [u8],
}

impl<'a> ReadChecker<'a> {
    fn new(expected: &'a [u8]) -> Self {
        Self { expected }
    }
}

impl<'a> DeserCallback for ReadChecker<'a> {
    type Arg<'de> = ReadReply<'de> where Self: 'de;
    type Return = bool;
    type CallFut<'s> = Ready<bool> where Self: 's;
    fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
        ready(self.expected == arg.buf)
    }
}

trait ActionFn<Arg, Fut: Send + Future>: Send + Sync + Fn(MsgReceived<Arg>) -> Fut {}

impl<Arg, Fut: Send + Future, T: Send + Sync + Fn(MsgReceived<Arg>) -> Fut> ActionFn<Arg, Fut>
    for T
{
}

struct Action<Arg, Fut> {
    func: Arc<dyn ActionFn<Arg, Fut>>,
}

impl<Arg, Fut: Send + Future> Action<Arg, Fut> {
    fn new<F: 'static + ActionFn<Arg, Fut>>(func: F) -> Self {
        Self {
            func: Arc::new(func),
        }
    }
}

impl<Arg, Fut> Clone for Action<Arg, Fut> {
    fn clone(&self) -> Self {
        Self {
            func: self.func.clone(),
        }
    }
}

impl<Arg: for<'a> CallMsg<'a>, Fut: Send + Future<Output = btlib::Result<()>>> MsgCallback
    for Action<Arg, Fut>
{
    type Arg<'de> = Arg where Arg: 'de, Fut: 'de;
    type CallFut<'de> = Fut where Arg: 'de, Fut: 'de;
    fn call<'de>(&'de self, arg: MsgReceived<Self::Arg<'de>>) -> Self::CallFut<'de> {
        (self.func)(arg)
    }
}

async fn read_server() -> (Transmitter, Receiver) {
    let file = [0, 1, 2, 3, 4, 5, 6, 7];
    proc_tx_rx(Action::new(move |mut msg: MsgReceived<Read>| async move {
        let body = msg.body();
        let start = body.offset;
        let end = start + body.size;
        let buf = &file[start..end];
        let mut replier = msg.take_replier().unwrap();
        replier.reply(ReadReply { buf }).await
    }))
    .await
}

#[tokio::test]
async fn call_with_lifetime() {
    let (sender, _receiver) = read_server().await;

    let correct_one = sender
        .call(Read { offset: 2, size: 3 }, ReadChecker::new(&[2, 3, 4]))
        .await
        .unwrap();
    let correct_two = sender
        .call(Read { offset: 0, size: 2 }, ReadChecker::new(&[0, 1]))
        .await
        .unwrap();

    assert!(correct_one);
    assert!(correct_two);
}

#[tokio::test]
async fn call_concurrently() {
    let (sender, _receiver) = read_server().await;

    let call_one = sender.call(Read { offset: 2, size: 3 }, ReadChecker::new(&[2, 3, 4]));
    let call_two = sender.call(Read { offset: 0, size: 2 }, ReadChecker::new(&[0, 1]));
    let (result_one, result_two) = join!(call_one, call_two);

    assert!(result_one.unwrap());
    assert!(result_two.unwrap());
}