summaryrefslogtreecommitdiff
path: root/broadcast/src
diff options
context:
space:
mode:
authorAria <me@aria.rip>2023-10-15 00:57:41 +0100
committerAria <me@aria.rip>2023-10-15 00:57:41 +0100
commit42123efe8fd92d6d81b6d5d10ae86866ea9b6a3c (patch)
tree2dfcb1a27c337d1c9b89f54aff35e94795970b74 /broadcast/src
parent7447f3fb801ba954c7b8cbf3f47700ffcc562d20 (diff)
some refactors
Diffstat (limited to 'broadcast/src')
-rw-r--r--broadcast/src/handler.rs33
-rw-r--r--broadcast/src/main.rs1
-rw-r--r--broadcast/src/topology.rs153
3 files changed, 161 insertions, 26 deletions
diff --git a/broadcast/src/handler.rs b/broadcast/src/handler.rs
index d4bb94a..b55cd9b 100644
--- a/broadcast/src/handler.rs
+++ b/broadcast/src/handler.rs
@@ -6,11 +6,12 @@ use smol::{
};
use std::{
collections::{HashMap, HashSet},
+ iter,
sync::Arc,
time::Duration,
};
-use crate::msg::*;
+use crate::{msg::*, topology::Topology};
use common::{
msg::{MessageHeader, Output},
@@ -23,7 +24,7 @@ const RETRY_TIMEOUT_SECS: u64 = 1;
pub struct BroadcastHandler {
node_id: String,
seen: RwLock<HashSet<BroadcastTarget>>,
- topology: RwLock<HashMap<String, HashSet<String>>>,
+ topology: Topology,
output: Output<BroadcastBody>,
attempted_broadcasts: Mutex<HashMap<MessageID, Task<()>>>,
}
@@ -32,15 +33,9 @@ impl Handler for BroadcastHandler {
type Body = BroadcastBody;
fn init(node_id: String, node_ids: Vec<String>, output: Output<Self::Body>) -> Self {
- // Initial topology assumes all nodes are neighbours
- let mut topology = HashMap::new();
- for id in node_ids.iter() {
- topology.insert(id.clone(), node_ids.iter().cloned().collect());
- }
-
BroadcastHandler {
node_id,
- topology: RwLock::new(topology),
+ topology: Topology::dense(node_ids.clone()),
seen: RwLock::new(HashSet::new()),
output,
attempted_broadcasts: Mutex::default(),
@@ -73,7 +68,7 @@ impl Handler for BroadcastHandler {
BroadcastBody::Topology { msg_id, topology } => {
// Start using the new topology
- *self.topology.write().await = topology;
+ self.topology.replace(topology).await;
// Send reply if needed
if let Some(msg_id) = msg_id {
@@ -137,24 +132,10 @@ impl BroadcastHandler {
return;
}
- // Ensure we don't keep holding the read lock
- let mut targets = self.topology.read().await.clone();
-
- // Only send to neighbours that the source has not sent to.
- // This isn't technically optimal, but its as close as we can get without
- // tracking the path of each broadcast message.
- let our_targets = targets.remove(&self.node_id).unwrap();
- let their_targets = targets
- .remove(&src.to_string())
- .unwrap_or_else(|| HashSet::new());
-
// Race all send futures
+ let path = iter::once(src);
let mut tasks = self.attempted_broadcasts.lock().await;
- for target in our_targets.into_iter() {
- if &target == &src || &target == &self.node_id || their_targets.contains(&target) {
- continue;
- }
-
+ for target in self.topology.targets(&self.node_id, path).await {
let msg_id = gen_msg_id();
let this = self.clone();
tasks.insert(
diff --git a/broadcast/src/main.rs b/broadcast/src/main.rs
index bdc8413..f4f1bf4 100644
--- a/broadcast/src/main.rs
+++ b/broadcast/src/main.rs
@@ -4,6 +4,7 @@ use common::run_server;
mod handler;
mod msg;
+mod topology;
use handler::BroadcastHandler;
diff --git a/broadcast/src/topology.rs b/broadcast/src/topology.rs
new file mode 100644
index 0000000..d91b8ae
--- /dev/null
+++ b/broadcast/src/topology.rs
@@ -0,0 +1,153 @@
+use std::collections::{HashMap, HashSet};
+
+use smol::lock::RwLock;
+
+pub type NodeId = String;
+pub type TopologyDesc = HashMap<NodeId, HashSet<NodeId>>;
+
+pub struct Topology(RwLock<TopologyDesc>);
+
+impl Topology {
+ /// Create a new topology from the given description
+ pub fn new(top: TopologyDesc) -> Self {
+ Topology(RwLock::new(top))
+ }
+
+ /// Create a new topology in which all nodes are connected to each other.
+ pub fn dense(node_ids: Vec<String>) -> Self {
+ let mut top = TopologyDesc::new();
+ for node_id in node_ids.iter() {
+ top.insert(node_id.clone(), node_ids.iter().cloned().collect());
+ }
+
+ Topology(RwLock::new(top))
+ }
+
+ /// Replace the current topology with a new one.
+ pub async fn replace(&self, new: TopologyDesc) {
+ *self.0.write().await = new;
+ }
+
+ /// Get the next targets from the given topology, for a message
+ /// which has travelled across the given path and is now at node_id.
+ pub async fn targets(
+ &self,
+ node_id: &String,
+ path: impl Iterator<Item = &str>,
+ ) -> HashSet<String> {
+ // Ensure we don't keep holding the read lock
+ let topology = self.0.read().await;
+
+ // Get all visited nodes, from all neighbours of all node along the source path
+ let mut visited = HashSet::new();
+ for node in path {
+ visited.insert(node.to_string());
+ if let Some(neighbours) = topology.get(node) {
+ for neighbour in neighbours {
+ visited.insert(neighbour.clone());
+ }
+ }
+ }
+
+ // Send to all neighbours that haven't already been sent to
+ topology
+ .get(node_id)
+ .unwrap()
+ .difference(&visited)
+ .cloned()
+ .filter(|n| n != node_id)
+ .collect()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use std::iter;
+
+ use super::*;
+
+ fn name(x: usize, y: usize) -> String {
+ format!("{},{}", x, y)
+ }
+
+ fn grid(w: usize, h: usize) -> TopologyDesc {
+ let mut top = HashMap::new();
+ for x in 0..w {
+ for y in 0..h {
+ let mut neighbours = HashSet::new();
+ if x > 0 {
+ neighbours.insert(name(x - 1, y));
+ if y > 0 {
+ neighbours.insert(name(x - 1, y - 1));
+ }
+ if y < h - 1 {
+ neighbours.insert(name(x - 1, y + 1));
+ }
+ }
+ if x < h - 1 {
+ neighbours.insert(name(x + 1, y));
+ if y > 0 {
+ neighbours.insert(name(x + 1, y - 1));
+ }
+ if y < h - 1 {
+ neighbours.insert(name(x + 1, y + 1));
+ }
+ }
+
+ if y > 0 {
+ neighbours.insert(name(x, y - 1));
+ }
+ if y < h - 1 {
+ neighbours.insert(name(x, y + 1));
+ }
+
+ top.insert(name(x, y), neighbours);
+ }
+ }
+
+ top
+ }
+
+ #[test]
+ pub fn test_grid_entrypoint() {
+ smol::block_on(async {
+ let top = Topology::new(grid(3, 3));
+
+ // any corner must have 3 targets
+ assert_eq!(top.targets(&name(0, 0), iter::empty()).await.len(), 3);
+ assert_eq!(top.targets(&name(2, 0), iter::empty()).await.len(), 3);
+ assert_eq!(top.targets(&name(2, 2), iter::empty()).await.len(), 3);
+ assert_eq!(top.targets(&name(0, 2), iter::empty()).await.len(), 3);
+
+ // any side must have 5 targets
+ assert_eq!(top.targets(&name(0, 1), iter::empty()).await.len(), 5);
+ assert_eq!(top.targets(&name(1, 0), iter::empty()).await.len(), 5);
+ assert_eq!(top.targets(&name(2, 1), iter::empty()).await.len(), 5);
+ assert_eq!(top.targets(&name(1, 2), iter::empty()).await.len(), 5);
+
+ // the center must have 8 targets
+ assert_eq!(top.targets(&name(1, 1), iter::empty()).await.len(), 8);
+ })
+ }
+
+ #[test]
+ pub fn test_grid_previous() {
+ smol::block_on(async {
+ let top = Topology::new(grid(3, 3));
+
+ // if we've passed through the center, we will never have any targets
+ for x in 0..3 {
+ for y in 0..3 {
+ assert_eq!(
+ dbg!(
+ top.targets(&name(x, y), iter::once(name(1, 1).as_str()))
+ .await
+ )
+ .len(),
+ 0
+ );
+ }
+ }
+ })
+ }
+}