From dba41282cac86a740c007498709e996b9fa3e59b Mon Sep 17 00:00:00 2001 From: Aria Date: Thu, 19 Oct 2023 21:00:26 +0100 Subject: wip: broadcast message batching --- broadcast/src/batch.rs | 81 +++++++++++++++++++++++++++++ broadcast/src/handler.rs | 119 +++++++++++++++++++++++++----------------- broadcast/src/main.rs | 1 + broadcast/src/msg.rs | 7 ++- broadcast/src/topology.rs | 128 +--------------------------------------------- common/src/lib.rs | 4 +- 6 files changed, 164 insertions(+), 176 deletions(-) create mode 100644 broadcast/src/batch.rs diff --git a/broadcast/src/batch.rs b/broadcast/src/batch.rs new file mode 100644 index 0000000..8c1c17d --- /dev/null +++ b/broadcast/src/batch.rs @@ -0,0 +1,81 @@ +use std::{ + collections::HashSet, + sync::Arc, + time::{Duration, Instant}, +}; + +use common::msg_id::MessageID; +use futures::Future; +use smol::Timer; + +use crate::{ + handler::BroadcastHandler, + msg::{BroadcastBody, BroadcastTarget}, +}; + +const RETRY_TIMEOUT_SECS: u64 = 1; + +#[derive(Debug, Clone)] +pub struct MessageBatch { + max_message_delay: Duration, + max_message_count: usize, + last_update: Instant, + messages: HashSet, +} + +impl MessageBatch { + pub fn new(max_message_delay: Duration, max_message_count: usize) -> Self { + Self { + max_message_delay, + max_message_count, + last_update: Instant::now(), + messages: HashSet::new(), + } + } + + pub fn add(&mut self, msg: BroadcastTarget) { + self.messages.insert(msg); + self.last_update = Instant::now(); + } + + pub fn clear(&mut self) { + self.messages.clear(); + self.last_update = Instant::now(); + } + + pub fn should_broadcast(&self) -> bool { + !self.messages.is_empty() + && (self.last_update.elapsed() >= self.max_message_delay + || self.messages.len() >= self.max_message_count) + } + + pub fn sleep_time(&self) -> Duration { + self.last_update + .elapsed() + .saturating_sub(self.max_message_delay) + } + + pub fn broadcast( + &self, + this: Arc, + dst: String, + msg_id: MessageID, + ) -> impl Future + 'static { + let messages = self.messages.clone(); + async move { + loop { + this.output + .send( + &dst, + &BroadcastBody::BroadcastBatch { + msg_id: Some(msg_id), + messages: messages.clone().into_iter().collect(), + }, + ) + .await; + + Timer::after(Duration::from_secs(RETRY_TIMEOUT_SECS)).await; + } + } + } +} diff --git a/broadcast/src/handler.rs b/broadcast/src/handler.rs index b38e2b7..4ee2664 100644 --- a/broadcast/src/handler.rs +++ b/broadcast/src/handler.rs @@ -10,35 +10,42 @@ use std::{ time::Duration, }; -use crate::{msg::*, topology::Topology}; +use crate::{batch::MessageBatch, msg::*, topology::Topology}; use common::{ msg::{MessageHeader, Output}, msg_id::{gen_msg_id, MessageID}, Handler, }; - -const RETRY_TIMEOUT_SECS: u64 = 1; +const MAX_STABLE_DELAY_MS: Duration = Duration::from_millis(700); pub struct BroadcastHandler { node_id: String, seen: RwLock>, topology: Topology, - output: Output, + batch: Mutex, + pub(crate) output: Output, attempted_broadcasts: Mutex>>, } impl Handler for BroadcastHandler { type Body = BroadcastBody; - fn init(node_id: String, node_ids: Vec, output: Output) -> Self { - BroadcastHandler { + fn init(node_id: String, node_ids: Vec, output: Output) -> Arc { + let max_message_delay = MAX_STABLE_DELAY_MS / (node_ids.len() / node_ids.len()) as u32; + + let this = Arc::new(Self { node_id, topology: Topology::dense(node_ids.clone()), seen: RwLock::new(HashSet::new()), output, attempted_broadcasts: Mutex::default(), - } + batch: Mutex::new(MessageBatch::new(max_message_delay, 1000)), + }); + + smol::spawn(this.clone().poll_batch()).detach(); + + this } fn handle<'a>( @@ -51,20 +58,37 @@ impl Handler for BroadcastHandler { BroadcastBody::Broadcast { msg_id: Some(msg_id), message, - path, } => { future::zip( - self.receive_broadcast(&header.src, path, message), + self.receive_broadcast(message), self.send_broadcast_ok(&header.src, msg_id), ) .await; } + + BroadcastBody::BroadcastBatch { + msg_id: Some(msg_id), + messages, + } => { + future::zip( + self.receive_broadcast_batch(messages), + self.send_broadcast_ok(&header.src, msg_id), + ) + .await; + } + BroadcastBody::Broadcast { msg_id: None, message, - path, } => { - self.receive_broadcast(&header.src, path, message).await; + self.receive_broadcast(message).await; + } + + BroadcastBody::BroadcastBatch { + msg_id: None, + messages, + } => { + self.receive_broadcast_batch(messages).await; } BroadcastBody::Topology { msg_id, topology } => { @@ -127,48 +151,49 @@ impl BroadcastHandler { } /// Receive a given message, and broadcast it onwards if it is new - async fn receive_broadcast( - self: &Arc, - src: &str, - previous_path: Option>, - message: BroadcastTarget, - ) { + async fn receive_broadcast(self: &Arc, message: BroadcastTarget) { let new = self.seen.write().await.insert(message); if !new { return; } - // Race all send futures - let mut previous_path = previous_path.unwrap_or_else(|| vec![]); - previous_path.push(src.to_string()); - let mut tasks = self.attempted_broadcasts.lock().await; - for target in self - .topology - .targets(&self.node_id, previous_path.iter().map(String::as_str)) - .await - { - let msg_id = gen_msg_id(); - let this = self.clone(); - let path = previous_path.clone(); - tasks.insert( - msg_id, - smol::spawn(async move { - loop { - this.output - .send( - &target, - &BroadcastBody::Broadcast { - msg_id: Some(msg_id), - message, - path: Some(path.clone()), - }, - ) - .await; + let mut batch = self.batch.lock().await; + batch.add(message); + } - Timer::after(Duration::from_secs(RETRY_TIMEOUT_SECS)).await; - } - }), - ); + async fn receive_broadcast_batch(self: &Arc, message: Vec) { + let mut batch = self.batch.lock().await; + let mut seen = self.seen.write().await; + let mut new = false; + + for message in message.into_iter() { + new |= seen.insert(message); + batch.add(message); + } + + if !new { + return; + } + } + + async fn poll_batch(self: Arc) { + loop { + let mut batch = self.batch.lock().await; + if batch.should_broadcast() { + let mut tasks = self.attempted_broadcasts.lock().await; + for target in self.topology.all_targets(&self.node_id).await { + let msg_id = gen_msg_id(); + let this = self.clone(); + tasks.insert(msg_id, smol::spawn(batch.broadcast(this, target, msg_id))); + } + + batch.clear(); + } + + let time = batch.sleep_time(); + drop(batch); + + Timer::after(time).await; } } } diff --git a/broadcast/src/main.rs b/broadcast/src/main.rs index f4f1bf4..4b1b72f 100644 --- a/broadcast/src/main.rs +++ b/broadcast/src/main.rs @@ -2,6 +2,7 @@ use common::run_server; +mod batch; mod handler; mod msg; mod topology; diff --git a/broadcast/src/msg.rs b/broadcast/src/msg.rs index c252394..bbe83a4 100644 --- a/broadcast/src/msg.rs +++ b/broadcast/src/msg.rs @@ -11,7 +11,12 @@ pub enum BroadcastBody { Broadcast { msg_id: Option, message: BroadcastTarget, - path: Option>, + }, + + #[serde(rename = "broadcast_batch")] + BroadcastBatch { + msg_id: Option, + messages: Vec, }, #[serde(rename = "broadcast_ok")] diff --git a/broadcast/src/topology.rs b/broadcast/src/topology.rs index d91b8ae..5e16666 100644 --- a/broadcast/src/topology.rs +++ b/broadcast/src/topology.rs @@ -8,11 +8,6 @@ pub type TopologyDesc = HashMap>; pub struct Topology(RwLock); impl Topology { - /// Create a new topology from the given description - pub fn new(top: TopologyDesc) -> Self { - Topology(RwLock::new(top)) - } - /// Create a new topology in which all nodes are connected to each other. pub fn dense(node_ids: Vec) -> Self { let mut top = TopologyDesc::new(); @@ -28,126 +23,7 @@ impl Topology { *self.0.write().await = new; } - /// Get the next targets from the given topology, for a message - /// which has travelled across the given path and is now at node_id. - pub async fn targets( - &self, - node_id: &String, - path: impl Iterator, - ) -> HashSet { - // Ensure we don't keep holding the read lock - let topology = self.0.read().await; - - // Get all visited nodes, from all neighbours of all node along the source path - let mut visited = HashSet::new(); - for node in path { - visited.insert(node.to_string()); - if let Some(neighbours) = topology.get(node) { - for neighbour in neighbours { - visited.insert(neighbour.clone()); - } - } - } - - // Send to all neighbours that haven't already been sent to - topology - .get(node_id) - .unwrap() - .difference(&visited) - .cloned() - .filter(|n| n != node_id) - .collect() - } -} - -#[cfg(test)] -mod tests { - use std::iter; - - use super::*; - - fn name(x: usize, y: usize) -> String { - format!("{},{}", x, y) - } - - fn grid(w: usize, h: usize) -> TopologyDesc { - let mut top = HashMap::new(); - for x in 0..w { - for y in 0..h { - let mut neighbours = HashSet::new(); - if x > 0 { - neighbours.insert(name(x - 1, y)); - if y > 0 { - neighbours.insert(name(x - 1, y - 1)); - } - if y < h - 1 { - neighbours.insert(name(x - 1, y + 1)); - } - } - if x < h - 1 { - neighbours.insert(name(x + 1, y)); - if y > 0 { - neighbours.insert(name(x + 1, y - 1)); - } - if y < h - 1 { - neighbours.insert(name(x + 1, y + 1)); - } - } - - if y > 0 { - neighbours.insert(name(x, y - 1)); - } - if y < h - 1 { - neighbours.insert(name(x, y + 1)); - } - - top.insert(name(x, y), neighbours); - } - } - - top - } - - #[test] - pub fn test_grid_entrypoint() { - smol::block_on(async { - let top = Topology::new(grid(3, 3)); - - // any corner must have 3 targets - assert_eq!(top.targets(&name(0, 0), iter::empty()).await.len(), 3); - assert_eq!(top.targets(&name(2, 0), iter::empty()).await.len(), 3); - assert_eq!(top.targets(&name(2, 2), iter::empty()).await.len(), 3); - assert_eq!(top.targets(&name(0, 2), iter::empty()).await.len(), 3); - - // any side must have 5 targets - assert_eq!(top.targets(&name(0, 1), iter::empty()).await.len(), 5); - assert_eq!(top.targets(&name(1, 0), iter::empty()).await.len(), 5); - assert_eq!(top.targets(&name(2, 1), iter::empty()).await.len(), 5); - assert_eq!(top.targets(&name(1, 2), iter::empty()).await.len(), 5); - - // the center must have 8 targets - assert_eq!(top.targets(&name(1, 1), iter::empty()).await.len(), 8); - }) - } - - #[test] - pub fn test_grid_previous() { - smol::block_on(async { - let top = Topology::new(grid(3, 3)); - - // if we've passed through the center, we will never have any targets - for x in 0..3 { - for y in 0..3 { - assert_eq!( - dbg!( - top.targets(&name(x, y), iter::once(name(1, 1).as_str())) - .await - ) - .len(), - 0 - ); - } - } - }) + pub async fn all_targets(&self, node_id: &NodeId) -> HashSet { + self.0.read().await.get(node_id).unwrap().clone() } } diff --git a/common/src/lib.rs b/common/src/lib.rs index 69a872b..bfdfa42 100644 --- a/common/src/lib.rs +++ b/common/src/lib.rs @@ -21,7 +21,7 @@ pub mod msg_id; pub trait Handler: Send + Sync + 'static { type Body: Serialize + for<'a> Deserialize<'a> + Send + Clone; - fn init(node_id: String, node_ids: Vec, output: Output) -> Self; + fn init(node_id: String, node_ids: Vec, output: Output) -> Arc; fn handle<'a>( self: &'a Arc, header: MessageHeader, @@ -63,7 +63,7 @@ pub fn run_server() { fn sync_init_handler( reader: R, mut writer: W, -) -> (H, Receiver>) { +) -> (Arc, Receiver>) { // Receive the init message let deser = Deserializer::from_reader(reader); let mut deser = deser.into_iter::>(); -- cgit v1.2.3