diff --git a/core/include/librmcs/data/datas.hpp b/core/include/librmcs/data/datas.hpp index daecdf6..0a3de49 100644 --- a/core/include/librmcs/data/datas.hpp +++ b/core/include/librmcs/data/datas.hpp @@ -30,6 +30,20 @@ enum class DataId : uint8_t { kUart3 = 14, kImu = 15, + + kSession = 16, +}; + +enum class SessionType : uint8_t { + kStart = 0, + kStartAck = 1, + kKeepalive = 2, + kKeepaliveAck = 3, +}; + +struct SessionControlView { + SessionType type; + uint32_t nonce; }; struct CanDataView { diff --git a/core/src/protocol/deserializer.cpp b/core/src/protocol/deserializer.cpp index 7a80ccd..f746abb 100644 --- a/core/src/protocol/deserializer.cpp +++ b/core/src/protocol/deserializer.cpp @@ -53,6 +53,7 @@ coroutine::LifoTask Deserializer::process_stream() { case FieldId::kUart3: success = co_await process_uart_field(id); break; case FieldId::kGpio: success = co_await process_gpio_field(id); break; case FieldId::kImu: success = co_await process_imu_field(id); break; + case FieldId::kSession: success = co_await process_session_field(id); break; default: break; } if (!success) @@ -317,4 +318,20 @@ coroutine::LifoTask Deserializer::process_imu_field(FieldId) { co_return true; } +coroutine::LifoTask Deserializer::process_session_field(FieldId) { + const auto* header_bytes = co_await peek_bytes(sizeof(SessionHeader)); + if (!header_bytes) [[unlikely]] + co_return false; + + auto header = SessionHeader::CRef{header_bytes}; + data::SessionControlView data_view{}; + data_view.type = header.get(); + data_view.nonce = header.get(); + consume_peeked(); + + callback_.session_control_deserialized_callback(data_view); + + co_return true; +} + } // namespace librmcs::core::protocol diff --git a/core/src/protocol/deserializer.hpp b/core/src/protocol/deserializer.hpp index ea0dc87..4e346a0 100644 --- a/core/src/protocol/deserializer.hpp +++ b/core/src/protocol/deserializer.hpp @@ -46,6 +46,8 @@ class DeserializeCallback { virtual void temperature_deserialized_callback(const data::TemperatureDataView& data) = 0; + virtual void session_control_deserialized_callback(const data::SessionControlView& data) = 0; + virtual void error_callback() = 0; }; @@ -120,6 +122,8 @@ class Deserializer : private coroutine::InlineLifoContext<1024> { coroutine::LifoTask process_imu_field(FieldId field_id); + coroutine::LifoTask process_session_field(FieldId field_id); + // Await until at least `size` contiguous bytes are available at the current read position. // Returns a pointer to a contiguous region of at least `size` bytes. // (from input buffer or pending cache) diff --git a/core/src/protocol/protocol.hpp b/core/src/protocol/protocol.hpp index c8be741..2fb690d 100644 --- a/core/src/protocol/protocol.hpp +++ b/core/src/protocol/protocol.hpp @@ -48,6 +48,11 @@ struct UartHeaderExtendedLayout { using DataLengthExtended = BitfieldMember<6, 10>; }; +struct SessionHeaderLayout { + using Type = BitfieldMember<4, 4, data::SessionType>; + using Nonce = BitfieldMember<8, 32, uint32_t>; +}; + } // namespace layouts struct FieldHeader @@ -82,6 +87,10 @@ struct UartHeaderExtended , layouts::UartHeaderLayout , layouts::UartHeaderExtendedLayout {}; +struct SessionHeader + : utility::Bitfield<5> + , layouts::SessionHeaderLayout {}; + struct GpioHeader : utility::Bitfield<2> { enum class PayloadEnum : uint8_t { kDigitalWriteLow = 0b0000, diff --git a/core/src/protocol/serializer.hpp b/core/src/protocol/serializer.hpp index 06943d7..db85988 100644 --- a/core/src/protocol/serializer.hpp +++ b/core/src/protocol/serializer.hpp @@ -380,6 +380,25 @@ class Serializer { return SerializeResult::kSuccess; } + SerializeResult write_session_control(const data::SessionControlView& view) noexcept { + const std::size_t required = required_session_size(); + + auto dst = buffer_.allocate(required); + LIBRMCS_VERIFY_LIKELY(!dst.empty(), SerializeResult::kBadAlloc); + utility::assert_debug(dst.size() == required); + std::byte* cursor = dst.data(); + + write_field_header(cursor, FieldId::kSession); + + auto header = SessionHeader::Ref(cursor); + cursor += sizeof(SessionHeader); + header.set(view.type); + header.set(view.nonce); + + utility::assert_debug(cursor == dst.data() + dst.size()); + return SerializeResult::kSuccess; + } + private: static constexpr bool use_extended_field_header(FieldId field_id) { utility::assert_debug(field_id != FieldId::kExtend); @@ -489,6 +508,13 @@ class Serializer { return total; } + static constexpr std::size_t required_session_size() { + constexpr std::size_t total = required_field_header_size(FieldId::kSession) + + sizeof(SessionHeader) - sizeof(FieldHeader); + utility::assert_debug(total <= kProtocolBufferSize); + return total; + } + SerializeBuffer& buffer_; }; diff --git a/firmware/c_board/app/src/usb/interrupt_safe_buffer.hpp b/firmware/c_board/app/src/usb/interrupt_safe_buffer.hpp index d45ea36..e52e35f 100644 --- a/firmware/c_board/app/src/usb/interrupt_safe_buffer.hpp +++ b/firmware/c_board/app/src/usb/interrupt_safe_buffer.hpp @@ -27,8 +27,9 @@ class InterruptSafeBuffer final std::span allocate(size_t size) noexcept override { core::utility::assert_debug(size <= core::protocol::kProtocolBufferSize); - if (is_locked_.test(std::memory_order::relaxed)) + if (is_locked_.test(std::memory_order::relaxed)) { return {}; + } auto out = out_.load(std::memory_order::relaxed); @@ -102,36 +103,27 @@ class InterruptSafeBuffer final } void clear() { + const bool was_locked = is_locked_.test_and_set(std::memory_order::relaxed); + core::utility::assert_debug(!was_locked); + auto in = in_.load(std::memory_order::relaxed); auto out = out_.load(std::memory_order::relaxed); auto readable = in - out; - if (!readable) - return; + if (readable) { + auto offset = out & kMask; + auto slice = std::min(readable, kBatchCount - offset); - auto offset = out & kMask; - auto slice = std::min(readable, kBatchCount - offset); + for (size_t i = 0; i < slice; i++) + batches_[offset + i].reset(); + for (size_t i = 0; i < readable - slice; i++) + batches_[i].reset(); - for (size_t i = 0; i < slice; i++) - batches_[offset + i].reset(); - for (size_t i = 0; i < readable - slice; i++) - batches_[i].reset(); - - std::atomic_signal_fence(std::memory_order::release); - out_.store(in, std::memory_order::relaxed); - } - - bool try_lock() { return !is_locked_.test_and_set(std::memory_order::relaxed); } - - bool try_unlock_and_clear() { - if (!is_locked_.test(std::memory_order::relaxed)) - return false; + std::atomic_signal_fence(std::memory_order::release); + out_.store(in, std::memory_order::relaxed); + } - // Unlocking drops stale queued batches from the last not-ready cycle before - // new ISR writes are accepted. - clear(); is_locked_.clear(std::memory_order::relaxed); - return true; } private: diff --git a/firmware/c_board/app/src/usb/vendor.cpp b/firmware/c_board/app/src/usb/vendor.cpp index 58f06f4..4f53c0a 100644 --- a/firmware/c_board/app/src/usb/vendor.cpp +++ b/firmware/c_board/app/src/usb/vendor.cpp @@ -31,13 +31,20 @@ void tud_dfu_runtime_reboot_to_dfu_cb() { NVIC_SystemReset(); } -void tud_suspend_cb(bool remote_wakeup_en) { (void)remote_wakeup_en; } +void tud_suspend_cb(bool remote_wakeup_en) { + (void)remote_wakeup_en; + usb::vendor->deactivate_session(); + usb::vendor->finish_downlink_transfer(); +} void tud_resume_cb() {} void tud_mount_cb() {} -void tud_umount_cb() {} +void tud_umount_cb() { + usb::vendor->deactivate_session(); + usb::vendor->finish_downlink_transfer(); +} } // extern "C" diff --git a/firmware/c_board/app/src/usb/vendor.hpp b/firmware/c_board/app/src/usb/vendor.hpp index e882c3c..b8ac443 100644 --- a/firmware/c_board/app/src/usb/vendor.hpp +++ b/firmware/c_board/app/src/usb/vendor.hpp @@ -1,13 +1,13 @@ #pragma once #include +#include #include #include #include #include #include -#include #include #include "core/include/librmcs/data/datas.hpp" @@ -20,6 +20,7 @@ #include "core/src/utility/immovable.hpp" #include "firmware/c_board/app/src/can/can.hpp" #include "firmware/c_board/app/src/gpio/gpio.hpp" +#include "firmware/c_board/app/src/timer/timer.hpp" #include "firmware/c_board/app/src/uart/uart.hpp" #include "firmware/c_board/app/src/usb/interrupt_safe_buffer.hpp" #include "firmware/c_board/app/src/usb/usb_descriptors.hpp" @@ -35,6 +36,7 @@ class Vendor static constexpr size_t kMaxPacketSize = 64; static constexpr std::size_t kGpioChannelCount = std::size(spec::c_board::kGpioDescriptors); + static constexpr auto kSessionLease = std::chrono::milliseconds{1000}; Vendor() { usb::usb_descriptors.init(); @@ -43,15 +45,20 @@ class Vendor core::protocol::Serializer& serializer() { return serializer_; } + void deactivate_session() { session_established_ = false; } + void handle_downlink(std::span buffer, bool finished) { deserializer_.feed(buffer); if (finished) deserializer_.finish_transfer(); } + void finish_downlink_transfer() { deserializer_.finish_transfer(); } + bool try_transmit() { - if (!tud_ready()) { - transmit_buffer_.try_lock(); + refresh_session_state(); + + if (!session_established_) { return false; } @@ -59,7 +66,6 @@ class Vendor return false; if (!transmitting_batch_) { - transmit_buffer_.try_unlock_and_clear(); transmitting_batch_ = transmit_buffer_.pop_batch(); } if (!transmitting_batch_) @@ -87,8 +93,23 @@ class Vendor } private: + void activate_session(uint32_t nonce) { + if (transmitting_batch_) { + transmit_buffer_.release_batch(transmitting_batch_); + transmitting_batch_ = nullptr; + transmitted_size_ = 0; + } + transmit_buffer_.clear(); + + current_session_nonce_ = nonce; + last_session_refresh_ = timer::timer->timepoint48(); + session_established_ = true; + } + void can_deserialized_callback( core::protocol::FieldId id, const data::CanDataView& data) override { + if (!session_established_) + return; switch (id) { case data::DataId::kCan1: can::can1->handle_downlink(data); break; case data::DataId::kCan2: can::can2->handle_downlink(data); break; @@ -98,6 +119,8 @@ class Vendor void uart_deserialized_callback( core::protocol::FieldId id, const data::UartDataView& data) override { + if (!session_established_) + return; switch (id) { case data::DataId::kUart1: uart::uart1->handle_downlink(data); break; case data::DataId::kUart2: uart::uart2->handle_downlink(data); break; @@ -107,6 +130,8 @@ class Vendor void gpio_digital_data_deserialized_callback( uint8_t channel_index, const data::GpioDigitalDataView& data) override { + if (!session_established_) + return; if (channel_index >= kGpioChannelCount) return; @@ -119,6 +144,8 @@ class Vendor void gpio_analog_data_deserialized_callback( uint8_t channel_index, const data::GpioAnalogDataView& data) override { + if (!session_established_) + return; if (channel_index >= kGpioChannelCount) return; @@ -131,6 +158,8 @@ class Vendor void gpio_digital_read_config_deserialized_callback( uint8_t channel_index, const data::GpioReadConfigView& data) override { + if (!session_established_) + return; if (channel_index >= kGpioChannelCount) return; @@ -159,7 +188,52 @@ class Vendor (void)data; } - void error_callback() override { core::utility::assert_failed_always(); } + void session_control_deserialized_callback(const data::SessionControlView& data) override { + switch (data.type) { + case data::SessionType::kStart: { + const bool same_session = session_established_ && data.nonce == current_session_nonce_; + + if (!same_session) + activate_session(data.nonce); + else + last_session_refresh_ = timer::timer->timepoint48(); + + const auto result = serializer_.write_session_control( + {.type = data::SessionType::kStartAck, .nonce = data.nonce}); + core::utility::assert_always( + result != core::protocol::Serializer::SerializeResult::kInvalidArgument); + break; + } + case data::SessionType::kKeepalive: + if (!session_established_ || data.nonce != current_session_nonce_) + return; + + last_session_refresh_ = timer::timer->timepoint48(); + { + const auto result = serializer_.write_session_control( + {.type = data::SessionType::kKeepaliveAck, .nonce = data.nonce}); + core::utility::assert_always( + result != core::protocol::Serializer::SerializeResult::kInvalidArgument); + } + break; + default: return; + } + } + + void error_callback() override { + // TODO: Report USB downlink deserialization errors through a dedicated error path. + } + + void refresh_session_state() { + if (!session_established_) + return; + + if (!timer::timer->check_expired( + last_session_refresh_, timer::Timer::to_duration48_checked(kSessionLease))) + return; + + deactivate_session(); + } core::protocol::Deserializer deserializer_{*this}; @@ -168,6 +242,9 @@ class Vendor const InterruptSafeBuffer::Batch* transmitting_batch_ = nullptr; size_t transmitted_size_ = 0; + bool session_established_ = false; + uint32_t current_session_nonce_ = 0; + timer::Timer::TimePoint48 last_session_refresh_ = timer::Timer::TimePoint48::min(); }; inline constinit Vendor::Lazy vendor; diff --git a/firmware/rmcs_board/app/src/usb/interrupt_safe_buffer.hpp b/firmware/rmcs_board/app/src/usb/interrupt_safe_buffer.hpp index 724b4b2..b8bdde6 100644 --- a/firmware/rmcs_board/app/src/usb/interrupt_safe_buffer.hpp +++ b/firmware/rmcs_board/app/src/usb/interrupt_safe_buffer.hpp @@ -27,8 +27,9 @@ class InterruptSafeBuffer final std::span allocate(size_t size) noexcept override { core::utility::assert_debug(size <= core::protocol::kProtocolBufferSize); - if (is_locked_.test(std::memory_order::relaxed)) + if (is_locked_.test(std::memory_order::relaxed)) { return {}; + } auto out = out_.load(std::memory_order::relaxed); @@ -102,36 +103,27 @@ class InterruptSafeBuffer final } void clear() { + const bool was_locked = is_locked_.test_and_set(std::memory_order::relaxed); + core::utility::assert_debug(!was_locked); + auto in = in_.load(std::memory_order::relaxed); auto out = out_.load(std::memory_order::relaxed); auto readable = in - out; - if (!readable) - return; + if (readable) { + auto offset = out & kMask; + auto slice = std::min(readable, kBatchCount - offset); - auto offset = out & kMask; - auto slice = std::min(readable, kBatchCount - offset); + for (size_t i = 0; i < slice; i++) + batches_[offset + i].reset(); + for (size_t i = 0; i < readable - slice; i++) + batches_[i].reset(); - for (size_t i = 0; i < slice; i++) - batches_[offset + i].reset(); - for (size_t i = 0; i < readable - slice; i++) - batches_[i].reset(); - - std::atomic_signal_fence(std::memory_order_release); - out_.store(in, std::memory_order::relaxed); - } - - bool try_lock() { return !is_locked_.test_and_set(std::memory_order::relaxed); } - - bool try_unlock_and_clear() { - if (!is_locked_.test(std::memory_order::relaxed)) - return false; + std::atomic_signal_fence(std::memory_order_release); + out_.store(in, std::memory_order::relaxed); + } - // Unlocking drops stale queued batches from the last not-ready cycle before - // new ISR writes are accepted. - clear(); is_locked_.clear(std::memory_order::relaxed); - return true; } private: diff --git a/firmware/rmcs_board/app/src/usb/vendor.cpp b/firmware/rmcs_board/app/src/usb/vendor.cpp index 544188d..4edd4f6 100644 --- a/firmware/rmcs_board/app/src/usb/vendor.cpp +++ b/firmware/rmcs_board/app/src/usb/vendor.cpp @@ -27,13 +27,20 @@ void tud_vendor_rx_cb(uint8_t itf, const uint8_t* buffer, uint16_t size) { void tud_dfu_runtime_reboot_to_dfu_cb() { boot::BootMailbox::reboot_to_bootloader(); } -void tud_suspend_cb(bool remote_wakeup_en) { (void)remote_wakeup_en; } +void tud_suspend_cb(bool remote_wakeup_en) { + (void)remote_wakeup_en; + usb::vendor->deactivate_session(); + usb::vendor->finish_downlink_transfer(); +} void tud_resume_cb() {} void tud_mount_cb() {} -void tud_umount_cb() {} +void tud_umount_cb() { + usb::vendor->deactivate_session(); + usb::vendor->finish_downlink_transfer(); +} } // extern "C" diff --git a/firmware/rmcs_board/app/src/usb/vendor.hpp b/firmware/rmcs_board/app/src/usb/vendor.hpp index 55aa5f3..9a11d4b 100644 --- a/firmware/rmcs_board/app/src/usb/vendor.hpp +++ b/firmware/rmcs_board/app/src/usb/vendor.hpp @@ -20,6 +20,7 @@ #include "core/src/utility/immovable.hpp" #include "firmware/rmcs_board/app/src/can/can.hpp" #include "firmware/rmcs_board/app/src/gpio/gpio.hpp" +#include "firmware/rmcs_board/app/src/timer/timer.hpp" #include "firmware/rmcs_board/app/src/uart/uart.hpp" #include "firmware/rmcs_board/app/src/usb/interrupt_safe_buffer.hpp" #include "firmware/rmcs_board/app/src/usb/usb_descriptors.hpp" @@ -33,6 +34,8 @@ class Vendor public: using Lazy = utility::Lazy; + static constexpr uint64_t kSessionLeaseQuarterUs = 4'000'000; + Vendor() { usb::usb_descriptors.init(); @@ -46,15 +49,20 @@ class Vendor core::protocol::Serializer& serializer() { return serializer_; } + void deactivate_session() { session_established_ = false; } + void handle_downlink(std::span buffer, bool finished) { deserializer_.feed(buffer); if (finished) deserializer_.finish_transfer(); } + void finish_downlink_transfer() { deserializer_.finish_transfer(); } + bool try_transmit() { - if (!tud_ready()) { - transmit_buffer_.try_lock(); + refresh_session_state(); + + if (!session_established_) { return false; } @@ -62,7 +70,6 @@ class Vendor return false; if (!transmitting_batch_) { - transmit_buffer_.try_unlock_and_clear(); transmitting_batch_ = transmit_buffer_.pop_batch(); } if (!transmitting_batch_) @@ -91,8 +98,23 @@ class Vendor } private: + void activate_session(uint32_t nonce) { + if (transmitting_batch_) { + transmit_buffer_.release_batch(transmitting_batch_); + transmitting_batch_ = nullptr; + transmitted_size_ = 0; + } + transmit_buffer_.clear(); + + current_session_nonce_ = nonce; + last_session_refresh_quarter_us_ = timer::Timer::timestamp64_quarter_us(); + session_established_ = true; + } + void can_deserialized_callback( core::protocol::FieldId id, const data::CanDataView& data) override { + if (!session_established_) + return; switch (id) { case data::DataId::kCan0: can::can_array[0]->handle_downlink(data); break; case data::DataId::kCan1: can::can_array[1]->handle_downlink(data); break; @@ -104,6 +126,8 @@ class Vendor void uart_deserialized_callback( core::protocol::FieldId id, const data::UartDataView& data) override { + if (!session_established_) + return; switch (id) { case data::DataId::kUart0: uart::uart_array[0]->handle_downlink(data); break; case data::DataId::kUart1: uart::uart_array[1]->handle_downlink(data); break; @@ -119,6 +143,8 @@ class Vendor void gpio_digital_data_deserialized_callback( uint8_t channel_index, const data::GpioDigitalDataView& data) override { + if (!session_established_) + return; if (channel_index >= board::spec::kGpioDescriptors.size()) return; @@ -131,6 +157,8 @@ class Vendor void gpio_analog_data_deserialized_callback( uint8_t channel_index, const data::GpioAnalogDataView& data) override { + if (!session_established_) + return; if (channel_index >= board::spec::kGpioDescriptors.size()) return; @@ -143,6 +171,8 @@ class Vendor void gpio_digital_read_config_deserialized_callback( uint8_t channel_index, const data::GpioReadConfigView& data) override { + if (!session_established_) + return; if (channel_index >= board::spec::kGpioDescriptors.size()) return; @@ -171,7 +201,52 @@ class Vendor (void)data; } - void error_callback() override { core::utility::assert_failed_always(); } + void session_control_deserialized_callback(const data::SessionControlView& data) override { + switch (data.type) { + case data::SessionType::kStart: { + const bool same_session = session_established_ && data.nonce == current_session_nonce_; + + if (!same_session) + activate_session(data.nonce); + else + last_session_refresh_quarter_us_ = timer::Timer::timestamp64_quarter_us(); + + const auto result = serializer_.write_session_control( + {.type = data::SessionType::kStartAck, .nonce = data.nonce}); + core::utility::assert_always( + result != core::protocol::Serializer::SerializeResult::kInvalidArgument); + break; + } + case data::SessionType::kKeepalive: + if (!session_established_ || data.nonce != current_session_nonce_) + return; + + last_session_refresh_quarter_us_ = timer::Timer::timestamp64_quarter_us(); + { + const auto result = serializer_.write_session_control( + {.type = data::SessionType::kKeepaliveAck, .nonce = data.nonce}); + core::utility::assert_always( + result != core::protocol::Serializer::SerializeResult::kInvalidArgument); + } + break; + default: return; + } + } + + void error_callback() override { + // TODO: Report USB downlink deserialization errors through a dedicated error path. + } + + void refresh_session_state() { + if (!session_established_) + return; + + const uint64_t now = timer::Timer::timestamp64_quarter_us(); + if (now - last_session_refresh_quarter_us_ < kSessionLeaseQuarterUs) + return; + + deactivate_session(); + } core::protocol::Deserializer deserializer_{*this}; @@ -180,6 +255,9 @@ class Vendor const InterruptSafeBuffer::Batch* transmitting_batch_ = nullptr; size_t transmitted_size_ = 0; + bool session_established_ = false; + uint32_t current_session_nonce_ = 0; + uint64_t last_session_refresh_quarter_us_ = 0; }; inline constinit Vendor::Lazy vendor; diff --git a/host/src/protocol/handler.cpp b/host/src/protocol/handler.cpp index be9372d..b7bc504 100644 --- a/host/src/protocol/handler.cpp +++ b/host/src/protocol/handler.cpp @@ -1,12 +1,20 @@ #include "librmcs/protocol/handler.hpp" +#include +#include +#include #include #include #include +#include #include +#include #include +#include #include +#include #include +#include #include #include "core/src/protocol/deserializer.hpp" @@ -23,38 +31,63 @@ namespace librmcs::host::protocol { class Handler::Impl : public core::protocol::DeserializeCallback { public: + static constexpr auto kSessionAckTimeout = std::chrono::milliseconds{200}; + static constexpr size_t kSessionAckRetryCount = 5; + static constexpr auto kSessionRefreshInterval = std::chrono::milliseconds{250}; + explicit Impl(std::unique_ptr transport, data::DataCallback& callback) - : transport_(std::move(transport)) - , callback_(callback) - , deserializer_(*this) { + : callback_(callback) + , deserializer_(*this) + , expected_session_nonce_(generate_session_nonce()) + , transport_(std::move(transport)) { transport_->receive([this](std::span buffer) { // Operating system automatically assembles the packet deserializer_.feed(buffer); deserializer_.finish_transfer(); }); + + establish_session(); + keepalive_thread_ = std::thread{[this] { keepalive_loop(); }}; + } + + ~Impl() override { + stop_keepalive_.store(true, std::memory_order_relaxed); + session_cv_.notify_all(); + if (keepalive_thread_.joinable()) + keepalive_thread_.join(); + + transport_.reset(); } PacketBuilder start_transmit() { return PacketBuilder{transport_.get()}; } void can_deserialized_callback( core::protocol::FieldId id, const data::CanDataView& data) override { + if (!session_established()) + return; if (!callback_.can_receive_callback(id, data)) logging::get_logger().error("Unexpected can field id: ", static_cast(id)); } void uart_deserialized_callback( core::protocol::FieldId id, const data::UartDataView& data) override { + if (!session_established()) + return; if (!callback_.uart_receive_callback(id, data)) logging::get_logger().error("Unexpected uart field id: ", static_cast(id)); } void gpio_digital_data_deserialized_callback( uint8_t channel_index, const data::GpioDigitalDataView& data) override { + if (!session_established()) + return; callback_.gpio_digital_read_result_callback(channel_index, data); } void gpio_analog_data_deserialized_callback( uint8_t channel_index, const data::GpioAnalogDataView& data) override { + if (!session_established()) + return; callback_.gpio_analog_read_result_callback(channel_index, data); } @@ -73,25 +106,160 @@ class Handler::Impl : public core::protocol::DeserializeCallback { } void accelerometer_deserialized_callback(const data::AccelerometerDataView& data) override { + if (!session_established()) + return; callback_.accelerometer_receive_callback(data); } void gyroscope_deserialized_callback(const data::GyroscopeDataView& data) override { + if (!session_established()) + return; callback_.gyroscope_receive_callback(data); } void temperature_deserialized_callback(const data::TemperatureDataView& data) override { + if (!session_established()) + return; callback_.temperature_receive_callback(data); } + void session_control_deserialized_callback(const data::SessionControlView& data) override { + if (data.nonce != expected_session_nonce_) + return; + + bool notify = false; + { + const std::scoped_lock guard{session_mutex_}; + switch (data.type) { + case data::SessionType::kStartAck: + session_established_.store(true, std::memory_order_relaxed); + ++session_start_ack_count_; + notify = true; + break; + case data::SessionType::kKeepaliveAck: + ++session_keepalive_ack_count_; + notify = true; + break; + default: break; + } + } + if (notify) + session_cv_.notify_all(); + } + void error_callback() override { logging::get_logger().error("Deserializer encountered an error while parsing input"); } private: - std::unique_ptr transport_; + [[nodiscard]] bool session_established() const { + return session_established_.load(std::memory_order_relaxed); + } + + void establish_session() { + for (size_t attempt = 0; attempt < kSessionAckRetryCount; ++attempt) { + uint64_t previous_session_start_ack_count = 0; + { + const std::scoped_lock guard{session_mutex_}; + previous_session_start_ack_count = session_start_ack_count_; + } + + send_session_start(); + + std::unique_lock lock{session_mutex_}; + if (session_cv_.wait_for( + lock, kSessionAckTimeout, [this, previous_session_start_ack_count] { + return session_start_ack_count_ > previous_session_start_ack_count; + })) { + return; + } + } + + throw std::runtime_error{"Timed out waiting for SESSION_ACK"}; + } + + void send_session_start() { send_session_control(data::SessionType::kStart, "Session Start"); } + + void refresh_session() { + for (size_t attempt = 0; attempt < kSessionAckRetryCount; ++attempt) { + uint64_t previous_session_keepalive_ack_count = 0; + { + const std::scoped_lock guard{session_mutex_}; + previous_session_keepalive_ack_count = session_keepalive_ack_count_; + } + + send_session_keepalive(); + + std::unique_lock lock{session_mutex_}; + if (session_cv_.wait_for( + lock, kSessionAckTimeout, [this, previous_session_keepalive_ack_count] { + return stop_keepalive_.load(std::memory_order_relaxed) + || session_keepalive_ack_count_ > previous_session_keepalive_ack_count; + })) { + return; + } + } + + throw std::runtime_error{"Timed out waiting for SESSION_KEEPALIVE_ACK"}; + } + + void send_session_keepalive() { + send_session_control(data::SessionType::kKeepalive, "Session Keepalive"); + } + + void send_session_control(data::SessionType type, std::string_view operation_name) { + core::protocol::Serializer::SerializeResult result; + { + StreamBuffer buffer{*transport_}; + core::protocol::Serializer serializer{buffer}; + result = + serializer.write_session_control({.type = type, .nonce = expected_session_nonce_}); + } + + core::utility::assert_debug( + result != core::protocol::Serializer::SerializeResult::kInvalidArgument); + if (result == core::protocol::Serializer::SerializeResult::kBadAlloc) [[unlikely]] + throw std::runtime_error( + std::string{"Failed to transmit "} + std::string{operation_name} + + ": Transmit buffer unavailable (acquire failed)"); + } + + void keepalive_loop() { + while (!stop_keepalive_.load(std::memory_order_relaxed)) { + std::this_thread::sleep_for(kSessionRefreshInterval); + if (stop_keepalive_.load(std::memory_order_relaxed)) + break; + + try { + refresh_session(); + } catch (const std::exception& exception) { + logging::get_logger().error( + "Failed to refresh session: {}. Terminating...", exception.what()); + std::terminate(); + } + } + } + + static uint32_t generate_session_nonce() { + std::random_device random_device; + std::uniform_int_distribution distribution; + return distribution(random_device); + } + data::DataCallback& callback_; core::protocol::Deserializer deserializer_; + + mutable std::mutex session_mutex_; + std::condition_variable session_cv_; + std::atomic session_established_{false}; + uint64_t session_start_ack_count_ = 0; + uint64_t session_keepalive_ack_count_ = 0; + uint32_t expected_session_nonce_ = 0; + + std::unique_ptr transport_; + + std::atomic stop_keepalive_{false}; + std::thread keepalive_thread_; }; namespace { diff --git a/host/src/transport/usb/usb.cpp b/host/src/transport/usb/usb.cpp index ca79d56..cc93fc2 100644 --- a/host/src/transport/usb/usb.cpp +++ b/host/src/transport/usb/usb.cpp @@ -1,5 +1,4 @@ #include -#include #include #include #include @@ -227,7 +226,8 @@ class Usb : public Transport { wrapper->self_.usb_transmit_complete_callback(wrapper); }, wrapper, 0); - transfer->flags = libusb_transfer_flags::LIBUSB_TRANSFER_FREE_BUFFER; + transfer->flags = libusb_transfer_flags::LIBUSB_TRANSFER_FREE_BUFFER + | libusb_transfer_flags::LIBUSB_TRANSFER_ADD_ZERO_PACKET; } } catch (...) { for (auto& wrapper : transmit_transfers) { @@ -294,11 +294,7 @@ class Usb : public Transport { return; } - const auto now = std::chrono::steady_clock::now(); - const bool should_drop = now > last_rx_callback_timepoint_ + std::chrono::seconds{1}; - last_rx_callback_timepoint_ = now; - - if (!should_drop && transfer->actual_length > 0) { + if (transfer->actual_length > 0) { const auto* first = reinterpret_cast(transfer->buffer); const auto size = static_cast(transfer->actual_length); receive_callback_({first, size}); @@ -363,8 +359,6 @@ class Usb : public Transport { std::mutex transmit_transfer_pop_mutex_, transmit_transfer_push_mutex_; std::function)> receive_callback_; - std::chrono::steady_clock::time_point last_rx_callback_timepoint_ = - std::chrono::steady_clock::time_point::min(); }; std::unique_ptr create_transport(