use read::{FrameReader, ReceivedFrame}; use rtt::RttState; use std::{collections::VecDeque, sync::Arc}; use tokio::io::AsyncReadExt; use tokio::{ io::{AsyncBufRead, AsyncWrite, AsyncWriteExt}, pin, select, spawn, sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}, task::JoinHandle, time::{sleep_until, Instant}, }; use tokio_util::{bytes::BytesMut, sync::CancellationToken}; use crate::encoding::crc16; mod read; mod rtt; pub(crate) const MESSAGE_HEADER_SIZE: usize = 2; pub(crate) const MESSAGE_TRAILER_SIZE: usize = 3; pub(crate) const MESSAGE_LENGTH_MIN: usize = MESSAGE_HEADER_SIZE + MESSAGE_TRAILER_SIZE; pub(crate) const MESSAGE_LENGTH_MAX: usize = 64; pub(crate) const MESSAGE_LENGTH_PAYLOAD_MAX: usize = MESSAGE_LENGTH_MAX - MESSAGE_LENGTH_MIN; pub(crate) const MESSAGE_POSITION_SEQ: usize = 1; pub(crate) const MESSAGE_TRAILER_CRC: usize = 3; pub(crate) const MESSAGE_VALUE_SYNC: u8 = 0x7E; pub(crate) const MESSAGE_DEST: u8 = 0x10; pub(crate) const MESSAGE_SEQ_MASK: u8 = 0x0F; /// Wrapper around a connection to a klipper firmware MCU, which deals with /// retransmission, flow control, etc. /// /// Internally, this holds a handle to an async task and a channel to communicate with it, /// meaning operations don't necessarily block the current task. #[derive(Debug)] pub struct Transport { // Handles to the associated async task task_inner: JoinHandle<()>, /// Queue for outbound messages cmd_send: UnboundedSender, } /// A message sent to the transport task #[derive(Debug)] enum TransportCommand { SendMessage(Vec), Exit, } pub(crate) type TransportReceiver = UnboundedReceiver, TransportError>>; impl Transport { pub(crate) async fn connect( rdr: impl AsyncBufRead + Unpin + Send + 'static, wr: impl AsyncWrite + Unpin + Send + 'static, ) -> (Transport, TransportReceiver) { let (data_send, data_recv) = unbounded_channel(); let (cmd_send, cmd_recv) = unbounded_channel(); let cancel_token = CancellationToken::new(); let task_inner = spawn(async move { let mut ts = TransportState::new(rdr, wr, data_send, cmd_recv, cancel_token); if let Err(e) = ts.run().await { let _ = ts.data_send.send(Err(e)); } }); ( Transport { task_inner, cmd_send, }, data_recv, ) } pub(crate) fn send(&self, msg: &[u8]) -> Result<(), TransmitterError> { self.cmd_send .send(TransportCommand::SendMessage(msg.into())) .map_err(|_| TransmitterError::ConnectionClosed) } pub(crate) async fn close(self) { let _ = self.cmd_send.send(TransportCommand::Exit); let _ = self.task_inner.await; } } #[derive(thiserror::Error, Debug)] pub enum TransportError { #[error("message encoding failed: {0}")] MessageEncode(#[from] MessageEncodeError), #[error("transmitter error: {0}")] Transmitter(#[from] TransmitterError), #[error("io error: {0}")] IOError(#[from] std::io::Error), } /// State for the task which deals with transport state #[derive(Debug)] struct TransportState { frdr: FrameReader, rdr: R, wr: W, data_send: UnboundedSender, TransportError>>, cmd_recv: UnboundedReceiver, cancel: CancellationToken, is_synchronized: bool, rtt_state: RttState, receive_sequence: u64, send_sequence: u64, last_ack_sequence: u64, ignore_nak_seq: u64, retransmit_seq: u64, retransmit_now: bool, corked_until: Option, inflight_messages: VecDeque, pending_messages: VecDeque>, } impl TransportState { fn new( rdr: R, wr: W, data_send: UnboundedSender, TransportError>>, cmd_recv: UnboundedReceiver, cancel: CancellationToken, ) -> Self { Self { frdr: FrameReader::default(), rdr, wr, data_send, cmd_recv, cancel, is_synchronized: false, rtt_state: RttState::default(), receive_sequence: 1, send_sequence: 1, last_ack_sequence: 0, ignore_nak_seq: 0, retransmit_seq: 0, retransmit_now: false, corked_until: None, inflight_messages: VecDeque::new(), pending_messages: VecDeque::new(), } } async fn run(&mut self) -> Result<(), TransportError> { let mut buf = BytesMut::with_capacity(MESSAGE_LENGTH_MAX); loop { if self.retransmit_now { self.retransmit_pending().await?; } if !self.pending_messages.is_empty() && self.can_send() { self.send_more_frames().await?; } while let Some(frame) = self.frdr.read_frame() { self.handle_frame(frame); } let retransmit_deadline = self .inflight_messages .front() .map(|msg| msg.sent_at + self.rtt_state.rto()); let retransmit_timeout: futures::future::OptionFuture<_> = retransmit_deadline.map(sleep_until).into(); pin!(retransmit_timeout); let corked_timeout: futures::future::OptionFuture<_> = self.corked_until.map(sleep_until).into(); pin!(corked_timeout); select! { _ = self.rdr.read_buf(&mut buf) => { self.frdr.receive_data(&buf); buf.clear(); }, msg = self.cmd_recv.recv() => { match msg { Some(TransportCommand::SendMessage(msg)) => { self.pending_messages.push_back(msg); }, Some(TransportCommand::Exit) => { self.cancel.cancel(); } None => break, }; }, _ = &mut retransmit_timeout, if retransmit_deadline.is_some() => { self.retransmit_now = true; }, _ = &mut corked_timeout, if self.corked_until.is_some() => { // Timeout for when we are able to send again } _ = self.cancel.cancelled() => { break; }, } } Ok(()) } /// Handle an incoming frame, by updating sequence numbers and sending the data upwards if needed fn handle_frame(&mut self, frame: ReceivedFrame) { let rseq = self.receive_sequence; // wrap-around logic(?) let mut sequence = (rseq & !(MESSAGE_SEQ_MASK as u64)) | (frame.sequence as u64); if sequence < rseq { sequence += (MESSAGE_SEQ_MASK as u64) + 1; } // Frame acknowledges some messages if !self.is_synchronized || sequence != rseq { if sequence > self.send_sequence && self.is_synchronized { // Ack for unsent message - weird, but ignore and try to continue return; } self.update_receive_seq(frame.receive_time, sequence); } if !frame.payload.is_empty() { // Data message, we deliver this directly to the application as the MCU can't actually // retransmit anyway. // TODO: Maybe check the CRC anyway so we can discard it here let _ = self.data_send.send(Ok(frame.payload)); } else if sequence > self.last_ack_sequence { // ACK self.last_ack_sequence = sequence; } else if sequence > self.ignore_nak_seq && !self.inflight_messages.is_empty() { // NAK self.retransmit_now = true; } } /// Update the last received sequence number, removing acknowledged messages from `self.inflight_messages` fn update_receive_seq(&mut self, receive_time: Instant, sequence: u64) { let mut sent_seq = self.receive_sequence; // Discard messages from inflight_messages up to sequence loop { if let Some(msg) = self.inflight_messages.pop_front() { sent_seq += 1; if sequence == sent_seq { // Found the matching sent message if !msg.is_retransmit { let elapsed = receive_time.saturating_duration_since(msg.sent_at); self.rtt_state.update(elapsed); } break; } } else { // Ack with no outstanding messages, happens during connection init self.send_sequence = sequence; break; } } self.receive_sequence = sequence; self.is_synchronized = true; } fn can_send(&self) -> bool { self.corked_until.is_none() && self.inflight_messages.len() < 12 } /// Send as many more frames as possible from [`self.pending_messages`] async fn send_more_frames(&mut self) -> Result<(), TransportError> { while self.can_send() && !self.pending_messages.is_empty() { self.send_new_frame().await?; } Ok(()) } /// Send a single new frame from [`self.pending_messages`] async fn send_new_frame(&mut self) -> Result<(), TransportError> { let mut buf = Vec::new(); while let Some(next) = self.pending_messages.front() { if !buf.is_empty() && buf.len() + next.len() <= MESSAGE_LENGTH_PAYLOAD_MAX { // Add to the end of the frame. Unwrap is safe because we already peeked. let mut next = self.pending_messages.pop_front().unwrap(); buf.append(&mut next); } else { break; } } let frame = Arc::new(encode_frame(self.send_sequence, &buf)?); self.send_sequence += 1; self.inflight_messages.push_back(SentFrame { sent_at: Instant::now(), sequence: self.send_sequence, payload: frame.clone(), is_retransmit: false, }); self.wr.write_all(&frame).await?; Ok(()) } /// Retransmit all inflight messages async fn retransmit_pending(&mut self) -> Result<(), TransportError> { let len: usize = self .inflight_messages .iter() .map(|msg| msg.payload.len()) .sum(); let mut buf = Vec::with_capacity(1 + len); buf.push(MESSAGE_VALUE_SYNC); let now = Instant::now(); for msg in self.inflight_messages.iter_mut() { buf.extend_from_slice(&msg.payload); msg.is_retransmit = true; msg.sent_at = now; } self.wr.write_all(&buf).await?; if self.retransmit_now { self.ignore_nak_seq = self.receive_sequence; if self.receive_sequence < self.retransmit_seq { self.ignore_nak_seq = self.retransmit_seq; } self.retransmit_now = false; } else { self.rtt_state.double(); self.ignore_nak_seq = self.send_sequence; } self.retransmit_seq = self.send_sequence; Ok(()) } } /// An error encountered when transmitting a message #[derive(thiserror::Error, Debug)] pub enum TransmitterError { #[error("io error: {0}")] IoError(#[from] std::io::Error), #[error("connection closed")] ConnectionClosed, } #[derive(Debug, Clone)] pub(crate) struct SentFrame { pub sent_at: Instant, #[allow(dead_code)] pub sequence: u64, pub payload: Arc>, pub is_retransmit: bool, } #[derive(thiserror::Error, Debug)] pub enum MessageEncodeError { #[error("message would exceed the maximum packet length of {MESSAGE_LENGTH_MAX} bytes")] MessageTooLong, } fn encode_frame(sequence: u64, payload: &[u8]) -> Result, MessageEncodeError> { let len = MESSAGE_LENGTH_MIN + payload.len(); if len > MESSAGE_LENGTH_MAX { return Err(MessageEncodeError::MessageTooLong); } let mut buf = Vec::with_capacity(len); buf.push(len as u8); buf.push(MESSAGE_DEST | ((sequence as u8) & MESSAGE_SEQ_MASK)); buf.extend_from_slice(payload); let crc = crc16(&buf[0..len - MESSAGE_TRAILER_SIZE]); buf.push(((crc >> 8) & 0xFF) as u8); buf.push((crc & 0xFF) as u8); buf.push(MESSAGE_VALUE_SYNC); Ok(buf) }