summaryrefslogtreecommitdiff
path: root/broadcast/src/topology.rs
diff options
context:
space:
mode:
authorAria <me@aria.rip>2023-10-19 21:00:26 +0100
committerAria <me@aria.rip>2023-10-19 21:00:26 +0100
commitdba41282cac86a740c007498709e996b9fa3e59b (patch)
tree5d9309962f55ca00be6fee3b7ac24b7314e86f2a /broadcast/src/topology.rs
parent07e2085190e30010ad595369a07842413bacd3d1 (diff)
wip: broadcast message batching
Diffstat (limited to 'broadcast/src/topology.rs')
-rw-r--r--broadcast/src/topology.rs128
1 files changed, 2 insertions, 126 deletions
diff --git a/broadcast/src/topology.rs b/broadcast/src/topology.rs
index d91b8ae..5e16666 100644
--- a/broadcast/src/topology.rs
+++ b/broadcast/src/topology.rs
@@ -8,11 +8,6 @@ pub type TopologyDesc = HashMap<NodeId, HashSet<NodeId>>;
pub struct Topology(RwLock<TopologyDesc>);
impl Topology {
- /// Create a new topology from the given description
- pub fn new(top: TopologyDesc) -> Self {
- Topology(RwLock::new(top))
- }
-
/// Create a new topology in which all nodes are connected to each other.
pub fn dense(node_ids: Vec<String>) -> Self {
let mut top = TopologyDesc::new();
@@ -28,126 +23,7 @@ impl Topology {
*self.0.write().await = new;
}
- /// Get the next targets from the given topology, for a message
- /// which has travelled across the given path and is now at node_id.
- pub async fn targets(
- &self,
- node_id: &String,
- path: impl Iterator<Item = &str>,
- ) -> HashSet<String> {
- // Ensure we don't keep holding the read lock
- let topology = self.0.read().await;
-
- // Get all visited nodes, from all neighbours of all node along the source path
- let mut visited = HashSet::new();
- for node in path {
- visited.insert(node.to_string());
- if let Some(neighbours) = topology.get(node) {
- for neighbour in neighbours {
- visited.insert(neighbour.clone());
- }
- }
- }
-
- // Send to all neighbours that haven't already been sent to
- topology
- .get(node_id)
- .unwrap()
- .difference(&visited)
- .cloned()
- .filter(|n| n != node_id)
- .collect()
- }
-}
-
-#[cfg(test)]
-mod tests {
- use std::iter;
-
- use super::*;
-
- fn name(x: usize, y: usize) -> String {
- format!("{},{}", x, y)
- }
-
- fn grid(w: usize, h: usize) -> TopologyDesc {
- let mut top = HashMap::new();
- for x in 0..w {
- for y in 0..h {
- let mut neighbours = HashSet::new();
- if x > 0 {
- neighbours.insert(name(x - 1, y));
- if y > 0 {
- neighbours.insert(name(x - 1, y - 1));
- }
- if y < h - 1 {
- neighbours.insert(name(x - 1, y + 1));
- }
- }
- if x < h - 1 {
- neighbours.insert(name(x + 1, y));
- if y > 0 {
- neighbours.insert(name(x + 1, y - 1));
- }
- if y < h - 1 {
- neighbours.insert(name(x + 1, y + 1));
- }
- }
-
- if y > 0 {
- neighbours.insert(name(x, y - 1));
- }
- if y < h - 1 {
- neighbours.insert(name(x, y + 1));
- }
-
- top.insert(name(x, y), neighbours);
- }
- }
-
- top
- }
-
- #[test]
- pub fn test_grid_entrypoint() {
- smol::block_on(async {
- let top = Topology::new(grid(3, 3));
-
- // any corner must have 3 targets
- assert_eq!(top.targets(&name(0, 0), iter::empty()).await.len(), 3);
- assert_eq!(top.targets(&name(2, 0), iter::empty()).await.len(), 3);
- assert_eq!(top.targets(&name(2, 2), iter::empty()).await.len(), 3);
- assert_eq!(top.targets(&name(0, 2), iter::empty()).await.len(), 3);
-
- // any side must have 5 targets
- assert_eq!(top.targets(&name(0, 1), iter::empty()).await.len(), 5);
- assert_eq!(top.targets(&name(1, 0), iter::empty()).await.len(), 5);
- assert_eq!(top.targets(&name(2, 1), iter::empty()).await.len(), 5);
- assert_eq!(top.targets(&name(1, 2), iter::empty()).await.len(), 5);
-
- // the center must have 8 targets
- assert_eq!(top.targets(&name(1, 1), iter::empty()).await.len(), 8);
- })
- }
-
- #[test]
- pub fn test_grid_previous() {
- smol::block_on(async {
- let top = Topology::new(grid(3, 3));
-
- // if we've passed through the center, we will never have any targets
- for x in 0..3 {
- for y in 0..3 {
- assert_eq!(
- dbg!(
- top.targets(&name(x, y), iter::once(name(1, 1).as_str()))
- .await
- )
- .len(),
- 0
- );
- }
- }
- })
+ pub async fn all_targets(&self, node_id: &NodeId) -> HashSet<String> {
+ self.0.read().await.get(node_id).unwrap().clone()
}
}