diff options
author | Aria <me@aria.rip> | 2023-10-13 01:33:38 +0100 |
---|---|---|
committer | Aria <me@aria.rip> | 2023-10-13 01:33:38 +0100 |
commit | 1b6c1b425f78f4ec3eb275f21a792776e50cbf93 (patch) | |
tree | 9adb3c9fc11ee379078b60243f1705e991f7bf5d /common/src | |
parent | 186087b2010f7f2b9631a28b80527d99b751b882 (diff) |
start using async
Diffstat (limited to 'common/src')
-rw-r--r-- | common/src/lib.rs | 154 | ||||
-rw-r--r-- | common/src/msg.rs | 50 | ||||
-rw-r--r-- | common/src/msg_id.rs | 16 |
3 files changed, 141 insertions, 79 deletions
diff --git a/common/src/lib.rs b/common/src/lib.rs index 616dfcc..5317b3e 100644 --- a/common/src/lib.rs +++ b/common/src/lib.rs @@ -1,94 +1,122 @@ -use std::io::{Read, Write}; +#![feature(return_position_impl_trait_in_trait)] +use std::{ + future::Future, + io::{self, Read, Write}, + thread, +}; -use msg::{MaelstromBody, MaelstromBodyOr, Message, MessageHeader}; +use msg::{MaelstromBody, Message, MessageHeader, Output}; +use msg_id::gen_msg_id; use serde::{Deserialize, Serialize}; use serde_json::Deserializer; +use smol::{ + channel::{self, Receiver, Sender}, + future, + stream::StreamExt, + Executor, +}; pub mod msg; +pub mod msg_id; pub trait Handler { - type Body: Serialize + for<'a> Deserialize<'a>; + type Body: Serialize + for<'a> Deserialize<'a> + Send + Clone; - fn init(node_id: String, node_ids: Vec<String>, msg_id: usize) -> Self; - fn handle( - &mut self, - header: MessageHeader, - body: Self::Body, - writer: &mut MsgWriter<impl Write>, - ) -> (); + fn init(node_id: String, node_ids: Vec<String>, output: Output<Self::Body>) -> Self; + fn handle(&self, header: MessageHeader, body: Self::Body) -> impl Future<Output = ()> + Send; } -pub struct MsgWriter<W> { - node_id: String, - writer: W, -} +pub fn run_server<H: Handler>() { + // Perform sync initialisation of the handler + // This is a special case so that we can use a different message body type just for init messages + let (handler, out_recv) = sync_init_handler::<H, _, _>(io::stdin(), io::stdout()); -impl<W: Write> MsgWriter<W> { - pub fn new(node_id: String, writer: W) -> Self { - Self { node_id, writer } - } + let (inp_send, mut inp_recv) = channel::unbounded::<Message<H::Body>>(); - pub fn write<T: Serialize>(&mut self, dst: String, msg: &T) { - let msg = Message { - header: MessageHeader { - src: self.node_id.clone(), - dst, - }, - body: MaelstromBodyOr::Other { inner: msg }, - }; - serde_json::to_writer(&mut self.writer, &msg).unwrap(); - self.writer.write(&[b'\n']).unwrap(); - } -} - -pub fn run_with<T: Handler>(mut reader: impl Read, mut writer: impl Write) { - let (mut handler, mut msg_writer) = init_handler::<T, _, _>(&mut reader, &mut writer); + thread::scope(|s| { + // Worker threads for receiving and sending + // This is easier than making it async, and good enough for our usecase. + s.spawn(|| recv_loop(io::stdin(), inp_send)); + s.spawn(|| send_loop(io::stdout(), out_recv)); - let deser = Deserializer::from_reader(reader); - for msg in deser.into_iter::<Message<T::Body>>() { - let msg = msg.unwrap(); - match msg.body { - MaelstromBodyOr::Other { inner } => { - handler.handle(msg.header, inner, &mut msg_writer); + // As we receive messages, spawn a future for each + let executor = Executor::new(); + future::block_on(executor.run(async { + while let Some(msg) = inp_recv.next().await { + executor + .spawn(handler.handle(msg.header, msg.body)) + .detach(); } - _ => todo!(), - }; - } + })); + }); } -pub fn init_handler<T: Handler, R: Read, W: Write>(reader: R, writer: W) -> (T, MsgWriter<W>) { +/// Initialises the handler synchronously. +/// +/// This is done as a seperate step because we initially deserialize into a different type +/// than our handler will accept, so there's no point spawning and immediately finishing threads. +fn sync_init_handler<H: Handler, R: Read, W: Write>( + reader: R, + mut writer: W, +) -> (H, Receiver<Message<H::Body>>) { + // Receive the init message let deser = Deserializer::from_reader(reader); - let mut deser = deser.into_iter::<Message<()>>(); - let Some(msg) = deser.next() else { - panic!("stream ended before init message"); - }; - let Ok(msg) = msg else { - panic!("{}", msg.unwrap_err()); - }; - - let (node_id, node_ids, msg_id) = match msg.body { - MaelstromBodyOr::MaelstromBody { - inner: + let mut deser = deser.into_iter::<Message<MaelstromBody>>(); + let (init_header, node_id, node_ids, init_msg_id) = match deser.next() { + Some(Ok(Message { + header, + body: MaelstromBody::Init { node_id, node_ids, msg_id, }, - } => (node_id, node_ids, msg_id), + })) => (header, node_id, node_ids, msg_id), + Some(Err(e)) => panic!("invalid init message: {}", e), _ => { panic!("expected init message to be first message"); } }; - let mut writer = MsgWriter::new(node_id.clone(), writer); - - writer.write( - msg.header.src, - &MaelstromBody::InitOk { - msg_id: 0, - in_reply_to: msg_id, + // Write the init_ok message + write_newline( + &mut writer, + &Message { + header: init_header.flip(), + body: MaelstromBody::InitOk { + in_reply_to: init_msg_id, + msg_id: gen_msg_id(), + }, }, ); - (T::init(node_id, node_ids, msg_id), writer) + // Create handler, and channel to go with it + let (send, recv) = channel::unbounded(); + + ( + H::init(node_id.clone(), node_ids, Output::new(node_id, send)), + recv, + ) +} + +/// Receives JSON from a reader, and outputs the deserialised result to a channel +fn recv_loop<M: for<'a> Deserialize<'a>>(reader: impl Read, channel: Sender<M>) { + let deser = Deserializer::from_reader(reader); + for msg in deser.into_iter() { + let msg = msg.unwrap(); + channel.send_blocking(msg).unwrap(); + } +} + +/// Receives things to send, and outputs them as JSON to writer +fn send_loop<M: Serialize>(mut writer: impl Write, channel: Receiver<M>) { + while let Ok(msg) = channel.recv_blocking() { + write_newline(&mut writer, msg); + } +} + +/// Write a message to writer, followed by a newline +fn write_newline(mut writer: impl Write, msg: impl Serialize) { + serde_json::to_writer(&mut writer, &msg).unwrap(); + writer.write(&[b'\n']).unwrap(); } diff --git a/common/src/msg.rs b/common/src/msg.rs index 23db171..7e9863f 100644 --- a/common/src/msg.rs +++ b/common/src/msg.rs @@ -1,10 +1,13 @@ use serde::{Deserialize, Serialize}; +use smol::channel::Sender; + +use crate::msg_id::MessageID; #[derive(Debug, Serialize, Deserialize)] pub struct Message<B> { #[serde(flatten)] pub header: MessageHeader, - pub body: MaelstromBodyOr<B>, + pub body: B, } #[derive(Debug, Serialize, Deserialize)] @@ -24,27 +27,42 @@ impl MessageHeader { } #[derive(Debug, Serialize, Deserialize)] -#[serde(untagged)] -pub enum MaelstromBodyOr<B> { - MaelstromBody { - #[serde(flatten)] - inner: MaelstromBody, - }, - Other { - #[serde(flatten)] - inner: B, - }, -} - -#[derive(Debug, Serialize, Deserialize)] #[serde(tag = "type")] pub enum MaelstromBody { #[serde(rename = "init")] Init { node_id: String, node_ids: Vec<String>, - msg_id: usize, + msg_id: MessageID, }, #[serde(rename = "init_ok")] - InitOk { msg_id: usize, in_reply_to: usize }, + InitOk { + msg_id: MessageID, + in_reply_to: MessageID, + }, +} + +pub struct Output<B> { + node_id: String, + channel: Sender<Message<B>>, +} + +impl<B: Serialize + Clone> Output<B> { + pub fn new(node_id: String, channel: Sender<Message<B>>) -> Self { + Self { node_id, channel } + } + + pub async fn send(&self, dst: &str, body: &B) { + self.send_raw(Message { + header: MessageHeader { + src: self.node_id.clone(), + dst: dst.to_string(), + }, + body: body.clone(), + }) + .await; + } + pub async fn send_raw(&self, msg: Message<B>) { + self.channel.send(msg).await.unwrap(); + } } diff --git a/common/src/msg_id.rs b/common/src/msg_id.rs new file mode 100644 index 0000000..e953f08 --- /dev/null +++ b/common/src/msg_id.rs @@ -0,0 +1,16 @@ +use std::time::{SystemTime, UNIX_EPOCH}; + +use rand::{thread_rng, Rng}; + +pub type MessageID = u64; + +pub fn gen_msg_id() -> MessageID { + // Time since UNIX epoch in milliseconds, (48 bits) + let now = SystemTime::now(); + let time_millis: u128 = now.duration_since(UNIX_EPOCH).unwrap().as_millis(); + + // 16 bits of randomness + let rand: u16 = thread_rng().gen(); + + ((time_millis as u64) << 16) | (rand as u64) +} |