diff options
author | Aria <me@aria.rip> | 2023-10-20 00:33:52 +0100 |
---|---|---|
committer | Aria <me@aria.rip> | 2023-10-20 00:33:52 +0100 |
commit | 4b6e257263f7bb6eaebb9be672f88bda85c77586 (patch) | |
tree | ba3e5130284a96e73418be1440a97646b1727bdf /broadcast/src | |
parent | bb54fd5b450ea1b126f7c31845f12893bf061058 (diff) |
broadcast batching with minimal set of messages needed
Diffstat (limited to 'broadcast/src')
-rw-r--r-- | broadcast/src/batch.rs | 1 | ||||
-rw-r--r-- | broadcast/src/handler.rs | 90 | ||||
-rw-r--r-- | broadcast/src/topology.rs | 30 |
3 files changed, 90 insertions, 31 deletions
diff --git a/broadcast/src/batch.rs b/broadcast/src/batch.rs index d69771d..42cc3b8 100644 --- a/broadcast/src/batch.rs +++ b/broadcast/src/batch.rs @@ -50,6 +50,7 @@ impl MessageBatch { self.first_added .elapsed() .saturating_sub(self.max_message_delay) + .saturating_sub(Duration::from_millis(10)) } pub fn broadcast( diff --git a/broadcast/src/handler.rs b/broadcast/src/handler.rs index 2b470fb..e71edef 100644 --- a/broadcast/src/handler.rs +++ b/broadcast/src/handler.rs @@ -1,3 +1,4 @@ +use elsa::sync::FrozenMap; use smol::{ future, lock::{Mutex, RwLock}, @@ -10,7 +11,11 @@ use std::{ time::Duration, }; -use crate::{batch::MessageBatch, msg::*, topology::Topology}; +use crate::{ + batch::MessageBatch, + msg::*, + topology::{NodeId, Topology}, +}; use common::{ msg::{MessageHeader, Output}, @@ -24,8 +29,9 @@ pub struct BroadcastHandler { node_id: String, seen: RwLock<HashSet<BroadcastTarget>>, topology: Topology, - batch: Mutex<MessageBatch>, + batch: FrozenMap<NodeId, Box<Mutex<MessageBatch>>>, pub(crate) output: Output<BroadcastBody>, + max_message_delay: Duration, attempted_broadcasts: Mutex<HashMap<MessageID, Task<()>>>, } @@ -36,12 +42,13 @@ impl Handler for BroadcastHandler { let max_message_delay = MAX_STABLE_DELAY_MS / (node_ids.len() as f32).sqrt() as u32; let this = Arc::new(Self { + max_message_delay, node_id, topology: Topology::dense(node_ids.clone()), seen: RwLock::new(HashSet::new()), output, attempted_broadcasts: Mutex::default(), - batch: Mutex::new(MessageBatch::new(max_message_delay)), + batch: FrozenMap::new(), }); smol::spawn(this.clone().poll_batch()).detach(); @@ -61,7 +68,7 @@ impl Handler for BroadcastHandler { message, } => { future::zip( - self.receive_broadcast(message), + self.receive_broadcast(&header.src, message), self.send_broadcast_ok(&header.src, msg_id), ) .await; @@ -72,7 +79,7 @@ impl Handler for BroadcastHandler { messages, } => { future::zip( - self.receive_broadcast_batch(messages), + self.receive_broadcast_batch(&header.src, messages), self.send_broadcast_ok(&header.src, msg_id), ) .await; @@ -82,14 +89,14 @@ impl Handler for BroadcastHandler { msg_id: None, message, } => { - self.receive_broadcast(message).await; + self.receive_broadcast(&header.src, message).await; } BroadcastBody::BroadcastBatch { msg_id: None, messages, } => { - self.receive_broadcast_batch(messages).await; + self.receive_broadcast_batch(&header.src, messages).await; } BroadcastBody::Topology { msg_id, topology } => { @@ -152,49 +159,76 @@ impl BroadcastHandler { } /// Receive a given message, and broadcast it onwards if it is new - async fn receive_broadcast(self: &Arc<Self>, message: BroadcastTarget) { + async fn receive_broadcast(self: &Arc<Self>, src: &String, message: BroadcastTarget) { let new = self.seen.write().await.insert(message); if !new { return; } - let mut batch = self.batch.lock().await; - batch.add(message); + for dest in self.topology.targets(&self.node_id, &src).await { + let batch_lock = self.get_batch_lock(&dest); + + let mut batch = batch_lock.lock().await; + batch.add(message); + } } - async fn receive_broadcast_batch(self: &Arc<Self>, message: Vec<BroadcastTarget>) { - let mut batch = self.batch.lock().await; + async fn receive_broadcast_batch( + self: &Arc<Self>, + src: &String, + message: Vec<BroadcastTarget>, + ) { let mut seen = self.seen.write().await; for message in message.into_iter() { if seen.insert(message) { - batch.add(message); + for dest in self.topology.targets(&self.node_id, &src).await { + let batch_lock = self.get_batch_lock(&dest); + let mut batch = batch_lock.lock().await; + batch.add(message); + } } } } - async fn poll_batch(self: Arc<Self>) { - loop { - let mut batch = self.batch.lock().await; - self.do_batch_check(&mut batch).await; - - let time = batch.sleep_time(); - drop(batch); + fn get_batch_lock(&self, key: &String) -> &Mutex<MessageBatch> { + match self.batch.get(key) { + Some(x) => x, + None => { + self.batch.insert( + key.clone(), + Box::new(Mutex::new(MessageBatch::new(self.max_message_delay))), + ); - Timer::after(time).await; + self.batch.get(key).unwrap() + } } } - async fn do_batch_check(self: &Arc<Self>, batch: &mut MessageBatch) { - if batch.should_broadcast() { + async fn poll_batch(self: Arc<Self>) { + loop { let mut tasks = self.attempted_broadcasts.lock().await; - for target in self.topology.all_targets(&self.node_id).await { - let msg_id = gen_msg_id(); - let this = self.clone(); - tasks.insert(msg_id, smol::spawn(batch.broadcast(this, target, msg_id))); + let mut min_sleep_time = self.max_message_delay; + for key in self.batch.keys_cloned() { + let batch_lock = self.batch.get(&key).unwrap(); + let mut batch = batch_lock.lock().await; + if batch.should_broadcast() { + let msg_id = gen_msg_id(); + let this = self.clone(); + tasks.insert(msg_id, smol::spawn(batch.broadcast(this, key, msg_id))); + + batch.clear(); + } + + let batch_sleep_time = batch.sleep_time(); + if batch_sleep_time < min_sleep_time { + min_sleep_time = batch_sleep_time; + } } - batch.clear(); + drop(tasks); + + Timer::after(min_sleep_time).await; } } } diff --git a/broadcast/src/topology.rs b/broadcast/src/topology.rs index 5e16666..bb10466 100644 --- a/broadcast/src/topology.rs +++ b/broadcast/src/topology.rs @@ -14,16 +14,40 @@ impl Topology { for node_id in node_ids.iter() { top.insert(node_id.clone(), node_ids.iter().cloned().collect()); } + Self::filter_desc(&mut top); Topology(RwLock::new(top)) } /// Replace the current topology with a new one. - pub async fn replace(&self, new: TopologyDesc) { + pub async fn replace(&self, mut new: TopologyDesc) { + Self::filter_desc(&mut new); *self.0.write().await = new; } - pub async fn all_targets(&self, node_id: &NodeId) -> HashSet<String> { - self.0.read().await.get(node_id).unwrap().clone() + fn filter_desc(desc: &mut TopologyDesc) { + for (node_id, neighbours) in desc.iter_mut() { + neighbours.remove(node_id); + } + } + + /// Get the next targets from the given topology, for a message + /// which has travelled across the given path and is now at node_id. + pub async fn targets(&self, node_id: &String, last_node_id: &String) -> HashSet<String> { + // Ensure we don't keep holding the read lock + let topology = self.0.read().await; + + // Get all nodes the last node sent it to + let visited = topology.get(last_node_id); + + let neighbours = topology.get(node_id).unwrap(); + match visited { + Some(visited) => neighbours + .difference(&visited) + .cloned() + .filter(|n| n != node_id) + .collect(), + None => neighbours.clone(), + } } } |