use smol::{ future, lock::{Mutex, RwLock}, prelude::*, Task, Timer, }; use std::{ collections::{HashMap, HashSet}, sync::Arc, time::Duration, }; use crate::msg::*; use common::{ msg::{MessageHeader, Output}, msg_id::{gen_msg_id, MessageID}, Handler, }; const RETRY_TIMEOUT_SECS: u64 = 1; pub struct BroadcastHandler { node_id: String, seen: RwLock>, topology: RwLock>>, output: Output, attempted_broadcasts: Mutex>>, } impl Handler for BroadcastHandler { type Body = BroadcastBody; fn init(node_id: String, node_ids: Vec, output: Output) -> Self { // Initial topology assumes all nodes are neighbours let mut topology = HashMap::new(); for id in node_ids.iter() { topology.insert(id.clone(), node_ids.iter().cloned().collect()); } BroadcastHandler { node_id, topology: RwLock::new(topology), seen: RwLock::new(HashSet::new()), output, attempted_broadcasts: Mutex::default(), } } fn handle<'a>( self: &'a Arc, header: MessageHeader, body: BroadcastBody, ) -> impl Future + Send + 'a { async move { match body { BroadcastBody::Broadcast { msg_id: Some(msg_id), message, } => { future::zip( self.receive_broadcast(&header.src, message), self.send_broadcast_ok(&header.src, msg_id), ) .await; } BroadcastBody::Broadcast { msg_id: None, message, } => { self.receive_broadcast(&header.src, message).await; } BroadcastBody::Topology { msg_id, topology } => { // Start using the new topology *self.topology.write().await = topology; // Send reply if needed if let Some(msg_id) = msg_id { self.output .send( &header.src, &BroadcastBody::TopologyOk { in_reply_to: msg_id, }, ) .await; } } BroadcastBody::Read { msg_id } => { // Send all received messages back self.output .send( &header.src, &BroadcastBody::ReadOk { in_reply_to: msg_id, messages: self.seen.read().await.clone(), }, ) .await } // Ignore OK messages - we never actually request them BroadcastBody::BroadcastOk { in_reply_to } => { if let Some(task) = self.attempted_broadcasts.lock().await.remove(&in_reply_to) { task.cancel().await; } } BroadcastBody::TopologyOk { .. } => {} BroadcastBody::ReadOk { .. } => {} } } } } impl BroadcastHandler { /// Reply with a broadcast OK message async fn send_broadcast_ok(&self, src: &str, msg_id: MessageID) { self.output .send( &src, &BroadcastBody::BroadcastOk { in_reply_to: msg_id, }, ) .await; } /// Receive a given message, and broadcast it onwards if it is new async fn receive_broadcast(self: &Arc, src: &str, message: BroadcastTarget) { let new = self.seen.write().await.insert(message); if !new { return; } // Ensure we don't keep holding the read lock let mut targets = self.topology.read().await.clone(); // Only send to neighbours that the source has not sent to. // This isn't technically optimal, but its as close as we can get without // tracking the path of each broadcast message. let our_targets = targets.remove(&self.node_id).unwrap(); let their_targets = targets .remove(&src.to_string()) .unwrap_or_else(|| HashSet::new()); // Race all send futures let mut tasks = self.attempted_broadcasts.lock().await; for target in our_targets.into_iter() { if &target == &src || &target == &self.node_id || their_targets.contains(&target) { continue; } let msg_id = gen_msg_id(); let this = self.clone(); tasks.insert( msg_id, smol::spawn(async move { loop { this.output .send( &target, &BroadcastBody::Broadcast { msg_id: Some(msg_id), message, }, ) .await; Timer::after(Duration::from_secs(RETRY_TIMEOUT_SECS)).await; } }), ); } } }