summaryrefslogtreecommitdiff
path: root/crates
diff options
context:
space:
mode:
authortcmal <me@aria.rip>2024-09-12 16:33:19 +0100
committertcmal <me@aria.rip>2024-09-17 16:12:30 +0100
commit5f9fbea5a1b08962887d16457e77a922919a8818 (patch)
treee326db0d8b56f41d4f62ffaf909311363c59cae4 /crates
parent9a7042df7f9a7ad0324c4442ae918c2a4a442966 (diff)
Vendor windlass (control-side klipper protocol implementation)
Diffstat (limited to 'crates')
-rw-r--r--crates/windlass/.gitignore2
-rw-r--r--crates/windlass/Cargo.toml15
-rw-r--r--crates/windlass/LICENSE.txt21
-rw-r--r--crates/windlass/README.md18
-rw-r--r--crates/windlass/examples/usb_test.rs60
-rw-r--r--crates/windlass/src/dictionary.rs131
-rw-r--r--crates/windlass/src/encoding.rs327
-rw-r--r--crates/windlass/src/lib.rs21
-rw-r--r--crates/windlass/src/macros.rs200
-rw-r--r--crates/windlass/src/mcu.rs585
-rw-r--r--crates/windlass/src/messages.rs237
-rw-r--r--crates/windlass/src/transport.rs615
12 files changed, 2232 insertions, 0 deletions
diff --git a/crates/windlass/.gitignore b/crates/windlass/.gitignore
new file mode 100644
index 0000000..4fffb2f
--- /dev/null
+++ b/crates/windlass/.gitignore
@@ -0,0 +1,2 @@
+/target
+/Cargo.lock
diff --git a/crates/windlass/Cargo.toml b/crates/windlass/Cargo.toml
new file mode 100644
index 0000000..2c5d436
--- /dev/null
+++ b/crates/windlass/Cargo.toml
@@ -0,0 +1,15 @@
+[package]
+name = "windlass"
+version = "0.1.0"
+edition = "2021"
+
+[dependencies]
+futures = "0.3"
+thiserror = "1"
+paste = "1"
+flate2 = "1"
+tokio = { workspace = true, features = ["sync", "io-util", "time", "macros", "rt"] }
+tokio-util = { workspace = true }
+serde = { workspace = true, features = ["derive"] }
+serde_json = { workspace = true }
+tracing = { workspace = true }
diff --git a/crates/windlass/LICENSE.txt b/crates/windlass/LICENSE.txt
new file mode 100644
index 0000000..2b6864b
--- /dev/null
+++ b/crates/windlass/LICENSE.txt
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2023 Lasse Dalegaard
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/crates/windlass/README.md b/crates/windlass/README.md
new file mode 100644
index 0000000..5d35c7d
--- /dev/null
+++ b/crates/windlass/README.md
@@ -0,0 +1,18 @@
+# Windlass
+
+This code was originally sourced from [https://github.com/Annex-Engineering/windlass/tree/master](annex-engineering).
+
+Windlass is an implementation of the host side of the Klipper protocol.
+
+## Licensing
+
+Windlass is licensed under the MIT license, which means you can do pretty much
+whatever you want with it. Please see [LICENSE.txt](LICENSE.txt) for more
+information.
+
+## Acknowledgements
+
+This project is in no way endorsed by the Klipper project. Please do not direct
+any support requests to the Klipper project.
+
+ * [Klipper](https://www.klipper3d.org/) by [Kevin O'Connor](https://www.patreon.com/koconnor)
diff --git a/crates/windlass/examples/usb_test.rs b/crates/windlass/examples/usb_test.rs
new file mode 100644
index 0000000..6aee77b
--- /dev/null
+++ b/crates/windlass/examples/usb_test.rs
@@ -0,0 +1,60 @@
+use tokio::time::Duration;
+use windlass::{mcu_command, mcu_reply, McuConnection};
+
+mcu_command!(DebugNop, "debug_nop");
+mcu_command!(GetClock, "get_clock");
+mcu_reply!(Clock, "clock", clock: u32);
+mcu_command!(GetUptime, "get_uptime");
+mcu_reply!(Uptime, "uptime", high: u32, clock: u32);
+mcu_reply!(
+ AdxlStatus,
+ "adxl345_status",
+ oid: u8,
+ sequence: u16,
+ data: Vec<u8>,
+);
+mcu_reply!(Stats, "stats", count: u32, sum: u32, sumsq: u32);
+
+mcu_command!(
+ ConfigEndstop,
+ "config_endstop",
+ oid: u8,
+ pin: u8,
+ pull_up: u8
+);
+
+#[tokio::main]
+async fn main() {
+ tracing_subscriber::fmt::init();
+
+ let target = std::env::args().nth(1).expect("Missing path");
+ let builder = tokio_serial::new(target, 96_000);
+ let port = tokio_serial::SerialStream::open(&builder).expect("USB open");
+
+ let mut mcu = McuConnection::connect(port).await.expect("MCU connect");
+
+ let mut resp = mcu.register_response(Stats).expect("Register");
+
+ let clock = mcu.send_receive(GetClock::encode(), Clock);
+ let uptime = mcu.send_receive(GetUptime::encode(), Uptime).await;
+ println!("Uptime: {uptime:?}");
+ let clock = clock.await;
+ println!("Clock: {clock:?}");
+
+ let timeout = tokio::time::sleep(Duration::from_secs(10));
+ tokio::pin!(timeout);
+
+ loop {
+ tokio::select! {
+ _ = &mut timeout => break,
+ s = resp.recv() => {
+ println!("Stats: {s:?}");
+ }
+ err = mcu.closed() => {
+ println!("MCU connection lost: {err:?}");
+ break;
+ }
+ }
+ }
+ mcu.close().await;
+}
diff --git a/crates/windlass/src/dictionary.rs b/crates/windlass/src/dictionary.rs
new file mode 100644
index 0000000..5665f80
--- /dev/null
+++ b/crates/windlass/src/dictionary.rs
@@ -0,0 +1,131 @@
+use std::collections::BTreeMap;
+
+use crate::{
+ encoding::encode_vlq_int,
+ messages::{MessageParser, MessageSkipperError},
+};
+
+#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
+#[serde(untagged)]
+pub enum ConfigVar {
+ String(String),
+ Number(f64),
+}
+
+#[derive(Debug, serde::Deserialize)]
+#[serde(untagged)]
+pub enum Enumeration {
+ Single(i64),
+ Range(i64, i64),
+}
+
+#[derive(Debug, serde::Deserialize)]
+pub(crate) struct RawDictionary {
+ #[serde(default)]
+ config: BTreeMap<String, ConfigVar>,
+
+ #[serde(default)]
+ enumerations: BTreeMap<String, BTreeMap<String, Enumeration>>,
+
+ #[serde(default)]
+ commands: BTreeMap<String, i16>,
+ #[serde(default)]
+ responses: BTreeMap<String, i16>,
+ #[serde(default)]
+ output: BTreeMap<String, i16>,
+
+ #[serde(default)]
+ build_versions: Option<String>,
+ #[serde(default)]
+ version: Option<String>,
+
+ #[serde(flatten)]
+ extra: BTreeMap<String, serde_json::Value>,
+}
+
+/// Dictionary error
+#[derive(thiserror::Error, Debug)]
+pub enum DictionaryError {
+ /// Found an empty command
+ #[error("empty command found")]
+ EmptyCommand,
+ /// Received a command in an invalid format
+ #[error("invalid command format: {0}")]
+ InvalidCommandFormat(String, MessageSkipperError),
+ /// Received an output string with an invalid format
+ #[error("invalid output format: {0}")]
+ InvalidOutputFormat(String, MessageSkipperError),
+ /// Received a command with an invalid tag
+ #[error("command tag {0} output valid range of -32..95")]
+ InvalidCommandTag(u16),
+}
+
+#[derive(Debug)]
+pub struct Dictionary {
+ pub message_ids: BTreeMap<String, u16>,
+ pub message_parsers: BTreeMap<u16, MessageParser>,
+ pub config: BTreeMap<String, ConfigVar>,
+ pub enumerations: BTreeMap<String, BTreeMap<String, Enumeration>>,
+ pub build_versions: Option<String>,
+ pub version: Option<String>,
+ pub extra: BTreeMap<String, serde_json::Value>,
+}
+
+impl Dictionary {
+ pub(crate) fn from_raw_dictionary(raw: RawDictionary) -> Result<Self, DictionaryError> {
+ let mut message_ids = BTreeMap::new();
+ let mut message_parsers = BTreeMap::new();
+
+ for (cmd, tag) in raw.commands {
+ let mut split = cmd.split(' ');
+ let name = split.next().ok_or(DictionaryError::EmptyCommand)?;
+ let parser = MessageParser::new(name, split)
+ .map_err(|e| DictionaryError::InvalidCommandFormat(name.to_string(), e))?;
+ let tag = Self::map_tag(tag)?;
+ message_parsers.insert(tag, parser);
+ message_ids.insert(name.to_string(), tag);
+ }
+
+ for (resp, tag) in raw.responses {
+ let mut split = resp.split(' ');
+ let name = split.next().ok_or(DictionaryError::EmptyCommand)?;
+ let parser = MessageParser::new(name, split)
+ .map_err(|e| DictionaryError::InvalidCommandFormat(name.to_string(), e))?;
+ let tag = Self::map_tag(tag)?;
+ message_parsers.insert(tag, parser);
+ message_ids.insert(name.to_string(), tag);
+ }
+
+ for (msg, tag) in raw.output {
+ let parser = MessageParser::new_output(&msg)
+ .map_err(|e| DictionaryError::InvalidCommandFormat(msg.to_string(), e))?;
+ let tag = Self::map_tag(tag)?;
+ message_parsers.insert(tag, parser);
+ }
+
+ Ok(Dictionary {
+ message_ids,
+ message_parsers,
+ config: raw.config,
+ enumerations: raw.enumerations,
+ build_versions: raw.build_versions,
+ version: raw.version,
+ extra: raw.extra,
+ })
+ }
+
+ fn map_tag(tag: i16) -> Result<u16, DictionaryError> {
+ let mut buf = vec![];
+ encode_vlq_int(&mut buf, tag as u32);
+ let v = if buf.len() > 1 {
+ ((buf[0] as u16) & 0x7F) << 7 | (buf[1] as u16) & 0x7F
+ } else {
+ (buf[0] as u16) & 0x7F
+ };
+ if v >= 1 << 14 {
+ Err(DictionaryError::InvalidCommandTag(v))
+ } else {
+ Ok(v)
+ }
+ }
+}
diff --git a/crates/windlass/src/encoding.rs b/crates/windlass/src/encoding.rs
new file mode 100644
index 0000000..0950956
--- /dev/null
+++ b/crates/windlass/src/encoding.rs
@@ -0,0 +1,327 @@
+use std::fmt::Display;
+
+use crate::messages::MessageSkipperError;
+
+/// Message decoding error
+#[derive(thiserror::Error, Debug, Clone)]
+pub enum MessageDecodeError {
+ /// More data was expected but none is available
+ #[error("eof unexpected")]
+ UnexpectedEof,
+ /// A received string could not be decoded as UTF8
+ #[error("invalid utf8 string")]
+ Utf8Error(#[from] std::str::Utf8Error),
+}
+
+pub(crate) fn encode_vlq_int(output: &mut Vec<u8>, v: u32) {
+ let sv = v as i32;
+ if !(-(1 << 26)..(3 << 26)).contains(&sv) {
+ output.push(((sv >> 28) & 0x7F) as u8 | 0x80);
+ }
+ if !(-(1 << 19)..(3 << 19)).contains(&sv) {
+ output.push(((sv >> 21) & 0x7F) as u8 | 0x80);
+ }
+ if !(-(1 << 12)..(3 << 12)).contains(&sv) {
+ output.push(((sv >> 14) & 0x7F) as u8 | 0x80);
+ }
+ if !(-(1 << 5)..(3 << 5)).contains(&sv) {
+ output.push(((sv >> 7) & 0x7F) as u8 | 0x80);
+ }
+ output.push((sv & 0x7F) as u8);
+}
+
+pub(crate) fn next_byte(data: &mut &[u8]) -> Result<u8, MessageDecodeError> {
+ if data.is_empty() {
+ Err(MessageDecodeError::UnexpectedEof)
+ } else {
+ let v = data[0];
+ *data = &data[1..];
+ Ok(v)
+ }
+}
+
+pub(crate) fn parse_vlq_int(data: &mut &[u8]) -> Result<u32, MessageDecodeError> {
+ let mut c = next_byte(data)? as u32;
+ let mut v = c & 0x7F;
+ if (c & 0x60) == 0x60 {
+ v |= (-0x20i32) as u32;
+ }
+ while c & 0x80 != 0 {
+ c = next_byte(data)? as u32;
+ v = (v << 7) | (c & 0x7F);
+ }
+
+ Ok(v)
+}
+
+pub trait Readable<'de>: Sized {
+ fn read(data: &mut &'de [u8]) -> Result<Self, MessageDecodeError>;
+
+ fn skip(data: &mut &[u8]) -> Result<(), MessageDecodeError>;
+}
+
+pub trait Writable: Sized {
+ fn write(&self, output: &mut Vec<u8>);
+}
+
+pub trait Borrowable: Sized {
+ type Borrowed<'a>
+ where
+ Self: 'a;
+ fn from_borrowed(src: Self::Borrowed<'_>) -> Self;
+}
+
+pub trait ToFieldType: Sized {
+ fn as_field_type() -> FieldType;
+}
+
+macro_rules! int_readwrite {
+ ( $type:tt, $field_type:expr ) => {
+ impl Readable<'_> for $type {
+ fn read(data: &mut &[u8]) -> Result<Self, MessageDecodeError> {
+ parse_vlq_int(data).map(|v| v as $type)
+ }
+
+ fn skip(data: &mut &[u8]) -> Result<(), MessageDecodeError> {
+ parse_vlq_int(data).map(|_| ())
+ }
+ }
+
+ impl Writable for $type {
+ fn write(&self, output: &mut Vec<u8>) {
+ encode_vlq_int(output, *self as u32)
+ }
+ }
+
+ impl Borrowable for $type {
+ type Borrowed<'a> = Self;
+ fn from_borrowed(src: Self::Borrowed<'_>) -> Self {
+ src
+ }
+ }
+
+ impl ToFieldType for $type {
+ fn as_field_type() -> FieldType {
+ $field_type
+ }
+ }
+ };
+}
+
+int_readwrite!(u32, FieldType::U32);
+int_readwrite!(i32, FieldType::I32);
+int_readwrite!(u16, FieldType::U16);
+int_readwrite!(i16, FieldType::I16);
+int_readwrite!(u8, FieldType::U8);
+
+impl Readable<'_> for bool {
+ fn read(data: &mut &[u8]) -> Result<Self, MessageDecodeError> {
+ parse_vlq_int(data).map(|v| v != 0)
+ }
+
+ fn skip(data: &mut &[u8]) -> Result<(), MessageDecodeError> {
+ parse_vlq_int(data).map(|_| ())
+ }
+}
+
+impl Writable for bool {
+ fn write(&self, output: &mut Vec<u8>) {
+ encode_vlq_int(output, u32::from(*self))
+ }
+}
+
+impl Borrowable for bool {
+ type Borrowed<'a> = Self;
+ fn from_borrowed(src: Self::Borrowed<'_>) -> Self {
+ src
+ }
+}
+
+impl ToFieldType for bool {
+ fn as_field_type() -> FieldType {
+ FieldType::U8
+ }
+}
+
+impl<'de> Readable<'de> for &'de [u8] {
+ fn read(data: &mut &'de [u8]) -> Result<&'de [u8], MessageDecodeError> {
+ let len = parse_vlq_int(data)? as usize;
+ if data.len() < len {
+ Err(MessageDecodeError::UnexpectedEof)
+ } else {
+ let ret = &data[..len];
+ *data = &data[len..];
+ Ok(ret)
+ }
+ }
+
+ fn skip(data: &mut &[u8]) -> Result<(), MessageDecodeError> {
+ let len = parse_vlq_int(data)? as usize;
+ if data.len() < len {
+ Err(MessageDecodeError::UnexpectedEof)
+ } else {
+ *data = &data[len..];
+ Ok(())
+ }
+ }
+}
+
+impl Writable for &[u8] {
+ fn write(&self, output: &mut Vec<u8>) {
+ encode_vlq_int(output, self.len() as u32);
+ output.extend_from_slice(self);
+ }
+}
+
+impl Borrowable for Vec<u8> {
+ type Borrowed<'a> = &'a [u8];
+ fn from_borrowed(src: Self::Borrowed<'_>) -> Self {
+ src.into()
+ }
+}
+
+impl ToFieldType for Vec<u8> {
+ fn as_field_type() -> FieldType {
+ FieldType::ByteArray
+ }
+}
+
+impl<'de> Readable<'de> for &'de str {
+ fn read(data: &mut &'de [u8]) -> Result<&'de str, MessageDecodeError> {
+ let len = parse_vlq_int(data)? as usize;
+ if data.len() < len {
+ Err(MessageDecodeError::UnexpectedEof)
+ } else {
+ let ret = &data[..len];
+ *data = &data[len..];
+ Ok(std::str::from_utf8(ret)?)
+ }
+ }
+
+ fn skip(data: &mut &[u8]) -> Result<(), MessageDecodeError> {
+ let len = parse_vlq_int(data)? as usize;
+ if data.len() < len {
+ Err(MessageDecodeError::UnexpectedEof)
+ } else {
+ *data = &data[len..];
+ Ok(())
+ }
+ }
+}
+
+impl Writable for &str {
+ fn write(&self, output: &mut Vec<u8>) {
+ let bytes = self.as_bytes();
+ encode_vlq_int(output, bytes.len() as u32);
+ output.extend_from_slice(bytes);
+ }
+}
+
+impl Borrowable for String {
+ type Borrowed<'a> = &'a str;
+ fn from_borrowed(src: Self::Borrowed<'_>) -> Self {
+ src.to_string()
+ }
+}
+
+impl ToFieldType for String {
+ fn as_field_type() -> FieldType {
+ FieldType::String
+ }
+}
+
+#[derive(Copy, Clone, Debug, Eq, PartialEq)]
+pub enum FieldType {
+ U32,
+ I32,
+ U16,
+ I16,
+ U8,
+ String,
+ ByteArray,
+}
+
+impl FieldType {
+ pub(crate) fn skip(&self, input: &mut &[u8]) -> Result<(), MessageDecodeError> {
+ match self {
+ Self::U32 => <u32 as Readable>::skip(input),
+ Self::I32 => <i32 as Readable>::skip(input),
+ Self::U16 => <u16 as Readable>::skip(input),
+ Self::I16 => <i16 as Readable>::skip(input),
+ Self::U8 => <u8 as Readable>::skip(input),
+ Self::String => <&str as Readable>::skip(input),
+ Self::ByteArray => <&[u8] as Readable>::skip(input),
+ }
+ }
+
+ pub(crate) fn read(&self, input: &mut &[u8]) -> Result<FieldValue, MessageDecodeError> {
+ Ok(match self {
+ Self::U32 => FieldValue::U32(<u32 as Readable>::read(input)?),
+ Self::I32 => FieldValue::I32(<i32 as Readable>::read(input)?),
+ Self::U16 => FieldValue::U16(<u16 as Readable>::read(input)?),
+ Self::I16 => FieldValue::I16(<i16 as Readable>::read(input)?),
+ Self::U8 => FieldValue::U8(<u8 as Readable>::read(input)?),
+ Self::String => FieldValue::String(<&str as Readable>::read(input)?.into()),
+ Self::ByteArray => FieldValue::ByteArray(<&[u8] as Readable>::read(input)?.into()),
+ })
+ }
+
+ pub(crate) fn from_msg(s: &str) -> Result<Self, MessageSkipperError> {
+ match s {
+ "%u" => Ok(Self::U32),
+ "%i" => Ok(Self::I32),
+ "%hu" => Ok(Self::U16),
+ "%hi" => Ok(Self::I16),
+ "%c" => Ok(Self::U8),
+ "%s" => Ok(Self::String),
+ "%*s" => Ok(Self::ByteArray),
+ "%.*s" => Ok(Self::ByteArray),
+ s => Err(MessageSkipperError::InvalidFormatFieldType(s.to_string())),
+ }
+ }
+
+ pub(crate) fn from_format(s: &str) -> Result<(Self, &str), MessageSkipperError> {
+ if let Some(rest) = s.strip_prefix("%u") {
+ Ok((Self::U32, rest))
+ } else if let Some(rest) = s.strip_prefix("%i") {
+ Ok((Self::I32, rest))
+ } else if let Some(rest) = s.strip_prefix("%hu") {
+ Ok((Self::U16, rest))
+ } else if let Some(rest) = s.strip_prefix("%hi") {
+ Ok((Self::I16, rest))
+ } else if let Some(rest) = s.strip_prefix("%c") {
+ Ok((Self::U8, rest))
+ } else if let Some(rest) = s.strip_prefix("%.*s") {
+ Ok((Self::ByteArray, rest))
+ } else if let Some(rest) = s.strip_prefix("%*s") {
+ Ok((Self::String, rest))
+ } else {
+ Err(MessageSkipperError::InvalidFormatFieldType(s.to_string()))
+ }
+ }
+}
+
+#[derive(Debug)]
+pub enum FieldValue {
+ U32(u32),
+ I32(i32),
+ U16(u16),
+ I16(i16),
+ U8(u8),
+ String(String),
+ ByteArray(Vec<u8>),
+}
+
+impl Display for FieldValue {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ Self::U32(v) => write!(f, "{}", v),
+ Self::I32(v) => write!(f, "{}", v),
+ Self::U16(v) => write!(f, "{}", v),
+ Self::I16(v) => write!(f, "{}", v),
+ Self::U8(v) => write!(f, "{}", v),
+ Self::String(v) => write!(f, "{}", v),
+ Self::ByteArray(v) => write!(f, "{:?}", v),
+ }
+ }
+}
diff --git a/crates/windlass/src/lib.rs b/crates/windlass/src/lib.rs
new file mode 100644
index 0000000..c578bf2
--- /dev/null
+++ b/crates/windlass/src/lib.rs
@@ -0,0 +1,21 @@
+#[macro_use]
+#[doc(hidden)]
+pub mod macros;
+
+#[doc(hidden)]
+pub mod dictionary;
+
+#[doc(hidden)]
+pub mod encoding;
+
+mod mcu;
+
+#[doc(hidden)]
+pub mod messages;
+
+mod transport;
+
+pub use dictionary::DictionaryError;
+pub use encoding::MessageDecodeError;
+pub use mcu::{McuConnection, McuConnectionError};
+pub use messages::EncodedMessage;
diff --git a/crates/windlass/src/macros.rs b/crates/windlass/src/macros.rs
new file mode 100644
index 0000000..6f57a70
--- /dev/null
+++ b/crates/windlass/src/macros.rs
@@ -0,0 +1,200 @@
+pub use crate::transport::MESSAGE_LENGTH_PAYLOAD_MAX;
+
+#[macro_export]
+#[doc(hidden)]
+macro_rules! mcu_message_impl_oid_check {
+ ($ty_name:ident) => {
+ impl $crate::messages::WithoutOid for $ty_name {}
+ };
+ ($ty_name:ident, oid $(, $args:ident)*) => {
+ impl $crate::messages::WithOid for $ty_name {}
+ };
+ ($ty_name:ident, $arg:ident $(, $args:ident)*) => {
+ $crate::mcu_message_impl_oid_check!($ty_name $(, $args)*);
+ };
+}
+
+#[macro_export]
+#[doc(hidden)]
+macro_rules! mcu_message_impl {
+ ($ty_name:ident, $cmd_name:literal = $cmd_id:expr $(, $arg:ident : $kind:ty)*) => {
+ paste::paste! {
+ #[derive(Debug)]
+ struct $ty_name;
+
+ #[allow(dead_code)]
+ impl $ty_name {
+ #[allow(clippy::extra_unused_lifetimes)]
+ pub fn encode<'a>($($arg: <$kind as $crate::encoding::Borrowable>::Borrowed<'a>, )*) -> $crate::messages::EncodedMessage<Self> {
+ let payload = [<$ty_name Data>] {
+ $($arg,)*
+ _lifetime: Default::default(),
+ }.encode();
+ $crate::messages::EncodedMessage {
+ payload,
+ _message_kind: Default::default(),
+ }
+ }
+
+ pub fn decode<'a>(input: &mut &'a [u8]) -> Result<[<$ty_name Data>]<'a>, $crate::encoding::MessageDecodeError> {
+ [<$ty_name Data>]::decode(input)
+ }
+ }
+
+ #[allow(dead_code)]
+ struct [<$ty_name Data>]<'a> {
+ // $(pub $arg: $kind,)*
+ $(pub $arg: <$kind as $crate::encoding::Borrowable>::Borrowed<'a>,)*
+
+ _lifetime: std::marker::PhantomData<&'a ()>,
+ }
+
+ #[allow(dead_code)]
+ impl<'a> [<$ty_name Data>]<'a> {
+ fn encode(&self) -> $crate::messages::FrontTrimmableBuffer {
+ use $crate::encoding::Writable;
+ let mut buf = Vec::with_capacity($crate::macros::MESSAGE_LENGTH_PAYLOAD_MAX);
+ buf.push(0);
+ buf.push(0);
+ $(self.$arg.write(&mut buf);)*
+ $crate::messages::FrontTrimmableBuffer { content: buf, offset: 0 }
+ }
+
+ fn decode(input: &mut &'a [u8]) -> Result<Self, $crate::encoding::MessageDecodeError> {
+ $(let $arg = $crate::encoding::Readable::read(input)?;)*
+ Ok(Self { $($arg,)* _lifetime: Default::default() })
+ }
+
+ }
+
+ impl<'a> std::fmt::Debug for [<$ty_name Data>]<'a> {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ let mut ds = f.debug_struct(stringify!([<$ty_name Data>]));
+ $( let ds = ds.field(stringify!($arg), &self.$arg); )*
+ ds.finish()
+ }
+ }
+
+ #[derive(Clone)]
+ #[allow(dead_code)]
+ struct [<$ty_name DataOwned>] {
+ $(pub $arg: $kind,)*
+ }
+
+ impl<'a> std::convert::From<[<$ty_name Data>]<'a>> for [<$ty_name DataOwned>] {
+ fn from(value: [<$ty_name Data>]) -> Self {
+ Self {
+ $($arg: $crate::encoding::Borrowable::from_borrowed(value.$arg),)*
+ }
+ }
+ }
+
+ impl std::fmt::Debug for [<$ty_name DataOwned>] {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ let mut ds = f.debug_struct(stringify!([<$ty_name DataOwned>]));
+ $( let ds = ds.field(stringify!($arg), &self.$arg); )*
+ ds.finish()
+ }
+ }
+
+ #[allow(dead_code)]
+ impl $crate::messages::Message for $ty_name {
+ type Pod<'a> = [<$ty_name Data>]<'a>;
+ type PodOwned = [<$ty_name DataOwned>];
+
+ fn get_id(dict: Option<&$crate::dictionary::Dictionary>) -> Option<u16> {
+ $cmd_id.or_else(|| dict.and_then(|dict| dict.message_ids.get($cmd_name).copied()))
+ }
+
+ fn get_name() -> &'static str {
+ $cmd_name
+ }
+
+ fn decode<'a>(input: &mut &'a [u8]) -> Result<Self::Pod<'a>, $crate::encoding::MessageDecodeError> {
+ Self::decode(input)
+ }
+
+ fn fields<'a>() -> Vec<(&'static str, $crate::encoding::FieldType)> {
+ vec![
+ $( ( stringify!($arg), <$kind as $crate::encoding::ToFieldType>::as_field_type() ), )*
+ ]
+ }
+ }
+
+ $crate::mcu_message_impl_oid_check!($ty_name $(, $arg)*);
+ }
+ };
+}
+
+/// Declare a host -> device message
+///
+/// Declares the format of a message sent from host to device. The general format is as follows:
+///
+/// ```text
+/// mcu_command!(<struct name>, "<command name>"( = <id>) [, arg: type, ..]);
+/// ```
+///
+/// A struct with the given name will be defined, with relevant interfaces. The message name and
+/// arguments will be matched to the MCU at runtime. The struct name has no restrictions, but it is
+/// recommended to pick a SnakeCased version of the command name.
+///
+/// Optionally an `id` can be directly specified. Generally this is not needed. When not specified,
+/// it will be automatically inferred at runtime and matched to the dictionary retrieved from the
+/// target MCU.
+///
+/// Arguments can be specified, they will be mapped to the relevant Klipper argument types. The
+/// supported types and mappings are as follows:
+///
+/// | Rust type | Format string |
+/// |------------|---------------|
+/// | `u32` | `%u` |
+/// | `i32` | `%i` |
+/// | `u16` | `%hu` |
+/// | `i16` | `%hi` |
+/// | `u8` | `%c` |
+/// | `&'a [u8]` | `%.*s`, `%*s` |
+/// | `&'a str` | `%s` |
+///
+/// Note that the buffer types take a lifetime. This must always be `'a`.
+///
+/// # Examples
+///
+/// ```ignore
+/// // This defines 'config_endstop oid=%c pin=%c pull_up=%c'
+/// mcu_command!(ConfigEndstop, "config_endstop", oid: u8, pin: u8, pull_up: u8);
+/// ```
+#[macro_export]
+macro_rules! mcu_command {
+ ($ty_name:ident, $cmd_name:literal $(, $arg:ident : $kind:ty)* $(,)?) => {
+ $crate::mcu_message_impl!($ty_name, $cmd_name = None $(, $arg: $kind)*);
+ };
+ ($ty_name:ident, $cmd_name:literal = $cmd_id:literal $(, $arg:ident : $kind:ty)* $(,)?) => {
+ $crate::mcu_message_impl!($ty_name, $cmd_name = Some($cmd_id) $(, $arg: $kind)*);
+ };
+}
+
+/// Declare a device -> host message
+///
+/// Declares the format of a message sent from device to host. The general format is as follows:
+///
+/// ```text
+/// mcu_reply!(<struct name>, "<reply name>"( = <id>) [, arg: type, ..]);
+/// ```
+///
+/// For more information on the various fields, see the documentation for [mcu_command].
+///
+/// # Examples
+///
+/// ```ignore
+/// // This defines 'config_endstop oid=%c pin=%c pull_up=%c'
+/// mcu_reply!(Uptime, "uptime", high: u32, clock: u32);
+/// ```
+#[macro_export]
+macro_rules! mcu_reply {
+ ($ty_name:ident, $cmd_name:literal $(, $arg:ident : $kind:ty)* $(,)?) => {
+ $crate::mcu_message_impl!($ty_name, $cmd_name = None $(, $arg: $kind)*);
+ };
+ ($ty_name:ident, $cmd_name:literal = $cmd_id:literal $(, $arg:ident : $kind:ty)* $(,)?) => {
+ $crate::mcu_message_impl!($ty_name, $cmd_name = Some($cmd_id) $(, $arg: $kind)*);
+ };
+}
diff --git a/crates/windlass/src/mcu.rs b/crates/windlass/src/mcu.rs
new file mode 100644
index 0000000..53fea7f
--- /dev/null
+++ b/crates/windlass/src/mcu.rs
@@ -0,0 +1,585 @@
+use std::{
+ any::TypeId,
+ collections::{BTreeMap, BTreeSet},
+ io::Read,
+ sync::{atomic::AtomicUsize, Arc, Mutex},
+ time::{Duration, Instant},
+};
+
+use tokio::{
+ io::{AsyncRead, AsyncWrite},
+ select,
+ sync::{
+ mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
+ oneshot, OnceCell,
+ },
+ task::{spawn, JoinHandle},
+};
+use tracing::{debug, error, trace};
+
+use crate::messages::{format_command_args, EncodedMessage, Message, WithOid, WithoutOid};
+use crate::transport::{Transport, TransportReceiver};
+use crate::{
+ dictionary::{Dictionary, DictionaryError, RawDictionary},
+ transport::TransportError,
+};
+use crate::{
+ encoding::{parse_vlq_int, MessageDecodeError},
+ messages::FrontTrimmableBuffer,
+};
+
+mcu_command!(Identify, "identify" = 1, offset: u32, count: u8);
+mcu_reply!(
+ IdentifyResponse,
+ "identify_response" = 0,
+ offset: u32,
+ data: Vec<u8>,
+);
+
+/// MCU Connection Errors
+#[derive(thiserror::Error, Debug)]
+pub enum McuConnectionError {
+ /// Encoding a message was attempted but a command with a matching name doesn't exist in the
+ /// dictionary.
+ #[error("unknown message ID for command '{0}'")]
+ UnknownMessageId(&'static str),
+ /// An issue was encountered while decoding the received frame
+ #[error("error decoding message with ID '{0}'")]
+ DecodingError(#[from] MessageDecodeError),
+ /// An error was encountered while fetching the dictionary
+ #[error("error obtaining identify data: {0}")]
+ DictionaryFetch(Box<dyn std::error::Error + Send + Sync>),
+ /// An error was encountered while parsing the dictionary
+ #[error("dictionary issue: {0}")]
+ Dictionary(#[from] DictionaryError),
+ /// Received an unknown command from the MCU
+ #[error("unknown command {0}")]
+ UnknownCommand(u16),
+ /// There was a mismatch between the command arguments on the remote and local sides.
+ #[error("mismatched command {0}: '{1}' vs '{2}'")]
+ CommandMismatch(&'static str, String, String),
+ /// There was a transport-level issue
+ #[error("transport error: {0}")]
+ Transport(#[from] TransportError),
+}
+
+#[derive(thiserror::Error, Debug)]
+pub enum SendReceiveError {
+ #[error("timeout")]
+ Timeout,
+ #[error("mcu connection error: {0}")]
+ McuConnection(#[from] McuConnectionError),
+}
+
+#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash, Ord, PartialOrd)]
+struct HandlerId(u16, Option<u8>);
+
+trait Handler: Send + Sync {
+ fn handle(&mut self, data: &mut &[u8]) -> HandlerResult;
+}
+
+#[derive(Debug, Clone, Copy, Eq, PartialEq)]
+pub struct RegisteredResponseHandle(usize);
+
+#[derive(Default)]
+struct Handlers {
+ handlers: std::sync::Mutex<BTreeMap<HandlerId, UniqueHandler>>,
+ next_handler: AtomicUsize,
+}
+
+type UniqueHandler = (usize, Box<dyn Handler>);
+
+impl std::fmt::Debug for Handlers {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ let h = self
+ .handlers
+ .try_lock()
+ .map(|h| h.keys().cloned().collect::<BTreeSet<_>>())
+ .unwrap_or_default();
+ f.debug_struct("Handlers")
+ .field("handlers", &h)
+ .finish_non_exhaustive()
+ }
+}
+
+impl Handlers {
+ fn register(
+ &self,
+ command: u16,
+ oid: Option<u8>,
+ handler: Box<dyn Handler>,
+ ) -> RegisteredResponseHandle {
+ let id = HandlerId(command, oid);
+ let uniq = self
+ .next_handler
+ .fetch_add(1, std::sync::atomic::Ordering::AcqRel);
+ self.handlers.lock().unwrap().insert(id, (uniq, handler));
+ RegisteredResponseHandle(uniq)
+ }
+
+ fn call(&self, command: u16, oid: Option<u8>, data: &mut &[u8]) -> bool {
+ let id = HandlerId(command, oid);
+ let mut handlers = self.handlers.lock().unwrap();
+ let handler = handlers.get_mut(&id);
+ if let Some((_, handler)) = handler {
+ if matches!(handler.handle(data), HandlerResult::Deregister) {
+ handlers.remove(&id);
+ }
+ true
+ } else {
+ false
+ }
+ }
+
+ fn remove_handler(
+ &self,
+ command: u16,
+ oid: Option<u8>,
+ uniq: Option<RegisteredResponseHandle>,
+ ) {
+ let id = HandlerId(command, oid);
+ let mut handlers = self.handlers.lock().unwrap();
+ if let Some(RegisteredResponseHandle(uniq)) = uniq {
+ match handlers.get(&id) {
+ None => return,
+ Some((u, _)) if *u != uniq => return,
+ _ => {}
+ }
+ }
+ handlers.remove(&id);
+ }
+}
+
+#[derive(Debug)]
+struct McuConnectionInner {
+ dictionary: OnceCell<Dictionary>,
+ handlers: Handlers,
+ good_types: Mutex<BTreeSet<TypeId>>,
+}
+
+impl McuConnectionInner {
+ async fn receive_loop(&self, mut inbox: TransportReceiver) -> Result<(), McuConnectionError> {
+ while let Some(frame) = inbox.recv().await {
+ let frame = match frame {
+ Ok(frame) => frame,
+ Err(e) => return Err(McuConnectionError::Transport(e)),
+ };
+ let frame = &mut frame.as_slice();
+ self.parse_frame(frame)?;
+ }
+ Ok(())
+ }
+
+ fn parse_frame(&self, frame: &mut &[u8]) -> Result<(), McuConnectionError> {
+ while !frame.is_empty() {
+ let cmd = parse_vlq_int(frame)? as u16;
+
+ // If a raw handler is registered for this, call it
+ if self.handlers.call(cmd, None, frame) {
+ continue;
+ }
+
+ // There was no registered raw handler, check the dictionary to decode
+ if let Some(dict) = self.dictionary.get() {
+ if let Some(parser) = dict.message_parsers.get(&cmd) {
+ if let Some(msg) = parser.parse_and_output(frame) {
+ let msg = msg?;
+ debug!(msg = msg, "Output message from MCU");
+ continue;
+ }
+
+ let mut curmsg = &frame[..];
+ let oid = parser.skip_with_oid(frame)?;
+ if tracing::enabled!(tracing::Level::TRACE) {
+ #[allow(clippy::redundant_slicing)] // The parse call changes the slice!
+ let mut tmp = &curmsg[..];
+ trace!(
+ cmd_id = cmd,
+ cmd_name = parser.name,
+ data = ?parser.parse(&mut tmp).unwrap(), // Safe because we skipped before and it passed
+ "Received message",
+ );
+ }
+ if !self.handlers.call(cmd, oid, &mut curmsg) {
+ debug!(
+ cmd_id = cmd,
+ cmd_name = parser.name,
+ // Safe because we skipped before and it passed. Since call returned
+ // false, we know no handler took the data and can safely take curmsg
+ // here.
+ data = ?parser.parse(&mut curmsg).unwrap(),
+ "Unhandled message",
+ );
+ }
+ } else {
+ return Err(McuConnectionError::UnknownCommand(cmd));
+ }
+ } else {
+ break;
+ }
+ }
+ Ok(())
+ }
+}
+
+/// Manages a connection to a Klipper MCU
+///
+///
+#[derive(Debug)]
+pub struct McuConnection {
+ inner: Arc<McuConnectionInner>,
+ transport: Transport,
+ _receiver: JoinHandle<()>,
+ exit_code: ExitCode,
+}
+
+#[derive(Debug)]
+enum ExitCode {
+ Waiting(oneshot::Receiver<Result<(), McuConnectionError>>),
+ Exited(Result<(), McuConnectionError>),
+}
+
+impl McuConnection {
+ /// Connect to an MCU
+ ///
+ /// Attempts to connect to a MCU on the given data stream interface. The MCU is contacted and
+ /// the dictionary is loaded and applied. The returned [McuConnection] is fully ready to
+ /// communicate with the MCU.
+ pub async fn connect<R>(stream: R) -> Result<Self, McuConnectionError>
+ where
+ R: AsyncRead + AsyncWrite + Send + 'static,
+ {
+ let (transport, inbox) = Transport::connect(stream).await;
+
+ let (exit_tx, exit_rx) = oneshot::channel();
+ let inner = Arc::new(McuConnectionInner {
+ dictionary: OnceCell::new(),
+ handlers: Handlers::default(),
+ good_types: Mutex::default(),
+ });
+
+ let receiver_inner = inner.clone();
+ let receiver = tokio::spawn(async move {
+ let _ = exit_tx.send(receiver_inner.receive_loop(inbox).await);
+ });
+
+ let conn = McuConnection {
+ inner,
+ transport,
+ _receiver: receiver,
+ exit_code: ExitCode::Waiting(exit_rx),
+ };
+
+ let mut identify_data = Vec::new();
+ let identify_start = Instant::now();
+ loop {
+ let data = conn
+ .send_receive(
+ Identify::encode(identify_data.len() as u32, 40),
+ IdentifyResponse,
+ )
+ .await;
+ let mut data = match data {
+ Ok(data) => data,
+ Err(e) => {
+ if identify_start.elapsed() > Duration::from_secs(10) {
+ return Err(McuConnectionError::DictionaryFetch(Box::new(e)));
+ }
+ continue;
+ }
+ };
+ if data.offset as usize == identify_data.len() {
+ if data.data.is_empty() {
+ break;
+ }
+ identify_data.append(&mut data.data);
+ }
+ }
+ let mut decoder = flate2::read::ZlibDecoder::new(identify_data.as_slice());
+ let mut buf = Vec::new();
+ decoder
+ .read_to_end(&mut buf)
+ .map_err(|err| McuConnectionError::DictionaryFetch(Box::new(err)))?;
+ let raw_dict: RawDictionary = serde_json::from_slice(&buf)
+ .map_err(|err| McuConnectionError::DictionaryFetch(Box::new(err)))?;
+ let dict =
+ Dictionary::from_raw_dictionary(raw_dict).map_err(McuConnectionError::Dictionary)?;
+ debug!(dictionary = ?dict, "MCU dictionary");
+ conn.inner
+ .dictionary
+ .set(dict)
+ .expect("Dictionary already set");
+
+ Ok(conn)
+ }
+
+ fn verify_command_matches<C: Message>(&self) -> Result<(), McuConnectionError> {
+ // Special handling for identify/identify_response
+ if C::get_id(None) == Some(0) || C::get_id(None) == Some(1) {
+ return Ok(());
+ }
+
+ if self
+ .inner
+ .good_types
+ .lock()
+ .unwrap()
+ .contains(&TypeId::of::<C>())
+ {
+ return Ok(());
+ }
+
+ // We can only get here if we _have_ a dictionary as only Identify/IdentifyResponse are
+ // tested before then.
+ let dictionary = self.inner.dictionary.get().unwrap();
+ let id = C::get_id(Some(dictionary))
+ .ok_or_else(|| McuConnectionError::UnknownMessageId(C::get_name()))?;
+
+ // Must exist because we know the tag
+ let parser = dictionary.message_parsers.get(&id).unwrap();
+ let remote_fields = parser.fields.iter().map(|(s, t)| (s.as_str(), *t));
+ let local_fields = C::fields().into_iter();
+
+ if !remote_fields.eq(local_fields) {
+ return Err(McuConnectionError::CommandMismatch(
+ C::get_name(),
+ format_command_args(parser.fields.iter().map(|(s, t)| (s.as_str(), *t))),
+ format_command_args(C::fields().into_iter()),
+ ));
+ }
+
+ self.inner
+ .good_types
+ .lock()
+ .unwrap()
+ .insert(TypeId::of::<C>());
+
+ Ok(())
+ }
+
+ fn encode_command<C: Message>(
+ &self,
+ command: EncodedMessage<C>,
+ ) -> Result<FrontTrimmableBuffer, McuConnectionError> {
+ let id = command
+ .message_id(self.inner.dictionary.get())
+ .ok_or_else(|| McuConnectionError::UnknownMessageId(command.message_name()))?;
+ let mut payload = command.payload;
+ if id >= 0x80 {
+ payload.content[0] = ((id >> 7) & 0x7F) as u8 | 0x80;
+ } else {
+ payload.offset = 1;
+ }
+ payload.content[1] = (id & 0x7F) as u8;
+ Ok(payload)
+ }
+
+ /// Sends a command to the MCU
+ pub async fn send<C: Message>(
+ &self,
+ command: EncodedMessage<C>,
+ ) -> Result<(), McuConnectionError> {
+ let cmd = self.encode_command(command)?;
+ self.transport
+ .send(cmd.as_slice())
+ .map_err(|e| McuConnectionError::Transport(TransportError::Transmitter(e)))
+ }
+
+ async fn send_receive_impl<C: Message, R: Message>(
+ &self,
+ command: EncodedMessage<C>,
+ reply: R,
+ oid: Option<u8>,
+ ) -> Result<R::PodOwned, SendReceiveError> {
+ struct RespHandler<R: Message>(
+ Option<oneshot::Sender<Result<R::PodOwned, McuConnectionError>>>,
+ );
+
+ impl<R: Message> Handler for RespHandler<R> {
+ fn handle(&mut self, data: &mut &[u8]) -> HandlerResult {
+ if let Some(tx) = self.0.take() {
+ let _ = match R::decode(data) {
+ Ok(msg) => tx.send(Ok(msg.into())),
+ Err(e) => tx.send(Err(McuConnectionError::DecodingError(e))),
+ };
+ }
+ HandlerResult::Deregister
+ }
+ }
+
+ let cmd = self.encode_command(command)?;
+
+ self.verify_command_matches::<C>()?;
+ self.verify_command_matches::<R>()?;
+
+ let (tx, mut rx) = tokio::sync::oneshot::channel::<Result<R::PodOwned, _>>();
+ self.register_raw_response(reply, oid, Box::new(RespHandler::<R>(Some(tx))))?;
+
+ let mut retry_delay = 0.01;
+ for _retry in 0..=5 {
+ self.transport
+ .send(cmd.as_slice())
+ .map_err(|e| McuConnectionError::Transport(TransportError::Transmitter(e)))?;
+
+ let sleep = tokio::time::sleep(Duration::from_secs_f32(retry_delay));
+ tokio::pin!(sleep);
+
+ select! {
+ reply = &mut rx => {
+ return match reply {
+ Ok(Err(e)) => Err(SendReceiveError::McuConnection(e)),
+ Ok(Ok(v)) => Ok(v),
+ Err(_) => Err(SendReceiveError::Timeout)
+ };
+ },
+ _ = &mut sleep => {},
+ }
+
+ retry_delay *= 2.0;
+ }
+
+ Err(SendReceiveError::Timeout)
+ }
+
+ /// Sends a message to the MCU, and await a reply
+ ///
+ /// This sends a message to the MCU and awaits a reply matching the given `reply`.
+ /// This version works with replies that have an `oid` field.
+ pub async fn send_receive_oid<C: Message, R: Message + WithOid>(
+ &self,
+ command: EncodedMessage<C>,
+ reply: R,
+ oid: u8,
+ ) -> Result<R::PodOwned, SendReceiveError> {
+ self.send_receive_impl(command, reply, Some(oid)).await
+ }
+
+ /// Sends a message to the MCU, and await a reply
+ ///
+ /// This sends a message to the MCU and awaits a reply matching the given `reply`.
+ /// This version works with replies that do not have an `oid` field.
+ pub async fn send_receive<C: Message, R: Message + WithoutOid>(
+ &self,
+ command: EncodedMessage<C>,
+ reply: R,
+ ) -> Result<R::PodOwned, SendReceiveError> {
+ self.send_receive_impl(command, reply, None).await
+ }
+
+ fn register_raw_response<R: Message>(
+ &self,
+ _reply: R,
+ oid: Option<u8>,
+ handler: Box<dyn Handler>,
+ ) -> Result<RegisteredResponseHandle, McuConnectionError> {
+ let id = R::get_id(self.inner.dictionary.get())
+ .ok_or_else(|| McuConnectionError::UnknownMessageId(R::get_name()))?;
+ Ok(self.inner.handlers.register(id, oid, handler))
+ }
+
+ fn register_response_impl<R: Message>(
+ &self,
+ reply: R,
+ oid: Option<u8>,
+ ) -> Result<UnboundedReceiver<R::PodOwned>, McuConnectionError> {
+ struct RespHandler<R: Message>(UnboundedSender<R::PodOwned>, Option<oneshot::Sender<()>>);
+
+ impl<R: Message> Drop for RespHandler<R> {
+ fn drop(&mut self) {
+ self.1.take();
+ }
+ }
+
+ impl<R: Message> Handler for RespHandler<R> {
+ fn handle(&mut self, data: &mut &[u8]) -> HandlerResult {
+ let msg = R::decode(data)
+ .expect("Parser should already have assured this could parse")
+ .into();
+ match self.0.send(msg) {
+ Ok(_) => HandlerResult::Continue,
+ Err(_) => HandlerResult::Deregister,
+ }
+ }
+ }
+
+ self.verify_command_matches::<R>()?;
+
+ let (tx, rx) = unbounded_channel();
+ let tx_closed = tx.clone();
+ let (closer_tx, closer_rx) = oneshot::channel();
+ let uniq = self.register_raw_response(
+ reply,
+ oid,
+ Box::new(RespHandler::<R>(tx, Some(closer_tx))),
+ )?;
+
+ // Safe because register_raw_response already verified this
+ let id = R::get_id(self.inner.dictionary.get()).unwrap();
+
+ let inner = self.inner.clone();
+ spawn(async move {
+ select! {
+ _ = tx_closed.closed() => {},
+ _ = closer_rx => {},
+ }
+ inner.handlers.remove_handler(id, oid, Some(uniq));
+ });
+
+ Ok(rx)
+ }
+
+ /// Registers a subscriber for a reply message
+ ///
+ /// Received replies matching the type will be delivered to the returned channel.
+ /// To end the subscription, simply drop the returned receiver.
+ /// This version works with replies that have an `oid` field.
+ pub fn register_response_oid<R: Message + WithOid>(
+ &self,
+ reply: R,
+ oid: u8,
+ ) -> Result<UnboundedReceiver<R::PodOwned>, McuConnectionError> {
+ self.register_response_impl(reply, Some(oid))
+ }
+
+ /// Registers a subscriber for a reply message
+ ///
+ /// Received replies matching the type will be delivered to the returned channel.
+ /// To end the subscription, simply drop the returned receiver.
+ /// This version works with replies that do not have an `oid` field.
+ pub fn register_response<R: Message + WithoutOid>(
+ &self,
+ reply: R,
+ ) -> Result<UnboundedReceiver<R::PodOwned>, McuConnectionError> {
+ self.register_response_impl(reply, None)
+ }
+
+ /// Returns a reference the dictionary the MCU returned during initial handshake
+ pub fn dictionary(&self) -> &Dictionary {
+ self.inner.dictionary.get().unwrap()
+ }
+
+ /// Closes the transport and ends all subscriptions. Returns when the transport is closed.
+ pub async fn close(self) {
+ self.transport.close().await;
+ }
+
+ /// Waits for the connection to close, returning the error that closed it if any.
+ pub async fn closed(&mut self) -> &Result<(), McuConnectionError> {
+ if let ExitCode::Waiting(chan) = &mut self.exit_code {
+ self.exit_code = ExitCode::Exited(match chan.await {
+ Ok(r) => r,
+ Err(_) => Ok(()),
+ });
+ }
+ match self.exit_code {
+ ExitCode::Exited(ref val) => val,
+ _ => unreachable!(),
+ }
+ }
+}
+
+#[derive(Debug)]
+enum HandlerResult {
+ Continue,
+ Deregister,
+}
diff --git a/crates/windlass/src/messages.rs b/crates/windlass/src/messages.rs
new file mode 100644
index 0000000..5b97d31
--- /dev/null
+++ b/crates/windlass/src/messages.rs
@@ -0,0 +1,237 @@
+use std::collections::BTreeMap;
+
+use crate::dictionary::Dictionary;
+use crate::encoding::{FieldType, FieldValue, MessageDecodeError};
+
+#[derive(thiserror::Error, Debug)]
+pub enum MessageSkipperError {
+ #[error("invalid argument format: {0}")]
+ InvalidArgumentFormat(String),
+ #[error("unknown type '{1}' for argument '{0}'")]
+ UnknownType(String, String),
+ #[error("invalid format field type '%{0}'")]
+ InvalidFormatFieldType(String),
+}
+
+pub(crate) fn format_command_args<'a>(
+ fields: impl Iterator<Item = (&'a str, FieldType)>,
+) -> String {
+ let mut buf = String::new();
+ for (idx, (name, ty)) in fields.enumerate() {
+ if idx != 0 {
+ buf.push(' ');
+ }
+ buf.push_str(name);
+ buf.push('=');
+ buf.push_str(match ty {
+ FieldType::U32 => "%u",
+ FieldType::I32 => "%i",
+ FieldType::U16 => "%hu",
+ FieldType::I16 => "%hi",
+ FieldType::U8 => "%c",
+ FieldType::String => "%s",
+ FieldType::ByteArray => "%*s",
+ });
+ }
+ buf
+}
+
+pub struct MessageParser {
+ pub name: String,
+ pub fields: Vec<(String, FieldType)>,
+ pub output: Option<OutputFormat>,
+}
+
+impl MessageParser {
+ pub(crate) fn new<'a>(
+ name: &str,
+ parts: impl Iterator<Item = &'a str>,
+ ) -> Result<MessageParser, MessageSkipperError> {
+ let mut fields = vec![];
+ for part in parts {
+ let (arg, ty) = part
+ .split_once('=')
+ .ok_or_else(|| MessageSkipperError::InvalidArgumentFormat(part.into()))?;
+
+ let field_type = FieldType::from_msg(ty)?;
+ fields.push((arg.to_string(), field_type));
+ }
+ Ok(Self {
+ name: name.to_string(),
+ fields,
+ output: None,
+ })
+ }
+
+ pub(crate) fn new_output(msg: &str) -> Result<MessageParser, MessageSkipperError> {
+ let mut fields = vec![];
+ let mut parts = vec![];
+
+ let mut work = msg;
+ while let Some(pos) = work.find('%') {
+ let (pre, rest) = work.split_at(pos);
+ if !pre.is_empty() {
+ parts.push(FormatBlock::Static(pre.to_string()));
+ }
+ if let Some(rest) = rest.strip_prefix("%%") {
+ parts.push(FormatBlock::Static("%".to_string()));
+ work = rest;
+ break;
+ }
+ let (format, rest) = FieldType::from_format(rest)?;
+ parts.push(FormatBlock::Field);
+ fields.push((format!("field_{}", fields.len()), format));
+ work = rest;
+ }
+ if !work.is_empty() {
+ parts.push(FormatBlock::Static(work.to_string()));
+ }
+
+ Ok(Self {
+ name: msg.to_string(),
+ fields,
+ output: Some(OutputFormat { parts }),
+ })
+ }
+
+ #[allow(dead_code)]
+ pub(crate) fn skip(&self, input: &mut &[u8]) -> Result<(), MessageDecodeError> {
+ for (_, field) in &self.fields {
+ field.skip(input)?;
+ }
+ Ok(())
+ }
+
+ pub(crate) fn skip_with_oid(
+ &self,
+ input: &mut &[u8],
+ ) -> Result<Option<u8>, MessageDecodeError> {
+ let mut oid = None;
+ for (name, field) in &self.fields {
+ if name == "oid" {
+ if let FieldValue::U8(read_oid) = field.read(input)? {
+ oid = Some(read_oid);
+ }
+ } else {
+ field.skip(input)?;
+ }
+ }
+ Ok(oid)
+ }
+
+ pub(crate) fn parse(
+ &self,
+ input: &mut &[u8],
+ ) -> Result<BTreeMap<String, FieldValue>, MessageDecodeError> {
+ let mut output = BTreeMap::new();
+ for (name, field) in &self.fields {
+ output.insert(name.to_string(), field.read(input)?);
+ }
+ Ok(output)
+ }
+
+ pub(crate) fn parse_and_output(
+ &self,
+ input: &mut &[u8],
+ ) -> Option<Result<String, MessageDecodeError>> {
+ self.output.as_ref().map(|output| {
+ let fields = self.parse(input)?;
+ Ok(output.format(fields.values()))
+ })
+ }
+}
+
+impl std::fmt::Debug for MessageParser {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_map()
+ .entry(&"name", &self.name)
+ .entry(&"fields", &self.fields)
+ .finish()
+ }
+}
+
+pub trait Message: 'static {
+ type Pod<'a>: Into<Self::PodOwned> + std::fmt::Debug;
+ type PodOwned: Clone + Send + std::fmt::Debug + 'static;
+ fn get_id(dict: Option<&Dictionary>) -> Option<u16>;
+ fn get_name() -> &'static str;
+ fn decode<'a>(input: &mut &'a [u8]) -> Result<Self::Pod<'a>, MessageDecodeError>;
+ fn fields() -> Vec<(&'static str, FieldType)>;
+}
+
+pub trait WithOid: 'static {}
+pub trait WithoutOid: 'static {}
+
+/// Represents an encoded message, with a type-level link to the message kind
+pub struct EncodedMessage<M> {
+ pub payload: FrontTrimmableBuffer,
+ pub _message_kind: std::marker::PhantomData<M>,
+}
+
+impl<R: Message> std::fmt::Debug for EncodedMessage<R> {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("EncodedMessage")
+ .field("kind", &R::get_name())
+ .field("payload", &self.payload)
+ .finish()
+ }
+}
+
+impl<M: Message> EncodedMessage<M> {
+ #[doc(hidden)]
+ pub fn message_id(&self, dict: Option<&Dictionary>) -> Option<u16> {
+ M::get_id(dict)
+ }
+
+ #[doc(hidden)]
+ pub fn message_name(&self) -> &'static str {
+ M::get_name()
+ }
+}
+
+/// Wraps a `Vec<u8>` allowing removal of front bytes in a zero-copy way
+pub struct FrontTrimmableBuffer {
+ pub content: Vec<u8>,
+ pub offset: usize,
+}
+
+impl FrontTrimmableBuffer {
+ pub fn as_slice(&self) -> &[u8] {
+ &self.content[self.offset..]
+ }
+}
+
+impl std::fmt::Debug for FrontTrimmableBuffer {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ std::fmt::Debug::fmt(self.as_slice(), f)
+ }
+}
+
+/// Holds the format of a `output()` style debug message
+#[derive(Debug)]
+pub struct OutputFormat {
+ parts: Vec<FormatBlock>,
+}
+
+impl OutputFormat {
+ fn format<'a>(&self, mut fields: impl Iterator<Item = &'a FieldValue>) -> String {
+ let mut buf = String::new();
+ for part in &self.parts {
+ match part {
+ FormatBlock::Static(s) => buf.push_str(s),
+ FormatBlock::Field => {
+ if let Some(v) = fields.next() {
+ std::fmt::write(&mut buf, format_args!("{v}")).ok();
+ }
+ }
+ }
+ }
+ buf
+ }
+}
+
+#[derive(Debug)]
+enum FormatBlock {
+ Static(String),
+ Field,
+}
diff --git a/crates/windlass/src/transport.rs b/crates/windlass/src/transport.rs
new file mode 100644
index 0000000..e47cc8e
--- /dev/null
+++ b/crates/windlass/src/transport.rs
@@ -0,0 +1,615 @@
+use std::{collections::VecDeque, sync::Arc};
+use tokio::{
+ io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader},
+ pin, select,
+ sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
+ task::{spawn, JoinHandle},
+ time::{sleep_until, Instant},
+};
+use tokio_util::sync::CancellationToken;
+use tracing::trace;
+
+const MESSAGE_HEADER_SIZE: usize = 2;
+const MESSAGE_TRAILER_SIZE: usize = 3;
+pub const MESSAGE_LENGTH_MIN: usize = MESSAGE_HEADER_SIZE + MESSAGE_TRAILER_SIZE;
+pub const MESSAGE_LENGTH_MAX: usize = 64;
+pub const MESSAGE_LENGTH_PAYLOAD_MAX: usize = MESSAGE_LENGTH_MAX - MESSAGE_LENGTH_MIN;
+const MESSAGE_POSITION_SEQ: usize = 1;
+const MESSAGE_TRAILER_CRC: usize = 3;
+const MESSAGE_VALUE_SYNC: u8 = 0x7E;
+const MESSAGE_DEST: u8 = 0x10;
+const MESSAGE_SEQ_MASK: u8 = 0x0F;
+
+#[derive(thiserror::Error, Debug)]
+pub enum ReceiverError {
+ #[error("io error: {0}")]
+ IoError(#[from] std::io::Error),
+}
+
+#[derive(thiserror::Error, Debug)]
+pub enum TransmitterError {
+ #[error("io error: {0}")]
+ IoError(#[from] std::io::Error),
+ #[error("connection closed")]
+ ConnectionClosed,
+}
+
+#[derive(Debug)]
+pub(crate) struct Transport {
+ _task_rdr: JoinHandle<()>,
+ _task_wr: JoinHandle<()>,
+ task_inner: JoinHandle<()>,
+ outbox_tx: UnboundedSender<TransportCommand>,
+}
+
+#[derive(Debug)]
+enum TransportCommand {
+ SendMessage(Vec<u8>),
+ Exit,
+}
+
+pub(crate) type TransportReceiver = UnboundedReceiver<Result<Vec<u8>, TransportError>>;
+
+impl Transport {
+ pub(crate) async fn connect<R>(stream: R) -> (Transport, TransportReceiver)
+ where
+ R: AsyncRead + AsyncWrite + Send + 'static,
+ {
+ let (rdr, wr) = tokio::io::split(stream);
+
+ let (raw_send_tx, raw_send_rx) = unbounded_channel();
+ let (raw_recv_tx, raw_recv_rx) = unbounded_channel();
+ let (app_inbox_tx, app_inbox_rx) = unbounded_channel();
+ let (app_outbox_tx, app_outbox_rx) = unbounded_channel();
+
+ let cancel_token = CancellationToken::new();
+
+ let cancel = cancel_token.clone();
+ let ait = app_inbox_tx.clone();
+ let task_rdr = spawn(async move {
+ if let Err(e) = LowlevelReader::run(raw_recv_tx, rdr, cancel).await {
+ let _ = ait.send(Err(TransportError::Receiver(e)));
+ }
+ });
+
+ let cancel = cancel_token.clone();
+ let ait = app_inbox_tx.clone();
+ let task_wr = spawn(async move {
+ if let Err(e) = LowlevelWriter::run(raw_send_rx, wr, cancel).await {
+ let _ = ait.send(Err(TransportError::Transmitter(e)));
+ }
+ });
+
+ let task_inner = spawn(async move {
+ let mut ts = TransportState::new(
+ raw_recv_rx,
+ raw_send_tx,
+ app_inbox_tx,
+ app_outbox_rx,
+ cancel_token,
+ );
+ if let Err(e) = ts.protocol_handler().await {
+ let _ = ts.app_inbox.send(Err(e));
+ }
+ });
+
+ (
+ Transport {
+ _task_rdr: task_rdr,
+ _task_wr: task_wr,
+ task_inner,
+ outbox_tx: app_outbox_tx,
+ },
+ app_inbox_rx,
+ )
+ }
+
+ pub(crate) fn send(&self, msg: &[u8]) -> Result<(), TransmitterError> {
+ self.outbox_tx
+ .send(TransportCommand::SendMessage(msg.into()))
+ .map_err(|_| TransmitterError::ConnectionClosed)
+ }
+
+ pub(crate) async fn close(self) {
+ let _ = self.outbox_tx.send(TransportCommand::Exit);
+ let _ = self.task_inner.await;
+ }
+}
+
+struct LowlevelReader<R> {
+ rdr: BufReader<R>,
+ synced: bool,
+}
+
+impl<R: AsyncRead + Unpin> LowlevelReader<R> {
+ async fn read_frame(&mut self) -> Result<Option<Frame>, ReceiverError> {
+ let mut buf = [0u8; MESSAGE_LENGTH_MAX];
+ let next_byte = self.rdr.read_u8().await?;
+ if next_byte == MESSAGE_VALUE_SYNC {
+ if !self.synced {
+ self.synced = true;
+ }
+ return Ok(None);
+ }
+ if !self.synced {
+ return Ok(None);
+ }
+
+ let receive_time = Instant::now();
+ let len = next_byte as usize;
+
+ if !(MESSAGE_LENGTH_MIN..=MESSAGE_LENGTH_MAX).contains(&len) {
+ self.synced = false;
+ return Ok(None);
+ }
+
+ self.rdr.read_exact(&mut buf[1..len]).await?;
+ buf[0] = len as u8;
+ let buf = &buf[..len];
+ let seq = buf[MESSAGE_POSITION_SEQ];
+ trace!(frame = ?buf, seq = seq, "Received frame");
+
+ if seq & !MESSAGE_SEQ_MASK != MESSAGE_DEST {
+ self.synced = false;
+ return Ok(None);
+ }
+
+ let actual_crc = crc16(&buf[0..len - MESSAGE_TRAILER_SIZE]);
+ let frame_crc = (buf[len - MESSAGE_TRAILER_CRC] as u16) << 8
+ | (buf[len - MESSAGE_TRAILER_CRC + 1] as u16);
+ if frame_crc != actual_crc {
+ self.synced = false;
+ return Ok(None);
+ }
+
+ Ok(Some(Frame {
+ receive_time,
+ sequence: seq & MESSAGE_SEQ_MASK,
+ payload: buf[MESSAGE_HEADER_SIZE..len - MESSAGE_TRAILER_SIZE].into(),
+ }))
+ }
+
+ async fn run(
+ outbox: UnboundedSender<Frame>,
+ rdr: R,
+ cancel_token: CancellationToken,
+ ) -> Result<(), ReceiverError>
+ where
+ R: AsyncRead + Unpin,
+ {
+ let mut state = Self {
+ rdr: BufReader::new(rdr),
+ synced: false,
+ };
+
+ loop {
+ match state.read_frame().await {
+ Ok(None) => {}
+ Ok(Some(frame)) => {
+ if outbox.send(frame).is_err() {
+ break Ok(());
+ }
+ }
+ Err(_) if cancel_token.is_cancelled() => break Ok(()),
+ Err(e) => break Err(e),
+ }
+ }
+ }
+}
+
+fn crc16(buf: &[u8]) -> u16 {
+ let mut crc = 0xFFFFu16;
+ for b in buf {
+ let b = *b ^ ((crc & 0xFF) as u8);
+ let b = b ^ (b << 4);
+ let b16 = b as u16;
+ crc = (b16 << 8 | crc >> 8) ^ (b16 >> 4) ^ (b16 << 3);
+ }
+ crc
+}
+
+#[derive(Debug)]
+struct Frame {
+ receive_time: Instant,
+ sequence: u8,
+ payload: Vec<u8>,
+}
+
+#[derive(Debug, Clone)]
+struct InflightFrame {
+ sent_at: Instant,
+ #[allow(dead_code)]
+ sequence: u64,
+ payload: Arc<Vec<u8>>,
+ is_retransmit: bool,
+}
+
+#[derive(thiserror::Error, Debug)]
+pub enum MessageEncodeError {
+ #[error("message would exceed the maximum packet length of {MESSAGE_LENGTH_MAX} bytes")]
+ MessageTooLong,
+}
+
+fn encode_frame(sequence: u64, payload: &[u8]) -> Result<Vec<u8>, MessageEncodeError> {
+ let len = MESSAGE_LENGTH_MIN + payload.len();
+ if len > MESSAGE_LENGTH_MAX {
+ return Err(MessageEncodeError::MessageTooLong);
+ }
+ let mut buf = Vec::with_capacity(len);
+ buf.push(len as u8);
+ buf.push(MESSAGE_DEST | ((sequence as u8) & MESSAGE_SEQ_MASK));
+ buf.extend_from_slice(payload);
+ let crc = crc16(&buf[0..len - MESSAGE_TRAILER_SIZE]);
+ buf.push(((crc >> 8) & 0xFF) as u8);
+ buf.push((crc & 0xFF) as u8);
+ buf.push(MESSAGE_VALUE_SYNC);
+ Ok(buf)
+}
+
+struct LowlevelWriter {}
+
+impl LowlevelWriter {
+ async fn run<W>(
+ mut inbox: UnboundedReceiver<Arc<Vec<u8>>>,
+ mut wr: W,
+ cancel_token: CancellationToken,
+ ) -> Result<(), TransmitterError>
+ where
+ W: AsyncWrite + Unpin,
+ {
+ loop {
+ select! {
+ msg = inbox.recv() => {
+ let msg = match msg {
+ Some(msg) => msg,
+ None => break,
+ };
+ trace!(payload = ?msg, seq = msg[MESSAGE_POSITION_SEQ], "Sent frame");
+ wr.write_all(&msg).await?;
+ wr.flush().await?;
+ }
+
+ _ = cancel_token.cancelled() => {
+ break;
+ }
+ }
+ }
+
+ Ok(())
+ }
+}
+
+#[derive(thiserror::Error, Debug)]
+pub enum TransportError {
+ #[error("message encoding failed: {0}")]
+ MessageEncode(#[from] MessageEncodeError),
+ #[error("receiver error: {0}")]
+ Receiver(#[from] ReceiverError),
+ #[error("transmitter error: {0}")]
+ Transmitter(#[from] TransmitterError),
+}
+
+const MIN_RTO: f32 = 0.025;
+const MAX_RTO: f32 = 5.000;
+
+#[derive(Debug)]
+struct RttState {
+ srtt: f32,
+ rttvar: f32,
+ rto: f32,
+}
+
+impl RttState {
+ fn new() -> Self {
+ Self {
+ srtt: 0.0,
+ rttvar: 0.0,
+ rto: MIN_RTO,
+ }
+ }
+
+ fn rto(&self) -> std::time::Duration {
+ std::time::Duration::from_secs_f32(self.rto)
+ }
+
+ fn update(&mut self, rtt: std::time::Duration) {
+ let r = rtt.as_secs_f32();
+ if self.srtt == 0.0 {
+ self.rttvar = r / 2.0;
+ self.srtt = r * 10.0; // Klipper uses this, we'll copy it
+ } else {
+ self.rttvar = (3.0 * self.rttvar + (self.srtt - r).abs()) / 4.0;
+ self.srtt = (7.0 * self.srtt + r) / 8.0;
+ }
+ let rttvar4 = (self.rttvar * 4.0).max(0.001);
+ self.rto = (self.srtt + rttvar4).clamp(MIN_RTO, MAX_RTO);
+ }
+}
+
+#[derive(Debug)]
+struct TransportState {
+ link_inbox: UnboundedReceiver<Frame>,
+ link_outbox: UnboundedSender<Arc<Vec<u8>>>,
+ app_inbox: UnboundedSender<Result<Vec<u8>, TransportError>>,
+ app_outbox: UnboundedReceiver<TransportCommand>,
+ cancel: CancellationToken,
+
+ is_synchronized: bool,
+ rtt_state: RttState,
+ receive_sequence: u64,
+ send_sequence: u64,
+ last_ack_sequence: u64,
+ ignore_nak_seq: u64,
+ retransmit_seq: u64,
+ retransmit_now: bool,
+
+ corked_until: Option<Instant>,
+
+ inflight_messages: VecDeque<InflightFrame>,
+ ready_messages: MessageQueue<Vec<u8>>,
+}
+
+impl TransportState {
+ fn new(
+ link_inbox: UnboundedReceiver<Frame>,
+ link_outbox: UnboundedSender<Arc<Vec<u8>>>,
+ app_inbox: UnboundedSender<Result<Vec<u8>, TransportError>>,
+ app_outbox: UnboundedReceiver<TransportCommand>,
+ cancel: CancellationToken,
+ ) -> Self {
+ Self {
+ link_inbox,
+ link_outbox,
+ app_inbox,
+ app_outbox,
+ cancel,
+
+ is_synchronized: false,
+ rtt_state: RttState::new(),
+ receive_sequence: 1,
+ send_sequence: 1,
+ last_ack_sequence: 0,
+ ignore_nak_seq: 0,
+ retransmit_seq: 0,
+ retransmit_now: false,
+
+ corked_until: None,
+
+ inflight_messages: VecDeque::new(),
+ ready_messages: MessageQueue::new(),
+ }
+ }
+
+ fn update_receive_seq(&mut self, receive_time: Instant, sequence: u64) {
+ let mut sent_seq = self.receive_sequence;
+
+ loop {
+ if let Some(msg) = self.inflight_messages.pop_front() {
+ sent_seq += 1;
+ if sequence == sent_seq {
+ // Found the matching sent message
+ if !msg.is_retransmit {
+ let elapsed = receive_time.saturating_duration_since(msg.sent_at);
+ self.rtt_state.update(elapsed);
+ }
+ break;
+ }
+ } else {
+ // Ack with no outstanding messages, happens during connection init
+ self.send_sequence = sequence;
+ break;
+ }
+ }
+
+ self.receive_sequence = sequence;
+ self.is_synchronized = true;
+ }
+
+ fn handle_frame(&mut self, frame: Frame) {
+ let rseq = self.receive_sequence;
+ let mut sequence = (rseq & !(MESSAGE_SEQ_MASK as u64)) | (frame.sequence as u64);
+ if sequence < rseq {
+ sequence += (MESSAGE_SEQ_MASK as u64) + 1;
+ }
+ if !self.is_synchronized || sequence != rseq {
+ if sequence > self.send_sequence && self.is_synchronized {
+ // Ack for unsent message
+ return;
+ }
+ self.update_receive_seq(frame.receive_time, sequence);
+ }
+
+ if frame.payload.is_empty() {
+ if self.last_ack_sequence < sequence {
+ self.last_ack_sequence = sequence;
+ } else if sequence > self.ignore_nak_seq && !self.inflight_messages.is_empty() {
+ // Trigger retransmit from NAK
+ self.retransmit_now = true;
+ }
+ } else {
+ // Data message, we deliver this directly to the application as the MCU can't actually
+ // retransmit anyway.
+ let _ = self.app_inbox.send(Ok(frame.payload));
+ }
+ }
+
+ fn can_send(&self) -> bool {
+ self.corked_until.is_none() && self.inflight_messages.len() < 12
+ }
+
+ fn send_new_frame(&mut self, mut initial: Vec<u8>) -> Result<(), TransportError> {
+ while let Some(next) = self.ready_messages.try_peek() {
+ if initial.len() + next.len() <= MESSAGE_LENGTH_PAYLOAD_MAX {
+ // Add to the end of the message. Unwrap is safe because we already peeked.
+ let mut next = self.ready_messages.try_recv().unwrap();
+ initial.append(&mut next);
+ } else {
+ break;
+ }
+ }
+ let frame = Arc::new(encode_frame(self.send_sequence, &initial)?);
+ self.send_sequence += 1;
+ self.inflight_messages.push_back(InflightFrame {
+ sent_at: Instant::now(),
+ sequence: self.send_sequence,
+ payload: frame.clone(),
+ is_retransmit: false,
+ });
+ let _ = self.link_outbox.send(frame);
+ Ok(())
+ }
+
+ fn send_more_frames(&mut self) -> Result<(), TransportError> {
+ while self.can_send() && self.ready_messages.try_peek().is_some() {
+ let msg = self.ready_messages.try_recv().unwrap();
+ self.send_new_frame(msg)?;
+ }
+ Ok(())
+ }
+
+ fn retransmit_pending(&mut self) {
+ let len: usize = self
+ .inflight_messages
+ .iter()
+ .map(|msg| msg.payload.len())
+ .sum();
+ let mut buf = Vec::with_capacity(1 + len);
+ buf.push(MESSAGE_VALUE_SYNC);
+ let now = Instant::now();
+ for msg in self.inflight_messages.iter_mut() {
+ buf.extend_from_slice(&msg.payload);
+ msg.is_retransmit = true;
+ msg.sent_at = now;
+ }
+ let _ = self.link_outbox.send(Arc::new(buf));
+
+ if self.retransmit_now {
+ self.ignore_nak_seq = self.receive_sequence;
+ if self.receive_sequence < self.retransmit_seq {
+ self.ignore_nak_seq = self.retransmit_seq;
+ }
+ self.retransmit_now = false;
+ } else {
+ self.rtt_state.rto = (self.rtt_state.rto * 2.0).clamp(MIN_RTO, MAX_RTO);
+ self.ignore_nak_seq = self.send_sequence;
+ }
+ self.retransmit_seq = self.send_sequence;
+ }
+
+ async fn protocol_handler(&mut self) -> Result<(), TransportError> {
+ loop {
+ if self.retransmit_now {
+ self.retransmit_pending();
+ }
+
+ let retransmit_deadline = self
+ .inflight_messages
+ .front()
+ .map(|msg| msg.sent_at + self.rtt_state.rto());
+ let retransmit_timeout: futures::future::OptionFuture<_> =
+ retransmit_deadline.map(sleep_until).into();
+ pin!(retransmit_timeout);
+
+ let corked_timeout: futures::future::OptionFuture<_> =
+ self.corked_until.map(sleep_until).into();
+ pin!(corked_timeout);
+
+ select! {
+ frame = self.link_inbox.recv() => {
+ let frame = match frame {
+ Some(frame) => frame,
+ None => break,
+ };
+ self.handle_frame(frame);
+ },
+
+ msg = self.app_outbox.recv() => {
+ match msg {
+ Some(TransportCommand::SendMessage(msg)) => {
+ self.ready_messages.send(msg);
+ },
+ Some(TransportCommand::Exit) => {
+ self.cancel.cancel();
+ }
+ None => break,
+ };
+ },
+
+ _ = self.ready_messages.recv_peek(), if self.can_send() => {
+ self.send_more_frames()?;
+ },
+
+ _ = &mut retransmit_timeout, if retransmit_deadline.is_some() => {
+ self.retransmit_now = true;
+ },
+
+ _ = &mut corked_timeout, if self.corked_until.is_some() => {
+ // Timeout for when we are able to send again
+ }
+
+ _ = self.cancel.cancelled() => {
+ break;
+ },
+ }
+ }
+ Ok(())
+ }
+}
+
+#[derive(Debug)]
+struct MessageQueue<T> {
+ sender: UnboundedSender<T>,
+ receiver: UnboundedReceiver<T>,
+ peeked: Option<T>,
+}
+
+impl<T> MessageQueue<T> {
+ fn new() -> Self {
+ let (sender, receiver) = unbounded_channel();
+ Self {
+ sender,
+ receiver,
+ peeked: None,
+ }
+ }
+
+ async fn recv_peek(&mut self) -> Option<&T> {
+ if self.peeked.is_some() {
+ self.peeked.as_ref()
+ } else {
+ match self.receiver.recv().await {
+ Some(msg) => {
+ self.peeked = Some(msg);
+ self.peeked.as_ref()
+ }
+ None => None,
+ }
+ }
+ }
+
+ fn try_recv(&mut self) -> Option<T> {
+ if let Some(msg) = self.peeked.take() {
+ Some(msg)
+ } else {
+ self.receiver.try_recv().ok()
+ }
+ }
+
+ fn try_peek(&mut self) -> Option<&T> {
+ if self.peeked.is_some() {
+ self.peeked.as_ref()
+ } else {
+ match self.receiver.try_recv() {
+ Ok(msg) => {
+ self.peeked = Some(msg);
+ self.peeked.as_ref()
+ }
+ Err(_) => None,
+ }
+ }
+ }
+
+ fn send(&mut self, msg: T) {
+ let _ = self.sender.send(msg);
+ }
+}