summaryrefslogtreecommitdiff
path: root/crates
diff options
context:
space:
mode:
authortcmal <me@aria.rip>2024-09-17 16:12:31 +0100
committertcmal <me@aria.rip>2024-09-30 23:17:34 +0100
commitfd82733126ee82b085875c44a0993534968afad3 (patch)
tree73ab58288c4da05e51c60a2e27cb62e33f2111ad /crates
parent5f9fbea5a1b08962887d16457e77a922919a8818 (diff)
Better docs / cleanup for windlass
Diffstat (limited to 'crates')
-rw-r--r--crates/windlass/src/dictionary.rs262
-rw-r--r--crates/windlass/src/encoding.rs274
-rw-r--r--crates/windlass/src/lib.rs14
-rw-r--r--crates/windlass/src/macros.rs157
-rw-r--r--crates/windlass/src/mcu.rs27
-rw-r--r--crates/windlass/src/messages.rs178
-rw-r--r--crates/windlass/src/transport.rs615
-rw-r--r--crates/windlass/src/transport/mod.rs434
-rw-r--r--crates/windlass/src/transport/read.rs87
9 files changed, 1042 insertions, 1006 deletions
diff --git a/crates/windlass/src/dictionary.rs b/crates/windlass/src/dictionary.rs
index 5665f80..270b0e5 100644
--- a/crates/windlass/src/dictionary.rs
+++ b/crates/windlass/src/dictionary.rs
@@ -1,78 +1,101 @@
-use std::collections::BTreeMap;
+use std::{collections::BTreeMap, ops::RangeInclusive};
use crate::{
encoding::encode_vlq_int,
messages::{MessageParser, MessageSkipperError},
};
-#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
-#[serde(untagged)]
-pub enum ConfigVar {
- String(String),
- Number(f64),
-}
+/// Deserialised data dictionary returned by the microcontroller.
+/// See [klipper protocol docs](https://www.klipper3d.org/Protocol.html#data-dictionary)
+#[derive(Debug)]
+pub struct Dictionary {
+ /// Map of message names to message IDs
+ message_ids: BTreeMap<String, u16>,
-#[derive(Debug, serde::Deserialize)]
-#[serde(untagged)]
-pub enum Enumeration {
- Single(i64),
- Range(i64, i64),
-}
+ /// Map of message IDs to parsers
+ message_parsers: BTreeMap<u16, MessageParser>,
-#[derive(Debug, serde::Deserialize)]
-pub(crate) struct RawDictionary {
- #[serde(default)]
+ /// Map of config variable name to variable values
config: BTreeMap<String, ConfigVar>,
- #[serde(default)]
- enumerations: BTreeMap<String, BTreeMap<String, Enumeration>>,
+ /// Map of declared [enumerations](https://www.klipper3d.org/Protocol.html#declaring-enumerations)
+ enumerations: BTreeMap<String, EnumDef>,
- #[serde(default)]
- commands: BTreeMap<String, i16>,
- #[serde(default)]
- responses: BTreeMap<String, i16>,
- #[serde(default)]
- output: BTreeMap<String, i16>,
+ /// Build version, if specified
+ build_version: Option<String>,
- #[serde(default)]
- build_versions: Option<String>,
- #[serde(default)]
+ /// Version, if specified
version: Option<String>,
- #[serde(flatten)]
+ /// Any extra data in the dictionary response
extra: BTreeMap<String, serde_json::Value>,
}
-/// Dictionary error
-#[derive(thiserror::Error, Debug)]
-pub enum DictionaryError {
- /// Found an empty command
- #[error("empty command found")]
- EmptyCommand,
- /// Received a command in an invalid format
- #[error("invalid command format: {0}")]
- InvalidCommandFormat(String, MessageSkipperError),
- /// Received an output string with an invalid format
- #[error("invalid output format: {0}")]
- InvalidOutputFormat(String, MessageSkipperError),
- /// Received a command with an invalid tag
- #[error("command tag {0} output valid range of -32..95")]
- InvalidCommandTag(u16),
-}
+impl Dictionary {
+ /// Get message id by name
+ pub fn message_id(&self, name: &str) -> Option<u16> {
+ self.message_ids.get(name).copied()
+ }
-#[derive(Debug)]
-pub struct Dictionary {
- pub message_ids: BTreeMap<String, u16>,
- pub message_parsers: BTreeMap<u16, MessageParser>,
- pub config: BTreeMap<String, ConfigVar>,
- pub enumerations: BTreeMap<String, BTreeMap<String, Enumeration>>,
- pub build_versions: Option<String>,
- pub version: Option<String>,
- pub extra: BTreeMap<String, serde_json::Value>,
+ /// Get message name by id
+ pub fn message_name(&self, id: u16) -> Option<&str> {
+ self.message_ids
+ .iter()
+ .find(|(_, i)| **i == id)
+ .map(|(name, _)| name.as_str())
+ }
+
+ /// Get message name by id
+ pub fn message_parser(&self, id: u16) -> Option<&MessageParser> {
+ self.message_parsers.get(&id)
+ }
+
+ /// Get config variable by name
+ pub fn config_var(&self, var_name: &str) -> Option<&ConfigVar> {
+ self.config.get(var_name)
+ }
+
+ /// Get enum definition by name
+ pub fn enum_def(&self, name: &str) -> Option<&EnumDef> {
+ self.enumerations.get(name)
+ }
+
+ /// Firmware build version, if specified
+ pub fn build_version(&self) -> Option<&str> {
+ self.build_version.as_deref()
+ }
+
+ /// Firmware version, if specified
+ pub fn version(&self) -> Option<&str> {
+ self.version.as_deref()
+ }
+
+ /// Extra data returned in the dictionary response
+ pub fn extra(&self) -> &BTreeMap<String, serde_json::Value> {
+ &self.extra
+ }
+
+ /// Figure out the actual value of a command tag
+ fn map_tag(tag: i16) -> Result<u16, DictionaryError> {
+ let mut buf = vec![];
+ encode_vlq_int(&mut buf, tag as u32);
+ let v = if buf.len() > 1 {
+ ((buf[0] as u16) & 0x7F) << 7 | (buf[1] as u16) & 0x7F
+ } else {
+ (buf[0] as u16) & 0x7F
+ };
+ if v >= 1 << 14 {
+ Err(DictionaryError::InvalidCommandTag(v))
+ } else {
+ Ok(v)
+ }
+ }
}
-impl Dictionary {
- pub(crate) fn from_raw_dictionary(raw: RawDictionary) -> Result<Self, DictionaryError> {
+impl TryFrom<RawDictionary> for Dictionary {
+ type Error = DictionaryError;
+
+ fn try_from(raw: RawDictionary) -> Result<Self, Self::Error> {
let mut message_ids = BTreeMap::new();
let mut message_parsers = BTreeMap::new();
@@ -107,25 +130,124 @@ impl Dictionary {
message_ids,
message_parsers,
config: raw.config,
- enumerations: raw.enumerations,
- build_versions: raw.build_versions,
+ enumerations: raw
+ .enumerations
+ .into_iter()
+ .map(|(name, vals)| (name, EnumDef(vals)))
+ .collect(),
+ build_version: raw.build_versions,
version: raw.version,
extra: raw.extra,
})
}
+}
- fn map_tag(tag: i16) -> Result<u16, DictionaryError> {
- let mut buf = vec![];
- encode_vlq_int(&mut buf, tag as u32);
- let v = if buf.len() > 1 {
- ((buf[0] as u16) & 0x7F) << 7 | (buf[1] as u16) & 0x7F
- } else {
- (buf[0] as u16) & 0x7F
- };
- if v >= 1 << 14 {
- Err(DictionaryError::InvalidCommandTag(v))
- } else {
- Ok(v)
- }
+/// Definition of an [enumerations](https://www.klipper3d.org/Protocol.html#declaring-enumerations)'s possible variables.
+#[derive(Debug)]
+pub struct EnumDef(BTreeMap<String, EnumValue>);
+
+impl EnumDef {
+ /// Get the range of valid values for this enum
+ /// Note that there's no guarantee the enum is actually contiguous for a range.
+ pub fn valid_range(&self) -> RangeInclusive<i64> {
+ let min = self
+ .0
+ .values()
+ .map(|v| match *v {
+ EnumValue::Single(x) => x,
+ EnumValue::Range(start, _end) => start,
+ })
+ .min()
+ .unwrap_or(0);
+ let max = self
+ .0
+ .values()
+ .map(|v| match *v {
+ EnumValue::Single(x) => x,
+ EnumValue::Range(_start, end) => end,
+ })
+ .max()
+ .unwrap_or(0);
+
+ min..=max
}
+
+ /// Lookup the name of the given enum value
+ pub fn lookup_name(&self, val: i64) -> Option<String> {
+ self.0
+ .iter()
+ .find(|(_, vals)| match **vals {
+ EnumValue::Single(x) => x == val,
+ EnumValue::Range(start, end) => (start..=end).contains(&val),
+ })
+ .map(|(name, vals)| match vals {
+ EnumValue::Single(_) => name.to_string(),
+ EnumValue::Range(_, _) => todo!("modify name to end with correct number"),
+ })
+ }
+}
+
+/// Error encountered when getting dictionary from microcontroller
+#[derive(thiserror::Error, Debug)]
+pub enum DictionaryError {
+ /// Found an empty command
+ #[error("empty command found")]
+ EmptyCommand,
+
+ /// Received a command in an invalid format
+ #[error("invalid command format: {0}")]
+ InvalidCommandFormat(String, MessageSkipperError),
+
+ /// Received an output string with an invalid format
+ #[error("invalid output format: {0}")]
+ InvalidOutputFormat(String, MessageSkipperError),
+
+ /// Received a command with an invalid tag
+ #[error("command tag {0} output valid range of -32..95")]
+ InvalidCommandTag(u16),
+}
+
+/// The raw JSON data dictionary response from the microcontroller
+#[derive(Debug, serde::Deserialize)]
+pub(crate) struct RawDictionary {
+ #[serde(default)]
+ config: BTreeMap<String, ConfigVar>,
+
+ #[serde(default)]
+ enumerations: BTreeMap<String, BTreeMap<String, EnumValue>>,
+
+ #[serde(default)]
+ commands: BTreeMap<String, i16>,
+
+ #[serde(default)]
+ responses: BTreeMap<String, i16>,
+
+ #[serde(default)]
+ output: BTreeMap<String, i16>,
+
+ #[serde(default)]
+ build_versions: Option<String>,
+
+ #[serde(default)]
+ version: Option<String>,
+
+ #[serde(flatten)]
+ extra: BTreeMap<String, serde_json::Value>,
+}
+
+/// The value of a configuration key
+#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
+#[serde(untagged)]
+pub enum ConfigVar {
+ String(String),
+ Number(f64),
+}
+
+/// Specifies a value an enumeration can take.
+/// Either one string corresponds to a single integer value, or it ends in a number and corresponds to a range of integer values.
+#[derive(Debug, serde::Deserialize)]
+#[serde(untagged)]
+enum EnumValue {
+ Single(i64),
+ Range(i64, i64),
}
diff --git a/crates/windlass/src/encoding.rs b/crates/windlass/src/encoding.rs
index 0950956..27238c0 100644
--- a/crates/windlass/src/encoding.rs
+++ b/crates/windlass/src/encoding.rs
@@ -2,17 +2,19 @@ use std::fmt::Display;
use crate::messages::MessageSkipperError;
-/// Message decoding error
-#[derive(thiserror::Error, Debug, Clone)]
-pub enum MessageDecodeError {
- /// More data was expected but none is available
- #[error("eof unexpected")]
- UnexpectedEof,
- /// A received string could not be decoded as UTF8
- #[error("invalid utf8 string")]
- Utf8Error(#[from] std::str::Utf8Error),
+/// Calculate the CRC-16 of the given buffer
+pub(crate) fn crc16(buf: &[u8]) -> u16 {
+ let mut crc = 0xFFFFu16;
+ for b in buf {
+ let b = *b ^ ((crc & 0xFF) as u8);
+ let b = b ^ (b << 4);
+ let b16 = b as u16;
+ crc = (b16 << 8 | crc >> 8) ^ (b16 >> 4) ^ (b16 << 3);
+ }
+ crc
}
+/// Encode the given integer as a [VLQ](https://www.klipper3d.org/Protocol.html#variable-length-quantities), pushing it to the back of `output`.
pub(crate) fn encode_vlq_int(output: &mut Vec<u8>, v: u32) {
let sv = v as i32;
if !(-(1 << 26)..(3 << 26)).contains(&sv) {
@@ -30,16 +32,7 @@ pub(crate) fn encode_vlq_int(output: &mut Vec<u8>, v: u32) {
output.push((sv & 0x7F) as u8);
}
-pub(crate) fn next_byte(data: &mut &[u8]) -> Result<u8, MessageDecodeError> {
- if data.is_empty() {
- Err(MessageDecodeError::UnexpectedEof)
- } else {
- let v = data[0];
- *data = &data[1..];
- Ok(v)
- }
-}
-
+/// Parse a [VLQ](https://www.klipper3d.org/Protocol.html#variable-length-quantities) from the top of `data`.
pub(crate) fn parse_vlq_int(data: &mut &[u8]) -> Result<u32, MessageDecodeError> {
let mut c = next_byte(data)? as u32;
let mut v = c & 0x7F;
@@ -54,16 +47,35 @@ pub(crate) fn parse_vlq_int(data: &mut &[u8]) -> Result<u32, MessageDecodeError>
Ok(v)
}
+/// Read the next byte from `data`, or error
+pub(crate) fn next_byte(data: &mut &[u8]) -> Result<u8, MessageDecodeError> {
+ if data.is_empty() {
+ Err(MessageDecodeError::UnexpectedEof)
+ } else {
+ let v = data[0];
+ *data = &data[1..];
+ Ok(v)
+ }
+}
+
+/// A type which can possibly be read from the top of a data buffer.
pub trait Readable<'de>: Sized {
+ /// Attempt to deserialise a value from the top of `data`.
fn read(data: &mut &'de [u8]) -> Result<Self, MessageDecodeError>;
- fn skip(data: &mut &[u8]) -> Result<(), MessageDecodeError>;
+ /// Skip over a value of this type at the top of `data`.
+ fn skip(data: &mut &'de [u8]) -> Result<(), MessageDecodeError> {
+ Self::read(data).map(|_| ())
+ }
}
+/// A type which can be written to the back of a data buffer.
pub trait Writable: Sized {
+ /// Write this type to the back of `output`.
fn write(&self, output: &mut Vec<u8>);
}
+/// TODO: Docs
pub trait Borrowable: Sized {
type Borrowed<'a>
where
@@ -71,20 +83,121 @@ pub trait Borrowable: Sized {
fn from_borrowed(src: Self::Borrowed<'_>) -> Self;
}
+/// The type of an encoded field
+#[derive(Copy, Clone, Debug, Eq, PartialEq)]
+pub enum FieldType {
+ U32,
+ I32,
+ U16,
+ I16,
+ U8,
+ String,
+ ByteArray,
+}
+
+impl FieldType {
+ /// Skip over a field of this type at the top of `input`.
+ pub(crate) fn skip(&self, input: &mut &[u8]) -> Result<(), MessageDecodeError> {
+ match self {
+ Self::U32 => <u32 as Readable>::skip(input),
+ Self::I32 => <i32 as Readable>::skip(input),
+ Self::U16 => <u16 as Readable>::skip(input),
+ Self::I16 => <i16 as Readable>::skip(input),
+ Self::U8 => <u8 as Readable>::skip(input),
+ Self::String => <&str as Readable>::skip(input),
+ Self::ByteArray => <&[u8] as Readable>::skip(input),
+ }
+ }
+
+ /// Read a field of this type from the top of `input`.
+ pub(crate) fn read(&self, input: &mut &[u8]) -> Result<FieldValue, MessageDecodeError> {
+ Ok(match self {
+ Self::U32 => FieldValue::U32(<u32 as Readable>::read(input)?),
+ Self::I32 => FieldValue::I32(<i32 as Readable>::read(input)?),
+ Self::U16 => FieldValue::U16(<u16 as Readable>::read(input)?),
+ Self::I16 => FieldValue::I16(<i16 as Readable>::read(input)?),
+ Self::U8 => FieldValue::U8(<u8 as Readable>::read(input)?),
+ Self::String => FieldValue::String(<&str as Readable>::read(input)?.into()),
+ Self::ByteArray => FieldValue::ByteArray(<&[u8] as Readable>::read(input)?.into()),
+ })
+ }
+
+ /// Deserialise the field type from a printf style declaration
+ pub(crate) fn from_msg(s: &str) -> Result<Self, MessageSkipperError> {
+ match s {
+ "%u" => Ok(Self::U32),
+ "%i" => Ok(Self::I32),
+ "%hu" => Ok(Self::U16),
+ "%hi" => Ok(Self::I16),
+ "%c" => Ok(Self::U8),
+ "%s" => Ok(Self::String),
+ "%*s" => Ok(Self::ByteArray),
+ "%.*s" => Ok(Self::ByteArray),
+ s => Err(MessageSkipperError::InvalidFormatFieldType(s.to_string())),
+ }
+ }
+
+ /// Deserialise the next field type from a printf style declaration, also returning the rest of the string.
+ pub(crate) fn from_format(s: &str) -> Result<(Self, &str), MessageSkipperError> {
+ if let Some(rest) = s.strip_prefix("%u") {
+ Ok((Self::U32, rest))
+ } else if let Some(rest) = s.strip_prefix("%i") {
+ Ok((Self::I32, rest))
+ } else if let Some(rest) = s.strip_prefix("%hu") {
+ Ok((Self::U16, rest))
+ } else if let Some(rest) = s.strip_prefix("%hi") {
+ Ok((Self::I16, rest))
+ } else if let Some(rest) = s.strip_prefix("%c") {
+ Ok((Self::U8, rest))
+ } else if let Some(rest) = s.strip_prefix("%.*s") {
+ Ok((Self::ByteArray, rest))
+ } else if let Some(rest) = s.strip_prefix("%*s") {
+ Ok((Self::String, rest))
+ } else {
+ Err(MessageSkipperError::InvalidFormatFieldType(s.to_string()))
+ }
+ }
+}
+
+/// The decoded value of a field
+#[derive(Debug)]
+pub enum FieldValue {
+ U32(u32),
+ I32(i32),
+ U16(u16),
+ I16(i16),
+ U8(u8),
+ String(String),
+ ByteArray(Vec<u8>),
+}
+
+impl Display for FieldValue {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ Self::U32(v) => write!(f, "{}", v),
+ Self::I32(v) => write!(f, "{}", v),
+ Self::U16(v) => write!(f, "{}", v),
+ Self::I16(v) => write!(f, "{}", v),
+ Self::U8(v) => write!(f, "{}", v),
+ Self::String(v) => write!(f, "{}", v),
+ Self::ByteArray(v) => write!(f, "{:?}", v),
+ }
+ }
+}
+
+/// A type which can be expressed as a [`FieldType`]
pub trait ToFieldType: Sized {
+ /// Get the corresponding [`FieldType`].
fn as_field_type() -> FieldType;
}
+/// Implements [`Readable`], [`Writable`], and [`ToFieldType`] for the given integer type.
macro_rules! int_readwrite {
( $type:tt, $field_type:expr ) => {
impl Readable<'_> for $type {
fn read(data: &mut &[u8]) -> Result<Self, MessageDecodeError> {
parse_vlq_int(data).map(|v| v as $type)
}
-
- fn skip(data: &mut &[u8]) -> Result<(), MessageDecodeError> {
- parse_vlq_int(data).map(|_| ())
- }
}
impl Writable for $type {
@@ -118,10 +231,6 @@ impl Readable<'_> for bool {
fn read(data: &mut &[u8]) -> Result<Self, MessageDecodeError> {
parse_vlq_int(data).map(|v| v != 0)
}
-
- fn skip(data: &mut &[u8]) -> Result<(), MessageDecodeError> {
- parse_vlq_int(data).map(|_| ())
- }
}
impl Writable for bool {
@@ -217,111 +326,20 @@ impl Writable for &str {
}
}
-impl Borrowable for String {
- type Borrowed<'a> = &'a str;
- fn from_borrowed(src: Self::Borrowed<'_>) -> Self {
- src.to_string()
- }
-}
-
impl ToFieldType for String {
fn as_field_type() -> FieldType {
FieldType::String
}
}
-#[derive(Copy, Clone, Debug, Eq, PartialEq)]
-pub enum FieldType {
- U32,
- I32,
- U16,
- I16,
- U8,
- String,
- ByteArray,
-}
-
-impl FieldType {
- pub(crate) fn skip(&self, input: &mut &[u8]) -> Result<(), MessageDecodeError> {
- match self {
- Self::U32 => <u32 as Readable>::skip(input),
- Self::I32 => <i32 as Readable>::skip(input),
- Self::U16 => <u16 as Readable>::skip(input),
- Self::I16 => <i16 as Readable>::skip(input),
- Self::U8 => <u8 as Readable>::skip(input),
- Self::String => <&str as Readable>::skip(input),
- Self::ByteArray => <&[u8] as Readable>::skip(input),
- }
- }
-
- pub(crate) fn read(&self, input: &mut &[u8]) -> Result<FieldValue, MessageDecodeError> {
- Ok(match self {
- Self::U32 => FieldValue::U32(<u32 as Readable>::read(input)?),
- Self::I32 => FieldValue::I32(<i32 as Readable>::read(input)?),
- Self::U16 => FieldValue::U16(<u16 as Readable>::read(input)?),
- Self::I16 => FieldValue::I16(<i16 as Readable>::read(input)?),
- Self::U8 => FieldValue::U8(<u8 as Readable>::read(input)?),
- Self::String => FieldValue::String(<&str as Readable>::read(input)?.into()),
- Self::ByteArray => FieldValue::ByteArray(<&[u8] as Readable>::read(input)?.into()),
- })
- }
-
- pub(crate) fn from_msg(s: &str) -> Result<Self, MessageSkipperError> {
- match s {
- "%u" => Ok(Self::U32),
- "%i" => Ok(Self::I32),
- "%hu" => Ok(Self::U16),
- "%hi" => Ok(Self::I16),
- "%c" => Ok(Self::U8),
- "%s" => Ok(Self::String),
- "%*s" => Ok(Self::ByteArray),
- "%.*s" => Ok(Self::ByteArray),
- s => Err(MessageSkipperError::InvalidFormatFieldType(s.to_string())),
- }
- }
-
- pub(crate) fn from_format(s: &str) -> Result<(Self, &str), MessageSkipperError> {
- if let Some(rest) = s.strip_prefix("%u") {
- Ok((Self::U32, rest))
- } else if let Some(rest) = s.strip_prefix("%i") {
- Ok((Self::I32, rest))
- } else if let Some(rest) = s.strip_prefix("%hu") {
- Ok((Self::U16, rest))
- } else if let Some(rest) = s.strip_prefix("%hi") {
- Ok((Self::I16, rest))
- } else if let Some(rest) = s.strip_prefix("%c") {
- Ok((Self::U8, rest))
- } else if let Some(rest) = s.strip_prefix("%.*s") {
- Ok((Self::ByteArray, rest))
- } else if let Some(rest) = s.strip_prefix("%*s") {
- Ok((Self::String, rest))
- } else {
- Err(MessageSkipperError::InvalidFormatFieldType(s.to_string()))
- }
- }
-}
-
-#[derive(Debug)]
-pub enum FieldValue {
- U32(u32),
- I32(i32),
- U16(u16),
- I16(i16),
- U8(u8),
- String(String),
- ByteArray(Vec<u8>),
-}
+/// Error encountered when decoding a message
+#[derive(thiserror::Error, Debug, Clone)]
+pub enum MessageDecodeError {
+ /// More data was expected but none is available
+ #[error("eof unexpected")]
+ UnexpectedEof,
-impl Display for FieldValue {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- match self {
- Self::U32(v) => write!(f, "{}", v),
- Self::I32(v) => write!(f, "{}", v),
- Self::U16(v) => write!(f, "{}", v),
- Self::I16(v) => write!(f, "{}", v),
- Self::U8(v) => write!(f, "{}", v),
- Self::String(v) => write!(f, "{}", v),
- Self::ByteArray(v) => write!(f, "{:?}", v),
- }
- }
+ /// A received string could not be decoded as UTF8
+ #[error("invalid utf8 string")]
+ Utf8Error(#[from] std::str::Utf8Error),
}
diff --git a/crates/windlass/src/lib.rs b/crates/windlass/src/lib.rs
index c578bf2..66a7674 100644
--- a/crates/windlass/src/lib.rs
+++ b/crates/windlass/src/lib.rs
@@ -1,21 +1,13 @@
+//! Windlass is an implementation of the host side of the Klipper protocol.
+
#[macro_use]
#[doc(hidden)]
pub mod macros;
-
-#[doc(hidden)]
pub mod dictionary;
-
-#[doc(hidden)]
pub mod encoding;
-
-mod mcu;
-
-#[doc(hidden)]
pub mod messages;
+mod mcu;
mod transport;
-pub use dictionary::DictionaryError;
-pub use encoding::MessageDecodeError;
pub use mcu::{McuConnection, McuConnectionError};
-pub use messages::EncodedMessage;
diff --git a/crates/windlass/src/macros.rs b/crates/windlass/src/macros.rs
index 6f57a70..f17c66d 100644
--- a/crates/windlass/src/macros.rs
+++ b/crates/windlass/src/macros.rs
@@ -1,16 +1,73 @@
-pub use crate::transport::MESSAGE_LENGTH_PAYLOAD_MAX;
-
+/// Declare a host -> device message
+///
+/// Declares the format of a message sent from host to device. The general format is as follows:
+///
+/// ```text
+/// mcu_command!(<struct name>, "<command name>"( = <id>) [, arg: type, ..]);
+/// ```
+///
+/// A struct with the given name will be defined, with relevant interfaces. The message name and
+/// arguments will be matched to the MCU at runtime. The struct name has no restrictions, but it is
+/// recommended to pick a SnakeCased version of the command name.
+///
+/// Optionally an `id` can be directly specified. Generally this is not needed. When not specified,
+/// it will be automatically inferred at runtime and matched to the dictionary retrieved from the
+/// target MCU.
+///
+/// Arguments can be specified, they will be mapped to the relevant Klipper argument types. The
+/// supported types and mappings are as follows:
+///
+/// | Rust type | Format string |
+/// |------------|---------------|
+/// | `u32` | `%u` |
+/// | `i32` | `%i` |
+/// | `u16` | `%hu` |
+/// | `i16` | `%hi` |
+/// | `u8` | `%c` |
+/// | `&'a [u8]` | `%.*s`, `%*s` |
+/// | `&'a str` | `%s` |
+///
+/// Note that the buffer types take a lifetime. This must always be `'a`.
+///
+/// # Examples
+///
+/// ```ignore
+/// // This defines 'config_endstop oid=%c pin=%c pull_up=%c'
+/// mcu_command!(ConfigEndstop, "config_endstop", oid: u8, pin: u8, pull_up: u8);
+/// ```
#[macro_export]
-#[doc(hidden)]
-macro_rules! mcu_message_impl_oid_check {
- ($ty_name:ident) => {
- impl $crate::messages::WithoutOid for $ty_name {}
+macro_rules! mcu_command {
+ ($ty_name:ident, $cmd_name:literal $(, $arg:ident : $kind:ty)* $(,)?) => {
+ $crate::mcu_message_impl!($ty_name, $cmd_name = None $(, $arg: $kind)*);
};
- ($ty_name:ident, oid $(, $args:ident)*) => {
- impl $crate::messages::WithOid for $ty_name {}
+ ($ty_name:ident, $cmd_name:literal = $cmd_id:literal $(, $arg:ident : $kind:ty)* $(,)?) => {
+ $crate::mcu_message_impl!($ty_name, $cmd_name = Some($cmd_id) $(, $arg: $kind)*);
};
- ($ty_name:ident, $arg:ident $(, $args:ident)*) => {
- $crate::mcu_message_impl_oid_check!($ty_name $(, $args)*);
+}
+
+/// Declare a device -> host message
+///
+/// Declares the format of a message sent from device to host. The general format is as follows:
+///
+/// ```text
+/// mcu_reply!(<struct name>, "<reply name>"( = <id>) [, arg: type, ..]);
+/// ```
+///
+/// For more information on the various fields, see the documentation for [mcu_command].
+///
+/// # Examples
+///
+/// ```ignore
+/// // This defines 'config_endstop oid=%c pin=%c pull_up=%c'
+/// mcu_reply!(Uptime, "uptime", high: u32, clock: u32);
+/// ```
+#[macro_export]
+macro_rules! mcu_reply {
+ ($ty_name:ident, $cmd_name:literal $(, $arg:ident : $kind:ty)* $(,)?) => {
+ $crate::mcu_message_impl!($ty_name, $cmd_name = None $(, $arg: $kind)*);
+ };
+ ($ty_name:ident, $cmd_name:literal = $cmd_id:literal $(, $arg:ident : $kind:ty)* $(,)?) => {
+ $crate::mcu_message_impl!($ty_name, $cmd_name = Some($cmd_id) $(, $arg: $kind)*);
};
}
@@ -53,7 +110,7 @@ macro_rules! mcu_message_impl {
impl<'a> [<$ty_name Data>]<'a> {
fn encode(&self) -> $crate::messages::FrontTrimmableBuffer {
use $crate::encoding::Writable;
- let mut buf = Vec::with_capacity($crate::macros::MESSAGE_LENGTH_PAYLOAD_MAX);
+ let mut buf = Vec::with_capacity($crate::transport::MESSAGE_LENGTH_PAYLOAD_MAX);
buf.push(0);
buf.push(0);
$(self.$arg.write(&mut buf);)*
@@ -103,7 +160,7 @@ macro_rules! mcu_message_impl {
type PodOwned = [<$ty_name DataOwned>];
fn get_id(dict: Option<&$crate::dictionary::Dictionary>) -> Option<u16> {
- $cmd_id.or_else(|| dict.and_then(|dict| dict.message_ids.get($cmd_name).copied()))
+ $cmd_id.or_else(|| dict.and_then(|dict| dict.message_id($cmd_name)))
}
fn get_name() -> &'static str {
@@ -126,75 +183,17 @@ macro_rules! mcu_message_impl {
};
}
-/// Declare a host -> device message
-///
-/// Declares the format of a message sent from host to device. The general format is as follows:
-///
-/// ```text
-/// mcu_command!(<struct name>, "<command name>"( = <id>) [, arg: type, ..]);
-/// ```
-///
-/// A struct with the given name will be defined, with relevant interfaces. The message name and
-/// arguments will be matched to the MCU at runtime. The struct name has no restrictions, but it is
-/// recommended to pick a SnakeCased version of the command name.
-///
-/// Optionally an `id` can be directly specified. Generally this is not needed. When not specified,
-/// it will be automatically inferred at runtime and matched to the dictionary retrieved from the
-/// target MCU.
-///
-/// Arguments can be specified, they will be mapped to the relevant Klipper argument types. The
-/// supported types and mappings are as follows:
-///
-/// | Rust type | Format string |
-/// |------------|---------------|
-/// | `u32` | `%u` |
-/// | `i32` | `%i` |
-/// | `u16` | `%hu` |
-/// | `i16` | `%hi` |
-/// | `u8` | `%c` |
-/// | `&'a [u8]` | `%.*s`, `%*s` |
-/// | `&'a str` | `%s` |
-///
-/// Note that the buffer types take a lifetime. This must always be `'a`.
-///
-/// # Examples
-///
-/// ```ignore
-/// // This defines 'config_endstop oid=%c pin=%c pull_up=%c'
-/// mcu_command!(ConfigEndstop, "config_endstop", oid: u8, pin: u8, pull_up: u8);
-/// ```
#[macro_export]
-macro_rules! mcu_command {
- ($ty_name:ident, $cmd_name:literal $(, $arg:ident : $kind:ty)* $(,)?) => {
- $crate::mcu_message_impl!($ty_name, $cmd_name = None $(, $arg: $kind)*);
- };
- ($ty_name:ident, $cmd_name:literal = $cmd_id:literal $(, $arg:ident : $kind:ty)* $(,)?) => {
- $crate::mcu_message_impl!($ty_name, $cmd_name = Some($cmd_id) $(, $arg: $kind)*);
+#[doc(hidden)]
+/// Implement the [`WithOid`](crate::messages::WithOid) / [`WithoutOid`](crate::messages::WithoutOid) marker type for an MCU message type.
+macro_rules! mcu_message_impl_oid_check {
+ ($ty_name:ident) => {
+ impl $crate::messages::WithoutOid for $ty_name {}
};
-}
-
-/// Declare a device -> host message
-///
-/// Declares the format of a message sent from device to host. The general format is as follows:
-///
-/// ```text
-/// mcu_reply!(<struct name>, "<reply name>"( = <id>) [, arg: type, ..]);
-/// ```
-///
-/// For more information on the various fields, see the documentation for [mcu_command].
-///
-/// # Examples
-///
-/// ```ignore
-/// // This defines 'config_endstop oid=%c pin=%c pull_up=%c'
-/// mcu_reply!(Uptime, "uptime", high: u32, clock: u32);
-/// ```
-#[macro_export]
-macro_rules! mcu_reply {
- ($ty_name:ident, $cmd_name:literal $(, $arg:ident : $kind:ty)* $(,)?) => {
- $crate::mcu_message_impl!($ty_name, $cmd_name = None $(, $arg: $kind)*);
+ ($ty_name:ident, oid $(, $args:ident)*) => {
+ impl $crate::messages::WithOid for $ty_name {}
};
- ($ty_name:ident, $cmd_name:literal = $cmd_id:literal $(, $arg:ident : $kind:ty)* $(,)?) => {
- $crate::mcu_message_impl!($ty_name, $cmd_name = Some($cmd_id) $(, $arg: $kind)*);
+ ($ty_name:ident, $arg:ident $(, $args:ident)*) => {
+ $crate::mcu_message_impl_oid_check!($ty_name $(, $args)*);
};
}
diff --git a/crates/windlass/src/mcu.rs b/crates/windlass/src/mcu.rs
index 53fea7f..c2ca54a 100644
--- a/crates/windlass/src/mcu.rs
+++ b/crates/windlass/src/mcu.rs
@@ -7,7 +7,7 @@ use std::{
};
use tokio::{
- io::{AsyncRead, AsyncWrite},
+ io::{AsyncBufRead, AsyncWrite},
select,
sync::{
mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
@@ -181,13 +181,7 @@ impl McuConnectionInner {
// There was no registered raw handler, check the dictionary to decode
if let Some(dict) = self.dictionary.get() {
- if let Some(parser) = dict.message_parsers.get(&cmd) {
- if let Some(msg) = parser.parse_and_output(frame) {
- let msg = msg?;
- debug!(msg = msg, "Output message from MCU");
- continue;
- }
-
+ if let Some(parser) = dict.message_parser(cmd) {
let mut curmsg = &frame[..];
let oid = parser.skip_with_oid(frame)?;
if tracing::enabled!(tracing::Level::TRACE) {
@@ -245,11 +239,12 @@ impl McuConnection {
/// Attempts to connect to a MCU on the given data stream interface. The MCU is contacted and
/// the dictionary is loaded and applied. The returned [McuConnection] is fully ready to
/// communicate with the MCU.
- pub async fn connect<R>(stream: R) -> Result<Self, McuConnectionError>
+ pub async fn connect<R, W>(rdr: R, wr: W) -> Result<Self, McuConnectionError>
where
- R: AsyncRead + AsyncWrite + Send + 'static,
+ R: AsyncBufRead + Send + Unpin + 'static,
+ W: AsyncWrite + Send + Unpin + 'static,
{
- let (transport, inbox) = Transport::connect(stream).await;
+ let (transport, inbox) = Transport::connect(rdr, wr).await;
let (exit_tx, exit_rx) = oneshot::channel();
let inner = Arc::new(McuConnectionInner {
@@ -302,8 +297,7 @@ impl McuConnection {
.map_err(|err| McuConnectionError::DictionaryFetch(Box::new(err)))?;
let raw_dict: RawDictionary = serde_json::from_slice(&buf)
.map_err(|err| McuConnectionError::DictionaryFetch(Box::new(err)))?;
- let dict =
- Dictionary::from_raw_dictionary(raw_dict).map_err(McuConnectionError::Dictionary)?;
+ let dict = Dictionary::try_from(raw_dict).map_err(McuConnectionError::Dictionary)?;
debug!(dictionary = ?dict, "MCU dictionary");
conn.inner
.dictionary
@@ -336,7 +330,7 @@ impl McuConnection {
.ok_or_else(|| McuConnectionError::UnknownMessageId(C::get_name()))?;
// Must exist because we know the tag
- let parser = dictionary.message_parsers.get(&id).unwrap();
+ let parser = dictionary.message_parser(id).unwrap();
let remote_fields = parser.fields.iter().map(|(s, t)| (s.as_str(), *t));
let local_fields = C::fields().into_iter();
@@ -361,9 +355,8 @@ impl McuConnection {
&self,
command: EncodedMessage<C>,
) -> Result<FrontTrimmableBuffer, McuConnectionError> {
- let id = command
- .message_id(self.inner.dictionary.get())
- .ok_or_else(|| McuConnectionError::UnknownMessageId(command.message_name()))?;
+ let id = C::get_id(self.inner.dictionary.get())
+ .ok_or_else(|| McuConnectionError::UnknownMessageId(C::get_name()))?;
let mut payload = command.payload;
if id >= 0x80 {
payload.content[0] = ((id >> 7) & 0x7F) as u8 | 0x80;
diff --git a/crates/windlass/src/messages.rs b/crates/windlass/src/messages.rs
index 5b97d31..f08a3ca 100644
--- a/crates/windlass/src/messages.rs
+++ b/crates/windlass/src/messages.rs
@@ -3,46 +3,20 @@ use std::collections::BTreeMap;
use crate::dictionary::Dictionary;
use crate::encoding::{FieldType, FieldValue, MessageDecodeError};
-#[derive(thiserror::Error, Debug)]
-pub enum MessageSkipperError {
- #[error("invalid argument format: {0}")]
- InvalidArgumentFormat(String),
- #[error("unknown type '{1}' for argument '{0}'")]
- UnknownType(String, String),
- #[error("invalid format field type '%{0}'")]
- InvalidFormatFieldType(String),
-}
-
-pub(crate) fn format_command_args<'a>(
- fields: impl Iterator<Item = (&'a str, FieldType)>,
-) -> String {
- let mut buf = String::new();
- for (idx, (name, ty)) in fields.enumerate() {
- if idx != 0 {
- buf.push(' ');
- }
- buf.push_str(name);
- buf.push('=');
- buf.push_str(match ty {
- FieldType::U32 => "%u",
- FieldType::I32 => "%i",
- FieldType::U16 => "%hu",
- FieldType::I16 => "%hi",
- FieldType::U8 => "%c",
- FieldType::String => "%s",
- FieldType::ByteArray => "%*s",
- });
- }
- buf
-}
-
+/// A parser for a single message type
pub struct MessageParser {
+ /// The name of the message
pub name: String,
+
+ /// The fields of the message, and their types.
pub fields: Vec<(String, FieldType)>,
+
+ /// How the message should be debug printed
pub output: Option<OutputFormat>,
}
impl MessageParser {
+ /// Create a parser for a message with the given name and printf declaration parts
pub(crate) fn new<'a>(
name: &str,
parts: impl Iterator<Item = &'a str>,
@@ -63,6 +37,7 @@ impl MessageParser {
})
}
+ /// Create a parser for a message type with the given printf-style specifier
pub(crate) fn new_output(msg: &str) -> Result<MessageParser, MessageSkipperError> {
let mut fields = vec![];
let mut parts = vec![];
@@ -94,18 +69,17 @@ impl MessageParser {
})
}
+ /// Skip over this message at the top of `input`.
#[allow(dead_code)]
- pub(crate) fn skip(&self, input: &mut &[u8]) -> Result<(), MessageDecodeError> {
+ pub fn skip(&self, input: &mut &[u8]) -> Result<(), MessageDecodeError> {
for (_, field) in &self.fields {
field.skip(input)?;
}
Ok(())
}
- pub(crate) fn skip_with_oid(
- &self,
- input: &mut &[u8],
- ) -> Result<Option<u8>, MessageDecodeError> {
+ /// Skip over this message at the top of `input`, but try to read the `oid` field if it is part of this message.
+ pub fn skip_with_oid(&self, input: &mut &[u8]) -> Result<Option<u8>, MessageDecodeError> {
let mut oid = None;
for (name, field) in &self.fields {
if name == "oid" {
@@ -119,7 +93,8 @@ impl MessageParser {
Ok(oid)
}
- pub(crate) fn parse(
+ /// Parse a message of this type from the top of `input`.
+ pub fn parse(
&self,
input: &mut &[u8],
) -> Result<BTreeMap<String, FieldValue>, MessageDecodeError> {
@@ -129,37 +104,31 @@ impl MessageParser {
}
Ok(output)
}
-
- pub(crate) fn parse_and_output(
- &self,
- input: &mut &[u8],
- ) -> Option<Result<String, MessageDecodeError>> {
- self.output.as_ref().map(|output| {
- let fields = self.parse(input)?;
- Ok(output.format(fields.values()))
- })
- }
-}
-
-impl std::fmt::Debug for MessageParser {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- f.debug_map()
- .entry(&"name", &self.name)
- .entry(&"fields", &self.fields)
- .finish()
- }
}
+/// A message type
pub trait Message: 'static {
type Pod<'a>: Into<Self::PodOwned> + std::fmt::Debug;
type PodOwned: Clone + Send + std::fmt::Debug + 'static;
+
+ /// Get the message ID from the given data dictionary
fn get_id(dict: Option<&Dictionary>) -> Option<u16>;
+
+ /// Get the message name
+ // TODO: this could be an associated constant?
fn get_name() -> &'static str;
+
+ /// Decode this message type from the top of `input`.
fn decode<'a>(input: &mut &'a [u8]) -> Result<Self::Pod<'a>, MessageDecodeError>;
+
+ /// Get a list of field names and types
fn fields() -> Vec<(&'static str, FieldType)>;
}
+/// Marker trait for messages with an oid field
pub trait WithOid: 'static {}
+
+/// Marker trait for messages without an oid field
pub trait WithoutOid: 'static {}
/// Represents an encoded message, with a type-level link to the message kind
@@ -168,27 +137,6 @@ pub struct EncodedMessage<M> {
pub _message_kind: std::marker::PhantomData<M>,
}
-impl<R: Message> std::fmt::Debug for EncodedMessage<R> {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- f.debug_struct("EncodedMessage")
- .field("kind", &R::get_name())
- .field("payload", &self.payload)
- .finish()
- }
-}
-
-impl<M: Message> EncodedMessage<M> {
- #[doc(hidden)]
- pub fn message_id(&self, dict: Option<&Dictionary>) -> Option<u16> {
- M::get_id(dict)
- }
-
- #[doc(hidden)]
- pub fn message_name(&self) -> &'static str {
- M::get_name()
- }
-}
-
/// Wraps a `Vec<u8>` allowing removal of front bytes in a zero-copy way
pub struct FrontTrimmableBuffer {
pub content: Vec<u8>,
@@ -196,17 +144,12 @@ pub struct FrontTrimmableBuffer {
}
impl FrontTrimmableBuffer {
+ /// Get the rest of the buffer as a slice
pub fn as_slice(&self) -> &[u8] {
&self.content[self.offset..]
}
}
-impl std::fmt::Debug for FrontTrimmableBuffer {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- std::fmt::Debug::fmt(self.as_slice(), f)
- }
-}
-
/// Holds the format of a `output()` style debug message
#[derive(Debug)]
pub struct OutputFormat {
@@ -214,7 +157,8 @@ pub struct OutputFormat {
}
impl OutputFormat {
- fn format<'a>(&self, mut fields: impl Iterator<Item = &'a FieldValue>) -> String {
+ /// Format the given fields according to this output format.
+ pub fn format<'a>(&self, mut fields: impl Iterator<Item = &'a FieldValue>) -> String {
let mut buf = String::new();
for part in &self.parts {
match part {
@@ -230,8 +174,70 @@ impl OutputFormat {
}
}
+/// Part of an [`OutputFormat`].
#[derive(Debug)]
enum FormatBlock {
Static(String),
Field,
}
+
+/// Format the given name and type pairs as a printf style declaration.
+pub(crate) fn format_command_args<'a>(
+ fields: impl Iterator<Item = (&'a str, FieldType)>,
+) -> String {
+ let mut buf = String::new();
+ for (idx, (name, ty)) in fields.enumerate() {
+ if idx != 0 {
+ buf.push(' ');
+ }
+ buf.push_str(name);
+ buf.push('=');
+ buf.push_str(match ty {
+ FieldType::U32 => "%u",
+ FieldType::I32 => "%i",
+ FieldType::U16 => "%hu",
+ FieldType::I16 => "%hi",
+ FieldType::U8 => "%c",
+ FieldType::String => "%s",
+ FieldType::ByteArray => "%*s",
+ });
+ }
+ buf
+}
+
+/// An error enountered when parsing a message format string
+#[derive(thiserror::Error, Debug)]
+pub enum MessageSkipperError {
+ #[error("invalid argument format: {0}")]
+ InvalidArgumentFormat(String),
+
+ #[error("unknown type '{1}' for argument '{0}'")]
+ UnknownType(String, String),
+
+ #[error("invalid format field type '%{0}'")]
+ InvalidFormatFieldType(String),
+}
+
+impl std::fmt::Debug for FrontTrimmableBuffer {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ std::fmt::Debug::fmt(self.as_slice(), f)
+ }
+}
+
+impl std::fmt::Debug for MessageParser {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_map()
+ .entry(&"name", &self.name)
+ .entry(&"fields", &self.fields)
+ .finish()
+ }
+}
+
+impl<R: Message> std::fmt::Debug for EncodedMessage<R> {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("EncodedMessage")
+ .field("kind", &R::get_name())
+ .field("payload", &self.payload)
+ .finish()
+ }
+}
diff --git a/crates/windlass/src/transport.rs b/crates/windlass/src/transport.rs
deleted file mode 100644
index e47cc8e..0000000
--- a/crates/windlass/src/transport.rs
+++ /dev/null
@@ -1,615 +0,0 @@
-use std::{collections::VecDeque, sync::Arc};
-use tokio::{
- io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader},
- pin, select,
- sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
- task::{spawn, JoinHandle},
- time::{sleep_until, Instant},
-};
-use tokio_util::sync::CancellationToken;
-use tracing::trace;
-
-const MESSAGE_HEADER_SIZE: usize = 2;
-const MESSAGE_TRAILER_SIZE: usize = 3;
-pub const MESSAGE_LENGTH_MIN: usize = MESSAGE_HEADER_SIZE + MESSAGE_TRAILER_SIZE;
-pub const MESSAGE_LENGTH_MAX: usize = 64;
-pub const MESSAGE_LENGTH_PAYLOAD_MAX: usize = MESSAGE_LENGTH_MAX - MESSAGE_LENGTH_MIN;
-const MESSAGE_POSITION_SEQ: usize = 1;
-const MESSAGE_TRAILER_CRC: usize = 3;
-const MESSAGE_VALUE_SYNC: u8 = 0x7E;
-const MESSAGE_DEST: u8 = 0x10;
-const MESSAGE_SEQ_MASK: u8 = 0x0F;
-
-#[derive(thiserror::Error, Debug)]
-pub enum ReceiverError {
- #[error("io error: {0}")]
- IoError(#[from] std::io::Error),
-}
-
-#[derive(thiserror::Error, Debug)]
-pub enum TransmitterError {
- #[error("io error: {0}")]
- IoError(#[from] std::io::Error),
- #[error("connection closed")]
- ConnectionClosed,
-}
-
-#[derive(Debug)]
-pub(crate) struct Transport {
- _task_rdr: JoinHandle<()>,
- _task_wr: JoinHandle<()>,
- task_inner: JoinHandle<()>,
- outbox_tx: UnboundedSender<TransportCommand>,
-}
-
-#[derive(Debug)]
-enum TransportCommand {
- SendMessage(Vec<u8>),
- Exit,
-}
-
-pub(crate) type TransportReceiver = UnboundedReceiver<Result<Vec<u8>, TransportError>>;
-
-impl Transport {
- pub(crate) async fn connect<R>(stream: R) -> (Transport, TransportReceiver)
- where
- R: AsyncRead + AsyncWrite + Send + 'static,
- {
- let (rdr, wr) = tokio::io::split(stream);
-
- let (raw_send_tx, raw_send_rx) = unbounded_channel();
- let (raw_recv_tx, raw_recv_rx) = unbounded_channel();
- let (app_inbox_tx, app_inbox_rx) = unbounded_channel();
- let (app_outbox_tx, app_outbox_rx) = unbounded_channel();
-
- let cancel_token = CancellationToken::new();
-
- let cancel = cancel_token.clone();
- let ait = app_inbox_tx.clone();
- let task_rdr = spawn(async move {
- if let Err(e) = LowlevelReader::run(raw_recv_tx, rdr, cancel).await {
- let _ = ait.send(Err(TransportError::Receiver(e)));
- }
- });
-
- let cancel = cancel_token.clone();
- let ait = app_inbox_tx.clone();
- let task_wr = spawn(async move {
- if let Err(e) = LowlevelWriter::run(raw_send_rx, wr, cancel).await {
- let _ = ait.send(Err(TransportError::Transmitter(e)));
- }
- });
-
- let task_inner = spawn(async move {
- let mut ts = TransportState::new(
- raw_recv_rx,
- raw_send_tx,
- app_inbox_tx,
- app_outbox_rx,
- cancel_token,
- );
- if let Err(e) = ts.protocol_handler().await {
- let _ = ts.app_inbox.send(Err(e));
- }
- });
-
- (
- Transport {
- _task_rdr: task_rdr,
- _task_wr: task_wr,
- task_inner,
- outbox_tx: app_outbox_tx,
- },
- app_inbox_rx,
- )
- }
-
- pub(crate) fn send(&self, msg: &[u8]) -> Result<(), TransmitterError> {
- self.outbox_tx
- .send(TransportCommand::SendMessage(msg.into()))
- .map_err(|_| TransmitterError::ConnectionClosed)
- }
-
- pub(crate) async fn close(self) {
- let _ = self.outbox_tx.send(TransportCommand::Exit);
- let _ = self.task_inner.await;
- }
-}
-
-struct LowlevelReader<R> {
- rdr: BufReader<R>,
- synced: bool,
-}
-
-impl<R: AsyncRead + Unpin> LowlevelReader<R> {
- async fn read_frame(&mut self) -> Result<Option<Frame>, ReceiverError> {
- let mut buf = [0u8; MESSAGE_LENGTH_MAX];
- let next_byte = self.rdr.read_u8().await?;
- if next_byte == MESSAGE_VALUE_SYNC {
- if !self.synced {
- self.synced = true;
- }
- return Ok(None);
- }
- if !self.synced {
- return Ok(None);
- }
-
- let receive_time = Instant::now();
- let len = next_byte as usize;
-
- if !(MESSAGE_LENGTH_MIN..=MESSAGE_LENGTH_MAX).contains(&len) {
- self.synced = false;
- return Ok(None);
- }
-
- self.rdr.read_exact(&mut buf[1..len]).await?;
- buf[0] = len as u8;
- let buf = &buf[..len];
- let seq = buf[MESSAGE_POSITION_SEQ];
- trace!(frame = ?buf, seq = seq, "Received frame");
-
- if seq & !MESSAGE_SEQ_MASK != MESSAGE_DEST {
- self.synced = false;
- return Ok(None);
- }
-
- let actual_crc = crc16(&buf[0..len - MESSAGE_TRAILER_SIZE]);
- let frame_crc = (buf[len - MESSAGE_TRAILER_CRC] as u16) << 8
- | (buf[len - MESSAGE_TRAILER_CRC + 1] as u16);
- if frame_crc != actual_crc {
- self.synced = false;
- return Ok(None);
- }
-
- Ok(Some(Frame {
- receive_time,
- sequence: seq & MESSAGE_SEQ_MASK,
- payload: buf[MESSAGE_HEADER_SIZE..len - MESSAGE_TRAILER_SIZE].into(),
- }))
- }
-
- async fn run(
- outbox: UnboundedSender<Frame>,
- rdr: R,
- cancel_token: CancellationToken,
- ) -> Result<(), ReceiverError>
- where
- R: AsyncRead + Unpin,
- {
- let mut state = Self {
- rdr: BufReader::new(rdr),
- synced: false,
- };
-
- loop {
- match state.read_frame().await {
- Ok(None) => {}
- Ok(Some(frame)) => {
- if outbox.send(frame).is_err() {
- break Ok(());
- }
- }
- Err(_) if cancel_token.is_cancelled() => break Ok(()),
- Err(e) => break Err(e),
- }
- }
- }
-}
-
-fn crc16(buf: &[u8]) -> u16 {
- let mut crc = 0xFFFFu16;
- for b in buf {
- let b = *b ^ ((crc & 0xFF) as u8);
- let b = b ^ (b << 4);
- let b16 = b as u16;
- crc = (b16 << 8 | crc >> 8) ^ (b16 >> 4) ^ (b16 << 3);
- }
- crc
-}
-
-#[derive(Debug)]
-struct Frame {
- receive_time: Instant,
- sequence: u8,
- payload: Vec<u8>,
-}
-
-#[derive(Debug, Clone)]
-struct InflightFrame {
- sent_at: Instant,
- #[allow(dead_code)]
- sequence: u64,
- payload: Arc<Vec<u8>>,
- is_retransmit: bool,
-}
-
-#[derive(thiserror::Error, Debug)]
-pub enum MessageEncodeError {
- #[error("message would exceed the maximum packet length of {MESSAGE_LENGTH_MAX} bytes")]
- MessageTooLong,
-}
-
-fn encode_frame(sequence: u64, payload: &[u8]) -> Result<Vec<u8>, MessageEncodeError> {
- let len = MESSAGE_LENGTH_MIN + payload.len();
- if len > MESSAGE_LENGTH_MAX {
- return Err(MessageEncodeError::MessageTooLong);
- }
- let mut buf = Vec::with_capacity(len);
- buf.push(len as u8);
- buf.push(MESSAGE_DEST | ((sequence as u8) & MESSAGE_SEQ_MASK));
- buf.extend_from_slice(payload);
- let crc = crc16(&buf[0..len - MESSAGE_TRAILER_SIZE]);
- buf.push(((crc >> 8) & 0xFF) as u8);
- buf.push((crc & 0xFF) as u8);
- buf.push(MESSAGE_VALUE_SYNC);
- Ok(buf)
-}
-
-struct LowlevelWriter {}
-
-impl LowlevelWriter {
- async fn run<W>(
- mut inbox: UnboundedReceiver<Arc<Vec<u8>>>,
- mut wr: W,
- cancel_token: CancellationToken,
- ) -> Result<(), TransmitterError>
- where
- W: AsyncWrite + Unpin,
- {
- loop {
- select! {
- msg = inbox.recv() => {
- let msg = match msg {
- Some(msg) => msg,
- None => break,
- };
- trace!(payload = ?msg, seq = msg[MESSAGE_POSITION_SEQ], "Sent frame");
- wr.write_all(&msg).await?;
- wr.flush().await?;
- }
-
- _ = cancel_token.cancelled() => {
- break;
- }
- }
- }
-
- Ok(())
- }
-}
-
-#[derive(thiserror::Error, Debug)]
-pub enum TransportError {
- #[error("message encoding failed: {0}")]
- MessageEncode(#[from] MessageEncodeError),
- #[error("receiver error: {0}")]
- Receiver(#[from] ReceiverError),
- #[error("transmitter error: {0}")]
- Transmitter(#[from] TransmitterError),
-}
-
-const MIN_RTO: f32 = 0.025;
-const MAX_RTO: f32 = 5.000;
-
-#[derive(Debug)]
-struct RttState {
- srtt: f32,
- rttvar: f32,
- rto: f32,
-}
-
-impl RttState {
- fn new() -> Self {
- Self {
- srtt: 0.0,
- rttvar: 0.0,
- rto: MIN_RTO,
- }
- }
-
- fn rto(&self) -> std::time::Duration {
- std::time::Duration::from_secs_f32(self.rto)
- }
-
- fn update(&mut self, rtt: std::time::Duration) {
- let r = rtt.as_secs_f32();
- if self.srtt == 0.0 {
- self.rttvar = r / 2.0;
- self.srtt = r * 10.0; // Klipper uses this, we'll copy it
- } else {
- self.rttvar = (3.0 * self.rttvar + (self.srtt - r).abs()) / 4.0;
- self.srtt = (7.0 * self.srtt + r) / 8.0;
- }
- let rttvar4 = (self.rttvar * 4.0).max(0.001);
- self.rto = (self.srtt + rttvar4).clamp(MIN_RTO, MAX_RTO);
- }
-}
-
-#[derive(Debug)]
-struct TransportState {
- link_inbox: UnboundedReceiver<Frame>,
- link_outbox: UnboundedSender<Arc<Vec<u8>>>,
- app_inbox: UnboundedSender<Result<Vec<u8>, TransportError>>,
- app_outbox: UnboundedReceiver<TransportCommand>,
- cancel: CancellationToken,
-
- is_synchronized: bool,
- rtt_state: RttState,
- receive_sequence: u64,
- send_sequence: u64,
- last_ack_sequence: u64,
- ignore_nak_seq: u64,
- retransmit_seq: u64,
- retransmit_now: bool,
-
- corked_until: Option<Instant>,
-
- inflight_messages: VecDeque<InflightFrame>,
- ready_messages: MessageQueue<Vec<u8>>,
-}
-
-impl TransportState {
- fn new(
- link_inbox: UnboundedReceiver<Frame>,
- link_outbox: UnboundedSender<Arc<Vec<u8>>>,
- app_inbox: UnboundedSender<Result<Vec<u8>, TransportError>>,
- app_outbox: UnboundedReceiver<TransportCommand>,
- cancel: CancellationToken,
- ) -> Self {
- Self {
- link_inbox,
- link_outbox,
- app_inbox,
- app_outbox,
- cancel,
-
- is_synchronized: false,
- rtt_state: RttState::new(),
- receive_sequence: 1,
- send_sequence: 1,
- last_ack_sequence: 0,
- ignore_nak_seq: 0,
- retransmit_seq: 0,
- retransmit_now: false,
-
- corked_until: None,
-
- inflight_messages: VecDeque::new(),
- ready_messages: MessageQueue::new(),
- }
- }
-
- fn update_receive_seq(&mut self, receive_time: Instant, sequence: u64) {
- let mut sent_seq = self.receive_sequence;
-
- loop {
- if let Some(msg) = self.inflight_messages.pop_front() {
- sent_seq += 1;
- if sequence == sent_seq {
- // Found the matching sent message
- if !msg.is_retransmit {
- let elapsed = receive_time.saturating_duration_since(msg.sent_at);
- self.rtt_state.update(elapsed);
- }
- break;
- }
- } else {
- // Ack with no outstanding messages, happens during connection init
- self.send_sequence = sequence;
- break;
- }
- }
-
- self.receive_sequence = sequence;
- self.is_synchronized = true;
- }
-
- fn handle_frame(&mut self, frame: Frame) {
- let rseq = self.receive_sequence;
- let mut sequence = (rseq & !(MESSAGE_SEQ_MASK as u64)) | (frame.sequence as u64);
- if sequence < rseq {
- sequence += (MESSAGE_SEQ_MASK as u64) + 1;
- }
- if !self.is_synchronized || sequence != rseq {
- if sequence > self.send_sequence && self.is_synchronized {
- // Ack for unsent message
- return;
- }
- self.update_receive_seq(frame.receive_time, sequence);
- }
-
- if frame.payload.is_empty() {
- if self.last_ack_sequence < sequence {
- self.last_ack_sequence = sequence;
- } else if sequence > self.ignore_nak_seq && !self.inflight_messages.is_empty() {
- // Trigger retransmit from NAK
- self.retransmit_now = true;
- }
- } else {
- // Data message, we deliver this directly to the application as the MCU can't actually
- // retransmit anyway.
- let _ = self.app_inbox.send(Ok(frame.payload));
- }
- }
-
- fn can_send(&self) -> bool {
- self.corked_until.is_none() && self.inflight_messages.len() < 12
- }
-
- fn send_new_frame(&mut self, mut initial: Vec<u8>) -> Result<(), TransportError> {
- while let Some(next) = self.ready_messages.try_peek() {
- if initial.len() + next.len() <= MESSAGE_LENGTH_PAYLOAD_MAX {
- // Add to the end of the message. Unwrap is safe because we already peeked.
- let mut next = self.ready_messages.try_recv().unwrap();
- initial.append(&mut next);
- } else {
- break;
- }
- }
- let frame = Arc::new(encode_frame(self.send_sequence, &initial)?);
- self.send_sequence += 1;
- self.inflight_messages.push_back(InflightFrame {
- sent_at: Instant::now(),
- sequence: self.send_sequence,
- payload: frame.clone(),
- is_retransmit: false,
- });
- let _ = self.link_outbox.send(frame);
- Ok(())
- }
-
- fn send_more_frames(&mut self) -> Result<(), TransportError> {
- while self.can_send() && self.ready_messages.try_peek().is_some() {
- let msg = self.ready_messages.try_recv().unwrap();
- self.send_new_frame(msg)?;
- }
- Ok(())
- }
-
- fn retransmit_pending(&mut self) {
- let len: usize = self
- .inflight_messages
- .iter()
- .map(|msg| msg.payload.len())
- .sum();
- let mut buf = Vec::with_capacity(1 + len);
- buf.push(MESSAGE_VALUE_SYNC);
- let now = Instant::now();
- for msg in self.inflight_messages.iter_mut() {
- buf.extend_from_slice(&msg.payload);
- msg.is_retransmit = true;
- msg.sent_at = now;
- }
- let _ = self.link_outbox.send(Arc::new(buf));
-
- if self.retransmit_now {
- self.ignore_nak_seq = self.receive_sequence;
- if self.receive_sequence < self.retransmit_seq {
- self.ignore_nak_seq = self.retransmit_seq;
- }
- self.retransmit_now = false;
- } else {
- self.rtt_state.rto = (self.rtt_state.rto * 2.0).clamp(MIN_RTO, MAX_RTO);
- self.ignore_nak_seq = self.send_sequence;
- }
- self.retransmit_seq = self.send_sequence;
- }
-
- async fn protocol_handler(&mut self) -> Result<(), TransportError> {
- loop {
- if self.retransmit_now {
- self.retransmit_pending();
- }
-
- let retransmit_deadline = self
- .inflight_messages
- .front()
- .map(|msg| msg.sent_at + self.rtt_state.rto());
- let retransmit_timeout: futures::future::OptionFuture<_> =
- retransmit_deadline.map(sleep_until).into();
- pin!(retransmit_timeout);
-
- let corked_timeout: futures::future::OptionFuture<_> =
- self.corked_until.map(sleep_until).into();
- pin!(corked_timeout);
-
- select! {
- frame = self.link_inbox.recv() => {
- let frame = match frame {
- Some(frame) => frame,
- None => break,
- };
- self.handle_frame(frame);
- },
-
- msg = self.app_outbox.recv() => {
- match msg {
- Some(TransportCommand::SendMessage(msg)) => {
- self.ready_messages.send(msg);
- },
- Some(TransportCommand::Exit) => {
- self.cancel.cancel();
- }
- None => break,
- };
- },
-
- _ = self.ready_messages.recv_peek(), if self.can_send() => {
- self.send_more_frames()?;
- },
-
- _ = &mut retransmit_timeout, if retransmit_deadline.is_some() => {
- self.retransmit_now = true;
- },
-
- _ = &mut corked_timeout, if self.corked_until.is_some() => {
- // Timeout for when we are able to send again
- }
-
- _ = self.cancel.cancelled() => {
- break;
- },
- }
- }
- Ok(())
- }
-}
-
-#[derive(Debug)]
-struct MessageQueue<T> {
- sender: UnboundedSender<T>,
- receiver: UnboundedReceiver<T>,
- peeked: Option<T>,
-}
-
-impl<T> MessageQueue<T> {
- fn new() -> Self {
- let (sender, receiver) = unbounded_channel();
- Self {
- sender,
- receiver,
- peeked: None,
- }
- }
-
- async fn recv_peek(&mut self) -> Option<&T> {
- if self.peeked.is_some() {
- self.peeked.as_ref()
- } else {
- match self.receiver.recv().await {
- Some(msg) => {
- self.peeked = Some(msg);
- self.peeked.as_ref()
- }
- None => None,
- }
- }
- }
-
- fn try_recv(&mut self) -> Option<T> {
- if let Some(msg) = self.peeked.take() {
- Some(msg)
- } else {
- self.receiver.try_recv().ok()
- }
- }
-
- fn try_peek(&mut self) -> Option<&T> {
- if self.peeked.is_some() {
- self.peeked.as_ref()
- } else {
- match self.receiver.try_recv() {
- Ok(msg) => {
- self.peeked = Some(msg);
- self.peeked.as_ref()
- }
- Err(_) => None,
- }
- }
- }
-
- fn send(&mut self, msg: T) {
- let _ = self.sender.send(msg);
- }
-}
diff --git a/crates/windlass/src/transport/mod.rs b/crates/windlass/src/transport/mod.rs
new file mode 100644
index 0000000..f5c5fc3
--- /dev/null
+++ b/crates/windlass/src/transport/mod.rs
@@ -0,0 +1,434 @@
+use read::{FrameReader, ReceivedFrame, ReceiverError};
+use std::{collections::VecDeque, sync::Arc, time::Duration};
+use tokio::{
+ io::{AsyncBufRead, AsyncWrite, AsyncWriteExt},
+ pin, select, spawn,
+ sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
+ task::JoinHandle,
+ time::{sleep_until, Instant},
+};
+use tokio_util::sync::CancellationToken;
+
+use crate::encoding::crc16;
+
+pub(crate) const MESSAGE_HEADER_SIZE: usize = 2;
+pub(crate) const MESSAGE_TRAILER_SIZE: usize = 3;
+pub(crate) const MESSAGE_LENGTH_MIN: usize = MESSAGE_HEADER_SIZE + MESSAGE_TRAILER_SIZE;
+pub(crate) const MESSAGE_LENGTH_MAX: usize = 64;
+pub(crate) const MESSAGE_LENGTH_PAYLOAD_MAX: usize = MESSAGE_LENGTH_MAX - MESSAGE_LENGTH_MIN;
+pub(crate) const MESSAGE_POSITION_SEQ: usize = 1;
+pub(crate) const MESSAGE_TRAILER_CRC: usize = 3;
+pub(crate) const MESSAGE_VALUE_SYNC: u8 = 0x7E;
+pub(crate) const MESSAGE_DEST: u8 = 0x10;
+pub(crate) const MESSAGE_SEQ_MASK: u8 = 0x0F;
+
+mod read;
+
+/// Wrapper around a connection to a klipper firmware MCU, which deals with
+/// retransmission, flow control, etc.
+///
+/// Internally, this holds a handle to an async task and a channel to communicate with it,
+/// meaning operations don't necessarily block the current task.
+#[derive(Debug)]
+pub struct Transport {
+ // Handles to the associated async task
+ task_inner: JoinHandle<()>,
+
+ /// Queue for outbound messages
+ cmd_send: UnboundedSender<TransportCommand>,
+}
+
+/// A message sent to the transport task
+#[derive(Debug)]
+enum TransportCommand {
+ SendMessage(Vec<u8>),
+ Exit,
+}
+
+pub(crate) type TransportReceiver = UnboundedReceiver<Result<Vec<u8>, TransportError>>;
+
+impl Transport {
+ pub(crate) async fn connect(
+ rdr: impl AsyncBufRead + Unpin + Send + 'static,
+ wr: impl AsyncWrite + Unpin + Send + 'static,
+ ) -> (Transport, TransportReceiver) {
+ let (data_send, data_recv) = unbounded_channel();
+ let (cmd_send, cmd_recv) = unbounded_channel();
+
+ let cancel_token = CancellationToken::new();
+ let task_inner = spawn(async move {
+ let mut ts = TransportState::new(rdr, wr, data_send, cmd_recv, cancel_token);
+ if let Err(e) = ts.run().await {
+ let _ = ts.data_send.send(Err(e));
+ }
+ });
+
+ (
+ Transport {
+ task_inner,
+ cmd_send,
+ },
+ data_recv,
+ )
+ }
+
+ pub(crate) fn send(&self, msg: &[u8]) -> Result<(), TransmitterError> {
+ self.cmd_send
+ .send(TransportCommand::SendMessage(msg.into()))
+ .map_err(|_| TransmitterError::ConnectionClosed)
+ }
+
+ pub(crate) async fn close(self) {
+ let _ = self.cmd_send.send(TransportCommand::Exit);
+ let _ = self.task_inner.await;
+ }
+}
+
+#[derive(thiserror::Error, Debug)]
+pub enum TransportError {
+ #[error("message encoding failed: {0}")]
+ MessageEncode(#[from] MessageEncodeError),
+
+ #[error("receiver error: {0}")]
+ Receiver(#[from] ReceiverError),
+
+ #[error("transmitter error: {0}")]
+ Transmitter(#[from] TransmitterError),
+
+ #[error("io error: {0}")]
+ IOError(#[from] std::io::Error),
+}
+
+const MIN_RTO: f32 = 0.025;
+const MAX_RTO: f32 = 5.000;
+
+/// State for estimating the round trip time of the connection
+#[derive(Debug)]
+struct RttState {
+ srtt: f32,
+ rttvar: f32,
+ rto: f32,
+}
+
+impl Default for RttState {
+ fn default() -> Self {
+ Self {
+ srtt: 0.0,
+ rttvar: 0.0,
+ rto: MIN_RTO,
+ }
+ }
+}
+
+impl RttState {
+ /// Get the current recommended retransmission timeout
+ fn rto(&self) -> Duration {
+ Duration::from_secs_f32(self.rto)
+ }
+
+ /// Update the RTT estimation given a new observation
+ fn update(&mut self, rtt: Duration) {
+ let r = rtt.as_secs_f32();
+ if self.srtt == 0.0 {
+ self.rttvar = r / 2.0;
+ self.srtt = r * 10.0; // Klipper uses this, we'll copy it
+ } else {
+ self.rttvar = (3.0 * self.rttvar + (self.srtt - r).abs()) / 4.0;
+ self.srtt = (7.0 * self.srtt + r) / 8.0;
+ }
+ let rttvar4 = (self.rttvar * 4.0).max(0.001);
+ self.rto = (self.srtt + rttvar4).clamp(MIN_RTO, MAX_RTO);
+ }
+}
+
+/// State for the task which deals with transport state
+#[derive(Debug)]
+struct TransportState<R, W> {
+ rdr: FrameReader<R>,
+ wr: W,
+
+ data_send: UnboundedSender<Result<Vec<u8>, TransportError>>,
+ cmd_recv: UnboundedReceiver<TransportCommand>,
+
+ cancel: CancellationToken,
+
+ is_synchronized: bool,
+ rtt_state: RttState,
+ receive_sequence: u64,
+ send_sequence: u64,
+ last_ack_sequence: u64,
+ ignore_nak_seq: u64,
+ retransmit_seq: u64,
+ retransmit_now: bool,
+
+ corked_until: Option<Instant>,
+
+ inflight_messages: VecDeque<SentFrame>,
+ pending_messages: VecDeque<Vec<u8>>,
+}
+
+impl<R: AsyncBufRead + Unpin, W: AsyncWrite + Unpin> TransportState<R, W> {
+ fn new(
+ rdr: R,
+ wr: W,
+ data_send: UnboundedSender<Result<Vec<u8>, TransportError>>,
+ cmd_recv: UnboundedReceiver<TransportCommand>,
+ cancel: CancellationToken,
+ ) -> Self {
+ Self {
+ rdr: FrameReader::new(rdr),
+ wr,
+ data_send,
+ cmd_recv,
+ cancel,
+
+ is_synchronized: false,
+ rtt_state: RttState::default(),
+ receive_sequence: 1,
+ send_sequence: 1,
+ last_ack_sequence: 0,
+ ignore_nak_seq: 0,
+ retransmit_seq: 0,
+ retransmit_now: false,
+
+ corked_until: None,
+
+ inflight_messages: VecDeque::new(),
+ pending_messages: VecDeque::new(),
+ }
+ }
+
+ async fn run(&mut self) -> Result<(), TransportError> {
+ loop {
+ if self.retransmit_now {
+ self.retransmit_pending().await?;
+ }
+
+ if !self.pending_messages.is_empty() && self.can_send() {
+ self.send_more_frames().await?;
+ }
+
+ let retransmit_deadline = self
+ .inflight_messages
+ .front()
+ .map(|msg| msg.sent_at + self.rtt_state.rto());
+ let retransmit_timeout: futures::future::OptionFuture<_> =
+ retransmit_deadline.map(sleep_until).into();
+ pin!(retransmit_timeout);
+
+ let corked_timeout: futures::future::OptionFuture<_> =
+ self.corked_until.map(sleep_until).into();
+ pin!(corked_timeout);
+
+ // FIXME: This is not correct because read_frame is not cancellation safe
+ select! {
+ frame = self.rdr.read_frame() => {
+ let frame = frame?;
+ let frame = match frame {
+ Some(frame) => frame,
+ None => break,
+ };
+ self.handle_frame(frame);
+ },
+
+ msg = self.cmd_recv.recv() => {
+ match msg {
+ Some(TransportCommand::SendMessage(msg)) => {
+ self.pending_messages.push_back(msg);
+ },
+ Some(TransportCommand::Exit) => {
+ self.cancel.cancel();
+ }
+ None => break,
+ };
+ },
+
+ _ = &mut retransmit_timeout, if retransmit_deadline.is_some() => {
+ self.retransmit_now = true;
+ },
+
+ _ = &mut corked_timeout, if self.corked_until.is_some() => {
+ // Timeout for when we are able to send again
+ }
+
+ _ = self.cancel.cancelled() => {
+ break;
+ },
+ }
+ }
+ Ok(())
+ }
+
+ /// Handle an incoming frame, by updating sequence numbers and sending the data upwards if needed
+ fn handle_frame(&mut self, frame: ReceivedFrame) {
+ let rseq = self.receive_sequence;
+
+ // wrap-around logic(?)
+ let mut sequence = (rseq & !(MESSAGE_SEQ_MASK as u64)) | (frame.sequence as u64);
+ if sequence < rseq {
+ sequence += (MESSAGE_SEQ_MASK as u64) + 1;
+ }
+
+ // Frame acknowledges some messages
+ if !self.is_synchronized || sequence != rseq {
+ if sequence > self.send_sequence && self.is_synchronized {
+ // Ack for unsent message - weird, but ignore and try to continue
+ return;
+ }
+
+ self.update_receive_seq(frame.receive_time, sequence);
+ }
+
+ if !frame.payload.is_empty() {
+ // Data message, we deliver this directly to the application as the MCU can't actually
+ // retransmit anyway.
+ // TODO: Maybe check the CRC anyway so we can discard it here
+ let _ = self.data_send.send(Ok(frame.payload));
+ } else if sequence > self.last_ack_sequence {
+ // ACK
+ self.last_ack_sequence = sequence;
+ } else if sequence > self.ignore_nak_seq && !self.inflight_messages.is_empty() {
+ // NAK
+ self.retransmit_now = true;
+ }
+ }
+
+ /// Update the last received sequence number, removing acknowledged messages from `self.inflight_messages`
+ fn update_receive_seq(&mut self, receive_time: Instant, sequence: u64) {
+ let mut sent_seq = self.receive_sequence;
+
+ // Discard messages from inflight_messages up to sequence
+ loop {
+ if let Some(msg) = self.inflight_messages.pop_front() {
+ sent_seq += 1;
+ if sequence == sent_seq {
+ // Found the matching sent message
+ if !msg.is_retransmit {
+ let elapsed = receive_time.saturating_duration_since(msg.sent_at);
+ self.rtt_state.update(elapsed);
+ }
+ break;
+ }
+ } else {
+ // Ack with no outstanding messages, happens during connection init
+ self.send_sequence = sequence;
+ break;
+ }
+ }
+
+ self.receive_sequence = sequence;
+ self.is_synchronized = true;
+ }
+
+ fn can_send(&self) -> bool {
+ self.corked_until.is_none() && self.inflight_messages.len() < 12
+ }
+
+ /// Send as many more frames as possible from [`self.pending_messages`]
+ async fn send_more_frames(&mut self) -> Result<(), TransportError> {
+ while self.can_send() && !self.pending_messages.is_empty() {
+ self.send_new_frame().await?;
+ }
+
+ Ok(())
+ }
+
+ /// Send a single new frame from [`self.pending_messages`]
+ async fn send_new_frame(&mut self) -> Result<(), TransportError> {
+ let mut buf = Vec::new();
+ while let Some(next) = self.pending_messages.front() {
+ if !buf.is_empty() && buf.len() + next.len() <= MESSAGE_LENGTH_PAYLOAD_MAX {
+ // Add to the end of the frame. Unwrap is safe because we already peeked.
+ let mut next = self.pending_messages.pop_front().unwrap();
+ buf.append(&mut next);
+ } else {
+ break;
+ }
+ }
+
+ let frame = Arc::new(encode_frame(self.send_sequence, &buf)?);
+ self.send_sequence += 1;
+ self.inflight_messages.push_back(SentFrame {
+ sent_at: Instant::now(),
+ sequence: self.send_sequence,
+ payload: frame.clone(),
+ is_retransmit: false,
+ });
+ self.wr.write_all(&frame).await?;
+
+ Ok(())
+ }
+
+ /// Retransmit all inflight messages
+ async fn retransmit_pending(&mut self) -> Result<(), TransportError> {
+ let len: usize = self
+ .inflight_messages
+ .iter()
+ .map(|msg| msg.payload.len())
+ .sum();
+ let mut buf = Vec::with_capacity(1 + len);
+ buf.push(MESSAGE_VALUE_SYNC);
+ let now = Instant::now();
+ for msg in self.inflight_messages.iter_mut() {
+ buf.extend_from_slice(&msg.payload);
+ msg.is_retransmit = true;
+ msg.sent_at = now;
+ }
+ self.wr.write_all(&buf).await?;
+
+ if self.retransmit_now {
+ self.ignore_nak_seq = self.receive_sequence;
+ if self.receive_sequence < self.retransmit_seq {
+ self.ignore_nak_seq = self.retransmit_seq;
+ }
+ self.retransmit_now = false;
+ } else {
+ self.rtt_state.rto = (self.rtt_state.rto * 2.0).clamp(MIN_RTO, MAX_RTO);
+ self.ignore_nak_seq = self.send_sequence;
+ }
+ self.retransmit_seq = self.send_sequence;
+
+ Ok(())
+ }
+}
+
+/// An error encountered when transmitting a message
+#[derive(thiserror::Error, Debug)]
+pub enum TransmitterError {
+ #[error("io error: {0}")]
+ IoError(#[from] std::io::Error),
+
+ #[error("connection closed")]
+ ConnectionClosed,
+}
+
+#[derive(Debug, Clone)]
+pub(crate) struct SentFrame {
+ pub sent_at: Instant,
+ #[allow(dead_code)]
+ pub sequence: u64,
+ pub payload: Arc<Vec<u8>>,
+ pub is_retransmit: bool,
+}
+
+#[derive(thiserror::Error, Debug)]
+pub enum MessageEncodeError {
+ #[error("message would exceed the maximum packet length of {MESSAGE_LENGTH_MAX} bytes")]
+ MessageTooLong,
+}
+
+fn encode_frame(sequence: u64, payload: &[u8]) -> Result<Vec<u8>, MessageEncodeError> {
+ let len = MESSAGE_LENGTH_MIN + payload.len();
+ if len > MESSAGE_LENGTH_MAX {
+ return Err(MessageEncodeError::MessageTooLong);
+ }
+ let mut buf = Vec::with_capacity(len);
+ buf.push(len as u8);
+ buf.push(MESSAGE_DEST | ((sequence as u8) & MESSAGE_SEQ_MASK));
+ buf.extend_from_slice(payload);
+ let crc = crc16(&buf[0..len - MESSAGE_TRAILER_SIZE]);
+ buf.push(((crc >> 8) & 0xFF) as u8);
+ buf.push((crc & 0xFF) as u8);
+ buf.push(MESSAGE_VALUE_SYNC);
+ Ok(buf)
+}
diff --git a/crates/windlass/src/transport/read.rs b/crates/windlass/src/transport/read.rs
new file mode 100644
index 0000000..079f7b3
--- /dev/null
+++ b/crates/windlass/src/transport/read.rs
@@ -0,0 +1,87 @@
+use tokio::{
+ io::{AsyncBufRead, AsyncReadExt},
+ time::Instant,
+};
+use tracing::trace;
+
+use crate::{
+ encoding::crc16,
+ transport::{
+ MESSAGE_DEST, MESSAGE_HEADER_SIZE, MESSAGE_LENGTH_MAX, MESSAGE_LENGTH_MIN,
+ MESSAGE_POSITION_SEQ, MESSAGE_SEQ_MASK, MESSAGE_TRAILER_CRC, MESSAGE_TRAILER_SIZE,
+ MESSAGE_VALUE_SYNC,
+ },
+};
+
+#[derive(Debug)]
+pub(crate) struct ReceivedFrame {
+ pub receive_time: Instant,
+ pub sequence: u8,
+ pub payload: Vec<u8>,
+}
+
+#[derive(Debug)]
+pub struct FrameReader<R> {
+ rdr: R,
+ synced: bool,
+}
+
+impl<R: AsyncBufRead + Unpin> FrameReader<R> {
+ pub fn new(rdr: R) -> Self {
+ Self { rdr, synced: false }
+ }
+
+ pub async fn read_frame(&mut self) -> Result<Option<ReceivedFrame>, ReceiverError> {
+ let mut buf = [0u8; MESSAGE_LENGTH_MAX];
+ let next_byte = self.rdr.read_u8().await?;
+ if next_byte == MESSAGE_VALUE_SYNC {
+ if !self.synced {
+ self.synced = true;
+ }
+ return Ok(None);
+ }
+ if !self.synced {
+ return Ok(None);
+ }
+
+ let receive_time = Instant::now();
+ let len = next_byte as usize;
+
+ if !(MESSAGE_LENGTH_MIN..=MESSAGE_LENGTH_MAX).contains(&len) {
+ self.synced = false;
+ return Ok(None);
+ }
+
+ self.rdr.read_exact(&mut buf[1..len]).await?;
+ buf[0] = len as u8;
+ let buf = &buf[..len];
+ let seq = buf[MESSAGE_POSITION_SEQ];
+ trace!(frame = ?buf, seq = seq, "Received frame");
+
+ if seq & !MESSAGE_SEQ_MASK != MESSAGE_DEST {
+ self.synced = false;
+ return Ok(None);
+ }
+
+ let actual_crc = crc16(&buf[0..len - MESSAGE_TRAILER_SIZE]);
+ let frame_crc = (buf[len - MESSAGE_TRAILER_CRC] as u16) << 8
+ | (buf[len - MESSAGE_TRAILER_CRC + 1] as u16);
+ if frame_crc != actual_crc {
+ self.synced = false;
+ return Ok(None);
+ }
+
+ Ok(Some(ReceivedFrame {
+ receive_time,
+ sequence: seq & MESSAGE_SEQ_MASK,
+ payload: buf[MESSAGE_HEADER_SIZE..len - MESSAGE_TRAILER_SIZE].into(),
+ }))
+ }
+}
+
+/// An error encountered when receiving a message
+#[derive(thiserror::Error, Debug)]
+pub enum ReceiverError {
+ #[error("io error: {0}")]
+ IoError(#[from] std::io::Error),
+}