use smol::{ future, lock::{Mutex, RwLock}, prelude::*, Task, Timer, }; use std::{ collections::{HashMap, HashSet}, sync::Arc, time::Duration, }; use crate::{batch::MessageBatch, msg::*, topology::Topology}; use common::{ msg::{MessageHeader, Output}, msg_id::{gen_msg_id, MessageID}, Handler, }; const MAX_STABLE_DELAY_MS: Duration = Duration::from_millis(700); pub struct BroadcastHandler { node_id: String, seen: RwLock>, topology: Topology, batch: Mutex, pub(crate) output: Output, attempted_broadcasts: Mutex>>, } impl Handler for BroadcastHandler { type Body = BroadcastBody; fn init(node_id: String, node_ids: Vec, output: Output) -> Arc { let max_message_delay = MAX_STABLE_DELAY_MS / (node_ids.len() / node_ids.len()) as u32; let this = Arc::new(Self { node_id, topology: Topology::dense(node_ids.clone()), seen: RwLock::new(HashSet::new()), output, attempted_broadcasts: Mutex::default(), batch: Mutex::new(MessageBatch::new(max_message_delay, 1000)), }); smol::spawn(this.clone().poll_batch()).detach(); this } fn handle<'a>( self: &'a Arc, header: MessageHeader, body: BroadcastBody, ) -> impl Future + Send + 'a { async move { match body { BroadcastBody::Broadcast { msg_id: Some(msg_id), message, } => { future::zip( self.receive_broadcast(message), self.send_broadcast_ok(&header.src, msg_id), ) .await; } BroadcastBody::BroadcastBatch { msg_id: Some(msg_id), messages, } => { future::zip( self.receive_broadcast_batch(messages), self.send_broadcast_ok(&header.src, msg_id), ) .await; } BroadcastBody::Broadcast { msg_id: None, message, } => { self.receive_broadcast(message).await; } BroadcastBody::BroadcastBatch { msg_id: None, messages, } => { self.receive_broadcast_batch(messages).await; } BroadcastBody::Topology { msg_id, topology } => { // Start using the new topology self.topology.replace(topology).await; // Send reply if needed if let Some(msg_id) = msg_id { self.output .send( &header.src, &BroadcastBody::TopologyOk { in_reply_to: msg_id, }, ) .await; } } BroadcastBody::Read { msg_id } => { // Send all received messages back self.output .send( &header.src, &BroadcastBody::ReadOk { in_reply_to: msg_id, messages: self.seen.read().await.clone(), }, ) .await } BroadcastBody::BroadcastOk { in_reply_to } => { // Stop retrying, if we still are if let Some(task) = self.attempted_broadcasts.lock().await.remove(&in_reply_to) { task.cancel().await; } } // Ignore other OK messages - we never actually request them BroadcastBody::TopologyOk { .. } => {} BroadcastBody::ReadOk { .. } => {} } } } } impl BroadcastHandler { /// Reply with a broadcast OK message async fn send_broadcast_ok(&self, src: &str, msg_id: MessageID) { self.output .send( &src, &BroadcastBody::BroadcastOk { in_reply_to: msg_id, }, ) .await; } /// Receive a given message, and broadcast it onwards if it is new async fn receive_broadcast(self: &Arc, message: BroadcastTarget) { let new = self.seen.write().await.insert(message); if !new { return; } let mut batch = self.batch.lock().await; batch.add(message); } async fn receive_broadcast_batch(self: &Arc, message: Vec) { let mut batch = self.batch.lock().await; let mut seen = self.seen.write().await; let mut new = false; for message in message.into_iter() { new |= seen.insert(message); batch.add(message); } if !new { return; } } async fn poll_batch(self: Arc) { loop { let mut batch = self.batch.lock().await; if batch.should_broadcast() { let mut tasks = self.attempted_broadcasts.lock().await; for target in self.topology.all_targets(&self.node_id).await { let msg_id = gen_msg_id(); let this = self.clone(); tasks.insert(msg_id, smol::spawn(batch.broadcast(this, target, msg_id))); } batch.clear(); } let time = batch.sleep_time(); drop(batch); Timer::after(time).await; } } }