Skip to main content

leodos_protocols/misc/sle/
ber.rs

1//! Minimal ASN.1 BER encoder/decoder for SLE.
2//!
3//! Implements only the subset needed by SLE: INTEGER, OCTET STRING,
4//! SEQUENCE, CHOICE, ENUMERATED, NULL, BIT STRING, BOOLEAN.
5//! No heap allocation — all operations work on caller-provided
6//! byte slices.
7
8use super::types::SleError;
9
10/// ASN.1 tag class.
11#[derive(Copy, Clone, Debug, PartialEq, Eq)]
12#[repr(u8)]
13pub enum Class {
14    /// Universal (built-in ASN.1 types).
15    Universal = 0,
16    /// Application-specific.
17    Application = 1,
18    /// Context-specific (used in CHOICE / tagged fields).
19    Context = 2,
20    /// Private.
21    Private = 3,
22}
23
24impl Class {
25    fn from_bits(bits: u8) -> Self {
26        match bits {
27            0 => Self::Universal,
28            1 => Self::Application,
29            2 => Self::Context,
30            _ => Self::Private,
31        }
32    }
33}
34
35/// Well-known universal tag numbers.
36pub mod tags {
37    /// BOOLEAN
38    pub const BOOLEAN: u8 = 1;
39    /// INTEGER
40    pub const INTEGER: u8 = 2;
41    /// BIT STRING
42    pub const BIT_STRING: u8 = 3;
43    /// OCTET STRING
44    pub const OCTET_STRING: u8 = 4;
45    /// NULL
46    pub const NULL: u8 = 5;
47    /// ENUMERATED
48    pub const ENUMERATED: u8 = 10;
49    /// SEQUENCE / SEQUENCE OF
50    pub const SEQUENCE: u8 = 16;
51}
52
53/// Encodes a single-byte tag octet.
54///
55/// Only supports tag numbers 0..30 (short form). Returns the
56/// encoded tag byte.
57pub fn encode_tag(tag: u8, class: Class, constructed: bool) -> u8 {
58    let class_bits = (class as u8) << 6;
59    let constructed_bit = if constructed { 0x20 } else { 0x00 };
60    class_bits | constructed_bit | (tag & 0x1F)
61}
62
63/// Encodes a BER length into `buf`. Returns number of bytes written.
64///
65/// Short form: lengths 0..127 use one byte.
66/// Long form: lengths 128.. use 1 byte for count + N value bytes.
67pub fn encode_length(len: usize, buf: &mut [u8]) -> Result<usize, SleError> {
68    if len < 128 {
69        if buf.is_empty() {
70            return Err(SleError::BufferTooSmall);
71        }
72        buf[0] = len as u8;
73        Ok(1)
74    } else if len <= 0xFF {
75        if buf.len() < 2 {
76            return Err(SleError::BufferTooSmall);
77        }
78        buf[0] = 0x81;
79        buf[1] = len as u8;
80        Ok(2)
81    } else if len <= 0xFFFF {
82        if buf.len() < 3 {
83            return Err(SleError::BufferTooSmall);
84        }
85        buf[0] = 0x82;
86        buf[1] = (len >> 8) as u8;
87        buf[2] = len as u8;
88        Ok(3)
89    } else if len <= 0xFF_FFFF {
90        if buf.len() < 4 {
91            return Err(SleError::BufferTooSmall);
92        }
93        buf[0] = 0x83;
94        buf[1] = (len >> 16) as u8;
95        buf[2] = (len >> 8) as u8;
96        buf[3] = len as u8;
97        Ok(4)
98    } else {
99        if buf.len() < 5 {
100            return Err(SleError::BufferTooSmall);
101        }
102        buf[0] = 0x84;
103        buf[1] = (len >> 24) as u8;
104        buf[2] = (len >> 16) as u8;
105        buf[3] = (len >> 8) as u8;
106        buf[4] = len as u8;
107        Ok(5)
108    }
109}
110
111/// Decodes a BER tag byte.
112///
113/// Returns `(tag_number, class, constructed, bytes_consumed)`.
114/// Only supports short-form tags (tag number 0..30).
115pub fn decode_tag(
116    buf: &[u8],
117) -> Result<(u8, Class, bool, usize), SleError> {
118    if buf.is_empty() {
119        return Err(SleError::Truncated);
120    }
121    let b = buf[0];
122    let class = Class::from_bits(b >> 6);
123    let constructed = (b & 0x20) != 0;
124    let tag = b & 0x1F;
125    if tag == 0x1F {
126        return Err(SleError::UnexpectedTag);
127    }
128    Ok((tag, class, constructed, 1))
129}
130
131/// Decodes a BER length field.
132///
133/// Returns `(length_value, bytes_consumed)`.
134pub fn decode_length(buf: &[u8]) -> Result<(usize, usize), SleError> {
135    if buf.is_empty() {
136        return Err(SleError::Truncated);
137    }
138    let first = buf[0];
139    if first < 128 {
140        Ok((first as usize, 1))
141    } else {
142        let num_bytes = (first & 0x7F) as usize;
143        if num_bytes == 0 || num_bytes > 4 {
144            return Err(SleError::Truncated);
145        }
146        if buf.len() < 1 + num_bytes {
147            return Err(SleError::Truncated);
148        }
149        let mut val: usize = 0;
150        for i in 0..num_bytes {
151            val = (val << 8) | buf[1 + i] as usize;
152        }
153        Ok((val, 1 + num_bytes))
154    }
155}
156
157/// A BER writer that encodes TLV elements into a byte slice.
158pub struct BerWriter<'a> {
159    buf: &'a mut [u8],
160    pos: usize,
161}
162
163impl<'a> BerWriter<'a> {
164    /// Creates a new writer over the given buffer.
165    pub fn new(buf: &'a mut [u8]) -> Self {
166        Self { buf, pos: 0 }
167    }
168
169    /// Returns the number of bytes written so far.
170    pub fn len(&self) -> usize {
171        self.pos
172    }
173
174    /// Returns true if no bytes have been written.
175    pub fn is_empty(&self) -> bool {
176        self.pos == 0
177    }
178
179    /// Returns the bytes written so far.
180    pub fn as_bytes(&self) -> &[u8] {
181        &self.buf[..self.pos]
182    }
183
184    /// Returns the remaining writable capacity.
185    fn remaining(&self) -> usize {
186        self.buf.len() - self.pos
187    }
188
189    fn write_byte(&mut self, b: u8) -> Result<(), SleError> {
190        if self.remaining() == 0 {
191            return Err(SleError::BufferTooSmall);
192        }
193        self.buf[self.pos] = b;
194        self.pos += 1;
195        Ok(())
196    }
197
198    fn write_bytes(&mut self, data: &[u8]) -> Result<(), SleError> {
199        if self.remaining() < data.len() {
200            return Err(SleError::BufferTooSmall);
201        }
202        self.buf[self.pos..self.pos + data.len()].copy_from_slice(data);
203        self.pos += data.len();
204        Ok(())
205    }
206
207    fn write_tag(
208        &mut self,
209        tag: u8,
210        class: Class,
211        constructed: bool,
212    ) -> Result<(), SleError> {
213        self.write_byte(encode_tag(tag, class, constructed))
214    }
215
216    fn write_length(&mut self, len: usize) -> Result<(), SleError> {
217        let n = encode_length(len, &mut self.buf[self.pos..])?;
218        self.pos += n;
219        Ok(())
220    }
221
222    /// Writes an ASN.1 BOOLEAN.
223    pub fn write_bool(&mut self, value: bool) -> Result<(), SleError> {
224        self.write_tag(tags::BOOLEAN, Class::Universal, false)?;
225        self.write_length(1)?;
226        self.write_byte(if value { 0xFF } else { 0x00 })
227    }
228
229    /// Writes an ASN.1 INTEGER (signed, variable length).
230    pub fn write_integer(&mut self, value: i64) -> Result<(), SleError> {
231        self.write_tag(tags::INTEGER, Class::Universal, false)?;
232        let encoded = encode_i64(value);
233        self.write_length(encoded.len)?;
234        self.write_bytes(&encoded.bytes[..encoded.len])
235    }
236
237    /// Writes an ASN.1 OCTET STRING.
238    pub fn write_octet_string(
239        &mut self,
240        data: &[u8],
241    ) -> Result<(), SleError> {
242        self.write_tag(tags::OCTET_STRING, Class::Universal, false)?;
243        self.write_length(data.len())?;
244        self.write_bytes(data)
245    }
246
247    /// Writes an ASN.1 ENUMERATED value.
248    pub fn write_enum(&mut self, value: i64) -> Result<(), SleError> {
249        self.write_tag(tags::ENUMERATED, Class::Universal, false)?;
250        let encoded = encode_i64(value);
251        self.write_length(encoded.len)?;
252        self.write_bytes(&encoded.bytes[..encoded.len])
253    }
254
255    /// Writes an ASN.1 NULL.
256    pub fn write_null(&mut self) -> Result<(), SleError> {
257        self.write_tag(tags::NULL, Class::Universal, false)?;
258        self.write_length(0)
259    }
260
261    /// Writes an ASN.1 BIT STRING with zero unused bits.
262    pub fn write_bit_string(
263        &mut self,
264        data: &[u8],
265    ) -> Result<(), SleError> {
266        self.write_tag(tags::BIT_STRING, Class::Universal, false)?;
267        self.write_length(data.len() + 1)?;
268        self.write_byte(0)?; // unused bits = 0
269        self.write_bytes(data)
270    }
271
272    /// Begins a SEQUENCE. Returns the position of the length field
273    /// so it can be patched later with `end_sequence`.
274    pub fn begin_sequence(&mut self) -> Result<SeqStart, SleError> {
275        self.write_tag(tags::SEQUENCE, Class::Universal, true)?;
276        let len_pos = self.pos;
277        // Reserve space for a 3-byte length (0x82 + 2 bytes).
278        // This supports sequences up to 65535 bytes.
279        if self.remaining() < 3 {
280            return Err(SleError::BufferTooSmall);
281        }
282        self.pos += 3;
283        Ok(SeqStart {
284            len_pos,
285            content_start: self.pos,
286        })
287    }
288
289    /// Ends a SEQUENCE started with `begin_sequence`, patching the
290    /// length field.
291    pub fn end_sequence(
292        &mut self,
293        start: SeqStart,
294    ) -> Result<(), SleError> {
295        let content_len = self.pos - start.content_start;
296        if content_len < 128 {
297            // Shift content left by 2 bytes (we reserved 3, need 1).
298            let src = start.content_start;
299            let dst = start.len_pos + 1;
300            self.buf.copy_within(src..self.pos, dst);
301            self.buf[start.len_pos] = content_len as u8;
302            self.pos -= 2;
303        } else if content_len <= 0xFF {
304            // Shift content left by 1 byte (we reserved 3, need 2).
305            let src = start.content_start;
306            let dst = start.len_pos + 2;
307            self.buf.copy_within(src..self.pos, dst);
308            self.buf[start.len_pos] = 0x81;
309            self.buf[start.len_pos + 1] = content_len as u8;
310            self.pos -= 1;
311        } else {
312            // Exact fit: 3-byte length encoding.
313            self.buf[start.len_pos] = 0x82;
314            self.buf[start.len_pos + 1] = (content_len >> 8) as u8;
315            self.buf[start.len_pos + 2] = content_len as u8;
316        }
317        Ok(())
318    }
319
320    /// Writes a context-tagged constructed wrapper (implicit tag).
321    /// Returns a SeqStart for use with `end_sequence`.
322    pub fn begin_context(
323        &mut self,
324        tag: u8,
325        constructed: bool,
326    ) -> Result<SeqStart, SleError> {
327        self.write_tag(tag, Class::Context, constructed)?;
328        let len_pos = self.pos;
329        if self.remaining() < 3 {
330            return Err(SleError::BufferTooSmall);
331        }
332        self.pos += 3;
333        Ok(SeqStart {
334            len_pos,
335            content_start: self.pos,
336        })
337    }
338}
339
340/// Bookkeeping for an in-progress SEQUENCE or tagged wrapper.
341#[derive(Copy, Clone, Debug)]
342pub struct SeqStart {
343    len_pos: usize,
344    content_start: usize,
345}
346
347/// A BER reader that decodes TLV elements from a byte slice.
348pub struct BerReader<'a> {
349    buf: &'a [u8],
350    pos: usize,
351}
352
353impl<'a> BerReader<'a> {
354    /// Creates a new reader over the given buffer.
355    pub fn new(buf: &'a [u8]) -> Self {
356        Self { buf, pos: 0 }
357    }
358
359    /// Returns the current read position.
360    pub fn pos(&self) -> usize {
361        self.pos
362    }
363
364    /// Returns the remaining unread bytes.
365    pub fn remaining(&self) -> usize {
366        self.buf.len() - self.pos
367    }
368
369    /// Returns true if all bytes have been consumed.
370    pub fn is_empty(&self) -> bool {
371        self.remaining() == 0
372    }
373
374    /// Peeks at the next tag without advancing the position.
375    /// Returns `(tag_number, class, constructed)`.
376    pub fn peek_tag(
377        &self,
378    ) -> Result<(u8, Class, bool), SleError> {
379        let (tag, class, constructed, _) =
380            decode_tag(&self.buf[self.pos..])?;
381        Ok((tag, class, constructed))
382    }
383
384    /// Reads and validates a tag, returning its components.
385    pub fn read_tag(
386        &mut self,
387    ) -> Result<(u8, Class, bool), SleError> {
388        let (tag, class, constructed, consumed) =
389            decode_tag(&self.buf[self.pos..])?;
390        self.pos += consumed;
391        Ok((tag, class, constructed))
392    }
393
394    /// Reads a length field.
395    pub fn read_length(&mut self) -> Result<usize, SleError> {
396        let (len, consumed) =
397            decode_length(&self.buf[self.pos..])?;
398        self.pos += consumed;
399        Ok(len)
400    }
401
402    /// Reads raw bytes of the given length.
403    pub fn read_raw(
404        &mut self,
405        len: usize,
406    ) -> Result<&'a [u8], SleError> {
407        if self.remaining() < len {
408            return Err(SleError::Truncated);
409        }
410        let data = &self.buf[self.pos..self.pos + len];
411        self.pos += len;
412        Ok(data)
413    }
414
415    /// Reads an ASN.1 BOOLEAN.
416    pub fn read_bool(&mut self) -> Result<bool, SleError> {
417        let (tag, class, _, consumed) =
418            decode_tag(&self.buf[self.pos..])?;
419        if tag != tags::BOOLEAN || class != Class::Universal {
420            return Err(SleError::UnexpectedTag);
421        }
422        self.pos += consumed;
423        let len = self.read_length()?;
424        if len != 1 {
425            return Err(SleError::Truncated);
426        }
427        let val = self.buf[self.pos];
428        self.pos += 1;
429        Ok(val != 0)
430    }
431
432    /// Reads an ASN.1 INTEGER as i64.
433    pub fn read_integer(&mut self) -> Result<i64, SleError> {
434        let (tag, class, _, consumed) =
435            decode_tag(&self.buf[self.pos..])?;
436        if tag != tags::INTEGER || class != Class::Universal {
437            return Err(SleError::UnexpectedTag);
438        }
439        self.pos += consumed;
440        let len = self.read_length()?;
441        if len == 0 || len > 8 {
442            return Err(SleError::IntegerOverflow);
443        }
444        let data = self.read_raw(len)?;
445        Ok(decode_i64(data))
446    }
447
448    /// Reads an ASN.1 OCTET STRING, returning the raw bytes.
449    pub fn read_octet_string(
450        &mut self,
451    ) -> Result<&'a [u8], SleError> {
452        let (tag, class, _, consumed) =
453            decode_tag(&self.buf[self.pos..])?;
454        if tag != tags::OCTET_STRING || class != Class::Universal {
455            return Err(SleError::UnexpectedTag);
456        }
457        self.pos += consumed;
458        let len = self.read_length()?;
459        self.read_raw(len)
460    }
461
462    /// Reads an ASN.1 ENUMERATED as i64.
463    pub fn read_enum(&mut self) -> Result<i64, SleError> {
464        let (tag, class, _, consumed) =
465            decode_tag(&self.buf[self.pos..])?;
466        if tag != tags::ENUMERATED || class != Class::Universal {
467            return Err(SleError::UnexpectedTag);
468        }
469        self.pos += consumed;
470        let len = self.read_length()?;
471        if len == 0 || len > 8 {
472            return Err(SleError::IntegerOverflow);
473        }
474        let data = self.read_raw(len)?;
475        Ok(decode_i64(data))
476    }
477
478    /// Reads an ASN.1 NULL.
479    pub fn read_null(&mut self) -> Result<(), SleError> {
480        let (tag, class, _, consumed) =
481            decode_tag(&self.buf[self.pos..])?;
482        if tag != tags::NULL || class != Class::Universal {
483            return Err(SleError::UnexpectedTag);
484        }
485        self.pos += consumed;
486        let len = self.read_length()?;
487        if len != 0 {
488            return Err(SleError::Truncated);
489        }
490        Ok(())
491    }
492
493    /// Reads a SEQUENCE tag+length, returning the content length.
494    /// The caller should then read the contained elements.
495    pub fn read_sequence(&mut self) -> Result<usize, SleError> {
496        let (tag, class, constructed, consumed) =
497            decode_tag(&self.buf[self.pos..])?;
498        if tag != tags::SEQUENCE
499            || class != Class::Universal
500            || !constructed
501        {
502            return Err(SleError::UnexpectedTag);
503        }
504        self.pos += consumed;
505        self.read_length()
506    }
507
508    /// Reads a context-tagged wrapper, returning the tag number
509    /// and content length.
510    pub fn read_context_tag(
511        &mut self,
512    ) -> Result<(u8, usize), SleError> {
513        let (tag, class, _, consumed) =
514            decode_tag(&self.buf[self.pos..])?;
515        if class != Class::Context {
516            return Err(SleError::UnexpectedTag);
517        }
518        self.pos += consumed;
519        let len = self.read_length()?;
520        Ok((tag, len))
521    }
522
523    /// Creates a sub-reader limited to `len` bytes from the
524    /// current position, advancing past them.
525    pub fn sub_reader(
526        &mut self,
527        len: usize,
528    ) -> Result<BerReader<'a>, SleError> {
529        if self.remaining() < len {
530            return Err(SleError::Truncated);
531        }
532        let sub = BerReader::new(
533            &self.buf[self.pos..self.pos + len],
534        );
535        self.pos += len;
536        Ok(sub)
537    }
538}
539
540/// Encoded integer bytes + length.
541struct EncodedInt {
542    bytes: [u8; 8],
543    len: usize,
544}
545
546/// Encodes a signed i64 into the minimum BER integer bytes.
547fn encode_i64(value: i64) -> EncodedInt {
548    let raw = value.to_be_bytes();
549    let mut start = 0;
550    if value >= 0 {
551        while start < 7 && raw[start] == 0 && raw[start + 1] < 0x80
552        {
553            start += 1;
554        }
555    } else {
556        while start < 7
557            && raw[start] == 0xFF
558            && raw[start + 1] >= 0x80
559        {
560            start += 1;
561        }
562    }
563    let len = 8 - start;
564    let mut bytes = [0u8; 8];
565    bytes[..len].copy_from_slice(&raw[start..]);
566    EncodedInt { bytes, len }
567}
568
569/// Decodes a BER integer from big-endian bytes into i64.
570fn decode_i64(data: &[u8]) -> i64 {
571    let negative = !data.is_empty() && data[0] & 0x80 != 0;
572    let mut val: i64 = if negative { -1 } else { 0 };
573    for &b in data {
574        val = (val << 8) | b as i64;
575    }
576    val
577}
578
579#[cfg(test)]
580mod tests {
581    use super::*;
582
583    #[test]
584    fn roundtrip_integer() {
585        for &v in &[0i64, 1, -1, 127, 128, -128, -129, 256,
586                     i64::MAX, i64::MIN, 0x7FFF, -32768] {
587            let mut buf = [0u8; 32];
588            let mut w = BerWriter::new(&mut buf);
589            w.write_integer(v).unwrap();
590            let mut r = BerReader::new(w.as_bytes());
591            let got = r.read_integer().unwrap();
592            assert_eq!(got, v, "failed for {v}");
593        }
594    }
595
596    #[test]
597    fn roundtrip_bool() {
598        let mut buf = [0u8; 16];
599        let mut w = BerWriter::new(&mut buf);
600        w.write_bool(true).unwrap();
601        w.write_bool(false).unwrap();
602        let mut r = BerReader::new(w.as_bytes());
603        assert!(r.read_bool().unwrap());
604        assert!(!r.read_bool().unwrap());
605    }
606
607    #[test]
608    fn roundtrip_octet_string() {
609        let data = b"hello SLE";
610        let mut buf = [0u8; 32];
611        let mut w = BerWriter::new(&mut buf);
612        w.write_octet_string(data).unwrap();
613        let mut r = BerReader::new(w.as_bytes());
614        let got = r.read_octet_string().unwrap();
615        assert_eq!(got, data);
616    }
617
618    #[test]
619    fn roundtrip_enum() {
620        let mut buf = [0u8; 16];
621        let mut w = BerWriter::new(&mut buf);
622        w.write_enum(2).unwrap();
623        let mut r = BerReader::new(w.as_bytes());
624        assert_eq!(r.read_enum().unwrap(), 2);
625    }
626
627    #[test]
628    fn roundtrip_null() {
629        let mut buf = [0u8; 8];
630        let mut w = BerWriter::new(&mut buf);
631        w.write_null().unwrap();
632        let mut r = BerReader::new(w.as_bytes());
633        r.read_null().unwrap();
634    }
635
636    #[test]
637    fn roundtrip_sequence() {
638        let mut buf = [0u8; 64];
639        let mut w = BerWriter::new(&mut buf);
640        let seq = w.begin_sequence().unwrap();
641        w.write_integer(42).unwrap();
642        w.write_octet_string(b"test").unwrap();
643        w.end_sequence(seq).unwrap();
644
645        let mut r = BerReader::new(w.as_bytes());
646        let _seq_len = r.read_sequence().unwrap();
647        assert_eq!(r.read_integer().unwrap(), 42);
648        assert_eq!(r.read_octet_string().unwrap(), b"test");
649    }
650
651    #[test]
652    fn encode_decode_length_short() {
653        let mut buf = [0u8; 8];
654        let n = encode_length(42, &mut buf).unwrap();
655        assert_eq!(n, 1);
656        let (val, consumed) = decode_length(&buf).unwrap();
657        assert_eq!(val, 42);
658        assert_eq!(consumed, 1);
659    }
660
661    #[test]
662    fn encode_decode_length_long() {
663        let mut buf = [0u8; 8];
664        let n = encode_length(300, &mut buf).unwrap();
665        assert_eq!(n, 3);
666        let (val, consumed) = decode_length(&buf).unwrap();
667        assert_eq!(val, 300);
668        assert_eq!(consumed, 3);
669    }
670
671    #[test]
672    fn peek_tag_does_not_advance() {
673        let mut buf = [0u8; 8];
674        let mut w = BerWriter::new(&mut buf);
675        w.write_integer(7).unwrap();
676        let r = BerReader::new(w.as_bytes());
677        let (tag, class, _) = r.peek_tag().unwrap();
678        assert_eq!(tag, tags::INTEGER);
679        assert_eq!(class, Class::Universal);
680        assert_eq!(r.pos(), 0);
681    }
682}