summaryrefslogtreecommitdiff
path: root/common
diff options
context:
space:
mode:
authorAria <me@aria.rip>2023-10-13 01:33:38 +0100
committerAria <me@aria.rip>2023-10-13 01:33:38 +0100
commit1b6c1b425f78f4ec3eb275f21a792776e50cbf93 (patch)
tree9adb3c9fc11ee379078b60243f1705e991f7bf5d /common
parent186087b2010f7f2b9631a28b80527d99b751b882 (diff)
start using async
Diffstat (limited to 'common')
-rw-r--r--common/Cargo.toml6
-rw-r--r--common/src/lib.rs154
-rw-r--r--common/src/msg.rs50
-rw-r--r--common/src/msg_id.rs16
4 files changed, 145 insertions, 81 deletions
diff --git a/common/Cargo.toml b/common/Cargo.toml
index 3abc652..b184b3c 100644
--- a/common/Cargo.toml
+++ b/common/Cargo.toml
@@ -6,5 +6,7 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
-serde = { version = "1.0.185", features = ["derive"] }
-serde_json = "1.0.105" \ No newline at end of file
+smol = { workspace = true }
+serde = { workspace = true }
+serde_json = { workspace = true }
+rand = "0.8.5" \ No newline at end of file
diff --git a/common/src/lib.rs b/common/src/lib.rs
index 616dfcc..5317b3e 100644
--- a/common/src/lib.rs
+++ b/common/src/lib.rs
@@ -1,94 +1,122 @@
-use std::io::{Read, Write};
+#![feature(return_position_impl_trait_in_trait)]
+use std::{
+ future::Future,
+ io::{self, Read, Write},
+ thread,
+};
-use msg::{MaelstromBody, MaelstromBodyOr, Message, MessageHeader};
+use msg::{MaelstromBody, Message, MessageHeader, Output};
+use msg_id::gen_msg_id;
use serde::{Deserialize, Serialize};
use serde_json::Deserializer;
+use smol::{
+ channel::{self, Receiver, Sender},
+ future,
+ stream::StreamExt,
+ Executor,
+};
pub mod msg;
+pub mod msg_id;
pub trait Handler {
- type Body: Serialize + for<'a> Deserialize<'a>;
+ type Body: Serialize + for<'a> Deserialize<'a> + Send + Clone;
- fn init(node_id: String, node_ids: Vec<String>, msg_id: usize) -> Self;
- fn handle(
- &mut self,
- header: MessageHeader,
- body: Self::Body,
- writer: &mut MsgWriter<impl Write>,
- ) -> ();
+ fn init(node_id: String, node_ids: Vec<String>, output: Output<Self::Body>) -> Self;
+ fn handle(&self, header: MessageHeader, body: Self::Body) -> impl Future<Output = ()> + Send;
}
-pub struct MsgWriter<W> {
- node_id: String,
- writer: W,
-}
+pub fn run_server<H: Handler>() {
+ // Perform sync initialisation of the handler
+ // This is a special case so that we can use a different message body type just for init messages
+ let (handler, out_recv) = sync_init_handler::<H, _, _>(io::stdin(), io::stdout());
-impl<W: Write> MsgWriter<W> {
- pub fn new(node_id: String, writer: W) -> Self {
- Self { node_id, writer }
- }
+ let (inp_send, mut inp_recv) = channel::unbounded::<Message<H::Body>>();
- pub fn write<T: Serialize>(&mut self, dst: String, msg: &T) {
- let msg = Message {
- header: MessageHeader {
- src: self.node_id.clone(),
- dst,
- },
- body: MaelstromBodyOr::Other { inner: msg },
- };
- serde_json::to_writer(&mut self.writer, &msg).unwrap();
- self.writer.write(&[b'\n']).unwrap();
- }
-}
-
-pub fn run_with<T: Handler>(mut reader: impl Read, mut writer: impl Write) {
- let (mut handler, mut msg_writer) = init_handler::<T, _, _>(&mut reader, &mut writer);
+ thread::scope(|s| {
+ // Worker threads for receiving and sending
+ // This is easier than making it async, and good enough for our usecase.
+ s.spawn(|| recv_loop(io::stdin(), inp_send));
+ s.spawn(|| send_loop(io::stdout(), out_recv));
- let deser = Deserializer::from_reader(reader);
- for msg in deser.into_iter::<Message<T::Body>>() {
- let msg = msg.unwrap();
- match msg.body {
- MaelstromBodyOr::Other { inner } => {
- handler.handle(msg.header, inner, &mut msg_writer);
+ // As we receive messages, spawn a future for each
+ let executor = Executor::new();
+ future::block_on(executor.run(async {
+ while let Some(msg) = inp_recv.next().await {
+ executor
+ .spawn(handler.handle(msg.header, msg.body))
+ .detach();
}
- _ => todo!(),
- };
- }
+ }));
+ });
}
-pub fn init_handler<T: Handler, R: Read, W: Write>(reader: R, writer: W) -> (T, MsgWriter<W>) {
+/// Initialises the handler synchronously.
+///
+/// This is done as a seperate step because we initially deserialize into a different type
+/// than our handler will accept, so there's no point spawning and immediately finishing threads.
+fn sync_init_handler<H: Handler, R: Read, W: Write>(
+ reader: R,
+ mut writer: W,
+) -> (H, Receiver<Message<H::Body>>) {
+ // Receive the init message
let deser = Deserializer::from_reader(reader);
- let mut deser = deser.into_iter::<Message<()>>();
- let Some(msg) = deser.next() else {
- panic!("stream ended before init message");
- };
- let Ok(msg) = msg else {
- panic!("{}", msg.unwrap_err());
- };
-
- let (node_id, node_ids, msg_id) = match msg.body {
- MaelstromBodyOr::MaelstromBody {
- inner:
+ let mut deser = deser.into_iter::<Message<MaelstromBody>>();
+ let (init_header, node_id, node_ids, init_msg_id) = match deser.next() {
+ Some(Ok(Message {
+ header,
+ body:
MaelstromBody::Init {
node_id,
node_ids,
msg_id,
},
- } => (node_id, node_ids, msg_id),
+ })) => (header, node_id, node_ids, msg_id),
+ Some(Err(e)) => panic!("invalid init message: {}", e),
_ => {
panic!("expected init message to be first message");
}
};
- let mut writer = MsgWriter::new(node_id.clone(), writer);
-
- writer.write(
- msg.header.src,
- &MaelstromBody::InitOk {
- msg_id: 0,
- in_reply_to: msg_id,
+ // Write the init_ok message
+ write_newline(
+ &mut writer,
+ &Message {
+ header: init_header.flip(),
+ body: MaelstromBody::InitOk {
+ in_reply_to: init_msg_id,
+ msg_id: gen_msg_id(),
+ },
},
);
- (T::init(node_id, node_ids, msg_id), writer)
+ // Create handler, and channel to go with it
+ let (send, recv) = channel::unbounded();
+
+ (
+ H::init(node_id.clone(), node_ids, Output::new(node_id, send)),
+ recv,
+ )
+}
+
+/// Receives JSON from a reader, and outputs the deserialised result to a channel
+fn recv_loop<M: for<'a> Deserialize<'a>>(reader: impl Read, channel: Sender<M>) {
+ let deser = Deserializer::from_reader(reader);
+ for msg in deser.into_iter() {
+ let msg = msg.unwrap();
+ channel.send_blocking(msg).unwrap();
+ }
+}
+
+/// Receives things to send, and outputs them as JSON to writer
+fn send_loop<M: Serialize>(mut writer: impl Write, channel: Receiver<M>) {
+ while let Ok(msg) = channel.recv_blocking() {
+ write_newline(&mut writer, msg);
+ }
+}
+
+/// Write a message to writer, followed by a newline
+fn write_newline(mut writer: impl Write, msg: impl Serialize) {
+ serde_json::to_writer(&mut writer, &msg).unwrap();
+ writer.write(&[b'\n']).unwrap();
}
diff --git a/common/src/msg.rs b/common/src/msg.rs
index 23db171..7e9863f 100644
--- a/common/src/msg.rs
+++ b/common/src/msg.rs
@@ -1,10 +1,13 @@
use serde::{Deserialize, Serialize};
+use smol::channel::Sender;
+
+use crate::msg_id::MessageID;
#[derive(Debug, Serialize, Deserialize)]
pub struct Message<B> {
#[serde(flatten)]
pub header: MessageHeader,
- pub body: MaelstromBodyOr<B>,
+ pub body: B,
}
#[derive(Debug, Serialize, Deserialize)]
@@ -24,27 +27,42 @@ impl MessageHeader {
}
#[derive(Debug, Serialize, Deserialize)]
-#[serde(untagged)]
-pub enum MaelstromBodyOr<B> {
- MaelstromBody {
- #[serde(flatten)]
- inner: MaelstromBody,
- },
- Other {
- #[serde(flatten)]
- inner: B,
- },
-}
-
-#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum MaelstromBody {
#[serde(rename = "init")]
Init {
node_id: String,
node_ids: Vec<String>,
- msg_id: usize,
+ msg_id: MessageID,
},
#[serde(rename = "init_ok")]
- InitOk { msg_id: usize, in_reply_to: usize },
+ InitOk {
+ msg_id: MessageID,
+ in_reply_to: MessageID,
+ },
+}
+
+pub struct Output<B> {
+ node_id: String,
+ channel: Sender<Message<B>>,
+}
+
+impl<B: Serialize + Clone> Output<B> {
+ pub fn new(node_id: String, channel: Sender<Message<B>>) -> Self {
+ Self { node_id, channel }
+ }
+
+ pub async fn send(&self, dst: &str, body: &B) {
+ self.send_raw(Message {
+ header: MessageHeader {
+ src: self.node_id.clone(),
+ dst: dst.to_string(),
+ },
+ body: body.clone(),
+ })
+ .await;
+ }
+ pub async fn send_raw(&self, msg: Message<B>) {
+ self.channel.send(msg).await.unwrap();
+ }
}
diff --git a/common/src/msg_id.rs b/common/src/msg_id.rs
new file mode 100644
index 0000000..e953f08
--- /dev/null
+++ b/common/src/msg_id.rs
@@ -0,0 +1,16 @@
+use std::time::{SystemTime, UNIX_EPOCH};
+
+use rand::{thread_rng, Rng};
+
+pub type MessageID = u64;
+
+pub fn gen_msg_id() -> MessageID {
+ // Time since UNIX epoch in milliseconds, (48 bits)
+ let now = SystemTime::now();
+ let time_millis: u128 = now.duration_since(UNIX_EPOCH).unwrap().as_millis();
+
+ // 16 bits of randomness
+ let rand: u16 = thread_rng().gen();
+
+ ((time_millis as u64) << 16) | (rand as u64)
+}