From 42123efe8fd92d6d81b6d5d10ae86866ea9b6a3c Mon Sep 17 00:00:00 2001 From: Aria Date: Sun, 15 Oct 2023 00:57:41 +0100 Subject: some refactors --- broadcast/src/handler.rs | 33 +++------- broadcast/src/main.rs | 1 + broadcast/src/topology.rs | 153 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 161 insertions(+), 26 deletions(-) create mode 100644 broadcast/src/topology.rs (limited to 'broadcast/src') diff --git a/broadcast/src/handler.rs b/broadcast/src/handler.rs index d4bb94a..b55cd9b 100644 --- a/broadcast/src/handler.rs +++ b/broadcast/src/handler.rs @@ -6,11 +6,12 @@ use smol::{ }; use std::{ collections::{HashMap, HashSet}, + iter, sync::Arc, time::Duration, }; -use crate::msg::*; +use crate::{msg::*, topology::Topology}; use common::{ msg::{MessageHeader, Output}, @@ -23,7 +24,7 @@ const RETRY_TIMEOUT_SECS: u64 = 1; pub struct BroadcastHandler { node_id: String, seen: RwLock>, - topology: RwLock>>, + topology: Topology, output: Output, attempted_broadcasts: Mutex>>, } @@ -32,15 +33,9 @@ impl Handler for BroadcastHandler { type Body = BroadcastBody; fn init(node_id: String, node_ids: Vec, output: Output) -> Self { - // Initial topology assumes all nodes are neighbours - let mut topology = HashMap::new(); - for id in node_ids.iter() { - topology.insert(id.clone(), node_ids.iter().cloned().collect()); - } - BroadcastHandler { node_id, - topology: RwLock::new(topology), + topology: Topology::dense(node_ids.clone()), seen: RwLock::new(HashSet::new()), output, attempted_broadcasts: Mutex::default(), @@ -73,7 +68,7 @@ impl Handler for BroadcastHandler { BroadcastBody::Topology { msg_id, topology } => { // Start using the new topology - *self.topology.write().await = topology; + self.topology.replace(topology).await; // Send reply if needed if let Some(msg_id) = msg_id { @@ -137,24 +132,10 @@ impl BroadcastHandler { return; } - // Ensure we don't keep holding the read lock - let mut targets = self.topology.read().await.clone(); - - // Only send to neighbours that the source has not sent to. - // This isn't technically optimal, but its as close as we can get without - // tracking the path of each broadcast message. - let our_targets = targets.remove(&self.node_id).unwrap(); - let their_targets = targets - .remove(&src.to_string()) - .unwrap_or_else(|| HashSet::new()); - // Race all send futures + let path = iter::once(src); let mut tasks = self.attempted_broadcasts.lock().await; - for target in our_targets.into_iter() { - if &target == &src || &target == &self.node_id || their_targets.contains(&target) { - continue; - } - + for target in self.topology.targets(&self.node_id, path).await { let msg_id = gen_msg_id(); let this = self.clone(); tasks.insert( diff --git a/broadcast/src/main.rs b/broadcast/src/main.rs index bdc8413..f4f1bf4 100644 --- a/broadcast/src/main.rs +++ b/broadcast/src/main.rs @@ -4,6 +4,7 @@ use common::run_server; mod handler; mod msg; +mod topology; use handler::BroadcastHandler; diff --git a/broadcast/src/topology.rs b/broadcast/src/topology.rs new file mode 100644 index 0000000..d91b8ae --- /dev/null +++ b/broadcast/src/topology.rs @@ -0,0 +1,153 @@ +use std::collections::{HashMap, HashSet}; + +use smol::lock::RwLock; + +pub type NodeId = String; +pub type TopologyDesc = HashMap>; + +pub struct Topology(RwLock); + +impl Topology { + /// Create a new topology from the given description + pub fn new(top: TopologyDesc) -> Self { + Topology(RwLock::new(top)) + } + + /// Create a new topology in which all nodes are connected to each other. + pub fn dense(node_ids: Vec) -> Self { + let mut top = TopologyDesc::new(); + for node_id in node_ids.iter() { + top.insert(node_id.clone(), node_ids.iter().cloned().collect()); + } + + Topology(RwLock::new(top)) + } + + /// Replace the current topology with a new one. + pub async fn replace(&self, new: TopologyDesc) { + *self.0.write().await = new; + } + + /// Get the next targets from the given topology, for a message + /// which has travelled across the given path and is now at node_id. + pub async fn targets( + &self, + node_id: &String, + path: impl Iterator, + ) -> HashSet { + // Ensure we don't keep holding the read lock + let topology = self.0.read().await; + + // Get all visited nodes, from all neighbours of all node along the source path + let mut visited = HashSet::new(); + for node in path { + visited.insert(node.to_string()); + if let Some(neighbours) = topology.get(node) { + for neighbour in neighbours { + visited.insert(neighbour.clone()); + } + } + } + + // Send to all neighbours that haven't already been sent to + topology + .get(node_id) + .unwrap() + .difference(&visited) + .cloned() + .filter(|n| n != node_id) + .collect() + } +} + +#[cfg(test)] +mod tests { + use std::iter; + + use super::*; + + fn name(x: usize, y: usize) -> String { + format!("{},{}", x, y) + } + + fn grid(w: usize, h: usize) -> TopologyDesc { + let mut top = HashMap::new(); + for x in 0..w { + for y in 0..h { + let mut neighbours = HashSet::new(); + if x > 0 { + neighbours.insert(name(x - 1, y)); + if y > 0 { + neighbours.insert(name(x - 1, y - 1)); + } + if y < h - 1 { + neighbours.insert(name(x - 1, y + 1)); + } + } + if x < h - 1 { + neighbours.insert(name(x + 1, y)); + if y > 0 { + neighbours.insert(name(x + 1, y - 1)); + } + if y < h - 1 { + neighbours.insert(name(x + 1, y + 1)); + } + } + + if y > 0 { + neighbours.insert(name(x, y - 1)); + } + if y < h - 1 { + neighbours.insert(name(x, y + 1)); + } + + top.insert(name(x, y), neighbours); + } + } + + top + } + + #[test] + pub fn test_grid_entrypoint() { + smol::block_on(async { + let top = Topology::new(grid(3, 3)); + + // any corner must have 3 targets + assert_eq!(top.targets(&name(0, 0), iter::empty()).await.len(), 3); + assert_eq!(top.targets(&name(2, 0), iter::empty()).await.len(), 3); + assert_eq!(top.targets(&name(2, 2), iter::empty()).await.len(), 3); + assert_eq!(top.targets(&name(0, 2), iter::empty()).await.len(), 3); + + // any side must have 5 targets + assert_eq!(top.targets(&name(0, 1), iter::empty()).await.len(), 5); + assert_eq!(top.targets(&name(1, 0), iter::empty()).await.len(), 5); + assert_eq!(top.targets(&name(2, 1), iter::empty()).await.len(), 5); + assert_eq!(top.targets(&name(1, 2), iter::empty()).await.len(), 5); + + // the center must have 8 targets + assert_eq!(top.targets(&name(1, 1), iter::empty()).await.len(), 8); + }) + } + + #[test] + pub fn test_grid_previous() { + smol::block_on(async { + let top = Topology::new(grid(3, 3)); + + // if we've passed through the center, we will never have any targets + for x in 0..3 { + for y in 0..3 { + assert_eq!( + dbg!( + top.targets(&name(x, y), iter::once(name(1, 1).as_str())) + .await + ) + .len(), + 0 + ); + } + } + }) + } +} -- cgit v1.2.3