From 42123efe8fd92d6d81b6d5d10ae86866ea9b6a3c Mon Sep 17 00:00:00 2001 From: Aria Date: Sun, 15 Oct 2023 00:57:41 +0100 Subject: some refactors --- broadcast/src/topology.rs | 153 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 153 insertions(+) create mode 100644 broadcast/src/topology.rs (limited to 'broadcast/src/topology.rs') diff --git a/broadcast/src/topology.rs b/broadcast/src/topology.rs new file mode 100644 index 0000000..d91b8ae --- /dev/null +++ b/broadcast/src/topology.rs @@ -0,0 +1,153 @@ +use std::collections::{HashMap, HashSet}; + +use smol::lock::RwLock; + +pub type NodeId = String; +pub type TopologyDesc = HashMap>; + +pub struct Topology(RwLock); + +impl Topology { + /// Create a new topology from the given description + pub fn new(top: TopologyDesc) -> Self { + Topology(RwLock::new(top)) + } + + /// Create a new topology in which all nodes are connected to each other. + pub fn dense(node_ids: Vec) -> Self { + let mut top = TopologyDesc::new(); + for node_id in node_ids.iter() { + top.insert(node_id.clone(), node_ids.iter().cloned().collect()); + } + + Topology(RwLock::new(top)) + } + + /// Replace the current topology with a new one. + pub async fn replace(&self, new: TopologyDesc) { + *self.0.write().await = new; + } + + /// Get the next targets from the given topology, for a message + /// which has travelled across the given path and is now at node_id. + pub async fn targets( + &self, + node_id: &String, + path: impl Iterator, + ) -> HashSet { + // Ensure we don't keep holding the read lock + let topology = self.0.read().await; + + // Get all visited nodes, from all neighbours of all node along the source path + let mut visited = HashSet::new(); + for node in path { + visited.insert(node.to_string()); + if let Some(neighbours) = topology.get(node) { + for neighbour in neighbours { + visited.insert(neighbour.clone()); + } + } + } + + // Send to all neighbours that haven't already been sent to + topology + .get(node_id) + .unwrap() + .difference(&visited) + .cloned() + .filter(|n| n != node_id) + .collect() + } +} + +#[cfg(test)] +mod tests { + use std::iter; + + use super::*; + + fn name(x: usize, y: usize) -> String { + format!("{},{}", x, y) + } + + fn grid(w: usize, h: usize) -> TopologyDesc { + let mut top = HashMap::new(); + for x in 0..w { + for y in 0..h { + let mut neighbours = HashSet::new(); + if x > 0 { + neighbours.insert(name(x - 1, y)); + if y > 0 { + neighbours.insert(name(x - 1, y - 1)); + } + if y < h - 1 { + neighbours.insert(name(x - 1, y + 1)); + } + } + if x < h - 1 { + neighbours.insert(name(x + 1, y)); + if y > 0 { + neighbours.insert(name(x + 1, y - 1)); + } + if y < h - 1 { + neighbours.insert(name(x + 1, y + 1)); + } + } + + if y > 0 { + neighbours.insert(name(x, y - 1)); + } + if y < h - 1 { + neighbours.insert(name(x, y + 1)); + } + + top.insert(name(x, y), neighbours); + } + } + + top + } + + #[test] + pub fn test_grid_entrypoint() { + smol::block_on(async { + let top = Topology::new(grid(3, 3)); + + // any corner must have 3 targets + assert_eq!(top.targets(&name(0, 0), iter::empty()).await.len(), 3); + assert_eq!(top.targets(&name(2, 0), iter::empty()).await.len(), 3); + assert_eq!(top.targets(&name(2, 2), iter::empty()).await.len(), 3); + assert_eq!(top.targets(&name(0, 2), iter::empty()).await.len(), 3); + + // any side must have 5 targets + assert_eq!(top.targets(&name(0, 1), iter::empty()).await.len(), 5); + assert_eq!(top.targets(&name(1, 0), iter::empty()).await.len(), 5); + assert_eq!(top.targets(&name(2, 1), iter::empty()).await.len(), 5); + assert_eq!(top.targets(&name(1, 2), iter::empty()).await.len(), 5); + + // the center must have 8 targets + assert_eq!(top.targets(&name(1, 1), iter::empty()).await.len(), 8); + }) + } + + #[test] + pub fn test_grid_previous() { + smol::block_on(async { + let top = Topology::new(grid(3, 3)); + + // if we've passed through the center, we will never have any targets + for x in 0..3 { + for y in 0..3 { + assert_eq!( + dbg!( + top.targets(&name(x, y), iter::once(name(1, 1).as_str())) + .await + ) + .len(), + 0 + ); + } + } + }) + } +} -- cgit v1.2.3