summaryrefslogtreecommitdiff
path: root/broadcast/src/handler.rs
diff options
context:
space:
mode:
authorAria <me@aria.rip>2023-10-13 14:13:58 +0100
committerAria <me@aria.rip>2023-10-13 14:13:58 +0100
commitb2d679f05d04052bfc25167eaaf09c60c03251cb (patch)
tree7ca49d117a3167169e5b92613ca21c88c12bd47f /broadcast/src/handler.rs
parentc063f4da42a538138cc3e80a0e1faaf813a13bd2 (diff)
wip: fault tolerant broadcast
Diffstat (limited to 'broadcast/src/handler.rs')
-rw-r--r--broadcast/src/handler.rs167
1 files changed, 167 insertions, 0 deletions
diff --git a/broadcast/src/handler.rs b/broadcast/src/handler.rs
new file mode 100644
index 0000000..09c66b9
--- /dev/null
+++ b/broadcast/src/handler.rs
@@ -0,0 +1,167 @@
+use smol::{
+ lock::{Mutex, RwLock},
+ prelude::*,
+ Task, Timer,
+};
+use std::{
+ collections::{HashMap, HashSet},
+ sync::Arc,
+ time::Duration,
+};
+
+use crate::msg::*;
+
+use common::{
+ msg::{MessageHeader, Output},
+ msg_id::{gen_msg_id, MessageID},
+ Handler,
+};
+
+const RETRY_TIMEOUT: u64 = 2;
+
+pub struct BroadcastHandler {
+ node_id: String,
+ seen: RwLock<HashSet<BroadcastTarget>>,
+ broadcast_targets: RwLock<Vec<String>>,
+ output: Output<BroadcastBody>,
+ attempted_broadcasts: Mutex<HashMap<MessageID, Task<()>>>,
+}
+
+impl Handler for BroadcastHandler {
+ type Body = BroadcastBody;
+
+ fn init(node_id: String, mut node_ids: Vec<String>, output: Output<Self::Body>) -> Self {
+ node_ids.retain(|x| *x != node_id);
+
+ BroadcastHandler {
+ node_id,
+ broadcast_targets: RwLock::new(node_ids),
+ seen: RwLock::new(HashSet::new()),
+ output,
+ attempted_broadcasts: Mutex::default(),
+ }
+ }
+
+ fn handle<'a>(
+ self: &'a Arc<Self>,
+ header: MessageHeader,
+ body: BroadcastBody,
+ ) -> impl Future<Output = ()> + Send + 'a {
+ async move {
+ match body {
+ BroadcastBody::Broadcast {
+ msg_id: Some(msg_id),
+ message,
+ } => {
+ self.receive_broadcast(&header.src, message).await;
+ self.send_broadcast_ok(&header.src, msg_id).await;
+ }
+ BroadcastBody::Broadcast {
+ msg_id: None,
+ message,
+ } => {
+ self.receive_broadcast(&header.src, message).await;
+ }
+
+ BroadcastBody::Topology {
+ msg_id,
+ mut topology,
+ } => {
+ // Start using the new topology
+ if let Some(broadcast_targets) = topology.remove(&self.node_id) {
+ *self.broadcast_targets.write().await = broadcast_targets;
+ }
+
+ // Send reply if needed
+ if let Some(msg_id) = msg_id {
+ self.output
+ .send(
+ &header.src,
+ &BroadcastBody::TopologyOk {
+ in_reply_to: msg_id,
+ },
+ )
+ .await;
+ }
+ }
+
+ BroadcastBody::Read { msg_id } => {
+ // Send all received messages back
+ self.output
+ .send(
+ &header.src,
+ &BroadcastBody::ReadOk {
+ in_reply_to: msg_id,
+ messages: self.seen.read().await.clone(),
+ },
+ )
+ .await
+ }
+
+ // Ignore OK messages - we never actually request them
+ BroadcastBody::BroadcastOk { in_reply_to } => {
+ if let Some(task) = self.attempted_broadcasts.lock().await.remove(&in_reply_to)
+ {
+ task.cancel().await;
+ }
+ }
+ BroadcastBody::TopologyOk { .. } => {}
+ BroadcastBody::ReadOk { .. } => {}
+ }
+ }
+ }
+}
+
+impl BroadcastHandler {
+ /// Reply with a broadcast OK message
+ async fn send_broadcast_ok(&self, src: &str, msg_id: MessageID) {
+ self.output
+ .send(
+ &src,
+ &BroadcastBody::BroadcastOk {
+ in_reply_to: msg_id,
+ },
+ )
+ .await;
+ }
+
+ /// Receive a given message, and broadcast it onwards if it is new
+ async fn receive_broadcast(self: &Arc<Self>, src: &str, message: BroadcastTarget) {
+ let new = self.seen.write().await.insert(message);
+ if !new {
+ return;
+ }
+
+ // Ensure we don't keep holding the read lock
+ let targets = self.broadcast_targets.read().await.clone();
+
+ // Race all send futures
+ let mut tasks = self.attempted_broadcasts.lock().await;
+ for target in targets.into_iter() {
+ if &target == &src {
+ return;
+ }
+
+ let msg_id = gen_msg_id();
+ let this = self.clone();
+ tasks.insert(
+ msg_id,
+ smol::spawn(async move {
+ loop {
+ this.output
+ .send(
+ &target,
+ &BroadcastBody::Broadcast {
+ msg_id: None,
+ message,
+ },
+ )
+ .await;
+
+ Timer::after(Duration::from_secs(RETRY_TIMEOUT)).await;
+ }
+ }),
+ );
+ }
+ }
+}