From 00b4e3bd29bdb51205dfbbb59429c168d132c0f6 Mon Sep 17 00:00:00 2001 From: Simon Dugas Date: Tue, 7 Jul 2020 16:37:00 -0400 Subject: [PATCH] stream: Streaming API for decompression Create a new struct `Stream` that uses the `std::io::Write` interface to read chunks of compressed data and write them to an output sink. Add a streaming mode so processing can work with streaming chunks of data. This is required because process() assumed the input reader contained a complete stream. Update flags and try_process_next() were added to handle when the decompressor requests more input bytes than are available. Data is temporarily buffered in the DecoderState if more input bytes are required to make progress. This commit also adds utility functions to the rangecoder for working with streaming data. Adds an allow_incomplete option to disable end of stream checks when calling `finish()` on a stream. This is because some users may want to retrieve partially decompressed data. --- Cargo.toml | 1 + benches/lzma.rs | 26 ++ src/decode/lzbuffer.rs | 18 +- src/decode/lzma.rs | 292 +++++++++++++++++----- src/decode/mod.rs | 3 + src/decode/options.rs | 26 +- src/decode/rangecoder.rs | 88 +++++-- src/decode/stream.rs | 510 +++++++++++++++++++++++++++++++++++++++ src/encode/rangecoder.rs | 8 +- src/error.rs | 5 +- src/lib.rs | 2 + tests/files/small.txt | 1 + tests/lzma.rs | 218 ++++++++++++----- 13 files changed, 1062 insertions(+), 136 deletions(-) create mode 100644 src/decode/stream.rs create mode 100644 tests/files/small.txt diff --git a/Cargo.toml b/Cargo.toml index 77b529fe..c482136f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ env_logger = { version = "^0.7.1", optional = true } [features] enable_logging = ["env_logger", "log"] +stream = [] [badges] travis-ci = { repository = "gendx/lzma-rs", branch = "master" } diff --git a/benches/lzma.rs b/benches/lzma.rs index dc567ad0..57b9f5c6 100644 --- a/benches/lzma.rs +++ b/benches/lzma.rs @@ -34,6 +34,16 @@ fn decompress_bench(compressed: &[u8], b: &mut Bencher) { }); } +#[cfg(feature = "stream")] +fn decompress_stream_bench(compressed: &[u8], b: &mut Bencher) { + use std::io::Write; + b.iter(|| { + let mut stream = lzma_rs::decompress::Stream::new(Vec::new()); + stream.write_all(compressed).unwrap(); + stream.finish().unwrap() + }); +} + fn decompress_bench_file(compfile: &str, b: &mut Bencher) { let mut f = std::fs::File::open(compfile).unwrap(); let mut compressed = Vec::new(); @@ -41,6 +51,14 @@ fn decompress_bench_file(compfile: &str, b: &mut Bencher) { decompress_bench(&compressed, b); } +#[cfg(feature = "stream")] +fn decompress_stream_bench_file(compfile: &str, b: &mut Bencher) { + let mut f = std::fs::File::open(compfile).unwrap(); + let mut compressed = Vec::new(); + f.read_to_end(&mut compressed).unwrap(); + decompress_stream_bench(&compressed, b); +} + #[bench] fn compress_empty(b: &mut Bencher) { #[cfg(feature = "enable_logging")] @@ -90,6 +108,14 @@ fn decompress_big_file(b: &mut Bencher) { decompress_bench_file("tests/files/foo.txt.lzma", b); } +#[cfg(feature = "stream")] +#[bench] +fn decompress_stream_big_file(b: &mut Bencher) { + #[cfg(feature = "enable_logging")] + let _ = env_logger::try_init(); + decompress_stream_bench_file("tests/files/foo.txt.lzma", b); +} + #[bench] fn decompress_huge_dict(b: &mut Bencher) { #[cfg(feature = "enable_logging")] diff --git a/src/decode/lzbuffer.rs b/src/decode/lzbuffer.rs index fecd5930..9d77331a 100644 --- a/src/decode/lzbuffer.rs +++ b/src/decode/lzbuffer.rs @@ -18,8 +18,10 @@ where fn get_output(&self) -> &W; // Get a mutable reference to the output sink fn get_output_mut(&mut self) -> &mut W; - // Flush the buffer to the output + // Consumes this buffer and flushes any data fn finish(self) -> io::Result; + // Consumes this buffer without flushing any data + fn into_output(self) -> W; } // An accumulating buffer for LZ sequences @@ -143,12 +145,17 @@ where &mut self.stream } - // Flush the buffer to the output + // Consumes this buffer and flushes any data fn finish(mut self) -> io::Result { self.stream.write_all(self.buf.as_slice())?; self.stream.flush()?; Ok(self.stream) } + + // Consumes this buffer without flushing any data + fn into_output(self) -> W { + self.stream + } } // A circular buffer for LZ sequences @@ -291,7 +298,7 @@ where &mut self.stream } - // Flush the buffer to the output + // Consumes this buffer and flushes any data fn finish(mut self) -> io::Result { if self.cursor > 0 { self.stream.write_all(&self.buf[0..self.cursor])?; @@ -299,4 +306,9 @@ where } Ok(self.stream) } + + // Consumes this buffer without flushing any data + fn into_output(self) -> W { + self.stream + } } diff --git a/src/decode/lzma.rs b/src/decode/lzma.rs index bc3c2c00..05dedbf7 100644 --- a/src/decode/lzma.rs +++ b/src/decode/lzma.rs @@ -8,6 +8,24 @@ use std::marker::PhantomData; use crate::decompress::Options; use crate::decompress::UnpackedSize; +/// Maximum input data that can be processed in one iteration +const MAX_REQUIRED_INPUT: usize = 20; + +/// Processing mode for decompression. +/// +/// Tells the decompressor if we should expect more data after parsing the +/// current input. +#[derive(Debug, PartialEq)] +pub enum Mode { + /// Streaming mode. Process the input bytes but assume there will be more + /// chunks of input data to receive in future calls to `process_mode()`. + Run, + /// Sync mode. Process the input bytes and confirm end of stream has been reached. + /// Use this mode if you are processing a fixed buffer of compressed data, or after + /// using `Mode::Run` to check for the end of stream. + Finish, +} + pub struct LZMAParams { // most lc significant bits of previous byte are part of the literal context lc: u32, // 0..8 @@ -24,9 +42,7 @@ impl LZMAParams { R: io::BufRead, { // Properties - let props = input - .read_u8() - .map_err(|e| error::Error::LZMAError(format!("LZMA header too short: {}", e)))?; + let props = input.read_u8().map_err(error::Error::HeaderTooShort)?; let mut pb = props as u32; if pb >= 225 { @@ -46,7 +62,7 @@ impl LZMAParams { // Dictionary let dict_size_provided = input .read_u32::() - .map_err(|e| error::Error::LZMAError(format!("LZMA header too short: {}", e)))?; + .map_err(error::Error::HeaderTooShort)?; let dict_size = if dict_size_provided < 0x1000 { 0x1000 } else { @@ -58,9 +74,9 @@ impl LZMAParams { // Unpacked size let unpacked_size: Option = match options.unpacked_size { UnpackedSize::ReadFromHeader => { - let unpacked_size_provided = input.read_u64::().map_err(|e| { - error::Error::LZMAError(format!("LZMA header too short: {}", e)) - })?; + let unpacked_size_provided = input + .read_u64::() + .map_err(error::Error::HeaderTooShort)?; let marker_mandatory: bool = unpacked_size_provided == 0xFFFF_FFFF_FFFF_FFFF; if marker_mandatory { None @@ -69,7 +85,9 @@ impl LZMAParams { } } UnpackedSize::ReadHeaderButUseProvided(x) => { - input.read_u64::()?; + input + .read_u64::() + .map_err(error::Error::HeaderTooShort)?; x } UnpackedSize::UseProvided(x) => x, @@ -95,6 +113,9 @@ where LZB: lzbuffer::LZBuffer, { _phantom: PhantomData, + // buffer input data here if we need more for decompression, 20 is the max + // number of bytes that can be consumed during one iteration + tmp: std::io::Cursor<[u8; MAX_REQUIRED_INPUT]>, pub output: LZB, // most lc significant bits of previous byte are part of the literal context pub lc: u32, // 0..8 @@ -131,6 +152,7 @@ where { DecoderState { _phantom: PhantomData, + tmp: std::io::Cursor::new([0; MAX_REQUIRED_INPUT]), output, lc, lp, @@ -181,6 +203,7 @@ where params.dict_size as usize, memlimit, ), + tmp: std::io::Cursor::new([0; MAX_REQUIRED_INPUT]), lc: params.lc, lp: params.lp, pb: params.pb, @@ -237,23 +260,38 @@ where &mut self, rangecoder: &mut rangecoder::RangeDecoder<'a, R>, ) -> error::Result<()> { - loop { - if let Some(unpacked_size) = self.unpacked_size { - if self.output.len() as u64 >= unpacked_size { - break; - } - } else if rangecoder.is_finished_ok()? { - break; - } + self.process_mode(rangecoder, Mode::Finish) + } - let pos_state = self.output.len() & ((1 << self.pb) - 1); + #[cfg(feature = "stream")] + pub fn process_stream<'a, R: io::BufRead>( + &mut self, + rangecoder: &mut rangecoder::RangeDecoder<'a, R>, + ) -> error::Result<()> { + self.process_mode(rangecoder, Mode::Run) + } - // Literal - if !rangecoder.decode_bit( - // TODO: assumes pb = 2 ?? - &mut self.is_match[(self.state << 4) + pos_state], - )? { - let byte: u8 = self.decode_literal(rangecoder)?; + /// Process the next iteration of the loop. + /// + /// If the update flag is true, the decoder's state will be updated. + /// + /// Returns true if we should continue processing the loop, false otherwise. + fn process_next_inner<'a, R: io::BufRead>( + &mut self, + rangecoder: &mut rangecoder::RangeDecoder<'a, R>, + update: bool, + ) -> error::Result { + let pos_state = self.output.len() & ((1 << self.pb) - 1); + + // Literal + if !rangecoder.decode_bit( + // TODO: assumes pb = 2 ?? + &mut self.is_match[(self.state << 4) + pos_state], + update, + )? { + let byte: u8 = self.decode_literal(rangecoder, update)?; + + if update { lzma_debug!("Literal: {}", byte); self.output.append_literal(byte)?; @@ -264,35 +302,40 @@ where } else { self.state - 6 }; - continue; } + return Ok(true); + } - // LZ - let mut len: usize; - // Distance is repeated from LRU - if rangecoder.decode_bit(&mut self.is_rep[self.state])? { - // dist = rep[0] - if !rangecoder.decode_bit(&mut self.is_rep_g0[self.state])? { - // len = 1 - if !rangecoder - .decode_bit(&mut self.is_rep_0long[(self.state << 4) + pos_state])? - { - // update state (short rep) + // LZ + let mut len: usize; + // Distance is repeated from LRU + if rangecoder.decode_bit(&mut self.is_rep[self.state], update)? { + // dist = rep[0] + if !rangecoder.decode_bit(&mut self.is_rep_g0[self.state], update)? { + // len = 1 + if !rangecoder.decode_bit( + &mut self.is_rep_0long[(self.state << 4) + pos_state], + update, + )? { + // update state (short rep) + if update { self.state = if self.state < 7 { 9 } else { 11 }; let dist = self.rep[0] + 1; self.output.append_lz(1, dist)?; - continue; } - // dist = rep[i] + return Ok(true); + } + // dist = rep[i] + } else { + let idx: usize; + if !rangecoder.decode_bit(&mut self.is_rep_g1[self.state], update)? { + idx = 1; + } else if !rangecoder.decode_bit(&mut self.is_rep_g2[self.state], update)? { + idx = 2; } else { - let idx: usize; - if !rangecoder.decode_bit(&mut self.is_rep_g1[self.state])? { - idx = 1; - } else if !rangecoder.decode_bit(&mut self.is_rep_g2[self.state])? { - idx = 2; - } else { - idx = 3; - } + idx = 3; + } + if update { // Update LRU let dist = self.rep[idx]; for i in (0..idx).rev() { @@ -300,40 +343,171 @@ where } self.rep[0] = dist } + } + + len = self.rep_len_decoder.decode(rangecoder, pos_state, update)?; - len = self.rep_len_decoder.decode(rangecoder, pos_state)?; + if update { // update state (rep) self.state = if self.state < 7 { 8 } else { 11 }; - // New distance - } else { + } + // New distance + } else { + if update { // Update LRU self.rep[3] = self.rep[2]; self.rep[2] = self.rep[1]; self.rep[1] = self.rep[0]; - len = self.len_decoder.decode(rangecoder, pos_state)?; + } + len = self.len_decoder.decode(rangecoder, pos_state, update)?; + + if update { // update state (match) self.state = if self.state < 7 { 7 } else { 10 }; - self.rep[0] = self.decode_distance(rangecoder, len)?; + } + let rep_0 = self.decode_distance(rangecoder, len, update)?; + + if update { + self.rep[0] = rep_0; if self.rep[0] == 0xFFFF_FFFF { if rangecoder.is_finished_ok()? { - break; + return Ok(false); } return Err(error::Error::LZMAError(String::from( "Found end-of-stream marker but more bytes are available", ))); } } + } + if update { len += 2; let dist = self.rep[0] + 1; self.output.append_lz(len, dist)?; } + Ok(true) + } + + fn process_next<'a, R: io::BufRead>( + &mut self, + rangecoder: &mut rangecoder::RangeDecoder<'a, R>, + ) -> error::Result { + self.process_next_inner(rangecoder, true) + } + + /// Try to process the next iteration of the loop. + /// + /// This will check to see if there is enough data to consume and advance the + /// decompressor. Needed in streaming mode to avoid corrupting the state while + /// processing incomplete chunks of data. + #[allow(clippy::if_same_then_else)] + fn try_process_next<'a>(&mut self, buf: &'a [u8], range: u32, code: u32) -> error::Result<()> { + let mut temp = std::io::Cursor::new(buf); + let mut rangecoder = rangecoder::RangeDecoder::from_parts(&mut temp, range, code); + let _ = self.process_next_inner(&mut rangecoder, false)?; + Ok(()) + } + + pub fn process_mode<'a, R: io::BufRead>( + &mut self, + mut rangecoder: &mut rangecoder::RangeDecoder<'a, R>, + mode: Mode, + ) -> error::Result<()> { + loop { + if let Some(unpacked_size) = self.unpacked_size { + if self.output.len() as u64 >= unpacked_size { + break; + } + } else if match mode { + Mode::Run => rangecoder.is_eof()?, + Mode::Finish => rangecoder.is_finished_ok()?, + } { + break; + } + + if self.tmp.position() as usize > 0 { + // Fill as much of the tmp buffer as possible + let start = self.tmp.position() as usize; + let bytes_read = rangecoder.read_into(&mut self.tmp.get_mut()[start..])?; + let bytes_read = if bytes_read < std::u64::MAX as usize { + bytes_read as u64 + } else { + return Err(error::Error::LZMAError( + "Failed to convert integer to u64.".to_string(), + )); + }; + self.tmp.set_position(self.tmp.position() + bytes_read); + let tmp = *self.tmp.get_ref(); + + // Check if we need more data to advance the decompressor + if Mode::Run == mode + && (self.tmp.position() as usize) < MAX_REQUIRED_INPUT + && self + .try_process_next( + &tmp[0..self.tmp.position() as usize], + rangecoder.range(), + rangecoder.code(), + ) + .is_err() + { + return Ok(()); + } + + // Run the decompressor on the tmp buffer + let mut tmp_reader = io::Cursor::new(&tmp[0..self.tmp.position() as usize]); + let mut tmp_rangecoder = rangecoder::RangeDecoder::from_parts( + &mut tmp_reader, + rangecoder.range(), + rangecoder.code(), + ); + let res = self.process_next(&mut tmp_rangecoder)?; + + // Update the actual rangecoder + let (range, code) = tmp_rangecoder.into_parts(); + rangecoder.set(range, code); + + // Update tmp buffer + let end = self.tmp.position(); + let new_len = end - tmp_reader.position(); + self.tmp.get_mut()[0..new_len as usize] + .copy_from_slice(&tmp[tmp_reader.position() as usize..end as usize]); + self.tmp.set_position(new_len); + + if !res { + break; + }; + } else { + if (Mode::Run == mode) && (rangecoder.remaining()? < MAX_REQUIRED_INPUT) { + let range = rangecoder.range(); + let code = rangecoder.code(); + let buf = rangecoder.buf()?; + + if self.try_process_next(buf, range, code).is_err() { + let bytes_read = rangecoder.read_into(&mut self.tmp.get_mut()[..])?; + let bytes_read = if bytes_read < std::u64::MAX as usize { + bytes_read as u64 + } else { + return Err(error::Error::LZMAError( + "Failed to convert integer to u64.".to_string(), + )); + }; + self.tmp.set_position(bytes_read); + return Ok(()); + } + } + + if !self.process_next(&mut rangecoder)? { + break; + }; + } + } + if let Some(len) = self.unpacked_size { - if self.output.len() as u64 != len { + if Mode::Finish == mode && self.output.len() as u64 != len { return Err(error::Error::LZMAError(format!( "Expected unpacked size of {} but decompressed to {}", len, @@ -348,6 +522,7 @@ where fn decode_literal<'a, R: io::BufRead>( &mut self, rangecoder: &mut rangecoder::RangeDecoder<'a, R>, + update: bool, ) -> error::Result { let def_prev_byte = 0u8; let prev_byte = self.output.last_or(def_prev_byte) as usize; @@ -363,8 +538,9 @@ where while result < 0x100 { let match_bit = (match_byte >> 7) & 1; match_byte <<= 1; - let bit = - rangecoder.decode_bit(&mut probs[((1 + match_bit) << 8) + result])? as usize; + let bit = rangecoder + .decode_bit(&mut probs[((1 + match_bit) << 8) + result], update)? + as usize; result = (result << 1) ^ bit; if match_bit != bit { break; @@ -373,7 +549,7 @@ where } while result < 0x100 { - result = (result << 1) ^ (rangecoder.decode_bit(&mut probs[result])? as usize); + result = (result << 1) ^ (rangecoder.decode_bit(&mut probs[result], update)? as usize); } Ok((result - 0x100) as u8) @@ -383,10 +559,11 @@ where &mut self, rangecoder: &mut rangecoder::RangeDecoder<'a, R>, length: usize, + update: bool, ) -> error::Result { let len_state = if length > 3 { 3 } else { length }; - let pos_slot = self.pos_slot_decoder[len_state].parse(rangecoder)? as usize; + let pos_slot = self.pos_slot_decoder[len_state].parse(rangecoder, update)? as usize; if pos_slot < 4 { return Ok(pos_slot); } @@ -399,10 +576,11 @@ where num_direct_bits, &mut self.pos_decoders, result - pos_slot, + update, )? as usize; } else { result += (rangecoder.get(num_direct_bits - 4)? as usize) << 4; - result += self.align_decoder.parse_reverse(rangecoder)? as usize; + result += self.align_decoder.parse_reverse(rangecoder, update)? as usize; } Ok(result) diff --git a/src/decode/mod.rs b/src/decode/mod.rs index 4a81d621..2a7b0b8a 100644 --- a/src/decode/mod.rs +++ b/src/decode/mod.rs @@ -7,3 +7,6 @@ pub mod options; pub mod rangecoder; pub mod util; pub mod xz; + +#[cfg(feature = "stream")] +pub mod stream; diff --git a/src/decode/options.rs b/src/decode/options.rs index 6d84362d..cea2b588 100644 --- a/src/decode/options.rs +++ b/src/decode/options.rs @@ -1,5 +1,5 @@ /// Options to tweak decompression behavior. -#[derive(Clone, Copy, Debug, Default)] +#[derive(Clone, Copy, Debug, PartialEq, Default)] pub struct Options { /// Defines whether the unpacked size should be read from the header or provided. /// @@ -10,10 +10,16 @@ pub struct Options { /// /// The default is unlimited. pub memlimit: Option, + /// Determines whether to bypass end of stream validation. + /// + /// This option only applies to the [`Stream`](struct.Stream.html) API. + /// + /// The default is false (always do completion check). + pub allow_incomplete: bool, } /// Alternatives for defining the unpacked size of the decoded data. -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq)] pub enum UnpackedSize { /// Assume that the 8 bytes used to specify the unpacked size are present in the header. /// If the bytes are `0xFFFF_FFFF_FFFF_FFFF`, assume that there is an end-of-payload marker in @@ -38,3 +44,19 @@ impl Default for UnpackedSize { UnpackedSize::ReadFromHeader } } + +#[cfg(test)] +mod test { + use super::*; + #[test] + fn test_options() { + assert_eq!( + Options { + unpacked_size: UnpackedSize::ReadFromHeader, + memlimit: None, + allow_incomplete: false, + }, + Options::default() + ); + } +} diff --git a/src/decode/rangecoder.rs b/src/decode/rangecoder.rs index 7643ea10..70cf7735 100644 --- a/src/decode/rangecoder.rs +++ b/src/decode/rangecoder.rs @@ -28,9 +28,51 @@ where Ok(dec) } + pub fn from_parts(stream: &'a mut R, range: u32, code: u32) -> Self { + Self { + stream, + range, + code, + } + } + + pub fn into_parts(self) -> (u32, u32) { + (self.range, self.code) + } + + pub fn set(&mut self, range: u32, code: u32) { + self.range = range; + self.code = code; + } + + pub fn range(&self) -> u32 { + self.range + } + + pub fn code(&self) -> u32 { + self.code + } + + pub fn buf(&mut self) -> io::Result<&[u8]> { + self.stream.fill_buf() + } + + pub fn remaining(&mut self) -> io::Result { + Ok(self.buf()?.len()) + } + + pub fn read_into(&mut self, dst: &mut [u8]) -> io::Result { + self.stream.read(dst) + } + #[inline] pub fn is_finished_ok(&mut self) -> io::Result { - Ok(self.code == 0 && util::is_eof(self.stream)?) + Ok(self.code == 0 && self.is_eof()?) + } + + #[inline] + pub fn is_eof(&mut self) -> io::Result { + util::is_eof(self.stream) } #[inline] @@ -67,7 +109,7 @@ where } #[inline] - pub fn decode_bit(&mut self, prob: &mut u16) -> io::Result { + pub fn decode_bit(&mut self, prob: &mut u16, update: bool) -> io::Result { let bound: u32 = (self.range >> 11) * (*prob as u32); lzma_trace!( @@ -77,13 +119,17 @@ where (self.code > bound) as u8 ); if self.code < bound { - *prob += (0x800_u16 - *prob) >> 5; + if update { + *prob += (0x800_u16 - *prob) >> 5; + } self.range = bound; self.normalize()?; Ok(false) } else { - *prob -= *prob >> 5; + if update { + *prob -= *prob >> 5; + } self.code -= bound; self.range -= bound; @@ -92,10 +138,15 @@ where } } - fn parse_bit_tree(&mut self, num_bits: usize, probs: &mut [u16]) -> io::Result { + fn parse_bit_tree( + &mut self, + num_bits: usize, + probs: &mut [u16], + update: bool, + ) -> io::Result { let mut tmp: u32 = 1; for _ in 0..num_bits { - let bit = self.decode_bit(&mut probs[tmp as usize])?; + let bit = self.decode_bit(&mut probs[tmp as usize], update)?; tmp = (tmp << 1) ^ (bit as u32); } Ok(tmp - (1 << num_bits)) @@ -106,11 +157,12 @@ where num_bits: usize, probs: &mut [u16], offset: usize, + update: bool, ) -> io::Result { let mut result = 0u32; let mut tmp: usize = 1; for i in 0..num_bits { - let bit = self.decode_bit(&mut probs[offset + tmp])?; + let bit = self.decode_bit(&mut probs[offset + tmp], update)?; tmp = (tmp << 1) ^ (bit as usize); result ^= (bit as u32) << i; } @@ -133,15 +185,20 @@ impl BitTree { } } - pub fn parse(&mut self, rangecoder: &mut RangeDecoder) -> io::Result { - rangecoder.parse_bit_tree(self.num_bits, self.probs.as_mut_slice()) + pub fn parse( + &mut self, + rangecoder: &mut RangeDecoder, + update: bool, + ) -> io::Result { + rangecoder.parse_bit_tree(self.num_bits, self.probs.as_mut_slice(), update) } pub fn parse_reverse( &mut self, rangecoder: &mut RangeDecoder, + update: bool, ) -> io::Result { - rangecoder.parse_reverse_bit_tree(self.num_bits, self.probs.as_mut_slice(), 0) + rangecoder.parse_reverse_bit_tree(self.num_bits, self.probs.as_mut_slice(), 0, update) } } @@ -168,13 +225,14 @@ impl LenDecoder { &mut self, rangecoder: &mut RangeDecoder, pos_state: usize, + update: bool, ) -> io::Result { - if !rangecoder.decode_bit(&mut self.choice)? { - Ok(self.low_coder[pos_state].parse(rangecoder)? as usize) - } else if !rangecoder.decode_bit(&mut self.choice2)? { - Ok(self.mid_coder[pos_state].parse(rangecoder)? as usize + 8) + if !rangecoder.decode_bit(&mut self.choice, update)? { + Ok(self.low_coder[pos_state].parse(rangecoder, update)? as usize) + } else if !rangecoder.decode_bit(&mut self.choice2, update)? { + Ok(self.mid_coder[pos_state].parse(rangecoder, update)? as usize + 8) } else { - Ok(self.high_coder.parse(rangecoder)? as usize + 16) + Ok(self.high_coder.parse(rangecoder, update)? as usize + 16) } } } diff --git a/src/decode/stream.rs b/src/decode/stream.rs new file mode 100644 index 00000000..e4312b6e --- /dev/null +++ b/src/decode/stream.rs @@ -0,0 +1,510 @@ +use crate::decode::lzbuffer::{LZBuffer, LZCircularBuffer}; +use crate::decode::lzma::{new_circular, new_circular_with_memlimit, DecoderState, LZMAParams}; +use crate::decode::rangecoder::RangeDecoder; +use crate::decompress::Options; +use crate::error::Error; +use std::fmt::Debug; +use std::io::{BufRead, Cursor, Read, Write}; + +/// Minimum header length to be read. +/// - props: u8 (1 byte) +/// - dict_size: u32 (4 bytes) +const MIN_HEADER_LEN: usize = 5; + +/// Max header length to be read. +/// - unpacked_size: u64 (8 bytes) +const MAX_HEADER_LEN: usize = MIN_HEADER_LEN + 8; + +/// Required bytes after the header. +/// - ignore: u8 (1 byte) +/// - code: u32 (4 bytes) +const START_BYTES: usize = 5; + +/// Maximum number of bytes to buffer while reading the header. +const MAX_TMP_LEN: usize = MAX_HEADER_LEN + START_BYTES; + +/// Internal state of this streaming decoder. This is needed because we have to +/// initialize the stream before processing any data. +#[derive(Debug)] +enum State +where + W: Write, +{ + /// Stream is initialized but header values have not yet been read. + Header(W), + /// Header values have been read and the stream is ready to process more data. + Data(RunState), +} + +/// Structures needed while decoding data. +struct RunState +where + W: Write, +{ + decoder: DecoderState>, + range: u32, + code: u32, +} + +impl Debug for RunState +where + W: Write, +{ + fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { + fmt.debug_struct("RunState") + .field("range", &self.range) + .field("code", &self.code) + .finish() + } +} + +/// Lzma decompressor that can process multiple chunks of data using the +/// `std::io::Write` interface. +pub struct Stream +where + W: Write, +{ + /// Temporary buffer to hold data while the header is being read. + tmp: Cursor<[u8; MAX_TMP_LEN]>, + /// Whether the stream is initialized and ready to process data. + /// An `Option` is used to avoid interior mutability when updating the state. + state: Option>, + /// Options given when a stream is created. + options: Options, +} + +impl Stream +where + W: Write, +{ + /// Initialize the stream. This will consume the `output` which is the sink + /// implementing `std::io::Write` that will receive decompressed bytes. + pub fn new(output: W) -> Self { + Self::new_with_options(&Options::default(), output) + } + + /// Initialize the stream with the given `options`. This will consume the + /// `output` which is the sink implementing `std::io::Write` that will + /// receive decompressed bytes. + pub fn new_with_options(options: &Options, output: W) -> Self { + Self { + tmp: Cursor::new([0; MAX_TMP_LEN]), + state: Some(State::Header(output)), + options: options.clone(), + } + } + + /// Get a reference to the output sink + pub fn get_output(&self) -> Option<&W> { + self.state.as_ref().map(|state| match state { + State::Header(output) => &output, + State::Data(state) => state.decoder.output.get_output(), + }) + } + + /// Get a mutable reference to the output sink + pub fn get_output_mut(&mut self) -> Option<&mut W> { + self.state.as_mut().map(|state| match state { + State::Header(output) => output, + State::Data(state) => state.decoder.output.get_output_mut(), + }) + } + + /// Consumes the stream and returns the output sink. This also makes sure + /// we have properly reached the end of the stream. + pub fn finish(mut self) -> crate::error::Result { + if let Some(state) = self.state.take() { + match state { + State::Header(output) => { + if self.tmp.position() > 0 { + Err(Error::LZMAError("failed to read header".to_string())) + } else { + Ok(output) + } + } + State::Data(mut state) => { + if !self.options.allow_incomplete { + // Process one last time with empty input to force end of + // stream checks + let mut stream = + Cursor::new(&self.tmp.get_ref()[0..self.tmp.position() as usize]); + let mut range_decoder = + RangeDecoder::from_parts(&mut stream, state.range, state.code); + state.decoder.process(&mut range_decoder)?; + } + let output = state.decoder.output.finish()?; + Ok(output) + } + } + } else { + // this will occur if a call to `write()` fails + Err(Error::LZMAError( + "can't finish stream because of previous write error".to_string(), + )) + } + } + + /// Attempts to read the header and transition into a running state. + /// + /// This function will consume the state, returning the next state on both + /// error and success. + fn read_header( + output: W, + mut input: &mut R, + options: &Options, + ) -> crate::error::Result> { + match LZMAParams::read_header(&mut input, options) { + Ok(params) => { + let decoder = if let Some(memlimit) = options.memlimit { + new_circular_with_memlimit(output, params, memlimit) + } else { + new_circular(output, params) + }?; + + // The RangeDecoder is only kept temporarily as we are processing + // chunks of data. + if let Ok(range_decoder) = RangeDecoder::new(&mut input) { + let (range, code) = range_decoder.into_parts(); + + Ok(State::Data(RunState { + decoder, + range, + code, + })) + } else { + // Failed to create a RangeDecoder because we need more data, + // try again later. + Ok(State::Header(decoder.output.into_output())) + } + } + // Failed to read_header() because we need more data, try again later. + Err(Error::HeaderTooShort(_)) => Ok(State::Header(output)), + // Fatal error. Don't retry. + Err(e) => Err(e), + } + } + + /// Process compressed data + fn read_data( + mut state: RunState, + mut input: &mut R, + ) -> std::io::Result> { + // Construct our RangeDecoder from the previous range and code + // values. + let mut range_decoder = RangeDecoder::from_parts(&mut input, state.range, state.code); + + // Try to process all bytes of data. + state + .decoder + .process_stream(&mut range_decoder) + .map_err(|e| -> std::io::Error { e.into() })?; + + // Save the range and code for the next chunk of data. + let (range, code) = range_decoder.into_parts(); + state.range = range; + state.code = code; + + Ok(RunState { + decoder: state.decoder, + range, + code, + }) + } +} + +impl Debug for Stream +where + W: Write + Debug, +{ + fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { + fmt.debug_struct("Stream") + .field("tmp", &self.tmp.position()) + .field("state", &self.state) + .field("options", &self.options) + .finish() + } +} + +impl Write for Stream +where + W: Write, +{ + fn write(&mut self, data: &[u8]) -> std::io::Result { + let mut input = Cursor::new(data); + + if let Some(state) = self.state.take() { + let state = match state { + // Read the header values and transition into a running state. + State::Header(state) => { + let res = if self.tmp.position() > 0 { + // attempt to fill the tmp buffer + let position = self.tmp.position(); + let bytes_read = + input.read(&mut self.tmp.get_mut()[position as usize..])?; + let bytes_read = if bytes_read < std::u64::MAX as usize { + bytes_read as u64 + } else { + return Err(std::io::Error::new( + std::io::ErrorKind::Other, + "Failed to convert integer to u64.", + )); + }; + self.tmp.set_position(position + bytes_read); + + // attempt to read the header from our tmp buffer + let (position, res) = { + let mut tmp_input = + Cursor::new(&self.tmp.get_ref()[0..self.tmp.position() as usize]); + let res = Stream::read_header(state, &mut tmp_input, &self.options); + (tmp_input.position(), res) + }; + + // discard all bytes up to position if reading the header + // was successful + match &res { + Ok(State::Data(_)) => { + let tmp = self.tmp.get_ref().clone(); + let end = self.tmp.position(); + let new_len = end - position; + (&mut self.tmp.get_mut()[0..new_len as usize]) + .copy_from_slice(&tmp[position as usize..end as usize]); + self.tmp.set_position(new_len); + } + _ => {} + } + res + } else { + Stream::read_header(state, &mut input, &self.options) + }; + + match res { + // occurs when not enough input bytes were provided to + // read the entire header + Ok(State::Header(val)) => { + if self.tmp.position() == 0 { + // reset the cursor because we may have partial reads + input.set_position(0); + let bytes_read = input.read(&mut self.tmp.get_mut()[..])?; + let bytes_read = if bytes_read < std::u64::MAX as usize { + bytes_read as u64 + } else { + return Err(std::io::Error::new( + std::io::ErrorKind::Other, + "Failed to convert integer to u64.", + )); + }; + self.tmp.set_position(bytes_read); + } + State::Header(val) + } + + // occurs when the header was successfully read and we + // move on to the next state + Ok(State::Data(val)) => State::Data(val), + + // occurs when the output was consumed due to a + // non-recoverable error + Err(e) => { + return Err(match e { + Error::IOError(e) | Error::HeaderTooShort(e) => e, + Error::LZMAError(e) | Error::XZError(e) => { + std::io::Error::new(std::io::ErrorKind::Other, e) + } + }); + } + } + } + + // Process another chunk of data. + State::Data(state) => { + let state = if self.tmp.position() > 0 { + let mut tmp_input = + Cursor::new(&self.tmp.get_ref()[0..self.tmp.position() as usize]); + let res = Stream::read_data(state, &mut tmp_input)?; + self.tmp.set_position(0); + res + } else { + state + }; + State::Data(Stream::read_data(state, &mut input)?) + } + }; + self.state.replace(state); + } + Ok(input.position() as usize) + } + + /// Flushes the output sink. The internal buffer isn't flushed to avoid + /// corrupting the internal state. Instead, call `finish()` to finalize the + /// stream and flush all remaining internal data. + fn flush(&mut self) -> std::io::Result<()> { + if let Some(ref mut state) = self.state { + match state { + State::Header(_) => Ok(()), + State::Data(state) => state.decoder.output.get_output_mut().flush(), + } + } else { + Ok(()) + } + } +} + +impl std::convert::Into for Error { + fn into(self) -> std::io::Error { + std::io::Error::new(std::io::ErrorKind::Other, format!("{:?}", self)) + } +} + +#[cfg(test)] +mod test { + use super::*; + + /// Test an empty stream + #[test] + fn test_stream_noop() { + let stream = Stream::new(Vec::new()); + assert!(stream.get_output().unwrap().is_empty()); + + let output = stream.finish().unwrap(); + assert!(output.is_empty()); + } + + /// Test writing an empty slice + #[test] + fn test_stream_zero() { + let mut stream = Stream::new(Vec::new()); + + stream.write_all(&[]).unwrap(); + stream.write_all(&[]).unwrap(); + + let output = stream.finish().unwrap(); + + assert!(output.is_empty()); + } + + /// Test a bad header value + #[test] + #[should_panic(expected = "LZMA header invalid properties: 255 must be < 225")] + fn test_bad_header() { + let input = [255u8; 32]; + + let mut stream = Stream::new(Vec::new()); + + stream.write_all(&input[..]).unwrap(); + + let output = stream.finish().unwrap(); + + assert!(output.is_empty()); + } + + /// Test processing only partial data + #[test] + fn test_stream_incomplete() { + let input = b"\x5d\x00\x00\x80\x00\xff\xff\xff\xff\xff\xff\xff\xff\x00\x83\xff\ + \xfb\xff\xff\xc0\x00\x00\x00"; + // Process until this index is reached. + let mut end = 1u64; + + // Test when we fail to provide the minimum number of bytes required to + // read the header. Header size is 13 bytes but we also read the first 5 + // bytes of data. + while end < (MAX_HEADER_LEN + START_BYTES) as u64 { + let mut stream = Stream::new(Vec::new()); + stream.write_all(&input[..end as usize]).unwrap(); + assert_eq!(stream.tmp.position(), end); + + let err = stream.finish().unwrap_err(); + assert!( + err.to_string().contains("failed to read header"), + "error was: {}", + err + ); + + end += 1; + } + + // Test when we fail to provide enough bytes to terminate the stream. A + // properly terminated stream will have a code value of 0. + while end < input.len() as u64 { + let mut stream = Stream::new(Vec::new()); + stream.write_all(&input[..end as usize]).unwrap(); + + // Header bytes will be buffered until there are enough to read + if end < (MAX_HEADER_LEN + START_BYTES) as u64 { + assert_eq!(stream.tmp.position(), end); + } + + let err = stream.finish().unwrap_err(); + assert!(err.to_string().contains("failed to fill whole buffer")); + + end += 1; + } + } + + /// Test processing all chunk sizes + #[test] + fn test_stream_chunked() { + let small_input = include_bytes!("../../tests/files/small.txt"); + + let mut reader = std::io::Cursor::new(&small_input[..]); + let mut small_input_compressed = Vec::new(); + crate::lzma_compress(&mut reader, &mut small_input_compressed).unwrap(); + + let input : Vec<(&[u8], &[u8])> = vec![ + (b"\x5d\x00\x00\x80\x00\xff\xff\xff\xff\xff\xff\xff\xff\x00\x83\xff\xfb\xff\xff\xc0\x00\x00\x00", b""), + (&small_input_compressed[..], small_input)]; + for (input, expected) in input { + for chunk in 1..input.len() { + let mut consumed = 0; + let mut stream = Stream::new(Vec::new()); + while consumed < input.len() { + let end = std::cmp::min(consumed + chunk, input.len()); + stream.write_all(&input[consumed..end]).unwrap(); + consumed = end; + } + let output = stream.finish().unwrap(); + assert_eq!(expected, &output[..]); + } + } + } + + #[test] + fn test_stream_corrupted() { + let mut stream = Stream::new(Vec::new()); + let err = stream + .write_all(b"corrupted bytes here corrupted bytes here") + .unwrap_err(); + assert!(err.to_string().contains("beyond output size")); + let err = stream.finish().unwrap_err(); + assert!(err + .to_string() + .contains("can\'t finish stream because of previous write error")); + } + + #[test] + fn test_allow_incomplete() { + let input = include_bytes!("../../tests/files/small.txt"); + + let mut reader = std::io::Cursor::new(&input[..]); + let mut compressed = Vec::new(); + crate::lzma_compress(&mut reader, &mut compressed).unwrap(); + let compressed = &compressed[..compressed.len() / 2]; + + // Should fail to finish() without the allow_incomplete option. + let mut stream = Stream::new(Vec::new()); + stream.write_all(&compressed[..]).unwrap(); + stream.finish().unwrap_err(); + + // Should succeed with the allow_incomplete option. + let mut stream = Stream::new_with_options( + &Options { + allow_incomplete: true, + ..Default::default() + }, + Vec::new(), + ); + stream.write_all(&compressed[..]).unwrap(); + let output = stream.finish().unwrap(); + assert_eq!(output, &input[..26]); + } +} diff --git a/src/encode/rangecoder.rs b/src/encode/rangecoder.rs index 5ab8582d..da5385d2 100644 --- a/src/encode/rangecoder.rs +++ b/src/encode/rangecoder.rs @@ -238,7 +238,7 @@ mod test { let mut decoder = RangeDecoder::new(&mut bufread).unwrap(); let mut prob = prob_init; for &b in bits { - assert_eq!(decoder.decode_bit(&mut prob).unwrap(), b); + assert_eq!(decoder.decode_bit(&mut prob, true).unwrap(), b); } assert!(decoder.is_finished_ok().unwrap()); } @@ -267,7 +267,7 @@ mod test { let mut decoder = RangeDecoder::new(&mut bufread).unwrap(); let mut tree = decode::rangecoder::BitTree::new(num_bits); for &v in values { - assert_eq!(tree.parse(&mut decoder).unwrap(), v); + assert_eq!(tree.parse(&mut decoder, true).unwrap(), v); } assert!(decoder.is_finished_ok().unwrap()); } @@ -309,7 +309,7 @@ mod test { let mut decoder = RangeDecoder::new(&mut bufread).unwrap(); let mut tree = decode::rangecoder::BitTree::new(num_bits); for &v in values { - assert_eq!(tree.parse_reverse(&mut decoder).unwrap(), v); + assert_eq!(tree.parse_reverse(&mut decoder, true).unwrap(), v); } assert!(decoder.is_finished_ok().unwrap()); } @@ -352,7 +352,7 @@ mod test { let mut len_decoder = LenDecoder::new(); for &v in values { assert_eq!( - len_decoder.decode(&mut decoder, pos_state).unwrap(), + len_decoder.decode(&mut decoder, pos_state, true).unwrap(), v as usize ); } diff --git a/src/error.rs b/src/error.rs index a9ec01d1..8e156558 100644 --- a/src/error.rs +++ b/src/error.rs @@ -9,6 +9,8 @@ use std::result; pub enum Error { /// I/O error. IOError(io::Error), + /// Not enough bytes to complete header + HeaderTooShort(io::Error), /// LZMA error. LZMAError(String), /// XZ error. @@ -28,6 +30,7 @@ impl Display for Error { fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Error::IOError(e) => write!(fmt, "io error: {}", e), + Error::HeaderTooShort(e) => write!(fmt, "header too short: {}", e), Error::LZMAError(e) => write!(fmt, "lzma error: {}", e), Error::XZError(e) => write!(fmt, "xz error: {}", e), } @@ -37,7 +40,7 @@ impl Display for Error { impl std::error::Error for Error { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { match self { - Error::IOError(e) => Some(e), + Error::IOError(e) | Error::HeaderTooShort(e) => Some(e), Error::LZMAError(_) | Error::XZError(_) => None, } } diff --git a/src/lib.rs b/src/lib.rs index fca86e5d..ca70bb0d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,6 +23,8 @@ pub mod compress { /// Decompression helpers. pub mod decompress { pub use crate::decode::options::*; + #[cfg(feature = "stream")] + pub use crate::decode::stream::Stream; } /// Decompress LZMA data with default [`Options`](decompress/struct.Options.html). diff --git a/tests/files/small.txt b/tests/files/small.txt new file mode 100644 index 00000000..ddb2f8d6 --- /dev/null +++ b/tests/files/small.txt @@ -0,0 +1 @@ +Project Gutenberg's Alice's Adventures in Wonderland, by Lewis Carroll diff --git a/tests/lzma.rs b/tests/lzma.rs index cec9576f..86a480a9 100644 --- a/tests/lzma.rs +++ b/tests/lzma.rs @@ -1,6 +1,9 @@ #[cfg(feature = "enable_logging")] use log::{debug, info}; +#[cfg(feature = "stream")] +use std::io::Write; + fn round_trip(x: &[u8]) { round_trip_no_options(x); @@ -22,17 +25,30 @@ fn round_trip_no_options(x: &[u8]) { info!("Compressed {} -> {} bytes", x.len(), compressed.len()); #[cfg(feature = "enable_logging")] debug!("Compressed content: {:?}", compressed); - let mut bf = std::io::BufReader::new(compressed.as_slice()); - let mut decomp: Vec = Vec::new(); - lzma_rs::lzma_decompress(&mut bf, &mut decomp).unwrap(); - assert_eq!(decomp, x) + + // test non-streaming decompression + { + let mut bf = std::io::BufReader::new(compressed.as_slice()); + let mut decomp: Vec = Vec::new(); + lzma_rs::lzma_decompress(&mut bf, &mut decomp).unwrap(); + assert_eq!(decomp, x); + } + + #[cfg(feature = "stream")] + // test streaming decompression + { + let mut stream = lzma_rs::decompress::Stream::new(Vec::new()); + stream.write_all(&compressed).unwrap(); + let decomp = stream.finish().unwrap(); + assert_eq!(decomp, x); + } } -fn round_trip_with_options( +fn assert_round_trip_with_options( x: &[u8], encode_options: &lzma_rs::compress::Options, decode_options: &lzma_rs::decompress::Options, -) -> lzma_rs::error::Result> { +) { let mut compressed: Vec = Vec::new(); lzma_rs::lzma_compress_with_options( &mut std::io::BufReader::new(x), @@ -44,21 +60,30 @@ fn round_trip_with_options( info!("Compressed {} -> {} bytes", x.len(), compressed.len()); #[cfg(feature = "enable_logging")] debug!("Compressed content: {:?}", compressed); - let mut bf = std::io::BufReader::new(compressed.as_slice()); - let mut decomp: Vec = Vec::new(); - lzma_rs::lzma_decompress_with_options(&mut bf, &mut decomp, decode_options)?; - Ok(decomp) -} -fn assert_round_trip_with_options( - x: &[u8], - encode_options: &lzma_rs::compress::Options, - decode_options: &lzma_rs::decompress::Options, -) { - assert_eq!( - round_trip_with_options(x, encode_options, decode_options).unwrap(), - x - ) + // test non-streaming decompression + { + let mut bf = std::io::BufReader::new(compressed.as_slice()); + let mut decomp: Vec = Vec::new(); + lzma_rs::lzma_decompress_with_options(&mut bf, &mut decomp, decode_options).unwrap(); + assert_eq!(decomp, x); + } + + #[cfg(feature = "stream")] + // test streaming decompression + { + let mut stream = lzma_rs::decompress::Stream::new_with_options(decode_options, Vec::new()); + + if let Err(error) = stream.write_all(&compressed) { + // WriteZero could indicate that the unpacked_size was reached before the + // end of the stream + if std::io::ErrorKind::WriteZero != error.kind() { + panic!(error); + } + } + let decomp = stream.finish().unwrap(); + assert_eq!(decomp, x); + } } fn round_trip_file(filename: &str) { @@ -75,30 +100,79 @@ fn round_trip_file(filename: &str) { fn decomp_big_file(compfile: &str, plainfile: &str) { use std::io::Read; - let mut expected = Vec::new(); - std::fs::File::open(plainfile) - .unwrap() - .read_to_end(&mut expected) - .unwrap(); - let mut f = std::io::BufReader::new(std::fs::File::open(compfile).unwrap()); - let mut decomp: Vec = Vec::new(); - lzma_rs::lzma_decompress(&mut f, &mut decomp).unwrap(); - assert!(decomp == expected) + let expected = { + let mut expected = Vec::new(); + std::fs::File::open(plainfile) + .unwrap() + .read_to_end(&mut expected) + .unwrap(); + expected + }; + + // test non-streaming decompression + { + let input = { + let mut input = Vec::new(); + std::fs::File::open(compfile) + .unwrap() + .read_to_end(&mut input) + .unwrap(); + input + }; + + let mut input = std::io::BufReader::new(input.as_slice()); + let mut decomp: Vec = Vec::new(); + lzma_rs::lzma_decompress(&mut input, &mut decomp).unwrap(); + assert_eq!(decomp, expected); + } + + #[cfg(feature = "stream")] + // test streaming decompression + { + let mut compfile = std::fs::File::open(compfile).unwrap(); + let mut stream = lzma_rs::decompress::Stream::new(Vec::new()); + + // read file in chunks + let mut tmp = [0u8; 1024]; + while { + let n = compfile.read(&mut tmp).unwrap(); + stream.write_all(&tmp[0..n]).unwrap(); + + n > 0 + } {} + + let decomp = stream.finish().unwrap(); + assert_eq!(decomp, expected); + } +} + +fn assert_decomp_eq(input: &[u8], expected: &[u8]) { + // test non-streaming decompression + { + let mut input = std::io::BufReader::new(input); + let mut decomp: Vec = Vec::new(); + lzma_rs::lzma_decompress(&mut input, &mut decomp).unwrap(); + assert_eq!(decomp, expected) + } + + #[cfg(feature = "stream")] + // test streaming decompression + { + let mut stream = lzma_rs::decompress::Stream::new(Vec::new()); + stream.write_all(input).unwrap(); + let decomp = stream.finish().unwrap(); + assert_eq!(decomp, expected); + } } #[test] +#[should_panic(expected = "HeaderTooShort")] fn decompress_short_header() { #[cfg(feature = "enable_logging")] let _ = env_logger::try_init(); let mut decomp: Vec = Vec::new(); // TODO: compare io::Errors? - assert_eq!( - format!( - "{:?}", - lzma_rs::lzma_decompress(&mut (b"" as &[u8]), &mut decomp).unwrap_err() - ), - String::from("LZMAError(\"LZMA header too short: failed to fill whole buffer\")") - ) + lzma_rs::lzma_decompress(&mut (b"" as &[u8]), &mut decomp).unwrap(); } #[test] @@ -142,23 +216,23 @@ fn big_file() { fn decompress_empty_world() { #[cfg(feature = "enable_logging")] let _ = env_logger::try_init(); - let mut x: &[u8] = b"\x5d\x00\x00\x80\x00\xff\xff\xff\xff\xff\xff\xff\xff\x00\x83\xff\ - \xfb\xff\xff\xc0\x00\x00\x00"; - let mut decomp: Vec = Vec::new(); - lzma_rs::lzma_decompress(&mut x, &mut decomp).unwrap(); - assert_eq!(decomp, b"") + assert_decomp_eq( + b"\x5d\x00\x00\x80\x00\xff\xff\xff\xff\xff\xff\xff\xff\x00\x83\xff\ + \xfb\xff\xff\xc0\x00\x00\x00", + b"", + ); } #[test] fn decompress_hello_world() { #[cfg(feature = "enable_logging")] let _ = env_logger::try_init(); - let mut x: &[u8] = b"\x5d\x00\x00\x80\x00\xff\xff\xff\xff\xff\xff\xff\xff\x00\x24\x19\ - \x49\x98\x6f\x10\x19\xc6\xd7\x31\xeb\x36\x50\xb2\x98\x48\xff\xfe\ - \xa5\xb0\x00"; - let mut decomp: Vec = Vec::new(); - lzma_rs::lzma_decompress(&mut x, &mut decomp).unwrap(); - assert_eq!(decomp, b"Hello world\x0a") + assert_decomp_eq( + b"\x5d\x00\x00\x80\x00\xff\xff\xff\xff\xff\xff\xff\xff\x00\x24\x19\ + \x49\x98\x6f\x10\x19\xc6\xd7\x31\xeb\x36\x50\xb2\x98\x48\xff\xfe\ + \xa5\xb0\x00", + b"Hello world\x0a", + ); } #[test] @@ -166,12 +240,12 @@ fn decompress_huge_dict() { // Hello world with a dictionary of size 0x7F7F7F7F #[cfg(feature = "enable_logging")] let _ = env_logger::try_init(); - let mut x: &[u8] = b"\x5d\x7f\x7f\x7f\x7f\xff\xff\xff\xff\xff\xff\xff\xff\x00\x24\x19\ + assert_decomp_eq( + b"\x5d\x7f\x7f\x7f\x7f\xff\xff\xff\xff\xff\xff\xff\xff\x00\x24\x19\ \x49\x98\x6f\x10\x19\xc6\xd7\x31\xeb\x36\x50\xb2\x98\x48\xff\xfe\ - \xa5\xb0\x00"; - let mut decomp: Vec = Vec::new(); - lzma_rs::lzma_decompress(&mut x, &mut decomp).unwrap(); - assert_eq!(decomp, b"Hello world\x0a") + \xa5\xb0\x00", + b"Hello world\x0a", + ); } #[test] @@ -244,7 +318,6 @@ fn unpacked_size_write_none_to_header_and_use_provided_none_on_read() { } #[test] -#[should_panic(expected = "exceeded memory limit of 0")] fn memlimit() { let data = b"Some data"; let encode_options = lzma_rs::compress::Options { @@ -253,6 +326,43 @@ fn memlimit() { let decode_options = lzma_rs::decompress::Options { unpacked_size: lzma_rs::decompress::UnpackedSize::ReadHeaderButUseProvided(None), memlimit: Some(0), + ..Default::default() }; - round_trip_with_options(&data[..], &encode_options, &decode_options).unwrap(); + + let mut compressed: Vec = Vec::new(); + lzma_rs::lzma_compress_with_options( + &mut std::io::BufReader::new(&data[..]), + &mut compressed, + &encode_options, + ) + .unwrap(); + + // test non-streaming decompression + { + let mut bf = std::io::BufReader::new(compressed.as_slice()); + let mut decomp: Vec = Vec::new(); + let error = lzma_rs::lzma_decompress_with_options(&mut bf, &mut decomp, &decode_options) + .unwrap_err(); + assert!( + error.to_string().contains("exceeded memory limit of 0"), + error.to_string() + ); + } + + #[cfg(feature = "stream")] + // test streaming decompression + { + let mut stream = lzma_rs::decompress::Stream::new_with_options(&decode_options, Vec::new()); + + let error = stream.write_all(&compressed).unwrap_err(); + assert!( + error.to_string().contains("exceeded memory limit of 0"), + error.to_string() + ); + let error = stream.finish().unwrap_err(); + assert!( + error.to_string().contains("previous write error"), + error.to_string() + ); + } }