summaryrefslogtreecommitdiff
path: root/broadcast/src/topology.rs
diff options
context:
space:
mode:
Diffstat (limited to 'broadcast/src/topology.rs')
-rw-r--r--broadcast/src/topology.rs30
1 files changed, 27 insertions, 3 deletions
diff --git a/broadcast/src/topology.rs b/broadcast/src/topology.rs
index 5e16666..bb10466 100644
--- a/broadcast/src/topology.rs
+++ b/broadcast/src/topology.rs
@@ -14,16 +14,40 @@ impl Topology {
for node_id in node_ids.iter() {
top.insert(node_id.clone(), node_ids.iter().cloned().collect());
}
+ Self::filter_desc(&mut top);
Topology(RwLock::new(top))
}
/// Replace the current topology with a new one.
- pub async fn replace(&self, new: TopologyDesc) {
+ pub async fn replace(&self, mut new: TopologyDesc) {
+ Self::filter_desc(&mut new);
*self.0.write().await = new;
}
- pub async fn all_targets(&self, node_id: &NodeId) -> HashSet<String> {
- self.0.read().await.get(node_id).unwrap().clone()
+ fn filter_desc(desc: &mut TopologyDesc) {
+ for (node_id, neighbours) in desc.iter_mut() {
+ neighbours.remove(node_id);
+ }
+ }
+
+ /// Get the next targets from the given topology, for a message
+ /// which has travelled across the given path and is now at node_id.
+ pub async fn targets(&self, node_id: &String, last_node_id: &String) -> HashSet<String> {
+ // Ensure we don't keep holding the read lock
+ let topology = self.0.read().await;
+
+ // Get all nodes the last node sent it to
+ let visited = topology.get(last_node_id);
+
+ let neighbours = topology.get(node_id).unwrap();
+ match visited {
+ Some(visited) => neighbours
+ .difference(&visited)
+ .cloned()
+ .filter(|n| n != node_id)
+ .collect(),
+ None => neighbours.clone(),
+ }
}
}