diff options
author | Aria <me@aria.rip> | 2023-10-19 21:00:26 +0100 |
---|---|---|
committer | Aria <me@aria.rip> | 2023-10-19 21:00:26 +0100 |
commit | dba41282cac86a740c007498709e996b9fa3e59b (patch) | |
tree | 5d9309962f55ca00be6fee3b7ac24b7314e86f2a /broadcast/src/topology.rs | |
parent | 07e2085190e30010ad595369a07842413bacd3d1 (diff) |
wip: broadcast message batching
Diffstat (limited to 'broadcast/src/topology.rs')
-rw-r--r-- | broadcast/src/topology.rs | 128 |
1 files changed, 2 insertions, 126 deletions
diff --git a/broadcast/src/topology.rs b/broadcast/src/topology.rs index d91b8ae..5e16666 100644 --- a/broadcast/src/topology.rs +++ b/broadcast/src/topology.rs @@ -8,11 +8,6 @@ pub type TopologyDesc = HashMap<NodeId, HashSet<NodeId>>; pub struct Topology(RwLock<TopologyDesc>); impl Topology { - /// Create a new topology from the given description - pub fn new(top: TopologyDesc) -> Self { - Topology(RwLock::new(top)) - } - /// Create a new topology in which all nodes are connected to each other. pub fn dense(node_ids: Vec<String>) -> Self { let mut top = TopologyDesc::new(); @@ -28,126 +23,7 @@ impl Topology { *self.0.write().await = new; } - /// Get the next targets from the given topology, for a message - /// which has travelled across the given path and is now at node_id. - pub async fn targets( - &self, - node_id: &String, - path: impl Iterator<Item = &str>, - ) -> HashSet<String> { - // Ensure we don't keep holding the read lock - let topology = self.0.read().await; - - // Get all visited nodes, from all neighbours of all node along the source path - let mut visited = HashSet::new(); - for node in path { - visited.insert(node.to_string()); - if let Some(neighbours) = topology.get(node) { - for neighbour in neighbours { - visited.insert(neighbour.clone()); - } - } - } - - // Send to all neighbours that haven't already been sent to - topology - .get(node_id) - .unwrap() - .difference(&visited) - .cloned() - .filter(|n| n != node_id) - .collect() - } -} - -#[cfg(test)] -mod tests { - use std::iter; - - use super::*; - - fn name(x: usize, y: usize) -> String { - format!("{},{}", x, y) - } - - fn grid(w: usize, h: usize) -> TopologyDesc { - let mut top = HashMap::new(); - for x in 0..w { - for y in 0..h { - let mut neighbours = HashSet::new(); - if x > 0 { - neighbours.insert(name(x - 1, y)); - if y > 0 { - neighbours.insert(name(x - 1, y - 1)); - } - if y < h - 1 { - neighbours.insert(name(x - 1, y + 1)); - } - } - if x < h - 1 { - neighbours.insert(name(x + 1, y)); - if y > 0 { - neighbours.insert(name(x + 1, y - 1)); - } - if y < h - 1 { - neighbours.insert(name(x + 1, y + 1)); - } - } - - if y > 0 { - neighbours.insert(name(x, y - 1)); - } - if y < h - 1 { - neighbours.insert(name(x, y + 1)); - } - - top.insert(name(x, y), neighbours); - } - } - - top - } - - #[test] - pub fn test_grid_entrypoint() { - smol::block_on(async { - let top = Topology::new(grid(3, 3)); - - // any corner must have 3 targets - assert_eq!(top.targets(&name(0, 0), iter::empty()).await.len(), 3); - assert_eq!(top.targets(&name(2, 0), iter::empty()).await.len(), 3); - assert_eq!(top.targets(&name(2, 2), iter::empty()).await.len(), 3); - assert_eq!(top.targets(&name(0, 2), iter::empty()).await.len(), 3); - - // any side must have 5 targets - assert_eq!(top.targets(&name(0, 1), iter::empty()).await.len(), 5); - assert_eq!(top.targets(&name(1, 0), iter::empty()).await.len(), 5); - assert_eq!(top.targets(&name(2, 1), iter::empty()).await.len(), 5); - assert_eq!(top.targets(&name(1, 2), iter::empty()).await.len(), 5); - - // the center must have 8 targets - assert_eq!(top.targets(&name(1, 1), iter::empty()).await.len(), 8); - }) - } - - #[test] - pub fn test_grid_previous() { - smol::block_on(async { - let top = Topology::new(grid(3, 3)); - - // if we've passed through the center, we will never have any targets - for x in 0..3 { - for y in 0..3 { - assert_eq!( - dbg!( - top.targets(&name(x, y), iter::once(name(1, 1).as_str())) - .await - ) - .len(), - 0 - ); - } - } - }) + pub async fn all_targets(&self, node_id: &NodeId) -> HashSet<String> { + self.0.read().await.get(node_id).unwrap().clone() } } |