diff options
Diffstat (limited to 'crates/windlass/src/transport')
-rw-r--r-- | crates/windlass/src/transport/mod.rs | 30 | ||||
-rw-r--r-- | crates/windlass/src/transport/read.rs | 179 |
2 files changed, 137 insertions, 72 deletions
diff --git a/crates/windlass/src/transport/mod.rs b/crates/windlass/src/transport/mod.rs index f5c5fc3..6f2cc7c 100644 --- a/crates/windlass/src/transport/mod.rs +++ b/crates/windlass/src/transport/mod.rs @@ -1,5 +1,6 @@ -use read::{FrameReader, ReceivedFrame, ReceiverError}; +use read::{FrameReader, ReceivedFrame}; use std::{collections::VecDeque, sync::Arc, time::Duration}; +use tokio::io::AsyncReadExt; use tokio::{ io::{AsyncBufRead, AsyncWrite, AsyncWriteExt}, pin, select, spawn, @@ -7,7 +8,7 @@ use tokio::{ task::JoinHandle, time::{sleep_until, Instant}, }; -use tokio_util::sync::CancellationToken; +use tokio_util::{bytes::BytesMut, sync::CancellationToken}; use crate::encoding::crc16; @@ -89,9 +90,6 @@ pub enum TransportError { #[error("message encoding failed: {0}")] MessageEncode(#[from] MessageEncodeError), - #[error("receiver error: {0}")] - Receiver(#[from] ReceiverError), - #[error("transmitter error: {0}")] Transmitter(#[from] TransmitterError), @@ -144,7 +142,8 @@ impl RttState { /// State for the task which deals with transport state #[derive(Debug)] struct TransportState<R, W> { - rdr: FrameReader<R>, + frdr: FrameReader, + rdr: R, wr: W, data_send: UnboundedSender<Result<Vec<u8>, TransportError>>, @@ -176,7 +175,8 @@ impl<R: AsyncBufRead + Unpin, W: AsyncWrite + Unpin> TransportState<R, W> { cancel: CancellationToken, ) -> Self { Self { - rdr: FrameReader::new(rdr), + frdr: FrameReader::default(), + rdr, wr, data_send, cmd_recv, @@ -199,6 +199,7 @@ impl<R: AsyncBufRead + Unpin, W: AsyncWrite + Unpin> TransportState<R, W> { } async fn run(&mut self) -> Result<(), TransportError> { + let mut buf = BytesMut::with_capacity(MESSAGE_LENGTH_MAX); loop { if self.retransmit_now { self.retransmit_pending().await?; @@ -208,6 +209,10 @@ impl<R: AsyncBufRead + Unpin, W: AsyncWrite + Unpin> TransportState<R, W> { self.send_more_frames().await?; } + while let Some(frame) = self.frdr.read_frame() { + self.handle_frame(frame); + } + let retransmit_deadline = self .inflight_messages .front() @@ -220,15 +225,10 @@ impl<R: AsyncBufRead + Unpin, W: AsyncWrite + Unpin> TransportState<R, W> { self.corked_until.map(sleep_until).into(); pin!(corked_timeout); - // FIXME: This is not correct because read_frame is not cancellation safe select! { - frame = self.rdr.read_frame() => { - let frame = frame?; - let frame = match frame { - Some(frame) => frame, - None => break, - }; - self.handle_frame(frame); + _ = self.rdr.read_buf(&mut buf) => { + self.frdr.receive_data(&buf); + buf.clear(); }, msg = self.cmd_recv.recv() => { diff --git a/crates/windlass/src/transport/read.rs b/crates/windlass/src/transport/read.rs index 079f7b3..1aaf20c 100644 --- a/crates/windlass/src/transport/read.rs +++ b/crates/windlass/src/transport/read.rs @@ -1,7 +1,6 @@ -use tokio::{ - io::{AsyncBufRead, AsyncReadExt}, - time::Instant, -}; +use std::{collections::VecDeque, mem}; + +use tokio::time::Instant; use tracing::trace; use crate::{ @@ -13,6 +12,7 @@ use crate::{ }, }; +/// A link-level frame. See [klipper docs](https://www.klipper3d.org/Protocol.html#message-blocks) #[derive(Debug)] pub(crate) struct ReceivedFrame { pub receive_time: Instant, @@ -20,68 +20,133 @@ pub(crate) struct ReceivedFrame { pub payload: Vec<u8>, } -#[derive(Debug)] -pub struct FrameReader<R> { - rdr: R, - synced: bool, +/// Buffers and reads link-level frames. See [`ReceivedFrame`]. +#[derive(Debug, Default)] +pub struct FrameReader { + buf: VecDeque<ReceivedFrame>, + partial: PartialFrameState, } -impl<R: AsyncBufRead + Unpin> FrameReader<R> { - pub fn new(rdr: R) -> Self { - Self { rdr, synced: false } - } - - pub async fn read_frame(&mut self) -> Result<Option<ReceivedFrame>, ReceiverError> { - let mut buf = [0u8; MESSAGE_LENGTH_MAX]; - let next_byte = self.rdr.read_u8().await?; - if next_byte == MESSAGE_VALUE_SYNC { - if !self.synced { - self.synced = true; - } - return Ok(None); - } - if !self.synced { - return Ok(None); +impl FrameReader { + /// Process new data from `b` + pub fn receive_data(&mut self, mut b: &[u8]) { + while !b.is_empty() { + if let Some((frame, remaining)) = self.partial.receive_data(b) { + self.buf.push_back(frame); + b = remaining; + }; } + } - let receive_time = Instant::now(); - let len = next_byte as usize; + /// Attempt to read a parsed frame from the data already processed. + pub fn read_frame(&mut self) -> Option<ReceivedFrame> { + self.buf.pop_front() + } +} - if !(MESSAGE_LENGTH_MIN..=MESSAGE_LENGTH_MAX).contains(&len) { - self.synced = false; - return Ok(None); - } +/// State machine for receiving a frame. +#[derive(Debug, Default)] +pub enum PartialFrameState { + /// Waiting to sync with the other side + #[default] + Unsynced, - self.rdr.read_exact(&mut buf[1..len]).await?; - buf[0] = len as u8; - let buf = &buf[..len]; - let seq = buf[MESSAGE_POSITION_SEQ]; - trace!(frame = ?buf, seq = seq, "Received frame"); + /// Synchronised and ready to receive data + Synced, - if seq & !MESSAGE_SEQ_MASK != MESSAGE_DEST { - self.synced = false; - return Ok(None); - } + /// Received length byte, waiting to receive more data. + Receiving { + /// The total length of this frame, including header and footer. + len: usize, - let actual_crc = crc16(&buf[0..len - MESSAGE_TRAILER_SIZE]); - let frame_crc = (buf[len - MESSAGE_TRAILER_CRC] as u16) << 8 - | (buf[len - MESSAGE_TRAILER_CRC + 1] as u16); - if frame_crc != actual_crc { - self.synced = false; - return Ok(None); - } + /// Accumulated data so far. + so_far: Vec<u8>, - Ok(Some(ReceivedFrame { - receive_time, - sequence: seq & MESSAGE_SEQ_MASK, - payload: buf[MESSAGE_HEADER_SIZE..len - MESSAGE_TRAILER_SIZE].into(), - })) - } + /// When the packet started being received. + receive_time: Instant, + }, } -/// An error encountered when receiving a message -#[derive(thiserror::Error, Debug)] -pub enum ReceiverError { - #[error("io error: {0}")] - IoError(#[from] std::io::Error), +impl PartialFrameState { + pub fn receive_data<'a>(&mut self, mut b: &'a [u8]) -> Option<(ReceivedFrame, &'a [u8])> { + while !b.is_empty() { + match self { + Self::Unsynced => { + // Wait for sync byte before doing anything else + if let Some(idx) = b.iter().position(|x| *x == MESSAGE_VALUE_SYNC) { + *self = Self::Synced; + b = &b[idx + 1..]; + } + } + PartialFrameState::Synced => { + // Attempt to start a new frame + let len = b[0] as usize; + if !(MESSAGE_LENGTH_MIN..=MESSAGE_LENGTH_MAX).contains(&len) { + *self = Self::Unsynced; + continue; + } + + let receive_time = Instant::now(); + let mut so_far = Vec::with_capacity(len); + so_far.push(b[0]); + *self = Self::Receiving { + len, + so_far, + receive_time, + }; + b = &b[1..]; + } + + PartialFrameState::Receiving { len, so_far, .. } => { + // Continue to receive data for frame + let len = *len; + let take = len - so_far.len(); + so_far.extend_from_slice(&b[..take]); + if so_far.len() < len { + // Frame not yet done, most likely b is empty now. + b = &b[take + 1..]; + continue; + } + + let seq = so_far[MESSAGE_POSITION_SEQ]; + trace!(frame = ?so_far, seq = seq, "Received frame"); + + // Check validity of frame + if seq & !MESSAGE_SEQ_MASK != MESSAGE_DEST { + *self = Self::Unsynced; + continue; + } + + let actual_crc = crc16(&so_far); + let frame_crc = (so_far[len - MESSAGE_TRAILER_CRC] as u16) << 8 + | (so_far[len - MESSAGE_TRAILER_CRC + 1] as u16); + if frame_crc != actual_crc { + *self = Self::Unsynced; + continue; + } + + // Set current state back to synced + let Self::Receiving { + so_far, + receive_time, + .. + } = mem::replace(self, Self::Synced) + else { + unreachable!() + }; + + // Return received frame and unused buffer. + return Some(( + ReceivedFrame { + receive_time, + sequence: seq & MESSAGE_SEQ_MASK, + payload: so_far[MESSAGE_HEADER_SIZE..len - MESSAGE_TRAILER_SIZE].into(), + }, + &b[take..], + )); + } + } + } + None + } } |