summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAria <me@aria.rip>2023-10-19 21:00:26 +0100
committerAria <me@aria.rip>2023-10-19 21:00:26 +0100
commitdba41282cac86a740c007498709e996b9fa3e59b (patch)
tree5d9309962f55ca00be6fee3b7ac24b7314e86f2a
parent07e2085190e30010ad595369a07842413bacd3d1 (diff)
wip: broadcast message batching
-rw-r--r--broadcast/src/batch.rs81
-rw-r--r--broadcast/src/handler.rs119
-rw-r--r--broadcast/src/main.rs1
-rw-r--r--broadcast/src/msg.rs7
-rw-r--r--broadcast/src/topology.rs128
-rw-r--r--common/src/lib.rs4
6 files changed, 164 insertions, 176 deletions
diff --git a/broadcast/src/batch.rs b/broadcast/src/batch.rs
new file mode 100644
index 0000000..8c1c17d
--- /dev/null
+++ b/broadcast/src/batch.rs
@@ -0,0 +1,81 @@
+use std::{
+ collections::HashSet,
+ sync::Arc,
+ time::{Duration, Instant},
+};
+
+use common::msg_id::MessageID;
+use futures::Future;
+use smol::Timer;
+
+use crate::{
+ handler::BroadcastHandler,
+ msg::{BroadcastBody, BroadcastTarget},
+};
+
+const RETRY_TIMEOUT_SECS: u64 = 1;
+
+#[derive(Debug, Clone)]
+pub struct MessageBatch {
+ max_message_delay: Duration,
+ max_message_count: usize,
+ last_update: Instant,
+ messages: HashSet<BroadcastTarget>,
+}
+
+impl MessageBatch {
+ pub fn new(max_message_delay: Duration, max_message_count: usize) -> Self {
+ Self {
+ max_message_delay,
+ max_message_count,
+ last_update: Instant::now(),
+ messages: HashSet::new(),
+ }
+ }
+
+ pub fn add(&mut self, msg: BroadcastTarget) {
+ self.messages.insert(msg);
+ self.last_update = Instant::now();
+ }
+
+ pub fn clear(&mut self) {
+ self.messages.clear();
+ self.last_update = Instant::now();
+ }
+
+ pub fn should_broadcast(&self) -> bool {
+ !self.messages.is_empty()
+ && (self.last_update.elapsed() >= self.max_message_delay
+ || self.messages.len() >= self.max_message_count)
+ }
+
+ pub fn sleep_time(&self) -> Duration {
+ self.last_update
+ .elapsed()
+ .saturating_sub(self.max_message_delay)
+ }
+
+ pub fn broadcast(
+ &self,
+ this: Arc<BroadcastHandler>,
+ dst: String,
+ msg_id: MessageID,
+ ) -> impl Future<Output = ()> + 'static {
+ let messages = self.messages.clone();
+ async move {
+ loop {
+ this.output
+ .send(
+ &dst,
+ &BroadcastBody::BroadcastBatch {
+ msg_id: Some(msg_id),
+ messages: messages.clone().into_iter().collect(),
+ },
+ )
+ .await;
+
+ Timer::after(Duration::from_secs(RETRY_TIMEOUT_SECS)).await;
+ }
+ }
+ }
+}
diff --git a/broadcast/src/handler.rs b/broadcast/src/handler.rs
index b38e2b7..4ee2664 100644
--- a/broadcast/src/handler.rs
+++ b/broadcast/src/handler.rs
@@ -10,35 +10,42 @@ use std::{
time::Duration,
};
-use crate::{msg::*, topology::Topology};
+use crate::{batch::MessageBatch, msg::*, topology::Topology};
use common::{
msg::{MessageHeader, Output},
msg_id::{gen_msg_id, MessageID},
Handler,
};
-
-const RETRY_TIMEOUT_SECS: u64 = 1;
+const MAX_STABLE_DELAY_MS: Duration = Duration::from_millis(700);
pub struct BroadcastHandler {
node_id: String,
seen: RwLock<HashSet<BroadcastTarget>>,
topology: Topology,
- output: Output<BroadcastBody>,
+ batch: Mutex<MessageBatch>,
+ pub(crate) output: Output<BroadcastBody>,
attempted_broadcasts: Mutex<HashMap<MessageID, Task<()>>>,
}
impl Handler for BroadcastHandler {
type Body = BroadcastBody;
- fn init(node_id: String, node_ids: Vec<String>, output: Output<Self::Body>) -> Self {
- BroadcastHandler {
+ fn init(node_id: String, node_ids: Vec<String>, output: Output<Self::Body>) -> Arc<Self> {
+ let max_message_delay = MAX_STABLE_DELAY_MS / (node_ids.len() / node_ids.len()) as u32;
+
+ let this = Arc::new(Self {
node_id,
topology: Topology::dense(node_ids.clone()),
seen: RwLock::new(HashSet::new()),
output,
attempted_broadcasts: Mutex::default(),
- }
+ batch: Mutex::new(MessageBatch::new(max_message_delay, 1000)),
+ });
+
+ smol::spawn(this.clone().poll_batch()).detach();
+
+ this
}
fn handle<'a>(
@@ -51,20 +58,37 @@ impl Handler for BroadcastHandler {
BroadcastBody::Broadcast {
msg_id: Some(msg_id),
message,
- path,
} => {
future::zip(
- self.receive_broadcast(&header.src, path, message),
+ self.receive_broadcast(message),
self.send_broadcast_ok(&header.src, msg_id),
)
.await;
}
+
+ BroadcastBody::BroadcastBatch {
+ msg_id: Some(msg_id),
+ messages,
+ } => {
+ future::zip(
+ self.receive_broadcast_batch(messages),
+ self.send_broadcast_ok(&header.src, msg_id),
+ )
+ .await;
+ }
+
BroadcastBody::Broadcast {
msg_id: None,
message,
- path,
} => {
- self.receive_broadcast(&header.src, path, message).await;
+ self.receive_broadcast(message).await;
+ }
+
+ BroadcastBody::BroadcastBatch {
+ msg_id: None,
+ messages,
+ } => {
+ self.receive_broadcast_batch(messages).await;
}
BroadcastBody::Topology { msg_id, topology } => {
@@ -127,48 +151,49 @@ impl BroadcastHandler {
}
/// Receive a given message, and broadcast it onwards if it is new
- async fn receive_broadcast(
- self: &Arc<Self>,
- src: &str,
- previous_path: Option<Vec<String>>,
- message: BroadcastTarget,
- ) {
+ async fn receive_broadcast(self: &Arc<Self>, message: BroadcastTarget) {
let new = self.seen.write().await.insert(message);
if !new {
return;
}
- // Race all send futures
- let mut previous_path = previous_path.unwrap_or_else(|| vec![]);
- previous_path.push(src.to_string());
- let mut tasks = self.attempted_broadcasts.lock().await;
- for target in self
- .topology
- .targets(&self.node_id, previous_path.iter().map(String::as_str))
- .await
- {
- let msg_id = gen_msg_id();
- let this = self.clone();
- let path = previous_path.clone();
- tasks.insert(
- msg_id,
- smol::spawn(async move {
- loop {
- this.output
- .send(
- &target,
- &BroadcastBody::Broadcast {
- msg_id: Some(msg_id),
- message,
- path: Some(path.clone()),
- },
- )
- .await;
+ let mut batch = self.batch.lock().await;
+ batch.add(message);
+ }
- Timer::after(Duration::from_secs(RETRY_TIMEOUT_SECS)).await;
- }
- }),
- );
+ async fn receive_broadcast_batch(self: &Arc<Self>, message: Vec<BroadcastTarget>) {
+ let mut batch = self.batch.lock().await;
+ let mut seen = self.seen.write().await;
+ let mut new = false;
+
+ for message in message.into_iter() {
+ new |= seen.insert(message);
+ batch.add(message);
+ }
+
+ if !new {
+ return;
+ }
+ }
+
+ async fn poll_batch(self: Arc<Self>) {
+ loop {
+ let mut batch = self.batch.lock().await;
+ if batch.should_broadcast() {
+ let mut tasks = self.attempted_broadcasts.lock().await;
+ for target in self.topology.all_targets(&self.node_id).await {
+ let msg_id = gen_msg_id();
+ let this = self.clone();
+ tasks.insert(msg_id, smol::spawn(batch.broadcast(this, target, msg_id)));
+ }
+
+ batch.clear();
+ }
+
+ let time = batch.sleep_time();
+ drop(batch);
+
+ Timer::after(time).await;
}
}
}
diff --git a/broadcast/src/main.rs b/broadcast/src/main.rs
index f4f1bf4..4b1b72f 100644
--- a/broadcast/src/main.rs
+++ b/broadcast/src/main.rs
@@ -2,6 +2,7 @@
use common::run_server;
+mod batch;
mod handler;
mod msg;
mod topology;
diff --git a/broadcast/src/msg.rs b/broadcast/src/msg.rs
index c252394..bbe83a4 100644
--- a/broadcast/src/msg.rs
+++ b/broadcast/src/msg.rs
@@ -11,7 +11,12 @@ pub enum BroadcastBody {
Broadcast {
msg_id: Option<MessageID>,
message: BroadcastTarget,
- path: Option<Vec<String>>,
+ },
+
+ #[serde(rename = "broadcast_batch")]
+ BroadcastBatch {
+ msg_id: Option<MessageID>,
+ messages: Vec<BroadcastTarget>,
},
#[serde(rename = "broadcast_ok")]
diff --git a/broadcast/src/topology.rs b/broadcast/src/topology.rs
index d91b8ae..5e16666 100644
--- a/broadcast/src/topology.rs
+++ b/broadcast/src/topology.rs
@@ -8,11 +8,6 @@ pub type TopologyDesc = HashMap<NodeId, HashSet<NodeId>>;
pub struct Topology(RwLock<TopologyDesc>);
impl Topology {
- /// Create a new topology from the given description
- pub fn new(top: TopologyDesc) -> Self {
- Topology(RwLock::new(top))
- }
-
/// Create a new topology in which all nodes are connected to each other.
pub fn dense(node_ids: Vec<String>) -> Self {
let mut top = TopologyDesc::new();
@@ -28,126 +23,7 @@ impl Topology {
*self.0.write().await = new;
}
- /// Get the next targets from the given topology, for a message
- /// which has travelled across the given path and is now at node_id.
- pub async fn targets(
- &self,
- node_id: &String,
- path: impl Iterator<Item = &str>,
- ) -> HashSet<String> {
- // Ensure we don't keep holding the read lock
- let topology = self.0.read().await;
-
- // Get all visited nodes, from all neighbours of all node along the source path
- let mut visited = HashSet::new();
- for node in path {
- visited.insert(node.to_string());
- if let Some(neighbours) = topology.get(node) {
- for neighbour in neighbours {
- visited.insert(neighbour.clone());
- }
- }
- }
-
- // Send to all neighbours that haven't already been sent to
- topology
- .get(node_id)
- .unwrap()
- .difference(&visited)
- .cloned()
- .filter(|n| n != node_id)
- .collect()
- }
-}
-
-#[cfg(test)]
-mod tests {
- use std::iter;
-
- use super::*;
-
- fn name(x: usize, y: usize) -> String {
- format!("{},{}", x, y)
- }
-
- fn grid(w: usize, h: usize) -> TopologyDesc {
- let mut top = HashMap::new();
- for x in 0..w {
- for y in 0..h {
- let mut neighbours = HashSet::new();
- if x > 0 {
- neighbours.insert(name(x - 1, y));
- if y > 0 {
- neighbours.insert(name(x - 1, y - 1));
- }
- if y < h - 1 {
- neighbours.insert(name(x - 1, y + 1));
- }
- }
- if x < h - 1 {
- neighbours.insert(name(x + 1, y));
- if y > 0 {
- neighbours.insert(name(x + 1, y - 1));
- }
- if y < h - 1 {
- neighbours.insert(name(x + 1, y + 1));
- }
- }
-
- if y > 0 {
- neighbours.insert(name(x, y - 1));
- }
- if y < h - 1 {
- neighbours.insert(name(x, y + 1));
- }
-
- top.insert(name(x, y), neighbours);
- }
- }
-
- top
- }
-
- #[test]
- pub fn test_grid_entrypoint() {
- smol::block_on(async {
- let top = Topology::new(grid(3, 3));
-
- // any corner must have 3 targets
- assert_eq!(top.targets(&name(0, 0), iter::empty()).await.len(), 3);
- assert_eq!(top.targets(&name(2, 0), iter::empty()).await.len(), 3);
- assert_eq!(top.targets(&name(2, 2), iter::empty()).await.len(), 3);
- assert_eq!(top.targets(&name(0, 2), iter::empty()).await.len(), 3);
-
- // any side must have 5 targets
- assert_eq!(top.targets(&name(0, 1), iter::empty()).await.len(), 5);
- assert_eq!(top.targets(&name(1, 0), iter::empty()).await.len(), 5);
- assert_eq!(top.targets(&name(2, 1), iter::empty()).await.len(), 5);
- assert_eq!(top.targets(&name(1, 2), iter::empty()).await.len(), 5);
-
- // the center must have 8 targets
- assert_eq!(top.targets(&name(1, 1), iter::empty()).await.len(), 8);
- })
- }
-
- #[test]
- pub fn test_grid_previous() {
- smol::block_on(async {
- let top = Topology::new(grid(3, 3));
-
- // if we've passed through the center, we will never have any targets
- for x in 0..3 {
- for y in 0..3 {
- assert_eq!(
- dbg!(
- top.targets(&name(x, y), iter::once(name(1, 1).as_str()))
- .await
- )
- .len(),
- 0
- );
- }
- }
- })
+ pub async fn all_targets(&self, node_id: &NodeId) -> HashSet<String> {
+ self.0.read().await.get(node_id).unwrap().clone()
}
}
diff --git a/common/src/lib.rs b/common/src/lib.rs
index 69a872b..bfdfa42 100644
--- a/common/src/lib.rs
+++ b/common/src/lib.rs
@@ -21,7 +21,7 @@ pub mod msg_id;
pub trait Handler: Send + Sync + 'static {
type Body: Serialize + for<'a> Deserialize<'a> + Send + Clone;
- fn init(node_id: String, node_ids: Vec<String>, output: Output<Self::Body>) -> Self;
+ fn init(node_id: String, node_ids: Vec<String>, output: Output<Self::Body>) -> Arc<Self>;
fn handle<'a>(
self: &'a Arc<Self>,
header: MessageHeader,
@@ -63,7 +63,7 @@ pub fn run_server<H: Handler>() {
fn sync_init_handler<H: Handler, R: Read, W: Write>(
reader: R,
mut writer: W,
-) -> (H, Receiver<Message<H::Body>>) {
+) -> (Arc<H>, Receiver<Message<H::Body>>) {
// Receive the init message
let deser = Deserializer::from_reader(reader);
let mut deser = deser.into_iter::<Message<MaelstromBody>>();