summaryrefslogtreecommitdiff
path: root/broadcast
diff options
context:
space:
mode:
Diffstat (limited to 'broadcast')
-rw-r--r--broadcast/Cargo.toml3
-rw-r--r--broadcast/src/handler.rs167
-rw-r--r--broadcast/src/main.rs136
-rw-r--r--broadcast/src/msg.rs36
4 files changed, 209 insertions, 133 deletions
diff --git a/broadcast/Cargo.toml b/broadcast/Cargo.toml
index f3e4cc7..06e4426 100644
--- a/broadcast/Cargo.toml
+++ b/broadcast/Cargo.toml
@@ -9,4 +9,5 @@ edition = "2021"
smol = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
-common = { path = "../common/" } \ No newline at end of file
+common = { path = "../common/" }
+futures = { version = "0.3.28", default_features = false, features = ["std"] } \ No newline at end of file
diff --git a/broadcast/src/handler.rs b/broadcast/src/handler.rs
new file mode 100644
index 0000000..09c66b9
--- /dev/null
+++ b/broadcast/src/handler.rs
@@ -0,0 +1,167 @@
+use smol::{
+ lock::{Mutex, RwLock},
+ prelude::*,
+ Task, Timer,
+};
+use std::{
+ collections::{HashMap, HashSet},
+ sync::Arc,
+ time::Duration,
+};
+
+use crate::msg::*;
+
+use common::{
+ msg::{MessageHeader, Output},
+ msg_id::{gen_msg_id, MessageID},
+ Handler,
+};
+
+const RETRY_TIMEOUT: u64 = 2;
+
+pub struct BroadcastHandler {
+ node_id: String,
+ seen: RwLock<HashSet<BroadcastTarget>>,
+ broadcast_targets: RwLock<Vec<String>>,
+ output: Output<BroadcastBody>,
+ attempted_broadcasts: Mutex<HashMap<MessageID, Task<()>>>,
+}
+
+impl Handler for BroadcastHandler {
+ type Body = BroadcastBody;
+
+ fn init(node_id: String, mut node_ids: Vec<String>, output: Output<Self::Body>) -> Self {
+ node_ids.retain(|x| *x != node_id);
+
+ BroadcastHandler {
+ node_id,
+ broadcast_targets: RwLock::new(node_ids),
+ seen: RwLock::new(HashSet::new()),
+ output,
+ attempted_broadcasts: Mutex::default(),
+ }
+ }
+
+ fn handle<'a>(
+ self: &'a Arc<Self>,
+ header: MessageHeader,
+ body: BroadcastBody,
+ ) -> impl Future<Output = ()> + Send + 'a {
+ async move {
+ match body {
+ BroadcastBody::Broadcast {
+ msg_id: Some(msg_id),
+ message,
+ } => {
+ self.receive_broadcast(&header.src, message).await;
+ self.send_broadcast_ok(&header.src, msg_id).await;
+ }
+ BroadcastBody::Broadcast {
+ msg_id: None,
+ message,
+ } => {
+ self.receive_broadcast(&header.src, message).await;
+ }
+
+ BroadcastBody::Topology {
+ msg_id,
+ mut topology,
+ } => {
+ // Start using the new topology
+ if let Some(broadcast_targets) = topology.remove(&self.node_id) {
+ *self.broadcast_targets.write().await = broadcast_targets;
+ }
+
+ // Send reply if needed
+ if let Some(msg_id) = msg_id {
+ self.output
+ .send(
+ &header.src,
+ &BroadcastBody::TopologyOk {
+ in_reply_to: msg_id,
+ },
+ )
+ .await;
+ }
+ }
+
+ BroadcastBody::Read { msg_id } => {
+ // Send all received messages back
+ self.output
+ .send(
+ &header.src,
+ &BroadcastBody::ReadOk {
+ in_reply_to: msg_id,
+ messages: self.seen.read().await.clone(),
+ },
+ )
+ .await
+ }
+
+ // Ignore OK messages - we never actually request them
+ BroadcastBody::BroadcastOk { in_reply_to } => {
+ if let Some(task) = self.attempted_broadcasts.lock().await.remove(&in_reply_to)
+ {
+ task.cancel().await;
+ }
+ }
+ BroadcastBody::TopologyOk { .. } => {}
+ BroadcastBody::ReadOk { .. } => {}
+ }
+ }
+ }
+}
+
+impl BroadcastHandler {
+ /// Reply with a broadcast OK message
+ async fn send_broadcast_ok(&self, src: &str, msg_id: MessageID) {
+ self.output
+ .send(
+ &src,
+ &BroadcastBody::BroadcastOk {
+ in_reply_to: msg_id,
+ },
+ )
+ .await;
+ }
+
+ /// Receive a given message, and broadcast it onwards if it is new
+ async fn receive_broadcast(self: &Arc<Self>, src: &str, message: BroadcastTarget) {
+ let new = self.seen.write().await.insert(message);
+ if !new {
+ return;
+ }
+
+ // Ensure we don't keep holding the read lock
+ let targets = self.broadcast_targets.read().await.clone();
+
+ // Race all send futures
+ let mut tasks = self.attempted_broadcasts.lock().await;
+ for target in targets.into_iter() {
+ if &target == &src {
+ return;
+ }
+
+ let msg_id = gen_msg_id();
+ let this = self.clone();
+ tasks.insert(
+ msg_id,
+ smol::spawn(async move {
+ loop {
+ this.output
+ .send(
+ &target,
+ &BroadcastBody::Broadcast {
+ msg_id: None,
+ message,
+ },
+ )
+ .await;
+
+ Timer::after(Duration::from_secs(RETRY_TIMEOUT)).await;
+ }
+ }),
+ );
+ }
+ }
+}
diff --git a/broadcast/src/main.rs b/broadcast/src/main.rs
index 09c0268..bdc8413 100644
--- a/broadcast/src/main.rs
+++ b/broadcast/src/main.rs
@@ -1,140 +1,12 @@
#![feature(return_position_impl_trait_in_trait)]
-use smol::{lock::RwLock, prelude::*};
-use std::collections::{HashMap, HashSet};
+use common::run_server;
-use common::{
- msg::{MessageHeader, Output},
- run_server, Handler,
-};
-use serde::{Deserialize, Serialize};
+mod handler;
+mod msg;
-type BroadcastTarget = usize;
+use handler::BroadcastHandler;
fn main() {
run_server::<BroadcastHandler>();
}
-
-#[derive(Debug, Deserialize, Serialize, Clone)]
-#[serde(tag = "type")]
-pub enum BroadcastBody {
- #[serde(rename = "broadcast")]
- Broadcast {
- msg_id: Option<usize>,
- message: BroadcastTarget,
- },
-
- #[serde(rename = "broadcast_ok")]
- BroadcastOk { in_reply_to: usize },
-
- #[serde(rename = "topology")]
- Topology {
- msg_id: Option<usize>,
- topology: HashMap<String, Vec<String>>,
- },
-
- #[serde(rename = "topology_ok")]
- TopologyOk { in_reply_to: usize },
-
- #[serde(rename = "read")]
- Read { msg_id: usize },
-
- #[serde(rename = "read_ok")]
- ReadOk {
- in_reply_to: usize,
- messages: HashSet<BroadcastTarget>,
- },
-}
-
-pub struct BroadcastHandler {
- node_id: String,
- seen: RwLock<HashSet<BroadcastTarget>>,
- broadcast_targets: RwLock<Vec<String>>,
- output: Output<BroadcastBody>,
-}
-
-impl Handler for BroadcastHandler {
- type Body = BroadcastBody;
-
- fn init(node_id: String, mut node_ids: Vec<String>, output: Output<Self::Body>) -> Self {
- node_ids.retain(|x| *x != node_id);
-
- BroadcastHandler {
- node_id,
- broadcast_targets: RwLock::new(node_ids),
- seen: RwLock::new(HashSet::new()),
- output,
- }
- }
-
- fn handle(
- &self,
- header: MessageHeader,
- body: BroadcastBody,
- ) -> impl Future<Output = ()> + Send {
- async move {
- match body {
- BroadcastBody::Broadcast { msg_id, message } => {
- self.seen.write().await.insert(message);
- if let Some(msg_id) = msg_id {
- self.output
- .send(
- &header.src,
- &BroadcastBody::BroadcastOk {
- in_reply_to: msg_id,
- },
- )
- .await;
- }
-
- for target in self.broadcast_targets.read().await.iter() {
- self.output
- .send(
- target,
- &BroadcastBody::Broadcast {
- msg_id: None,
- message,
- },
- )
- .await;
- }
- }
- BroadcastBody::Topology {
- msg_id,
- mut topology,
- } => {
- if let Some(broadcast_targets) = topology.remove(&self.node_id) {
- *self.broadcast_targets.write().await = broadcast_targets;
- }
-
- if let Some(msg_id) = msg_id {
- self.output
- .send(
- &header.src,
- &BroadcastBody::TopologyOk {
- in_reply_to: msg_id,
- },
- )
- .await;
- }
- }
-
- BroadcastBody::Read { msg_id } => {
- self.output
- .send(
- &header.src,
- &BroadcastBody::ReadOk {
- in_reply_to: msg_id,
- messages: self.seen.read().await.clone(),
- },
- )
- .await
- }
-
- BroadcastBody::BroadcastOk { .. } => {}
- BroadcastBody::TopologyOk { .. } => {}
- BroadcastBody::ReadOk { .. } => {}
- }
- }
- }
-}
diff --git a/broadcast/src/msg.rs b/broadcast/src/msg.rs
new file mode 100644
index 0000000..6433982
--- /dev/null
+++ b/broadcast/src/msg.rs
@@ -0,0 +1,36 @@
+use common::msg_id::MessageID;
+use serde::{Deserialize, Serialize};
+use std::collections::{HashMap, HashSet};
+
+pub type BroadcastTarget = usize;
+
+#[derive(Debug, Deserialize, Serialize, Clone)]
+#[serde(tag = "type")]
+pub enum BroadcastBody {
+ #[serde(rename = "broadcast")]
+ Broadcast {
+ msg_id: Option<MessageID>,
+ message: BroadcastTarget,
+ },
+
+ #[serde(rename = "broadcast_ok")]
+ BroadcastOk { in_reply_to: MessageID },
+
+ #[serde(rename = "topology")]
+ Topology {
+ msg_id: Option<MessageID>,
+ topology: HashMap<String, Vec<String>>,
+ },
+
+ #[serde(rename = "topology_ok")]
+ TopologyOk { in_reply_to: MessageID },
+
+ #[serde(rename = "read")]
+ Read { msg_id: MessageID },
+
+ #[serde(rename = "read_ok")]
+ ReadOk {
+ in_reply_to: MessageID,
+ messages: HashSet<BroadcastTarget>,
+ },
+}