diff options
author | tcmal <me@aria.rip> | 2024-09-17 16:12:31 +0100 |
---|---|---|
committer | tcmal <me@aria.rip> | 2024-09-30 23:17:34 +0100 |
commit | fd82733126ee82b085875c44a0993534968afad3 (patch) | |
tree | 73ab58288c4da05e51c60a2e27cb62e33f2111ad /crates | |
parent | 5f9fbea5a1b08962887d16457e77a922919a8818 (diff) |
Better docs / cleanup for windlass
Diffstat (limited to 'crates')
-rw-r--r-- | crates/windlass/src/dictionary.rs | 262 | ||||
-rw-r--r-- | crates/windlass/src/encoding.rs | 274 | ||||
-rw-r--r-- | crates/windlass/src/lib.rs | 14 | ||||
-rw-r--r-- | crates/windlass/src/macros.rs | 157 | ||||
-rw-r--r-- | crates/windlass/src/mcu.rs | 27 | ||||
-rw-r--r-- | crates/windlass/src/messages.rs | 178 | ||||
-rw-r--r-- | crates/windlass/src/transport.rs | 615 | ||||
-rw-r--r-- | crates/windlass/src/transport/mod.rs | 434 | ||||
-rw-r--r-- | crates/windlass/src/transport/read.rs | 87 |
9 files changed, 1042 insertions, 1006 deletions
diff --git a/crates/windlass/src/dictionary.rs b/crates/windlass/src/dictionary.rs index 5665f80..270b0e5 100644 --- a/crates/windlass/src/dictionary.rs +++ b/crates/windlass/src/dictionary.rs @@ -1,78 +1,101 @@ -use std::collections::BTreeMap; +use std::{collections::BTreeMap, ops::RangeInclusive}; use crate::{ encoding::encode_vlq_int, messages::{MessageParser, MessageSkipperError}, }; -#[derive(Debug, Clone, PartialEq, serde::Deserialize)] -#[serde(untagged)] -pub enum ConfigVar { - String(String), - Number(f64), -} +/// Deserialised data dictionary returned by the microcontroller. +/// See [klipper protocol docs](https://www.klipper3d.org/Protocol.html#data-dictionary) +#[derive(Debug)] +pub struct Dictionary { + /// Map of message names to message IDs + message_ids: BTreeMap<String, u16>, -#[derive(Debug, serde::Deserialize)] -#[serde(untagged)] -pub enum Enumeration { - Single(i64), - Range(i64, i64), -} + /// Map of message IDs to parsers + message_parsers: BTreeMap<u16, MessageParser>, -#[derive(Debug, serde::Deserialize)] -pub(crate) struct RawDictionary { - #[serde(default)] + /// Map of config variable name to variable values config: BTreeMap<String, ConfigVar>, - #[serde(default)] - enumerations: BTreeMap<String, BTreeMap<String, Enumeration>>, + /// Map of declared [enumerations](https://www.klipper3d.org/Protocol.html#declaring-enumerations) + enumerations: BTreeMap<String, EnumDef>, - #[serde(default)] - commands: BTreeMap<String, i16>, - #[serde(default)] - responses: BTreeMap<String, i16>, - #[serde(default)] - output: BTreeMap<String, i16>, + /// Build version, if specified + build_version: Option<String>, - #[serde(default)] - build_versions: Option<String>, - #[serde(default)] + /// Version, if specified version: Option<String>, - #[serde(flatten)] + /// Any extra data in the dictionary response extra: BTreeMap<String, serde_json::Value>, } -/// Dictionary error -#[derive(thiserror::Error, Debug)] -pub enum DictionaryError { - /// Found an empty command - #[error("empty command found")] - EmptyCommand, - /// Received a command in an invalid format - #[error("invalid command format: {0}")] - InvalidCommandFormat(String, MessageSkipperError), - /// Received an output string with an invalid format - #[error("invalid output format: {0}")] - InvalidOutputFormat(String, MessageSkipperError), - /// Received a command with an invalid tag - #[error("command tag {0} output valid range of -32..95")] - InvalidCommandTag(u16), -} +impl Dictionary { + /// Get message id by name + pub fn message_id(&self, name: &str) -> Option<u16> { + self.message_ids.get(name).copied() + } -#[derive(Debug)] -pub struct Dictionary { - pub message_ids: BTreeMap<String, u16>, - pub message_parsers: BTreeMap<u16, MessageParser>, - pub config: BTreeMap<String, ConfigVar>, - pub enumerations: BTreeMap<String, BTreeMap<String, Enumeration>>, - pub build_versions: Option<String>, - pub version: Option<String>, - pub extra: BTreeMap<String, serde_json::Value>, + /// Get message name by id + pub fn message_name(&self, id: u16) -> Option<&str> { + self.message_ids + .iter() + .find(|(_, i)| **i == id) + .map(|(name, _)| name.as_str()) + } + + /// Get message name by id + pub fn message_parser(&self, id: u16) -> Option<&MessageParser> { + self.message_parsers.get(&id) + } + + /// Get config variable by name + pub fn config_var(&self, var_name: &str) -> Option<&ConfigVar> { + self.config.get(var_name) + } + + /// Get enum definition by name + pub fn enum_def(&self, name: &str) -> Option<&EnumDef> { + self.enumerations.get(name) + } + + /// Firmware build version, if specified + pub fn build_version(&self) -> Option<&str> { + self.build_version.as_deref() + } + + /// Firmware version, if specified + pub fn version(&self) -> Option<&str> { + self.version.as_deref() + } + + /// Extra data returned in the dictionary response + pub fn extra(&self) -> &BTreeMap<String, serde_json::Value> { + &self.extra + } + + /// Figure out the actual value of a command tag + fn map_tag(tag: i16) -> Result<u16, DictionaryError> { + let mut buf = vec![]; + encode_vlq_int(&mut buf, tag as u32); + let v = if buf.len() > 1 { + ((buf[0] as u16) & 0x7F) << 7 | (buf[1] as u16) & 0x7F + } else { + (buf[0] as u16) & 0x7F + }; + if v >= 1 << 14 { + Err(DictionaryError::InvalidCommandTag(v)) + } else { + Ok(v) + } + } } -impl Dictionary { - pub(crate) fn from_raw_dictionary(raw: RawDictionary) -> Result<Self, DictionaryError> { +impl TryFrom<RawDictionary> for Dictionary { + type Error = DictionaryError; + + fn try_from(raw: RawDictionary) -> Result<Self, Self::Error> { let mut message_ids = BTreeMap::new(); let mut message_parsers = BTreeMap::new(); @@ -107,25 +130,124 @@ impl Dictionary { message_ids, message_parsers, config: raw.config, - enumerations: raw.enumerations, - build_versions: raw.build_versions, + enumerations: raw + .enumerations + .into_iter() + .map(|(name, vals)| (name, EnumDef(vals))) + .collect(), + build_version: raw.build_versions, version: raw.version, extra: raw.extra, }) } +} - fn map_tag(tag: i16) -> Result<u16, DictionaryError> { - let mut buf = vec![]; - encode_vlq_int(&mut buf, tag as u32); - let v = if buf.len() > 1 { - ((buf[0] as u16) & 0x7F) << 7 | (buf[1] as u16) & 0x7F - } else { - (buf[0] as u16) & 0x7F - }; - if v >= 1 << 14 { - Err(DictionaryError::InvalidCommandTag(v)) - } else { - Ok(v) - } +/// Definition of an [enumerations](https://www.klipper3d.org/Protocol.html#declaring-enumerations)'s possible variables. +#[derive(Debug)] +pub struct EnumDef(BTreeMap<String, EnumValue>); + +impl EnumDef { + /// Get the range of valid values for this enum + /// Note that there's no guarantee the enum is actually contiguous for a range. + pub fn valid_range(&self) -> RangeInclusive<i64> { + let min = self + .0 + .values() + .map(|v| match *v { + EnumValue::Single(x) => x, + EnumValue::Range(start, _end) => start, + }) + .min() + .unwrap_or(0); + let max = self + .0 + .values() + .map(|v| match *v { + EnumValue::Single(x) => x, + EnumValue::Range(_start, end) => end, + }) + .max() + .unwrap_or(0); + + min..=max } + + /// Lookup the name of the given enum value + pub fn lookup_name(&self, val: i64) -> Option<String> { + self.0 + .iter() + .find(|(_, vals)| match **vals { + EnumValue::Single(x) => x == val, + EnumValue::Range(start, end) => (start..=end).contains(&val), + }) + .map(|(name, vals)| match vals { + EnumValue::Single(_) => name.to_string(), + EnumValue::Range(_, _) => todo!("modify name to end with correct number"), + }) + } +} + +/// Error encountered when getting dictionary from microcontroller +#[derive(thiserror::Error, Debug)] +pub enum DictionaryError { + /// Found an empty command + #[error("empty command found")] + EmptyCommand, + + /// Received a command in an invalid format + #[error("invalid command format: {0}")] + InvalidCommandFormat(String, MessageSkipperError), + + /// Received an output string with an invalid format + #[error("invalid output format: {0}")] + InvalidOutputFormat(String, MessageSkipperError), + + /// Received a command with an invalid tag + #[error("command tag {0} output valid range of -32..95")] + InvalidCommandTag(u16), +} + +/// The raw JSON data dictionary response from the microcontroller +#[derive(Debug, serde::Deserialize)] +pub(crate) struct RawDictionary { + #[serde(default)] + config: BTreeMap<String, ConfigVar>, + + #[serde(default)] + enumerations: BTreeMap<String, BTreeMap<String, EnumValue>>, + + #[serde(default)] + commands: BTreeMap<String, i16>, + + #[serde(default)] + responses: BTreeMap<String, i16>, + + #[serde(default)] + output: BTreeMap<String, i16>, + + #[serde(default)] + build_versions: Option<String>, + + #[serde(default)] + version: Option<String>, + + #[serde(flatten)] + extra: BTreeMap<String, serde_json::Value>, +} + +/// The value of a configuration key +#[derive(Debug, Clone, PartialEq, serde::Deserialize)] +#[serde(untagged)] +pub enum ConfigVar { + String(String), + Number(f64), +} + +/// Specifies a value an enumeration can take. +/// Either one string corresponds to a single integer value, or it ends in a number and corresponds to a range of integer values. +#[derive(Debug, serde::Deserialize)] +#[serde(untagged)] +enum EnumValue { + Single(i64), + Range(i64, i64), } diff --git a/crates/windlass/src/encoding.rs b/crates/windlass/src/encoding.rs index 0950956..27238c0 100644 --- a/crates/windlass/src/encoding.rs +++ b/crates/windlass/src/encoding.rs @@ -2,17 +2,19 @@ use std::fmt::Display; use crate::messages::MessageSkipperError; -/// Message decoding error -#[derive(thiserror::Error, Debug, Clone)] -pub enum MessageDecodeError { - /// More data was expected but none is available - #[error("eof unexpected")] - UnexpectedEof, - /// A received string could not be decoded as UTF8 - #[error("invalid utf8 string")] - Utf8Error(#[from] std::str::Utf8Error), +/// Calculate the CRC-16 of the given buffer +pub(crate) fn crc16(buf: &[u8]) -> u16 { + let mut crc = 0xFFFFu16; + for b in buf { + let b = *b ^ ((crc & 0xFF) as u8); + let b = b ^ (b << 4); + let b16 = b as u16; + crc = (b16 << 8 | crc >> 8) ^ (b16 >> 4) ^ (b16 << 3); + } + crc } +/// Encode the given integer as a [VLQ](https://www.klipper3d.org/Protocol.html#variable-length-quantities), pushing it to the back of `output`. pub(crate) fn encode_vlq_int(output: &mut Vec<u8>, v: u32) { let sv = v as i32; if !(-(1 << 26)..(3 << 26)).contains(&sv) { @@ -30,16 +32,7 @@ pub(crate) fn encode_vlq_int(output: &mut Vec<u8>, v: u32) { output.push((sv & 0x7F) as u8); } -pub(crate) fn next_byte(data: &mut &[u8]) -> Result<u8, MessageDecodeError> { - if data.is_empty() { - Err(MessageDecodeError::UnexpectedEof) - } else { - let v = data[0]; - *data = &data[1..]; - Ok(v) - } -} - +/// Parse a [VLQ](https://www.klipper3d.org/Protocol.html#variable-length-quantities) from the top of `data`. pub(crate) fn parse_vlq_int(data: &mut &[u8]) -> Result<u32, MessageDecodeError> { let mut c = next_byte(data)? as u32; let mut v = c & 0x7F; @@ -54,16 +47,35 @@ pub(crate) fn parse_vlq_int(data: &mut &[u8]) -> Result<u32, MessageDecodeError> Ok(v) } +/// Read the next byte from `data`, or error +pub(crate) fn next_byte(data: &mut &[u8]) -> Result<u8, MessageDecodeError> { + if data.is_empty() { + Err(MessageDecodeError::UnexpectedEof) + } else { + let v = data[0]; + *data = &data[1..]; + Ok(v) + } +} + +/// A type which can possibly be read from the top of a data buffer. pub trait Readable<'de>: Sized { + /// Attempt to deserialise a value from the top of `data`. fn read(data: &mut &'de [u8]) -> Result<Self, MessageDecodeError>; - fn skip(data: &mut &[u8]) -> Result<(), MessageDecodeError>; + /// Skip over a value of this type at the top of `data`. + fn skip(data: &mut &'de [u8]) -> Result<(), MessageDecodeError> { + Self::read(data).map(|_| ()) + } } +/// A type which can be written to the back of a data buffer. pub trait Writable: Sized { + /// Write this type to the back of `output`. fn write(&self, output: &mut Vec<u8>); } +/// TODO: Docs pub trait Borrowable: Sized { type Borrowed<'a> where @@ -71,20 +83,121 @@ pub trait Borrowable: Sized { fn from_borrowed(src: Self::Borrowed<'_>) -> Self; } +/// The type of an encoded field +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub enum FieldType { + U32, + I32, + U16, + I16, + U8, + String, + ByteArray, +} + +impl FieldType { + /// Skip over a field of this type at the top of `input`. + pub(crate) fn skip(&self, input: &mut &[u8]) -> Result<(), MessageDecodeError> { + match self { + Self::U32 => <u32 as Readable>::skip(input), + Self::I32 => <i32 as Readable>::skip(input), + Self::U16 => <u16 as Readable>::skip(input), + Self::I16 => <i16 as Readable>::skip(input), + Self::U8 => <u8 as Readable>::skip(input), + Self::String => <&str as Readable>::skip(input), + Self::ByteArray => <&[u8] as Readable>::skip(input), + } + } + + /// Read a field of this type from the top of `input`. + pub(crate) fn read(&self, input: &mut &[u8]) -> Result<FieldValue, MessageDecodeError> { + Ok(match self { + Self::U32 => FieldValue::U32(<u32 as Readable>::read(input)?), + Self::I32 => FieldValue::I32(<i32 as Readable>::read(input)?), + Self::U16 => FieldValue::U16(<u16 as Readable>::read(input)?), + Self::I16 => FieldValue::I16(<i16 as Readable>::read(input)?), + Self::U8 => FieldValue::U8(<u8 as Readable>::read(input)?), + Self::String => FieldValue::String(<&str as Readable>::read(input)?.into()), + Self::ByteArray => FieldValue::ByteArray(<&[u8] as Readable>::read(input)?.into()), + }) + } + + /// Deserialise the field type from a printf style declaration + pub(crate) fn from_msg(s: &str) -> Result<Self, MessageSkipperError> { + match s { + "%u" => Ok(Self::U32), + "%i" => Ok(Self::I32), + "%hu" => Ok(Self::U16), + "%hi" => Ok(Self::I16), + "%c" => Ok(Self::U8), + "%s" => Ok(Self::String), + "%*s" => Ok(Self::ByteArray), + "%.*s" => Ok(Self::ByteArray), + s => Err(MessageSkipperError::InvalidFormatFieldType(s.to_string())), + } + } + + /// Deserialise the next field type from a printf style declaration, also returning the rest of the string. + pub(crate) fn from_format(s: &str) -> Result<(Self, &str), MessageSkipperError> { + if let Some(rest) = s.strip_prefix("%u") { + Ok((Self::U32, rest)) + } else if let Some(rest) = s.strip_prefix("%i") { + Ok((Self::I32, rest)) + } else if let Some(rest) = s.strip_prefix("%hu") { + Ok((Self::U16, rest)) + } else if let Some(rest) = s.strip_prefix("%hi") { + Ok((Self::I16, rest)) + } else if let Some(rest) = s.strip_prefix("%c") { + Ok((Self::U8, rest)) + } else if let Some(rest) = s.strip_prefix("%.*s") { + Ok((Self::ByteArray, rest)) + } else if let Some(rest) = s.strip_prefix("%*s") { + Ok((Self::String, rest)) + } else { + Err(MessageSkipperError::InvalidFormatFieldType(s.to_string())) + } + } +} + +/// The decoded value of a field +#[derive(Debug)] +pub enum FieldValue { + U32(u32), + I32(i32), + U16(u16), + I16(i16), + U8(u8), + String(String), + ByteArray(Vec<u8>), +} + +impl Display for FieldValue { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::U32(v) => write!(f, "{}", v), + Self::I32(v) => write!(f, "{}", v), + Self::U16(v) => write!(f, "{}", v), + Self::I16(v) => write!(f, "{}", v), + Self::U8(v) => write!(f, "{}", v), + Self::String(v) => write!(f, "{}", v), + Self::ByteArray(v) => write!(f, "{:?}", v), + } + } +} + +/// A type which can be expressed as a [`FieldType`] pub trait ToFieldType: Sized { + /// Get the corresponding [`FieldType`]. fn as_field_type() -> FieldType; } +/// Implements [`Readable`], [`Writable`], and [`ToFieldType`] for the given integer type. macro_rules! int_readwrite { ( $type:tt, $field_type:expr ) => { impl Readable<'_> for $type { fn read(data: &mut &[u8]) -> Result<Self, MessageDecodeError> { parse_vlq_int(data).map(|v| v as $type) } - - fn skip(data: &mut &[u8]) -> Result<(), MessageDecodeError> { - parse_vlq_int(data).map(|_| ()) - } } impl Writable for $type { @@ -118,10 +231,6 @@ impl Readable<'_> for bool { fn read(data: &mut &[u8]) -> Result<Self, MessageDecodeError> { parse_vlq_int(data).map(|v| v != 0) } - - fn skip(data: &mut &[u8]) -> Result<(), MessageDecodeError> { - parse_vlq_int(data).map(|_| ()) - } } impl Writable for bool { @@ -217,111 +326,20 @@ impl Writable for &str { } } -impl Borrowable for String { - type Borrowed<'a> = &'a str; - fn from_borrowed(src: Self::Borrowed<'_>) -> Self { - src.to_string() - } -} - impl ToFieldType for String { fn as_field_type() -> FieldType { FieldType::String } } -#[derive(Copy, Clone, Debug, Eq, PartialEq)] -pub enum FieldType { - U32, - I32, - U16, - I16, - U8, - String, - ByteArray, -} - -impl FieldType { - pub(crate) fn skip(&self, input: &mut &[u8]) -> Result<(), MessageDecodeError> { - match self { - Self::U32 => <u32 as Readable>::skip(input), - Self::I32 => <i32 as Readable>::skip(input), - Self::U16 => <u16 as Readable>::skip(input), - Self::I16 => <i16 as Readable>::skip(input), - Self::U8 => <u8 as Readable>::skip(input), - Self::String => <&str as Readable>::skip(input), - Self::ByteArray => <&[u8] as Readable>::skip(input), - } - } - - pub(crate) fn read(&self, input: &mut &[u8]) -> Result<FieldValue, MessageDecodeError> { - Ok(match self { - Self::U32 => FieldValue::U32(<u32 as Readable>::read(input)?), - Self::I32 => FieldValue::I32(<i32 as Readable>::read(input)?), - Self::U16 => FieldValue::U16(<u16 as Readable>::read(input)?), - Self::I16 => FieldValue::I16(<i16 as Readable>::read(input)?), - Self::U8 => FieldValue::U8(<u8 as Readable>::read(input)?), - Self::String => FieldValue::String(<&str as Readable>::read(input)?.into()), - Self::ByteArray => FieldValue::ByteArray(<&[u8] as Readable>::read(input)?.into()), - }) - } - - pub(crate) fn from_msg(s: &str) -> Result<Self, MessageSkipperError> { - match s { - "%u" => Ok(Self::U32), - "%i" => Ok(Self::I32), - "%hu" => Ok(Self::U16), - "%hi" => Ok(Self::I16), - "%c" => Ok(Self::U8), - "%s" => Ok(Self::String), - "%*s" => Ok(Self::ByteArray), - "%.*s" => Ok(Self::ByteArray), - s => Err(MessageSkipperError::InvalidFormatFieldType(s.to_string())), - } - } - - pub(crate) fn from_format(s: &str) -> Result<(Self, &str), MessageSkipperError> { - if let Some(rest) = s.strip_prefix("%u") { - Ok((Self::U32, rest)) - } else if let Some(rest) = s.strip_prefix("%i") { - Ok((Self::I32, rest)) - } else if let Some(rest) = s.strip_prefix("%hu") { - Ok((Self::U16, rest)) - } else if let Some(rest) = s.strip_prefix("%hi") { - Ok((Self::I16, rest)) - } else if let Some(rest) = s.strip_prefix("%c") { - Ok((Self::U8, rest)) - } else if let Some(rest) = s.strip_prefix("%.*s") { - Ok((Self::ByteArray, rest)) - } else if let Some(rest) = s.strip_prefix("%*s") { - Ok((Self::String, rest)) - } else { - Err(MessageSkipperError::InvalidFormatFieldType(s.to_string())) - } - } -} - -#[derive(Debug)] -pub enum FieldValue { - U32(u32), - I32(i32), - U16(u16), - I16(i16), - U8(u8), - String(String), - ByteArray(Vec<u8>), -} +/// Error encountered when decoding a message +#[derive(thiserror::Error, Debug, Clone)] +pub enum MessageDecodeError { + /// More data was expected but none is available + #[error("eof unexpected")] + UnexpectedEof, -impl Display for FieldValue { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::U32(v) => write!(f, "{}", v), - Self::I32(v) => write!(f, "{}", v), - Self::U16(v) => write!(f, "{}", v), - Self::I16(v) => write!(f, "{}", v), - Self::U8(v) => write!(f, "{}", v), - Self::String(v) => write!(f, "{}", v), - Self::ByteArray(v) => write!(f, "{:?}", v), - } - } + /// A received string could not be decoded as UTF8 + #[error("invalid utf8 string")] + Utf8Error(#[from] std::str::Utf8Error), } diff --git a/crates/windlass/src/lib.rs b/crates/windlass/src/lib.rs index c578bf2..66a7674 100644 --- a/crates/windlass/src/lib.rs +++ b/crates/windlass/src/lib.rs @@ -1,21 +1,13 @@ +//! Windlass is an implementation of the host side of the Klipper protocol. + #[macro_use] #[doc(hidden)] pub mod macros; - -#[doc(hidden)] pub mod dictionary; - -#[doc(hidden)] pub mod encoding; - -mod mcu; - -#[doc(hidden)] pub mod messages; +mod mcu; mod transport; -pub use dictionary::DictionaryError; -pub use encoding::MessageDecodeError; pub use mcu::{McuConnection, McuConnectionError}; -pub use messages::EncodedMessage; diff --git a/crates/windlass/src/macros.rs b/crates/windlass/src/macros.rs index 6f57a70..f17c66d 100644 --- a/crates/windlass/src/macros.rs +++ b/crates/windlass/src/macros.rs @@ -1,16 +1,73 @@ -pub use crate::transport::MESSAGE_LENGTH_PAYLOAD_MAX; - +/// Declare a host -> device message +/// +/// Declares the format of a message sent from host to device. The general format is as follows: +/// +/// ```text +/// mcu_command!(<struct name>, "<command name>"( = <id>) [, arg: type, ..]); +/// ``` +/// +/// A struct with the given name will be defined, with relevant interfaces. The message name and +/// arguments will be matched to the MCU at runtime. The struct name has no restrictions, but it is +/// recommended to pick a SnakeCased version of the command name. +/// +/// Optionally an `id` can be directly specified. Generally this is not needed. When not specified, +/// it will be automatically inferred at runtime and matched to the dictionary retrieved from the +/// target MCU. +/// +/// Arguments can be specified, they will be mapped to the relevant Klipper argument types. The +/// supported types and mappings are as follows: +/// +/// | Rust type | Format string | +/// |------------|---------------| +/// | `u32` | `%u` | +/// | `i32` | `%i` | +/// | `u16` | `%hu` | +/// | `i16` | `%hi` | +/// | `u8` | `%c` | +/// | `&'a [u8]` | `%.*s`, `%*s` | +/// | `&'a str` | `%s` | +/// +/// Note that the buffer types take a lifetime. This must always be `'a`. +/// +/// # Examples +/// +/// ```ignore +/// // This defines 'config_endstop oid=%c pin=%c pull_up=%c' +/// mcu_command!(ConfigEndstop, "config_endstop", oid: u8, pin: u8, pull_up: u8); +/// ``` #[macro_export] -#[doc(hidden)] -macro_rules! mcu_message_impl_oid_check { - ($ty_name:ident) => { - impl $crate::messages::WithoutOid for $ty_name {} +macro_rules! mcu_command { + ($ty_name:ident, $cmd_name:literal $(, $arg:ident : $kind:ty)* $(,)?) => { + $crate::mcu_message_impl!($ty_name, $cmd_name = None $(, $arg: $kind)*); }; - ($ty_name:ident, oid $(, $args:ident)*) => { - impl $crate::messages::WithOid for $ty_name {} + ($ty_name:ident, $cmd_name:literal = $cmd_id:literal $(, $arg:ident : $kind:ty)* $(,)?) => { + $crate::mcu_message_impl!($ty_name, $cmd_name = Some($cmd_id) $(, $arg: $kind)*); }; - ($ty_name:ident, $arg:ident $(, $args:ident)*) => { - $crate::mcu_message_impl_oid_check!($ty_name $(, $args)*); +} + +/// Declare a device -> host message +/// +/// Declares the format of a message sent from device to host. The general format is as follows: +/// +/// ```text +/// mcu_reply!(<struct name>, "<reply name>"( = <id>) [, arg: type, ..]); +/// ``` +/// +/// For more information on the various fields, see the documentation for [mcu_command]. +/// +/// # Examples +/// +/// ```ignore +/// // This defines 'config_endstop oid=%c pin=%c pull_up=%c' +/// mcu_reply!(Uptime, "uptime", high: u32, clock: u32); +/// ``` +#[macro_export] +macro_rules! mcu_reply { + ($ty_name:ident, $cmd_name:literal $(, $arg:ident : $kind:ty)* $(,)?) => { + $crate::mcu_message_impl!($ty_name, $cmd_name = None $(, $arg: $kind)*); + }; + ($ty_name:ident, $cmd_name:literal = $cmd_id:literal $(, $arg:ident : $kind:ty)* $(,)?) => { + $crate::mcu_message_impl!($ty_name, $cmd_name = Some($cmd_id) $(, $arg: $kind)*); }; } @@ -53,7 +110,7 @@ macro_rules! mcu_message_impl { impl<'a> [<$ty_name Data>]<'a> { fn encode(&self) -> $crate::messages::FrontTrimmableBuffer { use $crate::encoding::Writable; - let mut buf = Vec::with_capacity($crate::macros::MESSAGE_LENGTH_PAYLOAD_MAX); + let mut buf = Vec::with_capacity($crate::transport::MESSAGE_LENGTH_PAYLOAD_MAX); buf.push(0); buf.push(0); $(self.$arg.write(&mut buf);)* @@ -103,7 +160,7 @@ macro_rules! mcu_message_impl { type PodOwned = [<$ty_name DataOwned>]; fn get_id(dict: Option<&$crate::dictionary::Dictionary>) -> Option<u16> { - $cmd_id.or_else(|| dict.and_then(|dict| dict.message_ids.get($cmd_name).copied())) + $cmd_id.or_else(|| dict.and_then(|dict| dict.message_id($cmd_name))) } fn get_name() -> &'static str { @@ -126,75 +183,17 @@ macro_rules! mcu_message_impl { }; } -/// Declare a host -> device message -/// -/// Declares the format of a message sent from host to device. The general format is as follows: -/// -/// ```text -/// mcu_command!(<struct name>, "<command name>"( = <id>) [, arg: type, ..]); -/// ``` -/// -/// A struct with the given name will be defined, with relevant interfaces. The message name and -/// arguments will be matched to the MCU at runtime. The struct name has no restrictions, but it is -/// recommended to pick a SnakeCased version of the command name. -/// -/// Optionally an `id` can be directly specified. Generally this is not needed. When not specified, -/// it will be automatically inferred at runtime and matched to the dictionary retrieved from the -/// target MCU. -/// -/// Arguments can be specified, they will be mapped to the relevant Klipper argument types. The -/// supported types and mappings are as follows: -/// -/// | Rust type | Format string | -/// |------------|---------------| -/// | `u32` | `%u` | -/// | `i32` | `%i` | -/// | `u16` | `%hu` | -/// | `i16` | `%hi` | -/// | `u8` | `%c` | -/// | `&'a [u8]` | `%.*s`, `%*s` | -/// | `&'a str` | `%s` | -/// -/// Note that the buffer types take a lifetime. This must always be `'a`. -/// -/// # Examples -/// -/// ```ignore -/// // This defines 'config_endstop oid=%c pin=%c pull_up=%c' -/// mcu_command!(ConfigEndstop, "config_endstop", oid: u8, pin: u8, pull_up: u8); -/// ``` #[macro_export] -macro_rules! mcu_command { - ($ty_name:ident, $cmd_name:literal $(, $arg:ident : $kind:ty)* $(,)?) => { - $crate::mcu_message_impl!($ty_name, $cmd_name = None $(, $arg: $kind)*); - }; - ($ty_name:ident, $cmd_name:literal = $cmd_id:literal $(, $arg:ident : $kind:ty)* $(,)?) => { - $crate::mcu_message_impl!($ty_name, $cmd_name = Some($cmd_id) $(, $arg: $kind)*); +#[doc(hidden)] +/// Implement the [`WithOid`](crate::messages::WithOid) / [`WithoutOid`](crate::messages::WithoutOid) marker type for an MCU message type. +macro_rules! mcu_message_impl_oid_check { + ($ty_name:ident) => { + impl $crate::messages::WithoutOid for $ty_name {} }; -} - -/// Declare a device -> host message -/// -/// Declares the format of a message sent from device to host. The general format is as follows: -/// -/// ```text -/// mcu_reply!(<struct name>, "<reply name>"( = <id>) [, arg: type, ..]); -/// ``` -/// -/// For more information on the various fields, see the documentation for [mcu_command]. -/// -/// # Examples -/// -/// ```ignore -/// // This defines 'config_endstop oid=%c pin=%c pull_up=%c' -/// mcu_reply!(Uptime, "uptime", high: u32, clock: u32); -/// ``` -#[macro_export] -macro_rules! mcu_reply { - ($ty_name:ident, $cmd_name:literal $(, $arg:ident : $kind:ty)* $(,)?) => { - $crate::mcu_message_impl!($ty_name, $cmd_name = None $(, $arg: $kind)*); + ($ty_name:ident, oid $(, $args:ident)*) => { + impl $crate::messages::WithOid for $ty_name {} }; - ($ty_name:ident, $cmd_name:literal = $cmd_id:literal $(, $arg:ident : $kind:ty)* $(,)?) => { - $crate::mcu_message_impl!($ty_name, $cmd_name = Some($cmd_id) $(, $arg: $kind)*); + ($ty_name:ident, $arg:ident $(, $args:ident)*) => { + $crate::mcu_message_impl_oid_check!($ty_name $(, $args)*); }; } diff --git a/crates/windlass/src/mcu.rs b/crates/windlass/src/mcu.rs index 53fea7f..c2ca54a 100644 --- a/crates/windlass/src/mcu.rs +++ b/crates/windlass/src/mcu.rs @@ -7,7 +7,7 @@ use std::{ }; use tokio::{ - io::{AsyncRead, AsyncWrite}, + io::{AsyncBufRead, AsyncWrite}, select, sync::{ mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}, @@ -181,13 +181,7 @@ impl McuConnectionInner { // There was no registered raw handler, check the dictionary to decode if let Some(dict) = self.dictionary.get() { - if let Some(parser) = dict.message_parsers.get(&cmd) { - if let Some(msg) = parser.parse_and_output(frame) { - let msg = msg?; - debug!(msg = msg, "Output message from MCU"); - continue; - } - + if let Some(parser) = dict.message_parser(cmd) { let mut curmsg = &frame[..]; let oid = parser.skip_with_oid(frame)?; if tracing::enabled!(tracing::Level::TRACE) { @@ -245,11 +239,12 @@ impl McuConnection { /// Attempts to connect to a MCU on the given data stream interface. The MCU is contacted and /// the dictionary is loaded and applied. The returned [McuConnection] is fully ready to /// communicate with the MCU. - pub async fn connect<R>(stream: R) -> Result<Self, McuConnectionError> + pub async fn connect<R, W>(rdr: R, wr: W) -> Result<Self, McuConnectionError> where - R: AsyncRead + AsyncWrite + Send + 'static, + R: AsyncBufRead + Send + Unpin + 'static, + W: AsyncWrite + Send + Unpin + 'static, { - let (transport, inbox) = Transport::connect(stream).await; + let (transport, inbox) = Transport::connect(rdr, wr).await; let (exit_tx, exit_rx) = oneshot::channel(); let inner = Arc::new(McuConnectionInner { @@ -302,8 +297,7 @@ impl McuConnection { .map_err(|err| McuConnectionError::DictionaryFetch(Box::new(err)))?; let raw_dict: RawDictionary = serde_json::from_slice(&buf) .map_err(|err| McuConnectionError::DictionaryFetch(Box::new(err)))?; - let dict = - Dictionary::from_raw_dictionary(raw_dict).map_err(McuConnectionError::Dictionary)?; + let dict = Dictionary::try_from(raw_dict).map_err(McuConnectionError::Dictionary)?; debug!(dictionary = ?dict, "MCU dictionary"); conn.inner .dictionary @@ -336,7 +330,7 @@ impl McuConnection { .ok_or_else(|| McuConnectionError::UnknownMessageId(C::get_name()))?; // Must exist because we know the tag - let parser = dictionary.message_parsers.get(&id).unwrap(); + let parser = dictionary.message_parser(id).unwrap(); let remote_fields = parser.fields.iter().map(|(s, t)| (s.as_str(), *t)); let local_fields = C::fields().into_iter(); @@ -361,9 +355,8 @@ impl McuConnection { &self, command: EncodedMessage<C>, ) -> Result<FrontTrimmableBuffer, McuConnectionError> { - let id = command - .message_id(self.inner.dictionary.get()) - .ok_or_else(|| McuConnectionError::UnknownMessageId(command.message_name()))?; + let id = C::get_id(self.inner.dictionary.get()) + .ok_or_else(|| McuConnectionError::UnknownMessageId(C::get_name()))?; let mut payload = command.payload; if id >= 0x80 { payload.content[0] = ((id >> 7) & 0x7F) as u8 | 0x80; diff --git a/crates/windlass/src/messages.rs b/crates/windlass/src/messages.rs index 5b97d31..f08a3ca 100644 --- a/crates/windlass/src/messages.rs +++ b/crates/windlass/src/messages.rs @@ -3,46 +3,20 @@ use std::collections::BTreeMap; use crate::dictionary::Dictionary; use crate::encoding::{FieldType, FieldValue, MessageDecodeError}; -#[derive(thiserror::Error, Debug)] -pub enum MessageSkipperError { - #[error("invalid argument format: {0}")] - InvalidArgumentFormat(String), - #[error("unknown type '{1}' for argument '{0}'")] - UnknownType(String, String), - #[error("invalid format field type '%{0}'")] - InvalidFormatFieldType(String), -} - -pub(crate) fn format_command_args<'a>( - fields: impl Iterator<Item = (&'a str, FieldType)>, -) -> String { - let mut buf = String::new(); - for (idx, (name, ty)) in fields.enumerate() { - if idx != 0 { - buf.push(' '); - } - buf.push_str(name); - buf.push('='); - buf.push_str(match ty { - FieldType::U32 => "%u", - FieldType::I32 => "%i", - FieldType::U16 => "%hu", - FieldType::I16 => "%hi", - FieldType::U8 => "%c", - FieldType::String => "%s", - FieldType::ByteArray => "%*s", - }); - } - buf -} - +/// A parser for a single message type pub struct MessageParser { + /// The name of the message pub name: String, + + /// The fields of the message, and their types. pub fields: Vec<(String, FieldType)>, + + /// How the message should be debug printed pub output: Option<OutputFormat>, } impl MessageParser { + /// Create a parser for a message with the given name and printf declaration parts pub(crate) fn new<'a>( name: &str, parts: impl Iterator<Item = &'a str>, @@ -63,6 +37,7 @@ impl MessageParser { }) } + /// Create a parser for a message type with the given printf-style specifier pub(crate) fn new_output(msg: &str) -> Result<MessageParser, MessageSkipperError> { let mut fields = vec![]; let mut parts = vec![]; @@ -94,18 +69,17 @@ impl MessageParser { }) } + /// Skip over this message at the top of `input`. #[allow(dead_code)] - pub(crate) fn skip(&self, input: &mut &[u8]) -> Result<(), MessageDecodeError> { + pub fn skip(&self, input: &mut &[u8]) -> Result<(), MessageDecodeError> { for (_, field) in &self.fields { field.skip(input)?; } Ok(()) } - pub(crate) fn skip_with_oid( - &self, - input: &mut &[u8], - ) -> Result<Option<u8>, MessageDecodeError> { + /// Skip over this message at the top of `input`, but try to read the `oid` field if it is part of this message. + pub fn skip_with_oid(&self, input: &mut &[u8]) -> Result<Option<u8>, MessageDecodeError> { let mut oid = None; for (name, field) in &self.fields { if name == "oid" { @@ -119,7 +93,8 @@ impl MessageParser { Ok(oid) } - pub(crate) fn parse( + /// Parse a message of this type from the top of `input`. + pub fn parse( &self, input: &mut &[u8], ) -> Result<BTreeMap<String, FieldValue>, MessageDecodeError> { @@ -129,37 +104,31 @@ impl MessageParser { } Ok(output) } - - pub(crate) fn parse_and_output( - &self, - input: &mut &[u8], - ) -> Option<Result<String, MessageDecodeError>> { - self.output.as_ref().map(|output| { - let fields = self.parse(input)?; - Ok(output.format(fields.values())) - }) - } -} - -impl std::fmt::Debug for MessageParser { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_map() - .entry(&"name", &self.name) - .entry(&"fields", &self.fields) - .finish() - } } +/// A message type pub trait Message: 'static { type Pod<'a>: Into<Self::PodOwned> + std::fmt::Debug; type PodOwned: Clone + Send + std::fmt::Debug + 'static; + + /// Get the message ID from the given data dictionary fn get_id(dict: Option<&Dictionary>) -> Option<u16>; + + /// Get the message name + // TODO: this could be an associated constant? fn get_name() -> &'static str; + + /// Decode this message type from the top of `input`. fn decode<'a>(input: &mut &'a [u8]) -> Result<Self::Pod<'a>, MessageDecodeError>; + + /// Get a list of field names and types fn fields() -> Vec<(&'static str, FieldType)>; } +/// Marker trait for messages with an oid field pub trait WithOid: 'static {} + +/// Marker trait for messages without an oid field pub trait WithoutOid: 'static {} /// Represents an encoded message, with a type-level link to the message kind @@ -168,27 +137,6 @@ pub struct EncodedMessage<M> { pub _message_kind: std::marker::PhantomData<M>, } -impl<R: Message> std::fmt::Debug for EncodedMessage<R> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("EncodedMessage") - .field("kind", &R::get_name()) - .field("payload", &self.payload) - .finish() - } -} - -impl<M: Message> EncodedMessage<M> { - #[doc(hidden)] - pub fn message_id(&self, dict: Option<&Dictionary>) -> Option<u16> { - M::get_id(dict) - } - - #[doc(hidden)] - pub fn message_name(&self) -> &'static str { - M::get_name() - } -} - /// Wraps a `Vec<u8>` allowing removal of front bytes in a zero-copy way pub struct FrontTrimmableBuffer { pub content: Vec<u8>, @@ -196,17 +144,12 @@ pub struct FrontTrimmableBuffer { } impl FrontTrimmableBuffer { + /// Get the rest of the buffer as a slice pub fn as_slice(&self) -> &[u8] { &self.content[self.offset..] } } -impl std::fmt::Debug for FrontTrimmableBuffer { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - std::fmt::Debug::fmt(self.as_slice(), f) - } -} - /// Holds the format of a `output()` style debug message #[derive(Debug)] pub struct OutputFormat { @@ -214,7 +157,8 @@ pub struct OutputFormat { } impl OutputFormat { - fn format<'a>(&self, mut fields: impl Iterator<Item = &'a FieldValue>) -> String { + /// Format the given fields according to this output format. + pub fn format<'a>(&self, mut fields: impl Iterator<Item = &'a FieldValue>) -> String { let mut buf = String::new(); for part in &self.parts { match part { @@ -230,8 +174,70 @@ impl OutputFormat { } } +/// Part of an [`OutputFormat`]. #[derive(Debug)] enum FormatBlock { Static(String), Field, } + +/// Format the given name and type pairs as a printf style declaration. +pub(crate) fn format_command_args<'a>( + fields: impl Iterator<Item = (&'a str, FieldType)>, +) -> String { + let mut buf = String::new(); + for (idx, (name, ty)) in fields.enumerate() { + if idx != 0 { + buf.push(' '); + } + buf.push_str(name); + buf.push('='); + buf.push_str(match ty { + FieldType::U32 => "%u", + FieldType::I32 => "%i", + FieldType::U16 => "%hu", + FieldType::I16 => "%hi", + FieldType::U8 => "%c", + FieldType::String => "%s", + FieldType::ByteArray => "%*s", + }); + } + buf +} + +/// An error enountered when parsing a message format string +#[derive(thiserror::Error, Debug)] +pub enum MessageSkipperError { + #[error("invalid argument format: {0}")] + InvalidArgumentFormat(String), + + #[error("unknown type '{1}' for argument '{0}'")] + UnknownType(String, String), + + #[error("invalid format field type '%{0}'")] + InvalidFormatFieldType(String), +} + +impl std::fmt::Debug for FrontTrimmableBuffer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + std::fmt::Debug::fmt(self.as_slice(), f) + } +} + +impl std::fmt::Debug for MessageParser { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_map() + .entry(&"name", &self.name) + .entry(&"fields", &self.fields) + .finish() + } +} + +impl<R: Message> std::fmt::Debug for EncodedMessage<R> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("EncodedMessage") + .field("kind", &R::get_name()) + .field("payload", &self.payload) + .finish() + } +} diff --git a/crates/windlass/src/transport.rs b/crates/windlass/src/transport.rs deleted file mode 100644 index e47cc8e..0000000 --- a/crates/windlass/src/transport.rs +++ /dev/null @@ -1,615 +0,0 @@ -use std::{collections::VecDeque, sync::Arc}; -use tokio::{ - io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader}, - pin, select, - sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}, - task::{spawn, JoinHandle}, - time::{sleep_until, Instant}, -}; -use tokio_util::sync::CancellationToken; -use tracing::trace; - -const MESSAGE_HEADER_SIZE: usize = 2; -const MESSAGE_TRAILER_SIZE: usize = 3; -pub const MESSAGE_LENGTH_MIN: usize = MESSAGE_HEADER_SIZE + MESSAGE_TRAILER_SIZE; -pub const MESSAGE_LENGTH_MAX: usize = 64; -pub const MESSAGE_LENGTH_PAYLOAD_MAX: usize = MESSAGE_LENGTH_MAX - MESSAGE_LENGTH_MIN; -const MESSAGE_POSITION_SEQ: usize = 1; -const MESSAGE_TRAILER_CRC: usize = 3; -const MESSAGE_VALUE_SYNC: u8 = 0x7E; -const MESSAGE_DEST: u8 = 0x10; -const MESSAGE_SEQ_MASK: u8 = 0x0F; - -#[derive(thiserror::Error, Debug)] -pub enum ReceiverError { - #[error("io error: {0}")] - IoError(#[from] std::io::Error), -} - -#[derive(thiserror::Error, Debug)] -pub enum TransmitterError { - #[error("io error: {0}")] - IoError(#[from] std::io::Error), - #[error("connection closed")] - ConnectionClosed, -} - -#[derive(Debug)] -pub(crate) struct Transport { - _task_rdr: JoinHandle<()>, - _task_wr: JoinHandle<()>, - task_inner: JoinHandle<()>, - outbox_tx: UnboundedSender<TransportCommand>, -} - -#[derive(Debug)] -enum TransportCommand { - SendMessage(Vec<u8>), - Exit, -} - -pub(crate) type TransportReceiver = UnboundedReceiver<Result<Vec<u8>, TransportError>>; - -impl Transport { - pub(crate) async fn connect<R>(stream: R) -> (Transport, TransportReceiver) - where - R: AsyncRead + AsyncWrite + Send + 'static, - { - let (rdr, wr) = tokio::io::split(stream); - - let (raw_send_tx, raw_send_rx) = unbounded_channel(); - let (raw_recv_tx, raw_recv_rx) = unbounded_channel(); - let (app_inbox_tx, app_inbox_rx) = unbounded_channel(); - let (app_outbox_tx, app_outbox_rx) = unbounded_channel(); - - let cancel_token = CancellationToken::new(); - - let cancel = cancel_token.clone(); - let ait = app_inbox_tx.clone(); - let task_rdr = spawn(async move { - if let Err(e) = LowlevelReader::run(raw_recv_tx, rdr, cancel).await { - let _ = ait.send(Err(TransportError::Receiver(e))); - } - }); - - let cancel = cancel_token.clone(); - let ait = app_inbox_tx.clone(); - let task_wr = spawn(async move { - if let Err(e) = LowlevelWriter::run(raw_send_rx, wr, cancel).await { - let _ = ait.send(Err(TransportError::Transmitter(e))); - } - }); - - let task_inner = spawn(async move { - let mut ts = TransportState::new( - raw_recv_rx, - raw_send_tx, - app_inbox_tx, - app_outbox_rx, - cancel_token, - ); - if let Err(e) = ts.protocol_handler().await { - let _ = ts.app_inbox.send(Err(e)); - } - }); - - ( - Transport { - _task_rdr: task_rdr, - _task_wr: task_wr, - task_inner, - outbox_tx: app_outbox_tx, - }, - app_inbox_rx, - ) - } - - pub(crate) fn send(&self, msg: &[u8]) -> Result<(), TransmitterError> { - self.outbox_tx - .send(TransportCommand::SendMessage(msg.into())) - .map_err(|_| TransmitterError::ConnectionClosed) - } - - pub(crate) async fn close(self) { - let _ = self.outbox_tx.send(TransportCommand::Exit); - let _ = self.task_inner.await; - } -} - -struct LowlevelReader<R> { - rdr: BufReader<R>, - synced: bool, -} - -impl<R: AsyncRead + Unpin> LowlevelReader<R> { - async fn read_frame(&mut self) -> Result<Option<Frame>, ReceiverError> { - let mut buf = [0u8; MESSAGE_LENGTH_MAX]; - let next_byte = self.rdr.read_u8().await?; - if next_byte == MESSAGE_VALUE_SYNC { - if !self.synced { - self.synced = true; - } - return Ok(None); - } - if !self.synced { - return Ok(None); - } - - let receive_time = Instant::now(); - let len = next_byte as usize; - - if !(MESSAGE_LENGTH_MIN..=MESSAGE_LENGTH_MAX).contains(&len) { - self.synced = false; - return Ok(None); - } - - self.rdr.read_exact(&mut buf[1..len]).await?; - buf[0] = len as u8; - let buf = &buf[..len]; - let seq = buf[MESSAGE_POSITION_SEQ]; - trace!(frame = ?buf, seq = seq, "Received frame"); - - if seq & !MESSAGE_SEQ_MASK != MESSAGE_DEST { - self.synced = false; - return Ok(None); - } - - let actual_crc = crc16(&buf[0..len - MESSAGE_TRAILER_SIZE]); - let frame_crc = (buf[len - MESSAGE_TRAILER_CRC] as u16) << 8 - | (buf[len - MESSAGE_TRAILER_CRC + 1] as u16); - if frame_crc != actual_crc { - self.synced = false; - return Ok(None); - } - - Ok(Some(Frame { - receive_time, - sequence: seq & MESSAGE_SEQ_MASK, - payload: buf[MESSAGE_HEADER_SIZE..len - MESSAGE_TRAILER_SIZE].into(), - })) - } - - async fn run( - outbox: UnboundedSender<Frame>, - rdr: R, - cancel_token: CancellationToken, - ) -> Result<(), ReceiverError> - where - R: AsyncRead + Unpin, - { - let mut state = Self { - rdr: BufReader::new(rdr), - synced: false, - }; - - loop { - match state.read_frame().await { - Ok(None) => {} - Ok(Some(frame)) => { - if outbox.send(frame).is_err() { - break Ok(()); - } - } - Err(_) if cancel_token.is_cancelled() => break Ok(()), - Err(e) => break Err(e), - } - } - } -} - -fn crc16(buf: &[u8]) -> u16 { - let mut crc = 0xFFFFu16; - for b in buf { - let b = *b ^ ((crc & 0xFF) as u8); - let b = b ^ (b << 4); - let b16 = b as u16; - crc = (b16 << 8 | crc >> 8) ^ (b16 >> 4) ^ (b16 << 3); - } - crc -} - -#[derive(Debug)] -struct Frame { - receive_time: Instant, - sequence: u8, - payload: Vec<u8>, -} - -#[derive(Debug, Clone)] -struct InflightFrame { - sent_at: Instant, - #[allow(dead_code)] - sequence: u64, - payload: Arc<Vec<u8>>, - is_retransmit: bool, -} - -#[derive(thiserror::Error, Debug)] -pub enum MessageEncodeError { - #[error("message would exceed the maximum packet length of {MESSAGE_LENGTH_MAX} bytes")] - MessageTooLong, -} - -fn encode_frame(sequence: u64, payload: &[u8]) -> Result<Vec<u8>, MessageEncodeError> { - let len = MESSAGE_LENGTH_MIN + payload.len(); - if len > MESSAGE_LENGTH_MAX { - return Err(MessageEncodeError::MessageTooLong); - } - let mut buf = Vec::with_capacity(len); - buf.push(len as u8); - buf.push(MESSAGE_DEST | ((sequence as u8) & MESSAGE_SEQ_MASK)); - buf.extend_from_slice(payload); - let crc = crc16(&buf[0..len - MESSAGE_TRAILER_SIZE]); - buf.push(((crc >> 8) & 0xFF) as u8); - buf.push((crc & 0xFF) as u8); - buf.push(MESSAGE_VALUE_SYNC); - Ok(buf) -} - -struct LowlevelWriter {} - -impl LowlevelWriter { - async fn run<W>( - mut inbox: UnboundedReceiver<Arc<Vec<u8>>>, - mut wr: W, - cancel_token: CancellationToken, - ) -> Result<(), TransmitterError> - where - W: AsyncWrite + Unpin, - { - loop { - select! { - msg = inbox.recv() => { - let msg = match msg { - Some(msg) => msg, - None => break, - }; - trace!(payload = ?msg, seq = msg[MESSAGE_POSITION_SEQ], "Sent frame"); - wr.write_all(&msg).await?; - wr.flush().await?; - } - - _ = cancel_token.cancelled() => { - break; - } - } - } - - Ok(()) - } -} - -#[derive(thiserror::Error, Debug)] -pub enum TransportError { - #[error("message encoding failed: {0}")] - MessageEncode(#[from] MessageEncodeError), - #[error("receiver error: {0}")] - Receiver(#[from] ReceiverError), - #[error("transmitter error: {0}")] - Transmitter(#[from] TransmitterError), -} - -const MIN_RTO: f32 = 0.025; -const MAX_RTO: f32 = 5.000; - -#[derive(Debug)] -struct RttState { - srtt: f32, - rttvar: f32, - rto: f32, -} - -impl RttState { - fn new() -> Self { - Self { - srtt: 0.0, - rttvar: 0.0, - rto: MIN_RTO, - } - } - - fn rto(&self) -> std::time::Duration { - std::time::Duration::from_secs_f32(self.rto) - } - - fn update(&mut self, rtt: std::time::Duration) { - let r = rtt.as_secs_f32(); - if self.srtt == 0.0 { - self.rttvar = r / 2.0; - self.srtt = r * 10.0; // Klipper uses this, we'll copy it - } else { - self.rttvar = (3.0 * self.rttvar + (self.srtt - r).abs()) / 4.0; - self.srtt = (7.0 * self.srtt + r) / 8.0; - } - let rttvar4 = (self.rttvar * 4.0).max(0.001); - self.rto = (self.srtt + rttvar4).clamp(MIN_RTO, MAX_RTO); - } -} - -#[derive(Debug)] -struct TransportState { - link_inbox: UnboundedReceiver<Frame>, - link_outbox: UnboundedSender<Arc<Vec<u8>>>, - app_inbox: UnboundedSender<Result<Vec<u8>, TransportError>>, - app_outbox: UnboundedReceiver<TransportCommand>, - cancel: CancellationToken, - - is_synchronized: bool, - rtt_state: RttState, - receive_sequence: u64, - send_sequence: u64, - last_ack_sequence: u64, - ignore_nak_seq: u64, - retransmit_seq: u64, - retransmit_now: bool, - - corked_until: Option<Instant>, - - inflight_messages: VecDeque<InflightFrame>, - ready_messages: MessageQueue<Vec<u8>>, -} - -impl TransportState { - fn new( - link_inbox: UnboundedReceiver<Frame>, - link_outbox: UnboundedSender<Arc<Vec<u8>>>, - app_inbox: UnboundedSender<Result<Vec<u8>, TransportError>>, - app_outbox: UnboundedReceiver<TransportCommand>, - cancel: CancellationToken, - ) -> Self { - Self { - link_inbox, - link_outbox, - app_inbox, - app_outbox, - cancel, - - is_synchronized: false, - rtt_state: RttState::new(), - receive_sequence: 1, - send_sequence: 1, - last_ack_sequence: 0, - ignore_nak_seq: 0, - retransmit_seq: 0, - retransmit_now: false, - - corked_until: None, - - inflight_messages: VecDeque::new(), - ready_messages: MessageQueue::new(), - } - } - - fn update_receive_seq(&mut self, receive_time: Instant, sequence: u64) { - let mut sent_seq = self.receive_sequence; - - loop { - if let Some(msg) = self.inflight_messages.pop_front() { - sent_seq += 1; - if sequence == sent_seq { - // Found the matching sent message - if !msg.is_retransmit { - let elapsed = receive_time.saturating_duration_since(msg.sent_at); - self.rtt_state.update(elapsed); - } - break; - } - } else { - // Ack with no outstanding messages, happens during connection init - self.send_sequence = sequence; - break; - } - } - - self.receive_sequence = sequence; - self.is_synchronized = true; - } - - fn handle_frame(&mut self, frame: Frame) { - let rseq = self.receive_sequence; - let mut sequence = (rseq & !(MESSAGE_SEQ_MASK as u64)) | (frame.sequence as u64); - if sequence < rseq { - sequence += (MESSAGE_SEQ_MASK as u64) + 1; - } - if !self.is_synchronized || sequence != rseq { - if sequence > self.send_sequence && self.is_synchronized { - // Ack for unsent message - return; - } - self.update_receive_seq(frame.receive_time, sequence); - } - - if frame.payload.is_empty() { - if self.last_ack_sequence < sequence { - self.last_ack_sequence = sequence; - } else if sequence > self.ignore_nak_seq && !self.inflight_messages.is_empty() { - // Trigger retransmit from NAK - self.retransmit_now = true; - } - } else { - // Data message, we deliver this directly to the application as the MCU can't actually - // retransmit anyway. - let _ = self.app_inbox.send(Ok(frame.payload)); - } - } - - fn can_send(&self) -> bool { - self.corked_until.is_none() && self.inflight_messages.len() < 12 - } - - fn send_new_frame(&mut self, mut initial: Vec<u8>) -> Result<(), TransportError> { - while let Some(next) = self.ready_messages.try_peek() { - if initial.len() + next.len() <= MESSAGE_LENGTH_PAYLOAD_MAX { - // Add to the end of the message. Unwrap is safe because we already peeked. - let mut next = self.ready_messages.try_recv().unwrap(); - initial.append(&mut next); - } else { - break; - } - } - let frame = Arc::new(encode_frame(self.send_sequence, &initial)?); - self.send_sequence += 1; - self.inflight_messages.push_back(InflightFrame { - sent_at: Instant::now(), - sequence: self.send_sequence, - payload: frame.clone(), - is_retransmit: false, - }); - let _ = self.link_outbox.send(frame); - Ok(()) - } - - fn send_more_frames(&mut self) -> Result<(), TransportError> { - while self.can_send() && self.ready_messages.try_peek().is_some() { - let msg = self.ready_messages.try_recv().unwrap(); - self.send_new_frame(msg)?; - } - Ok(()) - } - - fn retransmit_pending(&mut self) { - let len: usize = self - .inflight_messages - .iter() - .map(|msg| msg.payload.len()) - .sum(); - let mut buf = Vec::with_capacity(1 + len); - buf.push(MESSAGE_VALUE_SYNC); - let now = Instant::now(); - for msg in self.inflight_messages.iter_mut() { - buf.extend_from_slice(&msg.payload); - msg.is_retransmit = true; - msg.sent_at = now; - } - let _ = self.link_outbox.send(Arc::new(buf)); - - if self.retransmit_now { - self.ignore_nak_seq = self.receive_sequence; - if self.receive_sequence < self.retransmit_seq { - self.ignore_nak_seq = self.retransmit_seq; - } - self.retransmit_now = false; - } else { - self.rtt_state.rto = (self.rtt_state.rto * 2.0).clamp(MIN_RTO, MAX_RTO); - self.ignore_nak_seq = self.send_sequence; - } - self.retransmit_seq = self.send_sequence; - } - - async fn protocol_handler(&mut self) -> Result<(), TransportError> { - loop { - if self.retransmit_now { - self.retransmit_pending(); - } - - let retransmit_deadline = self - .inflight_messages - .front() - .map(|msg| msg.sent_at + self.rtt_state.rto()); - let retransmit_timeout: futures::future::OptionFuture<_> = - retransmit_deadline.map(sleep_until).into(); - pin!(retransmit_timeout); - - let corked_timeout: futures::future::OptionFuture<_> = - self.corked_until.map(sleep_until).into(); - pin!(corked_timeout); - - select! { - frame = self.link_inbox.recv() => { - let frame = match frame { - Some(frame) => frame, - None => break, - }; - self.handle_frame(frame); - }, - - msg = self.app_outbox.recv() => { - match msg { - Some(TransportCommand::SendMessage(msg)) => { - self.ready_messages.send(msg); - }, - Some(TransportCommand::Exit) => { - self.cancel.cancel(); - } - None => break, - }; - }, - - _ = self.ready_messages.recv_peek(), if self.can_send() => { - self.send_more_frames()?; - }, - - _ = &mut retransmit_timeout, if retransmit_deadline.is_some() => { - self.retransmit_now = true; - }, - - _ = &mut corked_timeout, if self.corked_until.is_some() => { - // Timeout for when we are able to send again - } - - _ = self.cancel.cancelled() => { - break; - }, - } - } - Ok(()) - } -} - -#[derive(Debug)] -struct MessageQueue<T> { - sender: UnboundedSender<T>, - receiver: UnboundedReceiver<T>, - peeked: Option<T>, -} - -impl<T> MessageQueue<T> { - fn new() -> Self { - let (sender, receiver) = unbounded_channel(); - Self { - sender, - receiver, - peeked: None, - } - } - - async fn recv_peek(&mut self) -> Option<&T> { - if self.peeked.is_some() { - self.peeked.as_ref() - } else { - match self.receiver.recv().await { - Some(msg) => { - self.peeked = Some(msg); - self.peeked.as_ref() - } - None => None, - } - } - } - - fn try_recv(&mut self) -> Option<T> { - if let Some(msg) = self.peeked.take() { - Some(msg) - } else { - self.receiver.try_recv().ok() - } - } - - fn try_peek(&mut self) -> Option<&T> { - if self.peeked.is_some() { - self.peeked.as_ref() - } else { - match self.receiver.try_recv() { - Ok(msg) => { - self.peeked = Some(msg); - self.peeked.as_ref() - } - Err(_) => None, - } - } - } - - fn send(&mut self, msg: T) { - let _ = self.sender.send(msg); - } -} diff --git a/crates/windlass/src/transport/mod.rs b/crates/windlass/src/transport/mod.rs new file mode 100644 index 0000000..f5c5fc3 --- /dev/null +++ b/crates/windlass/src/transport/mod.rs @@ -0,0 +1,434 @@ +use read::{FrameReader, ReceivedFrame, ReceiverError}; +use std::{collections::VecDeque, sync::Arc, time::Duration}; +use tokio::{ + io::{AsyncBufRead, AsyncWrite, AsyncWriteExt}, + pin, select, spawn, + sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}, + task::JoinHandle, + time::{sleep_until, Instant}, +}; +use tokio_util::sync::CancellationToken; + +use crate::encoding::crc16; + +pub(crate) const MESSAGE_HEADER_SIZE: usize = 2; +pub(crate) const MESSAGE_TRAILER_SIZE: usize = 3; +pub(crate) const MESSAGE_LENGTH_MIN: usize = MESSAGE_HEADER_SIZE + MESSAGE_TRAILER_SIZE; +pub(crate) const MESSAGE_LENGTH_MAX: usize = 64; +pub(crate) const MESSAGE_LENGTH_PAYLOAD_MAX: usize = MESSAGE_LENGTH_MAX - MESSAGE_LENGTH_MIN; +pub(crate) const MESSAGE_POSITION_SEQ: usize = 1; +pub(crate) const MESSAGE_TRAILER_CRC: usize = 3; +pub(crate) const MESSAGE_VALUE_SYNC: u8 = 0x7E; +pub(crate) const MESSAGE_DEST: u8 = 0x10; +pub(crate) const MESSAGE_SEQ_MASK: u8 = 0x0F; + +mod read; + +/// Wrapper around a connection to a klipper firmware MCU, which deals with +/// retransmission, flow control, etc. +/// +/// Internally, this holds a handle to an async task and a channel to communicate with it, +/// meaning operations don't necessarily block the current task. +#[derive(Debug)] +pub struct Transport { + // Handles to the associated async task + task_inner: JoinHandle<()>, + + /// Queue for outbound messages + cmd_send: UnboundedSender<TransportCommand>, +} + +/// A message sent to the transport task +#[derive(Debug)] +enum TransportCommand { + SendMessage(Vec<u8>), + Exit, +} + +pub(crate) type TransportReceiver = UnboundedReceiver<Result<Vec<u8>, TransportError>>; + +impl Transport { + pub(crate) async fn connect( + rdr: impl AsyncBufRead + Unpin + Send + 'static, + wr: impl AsyncWrite + Unpin + Send + 'static, + ) -> (Transport, TransportReceiver) { + let (data_send, data_recv) = unbounded_channel(); + let (cmd_send, cmd_recv) = unbounded_channel(); + + let cancel_token = CancellationToken::new(); + let task_inner = spawn(async move { + let mut ts = TransportState::new(rdr, wr, data_send, cmd_recv, cancel_token); + if let Err(e) = ts.run().await { + let _ = ts.data_send.send(Err(e)); + } + }); + + ( + Transport { + task_inner, + cmd_send, + }, + data_recv, + ) + } + + pub(crate) fn send(&self, msg: &[u8]) -> Result<(), TransmitterError> { + self.cmd_send + .send(TransportCommand::SendMessage(msg.into())) + .map_err(|_| TransmitterError::ConnectionClosed) + } + + pub(crate) async fn close(self) { + let _ = self.cmd_send.send(TransportCommand::Exit); + let _ = self.task_inner.await; + } +} + +#[derive(thiserror::Error, Debug)] +pub enum TransportError { + #[error("message encoding failed: {0}")] + MessageEncode(#[from] MessageEncodeError), + + #[error("receiver error: {0}")] + Receiver(#[from] ReceiverError), + + #[error("transmitter error: {0}")] + Transmitter(#[from] TransmitterError), + + #[error("io error: {0}")] + IOError(#[from] std::io::Error), +} + +const MIN_RTO: f32 = 0.025; +const MAX_RTO: f32 = 5.000; + +/// State for estimating the round trip time of the connection +#[derive(Debug)] +struct RttState { + srtt: f32, + rttvar: f32, + rto: f32, +} + +impl Default for RttState { + fn default() -> Self { + Self { + srtt: 0.0, + rttvar: 0.0, + rto: MIN_RTO, + } + } +} + +impl RttState { + /// Get the current recommended retransmission timeout + fn rto(&self) -> Duration { + Duration::from_secs_f32(self.rto) + } + + /// Update the RTT estimation given a new observation + fn update(&mut self, rtt: Duration) { + let r = rtt.as_secs_f32(); + if self.srtt == 0.0 { + self.rttvar = r / 2.0; + self.srtt = r * 10.0; // Klipper uses this, we'll copy it + } else { + self.rttvar = (3.0 * self.rttvar + (self.srtt - r).abs()) / 4.0; + self.srtt = (7.0 * self.srtt + r) / 8.0; + } + let rttvar4 = (self.rttvar * 4.0).max(0.001); + self.rto = (self.srtt + rttvar4).clamp(MIN_RTO, MAX_RTO); + } +} + +/// State for the task which deals with transport state +#[derive(Debug)] +struct TransportState<R, W> { + rdr: FrameReader<R>, + wr: W, + + data_send: UnboundedSender<Result<Vec<u8>, TransportError>>, + cmd_recv: UnboundedReceiver<TransportCommand>, + + cancel: CancellationToken, + + is_synchronized: bool, + rtt_state: RttState, + receive_sequence: u64, + send_sequence: u64, + last_ack_sequence: u64, + ignore_nak_seq: u64, + retransmit_seq: u64, + retransmit_now: bool, + + corked_until: Option<Instant>, + + inflight_messages: VecDeque<SentFrame>, + pending_messages: VecDeque<Vec<u8>>, +} + +impl<R: AsyncBufRead + Unpin, W: AsyncWrite + Unpin> TransportState<R, W> { + fn new( + rdr: R, + wr: W, + data_send: UnboundedSender<Result<Vec<u8>, TransportError>>, + cmd_recv: UnboundedReceiver<TransportCommand>, + cancel: CancellationToken, + ) -> Self { + Self { + rdr: FrameReader::new(rdr), + wr, + data_send, + cmd_recv, + cancel, + + is_synchronized: false, + rtt_state: RttState::default(), + receive_sequence: 1, + send_sequence: 1, + last_ack_sequence: 0, + ignore_nak_seq: 0, + retransmit_seq: 0, + retransmit_now: false, + + corked_until: None, + + inflight_messages: VecDeque::new(), + pending_messages: VecDeque::new(), + } + } + + async fn run(&mut self) -> Result<(), TransportError> { + loop { + if self.retransmit_now { + self.retransmit_pending().await?; + } + + if !self.pending_messages.is_empty() && self.can_send() { + self.send_more_frames().await?; + } + + let retransmit_deadline = self + .inflight_messages + .front() + .map(|msg| msg.sent_at + self.rtt_state.rto()); + let retransmit_timeout: futures::future::OptionFuture<_> = + retransmit_deadline.map(sleep_until).into(); + pin!(retransmit_timeout); + + let corked_timeout: futures::future::OptionFuture<_> = + self.corked_until.map(sleep_until).into(); + pin!(corked_timeout); + + // FIXME: This is not correct because read_frame is not cancellation safe + select! { + frame = self.rdr.read_frame() => { + let frame = frame?; + let frame = match frame { + Some(frame) => frame, + None => break, + }; + self.handle_frame(frame); + }, + + msg = self.cmd_recv.recv() => { + match msg { + Some(TransportCommand::SendMessage(msg)) => { + self.pending_messages.push_back(msg); + }, + Some(TransportCommand::Exit) => { + self.cancel.cancel(); + } + None => break, + }; + }, + + _ = &mut retransmit_timeout, if retransmit_deadline.is_some() => { + self.retransmit_now = true; + }, + + _ = &mut corked_timeout, if self.corked_until.is_some() => { + // Timeout for when we are able to send again + } + + _ = self.cancel.cancelled() => { + break; + }, + } + } + Ok(()) + } + + /// Handle an incoming frame, by updating sequence numbers and sending the data upwards if needed + fn handle_frame(&mut self, frame: ReceivedFrame) { + let rseq = self.receive_sequence; + + // wrap-around logic(?) + let mut sequence = (rseq & !(MESSAGE_SEQ_MASK as u64)) | (frame.sequence as u64); + if sequence < rseq { + sequence += (MESSAGE_SEQ_MASK as u64) + 1; + } + + // Frame acknowledges some messages + if !self.is_synchronized || sequence != rseq { + if sequence > self.send_sequence && self.is_synchronized { + // Ack for unsent message - weird, but ignore and try to continue + return; + } + + self.update_receive_seq(frame.receive_time, sequence); + } + + if !frame.payload.is_empty() { + // Data message, we deliver this directly to the application as the MCU can't actually + // retransmit anyway. + // TODO: Maybe check the CRC anyway so we can discard it here + let _ = self.data_send.send(Ok(frame.payload)); + } else if sequence > self.last_ack_sequence { + // ACK + self.last_ack_sequence = sequence; + } else if sequence > self.ignore_nak_seq && !self.inflight_messages.is_empty() { + // NAK + self.retransmit_now = true; + } + } + + /// Update the last received sequence number, removing acknowledged messages from `self.inflight_messages` + fn update_receive_seq(&mut self, receive_time: Instant, sequence: u64) { + let mut sent_seq = self.receive_sequence; + + // Discard messages from inflight_messages up to sequence + loop { + if let Some(msg) = self.inflight_messages.pop_front() { + sent_seq += 1; + if sequence == sent_seq { + // Found the matching sent message + if !msg.is_retransmit { + let elapsed = receive_time.saturating_duration_since(msg.sent_at); + self.rtt_state.update(elapsed); + } + break; + } + } else { + // Ack with no outstanding messages, happens during connection init + self.send_sequence = sequence; + break; + } + } + + self.receive_sequence = sequence; + self.is_synchronized = true; + } + + fn can_send(&self) -> bool { + self.corked_until.is_none() && self.inflight_messages.len() < 12 + } + + /// Send as many more frames as possible from [`self.pending_messages`] + async fn send_more_frames(&mut self) -> Result<(), TransportError> { + while self.can_send() && !self.pending_messages.is_empty() { + self.send_new_frame().await?; + } + + Ok(()) + } + + /// Send a single new frame from [`self.pending_messages`] + async fn send_new_frame(&mut self) -> Result<(), TransportError> { + let mut buf = Vec::new(); + while let Some(next) = self.pending_messages.front() { + if !buf.is_empty() && buf.len() + next.len() <= MESSAGE_LENGTH_PAYLOAD_MAX { + // Add to the end of the frame. Unwrap is safe because we already peeked. + let mut next = self.pending_messages.pop_front().unwrap(); + buf.append(&mut next); + } else { + break; + } + } + + let frame = Arc::new(encode_frame(self.send_sequence, &buf)?); + self.send_sequence += 1; + self.inflight_messages.push_back(SentFrame { + sent_at: Instant::now(), + sequence: self.send_sequence, + payload: frame.clone(), + is_retransmit: false, + }); + self.wr.write_all(&frame).await?; + + Ok(()) + } + + /// Retransmit all inflight messages + async fn retransmit_pending(&mut self) -> Result<(), TransportError> { + let len: usize = self + .inflight_messages + .iter() + .map(|msg| msg.payload.len()) + .sum(); + let mut buf = Vec::with_capacity(1 + len); + buf.push(MESSAGE_VALUE_SYNC); + let now = Instant::now(); + for msg in self.inflight_messages.iter_mut() { + buf.extend_from_slice(&msg.payload); + msg.is_retransmit = true; + msg.sent_at = now; + } + self.wr.write_all(&buf).await?; + + if self.retransmit_now { + self.ignore_nak_seq = self.receive_sequence; + if self.receive_sequence < self.retransmit_seq { + self.ignore_nak_seq = self.retransmit_seq; + } + self.retransmit_now = false; + } else { + self.rtt_state.rto = (self.rtt_state.rto * 2.0).clamp(MIN_RTO, MAX_RTO); + self.ignore_nak_seq = self.send_sequence; + } + self.retransmit_seq = self.send_sequence; + + Ok(()) + } +} + +/// An error encountered when transmitting a message +#[derive(thiserror::Error, Debug)] +pub enum TransmitterError { + #[error("io error: {0}")] + IoError(#[from] std::io::Error), + + #[error("connection closed")] + ConnectionClosed, +} + +#[derive(Debug, Clone)] +pub(crate) struct SentFrame { + pub sent_at: Instant, + #[allow(dead_code)] + pub sequence: u64, + pub payload: Arc<Vec<u8>>, + pub is_retransmit: bool, +} + +#[derive(thiserror::Error, Debug)] +pub enum MessageEncodeError { + #[error("message would exceed the maximum packet length of {MESSAGE_LENGTH_MAX} bytes")] + MessageTooLong, +} + +fn encode_frame(sequence: u64, payload: &[u8]) -> Result<Vec<u8>, MessageEncodeError> { + let len = MESSAGE_LENGTH_MIN + payload.len(); + if len > MESSAGE_LENGTH_MAX { + return Err(MessageEncodeError::MessageTooLong); + } + let mut buf = Vec::with_capacity(len); + buf.push(len as u8); + buf.push(MESSAGE_DEST | ((sequence as u8) & MESSAGE_SEQ_MASK)); + buf.extend_from_slice(payload); + let crc = crc16(&buf[0..len - MESSAGE_TRAILER_SIZE]); + buf.push(((crc >> 8) & 0xFF) as u8); + buf.push((crc & 0xFF) as u8); + buf.push(MESSAGE_VALUE_SYNC); + Ok(buf) +} diff --git a/crates/windlass/src/transport/read.rs b/crates/windlass/src/transport/read.rs new file mode 100644 index 0000000..079f7b3 --- /dev/null +++ b/crates/windlass/src/transport/read.rs @@ -0,0 +1,87 @@ +use tokio::{ + io::{AsyncBufRead, AsyncReadExt}, + time::Instant, +}; +use tracing::trace; + +use crate::{ + encoding::crc16, + transport::{ + MESSAGE_DEST, MESSAGE_HEADER_SIZE, MESSAGE_LENGTH_MAX, MESSAGE_LENGTH_MIN, + MESSAGE_POSITION_SEQ, MESSAGE_SEQ_MASK, MESSAGE_TRAILER_CRC, MESSAGE_TRAILER_SIZE, + MESSAGE_VALUE_SYNC, + }, +}; + +#[derive(Debug)] +pub(crate) struct ReceivedFrame { + pub receive_time: Instant, + pub sequence: u8, + pub payload: Vec<u8>, +} + +#[derive(Debug)] +pub struct FrameReader<R> { + rdr: R, + synced: bool, +} + +impl<R: AsyncBufRead + Unpin> FrameReader<R> { + pub fn new(rdr: R) -> Self { + Self { rdr, synced: false } + } + + pub async fn read_frame(&mut self) -> Result<Option<ReceivedFrame>, ReceiverError> { + let mut buf = [0u8; MESSAGE_LENGTH_MAX]; + let next_byte = self.rdr.read_u8().await?; + if next_byte == MESSAGE_VALUE_SYNC { + if !self.synced { + self.synced = true; + } + return Ok(None); + } + if !self.synced { + return Ok(None); + } + + let receive_time = Instant::now(); + let len = next_byte as usize; + + if !(MESSAGE_LENGTH_MIN..=MESSAGE_LENGTH_MAX).contains(&len) { + self.synced = false; + return Ok(None); + } + + self.rdr.read_exact(&mut buf[1..len]).await?; + buf[0] = len as u8; + let buf = &buf[..len]; + let seq = buf[MESSAGE_POSITION_SEQ]; + trace!(frame = ?buf, seq = seq, "Received frame"); + + if seq & !MESSAGE_SEQ_MASK != MESSAGE_DEST { + self.synced = false; + return Ok(None); + } + + let actual_crc = crc16(&buf[0..len - MESSAGE_TRAILER_SIZE]); + let frame_crc = (buf[len - MESSAGE_TRAILER_CRC] as u16) << 8 + | (buf[len - MESSAGE_TRAILER_CRC + 1] as u16); + if frame_crc != actual_crc { + self.synced = false; + return Ok(None); + } + + Ok(Some(ReceivedFrame { + receive_time, + sequence: seq & MESSAGE_SEQ_MASK, + payload: buf[MESSAGE_HEADER_SIZE..len - MESSAGE_TRAILER_SIZE].into(), + })) + } +} + +/// An error encountered when receiving a message +#[derive(thiserror::Error, Debug)] +pub enum ReceiverError { + #[error("io error: {0}")] + IoError(#[from] std::io::Error), +} |