diff options
Diffstat (limited to 'broadcast/src/handler.rs')
-rw-r--r-- | broadcast/src/handler.rs | 167 |
1 files changed, 167 insertions, 0 deletions
diff --git a/broadcast/src/handler.rs b/broadcast/src/handler.rs new file mode 100644 index 0000000..09c66b9 --- /dev/null +++ b/broadcast/src/handler.rs @@ -0,0 +1,167 @@ +use smol::{ + lock::{Mutex, RwLock}, + prelude::*, + Task, Timer, +}; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, + time::Duration, +}; + +use crate::msg::*; + +use common::{ + msg::{MessageHeader, Output}, + msg_id::{gen_msg_id, MessageID}, + Handler, +}; + +const RETRY_TIMEOUT: u64 = 2; + +pub struct BroadcastHandler { + node_id: String, + seen: RwLock<HashSet<BroadcastTarget>>, + broadcast_targets: RwLock<Vec<String>>, + output: Output<BroadcastBody>, + attempted_broadcasts: Mutex<HashMap<MessageID, Task<()>>>, +} + +impl Handler for BroadcastHandler { + type Body = BroadcastBody; + + fn init(node_id: String, mut node_ids: Vec<String>, output: Output<Self::Body>) -> Self { + node_ids.retain(|x| *x != node_id); + + BroadcastHandler { + node_id, + broadcast_targets: RwLock::new(node_ids), + seen: RwLock::new(HashSet::new()), + output, + attempted_broadcasts: Mutex::default(), + } + } + + fn handle<'a>( + self: &'a Arc<Self>, + header: MessageHeader, + body: BroadcastBody, + ) -> impl Future<Output = ()> + Send + 'a { + async move { + match body { + BroadcastBody::Broadcast { + msg_id: Some(msg_id), + message, + } => { + self.receive_broadcast(&header.src, message).await; + self.send_broadcast_ok(&header.src, msg_id).await; + } + BroadcastBody::Broadcast { + msg_id: None, + message, + } => { + self.receive_broadcast(&header.src, message).await; + } + + BroadcastBody::Topology { + msg_id, + mut topology, + } => { + // Start using the new topology + if let Some(broadcast_targets) = topology.remove(&self.node_id) { + *self.broadcast_targets.write().await = broadcast_targets; + } + + // Send reply if needed + if let Some(msg_id) = msg_id { + self.output + .send( + &header.src, + &BroadcastBody::TopologyOk { + in_reply_to: msg_id, + }, + ) + .await; + } + } + + BroadcastBody::Read { msg_id } => { + // Send all received messages back + self.output + .send( + &header.src, + &BroadcastBody::ReadOk { + in_reply_to: msg_id, + messages: self.seen.read().await.clone(), + }, + ) + .await + } + + // Ignore OK messages - we never actually request them + BroadcastBody::BroadcastOk { in_reply_to } => { + if let Some(task) = self.attempted_broadcasts.lock().await.remove(&in_reply_to) + { + task.cancel().await; + } + } + BroadcastBody::TopologyOk { .. } => {} + BroadcastBody::ReadOk { .. } => {} + } + } + } +} + +impl BroadcastHandler { + /// Reply with a broadcast OK message + async fn send_broadcast_ok(&self, src: &str, msg_id: MessageID) { + self.output + .send( + &src, + &BroadcastBody::BroadcastOk { + in_reply_to: msg_id, + }, + ) + .await; + } + + /// Receive a given message, and broadcast it onwards if it is new + async fn receive_broadcast(self: &Arc<Self>, src: &str, message: BroadcastTarget) { + let new = self.seen.write().await.insert(message); + if !new { + return; + } + + // Ensure we don't keep holding the read lock + let targets = self.broadcast_targets.read().await.clone(); + + // Race all send futures + let mut tasks = self.attempted_broadcasts.lock().await; + for target in targets.into_iter() { + if &target == &src { + return; + } + + let msg_id = gen_msg_id(); + let this = self.clone(); + tasks.insert( + msg_id, + smol::spawn(async move { + loop { + this.output + .send( + &target, + &BroadcastBody::Broadcast { + msg_id: None, + message, + }, + ) + .await; + + Timer::after(Duration::from_secs(RETRY_TIMEOUT)).await; + } + }), + ); + } + } +} |