From 4b6e257263f7bb6eaebb9be672f88bda85c77586 Mon Sep 17 00:00:00 2001 From: Aria Date: Fri, 20 Oct 2023 00:33:52 +0100 Subject: broadcast batching with minimal set of messages needed --- Cargo.lock | 16 +++++++++ broadcast/Cargo.toml | 3 +- broadcast/src/batch.rs | 1 + broadcast/src/handler.rs | 90 ++++++++++++++++++++++++++++++++--------------- broadcast/src/topology.rs | 30 ++++++++++++++-- 5 files changed, 108 insertions(+), 32 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 191c81a..d9b0d5f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -165,6 +165,7 @@ name = "broadcast" version = "0.1.0" dependencies = [ "common", + "elsa", "futures", "serde", "serde_json", @@ -214,6 +215,15 @@ dependencies = [ "serde_json", ] +[[package]] +name = "elsa" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "714f766f3556b44e7e4776ad133fcc3445a489517c25c704ace411bb14790194" +dependencies = [ + "stable_deref_trait", +] + [[package]] name = "errno" version = "0.3.5" @@ -617,6 +627,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "stable_deref_trait" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" + [[package]] name = "syn" version = "2.0.29" diff --git a/broadcast/Cargo.toml b/broadcast/Cargo.toml index 06e4426..718f7d1 100644 --- a/broadcast/Cargo.toml +++ b/broadcast/Cargo.toml @@ -10,4 +10,5 @@ smol = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } common = { path = "../common/" } -futures = { version = "0.3.28", default_features = false, features = ["std"] } \ No newline at end of file +futures = { version = "0.3.28", default_features = false, features = ["std"] } +elsa = "1.9.0" diff --git a/broadcast/src/batch.rs b/broadcast/src/batch.rs index d69771d..42cc3b8 100644 --- a/broadcast/src/batch.rs +++ b/broadcast/src/batch.rs @@ -50,6 +50,7 @@ impl MessageBatch { self.first_added .elapsed() .saturating_sub(self.max_message_delay) + .saturating_sub(Duration::from_millis(10)) } pub fn broadcast( diff --git a/broadcast/src/handler.rs b/broadcast/src/handler.rs index 2b470fb..e71edef 100644 --- a/broadcast/src/handler.rs +++ b/broadcast/src/handler.rs @@ -1,3 +1,4 @@ +use elsa::sync::FrozenMap; use smol::{ future, lock::{Mutex, RwLock}, @@ -10,7 +11,11 @@ use std::{ time::Duration, }; -use crate::{batch::MessageBatch, msg::*, topology::Topology}; +use crate::{ + batch::MessageBatch, + msg::*, + topology::{NodeId, Topology}, +}; use common::{ msg::{MessageHeader, Output}, @@ -24,8 +29,9 @@ pub struct BroadcastHandler { node_id: String, seen: RwLock>, topology: Topology, - batch: Mutex, + batch: FrozenMap>>, pub(crate) output: Output, + max_message_delay: Duration, attempted_broadcasts: Mutex>>, } @@ -36,12 +42,13 @@ impl Handler for BroadcastHandler { let max_message_delay = MAX_STABLE_DELAY_MS / (node_ids.len() as f32).sqrt() as u32; let this = Arc::new(Self { + max_message_delay, node_id, topology: Topology::dense(node_ids.clone()), seen: RwLock::new(HashSet::new()), output, attempted_broadcasts: Mutex::default(), - batch: Mutex::new(MessageBatch::new(max_message_delay)), + batch: FrozenMap::new(), }); smol::spawn(this.clone().poll_batch()).detach(); @@ -61,7 +68,7 @@ impl Handler for BroadcastHandler { message, } => { future::zip( - self.receive_broadcast(message), + self.receive_broadcast(&header.src, message), self.send_broadcast_ok(&header.src, msg_id), ) .await; @@ -72,7 +79,7 @@ impl Handler for BroadcastHandler { messages, } => { future::zip( - self.receive_broadcast_batch(messages), + self.receive_broadcast_batch(&header.src, messages), self.send_broadcast_ok(&header.src, msg_id), ) .await; @@ -82,14 +89,14 @@ impl Handler for BroadcastHandler { msg_id: None, message, } => { - self.receive_broadcast(message).await; + self.receive_broadcast(&header.src, message).await; } BroadcastBody::BroadcastBatch { msg_id: None, messages, } => { - self.receive_broadcast_batch(messages).await; + self.receive_broadcast_batch(&header.src, messages).await; } BroadcastBody::Topology { msg_id, topology } => { @@ -152,49 +159,76 @@ impl BroadcastHandler { } /// Receive a given message, and broadcast it onwards if it is new - async fn receive_broadcast(self: &Arc, message: BroadcastTarget) { + async fn receive_broadcast(self: &Arc, src: &String, message: BroadcastTarget) { let new = self.seen.write().await.insert(message); if !new { return; } - let mut batch = self.batch.lock().await; - batch.add(message); + for dest in self.topology.targets(&self.node_id, &src).await { + let batch_lock = self.get_batch_lock(&dest); + + let mut batch = batch_lock.lock().await; + batch.add(message); + } } - async fn receive_broadcast_batch(self: &Arc, message: Vec) { - let mut batch = self.batch.lock().await; + async fn receive_broadcast_batch( + self: &Arc, + src: &String, + message: Vec, + ) { let mut seen = self.seen.write().await; for message in message.into_iter() { if seen.insert(message) { - batch.add(message); + for dest in self.topology.targets(&self.node_id, &src).await { + let batch_lock = self.get_batch_lock(&dest); + let mut batch = batch_lock.lock().await; + batch.add(message); + } } } } - async fn poll_batch(self: Arc) { - loop { - let mut batch = self.batch.lock().await; - self.do_batch_check(&mut batch).await; - - let time = batch.sleep_time(); - drop(batch); + fn get_batch_lock(&self, key: &String) -> &Mutex { + match self.batch.get(key) { + Some(x) => x, + None => { + self.batch.insert( + key.clone(), + Box::new(Mutex::new(MessageBatch::new(self.max_message_delay))), + ); - Timer::after(time).await; + self.batch.get(key).unwrap() + } } } - async fn do_batch_check(self: &Arc, batch: &mut MessageBatch) { - if batch.should_broadcast() { + async fn poll_batch(self: Arc) { + loop { let mut tasks = self.attempted_broadcasts.lock().await; - for target in self.topology.all_targets(&self.node_id).await { - let msg_id = gen_msg_id(); - let this = self.clone(); - tasks.insert(msg_id, smol::spawn(batch.broadcast(this, target, msg_id))); + let mut min_sleep_time = self.max_message_delay; + for key in self.batch.keys_cloned() { + let batch_lock = self.batch.get(&key).unwrap(); + let mut batch = batch_lock.lock().await; + if batch.should_broadcast() { + let msg_id = gen_msg_id(); + let this = self.clone(); + tasks.insert(msg_id, smol::spawn(batch.broadcast(this, key, msg_id))); + + batch.clear(); + } + + let batch_sleep_time = batch.sleep_time(); + if batch_sleep_time < min_sleep_time { + min_sleep_time = batch_sleep_time; + } } - batch.clear(); + drop(tasks); + + Timer::after(min_sleep_time).await; } } } diff --git a/broadcast/src/topology.rs b/broadcast/src/topology.rs index 5e16666..bb10466 100644 --- a/broadcast/src/topology.rs +++ b/broadcast/src/topology.rs @@ -14,16 +14,40 @@ impl Topology { for node_id in node_ids.iter() { top.insert(node_id.clone(), node_ids.iter().cloned().collect()); } + Self::filter_desc(&mut top); Topology(RwLock::new(top)) } /// Replace the current topology with a new one. - pub async fn replace(&self, new: TopologyDesc) { + pub async fn replace(&self, mut new: TopologyDesc) { + Self::filter_desc(&mut new); *self.0.write().await = new; } - pub async fn all_targets(&self, node_id: &NodeId) -> HashSet { - self.0.read().await.get(node_id).unwrap().clone() + fn filter_desc(desc: &mut TopologyDesc) { + for (node_id, neighbours) in desc.iter_mut() { + neighbours.remove(node_id); + } + } + + /// Get the next targets from the given topology, for a message + /// which has travelled across the given path and is now at node_id. + pub async fn targets(&self, node_id: &String, last_node_id: &String) -> HashSet { + // Ensure we don't keep holding the read lock + let topology = self.0.read().await; + + // Get all nodes the last node sent it to + let visited = topology.get(last_node_id); + + let neighbours = topology.get(node_id).unwrap(); + match visited { + Some(visited) => neighbours + .difference(&visited) + .cloned() + .filter(|n| n != node_id) + .collect(), + None => neighbours.clone(), + } } } -- cgit v1.2.3