summaryrefslogtreecommitdiff
path: root/broadcast/src/topology.rs
diff options
context:
space:
mode:
authorAria <me@aria.rip>2023-10-15 00:57:41 +0100
committerAria <me@aria.rip>2023-10-15 00:57:41 +0100
commit42123efe8fd92d6d81b6d5d10ae86866ea9b6a3c (patch)
tree2dfcb1a27c337d1c9b89f54aff35e94795970b74 /broadcast/src/topology.rs
parent7447f3fb801ba954c7b8cbf3f47700ffcc562d20 (diff)
some refactors
Diffstat (limited to 'broadcast/src/topology.rs')
-rw-r--r--broadcast/src/topology.rs153
1 files changed, 153 insertions, 0 deletions
diff --git a/broadcast/src/topology.rs b/broadcast/src/topology.rs
new file mode 100644
index 0000000..d91b8ae
--- /dev/null
+++ b/broadcast/src/topology.rs
@@ -0,0 +1,153 @@
+use std::collections::{HashMap, HashSet};
+
+use smol::lock::RwLock;
+
+pub type NodeId = String;
+pub type TopologyDesc = HashMap<NodeId, HashSet<NodeId>>;
+
+pub struct Topology(RwLock<TopologyDesc>);
+
+impl Topology {
+ /// Create a new topology from the given description
+ pub fn new(top: TopologyDesc) -> Self {
+ Topology(RwLock::new(top))
+ }
+
+ /// Create a new topology in which all nodes are connected to each other.
+ pub fn dense(node_ids: Vec<String>) -> Self {
+ let mut top = TopologyDesc::new();
+ for node_id in node_ids.iter() {
+ top.insert(node_id.clone(), node_ids.iter().cloned().collect());
+ }
+
+ Topology(RwLock::new(top))
+ }
+
+ /// Replace the current topology with a new one.
+ pub async fn replace(&self, new: TopologyDesc) {
+ *self.0.write().await = new;
+ }
+
+ /// Get the next targets from the given topology, for a message
+ /// which has travelled across the given path and is now at node_id.
+ pub async fn targets(
+ &self,
+ node_id: &String,
+ path: impl Iterator<Item = &str>,
+ ) -> HashSet<String> {
+ // Ensure we don't keep holding the read lock
+ let topology = self.0.read().await;
+
+ // Get all visited nodes, from all neighbours of all node along the source path
+ let mut visited = HashSet::new();
+ for node in path {
+ visited.insert(node.to_string());
+ if let Some(neighbours) = topology.get(node) {
+ for neighbour in neighbours {
+ visited.insert(neighbour.clone());
+ }
+ }
+ }
+
+ // Send to all neighbours that haven't already been sent to
+ topology
+ .get(node_id)
+ .unwrap()
+ .difference(&visited)
+ .cloned()
+ .filter(|n| n != node_id)
+ .collect()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use std::iter;
+
+ use super::*;
+
+ fn name(x: usize, y: usize) -> String {
+ format!("{},{}", x, y)
+ }
+
+ fn grid(w: usize, h: usize) -> TopologyDesc {
+ let mut top = HashMap::new();
+ for x in 0..w {
+ for y in 0..h {
+ let mut neighbours = HashSet::new();
+ if x > 0 {
+ neighbours.insert(name(x - 1, y));
+ if y > 0 {
+ neighbours.insert(name(x - 1, y - 1));
+ }
+ if y < h - 1 {
+ neighbours.insert(name(x - 1, y + 1));
+ }
+ }
+ if x < h - 1 {
+ neighbours.insert(name(x + 1, y));
+ if y > 0 {
+ neighbours.insert(name(x + 1, y - 1));
+ }
+ if y < h - 1 {
+ neighbours.insert(name(x + 1, y + 1));
+ }
+ }
+
+ if y > 0 {
+ neighbours.insert(name(x, y - 1));
+ }
+ if y < h - 1 {
+ neighbours.insert(name(x, y + 1));
+ }
+
+ top.insert(name(x, y), neighbours);
+ }
+ }
+
+ top
+ }
+
+ #[test]
+ pub fn test_grid_entrypoint() {
+ smol::block_on(async {
+ let top = Topology::new(grid(3, 3));
+
+ // any corner must have 3 targets
+ assert_eq!(top.targets(&name(0, 0), iter::empty()).await.len(), 3);
+ assert_eq!(top.targets(&name(2, 0), iter::empty()).await.len(), 3);
+ assert_eq!(top.targets(&name(2, 2), iter::empty()).await.len(), 3);
+ assert_eq!(top.targets(&name(0, 2), iter::empty()).await.len(), 3);
+
+ // any side must have 5 targets
+ assert_eq!(top.targets(&name(0, 1), iter::empty()).await.len(), 5);
+ assert_eq!(top.targets(&name(1, 0), iter::empty()).await.len(), 5);
+ assert_eq!(top.targets(&name(2, 1), iter::empty()).await.len(), 5);
+ assert_eq!(top.targets(&name(1, 2), iter::empty()).await.len(), 5);
+
+ // the center must have 8 targets
+ assert_eq!(top.targets(&name(1, 1), iter::empty()).await.len(), 8);
+ })
+ }
+
+ #[test]
+ pub fn test_grid_previous() {
+ smol::block_on(async {
+ let top = Topology::new(grid(3, 3));
+
+ // if we've passed through the center, we will never have any targets
+ for x in 0..3 {
+ for y in 0..3 {
+ assert_eq!(
+ dbg!(
+ top.targets(&name(x, y), iter::once(name(1, 1).as_str()))
+ .await
+ )
+ .len(),
+ 0
+ );
+ }
+ }
+ })
+ }
+}