Skip to main content

leodos_protocols/transport/srspp/api/cfs/
receiver.rs

1use core::future::poll_fn;
2use core::task::Poll;
3
4use futures::FutureExt;
5
6use leodos_libcfs::cfe::duration::Duration;
7use leodos_libcfs::cfe::time::SysTime;
8use leodos_libcfs::runtime::time::sleep;
9
10use crate::network::NetworkRead;
11use crate::network::NetworkWrite;
12use crate::network::isl::address::Address;
13use crate::network::spp::SequenceCount;
14use crate::transport::srspp::machine::receiver::AckInfo;
15use crate::transport::srspp::machine::receiver::AckState;
16use crate::transport::srspp::machine::receiver::HandleResult;
17use crate::transport::srspp::machine::receiver::ReceiverBackend;
18use crate::transport::srspp::machine::receiver::ReceiverConfig;
19use crate::transport::srspp::machine::receiver::ReceiverMachine;
20use crate::transport::srspp::machine::receiver::TimerAction;
21use crate::transport::srspp::packet::SrsppAckPacket;
22use crate::transport::srspp::packet::SrsppDataPacket;
23use crate::transport::srspp::packet::SrsppPacket;
24use crate::transport::srspp::packet::SrsppType;
25use crate::utils::cell::SyncRefCell;
26use heapless::LinearMap;
27
28use super::TransportError;
29use super::sender::duration_until;
30
31/// Per-stream receiver state for a single remote sender.
32pub(super) struct StreamState<R: ReceiverBackend> {
33    /// Receiver backend for this stream.
34    pub(super) machine: R,
35    /// ACK and timer state for this stream.
36    pub(super) ack_state: AckState,
37    /// Deadline for the delayed ACK timer.
38    pub(super) ack_deadline: Option<SysTime>,
39    /// Deadline for the progress (inactivity) timer.
40    pub(super) progress_deadline: Option<SysTime>,
41}
42
43/// Shared mutable state for the multi-stream receiver channel.
44pub(super) struct MultiReceiverState<E, R: ReceiverBackend, const MAX_STREAMS: usize> {
45    /// Configuration shared across all streams.
46    pub(super) config: ReceiverConfig,
47    /// Per-sender stream states keyed by source address.
48    pub(super) streams: LinearMap<Address, StreamState<R>, MAX_STREAMS>,
49    /// Delayed ACK duration.
50    pub(super) ack_delay: Duration,
51    /// Whether the handle has signaled no more receives.
52    pub(super) closed: bool,
53    /// First error encountered, propagated to the handle.
54    pub(super) error: Option<TransportError<E>>,
55}
56
57// ── Channel and driver ──
58
59/// Channel that owns the receiver state. Split into handle + driver.
60pub struct SrsppReceiver<
61    E,
62    R: ReceiverBackend = ReceiverMachine<8, 4096, 8192>,
63    const MAX_STREAMS: usize = 1,
64> {
65    /// Interior-mutable receiver state shared between handle and driver.
66    state: SyncRefCell<MultiReceiverState<E, R, MAX_STREAMS>>,
67}
68
69#[bon::bon]
70impl<E: Clone, R: ReceiverBackend, const MAX_STREAMS: usize> SrsppReceiver<E, R, MAX_STREAMS> {
71    /// Creates a new multi-stream receiver.
72    #[builder]
73    pub fn new(
74        local_address: Address,
75        apid: crate::network::spp::Apid,
76        #[builder(default)] function_code: u8,
77        #[builder(default)] immediate_ack: bool,
78        #[builder(default = 100)] ack_delay_ticks: u32,
79    ) -> Self {
80        let config = ReceiverConfig::builder()
81            .local_address(local_address)
82            .apid(apid)
83            .function_code(function_code)
84            .immediate_ack(immediate_ack)
85            .ack_delay_ticks(ack_delay_ticks)
86            .build();
87        let ack_delay = Duration::from_millis(config.ack_delay_ticks);
88        Self {
89            state: SyncRefCell::new(MultiReceiverState {
90                config,
91                streams: LinearMap::new(),
92                ack_delay,
93                closed: false,
94                error: None,
95            }),
96        }
97    }
98
99    /// Splits into a handle for receiving and a driver for I/O.
100    pub fn split(
101        &self,
102    ) -> (
103        SrsppRxHandle<'_, E, R, MAX_STREAMS>,
104        SrsppReceiverDriver<'_, E, R, MAX_STREAMS>,
105    ) {
106        (
107            SrsppRxHandle {
108                receiver: &self.state,
109            },
110            SrsppReceiverDriver::new(&self.state),
111        )
112    }
113}
114
115/// Driver that handles I/O. Runs as a concurrent task.
116pub struct SrsppReceiverDriver<'a, E, R: ReceiverBackend, const MAX_STREAMS: usize> {
117    pub(super) state: &'a SyncRefCell<MultiReceiverState<E, R, MAX_STREAMS>>,
118    ack_buffer: [u8; 32],
119}
120
121impl<'a, E, R: ReceiverBackend, const MAX_STREAMS: usize>
122    SrsppReceiverDriver<'a, E, R, MAX_STREAMS>
123{
124    pub(super) fn new(state: &'a SyncRefCell<MultiReceiverState<E, R, MAX_STREAMS>>) -> Self {
125        Self {
126            state,
127            ack_buffer: [0u8; 32],
128        }
129    }
130}
131
132impl<'a, E: Clone, R: ReceiverBackend, const MAX_STREAMS: usize>
133    SrsppReceiverDriver<'a, E, R, MAX_STREAMS>
134{
135    /// Processes a received data packet and dispatches to the correct stream.
136    pub(super) async fn process_data(
137        &mut self,
138        packet: &[u8],
139        link: &mut impl NetworkWrite<Error = E>,
140    ) -> Result<(), TransportError<E>> {
141        if let Ok(SrsppType::Data) = SrsppPacket::parse(packet).and_then(|p| p.srspp_type()) {
142            if let Ok(data) = SrsppDataPacket::parse(packet) {
143                let source_address = data.srspp_header.source_address();
144                let seq = data.primary.sequence_count();
145                let flags = data.primary.sequence_flag();
146
147                let result =
148                    self.state
149                        .with_mut(|s| -> Result<HandleResult, TransportError<E>> {
150                            if !s.streams.contains_key(&source_address) {
151                                let _ = s.streams.insert(
152                                    source_address,
153                                    StreamState {
154                                        machine: R::new(),
155                                        ack_state: AckState::new(&s.config, source_address),
156                                        ack_deadline: None,
157                                        progress_deadline: None,
158                                    },
159                                );
160                            }
161                            if let Some(stream) = s.streams.get_mut(&source_address) {
162                                let outcome =
163                                    stream.machine.handle_data(seq, flags, &data.payload)?;
164                                Ok(stream.ack_state.on_data(
165                                    outcome,
166                                    stream.machine.expected_seq(),
167                                    stream.machine.recv_bitmap(),
168                                ))
169                            } else {
170                                Ok(HandleResult::default())
171                            }
172                        })?;
173
174                self.drive_actions(source_address, result, link).await?;
175            }
176        }
177
178        Ok(())
179    }
180
181    /// Processes expired ACK and progress timers across all streams.
182    pub(super) async fn handle_timeouts(
183        &mut self,
184        link: &mut impl NetworkWrite<Error = E>,
185    ) -> Result<(), TransportError<E>> {
186        let now = SysTime::now();
187
188        let ack_expired = self.state.with(|s| {
189            s.streams
190                .iter()
191                .filter_map(|(source, stream)| {
192                    stream.ack_deadline.filter(|&d| now >= d).map(|_| *source)
193                })
194                .collect::<heapless::Vec<_, MAX_STREAMS>>()
195        });
196
197        for source in ack_expired {
198            let result = self.state.with_mut(|s| {
199                if let Some(stream) = s.streams.get_mut(&source) {
200                    stream.ack_deadline = None;
201                    stream
202                        .ack_state
203                        .on_ack_timeout(stream.machine.expected_seq(), stream.machine.recv_bitmap())
204                } else {
205                    HandleResult::default()
206                }
207            });
208            self.drive_actions(source, result, link).await?;
209        }
210
211        let progress_expired = self.state.with(|s| {
212            s.streams
213                .iter()
214                .filter_map(|(source, stream)| {
215                    stream
216                        .progress_deadline
217                        .filter(|&d| now >= d)
218                        .map(|_| *source)
219                })
220                .collect::<heapless::Vec<_, MAX_STREAMS>>()
221        });
222
223        for source in progress_expired {
224            let result = self
225                .state
226                .with_mut(|s| -> Result<HandleResult, TransportError<E>> {
227                    if let Some(stream) = s.streams.get_mut(&source) {
228                        stream.progress_deadline = None;
229                        let outcome = stream.machine.skip_gap()?;
230                        Ok(stream.ack_state.on_gap_skip(outcome))
231                    } else {
232                        Ok(HandleResult::default())
233                    }
234                })?;
235            self.drive_actions(source, result, link).await?;
236        }
237
238        Ok(())
239    }
240
241    /// Returns the earliest receiver deadline (ACK or progress).
242    pub(super) fn next_deadline(&self) -> Option<SysTime> {
243        self.state.with(|s| {
244            s.streams
245                .iter()
246                .map(|(_, s)| s)
247                .flat_map(|s| [s.ack_deadline, s.progress_deadline])
248                .flatten()
249                .min()
250        })
251    }
252
253    /// Sends ACK and updates timers based on a state machine result.
254    async fn drive_actions(
255        &mut self,
256        source: Address,
257        result: HandleResult,
258        link: &mut impl NetworkWrite<Error = E>,
259    ) -> Result<(), TransportError<E>> {
260        if let Some(AckInfo {
261            destination,
262            cumulative_ack,
263            selective_bitmap,
264        }) = result.ack
265        {
266            let (local_address, apid, function_code) = self.state.with(|s| {
267                (
268                    s.config.local_address,
269                    s.config.apid,
270                    s.config.function_code,
271                )
272            });
273            let ack = SrsppAckPacket::builder()
274                .buffer(&mut self.ack_buffer)
275                .source_address(local_address)
276                .target(destination)
277                .apid(apid)
278                .function_code(function_code)
279                .cumulative_ack(cumulative_ack)
280                .selective_bitmap(selective_bitmap)
281                .sequence_count(SequenceCount::from(0))
282                .build()?;
283            link.write(zerocopy::IntoBytes::as_bytes(ack))
284                .await
285                .map_err(TransportError::Network)?;
286        }
287
288        let ack_delay = self.state.with(|s| s.ack_delay);
289        self.state.with_mut(|s| {
290            if let Some(action) = result.ack_timer {
291                if let Some(entry) = s.streams.get_mut(&source) {
292                    entry.ack_deadline = match action {
293                        TimerAction::Start { .. } => {
294                            Some(SysTime::now() + SysTime::from(ack_delay))
295                        }
296                        TimerAction::Stop => None,
297                    };
298                }
299            }
300            if let Some(action) = result.progress_timer {
301                if let Some(entry) = s.streams.get_mut(&source) {
302                    entry.progress_deadline = match action {
303                        TimerAction::Start { ticks } => {
304                            let delay = Duration::from_millis(ticks);
305                            Some(SysTime::now() + SysTime::from(delay))
306                        }
307                        TimerAction::Stop => None,
308                    };
309                }
310            }
311        });
312
313        Ok(())
314    }
315
316    /// Run the driver loop.
317    pub async fn run<const MTU: usize>(
318        &mut self,
319        link: &mut (impl NetworkWrite<Error = E> + NetworkRead<Error = E>),
320    ) -> Result<(), TransportError<E>> {
321        let mut recv_buffer = [0u8; MTU];
322        loop {
323            if self.state.with(|s| s.closed) {
324                return Ok(());
325            }
326
327            let timeout = duration_until(self.next_deadline());
328
329            let event = {
330                let read_fut = link.read(&mut recv_buffer).fuse();
331                let sleep_fut = sleep(timeout).fuse();
332                pin_utils::pin_mut!(read_fut, sleep_fut);
333                futures::select_biased! {
334                    r = read_fut => Some(r),
335                    _ = sleep_fut => None,
336                }
337            };
338
339            match event {
340                Some(Ok(len)) => {
341                    if let Err(e) = self.process_data(&recv_buffer[..len], link).await {
342                        self.state.with_mut(|s| s.error = Some(e.clone()));
343                        return Err(e);
344                    }
345                }
346                Some(Err(e)) => {
347                    let err = TransportError::Network(e);
348                    self.state.with_mut(|s| s.error = Some(err.clone()));
349                    return Err(err);
350                }
351                None => {
352                    if let Err(e) = self.handle_timeouts(link).await {
353                        self.state.with_mut(|s| s.error = Some(e.clone()));
354                        return Err(e);
355                    }
356                }
357            }
358        }
359    }
360}
361
362/// Handle for receiving data from an SRSPP receiver.
363pub struct SrsppRxHandle<'a, E, R: ReceiverBackend, const MAX_STREAMS: usize> {
364    /// Reference to the shared multi-stream receiver state.
365    pub(super) receiver: &'a SyncRefCell<MultiReceiverState<E, R, MAX_STREAMS>>,
366}
367
368impl<'a, E: Clone, R: ReceiverBackend, const MAX_STREAMS: usize>
369    SrsppRxHandle<'a, E, R, MAX_STREAMS>
370{
371    /// Receives the next message, copying it into `buf`.
372    pub async fn recv(&mut self, buf: &mut [u8]) -> Result<(Address, usize), TransportError<E>> {
373        poll_fn(|_cx| {
374            self.receiver.with_mut(|s| {
375                if let Some(ref e) = s.error {
376                    return Poll::Ready(Err(e.clone()));
377                }
378                for (source, stream) in s.streams.iter_mut() {
379                    if let Some(msg) = stream.machine.take_message() {
380                        let len = msg.len().min(buf.len());
381                        buf[..len].copy_from_slice(&msg[..len]);
382                        return Poll::Ready(Ok((*source, len)));
383                    }
384                }
385                Poll::Pending
386            })
387        })
388        .await
389    }
390
391    /// Signal that no more receives are expected.
392    /// Driver will exit.
393    pub fn close(&mut self) {
394        self.receiver.with_mut(|s| s.closed = true);
395    }
396
397    /// Check if there's a message ready from any sender.
398    pub fn has_message(&self) -> bool {
399        self.receiver
400            .with(|s| s.streams.iter().any(|(_, s)| s.machine.has_message()))
401    }
402
403    /// Get the number of active streams.
404    pub fn stream_count(&self) -> usize {
405        self.receiver.with(|s| s.streams.len())
406    }
407
408    /// Wait for a complete message to become available.
409    ///
410    /// Returns a [`DeliveryToken`] that borrows `&mut self`,
411    /// preventing further receives while the token is held.
412    /// The driver keeps running — the cell is **not** borrowed
413    /// until [`DeliveryToken::consume`] is called.
414    pub async fn wait_for_message(
415        &mut self,
416    ) -> Result<DeliveryToken<'_, 'a, E, R, MAX_STREAMS>, TransportError<E>> {
417        let (source, msg_len) = poll_fn(|_cx| {
418            self.receiver.with(|s| {
419                if let Some(ref e) = s.error {
420                    return Poll::Ready(Err(e.clone()));
421                }
422                for (source, stream) in s.streams.iter() {
423                    if let Some(len) = stream.machine.message_len() {
424                        return Poll::Ready(Ok((*source, len)));
425                    }
426                }
427                Poll::Pending
428            })
429        })
430        .await?;
431        Ok(DeliveryToken {
432            rx: self,
433            source,
434            msg_len,
435        })
436    }
437
438    /// Wait for a message and process it in-place with a closure.
439    ///
440    /// Equivalent to `wait_for_message().await?.consume(f)` but
441    /// more concise when you don't need the [`DeliveryToken`]
442    /// metadata (source address, length).
443    pub async fn recv_with<F, Ret>(&mut self, f: F) -> Result<Ret, TransportError<E>>
444    where
445        F: FnOnce(&[u8]) -> Ret,
446    {
447        let token = self.wait_for_message().await?;
448        Ok(token.consume(f))
449    }
450}
451
452/// Zero-copy delivery token returned by
453/// [`SrsppRxHandle::wait_for_message`].
454///
455/// Holds `&mut SrsppRxHandle`, preventing another receive while
456/// the token is alive.  The cell is **not** borrowed — the
457/// driver freely delivers new segments in the background.
458///
459/// Call [`consume`](Self::consume) with a synchronous closure to
460/// read the message and release the token in one step.
461pub struct DeliveryToken<'a, 'rx, E, R: ReceiverBackend, const MAX_STREAMS: usize> {
462    rx: &'a mut SrsppRxHandle<'rx, E, R, MAX_STREAMS>,
463    source: Address,
464    msg_len: usize,
465}
466
467impl<'a, 'rx, E: Clone, R: ReceiverBackend, const MAX_STREAMS: usize>
468    DeliveryToken<'a, 'rx, E, R, MAX_STREAMS>
469{
470    /// Byte length of the pending message.
471    pub fn len(&self) -> usize {
472        self.msg_len
473    }
474
475    /// Source address of the sender that produced this message.
476    pub fn source(&self) -> Address {
477        self.source
478    }
479
480    /// Pass the message data to `f`, consume the token, and
481    /// return whatever `f` returns.
482    ///
483    /// The cell is borrowed only for the duration of `f`.
484    pub fn consume<F, Ret>(self, f: F) -> Ret
485    where
486        F: FnOnce(&[u8]) -> Ret,
487    {
488        self.rx.receiver.with_mut(|s| {
489            let stream = s.streams.get_mut(&self.source).unwrap();
490            stream.machine.consume_message(f).unwrap()
491        })
492    }
493}