use frame::{MessageEncodeError, Reader, ReceivedFrame, SentFrame}; use rtt::RttState; use std::collections::VecDeque; use tokio::io::AsyncReadExt; use tokio::{ io::{AsyncBufRead, AsyncWrite, AsyncWriteExt}, pin, select, spawn, sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}, task::JoinHandle, time::{sleep_until, Instant}, }; use tokio_util::{bytes::BytesMut, sync::CancellationToken}; mod frame; mod rtt; pub const MESSAGE_HEADER_SIZE: usize = 2; pub const MESSAGE_TRAILER_SIZE: usize = 3; pub const MESSAGE_LENGTH_MIN: usize = MESSAGE_HEADER_SIZE + MESSAGE_TRAILER_SIZE; pub const MESSAGE_LENGTH_MAX: usize = 64; pub const MESSAGE_LENGTH_PAYLOAD_MAX: usize = MESSAGE_LENGTH_MAX - MESSAGE_LENGTH_MIN; pub const MESSAGE_POSITION_SEQ: usize = 1; pub const MESSAGE_TRAILER_CRC: usize = 3; pub const MESSAGE_VALUE_SYNC: u8 = 0x7E; pub const MESSAGE_DEST: u8 = 0x10; pub const MESSAGE_SEQ_MASK: u8 = 0x0F; /// Wrapper around a connection to a klipper firmware MCU, which deals with /// retransmission, flow control, etc. /// /// Internally, this holds a handle to an async task and a channel to communicate with it, /// meaning operations don't necessarily block the current task. #[derive(Debug)] pub struct Transport { // Handles to the associated async task task_inner: JoinHandle<()>, /// Queue for outbound messages cmd_send: UnboundedSender, } /// A message sent to the transport task #[derive(Debug)] enum TransportCommand { SendMessage(Vec), Exit, } #[allow(clippy::module_name_repetitions)] pub type TransportReceiver = UnboundedReceiver, Error>>; impl Transport { pub(crate) fn connect( rdr: impl AsyncBufRead + Unpin + Send + 'static, wr: impl AsyncWrite + Unpin + Send + 'static, ) -> (Self, TransportReceiver) { let (data_send, data_recv) = unbounded_channel(); let (cmd_send, cmd_recv) = unbounded_channel(); let cancel_token = CancellationToken::new(); let task_inner = spawn(async move { let mut ts = TransportState::new(rdr, wr, data_send, cmd_recv, cancel_token); if let Err(e) = ts.run().await { let _ = ts.data_send.send(Err(e)); } }); ( Self { task_inner, cmd_send, }, data_recv, ) } pub(crate) fn send(&self, msg: &[u8]) -> Result<(), TransmitterError> { self.cmd_send .send(TransportCommand::SendMessage(msg.into())) .map_err(|_| TransmitterError::ConnectionClosed) } pub(crate) async fn close(self) { let _ = self.cmd_send.send(TransportCommand::Exit); let _ = self.task_inner.await; } } #[derive(thiserror::Error, Debug)] pub enum Error { #[error("message encoding failed: {0}")] MessageEncode(#[from] MessageEncodeError), #[error("transmitter error: {0}")] Transmitter(#[from] TransmitterError), #[error("io error: {0}")] IO(#[from] std::io::Error), } /// State for the task which deals with transport state #[derive(Debug)] struct TransportState { frdr: Reader, rdr: R, wr: W, data_send: UnboundedSender, Error>>, cmd_recv: UnboundedReceiver, cancel: CancellationToken, is_synchronized: bool, rtt_state: RttState, receive_sequence: u64, send_sequence: u64, last_ack_sequence: u64, ignore_nak_seq: u64, retransmit_seq: u64, retransmit_now: bool, corked_until: Option, inflight_messages: VecDeque, pending_messages: VecDeque>, } impl TransportState { fn new( rdr: R, wr: W, data_send: UnboundedSender, Error>>, cmd_recv: UnboundedReceiver, cancel: CancellationToken, ) -> Self { Self { frdr: Reader::default(), rdr, wr, data_send, cmd_recv, cancel, is_synchronized: false, rtt_state: RttState::default(), receive_sequence: 1, send_sequence: 1, last_ack_sequence: 0, ignore_nak_seq: 0, retransmit_seq: 0, retransmit_now: false, corked_until: None, inflight_messages: VecDeque::new(), pending_messages: VecDeque::new(), } } async fn run(&mut self) -> Result<(), Error> { let mut buf = BytesMut::with_capacity(MESSAGE_LENGTH_MAX); loop { if self.retransmit_now { self.retransmit_pending().await?; } if !self.pending_messages.is_empty() && self.can_send() { self.send_more_frames().await?; } while let Some(frame) = self.frdr.read_frame() { self.handle_frame(frame); } let retransmit_deadline = self .inflight_messages .front() .map(|msg| msg.sent_at + self.rtt_state.rto()); let retransmit_timeout: futures::future::OptionFuture<_> = retransmit_deadline.map(sleep_until).into(); pin!(retransmit_timeout); let corked_timeout: futures::future::OptionFuture<_> = self.corked_until.map(sleep_until).into(); pin!(corked_timeout); select! { _ = self.rdr.read_buf(&mut buf) => { self.frdr.receive_data(&buf); buf.clear(); }, msg = self.cmd_recv.recv() => { match msg { Some(TransportCommand::SendMessage(msg)) => { self.pending_messages.push_back(msg); }, Some(TransportCommand::Exit) => { self.cancel.cancel(); } None => break, }; }, _ = &mut retransmit_timeout, if retransmit_deadline.is_some() => { self.retransmit_now = true; }, _ = &mut corked_timeout, if self.corked_until.is_some() => { // Timeout for when we are able to send again } () = self.cancel.cancelled() => { break; }, } } Ok(()) } /// Handle an incoming frame, by updating sequence numbers and sending the data upwards if needed fn handle_frame(&mut self, frame: ReceivedFrame) { let rseq = self.receive_sequence; // wrap-around logic(?) let mut sequence = (rseq & !u64::from(MESSAGE_SEQ_MASK)) | u64::from(frame.sequence); if sequence < rseq { sequence += u64::from(MESSAGE_SEQ_MASK) + 1; } // Frame acknowledges some messages if !self.is_synchronized || sequence != rseq { if sequence > self.send_sequence && self.is_synchronized { // Ack for unsent message - weird, but ignore and try to continue return; } self.update_receive_seq(frame.receive_time, sequence); } if !frame.payload.is_empty() { // Data message, we deliver this directly to the application as the MCU can't actually // retransmit anyway. // TODO: Maybe check the CRC anyway so we can discard it here let _ = self.data_send.send(Ok(frame.payload)); } else if sequence > self.last_ack_sequence { // ACK self.last_ack_sequence = sequence; } else if sequence > self.ignore_nak_seq && !self.inflight_messages.is_empty() { // NAK self.retransmit_now = true; } } /// Update the last received sequence number, removing acknowledged messages from `self.inflight_messages` fn update_receive_seq(&mut self, receive_time: Instant, sequence: u64) { let mut sent_seq = self.receive_sequence; // Discard messages from inflight_messages up to sequence loop { if let Some(msg) = self.inflight_messages.pop_front() { sent_seq += 1; if sequence == sent_seq { // Found the matching sent message if !msg.is_retransmit { let elapsed = receive_time.saturating_duration_since(msg.sent_at); self.rtt_state.update(elapsed); } break; } } else { // Ack with no outstanding messages, happens during connection init self.send_sequence = sequence; break; } } self.receive_sequence = sequence; self.is_synchronized = true; } fn can_send(&self) -> bool { self.corked_until.is_none() && self.inflight_messages.len() < 12 } /// Send as many more frames as possible from [`self.pending_messages`] async fn send_more_frames(&mut self) -> Result<(), Error> { while self.can_send() && !self.pending_messages.is_empty() { self.send_new_frame().await?; } Ok(()) } /// Send a single new frame from [`self.pending_messages`] async fn send_new_frame(&mut self) -> Result<(), Error> { let mut buf = Vec::new(); while let Some(next) = self.pending_messages.front() { if !buf.is_empty() && buf.len() + next.len() <= MESSAGE_LENGTH_PAYLOAD_MAX { // Add to the end of the frame. Unwrap is safe because we already peeked. let mut next = self.pending_messages.pop_front().unwrap(); buf.append(&mut next); } else { break; } } self.inflight_messages .push_back(SentFrame::new(self.send_sequence, &buf)?); self.send_sequence += 1; self.wr .write_all(&self.inflight_messages.back().unwrap().payload) .await?; Ok(()) } /// Retransmit all inflight messages async fn retransmit_pending(&mut self) -> Result<(), Error> { let len: usize = self .inflight_messages .iter() .map(|msg| msg.payload.len()) .sum(); let mut buf = Vec::with_capacity(1 + len); buf.push(MESSAGE_VALUE_SYNC); let now = Instant::now(); for msg in &mut self.inflight_messages { buf.extend_from_slice(&msg.payload); msg.is_retransmit = true; msg.sent_at = now; } self.wr.write_all(&buf).await?; if self.retransmit_now { self.ignore_nak_seq = self.receive_sequence; if self.receive_sequence < self.retransmit_seq { self.ignore_nak_seq = self.retransmit_seq; } self.retransmit_now = false; } else { self.rtt_state.double(); self.ignore_nak_seq = self.send_sequence; } self.retransmit_seq = self.send_sequence; Ok(()) } } /// An error encountered when transmitting a message #[derive(thiserror::Error, Debug)] pub enum TransmitterError { #[error("io error: {0}")] IoError(#[from] std::io::Error), #[error("connection closed")] ConnectionClosed, }