diff options
Diffstat (limited to 'broadcast/src')
-rw-r--r-- | broadcast/src/batch.rs | 1 | ||||
-rw-r--r-- | broadcast/src/handler.rs | 90 | ||||
-rw-r--r-- | broadcast/src/topology.rs | 30 |
3 files changed, 31 insertions, 90 deletions
diff --git a/broadcast/src/batch.rs b/broadcast/src/batch.rs index 42cc3b8..d69771d 100644 --- a/broadcast/src/batch.rs +++ b/broadcast/src/batch.rs @@ -50,7 +50,6 @@ impl MessageBatch { self.first_added .elapsed() .saturating_sub(self.max_message_delay) - .saturating_sub(Duration::from_millis(10)) } pub fn broadcast( diff --git a/broadcast/src/handler.rs b/broadcast/src/handler.rs index e71edef..2b470fb 100644 --- a/broadcast/src/handler.rs +++ b/broadcast/src/handler.rs @@ -1,4 +1,3 @@ -use elsa::sync::FrozenMap; use smol::{ future, lock::{Mutex, RwLock}, @@ -11,11 +10,7 @@ use std::{ time::Duration, }; -use crate::{ - batch::MessageBatch, - msg::*, - topology::{NodeId, Topology}, -}; +use crate::{batch::MessageBatch, msg::*, topology::Topology}; use common::{ msg::{MessageHeader, Output}, @@ -29,9 +24,8 @@ pub struct BroadcastHandler { node_id: String, seen: RwLock<HashSet<BroadcastTarget>>, topology: Topology, - batch: FrozenMap<NodeId, Box<Mutex<MessageBatch>>>, + batch: Mutex<MessageBatch>, pub(crate) output: Output<BroadcastBody>, - max_message_delay: Duration, attempted_broadcasts: Mutex<HashMap<MessageID, Task<()>>>, } @@ -42,13 +36,12 @@ impl Handler for BroadcastHandler { let max_message_delay = MAX_STABLE_DELAY_MS / (node_ids.len() as f32).sqrt() as u32; let this = Arc::new(Self { - max_message_delay, node_id, topology: Topology::dense(node_ids.clone()), seen: RwLock::new(HashSet::new()), output, attempted_broadcasts: Mutex::default(), - batch: FrozenMap::new(), + batch: Mutex::new(MessageBatch::new(max_message_delay)), }); smol::spawn(this.clone().poll_batch()).detach(); @@ -68,7 +61,7 @@ impl Handler for BroadcastHandler { message, } => { future::zip( - self.receive_broadcast(&header.src, message), + self.receive_broadcast(message), self.send_broadcast_ok(&header.src, msg_id), ) .await; @@ -79,7 +72,7 @@ impl Handler for BroadcastHandler { messages, } => { future::zip( - self.receive_broadcast_batch(&header.src, messages), + self.receive_broadcast_batch(messages), self.send_broadcast_ok(&header.src, msg_id), ) .await; @@ -89,14 +82,14 @@ impl Handler for BroadcastHandler { msg_id: None, message, } => { - self.receive_broadcast(&header.src, message).await; + self.receive_broadcast(message).await; } BroadcastBody::BroadcastBatch { msg_id: None, messages, } => { - self.receive_broadcast_batch(&header.src, messages).await; + self.receive_broadcast_batch(messages).await; } BroadcastBody::Topology { msg_id, topology } => { @@ -159,76 +152,49 @@ impl BroadcastHandler { } /// Receive a given message, and broadcast it onwards if it is new - async fn receive_broadcast(self: &Arc<Self>, src: &String, message: BroadcastTarget) { + async fn receive_broadcast(self: &Arc<Self>, message: BroadcastTarget) { let new = self.seen.write().await.insert(message); if !new { return; } - for dest in self.topology.targets(&self.node_id, &src).await { - let batch_lock = self.get_batch_lock(&dest); - - let mut batch = batch_lock.lock().await; - batch.add(message); - } + let mut batch = self.batch.lock().await; + batch.add(message); } - async fn receive_broadcast_batch( - self: &Arc<Self>, - src: &String, - message: Vec<BroadcastTarget>, - ) { + async fn receive_broadcast_batch(self: &Arc<Self>, message: Vec<BroadcastTarget>) { + let mut batch = self.batch.lock().await; let mut seen = self.seen.write().await; for message in message.into_iter() { if seen.insert(message) { - for dest in self.topology.targets(&self.node_id, &src).await { - let batch_lock = self.get_batch_lock(&dest); - let mut batch = batch_lock.lock().await; - batch.add(message); - } + batch.add(message); } } } - fn get_batch_lock(&self, key: &String) -> &Mutex<MessageBatch> { - match self.batch.get(key) { - Some(x) => x, - None => { - self.batch.insert( - key.clone(), - Box::new(Mutex::new(MessageBatch::new(self.max_message_delay))), - ); + async fn poll_batch(self: Arc<Self>) { + loop { + let mut batch = self.batch.lock().await; + self.do_batch_check(&mut batch).await; - self.batch.get(key).unwrap() - } + let time = batch.sleep_time(); + drop(batch); + + Timer::after(time).await; } } - async fn poll_batch(self: Arc<Self>) { - loop { + async fn do_batch_check(self: &Arc<Self>, batch: &mut MessageBatch) { + if batch.should_broadcast() { let mut tasks = self.attempted_broadcasts.lock().await; - let mut min_sleep_time = self.max_message_delay; - for key in self.batch.keys_cloned() { - let batch_lock = self.batch.get(&key).unwrap(); - let mut batch = batch_lock.lock().await; - if batch.should_broadcast() { - let msg_id = gen_msg_id(); - let this = self.clone(); - tasks.insert(msg_id, smol::spawn(batch.broadcast(this, key, msg_id))); - - batch.clear(); - } - - let batch_sleep_time = batch.sleep_time(); - if batch_sleep_time < min_sleep_time { - min_sleep_time = batch_sleep_time; - } + for target in self.topology.all_targets(&self.node_id).await { + let msg_id = gen_msg_id(); + let this = self.clone(); + tasks.insert(msg_id, smol::spawn(batch.broadcast(this, target, msg_id))); } - drop(tasks); - - Timer::after(min_sleep_time).await; + batch.clear(); } } } diff --git a/broadcast/src/topology.rs b/broadcast/src/topology.rs index bb10466..5e16666 100644 --- a/broadcast/src/topology.rs +++ b/broadcast/src/topology.rs @@ -14,40 +14,16 @@ impl Topology { for node_id in node_ids.iter() { top.insert(node_id.clone(), node_ids.iter().cloned().collect()); } - Self::filter_desc(&mut top); Topology(RwLock::new(top)) } /// Replace the current topology with a new one. - pub async fn replace(&self, mut new: TopologyDesc) { - Self::filter_desc(&mut new); + pub async fn replace(&self, new: TopologyDesc) { *self.0.write().await = new; } - fn filter_desc(desc: &mut TopologyDesc) { - for (node_id, neighbours) in desc.iter_mut() { - neighbours.remove(node_id); - } - } - - /// Get the next targets from the given topology, for a message - /// which has travelled across the given path and is now at node_id. - pub async fn targets(&self, node_id: &String, last_node_id: &String) -> HashSet<String> { - // Ensure we don't keep holding the read lock - let topology = self.0.read().await; - - // Get all nodes the last node sent it to - let visited = topology.get(last_node_id); - - let neighbours = topology.get(node_id).unwrap(); - match visited { - Some(visited) => neighbours - .difference(&visited) - .cloned() - .filter(|n| n != node_id) - .collect(), - None => neighbours.clone(), - } + pub async fn all_targets(&self, node_id: &NodeId) -> HashSet<String> { + self.0.read().await.get(node_id).unwrap().clone() } } |