diff options
author | Aria <me@aria.rip> | 2023-10-13 14:13:58 +0100 |
---|---|---|
committer | Aria <me@aria.rip> | 2023-10-13 14:13:58 +0100 |
commit | b2d679f05d04052bfc25167eaaf09c60c03251cb (patch) | |
tree | 7ca49d117a3167169e5b92613ca21c88c12bd47f /broadcast/src | |
parent | c063f4da42a538138cc3e80a0e1faaf813a13bd2 (diff) |
wip: fault tolerant broadcast
Diffstat (limited to 'broadcast/src')
-rw-r--r-- | broadcast/src/handler.rs | 167 | ||||
-rw-r--r-- | broadcast/src/main.rs | 136 | ||||
-rw-r--r-- | broadcast/src/msg.rs | 36 |
3 files changed, 207 insertions, 132 deletions
diff --git a/broadcast/src/handler.rs b/broadcast/src/handler.rs new file mode 100644 index 0000000..09c66b9 --- /dev/null +++ b/broadcast/src/handler.rs @@ -0,0 +1,167 @@ +use smol::{ + lock::{Mutex, RwLock}, + prelude::*, + Task, Timer, +}; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, + time::Duration, +}; + +use crate::msg::*; + +use common::{ + msg::{MessageHeader, Output}, + msg_id::{gen_msg_id, MessageID}, + Handler, +}; + +const RETRY_TIMEOUT: u64 = 2; + +pub struct BroadcastHandler { + node_id: String, + seen: RwLock<HashSet<BroadcastTarget>>, + broadcast_targets: RwLock<Vec<String>>, + output: Output<BroadcastBody>, + attempted_broadcasts: Mutex<HashMap<MessageID, Task<()>>>, +} + +impl Handler for BroadcastHandler { + type Body = BroadcastBody; + + fn init(node_id: String, mut node_ids: Vec<String>, output: Output<Self::Body>) -> Self { + node_ids.retain(|x| *x != node_id); + + BroadcastHandler { + node_id, + broadcast_targets: RwLock::new(node_ids), + seen: RwLock::new(HashSet::new()), + output, + attempted_broadcasts: Mutex::default(), + } + } + + fn handle<'a>( + self: &'a Arc<Self>, + header: MessageHeader, + body: BroadcastBody, + ) -> impl Future<Output = ()> + Send + 'a { + async move { + match body { + BroadcastBody::Broadcast { + msg_id: Some(msg_id), + message, + } => { + self.receive_broadcast(&header.src, message).await; + self.send_broadcast_ok(&header.src, msg_id).await; + } + BroadcastBody::Broadcast { + msg_id: None, + message, + } => { + self.receive_broadcast(&header.src, message).await; + } + + BroadcastBody::Topology { + msg_id, + mut topology, + } => { + // Start using the new topology + if let Some(broadcast_targets) = topology.remove(&self.node_id) { + *self.broadcast_targets.write().await = broadcast_targets; + } + + // Send reply if needed + if let Some(msg_id) = msg_id { + self.output + .send( + &header.src, + &BroadcastBody::TopologyOk { + in_reply_to: msg_id, + }, + ) + .await; + } + } + + BroadcastBody::Read { msg_id } => { + // Send all received messages back + self.output + .send( + &header.src, + &BroadcastBody::ReadOk { + in_reply_to: msg_id, + messages: self.seen.read().await.clone(), + }, + ) + .await + } + + // Ignore OK messages - we never actually request them + BroadcastBody::BroadcastOk { in_reply_to } => { + if let Some(task) = self.attempted_broadcasts.lock().await.remove(&in_reply_to) + { + task.cancel().await; + } + } + BroadcastBody::TopologyOk { .. } => {} + BroadcastBody::ReadOk { .. } => {} + } + } + } +} + +impl BroadcastHandler { + /// Reply with a broadcast OK message + async fn send_broadcast_ok(&self, src: &str, msg_id: MessageID) { + self.output + .send( + &src, + &BroadcastBody::BroadcastOk { + in_reply_to: msg_id, + }, + ) + .await; + } + + /// Receive a given message, and broadcast it onwards if it is new + async fn receive_broadcast(self: &Arc<Self>, src: &str, message: BroadcastTarget) { + let new = self.seen.write().await.insert(message); + if !new { + return; + } + + // Ensure we don't keep holding the read lock + let targets = self.broadcast_targets.read().await.clone(); + + // Race all send futures + let mut tasks = self.attempted_broadcasts.lock().await; + for target in targets.into_iter() { + if &target == &src { + return; + } + + let msg_id = gen_msg_id(); + let this = self.clone(); + tasks.insert( + msg_id, + smol::spawn(async move { + loop { + this.output + .send( + &target, + &BroadcastBody::Broadcast { + msg_id: None, + message, + }, + ) + .await; + + Timer::after(Duration::from_secs(RETRY_TIMEOUT)).await; + } + }), + ); + } + } +} diff --git a/broadcast/src/main.rs b/broadcast/src/main.rs index 09c0268..bdc8413 100644 --- a/broadcast/src/main.rs +++ b/broadcast/src/main.rs @@ -1,140 +1,12 @@ #![feature(return_position_impl_trait_in_trait)] -use smol::{lock::RwLock, prelude::*}; -use std::collections::{HashMap, HashSet}; +use common::run_server; -use common::{ - msg::{MessageHeader, Output}, - run_server, Handler, -}; -use serde::{Deserialize, Serialize}; +mod handler; +mod msg; -type BroadcastTarget = usize; +use handler::BroadcastHandler; fn main() { run_server::<BroadcastHandler>(); } - -#[derive(Debug, Deserialize, Serialize, Clone)] -#[serde(tag = "type")] -pub enum BroadcastBody { - #[serde(rename = "broadcast")] - Broadcast { - msg_id: Option<usize>, - message: BroadcastTarget, - }, - - #[serde(rename = "broadcast_ok")] - BroadcastOk { in_reply_to: usize }, - - #[serde(rename = "topology")] - Topology { - msg_id: Option<usize>, - topology: HashMap<String, Vec<String>>, - }, - - #[serde(rename = "topology_ok")] - TopologyOk { in_reply_to: usize }, - - #[serde(rename = "read")] - Read { msg_id: usize }, - - #[serde(rename = "read_ok")] - ReadOk { - in_reply_to: usize, - messages: HashSet<BroadcastTarget>, - }, -} - -pub struct BroadcastHandler { - node_id: String, - seen: RwLock<HashSet<BroadcastTarget>>, - broadcast_targets: RwLock<Vec<String>>, - output: Output<BroadcastBody>, -} - -impl Handler for BroadcastHandler { - type Body = BroadcastBody; - - fn init(node_id: String, mut node_ids: Vec<String>, output: Output<Self::Body>) -> Self { - node_ids.retain(|x| *x != node_id); - - BroadcastHandler { - node_id, - broadcast_targets: RwLock::new(node_ids), - seen: RwLock::new(HashSet::new()), - output, - } - } - - fn handle( - &self, - header: MessageHeader, - body: BroadcastBody, - ) -> impl Future<Output = ()> + Send { - async move { - match body { - BroadcastBody::Broadcast { msg_id, message } => { - self.seen.write().await.insert(message); - if let Some(msg_id) = msg_id { - self.output - .send( - &header.src, - &BroadcastBody::BroadcastOk { - in_reply_to: msg_id, - }, - ) - .await; - } - - for target in self.broadcast_targets.read().await.iter() { - self.output - .send( - target, - &BroadcastBody::Broadcast { - msg_id: None, - message, - }, - ) - .await; - } - } - BroadcastBody::Topology { - msg_id, - mut topology, - } => { - if let Some(broadcast_targets) = topology.remove(&self.node_id) { - *self.broadcast_targets.write().await = broadcast_targets; - } - - if let Some(msg_id) = msg_id { - self.output - .send( - &header.src, - &BroadcastBody::TopologyOk { - in_reply_to: msg_id, - }, - ) - .await; - } - } - - BroadcastBody::Read { msg_id } => { - self.output - .send( - &header.src, - &BroadcastBody::ReadOk { - in_reply_to: msg_id, - messages: self.seen.read().await.clone(), - }, - ) - .await - } - - BroadcastBody::BroadcastOk { .. } => {} - BroadcastBody::TopologyOk { .. } => {} - BroadcastBody::ReadOk { .. } => {} - } - } - } -} diff --git a/broadcast/src/msg.rs b/broadcast/src/msg.rs new file mode 100644 index 0000000..6433982 --- /dev/null +++ b/broadcast/src/msg.rs @@ -0,0 +1,36 @@ +use common::msg_id::MessageID; +use serde::{Deserialize, Serialize}; +use std::collections::{HashMap, HashSet}; + +pub type BroadcastTarget = usize; + +#[derive(Debug, Deserialize, Serialize, Clone)] +#[serde(tag = "type")] +pub enum BroadcastBody { + #[serde(rename = "broadcast")] + Broadcast { + msg_id: Option<MessageID>, + message: BroadcastTarget, + }, + + #[serde(rename = "broadcast_ok")] + BroadcastOk { in_reply_to: MessageID }, + + #[serde(rename = "topology")] + Topology { + msg_id: Option<MessageID>, + topology: HashMap<String, Vec<String>>, + }, + + #[serde(rename = "topology_ok")] + TopologyOk { in_reply_to: MessageID }, + + #[serde(rename = "read")] + Read { msg_id: MessageID }, + + #[serde(rename = "read_ok")] + ReadOk { + in_reply_to: MessageID, + messages: HashSet<BroadcastTarget>, + }, +} |