use std::collections::{HashMap, HashSet}; use smol::lock::RwLock; pub type NodeId = String; pub type TopologyDesc = HashMap>; pub struct Topology(RwLock); impl Topology { /// Create a new topology from the given description pub fn new(top: TopologyDesc) -> Self { Topology(RwLock::new(top)) } /// Create a new topology in which all nodes are connected to each other. pub fn dense(node_ids: Vec) -> Self { let mut top = TopologyDesc::new(); for node_id in node_ids.iter() { top.insert(node_id.clone(), node_ids.iter().cloned().collect()); } Topology(RwLock::new(top)) } /// Replace the current topology with a new one. pub async fn replace(&self, new: TopologyDesc) { *self.0.write().await = new; } /// Get the next targets from the given topology, for a message /// which has travelled across the given path and is now at node_id. pub async fn targets( &self, node_id: &String, path: impl Iterator, ) -> HashSet { // Ensure we don't keep holding the read lock let topology = self.0.read().await; // Get all visited nodes, from all neighbours of all node along the source path let mut visited = HashSet::new(); for node in path { visited.insert(node.to_string()); if let Some(neighbours) = topology.get(node) { for neighbour in neighbours { visited.insert(neighbour.clone()); } } } // Send to all neighbours that haven't already been sent to topology .get(node_id) .unwrap() .difference(&visited) .cloned() .filter(|n| n != node_id) .collect() } } #[cfg(test)] mod tests { use std::iter; use super::*; fn name(x: usize, y: usize) -> String { format!("{},{}", x, y) } fn grid(w: usize, h: usize) -> TopologyDesc { let mut top = HashMap::new(); for x in 0..w { for y in 0..h { let mut neighbours = HashSet::new(); if x > 0 { neighbours.insert(name(x - 1, y)); if y > 0 { neighbours.insert(name(x - 1, y - 1)); } if y < h - 1 { neighbours.insert(name(x - 1, y + 1)); } } if x < h - 1 { neighbours.insert(name(x + 1, y)); if y > 0 { neighbours.insert(name(x + 1, y - 1)); } if y < h - 1 { neighbours.insert(name(x + 1, y + 1)); } } if y > 0 { neighbours.insert(name(x, y - 1)); } if y < h - 1 { neighbours.insert(name(x, y + 1)); } top.insert(name(x, y), neighbours); } } top } #[test] pub fn test_grid_entrypoint() { smol::block_on(async { let top = Topology::new(grid(3, 3)); // any corner must have 3 targets assert_eq!(top.targets(&name(0, 0), iter::empty()).await.len(), 3); assert_eq!(top.targets(&name(2, 0), iter::empty()).await.len(), 3); assert_eq!(top.targets(&name(2, 2), iter::empty()).await.len(), 3); assert_eq!(top.targets(&name(0, 2), iter::empty()).await.len(), 3); // any side must have 5 targets assert_eq!(top.targets(&name(0, 1), iter::empty()).await.len(), 5); assert_eq!(top.targets(&name(1, 0), iter::empty()).await.len(), 5); assert_eq!(top.targets(&name(2, 1), iter::empty()).await.len(), 5); assert_eq!(top.targets(&name(1, 2), iter::empty()).await.len(), 5); // the center must have 8 targets assert_eq!(top.targets(&name(1, 1), iter::empty()).await.len(), 8); }) } #[test] pub fn test_grid_previous() { smol::block_on(async { let top = Topology::new(grid(3, 3)); // if we've passed through the center, we will never have any targets for x in 0..3 { for y in 0..3 { assert_eq!( dbg!( top.targets(&name(x, y), iter::once(name(1, 1).as_str())) .await ) .len(), 0 ); } } }) } }