Skip to main content

leodos_protocols/transport/srspp/api/tokio/
sender.rs

1use zerocopy::{Immutable, IntoBytes};
2
3use crate::network::{NetworkRead, NetworkWrite};
4use crate::network::isl::address::Address;
5use crate::network::spp::SequenceCount;
6use crate::transport::srspp::machine::sender::SenderAction;
7use crate::transport::srspp::machine::sender::SenderActions;
8use crate::transport::srspp::machine::sender::SenderConfig;
9use crate::transport::srspp::machine::sender::SenderEvent;
10use crate::transport::srspp::machine::sender::SenderMachine;
11use crate::transport::srspp::packet::SrsppDataPacket;
12use crate::transport::srspp::packet::SrsppAckPacket;
13use crate::transport::srspp::packet::SrsppPacket;
14use crate::transport::srspp::packet::SrsppType;
15use crate::transport::srspp::rto::RtoPolicy;
16use std::collections::HashMap;
17use tokio::time::Instant;
18
19use super::SrsppError;
20use super::sleep_until;
21use super::ticks_to_duration;
22
23/// Async srspp sender.
24///
25/// Sends messages reliably over the link, handling segmentation and retransmission.
26/// Receives ACKs from the remote receiver.
27pub struct SrsppSender<L: NetworkWrite + NetworkRead<Error = <L as NetworkWrite>::Error>, P: RtoPolicy, const WIN: usize, const BUF: usize, const MTU: usize> {
28    /// Network link for sending data and receiving ACKs.
29    link: L,
30    /// Policy for computing retransmission timeouts.
31    rto_policy: P,
32    /// Sender state machine.
33    machine: SenderMachine<WIN, BUF, MTU>,
34    /// Pending actions from the state machine.
35    actions: SenderActions,
36    /// Per-packet retransmission deadlines keyed by sequence number.
37    retransmit_timers: HashMap<u16, Instant>,
38    /// Tick rate used to convert RTO ticks to durations.
39    ticks_per_sec: u32,
40    /// Instant when this sender was created, used for elapsed time.
41    start_time: Instant,
42    /// Buffer for receiving ACK packets from the link.
43    recv_buffer: [u8; MTU],
44    /// Buffer for building outgoing data packets.
45    tx_buffer: [u8; MTU],
46}
47
48impl<L: NetworkWrite + NetworkRead<Error = <L as NetworkWrite>::Error>, P: RtoPolicy, const WIN: usize, const BUF: usize, const MTU: usize>
49    SrsppSender<L, P, WIN, BUF, MTU>
50{
51    /// Create a new sender.
52    pub fn new(config: SenderConfig, link: L, rto_policy: P, ticks_per_sec: u32) -> Self {
53        Self {
54            link,
55            rto_policy,
56            machine: SenderMachine::new(config),
57            actions: SenderActions::new(),
58            retransmit_timers: HashMap::new(),
59            ticks_per_sec,
60            start_time: Instant::now(),
61            recv_buffer: [0u8; MTU],
62            tx_buffer: [0u8; MTU],
63        }
64    }
65
66    /// Send a message.
67    ///
68    /// The message is segmented if necessary and queued for transmission.
69    /// This returns when all packets have been transmitted (but not necessarily ACKed).
70    ///
71    /// For guaranteed delivery, call `flush()` after sending.
72    pub async fn send(&mut self, target: Address, data: &(impl IntoBytes + Immutable + ?Sized)) -> Result<(), SrsppError> {
73        let data = data.as_bytes();
74        self.machine
75            .handle(SenderEvent::SendRequest { target, data }, &mut self.actions)?;
76
77        self.process_actions().await?;
78        Ok(())
79    }
80
81    /// Wait for all sent data to be acknowledged.
82    pub async fn flush(&mut self) -> Result<(), SrsppError> {
83        while !self.machine.is_idle() {
84            self.poll().await?;
85        }
86        Ok(())
87    }
88
89    /// Poll for incoming ACKs and handle timeouts.
90    ///
91    /// Call this periodically if you want to process ACKs without blocking on flush.
92    pub async fn poll(&mut self) -> Result<(), SrsppError> {
93        let next_deadline = self.next_timer_deadline();
94
95        tokio::select! {
96            biased;
97
98            result = self.link.read(&mut self.recv_buffer) => {
99                let len = result.map_err(|e| SrsppError::Network(e.to_string()))?;
100                self.handle_incoming(&self.recv_buffer[..len].to_vec()).await?;
101            }
102
103            _ = sleep_until(next_deadline) => {
104                self.handle_timeouts().await?;
105            }
106        }
107
108        Ok(())
109    }
110
111    /// Check if all data has been acknowledged.
112    pub fn is_idle(&self) -> bool {
113        self.machine.is_idle()
114    }
115
116    /// Available buffer space in bytes.
117    pub fn available_bytes(&self) -> usize {
118        self.machine.available_bytes()
119    }
120
121    /// Executes pending actions: transmits packets and manages timers.
122    async fn process_actions(&mut self) -> Result<(), SrsppError> {
123        let actions: heapless::Vec<SenderAction, 32> =
124            self.actions.iter().copied().collect();
125
126        for action in &actions {
127            match action {
128                SenderAction::Transmit { seq, .. } => {
129                    let cfg = self.machine.config();
130                    let source_address = cfg.source_address;
131                    let apid = cfg.apid;
132                    let function_code = cfg.function_code;
133
134                    let packet_len =
135                        if let Some(info) = self.machine.get_payload(*seq) {
136                            let pkt = SrsppDataPacket::builder()
137                                .buffer(&mut self.tx_buffer)
138                                .source_address(source_address)
139                                .target(info.target)
140                                .apid(apid)
141                                .function_code(function_code)
142                                .sequence_count(*seq)
143                                .sequence_flag(info.flags)
144                                .payload_len(info.payload.len())
145                                .build()
146                                .map_err(|e| {
147                                    SrsppError::PacketError(
148                                        format!("{:?}", e),
149                                    )
150                                })?;
151                            pkt.payload.copy_from_slice(info.payload);
152                            Some(
153                                SrsppDataPacket::HEADER_SIZE
154                                    + info.payload.len(),
155                            )
156                        } else {
157                            None
158                        };
159
160                    if let Some(len) = packet_len {
161                        self.link
162                            .write(&self.tx_buffer[..len])
163                            .await
164                            .map_err(|e| {
165                                SrsppError::Network(e.to_string())
166                            })?;
167
168                        self.machine.mark_transmitted(*seq);
169
170                        let now = Instant::now();
171                        let elapsed = now.duration_since(self.start_time);
172                        let now_secs = elapsed.as_secs() as u32;
173                        let rto = self.rto_policy.rto_ticks(now_secs);
174                        let deadline =
175                            now + ticks_to_duration(rto, self.ticks_per_sec);
176                        self.retransmit_timers
177                            .insert(seq.value(), deadline);
178                    }
179                }
180                SenderAction::StopTimer { seq } => {
181                    self.retransmit_timers.remove(&seq.value());
182                }
183                SenderAction::PacketLost { seq } => {
184                    eprintln!(
185                        "srspp: Packet {} lost after max retransmits",
186                        seq.value()
187                    );
188                }
189                SenderAction::SpaceAvailable { .. } => {}
190                SenderAction::MessageLost => {
191                    eprintln!("srspp: Segmented message lost");
192                }
193            }
194        }
195        Ok(())
196    }
197
198    /// Parses an incoming packet and processes it if it is an ACK.
199    async fn handle_incoming(&mut self, packet: &[u8]) -> Result<(), SrsppError> {
200        let parsed = SrsppPacket::parse(packet)
201            .map_err(|e| SrsppError::PacketError(format!("{:?}", e)))?;
202        let srspp_type = parsed.srspp_type()
203            .map_err(|e| SrsppError::PacketError(format!("{:?}", e)))?;
204
205        if srspp_type == SrsppType::Ack {
206            let ack = SrsppAckPacket::parse(packet)
207                .map_err(|e| SrsppError::PacketError(format!("{:?}", e)))?;
208
209            self.machine.handle(
210                SenderEvent::AckReceived {
211                    cumulative_ack: ack.ack_payload.cumulative_ack(),
212                    selective_bitmap: ack.ack_payload.selective_ack_bitmap(),
213                },
214                &mut self.actions,
215            )?;
216
217            self.process_actions().await?;
218        }
219
220        Ok(())
221    }
222
223    /// Retransmits packets whose retransmission timers have expired.
224    async fn handle_timeouts(&mut self) -> Result<(), SrsppError> {
225        let now = Instant::now();
226
227        let expired: Vec<u16> = self
228            .retransmit_timers
229            .iter()
230            .filter(|(_, deadline)| **deadline <= now)
231            .map(|(seq, _)| *seq)
232            .collect();
233
234        for seq_val in expired {
235            self.retransmit_timers.remove(&seq_val);
236            self.machine.handle(
237                SenderEvent::RetransmitTimeout {
238                    seq: SequenceCount::from(seq_val),
239                },
240                &mut self.actions,
241            )?;
242            self.process_actions().await?;
243        }
244
245        Ok(())
246    }
247
248    /// Returns the earliest retransmission deadline, if any.
249    fn next_timer_deadline(&self) -> Option<Instant> {
250        self.retransmit_timers.values().min().copied()
251    }
252}