From 5f9fbea5a1b08962887d16457e77a922919a8818 Mon Sep 17 00:00:00 2001 From: tcmal Date: Thu, 12 Sep 2024 16:33:19 +0100 Subject: Vendor windlass (control-side klipper protocol implementation) --- crates/windlass/.gitignore | 2 + crates/windlass/Cargo.toml | 15 + crates/windlass/LICENSE.txt | 21 ++ crates/windlass/README.md | 18 + crates/windlass/examples/usb_test.rs | 60 ++++ crates/windlass/src/dictionary.rs | 131 ++++++++ crates/windlass/src/encoding.rs | 327 +++++++++++++++++++ crates/windlass/src/lib.rs | 21 ++ crates/windlass/src/macros.rs | 200 ++++++++++++ crates/windlass/src/mcu.rs | 585 +++++++++++++++++++++++++++++++++ crates/windlass/src/messages.rs | 237 ++++++++++++++ crates/windlass/src/transport.rs | 615 +++++++++++++++++++++++++++++++++++ 12 files changed, 2232 insertions(+) create mode 100644 crates/windlass/.gitignore create mode 100644 crates/windlass/Cargo.toml create mode 100644 crates/windlass/LICENSE.txt create mode 100644 crates/windlass/README.md create mode 100644 crates/windlass/examples/usb_test.rs create mode 100644 crates/windlass/src/dictionary.rs create mode 100644 crates/windlass/src/encoding.rs create mode 100644 crates/windlass/src/lib.rs create mode 100644 crates/windlass/src/macros.rs create mode 100644 crates/windlass/src/mcu.rs create mode 100644 crates/windlass/src/messages.rs create mode 100644 crates/windlass/src/transport.rs (limited to 'crates') diff --git a/crates/windlass/.gitignore b/crates/windlass/.gitignore new file mode 100644 index 0000000..4fffb2f --- /dev/null +++ b/crates/windlass/.gitignore @@ -0,0 +1,2 @@ +/target +/Cargo.lock diff --git a/crates/windlass/Cargo.toml b/crates/windlass/Cargo.toml new file mode 100644 index 0000000..2c5d436 --- /dev/null +++ b/crates/windlass/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "windlass" +version = "0.1.0" +edition = "2021" + +[dependencies] +futures = "0.3" +thiserror = "1" +paste = "1" +flate2 = "1" +tokio = { workspace = true, features = ["sync", "io-util", "time", "macros", "rt"] } +tokio-util = { workspace = true } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +tracing = { workspace = true } diff --git a/crates/windlass/LICENSE.txt b/crates/windlass/LICENSE.txt new file mode 100644 index 0000000..2b6864b --- /dev/null +++ b/crates/windlass/LICENSE.txt @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Lasse Dalegaard + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/crates/windlass/README.md b/crates/windlass/README.md new file mode 100644 index 0000000..5d35c7d --- /dev/null +++ b/crates/windlass/README.md @@ -0,0 +1,18 @@ +# Windlass + +This code was originally sourced from [https://github.com/Annex-Engineering/windlass/tree/master](annex-engineering). + +Windlass is an implementation of the host side of the Klipper protocol. + +## Licensing + +Windlass is licensed under the MIT license, which means you can do pretty much +whatever you want with it. Please see [LICENSE.txt](LICENSE.txt) for more +information. + +## Acknowledgements + +This project is in no way endorsed by the Klipper project. Please do not direct +any support requests to the Klipper project. + + * [Klipper](https://www.klipper3d.org/) by [Kevin O'Connor](https://www.patreon.com/koconnor) diff --git a/crates/windlass/examples/usb_test.rs b/crates/windlass/examples/usb_test.rs new file mode 100644 index 0000000..6aee77b --- /dev/null +++ b/crates/windlass/examples/usb_test.rs @@ -0,0 +1,60 @@ +use tokio::time::Duration; +use windlass::{mcu_command, mcu_reply, McuConnection}; + +mcu_command!(DebugNop, "debug_nop"); +mcu_command!(GetClock, "get_clock"); +mcu_reply!(Clock, "clock", clock: u32); +mcu_command!(GetUptime, "get_uptime"); +mcu_reply!(Uptime, "uptime", high: u32, clock: u32); +mcu_reply!( + AdxlStatus, + "adxl345_status", + oid: u8, + sequence: u16, + data: Vec, +); +mcu_reply!(Stats, "stats", count: u32, sum: u32, sumsq: u32); + +mcu_command!( + ConfigEndstop, + "config_endstop", + oid: u8, + pin: u8, + pull_up: u8 +); + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt::init(); + + let target = std::env::args().nth(1).expect("Missing path"); + let builder = tokio_serial::new(target, 96_000); + let port = tokio_serial::SerialStream::open(&builder).expect("USB open"); + + let mut mcu = McuConnection::connect(port).await.expect("MCU connect"); + + let mut resp = mcu.register_response(Stats).expect("Register"); + + let clock = mcu.send_receive(GetClock::encode(), Clock); + let uptime = mcu.send_receive(GetUptime::encode(), Uptime).await; + println!("Uptime: {uptime:?}"); + let clock = clock.await; + println!("Clock: {clock:?}"); + + let timeout = tokio::time::sleep(Duration::from_secs(10)); + tokio::pin!(timeout); + + loop { + tokio::select! { + _ = &mut timeout => break, + s = resp.recv() => { + println!("Stats: {s:?}"); + } + err = mcu.closed() => { + println!("MCU connection lost: {err:?}"); + break; + } + } + } + mcu.close().await; +} diff --git a/crates/windlass/src/dictionary.rs b/crates/windlass/src/dictionary.rs new file mode 100644 index 0000000..5665f80 --- /dev/null +++ b/crates/windlass/src/dictionary.rs @@ -0,0 +1,131 @@ +use std::collections::BTreeMap; + +use crate::{ + encoding::encode_vlq_int, + messages::{MessageParser, MessageSkipperError}, +}; + +#[derive(Debug, Clone, PartialEq, serde::Deserialize)] +#[serde(untagged)] +pub enum ConfigVar { + String(String), + Number(f64), +} + +#[derive(Debug, serde::Deserialize)] +#[serde(untagged)] +pub enum Enumeration { + Single(i64), + Range(i64, i64), +} + +#[derive(Debug, serde::Deserialize)] +pub(crate) struct RawDictionary { + #[serde(default)] + config: BTreeMap, + + #[serde(default)] + enumerations: BTreeMap>, + + #[serde(default)] + commands: BTreeMap, + #[serde(default)] + responses: BTreeMap, + #[serde(default)] + output: BTreeMap, + + #[serde(default)] + build_versions: Option, + #[serde(default)] + version: Option, + + #[serde(flatten)] + extra: BTreeMap, +} + +/// Dictionary error +#[derive(thiserror::Error, Debug)] +pub enum DictionaryError { + /// Found an empty command + #[error("empty command found")] + EmptyCommand, + /// Received a command in an invalid format + #[error("invalid command format: {0}")] + InvalidCommandFormat(String, MessageSkipperError), + /// Received an output string with an invalid format + #[error("invalid output format: {0}")] + InvalidOutputFormat(String, MessageSkipperError), + /// Received a command with an invalid tag + #[error("command tag {0} output valid range of -32..95")] + InvalidCommandTag(u16), +} + +#[derive(Debug)] +pub struct Dictionary { + pub message_ids: BTreeMap, + pub message_parsers: BTreeMap, + pub config: BTreeMap, + pub enumerations: BTreeMap>, + pub build_versions: Option, + pub version: Option, + pub extra: BTreeMap, +} + +impl Dictionary { + pub(crate) fn from_raw_dictionary(raw: RawDictionary) -> Result { + let mut message_ids = BTreeMap::new(); + let mut message_parsers = BTreeMap::new(); + + for (cmd, tag) in raw.commands { + let mut split = cmd.split(' '); + let name = split.next().ok_or(DictionaryError::EmptyCommand)?; + let parser = MessageParser::new(name, split) + .map_err(|e| DictionaryError::InvalidCommandFormat(name.to_string(), e))?; + let tag = Self::map_tag(tag)?; + message_parsers.insert(tag, parser); + message_ids.insert(name.to_string(), tag); + } + + for (resp, tag) in raw.responses { + let mut split = resp.split(' '); + let name = split.next().ok_or(DictionaryError::EmptyCommand)?; + let parser = MessageParser::new(name, split) + .map_err(|e| DictionaryError::InvalidCommandFormat(name.to_string(), e))?; + let tag = Self::map_tag(tag)?; + message_parsers.insert(tag, parser); + message_ids.insert(name.to_string(), tag); + } + + for (msg, tag) in raw.output { + let parser = MessageParser::new_output(&msg) + .map_err(|e| DictionaryError::InvalidCommandFormat(msg.to_string(), e))?; + let tag = Self::map_tag(tag)?; + message_parsers.insert(tag, parser); + } + + Ok(Dictionary { + message_ids, + message_parsers, + config: raw.config, + enumerations: raw.enumerations, + build_versions: raw.build_versions, + version: raw.version, + extra: raw.extra, + }) + } + + fn map_tag(tag: i16) -> Result { + let mut buf = vec![]; + encode_vlq_int(&mut buf, tag as u32); + let v = if buf.len() > 1 { + ((buf[0] as u16) & 0x7F) << 7 | (buf[1] as u16) & 0x7F + } else { + (buf[0] as u16) & 0x7F + }; + if v >= 1 << 14 { + Err(DictionaryError::InvalidCommandTag(v)) + } else { + Ok(v) + } + } +} diff --git a/crates/windlass/src/encoding.rs b/crates/windlass/src/encoding.rs new file mode 100644 index 0000000..0950956 --- /dev/null +++ b/crates/windlass/src/encoding.rs @@ -0,0 +1,327 @@ +use std::fmt::Display; + +use crate::messages::MessageSkipperError; + +/// Message decoding error +#[derive(thiserror::Error, Debug, Clone)] +pub enum MessageDecodeError { + /// More data was expected but none is available + #[error("eof unexpected")] + UnexpectedEof, + /// A received string could not be decoded as UTF8 + #[error("invalid utf8 string")] + Utf8Error(#[from] std::str::Utf8Error), +} + +pub(crate) fn encode_vlq_int(output: &mut Vec, v: u32) { + let sv = v as i32; + if !(-(1 << 26)..(3 << 26)).contains(&sv) { + output.push(((sv >> 28) & 0x7F) as u8 | 0x80); + } + if !(-(1 << 19)..(3 << 19)).contains(&sv) { + output.push(((sv >> 21) & 0x7F) as u8 | 0x80); + } + if !(-(1 << 12)..(3 << 12)).contains(&sv) { + output.push(((sv >> 14) & 0x7F) as u8 | 0x80); + } + if !(-(1 << 5)..(3 << 5)).contains(&sv) { + output.push(((sv >> 7) & 0x7F) as u8 | 0x80); + } + output.push((sv & 0x7F) as u8); +} + +pub(crate) fn next_byte(data: &mut &[u8]) -> Result { + if data.is_empty() { + Err(MessageDecodeError::UnexpectedEof) + } else { + let v = data[0]; + *data = &data[1..]; + Ok(v) + } +} + +pub(crate) fn parse_vlq_int(data: &mut &[u8]) -> Result { + let mut c = next_byte(data)? as u32; + let mut v = c & 0x7F; + if (c & 0x60) == 0x60 { + v |= (-0x20i32) as u32; + } + while c & 0x80 != 0 { + c = next_byte(data)? as u32; + v = (v << 7) | (c & 0x7F); + } + + Ok(v) +} + +pub trait Readable<'de>: Sized { + fn read(data: &mut &'de [u8]) -> Result; + + fn skip(data: &mut &[u8]) -> Result<(), MessageDecodeError>; +} + +pub trait Writable: Sized { + fn write(&self, output: &mut Vec); +} + +pub trait Borrowable: Sized { + type Borrowed<'a> + where + Self: 'a; + fn from_borrowed(src: Self::Borrowed<'_>) -> Self; +} + +pub trait ToFieldType: Sized { + fn as_field_type() -> FieldType; +} + +macro_rules! int_readwrite { + ( $type:tt, $field_type:expr ) => { + impl Readable<'_> for $type { + fn read(data: &mut &[u8]) -> Result { + parse_vlq_int(data).map(|v| v as $type) + } + + fn skip(data: &mut &[u8]) -> Result<(), MessageDecodeError> { + parse_vlq_int(data).map(|_| ()) + } + } + + impl Writable for $type { + fn write(&self, output: &mut Vec) { + encode_vlq_int(output, *self as u32) + } + } + + impl Borrowable for $type { + type Borrowed<'a> = Self; + fn from_borrowed(src: Self::Borrowed<'_>) -> Self { + src + } + } + + impl ToFieldType for $type { + fn as_field_type() -> FieldType { + $field_type + } + } + }; +} + +int_readwrite!(u32, FieldType::U32); +int_readwrite!(i32, FieldType::I32); +int_readwrite!(u16, FieldType::U16); +int_readwrite!(i16, FieldType::I16); +int_readwrite!(u8, FieldType::U8); + +impl Readable<'_> for bool { + fn read(data: &mut &[u8]) -> Result { + parse_vlq_int(data).map(|v| v != 0) + } + + fn skip(data: &mut &[u8]) -> Result<(), MessageDecodeError> { + parse_vlq_int(data).map(|_| ()) + } +} + +impl Writable for bool { + fn write(&self, output: &mut Vec) { + encode_vlq_int(output, u32::from(*self)) + } +} + +impl Borrowable for bool { + type Borrowed<'a> = Self; + fn from_borrowed(src: Self::Borrowed<'_>) -> Self { + src + } +} + +impl ToFieldType for bool { + fn as_field_type() -> FieldType { + FieldType::U8 + } +} + +impl<'de> Readable<'de> for &'de [u8] { + fn read(data: &mut &'de [u8]) -> Result<&'de [u8], MessageDecodeError> { + let len = parse_vlq_int(data)? as usize; + if data.len() < len { + Err(MessageDecodeError::UnexpectedEof) + } else { + let ret = &data[..len]; + *data = &data[len..]; + Ok(ret) + } + } + + fn skip(data: &mut &[u8]) -> Result<(), MessageDecodeError> { + let len = parse_vlq_int(data)? as usize; + if data.len() < len { + Err(MessageDecodeError::UnexpectedEof) + } else { + *data = &data[len..]; + Ok(()) + } + } +} + +impl Writable for &[u8] { + fn write(&self, output: &mut Vec) { + encode_vlq_int(output, self.len() as u32); + output.extend_from_slice(self); + } +} + +impl Borrowable for Vec { + type Borrowed<'a> = &'a [u8]; + fn from_borrowed(src: Self::Borrowed<'_>) -> Self { + src.into() + } +} + +impl ToFieldType for Vec { + fn as_field_type() -> FieldType { + FieldType::ByteArray + } +} + +impl<'de> Readable<'de> for &'de str { + fn read(data: &mut &'de [u8]) -> Result<&'de str, MessageDecodeError> { + let len = parse_vlq_int(data)? as usize; + if data.len() < len { + Err(MessageDecodeError::UnexpectedEof) + } else { + let ret = &data[..len]; + *data = &data[len..]; + Ok(std::str::from_utf8(ret)?) + } + } + + fn skip(data: &mut &[u8]) -> Result<(), MessageDecodeError> { + let len = parse_vlq_int(data)? as usize; + if data.len() < len { + Err(MessageDecodeError::UnexpectedEof) + } else { + *data = &data[len..]; + Ok(()) + } + } +} + +impl Writable for &str { + fn write(&self, output: &mut Vec) { + let bytes = self.as_bytes(); + encode_vlq_int(output, bytes.len() as u32); + output.extend_from_slice(bytes); + } +} + +impl Borrowable for String { + type Borrowed<'a> = &'a str; + fn from_borrowed(src: Self::Borrowed<'_>) -> Self { + src.to_string() + } +} + +impl ToFieldType for String { + fn as_field_type() -> FieldType { + FieldType::String + } +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub enum FieldType { + U32, + I32, + U16, + I16, + U8, + String, + ByteArray, +} + +impl FieldType { + pub(crate) fn skip(&self, input: &mut &[u8]) -> Result<(), MessageDecodeError> { + match self { + Self::U32 => ::skip(input), + Self::I32 => ::skip(input), + Self::U16 => ::skip(input), + Self::I16 => ::skip(input), + Self::U8 => ::skip(input), + Self::String => <&str as Readable>::skip(input), + Self::ByteArray => <&[u8] as Readable>::skip(input), + } + } + + pub(crate) fn read(&self, input: &mut &[u8]) -> Result { + Ok(match self { + Self::U32 => FieldValue::U32(::read(input)?), + Self::I32 => FieldValue::I32(::read(input)?), + Self::U16 => FieldValue::U16(::read(input)?), + Self::I16 => FieldValue::I16(::read(input)?), + Self::U8 => FieldValue::U8(::read(input)?), + Self::String => FieldValue::String(<&str as Readable>::read(input)?.into()), + Self::ByteArray => FieldValue::ByteArray(<&[u8] as Readable>::read(input)?.into()), + }) + } + + pub(crate) fn from_msg(s: &str) -> Result { + match s { + "%u" => Ok(Self::U32), + "%i" => Ok(Self::I32), + "%hu" => Ok(Self::U16), + "%hi" => Ok(Self::I16), + "%c" => Ok(Self::U8), + "%s" => Ok(Self::String), + "%*s" => Ok(Self::ByteArray), + "%.*s" => Ok(Self::ByteArray), + s => Err(MessageSkipperError::InvalidFormatFieldType(s.to_string())), + } + } + + pub(crate) fn from_format(s: &str) -> Result<(Self, &str), MessageSkipperError> { + if let Some(rest) = s.strip_prefix("%u") { + Ok((Self::U32, rest)) + } else if let Some(rest) = s.strip_prefix("%i") { + Ok((Self::I32, rest)) + } else if let Some(rest) = s.strip_prefix("%hu") { + Ok((Self::U16, rest)) + } else if let Some(rest) = s.strip_prefix("%hi") { + Ok((Self::I16, rest)) + } else if let Some(rest) = s.strip_prefix("%c") { + Ok((Self::U8, rest)) + } else if let Some(rest) = s.strip_prefix("%.*s") { + Ok((Self::ByteArray, rest)) + } else if let Some(rest) = s.strip_prefix("%*s") { + Ok((Self::String, rest)) + } else { + Err(MessageSkipperError::InvalidFormatFieldType(s.to_string())) + } + } +} + +#[derive(Debug)] +pub enum FieldValue { + U32(u32), + I32(i32), + U16(u16), + I16(i16), + U8(u8), + String(String), + ByteArray(Vec), +} + +impl Display for FieldValue { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::U32(v) => write!(f, "{}", v), + Self::I32(v) => write!(f, "{}", v), + Self::U16(v) => write!(f, "{}", v), + Self::I16(v) => write!(f, "{}", v), + Self::U8(v) => write!(f, "{}", v), + Self::String(v) => write!(f, "{}", v), + Self::ByteArray(v) => write!(f, "{:?}", v), + } + } +} diff --git a/crates/windlass/src/lib.rs b/crates/windlass/src/lib.rs new file mode 100644 index 0000000..c578bf2 --- /dev/null +++ b/crates/windlass/src/lib.rs @@ -0,0 +1,21 @@ +#[macro_use] +#[doc(hidden)] +pub mod macros; + +#[doc(hidden)] +pub mod dictionary; + +#[doc(hidden)] +pub mod encoding; + +mod mcu; + +#[doc(hidden)] +pub mod messages; + +mod transport; + +pub use dictionary::DictionaryError; +pub use encoding::MessageDecodeError; +pub use mcu::{McuConnection, McuConnectionError}; +pub use messages::EncodedMessage; diff --git a/crates/windlass/src/macros.rs b/crates/windlass/src/macros.rs new file mode 100644 index 0000000..6f57a70 --- /dev/null +++ b/crates/windlass/src/macros.rs @@ -0,0 +1,200 @@ +pub use crate::transport::MESSAGE_LENGTH_PAYLOAD_MAX; + +#[macro_export] +#[doc(hidden)] +macro_rules! mcu_message_impl_oid_check { + ($ty_name:ident) => { + impl $crate::messages::WithoutOid for $ty_name {} + }; + ($ty_name:ident, oid $(, $args:ident)*) => { + impl $crate::messages::WithOid for $ty_name {} + }; + ($ty_name:ident, $arg:ident $(, $args:ident)*) => { + $crate::mcu_message_impl_oid_check!($ty_name $(, $args)*); + }; +} + +#[macro_export] +#[doc(hidden)] +macro_rules! mcu_message_impl { + ($ty_name:ident, $cmd_name:literal = $cmd_id:expr $(, $arg:ident : $kind:ty)*) => { + paste::paste! { + #[derive(Debug)] + struct $ty_name; + + #[allow(dead_code)] + impl $ty_name { + #[allow(clippy::extra_unused_lifetimes)] + pub fn encode<'a>($($arg: <$kind as $crate::encoding::Borrowable>::Borrowed<'a>, )*) -> $crate::messages::EncodedMessage { + let payload = [<$ty_name Data>] { + $($arg,)* + _lifetime: Default::default(), + }.encode(); + $crate::messages::EncodedMessage { + payload, + _message_kind: Default::default(), + } + } + + pub fn decode<'a>(input: &mut &'a [u8]) -> Result<[<$ty_name Data>]<'a>, $crate::encoding::MessageDecodeError> { + [<$ty_name Data>]::decode(input) + } + } + + #[allow(dead_code)] + struct [<$ty_name Data>]<'a> { + // $(pub $arg: $kind,)* + $(pub $arg: <$kind as $crate::encoding::Borrowable>::Borrowed<'a>,)* + + _lifetime: std::marker::PhantomData<&'a ()>, + } + + #[allow(dead_code)] + impl<'a> [<$ty_name Data>]<'a> { + fn encode(&self) -> $crate::messages::FrontTrimmableBuffer { + use $crate::encoding::Writable; + let mut buf = Vec::with_capacity($crate::macros::MESSAGE_LENGTH_PAYLOAD_MAX); + buf.push(0); + buf.push(0); + $(self.$arg.write(&mut buf);)* + $crate::messages::FrontTrimmableBuffer { content: buf, offset: 0 } + } + + fn decode(input: &mut &'a [u8]) -> Result { + $(let $arg = $crate::encoding::Readable::read(input)?;)* + Ok(Self { $($arg,)* _lifetime: Default::default() }) + } + + } + + impl<'a> std::fmt::Debug for [<$ty_name Data>]<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut ds = f.debug_struct(stringify!([<$ty_name Data>])); + $( let ds = ds.field(stringify!($arg), &self.$arg); )* + ds.finish() + } + } + + #[derive(Clone)] + #[allow(dead_code)] + struct [<$ty_name DataOwned>] { + $(pub $arg: $kind,)* + } + + impl<'a> std::convert::From<[<$ty_name Data>]<'a>> for [<$ty_name DataOwned>] { + fn from(value: [<$ty_name Data>]) -> Self { + Self { + $($arg: $crate::encoding::Borrowable::from_borrowed(value.$arg),)* + } + } + } + + impl std::fmt::Debug for [<$ty_name DataOwned>] { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut ds = f.debug_struct(stringify!([<$ty_name DataOwned>])); + $( let ds = ds.field(stringify!($arg), &self.$arg); )* + ds.finish() + } + } + + #[allow(dead_code)] + impl $crate::messages::Message for $ty_name { + type Pod<'a> = [<$ty_name Data>]<'a>; + type PodOwned = [<$ty_name DataOwned>]; + + fn get_id(dict: Option<&$crate::dictionary::Dictionary>) -> Option { + $cmd_id.or_else(|| dict.and_then(|dict| dict.message_ids.get($cmd_name).copied())) + } + + fn get_name() -> &'static str { + $cmd_name + } + + fn decode<'a>(input: &mut &'a [u8]) -> Result, $crate::encoding::MessageDecodeError> { + Self::decode(input) + } + + fn fields<'a>() -> Vec<(&'static str, $crate::encoding::FieldType)> { + vec![ + $( ( stringify!($arg), <$kind as $crate::encoding::ToFieldType>::as_field_type() ), )* + ] + } + } + + $crate::mcu_message_impl_oid_check!($ty_name $(, $arg)*); + } + }; +} + +/// Declare a host -> device message +/// +/// Declares the format of a message sent from host to device. The general format is as follows: +/// +/// ```text +/// mcu_command!(, ""( = ) [, arg: type, ..]); +/// ``` +/// +/// A struct with the given name will be defined, with relevant interfaces. The message name and +/// arguments will be matched to the MCU at runtime. The struct name has no restrictions, but it is +/// recommended to pick a SnakeCased version of the command name. +/// +/// Optionally an `id` can be directly specified. Generally this is not needed. When not specified, +/// it will be automatically inferred at runtime and matched to the dictionary retrieved from the +/// target MCU. +/// +/// Arguments can be specified, they will be mapped to the relevant Klipper argument types. The +/// supported types and mappings are as follows: +/// +/// | Rust type | Format string | +/// |------------|---------------| +/// | `u32` | `%u` | +/// | `i32` | `%i` | +/// | `u16` | `%hu` | +/// | `i16` | `%hi` | +/// | `u8` | `%c` | +/// | `&'a [u8]` | `%.*s`, `%*s` | +/// | `&'a str` | `%s` | +/// +/// Note that the buffer types take a lifetime. This must always be `'a`. +/// +/// # Examples +/// +/// ```ignore +/// // This defines 'config_endstop oid=%c pin=%c pull_up=%c' +/// mcu_command!(ConfigEndstop, "config_endstop", oid: u8, pin: u8, pull_up: u8); +/// ``` +#[macro_export] +macro_rules! mcu_command { + ($ty_name:ident, $cmd_name:literal $(, $arg:ident : $kind:ty)* $(,)?) => { + $crate::mcu_message_impl!($ty_name, $cmd_name = None $(, $arg: $kind)*); + }; + ($ty_name:ident, $cmd_name:literal = $cmd_id:literal $(, $arg:ident : $kind:ty)* $(,)?) => { + $crate::mcu_message_impl!($ty_name, $cmd_name = Some($cmd_id) $(, $arg: $kind)*); + }; +} + +/// Declare a device -> host message +/// +/// Declares the format of a message sent from device to host. The general format is as follows: +/// +/// ```text +/// mcu_reply!(, ""( = ) [, arg: type, ..]); +/// ``` +/// +/// For more information on the various fields, see the documentation for [mcu_command]. +/// +/// # Examples +/// +/// ```ignore +/// // This defines 'config_endstop oid=%c pin=%c pull_up=%c' +/// mcu_reply!(Uptime, "uptime", high: u32, clock: u32); +/// ``` +#[macro_export] +macro_rules! mcu_reply { + ($ty_name:ident, $cmd_name:literal $(, $arg:ident : $kind:ty)* $(,)?) => { + $crate::mcu_message_impl!($ty_name, $cmd_name = None $(, $arg: $kind)*); + }; + ($ty_name:ident, $cmd_name:literal = $cmd_id:literal $(, $arg:ident : $kind:ty)* $(,)?) => { + $crate::mcu_message_impl!($ty_name, $cmd_name = Some($cmd_id) $(, $arg: $kind)*); + }; +} diff --git a/crates/windlass/src/mcu.rs b/crates/windlass/src/mcu.rs new file mode 100644 index 0000000..53fea7f --- /dev/null +++ b/crates/windlass/src/mcu.rs @@ -0,0 +1,585 @@ +use std::{ + any::TypeId, + collections::{BTreeMap, BTreeSet}, + io::Read, + sync::{atomic::AtomicUsize, Arc, Mutex}, + time::{Duration, Instant}, +}; + +use tokio::{ + io::{AsyncRead, AsyncWrite}, + select, + sync::{ + mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}, + oneshot, OnceCell, + }, + task::{spawn, JoinHandle}, +}; +use tracing::{debug, error, trace}; + +use crate::messages::{format_command_args, EncodedMessage, Message, WithOid, WithoutOid}; +use crate::transport::{Transport, TransportReceiver}; +use crate::{ + dictionary::{Dictionary, DictionaryError, RawDictionary}, + transport::TransportError, +}; +use crate::{ + encoding::{parse_vlq_int, MessageDecodeError}, + messages::FrontTrimmableBuffer, +}; + +mcu_command!(Identify, "identify" = 1, offset: u32, count: u8); +mcu_reply!( + IdentifyResponse, + "identify_response" = 0, + offset: u32, + data: Vec, +); + +/// MCU Connection Errors +#[derive(thiserror::Error, Debug)] +pub enum McuConnectionError { + /// Encoding a message was attempted but a command with a matching name doesn't exist in the + /// dictionary. + #[error("unknown message ID for command '{0}'")] + UnknownMessageId(&'static str), + /// An issue was encountered while decoding the received frame + #[error("error decoding message with ID '{0}'")] + DecodingError(#[from] MessageDecodeError), + /// An error was encountered while fetching the dictionary + #[error("error obtaining identify data: {0}")] + DictionaryFetch(Box), + /// An error was encountered while parsing the dictionary + #[error("dictionary issue: {0}")] + Dictionary(#[from] DictionaryError), + /// Received an unknown command from the MCU + #[error("unknown command {0}")] + UnknownCommand(u16), + /// There was a mismatch between the command arguments on the remote and local sides. + #[error("mismatched command {0}: '{1}' vs '{2}'")] + CommandMismatch(&'static str, String, String), + /// There was a transport-level issue + #[error("transport error: {0}")] + Transport(#[from] TransportError), +} + +#[derive(thiserror::Error, Debug)] +pub enum SendReceiveError { + #[error("timeout")] + Timeout, + #[error("mcu connection error: {0}")] + McuConnection(#[from] McuConnectionError), +} + +#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash, Ord, PartialOrd)] +struct HandlerId(u16, Option); + +trait Handler: Send + Sync { + fn handle(&mut self, data: &mut &[u8]) -> HandlerResult; +} + +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub struct RegisteredResponseHandle(usize); + +#[derive(Default)] +struct Handlers { + handlers: std::sync::Mutex>, + next_handler: AtomicUsize, +} + +type UniqueHandler = (usize, Box); + +impl std::fmt::Debug for Handlers { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let h = self + .handlers + .try_lock() + .map(|h| h.keys().cloned().collect::>()) + .unwrap_or_default(); + f.debug_struct("Handlers") + .field("handlers", &h) + .finish_non_exhaustive() + } +} + +impl Handlers { + fn register( + &self, + command: u16, + oid: Option, + handler: Box, + ) -> RegisteredResponseHandle { + let id = HandlerId(command, oid); + let uniq = self + .next_handler + .fetch_add(1, std::sync::atomic::Ordering::AcqRel); + self.handlers.lock().unwrap().insert(id, (uniq, handler)); + RegisteredResponseHandle(uniq) + } + + fn call(&self, command: u16, oid: Option, data: &mut &[u8]) -> bool { + let id = HandlerId(command, oid); + let mut handlers = self.handlers.lock().unwrap(); + let handler = handlers.get_mut(&id); + if let Some((_, handler)) = handler { + if matches!(handler.handle(data), HandlerResult::Deregister) { + handlers.remove(&id); + } + true + } else { + false + } + } + + fn remove_handler( + &self, + command: u16, + oid: Option, + uniq: Option, + ) { + let id = HandlerId(command, oid); + let mut handlers = self.handlers.lock().unwrap(); + if let Some(RegisteredResponseHandle(uniq)) = uniq { + match handlers.get(&id) { + None => return, + Some((u, _)) if *u != uniq => return, + _ => {} + } + } + handlers.remove(&id); + } +} + +#[derive(Debug)] +struct McuConnectionInner { + dictionary: OnceCell, + handlers: Handlers, + good_types: Mutex>, +} + +impl McuConnectionInner { + async fn receive_loop(&self, mut inbox: TransportReceiver) -> Result<(), McuConnectionError> { + while let Some(frame) = inbox.recv().await { + let frame = match frame { + Ok(frame) => frame, + Err(e) => return Err(McuConnectionError::Transport(e)), + }; + let frame = &mut frame.as_slice(); + self.parse_frame(frame)?; + } + Ok(()) + } + + fn parse_frame(&self, frame: &mut &[u8]) -> Result<(), McuConnectionError> { + while !frame.is_empty() { + let cmd = parse_vlq_int(frame)? as u16; + + // If a raw handler is registered for this, call it + if self.handlers.call(cmd, None, frame) { + continue; + } + + // There was no registered raw handler, check the dictionary to decode + if let Some(dict) = self.dictionary.get() { + if let Some(parser) = dict.message_parsers.get(&cmd) { + if let Some(msg) = parser.parse_and_output(frame) { + let msg = msg?; + debug!(msg = msg, "Output message from MCU"); + continue; + } + + let mut curmsg = &frame[..]; + let oid = parser.skip_with_oid(frame)?; + if tracing::enabled!(tracing::Level::TRACE) { + #[allow(clippy::redundant_slicing)] // The parse call changes the slice! + let mut tmp = &curmsg[..]; + trace!( + cmd_id = cmd, + cmd_name = parser.name, + data = ?parser.parse(&mut tmp).unwrap(), // Safe because we skipped before and it passed + "Received message", + ); + } + if !self.handlers.call(cmd, oid, &mut curmsg) { + debug!( + cmd_id = cmd, + cmd_name = parser.name, + // Safe because we skipped before and it passed. Since call returned + // false, we know no handler took the data and can safely take curmsg + // here. + data = ?parser.parse(&mut curmsg).unwrap(), + "Unhandled message", + ); + } + } else { + return Err(McuConnectionError::UnknownCommand(cmd)); + } + } else { + break; + } + } + Ok(()) + } +} + +/// Manages a connection to a Klipper MCU +/// +/// +#[derive(Debug)] +pub struct McuConnection { + inner: Arc, + transport: Transport, + _receiver: JoinHandle<()>, + exit_code: ExitCode, +} + +#[derive(Debug)] +enum ExitCode { + Waiting(oneshot::Receiver>), + Exited(Result<(), McuConnectionError>), +} + +impl McuConnection { + /// Connect to an MCU + /// + /// Attempts to connect to a MCU on the given data stream interface. The MCU is contacted and + /// the dictionary is loaded and applied. The returned [McuConnection] is fully ready to + /// communicate with the MCU. + pub async fn connect(stream: R) -> Result + where + R: AsyncRead + AsyncWrite + Send + 'static, + { + let (transport, inbox) = Transport::connect(stream).await; + + let (exit_tx, exit_rx) = oneshot::channel(); + let inner = Arc::new(McuConnectionInner { + dictionary: OnceCell::new(), + handlers: Handlers::default(), + good_types: Mutex::default(), + }); + + let receiver_inner = inner.clone(); + let receiver = tokio::spawn(async move { + let _ = exit_tx.send(receiver_inner.receive_loop(inbox).await); + }); + + let conn = McuConnection { + inner, + transport, + _receiver: receiver, + exit_code: ExitCode::Waiting(exit_rx), + }; + + let mut identify_data = Vec::new(); + let identify_start = Instant::now(); + loop { + let data = conn + .send_receive( + Identify::encode(identify_data.len() as u32, 40), + IdentifyResponse, + ) + .await; + let mut data = match data { + Ok(data) => data, + Err(e) => { + if identify_start.elapsed() > Duration::from_secs(10) { + return Err(McuConnectionError::DictionaryFetch(Box::new(e))); + } + continue; + } + }; + if data.offset as usize == identify_data.len() { + if data.data.is_empty() { + break; + } + identify_data.append(&mut data.data); + } + } + let mut decoder = flate2::read::ZlibDecoder::new(identify_data.as_slice()); + let mut buf = Vec::new(); + decoder + .read_to_end(&mut buf) + .map_err(|err| McuConnectionError::DictionaryFetch(Box::new(err)))?; + let raw_dict: RawDictionary = serde_json::from_slice(&buf) + .map_err(|err| McuConnectionError::DictionaryFetch(Box::new(err)))?; + let dict = + Dictionary::from_raw_dictionary(raw_dict).map_err(McuConnectionError::Dictionary)?; + debug!(dictionary = ?dict, "MCU dictionary"); + conn.inner + .dictionary + .set(dict) + .expect("Dictionary already set"); + + Ok(conn) + } + + fn verify_command_matches(&self) -> Result<(), McuConnectionError> { + // Special handling for identify/identify_response + if C::get_id(None) == Some(0) || C::get_id(None) == Some(1) { + return Ok(()); + } + + if self + .inner + .good_types + .lock() + .unwrap() + .contains(&TypeId::of::()) + { + return Ok(()); + } + + // We can only get here if we _have_ a dictionary as only Identify/IdentifyResponse are + // tested before then. + let dictionary = self.inner.dictionary.get().unwrap(); + let id = C::get_id(Some(dictionary)) + .ok_or_else(|| McuConnectionError::UnknownMessageId(C::get_name()))?; + + // Must exist because we know the tag + let parser = dictionary.message_parsers.get(&id).unwrap(); + let remote_fields = parser.fields.iter().map(|(s, t)| (s.as_str(), *t)); + let local_fields = C::fields().into_iter(); + + if !remote_fields.eq(local_fields) { + return Err(McuConnectionError::CommandMismatch( + C::get_name(), + format_command_args(parser.fields.iter().map(|(s, t)| (s.as_str(), *t))), + format_command_args(C::fields().into_iter()), + )); + } + + self.inner + .good_types + .lock() + .unwrap() + .insert(TypeId::of::()); + + Ok(()) + } + + fn encode_command( + &self, + command: EncodedMessage, + ) -> Result { + let id = command + .message_id(self.inner.dictionary.get()) + .ok_or_else(|| McuConnectionError::UnknownMessageId(command.message_name()))?; + let mut payload = command.payload; + if id >= 0x80 { + payload.content[0] = ((id >> 7) & 0x7F) as u8 | 0x80; + } else { + payload.offset = 1; + } + payload.content[1] = (id & 0x7F) as u8; + Ok(payload) + } + + /// Sends a command to the MCU + pub async fn send( + &self, + command: EncodedMessage, + ) -> Result<(), McuConnectionError> { + let cmd = self.encode_command(command)?; + self.transport + .send(cmd.as_slice()) + .map_err(|e| McuConnectionError::Transport(TransportError::Transmitter(e))) + } + + async fn send_receive_impl( + &self, + command: EncodedMessage, + reply: R, + oid: Option, + ) -> Result { + struct RespHandler( + Option>>, + ); + + impl Handler for RespHandler { + fn handle(&mut self, data: &mut &[u8]) -> HandlerResult { + if let Some(tx) = self.0.take() { + let _ = match R::decode(data) { + Ok(msg) => tx.send(Ok(msg.into())), + Err(e) => tx.send(Err(McuConnectionError::DecodingError(e))), + }; + } + HandlerResult::Deregister + } + } + + let cmd = self.encode_command(command)?; + + self.verify_command_matches::()?; + self.verify_command_matches::()?; + + let (tx, mut rx) = tokio::sync::oneshot::channel::>(); + self.register_raw_response(reply, oid, Box::new(RespHandler::(Some(tx))))?; + + let mut retry_delay = 0.01; + for _retry in 0..=5 { + self.transport + .send(cmd.as_slice()) + .map_err(|e| McuConnectionError::Transport(TransportError::Transmitter(e)))?; + + let sleep = tokio::time::sleep(Duration::from_secs_f32(retry_delay)); + tokio::pin!(sleep); + + select! { + reply = &mut rx => { + return match reply { + Ok(Err(e)) => Err(SendReceiveError::McuConnection(e)), + Ok(Ok(v)) => Ok(v), + Err(_) => Err(SendReceiveError::Timeout) + }; + }, + _ = &mut sleep => {}, + } + + retry_delay *= 2.0; + } + + Err(SendReceiveError::Timeout) + } + + /// Sends a message to the MCU, and await a reply + /// + /// This sends a message to the MCU and awaits a reply matching the given `reply`. + /// This version works with replies that have an `oid` field. + pub async fn send_receive_oid( + &self, + command: EncodedMessage, + reply: R, + oid: u8, + ) -> Result { + self.send_receive_impl(command, reply, Some(oid)).await + } + + /// Sends a message to the MCU, and await a reply + /// + /// This sends a message to the MCU and awaits a reply matching the given `reply`. + /// This version works with replies that do not have an `oid` field. + pub async fn send_receive( + &self, + command: EncodedMessage, + reply: R, + ) -> Result { + self.send_receive_impl(command, reply, None).await + } + + fn register_raw_response( + &self, + _reply: R, + oid: Option, + handler: Box, + ) -> Result { + let id = R::get_id(self.inner.dictionary.get()) + .ok_or_else(|| McuConnectionError::UnknownMessageId(R::get_name()))?; + Ok(self.inner.handlers.register(id, oid, handler)) + } + + fn register_response_impl( + &self, + reply: R, + oid: Option, + ) -> Result, McuConnectionError> { + struct RespHandler(UnboundedSender, Option>); + + impl Drop for RespHandler { + fn drop(&mut self) { + self.1.take(); + } + } + + impl Handler for RespHandler { + fn handle(&mut self, data: &mut &[u8]) -> HandlerResult { + let msg = R::decode(data) + .expect("Parser should already have assured this could parse") + .into(); + match self.0.send(msg) { + Ok(_) => HandlerResult::Continue, + Err(_) => HandlerResult::Deregister, + } + } + } + + self.verify_command_matches::()?; + + let (tx, rx) = unbounded_channel(); + let tx_closed = tx.clone(); + let (closer_tx, closer_rx) = oneshot::channel(); + let uniq = self.register_raw_response( + reply, + oid, + Box::new(RespHandler::(tx, Some(closer_tx))), + )?; + + // Safe because register_raw_response already verified this + let id = R::get_id(self.inner.dictionary.get()).unwrap(); + + let inner = self.inner.clone(); + spawn(async move { + select! { + _ = tx_closed.closed() => {}, + _ = closer_rx => {}, + } + inner.handlers.remove_handler(id, oid, Some(uniq)); + }); + + Ok(rx) + } + + /// Registers a subscriber for a reply message + /// + /// Received replies matching the type will be delivered to the returned channel. + /// To end the subscription, simply drop the returned receiver. + /// This version works with replies that have an `oid` field. + pub fn register_response_oid( + &self, + reply: R, + oid: u8, + ) -> Result, McuConnectionError> { + self.register_response_impl(reply, Some(oid)) + } + + /// Registers a subscriber for a reply message + /// + /// Received replies matching the type will be delivered to the returned channel. + /// To end the subscription, simply drop the returned receiver. + /// This version works with replies that do not have an `oid` field. + pub fn register_response( + &self, + reply: R, + ) -> Result, McuConnectionError> { + self.register_response_impl(reply, None) + } + + /// Returns a reference the dictionary the MCU returned during initial handshake + pub fn dictionary(&self) -> &Dictionary { + self.inner.dictionary.get().unwrap() + } + + /// Closes the transport and ends all subscriptions. Returns when the transport is closed. + pub async fn close(self) { + self.transport.close().await; + } + + /// Waits for the connection to close, returning the error that closed it if any. + pub async fn closed(&mut self) -> &Result<(), McuConnectionError> { + if let ExitCode::Waiting(chan) = &mut self.exit_code { + self.exit_code = ExitCode::Exited(match chan.await { + Ok(r) => r, + Err(_) => Ok(()), + }); + } + match self.exit_code { + ExitCode::Exited(ref val) => val, + _ => unreachable!(), + } + } +} + +#[derive(Debug)] +enum HandlerResult { + Continue, + Deregister, +} diff --git a/crates/windlass/src/messages.rs b/crates/windlass/src/messages.rs new file mode 100644 index 0000000..5b97d31 --- /dev/null +++ b/crates/windlass/src/messages.rs @@ -0,0 +1,237 @@ +use std::collections::BTreeMap; + +use crate::dictionary::Dictionary; +use crate::encoding::{FieldType, FieldValue, MessageDecodeError}; + +#[derive(thiserror::Error, Debug)] +pub enum MessageSkipperError { + #[error("invalid argument format: {0}")] + InvalidArgumentFormat(String), + #[error("unknown type '{1}' for argument '{0}'")] + UnknownType(String, String), + #[error("invalid format field type '%{0}'")] + InvalidFormatFieldType(String), +} + +pub(crate) fn format_command_args<'a>( + fields: impl Iterator, +) -> String { + let mut buf = String::new(); + for (idx, (name, ty)) in fields.enumerate() { + if idx != 0 { + buf.push(' '); + } + buf.push_str(name); + buf.push('='); + buf.push_str(match ty { + FieldType::U32 => "%u", + FieldType::I32 => "%i", + FieldType::U16 => "%hu", + FieldType::I16 => "%hi", + FieldType::U8 => "%c", + FieldType::String => "%s", + FieldType::ByteArray => "%*s", + }); + } + buf +} + +pub struct MessageParser { + pub name: String, + pub fields: Vec<(String, FieldType)>, + pub output: Option, +} + +impl MessageParser { + pub(crate) fn new<'a>( + name: &str, + parts: impl Iterator, + ) -> Result { + let mut fields = vec![]; + for part in parts { + let (arg, ty) = part + .split_once('=') + .ok_or_else(|| MessageSkipperError::InvalidArgumentFormat(part.into()))?; + + let field_type = FieldType::from_msg(ty)?; + fields.push((arg.to_string(), field_type)); + } + Ok(Self { + name: name.to_string(), + fields, + output: None, + }) + } + + pub(crate) fn new_output(msg: &str) -> Result { + let mut fields = vec![]; + let mut parts = vec![]; + + let mut work = msg; + while let Some(pos) = work.find('%') { + let (pre, rest) = work.split_at(pos); + if !pre.is_empty() { + parts.push(FormatBlock::Static(pre.to_string())); + } + if let Some(rest) = rest.strip_prefix("%%") { + parts.push(FormatBlock::Static("%".to_string())); + work = rest; + break; + } + let (format, rest) = FieldType::from_format(rest)?; + parts.push(FormatBlock::Field); + fields.push((format!("field_{}", fields.len()), format)); + work = rest; + } + if !work.is_empty() { + parts.push(FormatBlock::Static(work.to_string())); + } + + Ok(Self { + name: msg.to_string(), + fields, + output: Some(OutputFormat { parts }), + }) + } + + #[allow(dead_code)] + pub(crate) fn skip(&self, input: &mut &[u8]) -> Result<(), MessageDecodeError> { + for (_, field) in &self.fields { + field.skip(input)?; + } + Ok(()) + } + + pub(crate) fn skip_with_oid( + &self, + input: &mut &[u8], + ) -> Result, MessageDecodeError> { + let mut oid = None; + for (name, field) in &self.fields { + if name == "oid" { + if let FieldValue::U8(read_oid) = field.read(input)? { + oid = Some(read_oid); + } + } else { + field.skip(input)?; + } + } + Ok(oid) + } + + pub(crate) fn parse( + &self, + input: &mut &[u8], + ) -> Result, MessageDecodeError> { + let mut output = BTreeMap::new(); + for (name, field) in &self.fields { + output.insert(name.to_string(), field.read(input)?); + } + Ok(output) + } + + pub(crate) fn parse_and_output( + &self, + input: &mut &[u8], + ) -> Option> { + self.output.as_ref().map(|output| { + let fields = self.parse(input)?; + Ok(output.format(fields.values())) + }) + } +} + +impl std::fmt::Debug for MessageParser { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_map() + .entry(&"name", &self.name) + .entry(&"fields", &self.fields) + .finish() + } +} + +pub trait Message: 'static { + type Pod<'a>: Into + std::fmt::Debug; + type PodOwned: Clone + Send + std::fmt::Debug + 'static; + fn get_id(dict: Option<&Dictionary>) -> Option; + fn get_name() -> &'static str; + fn decode<'a>(input: &mut &'a [u8]) -> Result, MessageDecodeError>; + fn fields() -> Vec<(&'static str, FieldType)>; +} + +pub trait WithOid: 'static {} +pub trait WithoutOid: 'static {} + +/// Represents an encoded message, with a type-level link to the message kind +pub struct EncodedMessage { + pub payload: FrontTrimmableBuffer, + pub _message_kind: std::marker::PhantomData, +} + +impl std::fmt::Debug for EncodedMessage { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("EncodedMessage") + .field("kind", &R::get_name()) + .field("payload", &self.payload) + .finish() + } +} + +impl EncodedMessage { + #[doc(hidden)] + pub fn message_id(&self, dict: Option<&Dictionary>) -> Option { + M::get_id(dict) + } + + #[doc(hidden)] + pub fn message_name(&self) -> &'static str { + M::get_name() + } +} + +/// Wraps a `Vec` allowing removal of front bytes in a zero-copy way +pub struct FrontTrimmableBuffer { + pub content: Vec, + pub offset: usize, +} + +impl FrontTrimmableBuffer { + pub fn as_slice(&self) -> &[u8] { + &self.content[self.offset..] + } +} + +impl std::fmt::Debug for FrontTrimmableBuffer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + std::fmt::Debug::fmt(self.as_slice(), f) + } +} + +/// Holds the format of a `output()` style debug message +#[derive(Debug)] +pub struct OutputFormat { + parts: Vec, +} + +impl OutputFormat { + fn format<'a>(&self, mut fields: impl Iterator) -> String { + let mut buf = String::new(); + for part in &self.parts { + match part { + FormatBlock::Static(s) => buf.push_str(s), + FormatBlock::Field => { + if let Some(v) = fields.next() { + std::fmt::write(&mut buf, format_args!("{v}")).ok(); + } + } + } + } + buf + } +} + +#[derive(Debug)] +enum FormatBlock { + Static(String), + Field, +} diff --git a/crates/windlass/src/transport.rs b/crates/windlass/src/transport.rs new file mode 100644 index 0000000..e47cc8e --- /dev/null +++ b/crates/windlass/src/transport.rs @@ -0,0 +1,615 @@ +use std::{collections::VecDeque, sync::Arc}; +use tokio::{ + io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader}, + pin, select, + sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}, + task::{spawn, JoinHandle}, + time::{sleep_until, Instant}, +}; +use tokio_util::sync::CancellationToken; +use tracing::trace; + +const MESSAGE_HEADER_SIZE: usize = 2; +const MESSAGE_TRAILER_SIZE: usize = 3; +pub const MESSAGE_LENGTH_MIN: usize = MESSAGE_HEADER_SIZE + MESSAGE_TRAILER_SIZE; +pub const MESSAGE_LENGTH_MAX: usize = 64; +pub const MESSAGE_LENGTH_PAYLOAD_MAX: usize = MESSAGE_LENGTH_MAX - MESSAGE_LENGTH_MIN; +const MESSAGE_POSITION_SEQ: usize = 1; +const MESSAGE_TRAILER_CRC: usize = 3; +const MESSAGE_VALUE_SYNC: u8 = 0x7E; +const MESSAGE_DEST: u8 = 0x10; +const MESSAGE_SEQ_MASK: u8 = 0x0F; + +#[derive(thiserror::Error, Debug)] +pub enum ReceiverError { + #[error("io error: {0}")] + IoError(#[from] std::io::Error), +} + +#[derive(thiserror::Error, Debug)] +pub enum TransmitterError { + #[error("io error: {0}")] + IoError(#[from] std::io::Error), + #[error("connection closed")] + ConnectionClosed, +} + +#[derive(Debug)] +pub(crate) struct Transport { + _task_rdr: JoinHandle<()>, + _task_wr: JoinHandle<()>, + task_inner: JoinHandle<()>, + outbox_tx: UnboundedSender, +} + +#[derive(Debug)] +enum TransportCommand { + SendMessage(Vec), + Exit, +} + +pub(crate) type TransportReceiver = UnboundedReceiver, TransportError>>; + +impl Transport { + pub(crate) async fn connect(stream: R) -> (Transport, TransportReceiver) + where + R: AsyncRead + AsyncWrite + Send + 'static, + { + let (rdr, wr) = tokio::io::split(stream); + + let (raw_send_tx, raw_send_rx) = unbounded_channel(); + let (raw_recv_tx, raw_recv_rx) = unbounded_channel(); + let (app_inbox_tx, app_inbox_rx) = unbounded_channel(); + let (app_outbox_tx, app_outbox_rx) = unbounded_channel(); + + let cancel_token = CancellationToken::new(); + + let cancel = cancel_token.clone(); + let ait = app_inbox_tx.clone(); + let task_rdr = spawn(async move { + if let Err(e) = LowlevelReader::run(raw_recv_tx, rdr, cancel).await { + let _ = ait.send(Err(TransportError::Receiver(e))); + } + }); + + let cancel = cancel_token.clone(); + let ait = app_inbox_tx.clone(); + let task_wr = spawn(async move { + if let Err(e) = LowlevelWriter::run(raw_send_rx, wr, cancel).await { + let _ = ait.send(Err(TransportError::Transmitter(e))); + } + }); + + let task_inner = spawn(async move { + let mut ts = TransportState::new( + raw_recv_rx, + raw_send_tx, + app_inbox_tx, + app_outbox_rx, + cancel_token, + ); + if let Err(e) = ts.protocol_handler().await { + let _ = ts.app_inbox.send(Err(e)); + } + }); + + ( + Transport { + _task_rdr: task_rdr, + _task_wr: task_wr, + task_inner, + outbox_tx: app_outbox_tx, + }, + app_inbox_rx, + ) + } + + pub(crate) fn send(&self, msg: &[u8]) -> Result<(), TransmitterError> { + self.outbox_tx + .send(TransportCommand::SendMessage(msg.into())) + .map_err(|_| TransmitterError::ConnectionClosed) + } + + pub(crate) async fn close(self) { + let _ = self.outbox_tx.send(TransportCommand::Exit); + let _ = self.task_inner.await; + } +} + +struct LowlevelReader { + rdr: BufReader, + synced: bool, +} + +impl LowlevelReader { + async fn read_frame(&mut self) -> Result, ReceiverError> { + let mut buf = [0u8; MESSAGE_LENGTH_MAX]; + let next_byte = self.rdr.read_u8().await?; + if next_byte == MESSAGE_VALUE_SYNC { + if !self.synced { + self.synced = true; + } + return Ok(None); + } + if !self.synced { + return Ok(None); + } + + let receive_time = Instant::now(); + let len = next_byte as usize; + + if !(MESSAGE_LENGTH_MIN..=MESSAGE_LENGTH_MAX).contains(&len) { + self.synced = false; + return Ok(None); + } + + self.rdr.read_exact(&mut buf[1..len]).await?; + buf[0] = len as u8; + let buf = &buf[..len]; + let seq = buf[MESSAGE_POSITION_SEQ]; + trace!(frame = ?buf, seq = seq, "Received frame"); + + if seq & !MESSAGE_SEQ_MASK != MESSAGE_DEST { + self.synced = false; + return Ok(None); + } + + let actual_crc = crc16(&buf[0..len - MESSAGE_TRAILER_SIZE]); + let frame_crc = (buf[len - MESSAGE_TRAILER_CRC] as u16) << 8 + | (buf[len - MESSAGE_TRAILER_CRC + 1] as u16); + if frame_crc != actual_crc { + self.synced = false; + return Ok(None); + } + + Ok(Some(Frame { + receive_time, + sequence: seq & MESSAGE_SEQ_MASK, + payload: buf[MESSAGE_HEADER_SIZE..len - MESSAGE_TRAILER_SIZE].into(), + })) + } + + async fn run( + outbox: UnboundedSender, + rdr: R, + cancel_token: CancellationToken, + ) -> Result<(), ReceiverError> + where + R: AsyncRead + Unpin, + { + let mut state = Self { + rdr: BufReader::new(rdr), + synced: false, + }; + + loop { + match state.read_frame().await { + Ok(None) => {} + Ok(Some(frame)) => { + if outbox.send(frame).is_err() { + break Ok(()); + } + } + Err(_) if cancel_token.is_cancelled() => break Ok(()), + Err(e) => break Err(e), + } + } + } +} + +fn crc16(buf: &[u8]) -> u16 { + let mut crc = 0xFFFFu16; + for b in buf { + let b = *b ^ ((crc & 0xFF) as u8); + let b = b ^ (b << 4); + let b16 = b as u16; + crc = (b16 << 8 | crc >> 8) ^ (b16 >> 4) ^ (b16 << 3); + } + crc +} + +#[derive(Debug)] +struct Frame { + receive_time: Instant, + sequence: u8, + payload: Vec, +} + +#[derive(Debug, Clone)] +struct InflightFrame { + sent_at: Instant, + #[allow(dead_code)] + sequence: u64, + payload: Arc>, + is_retransmit: bool, +} + +#[derive(thiserror::Error, Debug)] +pub enum MessageEncodeError { + #[error("message would exceed the maximum packet length of {MESSAGE_LENGTH_MAX} bytes")] + MessageTooLong, +} + +fn encode_frame(sequence: u64, payload: &[u8]) -> Result, MessageEncodeError> { + let len = MESSAGE_LENGTH_MIN + payload.len(); + if len > MESSAGE_LENGTH_MAX { + return Err(MessageEncodeError::MessageTooLong); + } + let mut buf = Vec::with_capacity(len); + buf.push(len as u8); + buf.push(MESSAGE_DEST | ((sequence as u8) & MESSAGE_SEQ_MASK)); + buf.extend_from_slice(payload); + let crc = crc16(&buf[0..len - MESSAGE_TRAILER_SIZE]); + buf.push(((crc >> 8) & 0xFF) as u8); + buf.push((crc & 0xFF) as u8); + buf.push(MESSAGE_VALUE_SYNC); + Ok(buf) +} + +struct LowlevelWriter {} + +impl LowlevelWriter { + async fn run( + mut inbox: UnboundedReceiver>>, + mut wr: W, + cancel_token: CancellationToken, + ) -> Result<(), TransmitterError> + where + W: AsyncWrite + Unpin, + { + loop { + select! { + msg = inbox.recv() => { + let msg = match msg { + Some(msg) => msg, + None => break, + }; + trace!(payload = ?msg, seq = msg[MESSAGE_POSITION_SEQ], "Sent frame"); + wr.write_all(&msg).await?; + wr.flush().await?; + } + + _ = cancel_token.cancelled() => { + break; + } + } + } + + Ok(()) + } +} + +#[derive(thiserror::Error, Debug)] +pub enum TransportError { + #[error("message encoding failed: {0}")] + MessageEncode(#[from] MessageEncodeError), + #[error("receiver error: {0}")] + Receiver(#[from] ReceiverError), + #[error("transmitter error: {0}")] + Transmitter(#[from] TransmitterError), +} + +const MIN_RTO: f32 = 0.025; +const MAX_RTO: f32 = 5.000; + +#[derive(Debug)] +struct RttState { + srtt: f32, + rttvar: f32, + rto: f32, +} + +impl RttState { + fn new() -> Self { + Self { + srtt: 0.0, + rttvar: 0.0, + rto: MIN_RTO, + } + } + + fn rto(&self) -> std::time::Duration { + std::time::Duration::from_secs_f32(self.rto) + } + + fn update(&mut self, rtt: std::time::Duration) { + let r = rtt.as_secs_f32(); + if self.srtt == 0.0 { + self.rttvar = r / 2.0; + self.srtt = r * 10.0; // Klipper uses this, we'll copy it + } else { + self.rttvar = (3.0 * self.rttvar + (self.srtt - r).abs()) / 4.0; + self.srtt = (7.0 * self.srtt + r) / 8.0; + } + let rttvar4 = (self.rttvar * 4.0).max(0.001); + self.rto = (self.srtt + rttvar4).clamp(MIN_RTO, MAX_RTO); + } +} + +#[derive(Debug)] +struct TransportState { + link_inbox: UnboundedReceiver, + link_outbox: UnboundedSender>>, + app_inbox: UnboundedSender, TransportError>>, + app_outbox: UnboundedReceiver, + cancel: CancellationToken, + + is_synchronized: bool, + rtt_state: RttState, + receive_sequence: u64, + send_sequence: u64, + last_ack_sequence: u64, + ignore_nak_seq: u64, + retransmit_seq: u64, + retransmit_now: bool, + + corked_until: Option, + + inflight_messages: VecDeque, + ready_messages: MessageQueue>, +} + +impl TransportState { + fn new( + link_inbox: UnboundedReceiver, + link_outbox: UnboundedSender>>, + app_inbox: UnboundedSender, TransportError>>, + app_outbox: UnboundedReceiver, + cancel: CancellationToken, + ) -> Self { + Self { + link_inbox, + link_outbox, + app_inbox, + app_outbox, + cancel, + + is_synchronized: false, + rtt_state: RttState::new(), + receive_sequence: 1, + send_sequence: 1, + last_ack_sequence: 0, + ignore_nak_seq: 0, + retransmit_seq: 0, + retransmit_now: false, + + corked_until: None, + + inflight_messages: VecDeque::new(), + ready_messages: MessageQueue::new(), + } + } + + fn update_receive_seq(&mut self, receive_time: Instant, sequence: u64) { + let mut sent_seq = self.receive_sequence; + + loop { + if let Some(msg) = self.inflight_messages.pop_front() { + sent_seq += 1; + if sequence == sent_seq { + // Found the matching sent message + if !msg.is_retransmit { + let elapsed = receive_time.saturating_duration_since(msg.sent_at); + self.rtt_state.update(elapsed); + } + break; + } + } else { + // Ack with no outstanding messages, happens during connection init + self.send_sequence = sequence; + break; + } + } + + self.receive_sequence = sequence; + self.is_synchronized = true; + } + + fn handle_frame(&mut self, frame: Frame) { + let rseq = self.receive_sequence; + let mut sequence = (rseq & !(MESSAGE_SEQ_MASK as u64)) | (frame.sequence as u64); + if sequence < rseq { + sequence += (MESSAGE_SEQ_MASK as u64) + 1; + } + if !self.is_synchronized || sequence != rseq { + if sequence > self.send_sequence && self.is_synchronized { + // Ack for unsent message + return; + } + self.update_receive_seq(frame.receive_time, sequence); + } + + if frame.payload.is_empty() { + if self.last_ack_sequence < sequence { + self.last_ack_sequence = sequence; + } else if sequence > self.ignore_nak_seq && !self.inflight_messages.is_empty() { + // Trigger retransmit from NAK + self.retransmit_now = true; + } + } else { + // Data message, we deliver this directly to the application as the MCU can't actually + // retransmit anyway. + let _ = self.app_inbox.send(Ok(frame.payload)); + } + } + + fn can_send(&self) -> bool { + self.corked_until.is_none() && self.inflight_messages.len() < 12 + } + + fn send_new_frame(&mut self, mut initial: Vec) -> Result<(), TransportError> { + while let Some(next) = self.ready_messages.try_peek() { + if initial.len() + next.len() <= MESSAGE_LENGTH_PAYLOAD_MAX { + // Add to the end of the message. Unwrap is safe because we already peeked. + let mut next = self.ready_messages.try_recv().unwrap(); + initial.append(&mut next); + } else { + break; + } + } + let frame = Arc::new(encode_frame(self.send_sequence, &initial)?); + self.send_sequence += 1; + self.inflight_messages.push_back(InflightFrame { + sent_at: Instant::now(), + sequence: self.send_sequence, + payload: frame.clone(), + is_retransmit: false, + }); + let _ = self.link_outbox.send(frame); + Ok(()) + } + + fn send_more_frames(&mut self) -> Result<(), TransportError> { + while self.can_send() && self.ready_messages.try_peek().is_some() { + let msg = self.ready_messages.try_recv().unwrap(); + self.send_new_frame(msg)?; + } + Ok(()) + } + + fn retransmit_pending(&mut self) { + let len: usize = self + .inflight_messages + .iter() + .map(|msg| msg.payload.len()) + .sum(); + let mut buf = Vec::with_capacity(1 + len); + buf.push(MESSAGE_VALUE_SYNC); + let now = Instant::now(); + for msg in self.inflight_messages.iter_mut() { + buf.extend_from_slice(&msg.payload); + msg.is_retransmit = true; + msg.sent_at = now; + } + let _ = self.link_outbox.send(Arc::new(buf)); + + if self.retransmit_now { + self.ignore_nak_seq = self.receive_sequence; + if self.receive_sequence < self.retransmit_seq { + self.ignore_nak_seq = self.retransmit_seq; + } + self.retransmit_now = false; + } else { + self.rtt_state.rto = (self.rtt_state.rto * 2.0).clamp(MIN_RTO, MAX_RTO); + self.ignore_nak_seq = self.send_sequence; + } + self.retransmit_seq = self.send_sequence; + } + + async fn protocol_handler(&mut self) -> Result<(), TransportError> { + loop { + if self.retransmit_now { + self.retransmit_pending(); + } + + let retransmit_deadline = self + .inflight_messages + .front() + .map(|msg| msg.sent_at + self.rtt_state.rto()); + let retransmit_timeout: futures::future::OptionFuture<_> = + retransmit_deadline.map(sleep_until).into(); + pin!(retransmit_timeout); + + let corked_timeout: futures::future::OptionFuture<_> = + self.corked_until.map(sleep_until).into(); + pin!(corked_timeout); + + select! { + frame = self.link_inbox.recv() => { + let frame = match frame { + Some(frame) => frame, + None => break, + }; + self.handle_frame(frame); + }, + + msg = self.app_outbox.recv() => { + match msg { + Some(TransportCommand::SendMessage(msg)) => { + self.ready_messages.send(msg); + }, + Some(TransportCommand::Exit) => { + self.cancel.cancel(); + } + None => break, + }; + }, + + _ = self.ready_messages.recv_peek(), if self.can_send() => { + self.send_more_frames()?; + }, + + _ = &mut retransmit_timeout, if retransmit_deadline.is_some() => { + self.retransmit_now = true; + }, + + _ = &mut corked_timeout, if self.corked_until.is_some() => { + // Timeout for when we are able to send again + } + + _ = self.cancel.cancelled() => { + break; + }, + } + } + Ok(()) + } +} + +#[derive(Debug)] +struct MessageQueue { + sender: UnboundedSender, + receiver: UnboundedReceiver, + peeked: Option, +} + +impl MessageQueue { + fn new() -> Self { + let (sender, receiver) = unbounded_channel(); + Self { + sender, + receiver, + peeked: None, + } + } + + async fn recv_peek(&mut self) -> Option<&T> { + if self.peeked.is_some() { + self.peeked.as_ref() + } else { + match self.receiver.recv().await { + Some(msg) => { + self.peeked = Some(msg); + self.peeked.as_ref() + } + None => None, + } + } + } + + fn try_recv(&mut self) -> Option { + if let Some(msg) = self.peeked.take() { + Some(msg) + } else { + self.receiver.try_recv().ok() + } + } + + fn try_peek(&mut self) -> Option<&T> { + if self.peeked.is_some() { + self.peeked.as_ref() + } else { + match self.receiver.try_recv() { + Ok(msg) => { + self.peeked = Some(msg); + self.peeked.as_ref() + } + Err(_) => None, + } + } + } + + fn send(&mut self, msg: T) { + let _ = self.sender.send(msg); + } +} -- cgit v1.2.3