1const STRIP_HEIGHT: usize = 8;
28
29const DWT_LEVELS: usize = 3;
31
32#[derive(Debug, thiserror::Error)]
34pub enum Error {
35 #[error("Invalid compressor configuration")]
37 InvalidConfig,
38 #[error("Output buffer too small to hold compressed data")]
40 OutputFull,
41 #[error("Input data truncated or malformed")]
43 Truncated,
44 #[error("Provided scratch buffer is too small for required temporary storage")]
46 ScratchTooSmall,
47}
48
49#[derive(Debug, Clone)]
51pub struct Config {
52 pub width: u16,
54 pub height: u16,
56 pub bps: u8,
58 pub segment_strips: u16,
60 pub signed_samples: bool,
62}
63
64impl Config {
65 fn validate(&self) -> Result<(), Error> {
66 if self.width == 0 || self.height == 0 || self.width % 8 != 0 || self.height % 8 != 0 {
67 return Err(Error::InvalidConfig);
68 }
69 if self.bps < 2 || self.bps > 16 {
70 return Err(Error::InvalidConfig);
71 }
72 let max_strips = self.height as usize / STRIP_HEIGHT;
73 if self.segment_strips == 0 || self.segment_strips as usize > max_strips {
74 return Err(Error::InvalidConfig);
75 }
76 Ok(())
77 }
78
79 fn strips(&self) -> usize {
80 self.height as usize / STRIP_HEIGHT
81 }
82
83 fn seg_height(&self) -> usize {
84 self.segment_strips as usize * STRIP_HEIGHT
85 }
86}
87
88pub fn scratch_len(width: usize, seg_height: usize) -> usize {
90 width * seg_height + width
93}
94
95fn dwt53_forward_1d(data: &mut [i32], n: usize) {
103 if n < 2 {
104 return;
105 }
106
107 let half = n / 2;
112
113 for i in 0..half {
115 let left = data[2 * i];
116 let right = if 2 * i + 2 < n {
117 data[2 * i + 2]
118 } else {
119 data[2 * i] };
121 data[2 * i + 1] -= (left + right) / 2;
122 }
123
124 for i in 0..half {
126 let d_left = if i > 0 {
127 data[2 * (i - 1) + 1]
128 } else {
129 data[1] };
131 let d_right = data[2 * i + 1];
132 data[2 * i] += (d_left + d_right + 2) / 4;
133 }
134
135 deinterleave(data, n);
140}
141
142fn dwt53_inverse_1d(data: &mut [i32], n: usize) {
144 if n < 2 {
145 return;
146 }
147
148 let half = n / 2;
149
150 interleave(data, n);
152
153 for i in 0..half {
155 let d_left = if i > 0 {
156 data[2 * (i - 1) + 1]
157 } else {
158 data[1]
159 };
160 let d_right = data[2 * i + 1];
161 data[2 * i] -= (d_left + d_right + 2) / 4;
162 }
163
164 for i in 0..half {
166 let left = data[2 * i];
167 let right = if 2 * i + 2 < n {
168 data[2 * i + 2]
169 } else {
170 data[2 * i]
171 };
172 data[2 * i + 1] += (left + right) / 2;
173 }
174}
175
176fn deinterleave(data: &mut [i32], n: usize) {
178 let half = n / 2;
179 if half <= 1 {
204 return;
205 }
206 deinterleave_inplace(data, n);
207}
208
209fn deinterleave_inplace(data: &mut [i32], n: usize) {
210 let half = n / 2;
213 if half <= 1 {
214 return;
215 }
216
217 const MAX_HALF: usize = 8192;
241 let mut temp = [0i32; MAX_HALF];
242 for i in 0..half {
244 temp[i] = data[2 * i + 1];
245 }
246 for i in 1..half {
248 data[i] = data[2 * i];
249 }
250 for i in 0..half {
252 data[half + i] = temp[i];
253 }
254}
255
256fn interleave(data: &mut [i32], n: usize) {
257 let half = n / 2;
258 if half <= 1 {
259 return;
260 }
261 const MAX_HALF: usize = 8192;
262 let mut temp = [0i32; MAX_HALF];
263 for i in 0..half {
265 temp[i] = data[half + i];
266 }
267 for i in (1..half).rev() {
269 data[2 * i] = data[i];
270 }
271 for i in 0..half {
273 data[2 * i + 1] = temp[i];
274 }
275}
276
277fn dwt53_forward_2d(coeffs: &mut [i32], stride: usize, w: usize, h: usize) {
282 for y in 0..h {
284 let row_start = y * stride;
285 dwt53_forward_1d(&mut coeffs[row_start..row_start + w], w);
286 }
287
288 const MAX_COL: usize = 8192;
290 let mut col = [0i32; MAX_COL];
291 for x in 0..w {
292 for y in 0..h {
293 col[y] = coeffs[y * stride + x];
294 }
295 dwt53_forward_1d(&mut col, h);
296 for y in 0..h {
297 coeffs[y * stride + x] = col[y];
298 }
299 }
300}
301
302fn dwt53_inverse_2d(coeffs: &mut [i32], stride: usize, w: usize, h: usize) {
304 const MAX_COL: usize = 8192;
306 let mut col = [0i32; MAX_COL];
307 for x in 0..w {
308 for y in 0..h {
309 col[y] = coeffs[y * stride + x];
310 }
311 dwt53_inverse_1d(&mut col, h);
312 for y in 0..h {
313 coeffs[y * stride + x] = col[y];
314 }
315 }
316
317 for y in 0..h {
319 let row_start = y * stride;
320 dwt53_inverse_1d(&mut coeffs[row_start..row_start + w], w);
321 }
322}
323
324fn dwt_forward_3level(coeffs: &mut [i32], stride: usize, w: usize, h: usize) {
326 let mut cw = w;
327 let mut ch = h;
328 for _ in 0..DWT_LEVELS {
329 dwt53_forward_2d(coeffs, stride, cw, ch);
330 cw /= 2;
331 ch /= 2;
332 }
333}
334
335fn dwt_inverse_3level(coeffs: &mut [i32], stride: usize, w: usize, h: usize) {
337 let mut sizes = [(0usize, 0usize); DWT_LEVELS];
338 let mut cw = w;
339 let mut ch = h;
340 for i in 0..DWT_LEVELS {
341 cw /= 2;
342 ch /= 2;
343 sizes[i] = (cw * 2, ch * 2);
344 }
345 for i in (0..DWT_LEVELS).rev() {
347 let (sw, sh) = sizes[i];
348 dwt53_inverse_2d(coeffs, stride, sw, sh);
349 }
350}
351
352struct BitWriter<'a> {
369 buf: &'a mut [u8],
370 pos: usize,
371 bit: u32,
372}
373
374impl<'a> BitWriter<'a> {
375 fn new(buf: &'a mut [u8]) -> Self {
376 Self {
377 buf,
378 pos: 0,
379 bit: 0,
380 }
381 }
382
383 fn write_bits(&mut self, value: u64, n: u32) -> Result<(), Error> {
384 for i in (0..n).rev() {
385 let b = ((value >> i) & 1) as u8;
386 if self.pos >= self.buf.len() {
387 return Err(Error::OutputFull);
388 }
389 self.buf[self.pos] |= b << (7 - self.bit);
390 self.bit += 1;
391 if self.bit == 8 {
392 self.bit = 0;
393 self.pos += 1;
394 }
395 }
396 Ok(())
397 }
398
399 fn flush(&mut self) -> Result<(), Error> {
400 if self.bit > 0 {
401 self.bit = 0;
402 self.pos += 1;
403 }
404 Ok(())
405 }
406
407 fn bytes_written(&self) -> usize {
408 if self.bit > 0 { self.pos + 1 } else { self.pos }
409 }
410}
411
412struct BitReader<'a> {
413 buf: &'a [u8],
414 pos: usize,
415 bit: u32,
416}
417
418impl<'a> BitReader<'a> {
419 fn new(buf: &'a [u8]) -> Self {
420 Self {
421 buf,
422 pos: 0,
423 bit: 0,
424 }
425 }
426
427 fn read_bits(&mut self, n: u32) -> Result<u64, Error> {
428 let mut val = 0u64;
429 for _ in 0..n {
430 if self.pos >= self.buf.len() {
431 return Err(Error::Truncated);
432 }
433 let b = (self.buf[self.pos] >> (7 - self.bit)) & 1;
434 val = (val << 1) | b as u64;
435 self.bit += 1;
436 if self.bit == 8 {
437 self.bit = 0;
438 self.pos += 1;
439 }
440 }
441 Ok(val)
442 }
443}
444
445fn write_segment_header(
447 w: &mut BitWriter,
448 cfg: &Config,
449 seg_idx: usize,
450 max_bitplane: u8,
451) -> Result<(), Error> {
452 let total_segs = (cfg.strips() + cfg.segment_strips as usize - 1) / cfg.segment_strips as usize;
461 let start = if seg_idx == 0 { 1u64 } else { 0 };
462 let end = if seg_idx == total_segs - 1 { 1u64 } else { 0 };
463
464 w.write_bits(start, 1)?;
465 w.write_bits(end, 1)?;
466 w.write_bits(seg_idx as u64 % 256, 8)?;
467 w.write_bits(max_bitplane as u64, 4)?;
468 w.write_bits(DWT_LEVELS as u64, 3)?;
469 let signed = if cfg.signed_samples { 1u64 } else { 0 };
470 w.write_bits(signed, 1)?;
471 w.write_bits(0, 2)?;
472 Ok(())
473}
474
475fn read_segment_header(r: &mut BitReader) -> Result<(bool, bool, u8, u8, u8, bool), Error> {
476 let start = r.read_bits(1)? == 1;
477 let end = r.read_bits(1)? == 1;
478 let _seg_idx = r.read_bits(8)? as u8;
479 let max_bp = r.read_bits(4)? as u8;
480 let _levels = r.read_bits(3)? as u8;
481 let signed = r.read_bits(1)? == 1;
482 let _reserved = r.read_bits(2)?;
483 Ok((start, end, _seg_idx, max_bp, _levels, signed))
484}
485
486fn encode_segment(
492 w: &mut BitWriter,
493 coeffs: &[i32],
494 width: usize,
495 seg_h: usize,
496 max_bp: u8,
497) -> Result<(), Error> {
498 let n = width * seg_h;
499
500 w.write_bits(n as u64, 20)?;
502
503 for i in 0..n {
505 let v = coeffs[i];
506 let sign = if v < 0 { 1u64 } else { 0 };
507 let mag = v.unsigned_abs();
508 w.write_bits(sign, 1)?;
509 w.write_bits(mag as u64, max_bp as u32)?;
510 }
511 Ok(())
512}
513
514fn decode_segment(r: &mut BitReader, coeffs: &mut [i32], max_bp: u8) -> Result<usize, Error> {
516 let n = r.read_bits(20)? as usize;
517
518 for i in 0..n {
519 let sign = r.read_bits(1)?;
520 let mag = r.read_bits(max_bp as u32)? as i32;
521 coeffs[i] = if sign == 1 { -mag } else { mag };
522 }
523 Ok(n)
524}
525
526pub fn compress(
536 cfg: &Config,
537 image: &[u16],
538 out: &mut [u8],
539 scratch: &mut [i32],
540) -> Result<usize, Error> {
541 cfg.validate()?;
542
543 let w = cfg.width as usize;
544 let h = cfg.height as usize;
545 let seg_h = cfg.seg_height();
546 let needed = scratch_len(w, seg_h);
547
548 if scratch.len() < needed {
549 return Err(Error::ScratchTooSmall);
550 }
551 if image.len() < w * h {
552 return Err(Error::Truncated);
553 }
554
555 for b in out.iter_mut() {
556 *b = 0;
557 }
558
559 let mut bw = BitWriter::new(out);
560
561 bw.write_bits(w as u64, 16)?;
565 bw.write_bits(h as u64, 16)?;
566 bw.write_bits(cfg.bps as u64, 4)?;
567 bw.write_bits(cfg.segment_strips as u64, 16)?;
568
569 let total_strips = cfg.strips();
570 let seg_strips = cfg.segment_strips as usize;
571 let mut seg_idx = 0usize;
572
573 let mut strip = 0usize;
574 while strip < total_strips {
575 let cur_strips = core::cmp::min(seg_strips, total_strips - strip);
576 let cur_h = cur_strips * STRIP_HEIGHT;
577 let y_start = strip * STRIP_HEIGHT;
578
579 let coeffs = &mut scratch[..w * cur_h];
581 for y in 0..cur_h {
582 for x in 0..w {
583 coeffs[y * w + x] = image[(y_start + y) * w + x] as i32;
584 }
585 }
586
587 dwt_forward_3level(coeffs, w, w, cur_h);
589
590 let mut max_abs = 0u32;
592 for i in 0..(w * cur_h) {
593 let a = coeffs[i].unsigned_abs();
594 if a > max_abs {
595 max_abs = a;
596 }
597 }
598 let max_bp = if max_abs == 0 {
599 1
600 } else {
601 32 - max_abs.leading_zeros()
602 } as u8;
603
604 write_segment_header(&mut bw, cfg, seg_idx, max_bp)?;
605 encode_segment(&mut bw, coeffs, w, cur_h, max_bp)?;
606
607 strip += cur_strips;
608 seg_idx += 1;
609 }
610
611 bw.flush()?;
612 Ok(bw.bytes_written())
613}
614
615pub fn decompress(
619 data: &[u8],
620 image: &mut [u16],
621 scratch: &mut [i32],
622) -> Result<(Config, usize), Error> {
623 let mut br = BitReader::new(data);
624
625 let w = br.read_bits(16)? as u16;
626 let h = br.read_bits(16)? as u16;
627 let bps = br.read_bits(4)? as u8;
628 let seg_strips = br.read_bits(16)? as u16;
629
630 let cfg = Config {
631 width: w,
632 height: h,
633 bps,
634 segment_strips: seg_strips,
635 signed_samples: false, };
637 cfg.validate()?;
638
639 let wi = w as usize;
640 let hi = h as usize;
641 let seg_h = cfg.seg_height();
642 let needed = scratch_len(wi, seg_h);
643
644 if scratch.len() < needed {
645 return Err(Error::ScratchTooSmall);
646 }
647 if image.len() < wi * hi {
648 return Err(Error::OutputFull);
649 }
650
651 let total_strips = cfg.strips();
652 let seg_strips_n = seg_strips as usize;
653 let mut strip = 0usize;
654
655 while strip < total_strips {
656 let cur_strips = core::cmp::min(seg_strips_n, total_strips - strip);
657 let cur_h = cur_strips * STRIP_HEIGHT;
658 let y_start = strip * STRIP_HEIGHT;
659
660 let (_, _, _, max_bp, _, _signed) = read_segment_header(&mut br)?;
661
662 let coeffs = &mut scratch[..wi * cur_h];
663 let n = decode_segment(&mut br, coeffs, max_bp)?;
664 let _ = n;
665
666 dwt_inverse_3level(coeffs, wi, wi, cur_h);
668
669 for y in 0..cur_h {
671 for x in 0..wi {
672 let v = coeffs[y * wi + x];
673 image[(y_start + y) * wi + x] = v as u16;
674 }
675 }
676
677 strip += cur_strips;
678 }
679
680 Ok((cfg, wi * hi))
681}
682
683#[cfg(test)]
686mod tests {
687 use super::*;
688
689 fn default_config(w: u16, h: u16) -> Config {
690 Config {
691 width: w,
692 height: h,
693 bps: 8,
694 segment_strips: h / 8,
695 signed_samples: false,
696 }
697 }
698
699 fn roundtrip(cfg: &Config, image: &[u16]) {
700 let w = cfg.width as usize;
701 let seg_h = cfg.seg_height();
702 let _slen = scratch_len(w, seg_h);
703 let mut scratch = [0i32; 4096];
704 let mut compressed = [0u8; 8192];
705
706 let n = compress(cfg, image, &mut compressed, &mut scratch).unwrap();
707 assert!(n > 0);
708
709 let n_px = image.len();
710 let mut decoded = [0u16; 1024];
711 for s in scratch.iter_mut() {
712 *s = 0;
713 }
714 let (_, count) = decompress(&compressed[..n], &mut decoded[..n_px], &mut scratch).unwrap();
715 assert_eq!(count, n_px);
716 assert_eq!(&decoded[..n_px], image);
717 }
718
719 #[test]
720 fn dwt53_roundtrip_1d() {
721 let original = [10, 20, 30, 40, 50, 60, 70, 80];
722 let mut data = [0i32; 8];
723 for i in 0..8 {
724 data[i] = original[i];
725 }
726 dwt53_forward_1d(&mut data, 8);
727 dwt53_inverse_1d(&mut data, 8);
728 assert_eq!(data, original);
729 }
730
731 #[test]
732 fn dwt53_roundtrip_2d() {
733 let mut original = [0i32; 64];
734 for i in 0..64 {
735 original[i] = (i * 3 + 7) as i32;
736 }
737 let mut data = original;
738 dwt53_forward_2d(&mut data, 8, 8, 8);
739 dwt53_inverse_2d(&mut data, 8, 8, 8);
740 assert_eq!(data, original);
741 }
742
743 #[test]
744 fn dwt53_3level_roundtrip() {
745 let mut original = [0i32; 64];
746 for i in 0..64 {
747 original[i] = (i * 5 + 13) as i32;
748 }
749 let mut data = original;
750 dwt_forward_3level(&mut data, 8, 8, 8);
751 dwt_inverse_3level(&mut data, 8, 8, 8);
752 assert_eq!(data, original);
753 }
754
755 #[test]
756 fn roundtrip_constant() {
757 let cfg = default_config(8, 8);
758 roundtrip(&cfg, &[128u16; 64]);
759 }
760
761 #[test]
762 fn roundtrip_ramp() {
763 let cfg = default_config(8, 8);
764 let mut image = [0u16; 64];
765 for i in 0..64 {
766 image[i] = (i * 3) as u16;
767 }
768 roundtrip(&cfg, &image);
769 }
770
771 #[test]
772 fn roundtrip_16x16() {
773 let cfg = default_config(16, 16);
774 let mut image = [0u16; 256];
775 for i in 0..256 {
776 image[i] = (i % 200) as u16;
777 }
778 roundtrip(&cfg, &image);
779 }
780
781 #[test]
782 fn roundtrip_multi_segment() {
783 let cfg = Config {
784 width: 8,
785 height: 16,
786 bps: 8,
787 segment_strips: 1,
788 signed_samples: false,
789 };
790 let mut image = [0u16; 128];
791 for i in 0..128 {
792 image[i] = ((i * 7) % 256) as u16;
793 }
794 roundtrip(&cfg, &image);
795 }
796
797 #[test]
798 fn deinterleave_roundtrip() {
799 let original = [1, 2, 3, 4, 5, 6, 7, 8];
800 let mut data = original;
801 deinterleave(&mut data, 8);
802 assert_eq!(data, [1, 3, 5, 7, 2, 4, 6, 8]);
804 interleave(&mut data, 8);
805 assert_eq!(data, original);
806 }
807}