diff options
Diffstat (limited to 'broadcast/src')
-rw-r--r-- | broadcast/src/main.rs | 140 |
1 files changed, 140 insertions, 0 deletions
diff --git a/broadcast/src/main.rs b/broadcast/src/main.rs new file mode 100644 index 0000000..09c0268 --- /dev/null +++ b/broadcast/src/main.rs @@ -0,0 +1,140 @@ +#![feature(return_position_impl_trait_in_trait)] + +use smol::{lock::RwLock, prelude::*}; +use std::collections::{HashMap, HashSet}; + +use common::{ + msg::{MessageHeader, Output}, + run_server, Handler, +}; +use serde::{Deserialize, Serialize}; + +type BroadcastTarget = usize; + +fn main() { + run_server::<BroadcastHandler>(); +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +#[serde(tag = "type")] +pub enum BroadcastBody { + #[serde(rename = "broadcast")] + Broadcast { + msg_id: Option<usize>, + message: BroadcastTarget, + }, + + #[serde(rename = "broadcast_ok")] + BroadcastOk { in_reply_to: usize }, + + #[serde(rename = "topology")] + Topology { + msg_id: Option<usize>, + topology: HashMap<String, Vec<String>>, + }, + + #[serde(rename = "topology_ok")] + TopologyOk { in_reply_to: usize }, + + #[serde(rename = "read")] + Read { msg_id: usize }, + + #[serde(rename = "read_ok")] + ReadOk { + in_reply_to: usize, + messages: HashSet<BroadcastTarget>, + }, +} + +pub struct BroadcastHandler { + node_id: String, + seen: RwLock<HashSet<BroadcastTarget>>, + broadcast_targets: RwLock<Vec<String>>, + output: Output<BroadcastBody>, +} + +impl Handler for BroadcastHandler { + type Body = BroadcastBody; + + fn init(node_id: String, mut node_ids: Vec<String>, output: Output<Self::Body>) -> Self { + node_ids.retain(|x| *x != node_id); + + BroadcastHandler { + node_id, + broadcast_targets: RwLock::new(node_ids), + seen: RwLock::new(HashSet::new()), + output, + } + } + + fn handle( + &self, + header: MessageHeader, + body: BroadcastBody, + ) -> impl Future<Output = ()> + Send { + async move { + match body { + BroadcastBody::Broadcast { msg_id, message } => { + self.seen.write().await.insert(message); + if let Some(msg_id) = msg_id { + self.output + .send( + &header.src, + &BroadcastBody::BroadcastOk { + in_reply_to: msg_id, + }, + ) + .await; + } + + for target in self.broadcast_targets.read().await.iter() { + self.output + .send( + target, + &BroadcastBody::Broadcast { + msg_id: None, + message, + }, + ) + .await; + } + } + BroadcastBody::Topology { + msg_id, + mut topology, + } => { + if let Some(broadcast_targets) = topology.remove(&self.node_id) { + *self.broadcast_targets.write().await = broadcast_targets; + } + + if let Some(msg_id) = msg_id { + self.output + .send( + &header.src, + &BroadcastBody::TopologyOk { + in_reply_to: msg_id, + }, + ) + .await; + } + } + + BroadcastBody::Read { msg_id } => { + self.output + .send( + &header.src, + &BroadcastBody::ReadOk { + in_reply_to: msg_id, + messages: self.seen.read().await.clone(), + }, + ) + .await + } + + BroadcastBody::BroadcastOk { .. } => {} + BroadcastBody::TopologyOk { .. } => {} + BroadcastBody::ReadOk { .. } => {} + } + } + } +} |