summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAria <me@aria.rip>2023-10-20 00:33:52 +0100
committerAria <me@aria.rip>2023-10-20 00:33:52 +0100
commit4b6e257263f7bb6eaebb9be672f88bda85c77586 (patch)
treeba3e5130284a96e73418be1440a97646b1727bdf
parentbb54fd5b450ea1b126f7c31845f12893bf061058 (diff)
broadcast batching with minimal set of messages needed
-rw-r--r--Cargo.lock16
-rw-r--r--broadcast/Cargo.toml3
-rw-r--r--broadcast/src/batch.rs1
-rw-r--r--broadcast/src/handler.rs90
-rw-r--r--broadcast/src/topology.rs30
5 files changed, 108 insertions, 32 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 191c81a..d9b0d5f 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -165,6 +165,7 @@ name = "broadcast"
version = "0.1.0"
dependencies = [
"common",
+ "elsa",
"futures",
"serde",
"serde_json",
@@ -215,6 +216,15 @@ dependencies = [
]
[[package]]
+name = "elsa"
+version = "1.9.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "714f766f3556b44e7e4776ad133fcc3445a489517c25c704ace411bb14790194"
+dependencies = [
+ "stable_deref_trait",
+]
+
+[[package]]
name = "errno"
version = "0.3.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -618,6 +628,12 @@ dependencies = [
]
[[package]]
+name = "stable_deref_trait"
+version = "1.2.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3"
+
+[[package]]
name = "syn"
version = "2.0.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
diff --git a/broadcast/Cargo.toml b/broadcast/Cargo.toml
index 06e4426..718f7d1 100644
--- a/broadcast/Cargo.toml
+++ b/broadcast/Cargo.toml
@@ -10,4 +10,5 @@ smol = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
common = { path = "../common/" }
-futures = { version = "0.3.28", default_features = false, features = ["std"] } \ No newline at end of file
+futures = { version = "0.3.28", default_features = false, features = ["std"] }
+elsa = "1.9.0"
diff --git a/broadcast/src/batch.rs b/broadcast/src/batch.rs
index d69771d..42cc3b8 100644
--- a/broadcast/src/batch.rs
+++ b/broadcast/src/batch.rs
@@ -50,6 +50,7 @@ impl MessageBatch {
self.first_added
.elapsed()
.saturating_sub(self.max_message_delay)
+ .saturating_sub(Duration::from_millis(10))
}
pub fn broadcast(
diff --git a/broadcast/src/handler.rs b/broadcast/src/handler.rs
index 2b470fb..e71edef 100644
--- a/broadcast/src/handler.rs
+++ b/broadcast/src/handler.rs
@@ -1,3 +1,4 @@
+use elsa::sync::FrozenMap;
use smol::{
future,
lock::{Mutex, RwLock},
@@ -10,7 +11,11 @@ use std::{
time::Duration,
};
-use crate::{batch::MessageBatch, msg::*, topology::Topology};
+use crate::{
+ batch::MessageBatch,
+ msg::*,
+ topology::{NodeId, Topology},
+};
use common::{
msg::{MessageHeader, Output},
@@ -24,8 +29,9 @@ pub struct BroadcastHandler {
node_id: String,
seen: RwLock<HashSet<BroadcastTarget>>,
topology: Topology,
- batch: Mutex<MessageBatch>,
+ batch: FrozenMap<NodeId, Box<Mutex<MessageBatch>>>,
pub(crate) output: Output<BroadcastBody>,
+ max_message_delay: Duration,
attempted_broadcasts: Mutex<HashMap<MessageID, Task<()>>>,
}
@@ -36,12 +42,13 @@ impl Handler for BroadcastHandler {
let max_message_delay = MAX_STABLE_DELAY_MS / (node_ids.len() as f32).sqrt() as u32;
let this = Arc::new(Self {
+ max_message_delay,
node_id,
topology: Topology::dense(node_ids.clone()),
seen: RwLock::new(HashSet::new()),
output,
attempted_broadcasts: Mutex::default(),
- batch: Mutex::new(MessageBatch::new(max_message_delay)),
+ batch: FrozenMap::new(),
});
smol::spawn(this.clone().poll_batch()).detach();
@@ -61,7 +68,7 @@ impl Handler for BroadcastHandler {
message,
} => {
future::zip(
- self.receive_broadcast(message),
+ self.receive_broadcast(&header.src, message),
self.send_broadcast_ok(&header.src, msg_id),
)
.await;
@@ -72,7 +79,7 @@ impl Handler for BroadcastHandler {
messages,
} => {
future::zip(
- self.receive_broadcast_batch(messages),
+ self.receive_broadcast_batch(&header.src, messages),
self.send_broadcast_ok(&header.src, msg_id),
)
.await;
@@ -82,14 +89,14 @@ impl Handler for BroadcastHandler {
msg_id: None,
message,
} => {
- self.receive_broadcast(message).await;
+ self.receive_broadcast(&header.src, message).await;
}
BroadcastBody::BroadcastBatch {
msg_id: None,
messages,
} => {
- self.receive_broadcast_batch(messages).await;
+ self.receive_broadcast_batch(&header.src, messages).await;
}
BroadcastBody::Topology { msg_id, topology } => {
@@ -152,49 +159,76 @@ impl BroadcastHandler {
}
/// Receive a given message, and broadcast it onwards if it is new
- async fn receive_broadcast(self: &Arc<Self>, message: BroadcastTarget) {
+ async fn receive_broadcast(self: &Arc<Self>, src: &String, message: BroadcastTarget) {
let new = self.seen.write().await.insert(message);
if !new {
return;
}
- let mut batch = self.batch.lock().await;
- batch.add(message);
+ for dest in self.topology.targets(&self.node_id, &src).await {
+ let batch_lock = self.get_batch_lock(&dest);
+
+ let mut batch = batch_lock.lock().await;
+ batch.add(message);
+ }
}
- async fn receive_broadcast_batch(self: &Arc<Self>, message: Vec<BroadcastTarget>) {
- let mut batch = self.batch.lock().await;
+ async fn receive_broadcast_batch(
+ self: &Arc<Self>,
+ src: &String,
+ message: Vec<BroadcastTarget>,
+ ) {
let mut seen = self.seen.write().await;
for message in message.into_iter() {
if seen.insert(message) {
- batch.add(message);
+ for dest in self.topology.targets(&self.node_id, &src).await {
+ let batch_lock = self.get_batch_lock(&dest);
+ let mut batch = batch_lock.lock().await;
+ batch.add(message);
+ }
}
}
}
- async fn poll_batch(self: Arc<Self>) {
- loop {
- let mut batch = self.batch.lock().await;
- self.do_batch_check(&mut batch).await;
-
- let time = batch.sleep_time();
- drop(batch);
+ fn get_batch_lock(&self, key: &String) -> &Mutex<MessageBatch> {
+ match self.batch.get(key) {
+ Some(x) => x,
+ None => {
+ self.batch.insert(
+ key.clone(),
+ Box::new(Mutex::new(MessageBatch::new(self.max_message_delay))),
+ );
- Timer::after(time).await;
+ self.batch.get(key).unwrap()
+ }
}
}
- async fn do_batch_check(self: &Arc<Self>, batch: &mut MessageBatch) {
- if batch.should_broadcast() {
+ async fn poll_batch(self: Arc<Self>) {
+ loop {
let mut tasks = self.attempted_broadcasts.lock().await;
- for target in self.topology.all_targets(&self.node_id).await {
- let msg_id = gen_msg_id();
- let this = self.clone();
- tasks.insert(msg_id, smol::spawn(batch.broadcast(this, target, msg_id)));
+ let mut min_sleep_time = self.max_message_delay;
+ for key in self.batch.keys_cloned() {
+ let batch_lock = self.batch.get(&key).unwrap();
+ let mut batch = batch_lock.lock().await;
+ if batch.should_broadcast() {
+ let msg_id = gen_msg_id();
+ let this = self.clone();
+ tasks.insert(msg_id, smol::spawn(batch.broadcast(this, key, msg_id)));
+
+ batch.clear();
+ }
+
+ let batch_sleep_time = batch.sleep_time();
+ if batch_sleep_time < min_sleep_time {
+ min_sleep_time = batch_sleep_time;
+ }
}
- batch.clear();
+ drop(tasks);
+
+ Timer::after(min_sleep_time).await;
}
}
}
diff --git a/broadcast/src/topology.rs b/broadcast/src/topology.rs
index 5e16666..bb10466 100644
--- a/broadcast/src/topology.rs
+++ b/broadcast/src/topology.rs
@@ -14,16 +14,40 @@ impl Topology {
for node_id in node_ids.iter() {
top.insert(node_id.clone(), node_ids.iter().cloned().collect());
}
+ Self::filter_desc(&mut top);
Topology(RwLock::new(top))
}
/// Replace the current topology with a new one.
- pub async fn replace(&self, new: TopologyDesc) {
+ pub async fn replace(&self, mut new: TopologyDesc) {
+ Self::filter_desc(&mut new);
*self.0.write().await = new;
}
- pub async fn all_targets(&self, node_id: &NodeId) -> HashSet<String> {
- self.0.read().await.get(node_id).unwrap().clone()
+ fn filter_desc(desc: &mut TopologyDesc) {
+ for (node_id, neighbours) in desc.iter_mut() {
+ neighbours.remove(node_id);
+ }
+ }
+
+ /// Get the next targets from the given topology, for a message
+ /// which has travelled across the given path and is now at node_id.
+ pub async fn targets(&self, node_id: &String, last_node_id: &String) -> HashSet<String> {
+ // Ensure we don't keep holding the read lock
+ let topology = self.0.read().await;
+
+ // Get all nodes the last node sent it to
+ let visited = topology.get(last_node_id);
+
+ let neighbours = topology.get(node_id).unwrap();
+ match visited {
+ Some(visited) => neighbours
+ .difference(&visited)
+ .cloned()
+ .filter(|n| n != node_id)
+ .collect(),
+ None => neighbours.clone(),
+ }
}
}