use smol::{ future, lock::{Mutex, RwLock}, prelude::*, Task, Timer, }; use std::{ collections::{HashMap, HashSet}, sync::Arc, time::Duration, }; use crate::{msg::*, topology::Topology}; use common::{ msg::{MessageHeader, Output}, msg_id::{gen_msg_id, MessageID}, Handler, }; const RETRY_TIMEOUT_SECS: u64 = 1; pub struct BroadcastHandler { node_id: String, seen: RwLock>, topology: Topology, output: Output, attempted_broadcasts: Mutex>>, } impl Handler for BroadcastHandler { type Body = BroadcastBody; fn init(node_id: String, node_ids: Vec, output: Output) -> Self { BroadcastHandler { node_id, topology: Topology::dense(node_ids.clone()), seen: RwLock::new(HashSet::new()), output, attempted_broadcasts: Mutex::default(), } } fn handle<'a>( self: &'a Arc, header: MessageHeader, body: BroadcastBody, ) -> impl Future + Send + 'a { async move { match body { BroadcastBody::Broadcast { msg_id: Some(msg_id), message, path, } => { future::zip( self.receive_broadcast(&header.src, path, message), self.send_broadcast_ok(&header.src, msg_id), ) .await; } BroadcastBody::Broadcast { msg_id: None, message, path, } => { self.receive_broadcast(&header.src, path, message).await; } BroadcastBody::Topology { msg_id, topology } => { // Start using the new topology self.topology.replace(topology).await; // Send reply if needed if let Some(msg_id) = msg_id { self.output .send( &header.src, &BroadcastBody::TopologyOk { in_reply_to: msg_id, }, ) .await; } } BroadcastBody::Read { msg_id } => { // Send all received messages back self.output .send( &header.src, &BroadcastBody::ReadOk { in_reply_to: msg_id, messages: self.seen.read().await.clone(), }, ) .await } BroadcastBody::BroadcastOk { in_reply_to } => { // Stop retrying, if we still are if let Some(task) = self.attempted_broadcasts.lock().await.remove(&in_reply_to) { task.cancel().await; } } // Ignore other OK messages - we never actually request them BroadcastBody::TopologyOk { .. } => {} BroadcastBody::ReadOk { .. } => {} } } } } impl BroadcastHandler { /// Reply with a broadcast OK message async fn send_broadcast_ok(&self, src: &str, msg_id: MessageID) { self.output .send( &src, &BroadcastBody::BroadcastOk { in_reply_to: msg_id, }, ) .await; } /// Receive a given message, and broadcast it onwards if it is new async fn receive_broadcast( self: &Arc, src: &str, previous_path: Option>, message: BroadcastTarget, ) { let new = self.seen.write().await.insert(message); if !new { return; } // Race all send futures let mut previous_path = previous_path.unwrap_or_else(|| vec![]); previous_path.push(src.to_string()); let mut tasks = self.attempted_broadcasts.lock().await; for target in self .topology .targets(&self.node_id, previous_path.iter().map(String::as_str)) .await { let msg_id = gen_msg_id(); let this = self.clone(); let path = previous_path.clone(); tasks.insert( msg_id, smol::spawn(async move { loop { this.output .send( &target, &BroadcastBody::Broadcast { msg_id: Some(msg_id), message, path: Some(path.clone()), }, ) .await; Timer::after(Duration::from_secs(RETRY_TIMEOUT_SECS)).await; } }), ); } } }