summaryrefslogtreecommitdiff
path: root/crates/windlass/src/transport/mod.rs
diff options
context:
space:
mode:
authortcmal <me@aria.rip>2024-09-17 16:12:31 +0100
committertcmal <me@aria.rip>2024-09-30 23:17:34 +0100
commitfd82733126ee82b085875c44a0993534968afad3 (patch)
tree73ab58288c4da05e51c60a2e27cb62e33f2111ad /crates/windlass/src/transport/mod.rs
parent5f9fbea5a1b08962887d16457e77a922919a8818 (diff)
Better docs / cleanup for windlass
Diffstat (limited to 'crates/windlass/src/transport/mod.rs')
-rw-r--r--crates/windlass/src/transport/mod.rs434
1 files changed, 434 insertions, 0 deletions
diff --git a/crates/windlass/src/transport/mod.rs b/crates/windlass/src/transport/mod.rs
new file mode 100644
index 0000000..f5c5fc3
--- /dev/null
+++ b/crates/windlass/src/transport/mod.rs
@@ -0,0 +1,434 @@
+use read::{FrameReader, ReceivedFrame, ReceiverError};
+use std::{collections::VecDeque, sync::Arc, time::Duration};
+use tokio::{
+ io::{AsyncBufRead, AsyncWrite, AsyncWriteExt},
+ pin, select, spawn,
+ sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
+ task::JoinHandle,
+ time::{sleep_until, Instant},
+};
+use tokio_util::sync::CancellationToken;
+
+use crate::encoding::crc16;
+
+pub(crate) const MESSAGE_HEADER_SIZE: usize = 2;
+pub(crate) const MESSAGE_TRAILER_SIZE: usize = 3;
+pub(crate) const MESSAGE_LENGTH_MIN: usize = MESSAGE_HEADER_SIZE + MESSAGE_TRAILER_SIZE;
+pub(crate) const MESSAGE_LENGTH_MAX: usize = 64;
+pub(crate) const MESSAGE_LENGTH_PAYLOAD_MAX: usize = MESSAGE_LENGTH_MAX - MESSAGE_LENGTH_MIN;
+pub(crate) const MESSAGE_POSITION_SEQ: usize = 1;
+pub(crate) const MESSAGE_TRAILER_CRC: usize = 3;
+pub(crate) const MESSAGE_VALUE_SYNC: u8 = 0x7E;
+pub(crate) const MESSAGE_DEST: u8 = 0x10;
+pub(crate) const MESSAGE_SEQ_MASK: u8 = 0x0F;
+
+mod read;
+
+/// Wrapper around a connection to a klipper firmware MCU, which deals with
+/// retransmission, flow control, etc.
+///
+/// Internally, this holds a handle to an async task and a channel to communicate with it,
+/// meaning operations don't necessarily block the current task.
+#[derive(Debug)]
+pub struct Transport {
+ // Handles to the associated async task
+ task_inner: JoinHandle<()>,
+
+ /// Queue for outbound messages
+ cmd_send: UnboundedSender<TransportCommand>,
+}
+
+/// A message sent to the transport task
+#[derive(Debug)]
+enum TransportCommand {
+ SendMessage(Vec<u8>),
+ Exit,
+}
+
+pub(crate) type TransportReceiver = UnboundedReceiver<Result<Vec<u8>, TransportError>>;
+
+impl Transport {
+ pub(crate) async fn connect(
+ rdr: impl AsyncBufRead + Unpin + Send + 'static,
+ wr: impl AsyncWrite + Unpin + Send + 'static,
+ ) -> (Transport, TransportReceiver) {
+ let (data_send, data_recv) = unbounded_channel();
+ let (cmd_send, cmd_recv) = unbounded_channel();
+
+ let cancel_token = CancellationToken::new();
+ let task_inner = spawn(async move {
+ let mut ts = TransportState::new(rdr, wr, data_send, cmd_recv, cancel_token);
+ if let Err(e) = ts.run().await {
+ let _ = ts.data_send.send(Err(e));
+ }
+ });
+
+ (
+ Transport {
+ task_inner,
+ cmd_send,
+ },
+ data_recv,
+ )
+ }
+
+ pub(crate) fn send(&self, msg: &[u8]) -> Result<(), TransmitterError> {
+ self.cmd_send
+ .send(TransportCommand::SendMessage(msg.into()))
+ .map_err(|_| TransmitterError::ConnectionClosed)
+ }
+
+ pub(crate) async fn close(self) {
+ let _ = self.cmd_send.send(TransportCommand::Exit);
+ let _ = self.task_inner.await;
+ }
+}
+
+#[derive(thiserror::Error, Debug)]
+pub enum TransportError {
+ #[error("message encoding failed: {0}")]
+ MessageEncode(#[from] MessageEncodeError),
+
+ #[error("receiver error: {0}")]
+ Receiver(#[from] ReceiverError),
+
+ #[error("transmitter error: {0}")]
+ Transmitter(#[from] TransmitterError),
+
+ #[error("io error: {0}")]
+ IOError(#[from] std::io::Error),
+}
+
+const MIN_RTO: f32 = 0.025;
+const MAX_RTO: f32 = 5.000;
+
+/// State for estimating the round trip time of the connection
+#[derive(Debug)]
+struct RttState {
+ srtt: f32,
+ rttvar: f32,
+ rto: f32,
+}
+
+impl Default for RttState {
+ fn default() -> Self {
+ Self {
+ srtt: 0.0,
+ rttvar: 0.0,
+ rto: MIN_RTO,
+ }
+ }
+}
+
+impl RttState {
+ /// Get the current recommended retransmission timeout
+ fn rto(&self) -> Duration {
+ Duration::from_secs_f32(self.rto)
+ }
+
+ /// Update the RTT estimation given a new observation
+ fn update(&mut self, rtt: Duration) {
+ let r = rtt.as_secs_f32();
+ if self.srtt == 0.0 {
+ self.rttvar = r / 2.0;
+ self.srtt = r * 10.0; // Klipper uses this, we'll copy it
+ } else {
+ self.rttvar = (3.0 * self.rttvar + (self.srtt - r).abs()) / 4.0;
+ self.srtt = (7.0 * self.srtt + r) / 8.0;
+ }
+ let rttvar4 = (self.rttvar * 4.0).max(0.001);
+ self.rto = (self.srtt + rttvar4).clamp(MIN_RTO, MAX_RTO);
+ }
+}
+
+/// State for the task which deals with transport state
+#[derive(Debug)]
+struct TransportState<R, W> {
+ rdr: FrameReader<R>,
+ wr: W,
+
+ data_send: UnboundedSender<Result<Vec<u8>, TransportError>>,
+ cmd_recv: UnboundedReceiver<TransportCommand>,
+
+ cancel: CancellationToken,
+
+ is_synchronized: bool,
+ rtt_state: RttState,
+ receive_sequence: u64,
+ send_sequence: u64,
+ last_ack_sequence: u64,
+ ignore_nak_seq: u64,
+ retransmit_seq: u64,
+ retransmit_now: bool,
+
+ corked_until: Option<Instant>,
+
+ inflight_messages: VecDeque<SentFrame>,
+ pending_messages: VecDeque<Vec<u8>>,
+}
+
+impl<R: AsyncBufRead + Unpin, W: AsyncWrite + Unpin> TransportState<R, W> {
+ fn new(
+ rdr: R,
+ wr: W,
+ data_send: UnboundedSender<Result<Vec<u8>, TransportError>>,
+ cmd_recv: UnboundedReceiver<TransportCommand>,
+ cancel: CancellationToken,
+ ) -> Self {
+ Self {
+ rdr: FrameReader::new(rdr),
+ wr,
+ data_send,
+ cmd_recv,
+ cancel,
+
+ is_synchronized: false,
+ rtt_state: RttState::default(),
+ receive_sequence: 1,
+ send_sequence: 1,
+ last_ack_sequence: 0,
+ ignore_nak_seq: 0,
+ retransmit_seq: 0,
+ retransmit_now: false,
+
+ corked_until: None,
+
+ inflight_messages: VecDeque::new(),
+ pending_messages: VecDeque::new(),
+ }
+ }
+
+ async fn run(&mut self) -> Result<(), TransportError> {
+ loop {
+ if self.retransmit_now {
+ self.retransmit_pending().await?;
+ }
+
+ if !self.pending_messages.is_empty() && self.can_send() {
+ self.send_more_frames().await?;
+ }
+
+ let retransmit_deadline = self
+ .inflight_messages
+ .front()
+ .map(|msg| msg.sent_at + self.rtt_state.rto());
+ let retransmit_timeout: futures::future::OptionFuture<_> =
+ retransmit_deadline.map(sleep_until).into();
+ pin!(retransmit_timeout);
+
+ let corked_timeout: futures::future::OptionFuture<_> =
+ self.corked_until.map(sleep_until).into();
+ pin!(corked_timeout);
+
+ // FIXME: This is not correct because read_frame is not cancellation safe
+ select! {
+ frame = self.rdr.read_frame() => {
+ let frame = frame?;
+ let frame = match frame {
+ Some(frame) => frame,
+ None => break,
+ };
+ self.handle_frame(frame);
+ },
+
+ msg = self.cmd_recv.recv() => {
+ match msg {
+ Some(TransportCommand::SendMessage(msg)) => {
+ self.pending_messages.push_back(msg);
+ },
+ Some(TransportCommand::Exit) => {
+ self.cancel.cancel();
+ }
+ None => break,
+ };
+ },
+
+ _ = &mut retransmit_timeout, if retransmit_deadline.is_some() => {
+ self.retransmit_now = true;
+ },
+
+ _ = &mut corked_timeout, if self.corked_until.is_some() => {
+ // Timeout for when we are able to send again
+ }
+
+ _ = self.cancel.cancelled() => {
+ break;
+ },
+ }
+ }
+ Ok(())
+ }
+
+ /// Handle an incoming frame, by updating sequence numbers and sending the data upwards if needed
+ fn handle_frame(&mut self, frame: ReceivedFrame) {
+ let rseq = self.receive_sequence;
+
+ // wrap-around logic(?)
+ let mut sequence = (rseq & !(MESSAGE_SEQ_MASK as u64)) | (frame.sequence as u64);
+ if sequence < rseq {
+ sequence += (MESSAGE_SEQ_MASK as u64) + 1;
+ }
+
+ // Frame acknowledges some messages
+ if !self.is_synchronized || sequence != rseq {
+ if sequence > self.send_sequence && self.is_synchronized {
+ // Ack for unsent message - weird, but ignore and try to continue
+ return;
+ }
+
+ self.update_receive_seq(frame.receive_time, sequence);
+ }
+
+ if !frame.payload.is_empty() {
+ // Data message, we deliver this directly to the application as the MCU can't actually
+ // retransmit anyway.
+ // TODO: Maybe check the CRC anyway so we can discard it here
+ let _ = self.data_send.send(Ok(frame.payload));
+ } else if sequence > self.last_ack_sequence {
+ // ACK
+ self.last_ack_sequence = sequence;
+ } else if sequence > self.ignore_nak_seq && !self.inflight_messages.is_empty() {
+ // NAK
+ self.retransmit_now = true;
+ }
+ }
+
+ /// Update the last received sequence number, removing acknowledged messages from `self.inflight_messages`
+ fn update_receive_seq(&mut self, receive_time: Instant, sequence: u64) {
+ let mut sent_seq = self.receive_sequence;
+
+ // Discard messages from inflight_messages up to sequence
+ loop {
+ if let Some(msg) = self.inflight_messages.pop_front() {
+ sent_seq += 1;
+ if sequence == sent_seq {
+ // Found the matching sent message
+ if !msg.is_retransmit {
+ let elapsed = receive_time.saturating_duration_since(msg.sent_at);
+ self.rtt_state.update(elapsed);
+ }
+ break;
+ }
+ } else {
+ // Ack with no outstanding messages, happens during connection init
+ self.send_sequence = sequence;
+ break;
+ }
+ }
+
+ self.receive_sequence = sequence;
+ self.is_synchronized = true;
+ }
+
+ fn can_send(&self) -> bool {
+ self.corked_until.is_none() && self.inflight_messages.len() < 12
+ }
+
+ /// Send as many more frames as possible from [`self.pending_messages`]
+ async fn send_more_frames(&mut self) -> Result<(), TransportError> {
+ while self.can_send() && !self.pending_messages.is_empty() {
+ self.send_new_frame().await?;
+ }
+
+ Ok(())
+ }
+
+ /// Send a single new frame from [`self.pending_messages`]
+ async fn send_new_frame(&mut self) -> Result<(), TransportError> {
+ let mut buf = Vec::new();
+ while let Some(next) = self.pending_messages.front() {
+ if !buf.is_empty() && buf.len() + next.len() <= MESSAGE_LENGTH_PAYLOAD_MAX {
+ // Add to the end of the frame. Unwrap is safe because we already peeked.
+ let mut next = self.pending_messages.pop_front().unwrap();
+ buf.append(&mut next);
+ } else {
+ break;
+ }
+ }
+
+ let frame = Arc::new(encode_frame(self.send_sequence, &buf)?);
+ self.send_sequence += 1;
+ self.inflight_messages.push_back(SentFrame {
+ sent_at: Instant::now(),
+ sequence: self.send_sequence,
+ payload: frame.clone(),
+ is_retransmit: false,
+ });
+ self.wr.write_all(&frame).await?;
+
+ Ok(())
+ }
+
+ /// Retransmit all inflight messages
+ async fn retransmit_pending(&mut self) -> Result<(), TransportError> {
+ let len: usize = self
+ .inflight_messages
+ .iter()
+ .map(|msg| msg.payload.len())
+ .sum();
+ let mut buf = Vec::with_capacity(1 + len);
+ buf.push(MESSAGE_VALUE_SYNC);
+ let now = Instant::now();
+ for msg in self.inflight_messages.iter_mut() {
+ buf.extend_from_slice(&msg.payload);
+ msg.is_retransmit = true;
+ msg.sent_at = now;
+ }
+ self.wr.write_all(&buf).await?;
+
+ if self.retransmit_now {
+ self.ignore_nak_seq = self.receive_sequence;
+ if self.receive_sequence < self.retransmit_seq {
+ self.ignore_nak_seq = self.retransmit_seq;
+ }
+ self.retransmit_now = false;
+ } else {
+ self.rtt_state.rto = (self.rtt_state.rto * 2.0).clamp(MIN_RTO, MAX_RTO);
+ self.ignore_nak_seq = self.send_sequence;
+ }
+ self.retransmit_seq = self.send_sequence;
+
+ Ok(())
+ }
+}
+
+/// An error encountered when transmitting a message
+#[derive(thiserror::Error, Debug)]
+pub enum TransmitterError {
+ #[error("io error: {0}")]
+ IoError(#[from] std::io::Error),
+
+ #[error("connection closed")]
+ ConnectionClosed,
+}
+
+#[derive(Debug, Clone)]
+pub(crate) struct SentFrame {
+ pub sent_at: Instant,
+ #[allow(dead_code)]
+ pub sequence: u64,
+ pub payload: Arc<Vec<u8>>,
+ pub is_retransmit: bool,
+}
+
+#[derive(thiserror::Error, Debug)]
+pub enum MessageEncodeError {
+ #[error("message would exceed the maximum packet length of {MESSAGE_LENGTH_MAX} bytes")]
+ MessageTooLong,
+}
+
+fn encode_frame(sequence: u64, payload: &[u8]) -> Result<Vec<u8>, MessageEncodeError> {
+ let len = MESSAGE_LENGTH_MIN + payload.len();
+ if len > MESSAGE_LENGTH_MAX {
+ return Err(MessageEncodeError::MessageTooLong);
+ }
+ let mut buf = Vec::with_capacity(len);
+ buf.push(len as u8);
+ buf.push(MESSAGE_DEST | ((sequence as u8) & MESSAGE_SEQ_MASK));
+ buf.extend_from_slice(payload);
+ let crc = crc16(&buf[0..len - MESSAGE_TRAILER_SIZE]);
+ buf.push(((crc >> 8) & 0xFF) as u8);
+ buf.push((crc & 0xFF) as u8);
+ buf.push(MESSAGE_VALUE_SYNC);
+ Ok(buf)
+}