Skip to main content

leodos_protocols/coding/fec/
convolutional.rs

1//! CCSDS Convolutional Code (Rate 1/2, K=7) with Viterbi Decoding
2//!
3//! Implements the convolutional code specified in CCSDS 131.0-B-5
4//! (TM Synchronization and Channel Coding).
5//!
6//! # Parameters
7//!
8//! - Constraint length: K = 7 (64 encoder states)
9//! - Code rate: 1/2 (2 output symbols per input bit)
10//! - Generator polynomials: G1 = 171₈ (0x79), G2 = 133₈ (0x5B)
11//! - Tail: K−1 = 6 zero bits flush the encoder to state 0
12//!
13//! # Encoding
14//!
15//! Each input bit plus the 6-bit shift register state produces two
16//! coded symbols (G1 first, then G2). After all data bits, six zero
17//! tail bits terminate the trellis at state 0.
18//!
19//! # Decoding
20//!
21//! Soft-decision Viterbi algorithm with i16 LLR inputs (positive
22//! means bit 0 is more likely), matching the LDPC decoder convention.
23//! Uses a sliding-window traceback (depth 5K = 35) to bound stack
24//! usage regardless of frame length.
25
26/// Constraint length.
27pub const K: usize = 7;
28
29/// Number of encoder memory elements (K − 1).
30const MEMORY: usize = K - 1;
31
32/// Number of trellis states (2^(K−1) = 64).
33const NUM_STATES: usize = 1 << MEMORY;
34
35/// Generator polynomial G1 (octal 171).
36const G1: u8 = 0x79;
37
38/// Generator polynomial G2 (octal 133).
39const G2: u8 = 0x5B;
40
41/// Traceback depth for the sliding-window Viterbi decoder (5 × K).
42const TRACEBACK_DEPTH: usize = 5 * K;
43
44/// Path metric for unreachable states.
45const NEGINF: i32 = i32::MIN / 2;
46
47/// Errors from convolutional coding operations.
48#[derive(Debug, Copy, Clone, Eq, PartialEq)]
49pub enum ConvError {
50    /// Output buffer is too small.
51    BufferTooSmall {
52        /// Minimum required size in bytes.
53        required: usize,
54        /// Provided buffer size in bytes.
55        provided: usize,
56    },
57    /// LLR count must be even (two LLRs per trellis step).
58    OddLlrCount,
59    /// Frame too short (need at least K−1 trellis steps).
60    FrameTooShort,
61}
62
63/// Precomputed branch output symbols for each (state, input_bit).
64///
65/// `BRANCH[state][input]` = `(G1_output, G2_output)`.
66static BRANCH: [[(u8, u8); 2]; NUM_STATES] = {
67    let mut t = [[(0u8, 0u8); 2]; NUM_STATES];
68    let mut s = 0;
69    while s < NUM_STATES {
70        let mut b = 0u8;
71        while b < 2 {
72            let reg = (b << 6) | (s as u8);
73            let g1 = (reg & G1).count_ones() as u8 & 1;
74            let g2 = (reg & G2).count_ones() as u8 & 1;
75            t[s][b as usize] = (g1, g2);
76            b += 1;
77        }
78        s += 1;
79    }
80    t
81};
82
83/// Returns the number of encoded bytes for `data_len` data bytes.
84///
85/// Accounts for K−1 tail bits and rate-1/2 expansion.
86pub fn encoded_len(data_len: usize) -> usize {
87    let coded_bits = (data_len * 8 + MEMORY) * 2;
88    (coded_bits + 7) / 8
89}
90
91/// Encodes data using the rate-1/2, K=7 convolutional code.
92///
93/// Appends K−1 = 6 zero tail bits to terminate the trellis.
94/// Output symbols are packed MSB-first: for each input bit the
95/// G1 symbol comes first, then G2.
96///
97/// Returns the number of bytes written to `output`.
98pub fn encode(data: &[u8], output: &mut [u8]) -> Result<usize, ConvError> {
99    let out_len = encoded_len(data.len());
100    if output.len() < out_len {
101        return Err(ConvError::BufferTooSmall {
102            required: out_len,
103            provided: output.len(),
104        });
105    }
106
107    output[..out_len].fill(0);
108    let mut state = 0u8;
109    let mut oi = 0usize;
110
111    for &byte in data {
112        for bit_pos in (0..8).rev() {
113            let b = (byte >> bit_pos) & 1;
114            let (g1, g2) = BRANCH[state as usize][b as usize];
115            output[oi / 8] |= g1 << (7 - oi % 8);
116            oi += 1;
117            output[oi / 8] |= g2 << (7 - oi % 8);
118            oi += 1;
119            state = (state >> 1) | (b << 5);
120        }
121    }
122
123    for _ in 0..MEMORY {
124        let (g1, g2) = BRANCH[state as usize][0];
125        output[oi / 8] |= g1 << (7 - oi % 8);
126        oi += 1;
127        output[oi / 8] |= g2 << (7 - oi % 8);
128        oi += 1;
129        state >>= 1;
130    }
131
132    Ok(out_len)
133}
134
135/// Converts hard bits (packed bytes, MSB-first) to LLR values.
136///
137/// Each bit becomes an i16: bit 0 → `+magnitude`, bit 1 → `−magnitude`.
138/// This is useful for testing the Viterbi decoder with hard-decision
139/// input.
140pub fn hard_to_llr(bits: &[u8], num_bits: usize, magnitude: i16, llrs: &mut [i16]) {
141    for i in 0..num_bits.min(llrs.len()) {
142        let bit = (bits[i / 8] >> (7 - i % 8)) & 1;
143        llrs[i] = if bit == 0 { magnitude } else { -magnitude };
144    }
145}
146
147/// Decodes a convolutionally coded frame using soft-decision Viterbi.
148///
149/// `llrs` contains paired (G1, G2) log-likelihood ratios per trellis
150/// step. Positive LLR means bit 0 is more likely. The trellis must
151/// be terminated (K−1 tail steps appended by the encoder).
152///
153/// Returns the number of decoded data bytes written to `output`.
154pub fn decode(llrs: &[i16], output: &mut [u8]) -> Result<usize, ConvError> {
155    if llrs.len() % 2 != 0 {
156        return Err(ConvError::OddLlrCount);
157    }
158
159    let num_steps = llrs.len() / 2;
160    if num_steps < MEMORY {
161        return Err(ConvError::FrameTooShort);
162    }
163
164    let info_bits = num_steps - MEMORY;
165    let out_bytes = (info_bits + 7) / 8;
166
167    if output.len() < out_bytes {
168        return Err(ConvError::BufferTooSmall {
169            required: out_bytes,
170            provided: output.len(),
171        });
172    }
173
174    output[..out_bytes].fill(0);
175
176    let mut pm = [[NEGINF; NUM_STATES]; 2];
177    pm[0][0] = 0;
178    let mut cur = 0usize;
179    let mut decisions = [0u64; TRACEBACK_DEPTH];
180    let mut decoded = 0usize;
181
182    for step in 0..num_steps {
183        let l1 = llrs[2 * step] as i32;
184        let l2 = llrs[2 * step + 1] as i32;
185
186        // Branch metrics: bm[g1_out][g2_out]
187        let bm = [
188            [l1 + l2, l1 - l2],
189            [-l1 + l2, -l1 - l2],
190        ];
191
192        let prev = cur;
193        cur = 1 - cur;
194        let dslot = step % TRACEBACK_DEPTH;
195        let mut dword = 0u64;
196
197        for ns in 0..NUM_STATES {
198            let input = (ns >> 5) & 1;
199            let p0 = (ns << 1) & (NUM_STATES - 1);
200            let p1 = p0 | 1;
201
202            let (g1_0, g2_0) = BRANCH[p0][input];
203            let (g1_1, g2_1) = BRANCH[p1][input];
204
205            let m0 = pm[prev][p0]
206                .saturating_add(bm[g1_0 as usize][g2_0 as usize]);
207            let m1 = pm[prev][p1]
208                .saturating_add(bm[g1_1 as usize][g2_1 as usize]);
209
210            if m1 > m0 {
211                pm[cur][ns] = m1;
212                dword |= 1 << ns;
213            } else {
214                pm[cur][ns] = m0;
215            }
216        }
217
218        decisions[dslot] = dword;
219
220        // Normalize metrics to prevent overflow
221        if step & 0xFF == 0xFF {
222            let mut min_val = pm[cur][0];
223            for s in 1..NUM_STATES {
224                if pm[cur][s] < min_val {
225                    min_val = pm[cur][s];
226                }
227            }
228            for m in &mut pm[cur] {
229                *m -= min_val;
230            }
231        }
232
233        // Sliding-window traceback: output one decoded bit per step
234        // once the window is full.
235        if step >= TRACEBACK_DEPTH - 1 && decoded < info_bits {
236            // Find the state with the best metric
237            let mut best_s = 0;
238            let mut best_m = pm[cur][0];
239            for s in 1..NUM_STATES {
240                if pm[cur][s] > best_m {
241                    best_m = pm[cur][s];
242                    best_s = s;
243                }
244            }
245
246            // Trace back through the decision buffer
247            let mut s = best_s;
248            let mut slot = dslot;
249            for _ in 0..TRACEBACK_DEPTH - 1 {
250                let d = ((decisions[slot] >> s) & 1) as usize;
251                s = ((s << 1) | d) & (NUM_STATES - 1);
252                slot = if slot == 0 {
253                    TRACEBACK_DEPTH - 1
254                } else {
255                    slot - 1
256                };
257            }
258
259            // The MSB of the traced-back state is the decoded bit
260            let bit = (s >> 5) & 1;
261            if bit == 1 {
262                output[decoded / 8] |= 1 << (7 - decoded % 8);
263            }
264            decoded += 1;
265        }
266    }
267
268    // Final traceback from the known terminal state (0) for the
269    // remaining info bits that the sliding window didn't cover.
270    if decoded < info_bits {
271        let mut s = 0usize;
272        let last_slot = (num_steps - 1) % TRACEBACK_DEPTH;
273        let available = num_steps.min(TRACEBACK_DEPTH);
274
275        // Collect input bits by tracing backward from state 0.
276        // bits[i] = input bit at trellis step (num_steps − 1 − i).
277        let mut bits = [0u8; TRACEBACK_DEPTH];
278        bits[0] = ((s >> 5) & 1) as u8;
279
280        let mut slot = last_slot;
281        for i in 1..available {
282            let d = ((decisions[slot] >> s) & 1) as usize;
283            s = ((s << 1) | d) & (NUM_STATES - 1);
284            slot = if slot == 0 {
285                TRACEBACK_DEPTH - 1
286            } else {
287                slot - 1
288            };
289            bits[i] = ((s >> 5) & 1) as u8;
290        }
291
292        while decoded < info_bits {
293            let rev_idx = num_steps - 1 - decoded;
294            if rev_idx < available && bits[rev_idx] == 1 {
295                output[decoded / 8] |= 1 << (7 - decoded % 8);
296            }
297            decoded += 1;
298        }
299    }
300
301    Ok(out_bytes)
302}
303
304/// Convolutional encoder implementing [`FecEncoder`](crate::coding::FecEncoder).
305pub struct ConvolutionalEncoder;
306
307impl crate::coding::FecEncoder for ConvolutionalEncoder {
308    type Error = ConvError;
309
310    fn encode(&self, data: &[u8], output: &mut [u8]) -> Result<usize, Self::Error> {
311        encode(data, output)
312    }
313}
314
315/// Hard-decision Viterbi decoder implementing [`FecDecoder`](crate::coding::FecDecoder).
316pub struct ViterbiDecoder {
317    llr_magnitude: i16,
318}
319
320impl ViterbiDecoder {
321    /// Creates a decoder with the given hard-decision LLR magnitude.
322    pub fn new(llr_magnitude: i16) -> Self {
323        Self { llr_magnitude }
324    }
325}
326
327impl crate::coding::FecDecoder for ViterbiDecoder {
328    type Error = ConvError;
329
330    fn decode(&self, data: &mut [u8]) -> Result<usize, Self::Error> {
331        let num_bits = data.len() * 8;
332        let mut llrs = [0i16; 8192];
333        hard_to_llr(data, num_bits, self.llr_magnitude, &mut llrs[..num_bits]);
334        let mut output = [0u8; 1024];
335        let len = decode(&llrs[..num_bits], &mut output)?;
336        data[..len].copy_from_slice(&output[..len]);
337        Ok(len)
338    }
339}
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344
345    #[test]
346    fn encoded_len_calculation() {
347        // 1 byte = 8 bits + 6 tail = 14 steps × 2 = 28 coded bits
348        // = 4 bytes (28 / 8 = 3.5 → 4)
349        assert_eq!(encoded_len(1), 4);
350        // 10 bytes = 80 + 6 = 86 steps × 2 = 172 bits = 22 bytes
351        assert_eq!(encoded_len(10), 22);
352        // 0 bytes = 6 tail × 2 = 12 bits = 2 bytes
353        assert_eq!(encoded_len(0), 2);
354    }
355
356    #[test]
357    fn encode_zeros() {
358        let data = [0u8; 4];
359        let mut out = [0u8; 10];
360        let len = encode(&data, &mut out).unwrap();
361
362        // All-zero input with all-zero initial state should produce
363        // all-zero output (G1(0)=0, G2(0)=0 for state 0, input 0).
364        assert_eq!(len, encoded_len(4));
365        assert!(out[..len].iter().all(|&b| b == 0));
366    }
367
368    #[test]
369    fn encode_buffer_too_small() {
370        let data = [0u8; 4];
371        let mut out = [0u8; 2]; // too small
372        let err = encode(&data, &mut out);
373        assert!(matches!(err, Err(ConvError::BufferTooSmall { .. })));
374    }
375
376    #[test]
377    fn encoder_state_returns_to_zero() {
378        // After encoding with tail bits, encoder must be in state 0.
379        // We verify indirectly: the last 6 coded symbol pairs should
380        // all be (0,0) when starting from state 0 (which happens
381        // when input is all zeros).
382        let data = [0u8; 1];
383        let mut out = [0u8; 4];
384        encode(&data, &mut out).unwrap();
385
386        // For all-zero input, state never leaves 0, so all output = 0
387        assert!(out.iter().all(|&b| b == 0));
388    }
389
390    #[test]
391    fn encode_known_pattern() {
392        // Single byte 0x80 = bit pattern 10000000
393        // Bit 0 (=1): state 0, input 1
394        //   reg = (1<<6)|0 = 0b1000000
395        //   G1 = parity(0b1000000 & 0x79) = 1
396        //   G2 = parity(0b1000000 & 0x5B) = 1
397        //   state → 32
398        // Bit 1 (=0): state 32, input 0
399        //   reg = (0<<6)|32 = 0b0100000
400        //   G1 = parity(0b0100000 & 0x79) = 1 (bit 5 matches)
401        //   G2 = parity(0b0100000 & 0x5B) = 0 (no overlap)
402        //   state → 16
403        // First 4 coded bits: 1,1,1,0 → 0xE0
404        let data = [0x80];
405        let mut out = [0u8; 4];
406        encode(&data, &mut out).unwrap();
407        assert_eq!(out[0] & 0xF0, 0xE0);
408    }
409
410    #[test]
411    fn roundtrip_hard_decision() {
412        let data = [0xDE, 0xAD, 0xBE, 0xEF];
413        let mut encoded = [0u8; 32];
414        encode(&data, &mut encoded).unwrap();
415
416        let num_bits = (data.len() * 8 + MEMORY) * 2;
417        let mut llrs = [0i16; 256];
418        hard_to_llr(&encoded, num_bits, 127, &mut llrs);
419
420        let mut decoded = [0u8; 4];
421        let len = decode(&llrs[..num_bits], &mut decoded).unwrap();
422        assert_eq!(len, 4);
423        assert_eq!(decoded, data);
424    }
425
426    #[test]
427    fn roundtrip_single_byte() {
428        for val in [0x00, 0x01, 0x55, 0xAA, 0xFF] {
429            let data = [val];
430            let mut encoded = [0u8; 4];
431            encode(&data, &mut encoded).unwrap();
432
433            let num_bits = (8 + MEMORY) * 2;
434            let mut llrs = [0i16; 32];
435            hard_to_llr(&encoded, num_bits, 127, &mut llrs);
436
437            let mut decoded = [0u8; 1];
438            decode(&llrs[..num_bits], &mut decoded).unwrap();
439            assert_eq!(decoded[0], val, "failed for 0x{val:02X}");
440        }
441    }
442
443    #[test]
444    fn roundtrip_large_frame() {
445        let mut data = [0u8; 128];
446        for i in 0..data.len() {
447            data[i] = i as u8;
448        }
449        let mut encoded = [0u8; 300];
450        encode(&data, &mut encoded).unwrap();
451
452        let num_bits = (data.len() * 8 + MEMORY) * 2;
453        let mut llrs = [0i16; 2200];
454        hard_to_llr(&encoded, num_bits, 127, &mut llrs);
455
456        let mut decoded = [0u8; 128];
457        decode(&llrs[..num_bits], &mut decoded).unwrap();
458        assert_eq!(decoded, data);
459    }
460
461    #[test]
462    fn corrects_bit_errors() {
463        let data = [0x42, 0x37, 0x99, 0x10];
464        let mut encoded = [0u8; 32];
465        encode(&data, &mut encoded).unwrap();
466
467        let num_bits = (data.len() * 8 + MEMORY) * 2;
468        let mut llrs = [0i16; 256];
469        hard_to_llr(&encoded, num_bits, 50, &mut llrs);
470
471        // Flip some bits (simulate channel errors) by negating LLRs
472        llrs[0] = -llrs[0];
473        llrs[15] = -llrs[15];
474        llrs[30] = -llrs[30];
475        llrs[50] = -llrs[50];
476
477        let mut decoded = [0u8; 4];
478        decode(&llrs[..num_bits], &mut decoded).unwrap();
479        assert_eq!(decoded, data);
480    }
481
482    #[test]
483    fn soft_decision_advantage() {
484        // With soft decisions (varied magnitudes), the decoder should
485        // still recover the correct data even with a flipped hard bit,
486        // as long as the LLR magnitude for that bit is small.
487        let data = [0xAB];
488        let mut encoded = [0u8; 4];
489        encode(&data, &mut encoded).unwrap();
490
491        let num_bits = (8 + MEMORY) * 2;
492        let mut llrs = [0i16; 32];
493        hard_to_llr(&encoded, num_bits, 100, &mut llrs);
494
495        // Flip bit 3 but make it a weak decision
496        llrs[3] = -llrs[3].signum() * 5;
497
498        let mut decoded = [0u8; 1];
499        decode(&llrs[..num_bits], &mut decoded).unwrap();
500        assert_eq!(decoded[0], 0xAB);
501    }
502
503    #[test]
504    fn decode_odd_llr_count() {
505        let llrs = [0i16; 5];
506        let mut out = [0u8; 1];
507        assert_eq!(decode(&llrs, &mut out), Err(ConvError::OddLlrCount));
508    }
509
510    #[test]
511    fn decode_frame_too_short() {
512        let llrs = [0i16; 4]; // 2 steps < MEMORY=6
513        let mut out = [0u8; 1];
514        assert_eq!(decode(&llrs, &mut out), Err(ConvError::FrameTooShort));
515    }
516
517    #[test]
518    fn decode_output_too_small() {
519        // 20 LLRs = 10 steps, 10 - 6 = 4 info bits → need 1 byte
520        let llrs = [100i16; 20];
521        let mut out = [0u8; 0];
522        assert!(matches!(
523            decode(&llrs, &mut out),
524            Err(ConvError::BufferTooSmall { .. })
525        ));
526    }
527
528    #[test]
529    fn branch_table_symmetry() {
530        // For state 0, input 0: output should be (0, 0)
531        assert_eq!(BRANCH[0][0], (0, 0));
532        // For state 0, input 1: reg = 0b1000000 = 64
533        //   G1: 64 & 0x79 = 0b1000000 & 0b1111001 = 0b1000000 → parity=1
534        //   G2: 64 & 0x5B = 0b1000000 & 0b1011011 = 0b1000000 → parity=1
535        assert_eq!(BRANCH[0][1], (1, 1));
536    }
537
538    #[test]
539    fn roundtrip_all_ones() {
540        let data = [0xFF; 8];
541        let mut encoded = [0u8; 24];
542        encode(&data, &mut encoded).unwrap();
543
544        let num_bits = (64 + MEMORY) * 2;
545        let mut llrs = [0i16; 256];
546        hard_to_llr(&encoded, num_bits, 127, &mut llrs);
547
548        let mut decoded = [0u8; 8];
549        decode(&llrs[..num_bits], &mut decoded).unwrap();
550        assert_eq!(decoded, data);
551    }
552}