//! Client implementation for the WireGuard broker protocol. //! //! This module provides a client implementation that communicates with a WireGuard broker server //! using a binary protocol. The client handles serialization and deserialization of messages, //! error handling, and the core interaction flow. //! //! # Examples //! //! ``` //! use rosenpass_wireguard_broker::api::client::{BrokerClient, BrokerClientIo}; //! #[derive(Debug)] //! struct MyIo; //! //! impl BrokerClientIo for MyIo { //! type SendError = std::io::Error; //! type RecvError = std::io::Error; //! //! fn send_msg(&mut self, buf: &[u8]) -> Result<(), Self::SendError> { //! // Implement sending logic //! Ok(()) //! } //! //! fn recv_msg(&mut self) -> Result, Self::RecvError> { //! // Implement receiving logic //! Ok(None) //! } //! } //! //! // Create client with custom IO implementation //! let mut client = BrokerClient::new(MyIo); //! assert!(client.poll_response().unwrap().is_none()); //! ``` //! //! # Protocol //! //! The client implements a simple request-response protocol for setting WireGuard pre-shared keys. //! Messages are serialized using a binary format defined in the [`crate::api::msgs`] module. use std::{borrow::BorrowMut, fmt::Debug}; use crate::{ api::{ config::NetworkBrokerConfig, msgs::{self, REQUEST_MSG_BUFFER_SIZE}, }, SerializedBrokerConfig, WireGuardBroker, }; use super::{ config::NetworkBrokerConfigErr, msgs::{Envelope, SetPskResponse}, }; /// Error type for polling responses from the broker server. #[derive(thiserror::Error, Debug, Clone, Eq, PartialEq)] pub enum BrokerClientPollResponseError { /// An IO error occurred while receiving the response #[error(transparent)] IoError(RecvError), /// The received message was invalid or malformed #[error("Invalid message.")] InvalidMessage, } impl From for BrokerClientPollResponseError { fn from(value: msgs::InvalidMessageTypeError) -> Self { let msgs::InvalidMessageTypeError = value; // Assert that this is a unit type BrokerClientPollResponseError::::InvalidMessage } } /// Helper function that wraps a receive error into a `BrokerClientPollResponseError::IoError` fn io_poller(e: RecvError) -> BrokerClientPollResponseError { BrokerClientPollResponseError::::IoError(e) } /// Helper function that returns a `BrokerClientPollResponseError::InvalidMessage` error fn invalid_msg_poller() -> BrokerClientPollResponseError { BrokerClientPollResponseError::::InvalidMessage } /// Error type for setting pre-shared keys through the broker client. #[derive(thiserror::Error, Debug, Clone, Eq, PartialEq)] pub enum BrokerClientSetPskError { /// Error encoding or decoding the message #[error("Error with encoding/decoding message")] MsgError, /// Error in the broker configuration #[error("Network Broker Config error: {0}")] BrokerError(NetworkBrokerConfigErr), /// IO error while sending the request #[error(transparent)] IoError(SendError), /// Interface name exceeds maximum length #[error("Interface name out of bounds")] IfaceOutOfBounds, } /// Trait defining the IO operations required by the broker client. /// /// Implementors must provide methods for sending and receiving binary messages. pub trait BrokerClientIo { /// Error type returned by send operations type SendError; /// Error type returned by receive operations type RecvError; /// Send a binary message fn send_msg(&mut self, buf: &[u8]) -> Result<(), Self::SendError>; /// Receive a binary message, returning None if no message is available fn recv_msg(&mut self) -> Result, Self::RecvError>; } /// Client for interacting with a WireGuard broker server. /// /// The client handles the protocol-level communication with the server, /// including message serialization and response handling. #[derive(Debug)] pub struct BrokerClient where Io: BrokerClientIo + Debug, { io: Io, } impl BrokerClient where Io: BrokerClientIo + Debug, { /// Creates a new `BrokerClient` with the given IO implementation. pub fn new(io: Io) -> Self { Self { io } } /// Returns a reference to the underlying IO implementation. pub fn io(&self) -> &Io { &self.io } /// Returns a mutable reference to the underlying IO implementation. pub fn io_mut(&mut self) -> &mut Io { &mut self.io } /// Polls for a response from the broker server. /// /// This method attempts to receive and parse a SetPsk response message from the server. /// If no message is available, returns `Ok(None)`. If a message is received, it is /// parsed and validated before being returned as `Ok(Some(result))`. /// /// # Returns /// - `Ok(Some(result))` if a valid response was received /// - `Ok(None)` if no message was available /// - `Err(BrokerClientPollResponseError)` if an error occurred during receiving or parsing /// /// # Errors /// Returns an error if: /// - An IO error occurs while receiving the message /// - The received message is invalid or malformed /// - The message type is incorrect pub fn poll_response( &mut self, ) -> Result, BrokerClientPollResponseError> { let res: &[u8] = match self.io.borrow_mut().recv_msg().map_err(io_poller)? { Some(r) => r, None => return Ok(None), }; let typ = res.first().ok_or(invalid_msg_poller())?; let typ = msgs::MsgType::try_from(*typ)?; let msgs::MsgType::SetPsk = typ; // Assert type let res = zerocopy::Ref::<&[u8], Envelope>::new(res) .ok_or(invalid_msg_poller())?; let res: &msgs::SetPskResponse = &res.payload; let res: msgs::SetPskResponseReturnCode = res .return_code .try_into() .map_err(|_| invalid_msg_poller())?; let res: msgs::SetPskResult = res.into(); Ok(Some(res)) } } impl WireGuardBroker for BrokerClient where Io: BrokerClientIo + Debug, { type Error = BrokerClientSetPskError; fn set_psk(&mut self, config: SerializedBrokerConfig) -> Result<(), Self::Error> { let config: Result = config.try_into(); let config = config.map_err(BrokerClientSetPskError::BrokerError)?; use BrokerClientSetPskError::*; const BUF_SIZE: usize = REQUEST_MSG_BUFFER_SIZE; // Allocate message let mut req = [0u8; BUF_SIZE]; // Construct message view let mut req = zerocopy::Ref::<&mut [u8], Envelope>::new(&mut req) .ok_or(MsgError)?; // Populate envelope req.msg_type = msgs::MsgType::SetPsk as u8; { // Derived payload let req = &mut req.payload; // Populate payload req.peer_id.copy_from_slice(&config.peer_id.value); req.psk.copy_from_slice(config.psk.secret()); req.set_iface(config.iface.as_ref()) .ok_or(IfaceOutOfBounds)?; } // Send message self.io .borrow_mut() .send_msg(req.bytes()) .map_err(IoError)?; Ok(()) } } #[cfg(test)] mod tests { use super::*; use msgs::{MsgType, SetPskError, SetPskResponseReturnCode}; // Mock IO implementation for testing #[derive(Debug)] struct MockIo { recv_data: Vec, } impl MockIo { fn new() -> Self { Self { recv_data: Vec::new(), } } fn set_recv_data(&mut self, data: Option>) { self.recv_data = data.unwrap_or_default(); } } impl BrokerClientIo for MockIo { type SendError = std::io::Error; type RecvError = std::io::Error; fn send_msg(&mut self, _buf: &[u8]) -> Result<(), Self::SendError> { Ok(()) } fn recv_msg(&mut self) -> Result, Self::RecvError> { if self.recv_data.is_empty() { Ok(None) } else { Ok(Some(&self.recv_data)) } } } fn create_response_msg(return_code: u8) -> Vec { let mut msg = vec![ MsgType::SetPsk as u8, // msg_type 0, 0, 0, // reserved bytes ]; msg.push(return_code); // return_code msg } #[test] fn test_poll_response_no_message() { let io = MockIo::new(); let mut client = BrokerClient::new(io); assert_eq!(client.poll_response().unwrap(), None); } #[test] fn test_poll_response_success() { let mut io = MockIo::new(); io.set_recv_data(Some(create_response_msg( SetPskResponseReturnCode::Success as u8, ))); let mut client = BrokerClient::new(io); assert_eq!(client.poll_response().unwrap(), Some(Ok(()))); } #[test] fn test_poll_response_no_such_peer() { let mut io = MockIo::new(); io.set_recv_data(Some(create_response_msg( SetPskResponseReturnCode::NoSuchPeer as u8, ))); let mut client = BrokerClient::new(io); assert_eq!( client.poll_response().unwrap(), Some(Err(SetPskError::NoSuchPeer)) ); } #[test] fn test_poll_response_invalid_message_type() { let mut io = MockIo::new(); io.set_recv_data(Some(vec![0xFF, 0, 0, 0, 0])); // Invalid message type let mut client = BrokerClient::new(io); assert!(matches!( client.poll_response(), Err(BrokerClientPollResponseError::InvalidMessage) )); } #[test] fn test_poll_response_invalid_return_code() { let mut io = MockIo::new(); io.set_recv_data(Some(create_response_msg(0xFF))); // Invalid return code let mut client = BrokerClient::new(io); assert!(matches!( client.poll_response(), Err(BrokerClientPollResponseError::InvalidMessage) )); } }