From d3b2274504d16770ff1a3162cae4053a0820c284 Mon Sep 17 00:00:00 2001 From: Aria Date: Fri, 13 Oct 2023 01:33:56 +0100 Subject: wip: broadcast --- broadcast/src/main.rs | 140 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 140 insertions(+) create mode 100644 broadcast/src/main.rs (limited to 'broadcast/src/main.rs') diff --git a/broadcast/src/main.rs b/broadcast/src/main.rs new file mode 100644 index 0000000..09c0268 --- /dev/null +++ b/broadcast/src/main.rs @@ -0,0 +1,140 @@ +#![feature(return_position_impl_trait_in_trait)] + +use smol::{lock::RwLock, prelude::*}; +use std::collections::{HashMap, HashSet}; + +use common::{ + msg::{MessageHeader, Output}, + run_server, Handler, +}; +use serde::{Deserialize, Serialize}; + +type BroadcastTarget = usize; + +fn main() { + run_server::(); +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +#[serde(tag = "type")] +pub enum BroadcastBody { + #[serde(rename = "broadcast")] + Broadcast { + msg_id: Option, + message: BroadcastTarget, + }, + + #[serde(rename = "broadcast_ok")] + BroadcastOk { in_reply_to: usize }, + + #[serde(rename = "topology")] + Topology { + msg_id: Option, + topology: HashMap>, + }, + + #[serde(rename = "topology_ok")] + TopologyOk { in_reply_to: usize }, + + #[serde(rename = "read")] + Read { msg_id: usize }, + + #[serde(rename = "read_ok")] + ReadOk { + in_reply_to: usize, + messages: HashSet, + }, +} + +pub struct BroadcastHandler { + node_id: String, + seen: RwLock>, + broadcast_targets: RwLock>, + output: Output, +} + +impl Handler for BroadcastHandler { + type Body = BroadcastBody; + + fn init(node_id: String, mut node_ids: Vec, output: Output) -> Self { + node_ids.retain(|x| *x != node_id); + + BroadcastHandler { + node_id, + broadcast_targets: RwLock::new(node_ids), + seen: RwLock::new(HashSet::new()), + output, + } + } + + fn handle( + &self, + header: MessageHeader, + body: BroadcastBody, + ) -> impl Future + Send { + async move { + match body { + BroadcastBody::Broadcast { msg_id, message } => { + self.seen.write().await.insert(message); + if let Some(msg_id) = msg_id { + self.output + .send( + &header.src, + &BroadcastBody::BroadcastOk { + in_reply_to: msg_id, + }, + ) + .await; + } + + for target in self.broadcast_targets.read().await.iter() { + self.output + .send( + target, + &BroadcastBody::Broadcast { + msg_id: None, + message, + }, + ) + .await; + } + } + BroadcastBody::Topology { + msg_id, + mut topology, + } => { + if let Some(broadcast_targets) = topology.remove(&self.node_id) { + *self.broadcast_targets.write().await = broadcast_targets; + } + + if let Some(msg_id) = msg_id { + self.output + .send( + &header.src, + &BroadcastBody::TopologyOk { + in_reply_to: msg_id, + }, + ) + .await; + } + } + + BroadcastBody::Read { msg_id } => { + self.output + .send( + &header.src, + &BroadcastBody::ReadOk { + in_reply_to: msg_id, + messages: self.seen.read().await.clone(), + }, + ) + .await + } + + BroadcastBody::BroadcastOk { .. } => {} + BroadcastBody::TopologyOk { .. } => {} + BroadcastBody::ReadOk { .. } => {} + } + } + } +} -- cgit v1.2.3