1use super::types::SleError;
9
10#[derive(Copy, Clone, Debug, PartialEq, Eq)]
12#[repr(u8)]
13pub enum Class {
14 Universal = 0,
16 Application = 1,
18 Context = 2,
20 Private = 3,
22}
23
24impl Class {
25 fn from_bits(bits: u8) -> Self {
26 match bits {
27 0 => Self::Universal,
28 1 => Self::Application,
29 2 => Self::Context,
30 _ => Self::Private,
31 }
32 }
33}
34
35pub mod tags {
37 pub const BOOLEAN: u8 = 1;
39 pub const INTEGER: u8 = 2;
41 pub const BIT_STRING: u8 = 3;
43 pub const OCTET_STRING: u8 = 4;
45 pub const NULL: u8 = 5;
47 pub const ENUMERATED: u8 = 10;
49 pub const SEQUENCE: u8 = 16;
51}
52
53pub fn encode_tag(tag: u8, class: Class, constructed: bool) -> u8 {
58 let class_bits = (class as u8) << 6;
59 let constructed_bit = if constructed { 0x20 } else { 0x00 };
60 class_bits | constructed_bit | (tag & 0x1F)
61}
62
63pub fn encode_length(len: usize, buf: &mut [u8]) -> Result<usize, SleError> {
68 if len < 128 {
69 if buf.is_empty() {
70 return Err(SleError::BufferTooSmall);
71 }
72 buf[0] = len as u8;
73 Ok(1)
74 } else if len <= 0xFF {
75 if buf.len() < 2 {
76 return Err(SleError::BufferTooSmall);
77 }
78 buf[0] = 0x81;
79 buf[1] = len as u8;
80 Ok(2)
81 } else if len <= 0xFFFF {
82 if buf.len() < 3 {
83 return Err(SleError::BufferTooSmall);
84 }
85 buf[0] = 0x82;
86 buf[1] = (len >> 8) as u8;
87 buf[2] = len as u8;
88 Ok(3)
89 } else if len <= 0xFF_FFFF {
90 if buf.len() < 4 {
91 return Err(SleError::BufferTooSmall);
92 }
93 buf[0] = 0x83;
94 buf[1] = (len >> 16) as u8;
95 buf[2] = (len >> 8) as u8;
96 buf[3] = len as u8;
97 Ok(4)
98 } else {
99 if buf.len() < 5 {
100 return Err(SleError::BufferTooSmall);
101 }
102 buf[0] = 0x84;
103 buf[1] = (len >> 24) as u8;
104 buf[2] = (len >> 16) as u8;
105 buf[3] = (len >> 8) as u8;
106 buf[4] = len as u8;
107 Ok(5)
108 }
109}
110
111pub fn decode_tag(
116 buf: &[u8],
117) -> Result<(u8, Class, bool, usize), SleError> {
118 if buf.is_empty() {
119 return Err(SleError::Truncated);
120 }
121 let b = buf[0];
122 let class = Class::from_bits(b >> 6);
123 let constructed = (b & 0x20) != 0;
124 let tag = b & 0x1F;
125 if tag == 0x1F {
126 return Err(SleError::UnexpectedTag);
127 }
128 Ok((tag, class, constructed, 1))
129}
130
131pub fn decode_length(buf: &[u8]) -> Result<(usize, usize), SleError> {
135 if buf.is_empty() {
136 return Err(SleError::Truncated);
137 }
138 let first = buf[0];
139 if first < 128 {
140 Ok((first as usize, 1))
141 } else {
142 let num_bytes = (first & 0x7F) as usize;
143 if num_bytes == 0 || num_bytes > 4 {
144 return Err(SleError::Truncated);
145 }
146 if buf.len() < 1 + num_bytes {
147 return Err(SleError::Truncated);
148 }
149 let mut val: usize = 0;
150 for i in 0..num_bytes {
151 val = (val << 8) | buf[1 + i] as usize;
152 }
153 Ok((val, 1 + num_bytes))
154 }
155}
156
157pub struct BerWriter<'a> {
159 buf: &'a mut [u8],
160 pos: usize,
161}
162
163impl<'a> BerWriter<'a> {
164 pub fn new(buf: &'a mut [u8]) -> Self {
166 Self { buf, pos: 0 }
167 }
168
169 pub fn len(&self) -> usize {
171 self.pos
172 }
173
174 pub fn is_empty(&self) -> bool {
176 self.pos == 0
177 }
178
179 pub fn as_bytes(&self) -> &[u8] {
181 &self.buf[..self.pos]
182 }
183
184 fn remaining(&self) -> usize {
186 self.buf.len() - self.pos
187 }
188
189 fn write_byte(&mut self, b: u8) -> Result<(), SleError> {
190 if self.remaining() == 0 {
191 return Err(SleError::BufferTooSmall);
192 }
193 self.buf[self.pos] = b;
194 self.pos += 1;
195 Ok(())
196 }
197
198 fn write_bytes(&mut self, data: &[u8]) -> Result<(), SleError> {
199 if self.remaining() < data.len() {
200 return Err(SleError::BufferTooSmall);
201 }
202 self.buf[self.pos..self.pos + data.len()].copy_from_slice(data);
203 self.pos += data.len();
204 Ok(())
205 }
206
207 fn write_tag(
208 &mut self,
209 tag: u8,
210 class: Class,
211 constructed: bool,
212 ) -> Result<(), SleError> {
213 self.write_byte(encode_tag(tag, class, constructed))
214 }
215
216 fn write_length(&mut self, len: usize) -> Result<(), SleError> {
217 let n = encode_length(len, &mut self.buf[self.pos..])?;
218 self.pos += n;
219 Ok(())
220 }
221
222 pub fn write_bool(&mut self, value: bool) -> Result<(), SleError> {
224 self.write_tag(tags::BOOLEAN, Class::Universal, false)?;
225 self.write_length(1)?;
226 self.write_byte(if value { 0xFF } else { 0x00 })
227 }
228
229 pub fn write_integer(&mut self, value: i64) -> Result<(), SleError> {
231 self.write_tag(tags::INTEGER, Class::Universal, false)?;
232 let encoded = encode_i64(value);
233 self.write_length(encoded.len)?;
234 self.write_bytes(&encoded.bytes[..encoded.len])
235 }
236
237 pub fn write_octet_string(
239 &mut self,
240 data: &[u8],
241 ) -> Result<(), SleError> {
242 self.write_tag(tags::OCTET_STRING, Class::Universal, false)?;
243 self.write_length(data.len())?;
244 self.write_bytes(data)
245 }
246
247 pub fn write_enum(&mut self, value: i64) -> Result<(), SleError> {
249 self.write_tag(tags::ENUMERATED, Class::Universal, false)?;
250 let encoded = encode_i64(value);
251 self.write_length(encoded.len)?;
252 self.write_bytes(&encoded.bytes[..encoded.len])
253 }
254
255 pub fn write_null(&mut self) -> Result<(), SleError> {
257 self.write_tag(tags::NULL, Class::Universal, false)?;
258 self.write_length(0)
259 }
260
261 pub fn write_bit_string(
263 &mut self,
264 data: &[u8],
265 ) -> Result<(), SleError> {
266 self.write_tag(tags::BIT_STRING, Class::Universal, false)?;
267 self.write_length(data.len() + 1)?;
268 self.write_byte(0)?; self.write_bytes(data)
270 }
271
272 pub fn begin_sequence(&mut self) -> Result<SeqStart, SleError> {
275 self.write_tag(tags::SEQUENCE, Class::Universal, true)?;
276 let len_pos = self.pos;
277 if self.remaining() < 3 {
280 return Err(SleError::BufferTooSmall);
281 }
282 self.pos += 3;
283 Ok(SeqStart {
284 len_pos,
285 content_start: self.pos,
286 })
287 }
288
289 pub fn end_sequence(
292 &mut self,
293 start: SeqStart,
294 ) -> Result<(), SleError> {
295 let content_len = self.pos - start.content_start;
296 if content_len < 128 {
297 let src = start.content_start;
299 let dst = start.len_pos + 1;
300 self.buf.copy_within(src..self.pos, dst);
301 self.buf[start.len_pos] = content_len as u8;
302 self.pos -= 2;
303 } else if content_len <= 0xFF {
304 let src = start.content_start;
306 let dst = start.len_pos + 2;
307 self.buf.copy_within(src..self.pos, dst);
308 self.buf[start.len_pos] = 0x81;
309 self.buf[start.len_pos + 1] = content_len as u8;
310 self.pos -= 1;
311 } else {
312 self.buf[start.len_pos] = 0x82;
314 self.buf[start.len_pos + 1] = (content_len >> 8) as u8;
315 self.buf[start.len_pos + 2] = content_len as u8;
316 }
317 Ok(())
318 }
319
320 pub fn begin_context(
323 &mut self,
324 tag: u8,
325 constructed: bool,
326 ) -> Result<SeqStart, SleError> {
327 self.write_tag(tag, Class::Context, constructed)?;
328 let len_pos = self.pos;
329 if self.remaining() < 3 {
330 return Err(SleError::BufferTooSmall);
331 }
332 self.pos += 3;
333 Ok(SeqStart {
334 len_pos,
335 content_start: self.pos,
336 })
337 }
338}
339
340#[derive(Copy, Clone, Debug)]
342pub struct SeqStart {
343 len_pos: usize,
344 content_start: usize,
345}
346
347pub struct BerReader<'a> {
349 buf: &'a [u8],
350 pos: usize,
351}
352
353impl<'a> BerReader<'a> {
354 pub fn new(buf: &'a [u8]) -> Self {
356 Self { buf, pos: 0 }
357 }
358
359 pub fn pos(&self) -> usize {
361 self.pos
362 }
363
364 pub fn remaining(&self) -> usize {
366 self.buf.len() - self.pos
367 }
368
369 pub fn is_empty(&self) -> bool {
371 self.remaining() == 0
372 }
373
374 pub fn peek_tag(
377 &self,
378 ) -> Result<(u8, Class, bool), SleError> {
379 let (tag, class, constructed, _) =
380 decode_tag(&self.buf[self.pos..])?;
381 Ok((tag, class, constructed))
382 }
383
384 pub fn read_tag(
386 &mut self,
387 ) -> Result<(u8, Class, bool), SleError> {
388 let (tag, class, constructed, consumed) =
389 decode_tag(&self.buf[self.pos..])?;
390 self.pos += consumed;
391 Ok((tag, class, constructed))
392 }
393
394 pub fn read_length(&mut self) -> Result<usize, SleError> {
396 let (len, consumed) =
397 decode_length(&self.buf[self.pos..])?;
398 self.pos += consumed;
399 Ok(len)
400 }
401
402 pub fn read_raw(
404 &mut self,
405 len: usize,
406 ) -> Result<&'a [u8], SleError> {
407 if self.remaining() < len {
408 return Err(SleError::Truncated);
409 }
410 let data = &self.buf[self.pos..self.pos + len];
411 self.pos += len;
412 Ok(data)
413 }
414
415 pub fn read_bool(&mut self) -> Result<bool, SleError> {
417 let (tag, class, _, consumed) =
418 decode_tag(&self.buf[self.pos..])?;
419 if tag != tags::BOOLEAN || class != Class::Universal {
420 return Err(SleError::UnexpectedTag);
421 }
422 self.pos += consumed;
423 let len = self.read_length()?;
424 if len != 1 {
425 return Err(SleError::Truncated);
426 }
427 let val = self.buf[self.pos];
428 self.pos += 1;
429 Ok(val != 0)
430 }
431
432 pub fn read_integer(&mut self) -> Result<i64, SleError> {
434 let (tag, class, _, consumed) =
435 decode_tag(&self.buf[self.pos..])?;
436 if tag != tags::INTEGER || class != Class::Universal {
437 return Err(SleError::UnexpectedTag);
438 }
439 self.pos += consumed;
440 let len = self.read_length()?;
441 if len == 0 || len > 8 {
442 return Err(SleError::IntegerOverflow);
443 }
444 let data = self.read_raw(len)?;
445 Ok(decode_i64(data))
446 }
447
448 pub fn read_octet_string(
450 &mut self,
451 ) -> Result<&'a [u8], SleError> {
452 let (tag, class, _, consumed) =
453 decode_tag(&self.buf[self.pos..])?;
454 if tag != tags::OCTET_STRING || class != Class::Universal {
455 return Err(SleError::UnexpectedTag);
456 }
457 self.pos += consumed;
458 let len = self.read_length()?;
459 self.read_raw(len)
460 }
461
462 pub fn read_enum(&mut self) -> Result<i64, SleError> {
464 let (tag, class, _, consumed) =
465 decode_tag(&self.buf[self.pos..])?;
466 if tag != tags::ENUMERATED || class != Class::Universal {
467 return Err(SleError::UnexpectedTag);
468 }
469 self.pos += consumed;
470 let len = self.read_length()?;
471 if len == 0 || len > 8 {
472 return Err(SleError::IntegerOverflow);
473 }
474 let data = self.read_raw(len)?;
475 Ok(decode_i64(data))
476 }
477
478 pub fn read_null(&mut self) -> Result<(), SleError> {
480 let (tag, class, _, consumed) =
481 decode_tag(&self.buf[self.pos..])?;
482 if tag != tags::NULL || class != Class::Universal {
483 return Err(SleError::UnexpectedTag);
484 }
485 self.pos += consumed;
486 let len = self.read_length()?;
487 if len != 0 {
488 return Err(SleError::Truncated);
489 }
490 Ok(())
491 }
492
493 pub fn read_sequence(&mut self) -> Result<usize, SleError> {
496 let (tag, class, constructed, consumed) =
497 decode_tag(&self.buf[self.pos..])?;
498 if tag != tags::SEQUENCE
499 || class != Class::Universal
500 || !constructed
501 {
502 return Err(SleError::UnexpectedTag);
503 }
504 self.pos += consumed;
505 self.read_length()
506 }
507
508 pub fn read_context_tag(
511 &mut self,
512 ) -> Result<(u8, usize), SleError> {
513 let (tag, class, _, consumed) =
514 decode_tag(&self.buf[self.pos..])?;
515 if class != Class::Context {
516 return Err(SleError::UnexpectedTag);
517 }
518 self.pos += consumed;
519 let len = self.read_length()?;
520 Ok((tag, len))
521 }
522
523 pub fn sub_reader(
526 &mut self,
527 len: usize,
528 ) -> Result<BerReader<'a>, SleError> {
529 if self.remaining() < len {
530 return Err(SleError::Truncated);
531 }
532 let sub = BerReader::new(
533 &self.buf[self.pos..self.pos + len],
534 );
535 self.pos += len;
536 Ok(sub)
537 }
538}
539
540struct EncodedInt {
542 bytes: [u8; 8],
543 len: usize,
544}
545
546fn encode_i64(value: i64) -> EncodedInt {
548 let raw = value.to_be_bytes();
549 let mut start = 0;
550 if value >= 0 {
551 while start < 7 && raw[start] == 0 && raw[start + 1] < 0x80
552 {
553 start += 1;
554 }
555 } else {
556 while start < 7
557 && raw[start] == 0xFF
558 && raw[start + 1] >= 0x80
559 {
560 start += 1;
561 }
562 }
563 let len = 8 - start;
564 let mut bytes = [0u8; 8];
565 bytes[..len].copy_from_slice(&raw[start..]);
566 EncodedInt { bytes, len }
567}
568
569fn decode_i64(data: &[u8]) -> i64 {
571 let negative = !data.is_empty() && data[0] & 0x80 != 0;
572 let mut val: i64 = if negative { -1 } else { 0 };
573 for &b in data {
574 val = (val << 8) | b as i64;
575 }
576 val
577}
578
579#[cfg(test)]
580mod tests {
581 use super::*;
582
583 #[test]
584 fn roundtrip_integer() {
585 for &v in &[0i64, 1, -1, 127, 128, -128, -129, 256,
586 i64::MAX, i64::MIN, 0x7FFF, -32768] {
587 let mut buf = [0u8; 32];
588 let mut w = BerWriter::new(&mut buf);
589 w.write_integer(v).unwrap();
590 let mut r = BerReader::new(w.as_bytes());
591 let got = r.read_integer().unwrap();
592 assert_eq!(got, v, "failed for {v}");
593 }
594 }
595
596 #[test]
597 fn roundtrip_bool() {
598 let mut buf = [0u8; 16];
599 let mut w = BerWriter::new(&mut buf);
600 w.write_bool(true).unwrap();
601 w.write_bool(false).unwrap();
602 let mut r = BerReader::new(w.as_bytes());
603 assert!(r.read_bool().unwrap());
604 assert!(!r.read_bool().unwrap());
605 }
606
607 #[test]
608 fn roundtrip_octet_string() {
609 let data = b"hello SLE";
610 let mut buf = [0u8; 32];
611 let mut w = BerWriter::new(&mut buf);
612 w.write_octet_string(data).unwrap();
613 let mut r = BerReader::new(w.as_bytes());
614 let got = r.read_octet_string().unwrap();
615 assert_eq!(got, data);
616 }
617
618 #[test]
619 fn roundtrip_enum() {
620 let mut buf = [0u8; 16];
621 let mut w = BerWriter::new(&mut buf);
622 w.write_enum(2).unwrap();
623 let mut r = BerReader::new(w.as_bytes());
624 assert_eq!(r.read_enum().unwrap(), 2);
625 }
626
627 #[test]
628 fn roundtrip_null() {
629 let mut buf = [0u8; 8];
630 let mut w = BerWriter::new(&mut buf);
631 w.write_null().unwrap();
632 let mut r = BerReader::new(w.as_bytes());
633 r.read_null().unwrap();
634 }
635
636 #[test]
637 fn roundtrip_sequence() {
638 let mut buf = [0u8; 64];
639 let mut w = BerWriter::new(&mut buf);
640 let seq = w.begin_sequence().unwrap();
641 w.write_integer(42).unwrap();
642 w.write_octet_string(b"test").unwrap();
643 w.end_sequence(seq).unwrap();
644
645 let mut r = BerReader::new(w.as_bytes());
646 let _seq_len = r.read_sequence().unwrap();
647 assert_eq!(r.read_integer().unwrap(), 42);
648 assert_eq!(r.read_octet_string().unwrap(), b"test");
649 }
650
651 #[test]
652 fn encode_decode_length_short() {
653 let mut buf = [0u8; 8];
654 let n = encode_length(42, &mut buf).unwrap();
655 assert_eq!(n, 1);
656 let (val, consumed) = decode_length(&buf).unwrap();
657 assert_eq!(val, 42);
658 assert_eq!(consumed, 1);
659 }
660
661 #[test]
662 fn encode_decode_length_long() {
663 let mut buf = [0u8; 8];
664 let n = encode_length(300, &mut buf).unwrap();
665 assert_eq!(n, 3);
666 let (val, consumed) = decode_length(&buf).unwrap();
667 assert_eq!(val, 300);
668 assert_eq!(consumed, 3);
669 }
670
671 #[test]
672 fn peek_tag_does_not_advance() {
673 let mut buf = [0u8; 8];
674 let mut w = BerWriter::new(&mut buf);
675 w.write_integer(7).unwrap();
676 let r = BerReader::new(w.as_bytes());
677 let (tag, class, _) = r.peek_tag().unwrap();
678 assert_eq!(tag, tags::INTEGER);
679 assert_eq!(class, Class::Universal);
680 assert_eq!(r.pos(), 0);
681 }
682}