diff options
author | tcmal <me@aria.rip> | 2024-09-17 16:12:31 +0100 |
---|---|---|
committer | tcmal <me@aria.rip> | 2024-09-30 23:17:34 +0100 |
commit | fd82733126ee82b085875c44a0993534968afad3 (patch) | |
tree | 73ab58288c4da05e51c60a2e27cb62e33f2111ad /crates/windlass/src/transport/mod.rs | |
parent | 5f9fbea5a1b08962887d16457e77a922919a8818 (diff) |
Better docs / cleanup for windlass
Diffstat (limited to 'crates/windlass/src/transport/mod.rs')
-rw-r--r-- | crates/windlass/src/transport/mod.rs | 434 |
1 files changed, 434 insertions, 0 deletions
diff --git a/crates/windlass/src/transport/mod.rs b/crates/windlass/src/transport/mod.rs new file mode 100644 index 0000000..f5c5fc3 --- /dev/null +++ b/crates/windlass/src/transport/mod.rs @@ -0,0 +1,434 @@ +use read::{FrameReader, ReceivedFrame, ReceiverError}; +use std::{collections::VecDeque, sync::Arc, time::Duration}; +use tokio::{ + io::{AsyncBufRead, AsyncWrite, AsyncWriteExt}, + pin, select, spawn, + sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}, + task::JoinHandle, + time::{sleep_until, Instant}, +}; +use tokio_util::sync::CancellationToken; + +use crate::encoding::crc16; + +pub(crate) const MESSAGE_HEADER_SIZE: usize = 2; +pub(crate) const MESSAGE_TRAILER_SIZE: usize = 3; +pub(crate) const MESSAGE_LENGTH_MIN: usize = MESSAGE_HEADER_SIZE + MESSAGE_TRAILER_SIZE; +pub(crate) const MESSAGE_LENGTH_MAX: usize = 64; +pub(crate) const MESSAGE_LENGTH_PAYLOAD_MAX: usize = MESSAGE_LENGTH_MAX - MESSAGE_LENGTH_MIN; +pub(crate) const MESSAGE_POSITION_SEQ: usize = 1; +pub(crate) const MESSAGE_TRAILER_CRC: usize = 3; +pub(crate) const MESSAGE_VALUE_SYNC: u8 = 0x7E; +pub(crate) const MESSAGE_DEST: u8 = 0x10; +pub(crate) const MESSAGE_SEQ_MASK: u8 = 0x0F; + +mod read; + +/// Wrapper around a connection to a klipper firmware MCU, which deals with +/// retransmission, flow control, etc. +/// +/// Internally, this holds a handle to an async task and a channel to communicate with it, +/// meaning operations don't necessarily block the current task. +#[derive(Debug)] +pub struct Transport { + // Handles to the associated async task + task_inner: JoinHandle<()>, + + /// Queue for outbound messages + cmd_send: UnboundedSender<TransportCommand>, +} + +/// A message sent to the transport task +#[derive(Debug)] +enum TransportCommand { + SendMessage(Vec<u8>), + Exit, +} + +pub(crate) type TransportReceiver = UnboundedReceiver<Result<Vec<u8>, TransportError>>; + +impl Transport { + pub(crate) async fn connect( + rdr: impl AsyncBufRead + Unpin + Send + 'static, + wr: impl AsyncWrite + Unpin + Send + 'static, + ) -> (Transport, TransportReceiver) { + let (data_send, data_recv) = unbounded_channel(); + let (cmd_send, cmd_recv) = unbounded_channel(); + + let cancel_token = CancellationToken::new(); + let task_inner = spawn(async move { + let mut ts = TransportState::new(rdr, wr, data_send, cmd_recv, cancel_token); + if let Err(e) = ts.run().await { + let _ = ts.data_send.send(Err(e)); + } + }); + + ( + Transport { + task_inner, + cmd_send, + }, + data_recv, + ) + } + + pub(crate) fn send(&self, msg: &[u8]) -> Result<(), TransmitterError> { + self.cmd_send + .send(TransportCommand::SendMessage(msg.into())) + .map_err(|_| TransmitterError::ConnectionClosed) + } + + pub(crate) async fn close(self) { + let _ = self.cmd_send.send(TransportCommand::Exit); + let _ = self.task_inner.await; + } +} + +#[derive(thiserror::Error, Debug)] +pub enum TransportError { + #[error("message encoding failed: {0}")] + MessageEncode(#[from] MessageEncodeError), + + #[error("receiver error: {0}")] + Receiver(#[from] ReceiverError), + + #[error("transmitter error: {0}")] + Transmitter(#[from] TransmitterError), + + #[error("io error: {0}")] + IOError(#[from] std::io::Error), +} + +const MIN_RTO: f32 = 0.025; +const MAX_RTO: f32 = 5.000; + +/// State for estimating the round trip time of the connection +#[derive(Debug)] +struct RttState { + srtt: f32, + rttvar: f32, + rto: f32, +} + +impl Default for RttState { + fn default() -> Self { + Self { + srtt: 0.0, + rttvar: 0.0, + rto: MIN_RTO, + } + } +} + +impl RttState { + /// Get the current recommended retransmission timeout + fn rto(&self) -> Duration { + Duration::from_secs_f32(self.rto) + } + + /// Update the RTT estimation given a new observation + fn update(&mut self, rtt: Duration) { + let r = rtt.as_secs_f32(); + if self.srtt == 0.0 { + self.rttvar = r / 2.0; + self.srtt = r * 10.0; // Klipper uses this, we'll copy it + } else { + self.rttvar = (3.0 * self.rttvar + (self.srtt - r).abs()) / 4.0; + self.srtt = (7.0 * self.srtt + r) / 8.0; + } + let rttvar4 = (self.rttvar * 4.0).max(0.001); + self.rto = (self.srtt + rttvar4).clamp(MIN_RTO, MAX_RTO); + } +} + +/// State for the task which deals with transport state +#[derive(Debug)] +struct TransportState<R, W> { + rdr: FrameReader<R>, + wr: W, + + data_send: UnboundedSender<Result<Vec<u8>, TransportError>>, + cmd_recv: UnboundedReceiver<TransportCommand>, + + cancel: CancellationToken, + + is_synchronized: bool, + rtt_state: RttState, + receive_sequence: u64, + send_sequence: u64, + last_ack_sequence: u64, + ignore_nak_seq: u64, + retransmit_seq: u64, + retransmit_now: bool, + + corked_until: Option<Instant>, + + inflight_messages: VecDeque<SentFrame>, + pending_messages: VecDeque<Vec<u8>>, +} + +impl<R: AsyncBufRead + Unpin, W: AsyncWrite + Unpin> TransportState<R, W> { + fn new( + rdr: R, + wr: W, + data_send: UnboundedSender<Result<Vec<u8>, TransportError>>, + cmd_recv: UnboundedReceiver<TransportCommand>, + cancel: CancellationToken, + ) -> Self { + Self { + rdr: FrameReader::new(rdr), + wr, + data_send, + cmd_recv, + cancel, + + is_synchronized: false, + rtt_state: RttState::default(), + receive_sequence: 1, + send_sequence: 1, + last_ack_sequence: 0, + ignore_nak_seq: 0, + retransmit_seq: 0, + retransmit_now: false, + + corked_until: None, + + inflight_messages: VecDeque::new(), + pending_messages: VecDeque::new(), + } + } + + async fn run(&mut self) -> Result<(), TransportError> { + loop { + if self.retransmit_now { + self.retransmit_pending().await?; + } + + if !self.pending_messages.is_empty() && self.can_send() { + self.send_more_frames().await?; + } + + let retransmit_deadline = self + .inflight_messages + .front() + .map(|msg| msg.sent_at + self.rtt_state.rto()); + let retransmit_timeout: futures::future::OptionFuture<_> = + retransmit_deadline.map(sleep_until).into(); + pin!(retransmit_timeout); + + let corked_timeout: futures::future::OptionFuture<_> = + self.corked_until.map(sleep_until).into(); + pin!(corked_timeout); + + // FIXME: This is not correct because read_frame is not cancellation safe + select! { + frame = self.rdr.read_frame() => { + let frame = frame?; + let frame = match frame { + Some(frame) => frame, + None => break, + }; + self.handle_frame(frame); + }, + + msg = self.cmd_recv.recv() => { + match msg { + Some(TransportCommand::SendMessage(msg)) => { + self.pending_messages.push_back(msg); + }, + Some(TransportCommand::Exit) => { + self.cancel.cancel(); + } + None => break, + }; + }, + + _ = &mut retransmit_timeout, if retransmit_deadline.is_some() => { + self.retransmit_now = true; + }, + + _ = &mut corked_timeout, if self.corked_until.is_some() => { + // Timeout for when we are able to send again + } + + _ = self.cancel.cancelled() => { + break; + }, + } + } + Ok(()) + } + + /// Handle an incoming frame, by updating sequence numbers and sending the data upwards if needed + fn handle_frame(&mut self, frame: ReceivedFrame) { + let rseq = self.receive_sequence; + + // wrap-around logic(?) + let mut sequence = (rseq & !(MESSAGE_SEQ_MASK as u64)) | (frame.sequence as u64); + if sequence < rseq { + sequence += (MESSAGE_SEQ_MASK as u64) + 1; + } + + // Frame acknowledges some messages + if !self.is_synchronized || sequence != rseq { + if sequence > self.send_sequence && self.is_synchronized { + // Ack for unsent message - weird, but ignore and try to continue + return; + } + + self.update_receive_seq(frame.receive_time, sequence); + } + + if !frame.payload.is_empty() { + // Data message, we deliver this directly to the application as the MCU can't actually + // retransmit anyway. + // TODO: Maybe check the CRC anyway so we can discard it here + let _ = self.data_send.send(Ok(frame.payload)); + } else if sequence > self.last_ack_sequence { + // ACK + self.last_ack_sequence = sequence; + } else if sequence > self.ignore_nak_seq && !self.inflight_messages.is_empty() { + // NAK + self.retransmit_now = true; + } + } + + /// Update the last received sequence number, removing acknowledged messages from `self.inflight_messages` + fn update_receive_seq(&mut self, receive_time: Instant, sequence: u64) { + let mut sent_seq = self.receive_sequence; + + // Discard messages from inflight_messages up to sequence + loop { + if let Some(msg) = self.inflight_messages.pop_front() { + sent_seq += 1; + if sequence == sent_seq { + // Found the matching sent message + if !msg.is_retransmit { + let elapsed = receive_time.saturating_duration_since(msg.sent_at); + self.rtt_state.update(elapsed); + } + break; + } + } else { + // Ack with no outstanding messages, happens during connection init + self.send_sequence = sequence; + break; + } + } + + self.receive_sequence = sequence; + self.is_synchronized = true; + } + + fn can_send(&self) -> bool { + self.corked_until.is_none() && self.inflight_messages.len() < 12 + } + + /// Send as many more frames as possible from [`self.pending_messages`] + async fn send_more_frames(&mut self) -> Result<(), TransportError> { + while self.can_send() && !self.pending_messages.is_empty() { + self.send_new_frame().await?; + } + + Ok(()) + } + + /// Send a single new frame from [`self.pending_messages`] + async fn send_new_frame(&mut self) -> Result<(), TransportError> { + let mut buf = Vec::new(); + while let Some(next) = self.pending_messages.front() { + if !buf.is_empty() && buf.len() + next.len() <= MESSAGE_LENGTH_PAYLOAD_MAX { + // Add to the end of the frame. Unwrap is safe because we already peeked. + let mut next = self.pending_messages.pop_front().unwrap(); + buf.append(&mut next); + } else { + break; + } + } + + let frame = Arc::new(encode_frame(self.send_sequence, &buf)?); + self.send_sequence += 1; + self.inflight_messages.push_back(SentFrame { + sent_at: Instant::now(), + sequence: self.send_sequence, + payload: frame.clone(), + is_retransmit: false, + }); + self.wr.write_all(&frame).await?; + + Ok(()) + } + + /// Retransmit all inflight messages + async fn retransmit_pending(&mut self) -> Result<(), TransportError> { + let len: usize = self + .inflight_messages + .iter() + .map(|msg| msg.payload.len()) + .sum(); + let mut buf = Vec::with_capacity(1 + len); + buf.push(MESSAGE_VALUE_SYNC); + let now = Instant::now(); + for msg in self.inflight_messages.iter_mut() { + buf.extend_from_slice(&msg.payload); + msg.is_retransmit = true; + msg.sent_at = now; + } + self.wr.write_all(&buf).await?; + + if self.retransmit_now { + self.ignore_nak_seq = self.receive_sequence; + if self.receive_sequence < self.retransmit_seq { + self.ignore_nak_seq = self.retransmit_seq; + } + self.retransmit_now = false; + } else { + self.rtt_state.rto = (self.rtt_state.rto * 2.0).clamp(MIN_RTO, MAX_RTO); + self.ignore_nak_seq = self.send_sequence; + } + self.retransmit_seq = self.send_sequence; + + Ok(()) + } +} + +/// An error encountered when transmitting a message +#[derive(thiserror::Error, Debug)] +pub enum TransmitterError { + #[error("io error: {0}")] + IoError(#[from] std::io::Error), + + #[error("connection closed")] + ConnectionClosed, +} + +#[derive(Debug, Clone)] +pub(crate) struct SentFrame { + pub sent_at: Instant, + #[allow(dead_code)] + pub sequence: u64, + pub payload: Arc<Vec<u8>>, + pub is_retransmit: bool, +} + +#[derive(thiserror::Error, Debug)] +pub enum MessageEncodeError { + #[error("message would exceed the maximum packet length of {MESSAGE_LENGTH_MAX} bytes")] + MessageTooLong, +} + +fn encode_frame(sequence: u64, payload: &[u8]) -> Result<Vec<u8>, MessageEncodeError> { + let len = MESSAGE_LENGTH_MIN + payload.len(); + if len > MESSAGE_LENGTH_MAX { + return Err(MessageEncodeError::MessageTooLong); + } + let mut buf = Vec::with_capacity(len); + buf.push(len as u8); + buf.push(MESSAGE_DEST | ((sequence as u8) & MESSAGE_SEQ_MASK)); + buf.extend_from_slice(payload); + let crc = crc16(&buf[0..len - MESSAGE_TRAILER_SIZE]); + buf.push(((crc >> 8) & 0xFF) as u8); + buf.push((crc & 0xFF) as u8); + buf.push(MESSAGE_VALUE_SYNC); + Ok(buf) +} |