summaryrefslogtreecommitdiff
path: root/broadcast/src/batch.rs
blob: d69771d3b31993988e308468cab504cca62e6f87 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
use std::{
    collections::HashSet,
    sync::Arc,
    time::{Duration, Instant},
};

use common::msg_id::MessageID;
use futures::Future;
use smol::Timer;

use crate::{
    handler::BroadcastHandler,
    msg::{BroadcastBody, BroadcastTarget},
};

const RETRY_TIMEOUT_SECS: u64 = 1;

#[derive(Debug, Clone)]
pub struct MessageBatch {
    max_message_delay: Duration,
    first_added: Instant,
    messages: HashSet<BroadcastTarget>,
}

impl MessageBatch {
    pub fn new(max_message_delay: Duration) -> Self {
        Self {
            max_message_delay,
            first_added: Instant::now(),
            messages: HashSet::new(),
        }
    }

    pub fn add(&mut self, msg: BroadcastTarget) {
        if self.messages.is_empty() {
            self.first_added = Instant::now();
        }
        self.messages.insert(msg);
    }

    pub fn clear(&mut self) {
        self.messages.clear();
    }

    pub fn should_broadcast(&self) -> bool {
        !self.messages.is_empty() && self.first_added.elapsed() >= self.max_message_delay
    }

    pub fn sleep_time(&self) -> Duration {
        self.first_added
            .elapsed()
            .saturating_sub(self.max_message_delay)
    }

    pub fn broadcast(
        &self,
        this: Arc<BroadcastHandler>,
        dst: String,
        msg_id: MessageID,
    ) -> impl Future<Output = ()> + 'static {
        let messages = self.messages.clone();
        async move {
            loop {
                this.output
                    .send(
                        &dst,
                        &BroadcastBody::BroadcastBatch {
                            msg_id: Some(msg_id),
                            messages: messages.clone().into_iter().collect(),
                        },
                    )
                    .await;

                Timer::after(Duration::from_secs(RETRY_TIMEOUT_SECS)).await;
            }
        }
    }
}