summaryrefslogtreecommitdiff
path: root/broadcast/src/topology.rs
blob: d91b8ae7b29ba3fdd9d63434d16f97e94fa28c32 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
use std::collections::{HashMap, HashSet};

use smol::lock::RwLock;

pub type NodeId = String;
pub type TopologyDesc = HashMap<NodeId, HashSet<NodeId>>;

pub struct Topology(RwLock<TopologyDesc>);

impl Topology {
    /// Create a new topology from the given description
    pub fn new(top: TopologyDesc) -> Self {
        Topology(RwLock::new(top))
    }

    /// Create a new topology in which all nodes are connected to each other.
    pub fn dense(node_ids: Vec<String>) -> Self {
        let mut top = TopologyDesc::new();
        for node_id in node_ids.iter() {
            top.insert(node_id.clone(), node_ids.iter().cloned().collect());
        }

        Topology(RwLock::new(top))
    }

    /// Replace the current topology with a new one.
    pub async fn replace(&self, new: TopologyDesc) {
        *self.0.write().await = new;
    }

    /// Get the next targets from the given topology, for a message
    /// which has travelled across the given path and is now at node_id.
    pub async fn targets(
        &self,
        node_id: &String,
        path: impl Iterator<Item = &str>,
    ) -> HashSet<String> {
        // Ensure we don't keep holding the read lock
        let topology = self.0.read().await;

        // Get all visited nodes, from all neighbours of all node along the source path
        let mut visited = HashSet::new();
        for node in path {
            visited.insert(node.to_string());
            if let Some(neighbours) = topology.get(node) {
                for neighbour in neighbours {
                    visited.insert(neighbour.clone());
                }
            }
        }

        // Send to all neighbours that haven't already been sent to
        topology
            .get(node_id)
            .unwrap()
            .difference(&visited)
            .cloned()
            .filter(|n| n != node_id)
            .collect()
    }
}

#[cfg(test)]
mod tests {
    use std::iter;

    use super::*;

    fn name(x: usize, y: usize) -> String {
        format!("{},{}", x, y)
    }

    fn grid(w: usize, h: usize) -> TopologyDesc {
        let mut top = HashMap::new();
        for x in 0..w {
            for y in 0..h {
                let mut neighbours = HashSet::new();
                if x > 0 {
                    neighbours.insert(name(x - 1, y));
                    if y > 0 {
                        neighbours.insert(name(x - 1, y - 1));
                    }
                    if y < h - 1 {
                        neighbours.insert(name(x - 1, y + 1));
                    }
                }
                if x < h - 1 {
                    neighbours.insert(name(x + 1, y));
                    if y > 0 {
                        neighbours.insert(name(x + 1, y - 1));
                    }
                    if y < h - 1 {
                        neighbours.insert(name(x + 1, y + 1));
                    }
                }

                if y > 0 {
                    neighbours.insert(name(x, y - 1));
                }
                if y < h - 1 {
                    neighbours.insert(name(x, y + 1));
                }

                top.insert(name(x, y), neighbours);
            }
        }

        top
    }

    #[test]
    pub fn test_grid_entrypoint() {
        smol::block_on(async {
            let top = Topology::new(grid(3, 3));

            // any corner must have 3 targets
            assert_eq!(top.targets(&name(0, 0), iter::empty()).await.len(), 3);
            assert_eq!(top.targets(&name(2, 0), iter::empty()).await.len(), 3);
            assert_eq!(top.targets(&name(2, 2), iter::empty()).await.len(), 3);
            assert_eq!(top.targets(&name(0, 2), iter::empty()).await.len(), 3);

            // any side must have 5 targets
            assert_eq!(top.targets(&name(0, 1), iter::empty()).await.len(), 5);
            assert_eq!(top.targets(&name(1, 0), iter::empty()).await.len(), 5);
            assert_eq!(top.targets(&name(2, 1), iter::empty()).await.len(), 5);
            assert_eq!(top.targets(&name(1, 2), iter::empty()).await.len(), 5);

            // the center must have 8 targets
            assert_eq!(top.targets(&name(1, 1), iter::empty()).await.len(), 8);
        })
    }

    #[test]
    pub fn test_grid_previous() {
        smol::block_on(async {
            let top = Topology::new(grid(3, 3));

            // if we've passed through the center, we will never have any targets
            for x in 0..3 {
                for y in 0..3 {
                    assert_eq!(
                        dbg!(
                            top.targets(&name(x, y), iter::once(name(1, 1).as_str()))
                                .await
                        )
                        .len(),
                        0
                    );
                }
            }
        })
    }
}