diff options
Diffstat (limited to 'broadcast/src/handler.rs')
-rw-r--r-- | broadcast/src/handler.rs | 90 |
1 files changed, 62 insertions, 28 deletions
diff --git a/broadcast/src/handler.rs b/broadcast/src/handler.rs index 2b470fb..e71edef 100644 --- a/broadcast/src/handler.rs +++ b/broadcast/src/handler.rs @@ -1,3 +1,4 @@ +use elsa::sync::FrozenMap; use smol::{ future, lock::{Mutex, RwLock}, @@ -10,7 +11,11 @@ use std::{ time::Duration, }; -use crate::{batch::MessageBatch, msg::*, topology::Topology}; +use crate::{ + batch::MessageBatch, + msg::*, + topology::{NodeId, Topology}, +}; use common::{ msg::{MessageHeader, Output}, @@ -24,8 +29,9 @@ pub struct BroadcastHandler { node_id: String, seen: RwLock<HashSet<BroadcastTarget>>, topology: Topology, - batch: Mutex<MessageBatch>, + batch: FrozenMap<NodeId, Box<Mutex<MessageBatch>>>, pub(crate) output: Output<BroadcastBody>, + max_message_delay: Duration, attempted_broadcasts: Mutex<HashMap<MessageID, Task<()>>>, } @@ -36,12 +42,13 @@ impl Handler for BroadcastHandler { let max_message_delay = MAX_STABLE_DELAY_MS / (node_ids.len() as f32).sqrt() as u32; let this = Arc::new(Self { + max_message_delay, node_id, topology: Topology::dense(node_ids.clone()), seen: RwLock::new(HashSet::new()), output, attempted_broadcasts: Mutex::default(), - batch: Mutex::new(MessageBatch::new(max_message_delay)), + batch: FrozenMap::new(), }); smol::spawn(this.clone().poll_batch()).detach(); @@ -61,7 +68,7 @@ impl Handler for BroadcastHandler { message, } => { future::zip( - self.receive_broadcast(message), + self.receive_broadcast(&header.src, message), self.send_broadcast_ok(&header.src, msg_id), ) .await; @@ -72,7 +79,7 @@ impl Handler for BroadcastHandler { messages, } => { future::zip( - self.receive_broadcast_batch(messages), + self.receive_broadcast_batch(&header.src, messages), self.send_broadcast_ok(&header.src, msg_id), ) .await; @@ -82,14 +89,14 @@ impl Handler for BroadcastHandler { msg_id: None, message, } => { - self.receive_broadcast(message).await; + self.receive_broadcast(&header.src, message).await; } BroadcastBody::BroadcastBatch { msg_id: None, messages, } => { - self.receive_broadcast_batch(messages).await; + self.receive_broadcast_batch(&header.src, messages).await; } BroadcastBody::Topology { msg_id, topology } => { @@ -152,49 +159,76 @@ impl BroadcastHandler { } /// Receive a given message, and broadcast it onwards if it is new - async fn receive_broadcast(self: &Arc<Self>, message: BroadcastTarget) { + async fn receive_broadcast(self: &Arc<Self>, src: &String, message: BroadcastTarget) { let new = self.seen.write().await.insert(message); if !new { return; } - let mut batch = self.batch.lock().await; - batch.add(message); + for dest in self.topology.targets(&self.node_id, &src).await { + let batch_lock = self.get_batch_lock(&dest); + + let mut batch = batch_lock.lock().await; + batch.add(message); + } } - async fn receive_broadcast_batch(self: &Arc<Self>, message: Vec<BroadcastTarget>) { - let mut batch = self.batch.lock().await; + async fn receive_broadcast_batch( + self: &Arc<Self>, + src: &String, + message: Vec<BroadcastTarget>, + ) { let mut seen = self.seen.write().await; for message in message.into_iter() { if seen.insert(message) { - batch.add(message); + for dest in self.topology.targets(&self.node_id, &src).await { + let batch_lock = self.get_batch_lock(&dest); + let mut batch = batch_lock.lock().await; + batch.add(message); + } } } } - async fn poll_batch(self: Arc<Self>) { - loop { - let mut batch = self.batch.lock().await; - self.do_batch_check(&mut batch).await; - - let time = batch.sleep_time(); - drop(batch); + fn get_batch_lock(&self, key: &String) -> &Mutex<MessageBatch> { + match self.batch.get(key) { + Some(x) => x, + None => { + self.batch.insert( + key.clone(), + Box::new(Mutex::new(MessageBatch::new(self.max_message_delay))), + ); - Timer::after(time).await; + self.batch.get(key).unwrap() + } } } - async fn do_batch_check(self: &Arc<Self>, batch: &mut MessageBatch) { - if batch.should_broadcast() { + async fn poll_batch(self: Arc<Self>) { + loop { let mut tasks = self.attempted_broadcasts.lock().await; - for target in self.topology.all_targets(&self.node_id).await { - let msg_id = gen_msg_id(); - let this = self.clone(); - tasks.insert(msg_id, smol::spawn(batch.broadcast(this, target, msg_id))); + let mut min_sleep_time = self.max_message_delay; + for key in self.batch.keys_cloned() { + let batch_lock = self.batch.get(&key).unwrap(); + let mut batch = batch_lock.lock().await; + if batch.should_broadcast() { + let msg_id = gen_msg_id(); + let this = self.clone(); + tasks.insert(msg_id, smol::spawn(batch.broadcast(this, key, msg_id))); + + batch.clear(); + } + + let batch_sleep_time = batch.sleep_time(); + if batch_sleep_time < min_sleep_time { + min_sleep_time = batch_sleep_time; + } } - batch.clear(); + drop(tasks); + + Timer::after(min_sleep_time).await; } } } |