diff options
Diffstat (limited to 'broadcast/src/handler.rs')
-rw-r--r-- | broadcast/src/handler.rs | 119 |
1 files changed, 72 insertions, 47 deletions
diff --git a/broadcast/src/handler.rs b/broadcast/src/handler.rs index b38e2b7..4ee2664 100644 --- a/broadcast/src/handler.rs +++ b/broadcast/src/handler.rs @@ -10,35 +10,42 @@ use std::{ time::Duration, }; -use crate::{msg::*, topology::Topology}; +use crate::{batch::MessageBatch, msg::*, topology::Topology}; use common::{ msg::{MessageHeader, Output}, msg_id::{gen_msg_id, MessageID}, Handler, }; - -const RETRY_TIMEOUT_SECS: u64 = 1; +const MAX_STABLE_DELAY_MS: Duration = Duration::from_millis(700); pub struct BroadcastHandler { node_id: String, seen: RwLock<HashSet<BroadcastTarget>>, topology: Topology, - output: Output<BroadcastBody>, + batch: Mutex<MessageBatch>, + pub(crate) output: Output<BroadcastBody>, attempted_broadcasts: Mutex<HashMap<MessageID, Task<()>>>, } impl Handler for BroadcastHandler { type Body = BroadcastBody; - fn init(node_id: String, node_ids: Vec<String>, output: Output<Self::Body>) -> Self { - BroadcastHandler { + fn init(node_id: String, node_ids: Vec<String>, output: Output<Self::Body>) -> Arc<Self> { + let max_message_delay = MAX_STABLE_DELAY_MS / (node_ids.len() / node_ids.len()) as u32; + + let this = Arc::new(Self { node_id, topology: Topology::dense(node_ids.clone()), seen: RwLock::new(HashSet::new()), output, attempted_broadcasts: Mutex::default(), - } + batch: Mutex::new(MessageBatch::new(max_message_delay, 1000)), + }); + + smol::spawn(this.clone().poll_batch()).detach(); + + this } fn handle<'a>( @@ -51,20 +58,37 @@ impl Handler for BroadcastHandler { BroadcastBody::Broadcast { msg_id: Some(msg_id), message, - path, } => { future::zip( - self.receive_broadcast(&header.src, path, message), + self.receive_broadcast(message), self.send_broadcast_ok(&header.src, msg_id), ) .await; } + + BroadcastBody::BroadcastBatch { + msg_id: Some(msg_id), + messages, + } => { + future::zip( + self.receive_broadcast_batch(messages), + self.send_broadcast_ok(&header.src, msg_id), + ) + .await; + } + BroadcastBody::Broadcast { msg_id: None, message, - path, } => { - self.receive_broadcast(&header.src, path, message).await; + self.receive_broadcast(message).await; + } + + BroadcastBody::BroadcastBatch { + msg_id: None, + messages, + } => { + self.receive_broadcast_batch(messages).await; } BroadcastBody::Topology { msg_id, topology } => { @@ -127,48 +151,49 @@ impl BroadcastHandler { } /// Receive a given message, and broadcast it onwards if it is new - async fn receive_broadcast( - self: &Arc<Self>, - src: &str, - previous_path: Option<Vec<String>>, - message: BroadcastTarget, - ) { + async fn receive_broadcast(self: &Arc<Self>, message: BroadcastTarget) { let new = self.seen.write().await.insert(message); if !new { return; } - // Race all send futures - let mut previous_path = previous_path.unwrap_or_else(|| vec![]); - previous_path.push(src.to_string()); - let mut tasks = self.attempted_broadcasts.lock().await; - for target in self - .topology - .targets(&self.node_id, previous_path.iter().map(String::as_str)) - .await - { - let msg_id = gen_msg_id(); - let this = self.clone(); - let path = previous_path.clone(); - tasks.insert( - msg_id, - smol::spawn(async move { - loop { - this.output - .send( - &target, - &BroadcastBody::Broadcast { - msg_id: Some(msg_id), - message, - path: Some(path.clone()), - }, - ) - .await; + let mut batch = self.batch.lock().await; + batch.add(message); + } - Timer::after(Duration::from_secs(RETRY_TIMEOUT_SECS)).await; - } - }), - ); + async fn receive_broadcast_batch(self: &Arc<Self>, message: Vec<BroadcastTarget>) { + let mut batch = self.batch.lock().await; + let mut seen = self.seen.write().await; + let mut new = false; + + for message in message.into_iter() { + new |= seen.insert(message); + batch.add(message); + } + + if !new { + return; + } + } + + async fn poll_batch(self: Arc<Self>) { + loop { + let mut batch = self.batch.lock().await; + if batch.should_broadcast() { + let mut tasks = self.attempted_broadcasts.lock().await; + for target in self.topology.all_targets(&self.node_id).await { + let msg_id = gen_msg_id(); + let this = self.clone(); + tasks.insert(msg_id, smol::spawn(batch.broadcast(this, target, msg_id))); + } + + batch.clear(); + } + + let time = batch.sleep_time(); + drop(batch); + + Timer::after(time).await; } } } |