diff options
Diffstat (limited to 'broadcast/src/handler.rs')
-rw-r--r-- | broadcast/src/handler.rs | 90 |
1 files changed, 28 insertions, 62 deletions
diff --git a/broadcast/src/handler.rs b/broadcast/src/handler.rs index e71edef..2b470fb 100644 --- a/broadcast/src/handler.rs +++ b/broadcast/src/handler.rs @@ -1,4 +1,3 @@ -use elsa::sync::FrozenMap; use smol::{ future, lock::{Mutex, RwLock}, @@ -11,11 +10,7 @@ use std::{ time::Duration, }; -use crate::{ - batch::MessageBatch, - msg::*, - topology::{NodeId, Topology}, -}; +use crate::{batch::MessageBatch, msg::*, topology::Topology}; use common::{ msg::{MessageHeader, Output}, @@ -29,9 +24,8 @@ pub struct BroadcastHandler { node_id: String, seen: RwLock<HashSet<BroadcastTarget>>, topology: Topology, - batch: FrozenMap<NodeId, Box<Mutex<MessageBatch>>>, + batch: Mutex<MessageBatch>, pub(crate) output: Output<BroadcastBody>, - max_message_delay: Duration, attempted_broadcasts: Mutex<HashMap<MessageID, Task<()>>>, } @@ -42,13 +36,12 @@ impl Handler for BroadcastHandler { let max_message_delay = MAX_STABLE_DELAY_MS / (node_ids.len() as f32).sqrt() as u32; let this = Arc::new(Self { - max_message_delay, node_id, topology: Topology::dense(node_ids.clone()), seen: RwLock::new(HashSet::new()), output, attempted_broadcasts: Mutex::default(), - batch: FrozenMap::new(), + batch: Mutex::new(MessageBatch::new(max_message_delay)), }); smol::spawn(this.clone().poll_batch()).detach(); @@ -68,7 +61,7 @@ impl Handler for BroadcastHandler { message, } => { future::zip( - self.receive_broadcast(&header.src, message), + self.receive_broadcast(message), self.send_broadcast_ok(&header.src, msg_id), ) .await; @@ -79,7 +72,7 @@ impl Handler for BroadcastHandler { messages, } => { future::zip( - self.receive_broadcast_batch(&header.src, messages), + self.receive_broadcast_batch(messages), self.send_broadcast_ok(&header.src, msg_id), ) .await; @@ -89,14 +82,14 @@ impl Handler for BroadcastHandler { msg_id: None, message, } => { - self.receive_broadcast(&header.src, message).await; + self.receive_broadcast(message).await; } BroadcastBody::BroadcastBatch { msg_id: None, messages, } => { - self.receive_broadcast_batch(&header.src, messages).await; + self.receive_broadcast_batch(messages).await; } BroadcastBody::Topology { msg_id, topology } => { @@ -159,76 +152,49 @@ impl BroadcastHandler { } /// Receive a given message, and broadcast it onwards if it is new - async fn receive_broadcast(self: &Arc<Self>, src: &String, message: BroadcastTarget) { + async fn receive_broadcast(self: &Arc<Self>, message: BroadcastTarget) { let new = self.seen.write().await.insert(message); if !new { return; } - for dest in self.topology.targets(&self.node_id, &src).await { - let batch_lock = self.get_batch_lock(&dest); - - let mut batch = batch_lock.lock().await; - batch.add(message); - } + let mut batch = self.batch.lock().await; + batch.add(message); } - async fn receive_broadcast_batch( - self: &Arc<Self>, - src: &String, - message: Vec<BroadcastTarget>, - ) { + async fn receive_broadcast_batch(self: &Arc<Self>, message: Vec<BroadcastTarget>) { + let mut batch = self.batch.lock().await; let mut seen = self.seen.write().await; for message in message.into_iter() { if seen.insert(message) { - for dest in self.topology.targets(&self.node_id, &src).await { - let batch_lock = self.get_batch_lock(&dest); - let mut batch = batch_lock.lock().await; - batch.add(message); - } + batch.add(message); } } } - fn get_batch_lock(&self, key: &String) -> &Mutex<MessageBatch> { - match self.batch.get(key) { - Some(x) => x, - None => { - self.batch.insert( - key.clone(), - Box::new(Mutex::new(MessageBatch::new(self.max_message_delay))), - ); + async fn poll_batch(self: Arc<Self>) { + loop { + let mut batch = self.batch.lock().await; + self.do_batch_check(&mut batch).await; - self.batch.get(key).unwrap() - } + let time = batch.sleep_time(); + drop(batch); + + Timer::after(time).await; } } - async fn poll_batch(self: Arc<Self>) { - loop { + async fn do_batch_check(self: &Arc<Self>, batch: &mut MessageBatch) { + if batch.should_broadcast() { let mut tasks = self.attempted_broadcasts.lock().await; - let mut min_sleep_time = self.max_message_delay; - for key in self.batch.keys_cloned() { - let batch_lock = self.batch.get(&key).unwrap(); - let mut batch = batch_lock.lock().await; - if batch.should_broadcast() { - let msg_id = gen_msg_id(); - let this = self.clone(); - tasks.insert(msg_id, smol::spawn(batch.broadcast(this, key, msg_id))); - - batch.clear(); - } - - let batch_sleep_time = batch.sleep_time(); - if batch_sleep_time < min_sleep_time { - min_sleep_time = batch_sleep_time; - } + for target in self.topology.all_targets(&self.node_id).await { + let msg_id = gen_msg_id(); + let this = self.clone(); + tasks.insert(msg_id, smol::spawn(batch.broadcast(this, target, msg_id))); } - drop(tasks); - - Timer::after(min_sleep_time).await; + batch.clear(); } } } |