diff --git a/Cargo.toml b/Cargo.toml index 77b529fe..c482136f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ env_logger = { version = "^0.7.1", optional = true } [features] enable_logging = ["env_logger", "log"] +stream = [] [badges] travis-ci = { repository = "gendx/lzma-rs", branch = "master" } diff --git a/benches/lzma.rs b/benches/lzma.rs index dc567ad0..57b9f5c6 100644 --- a/benches/lzma.rs +++ b/benches/lzma.rs @@ -34,6 +34,16 @@ fn decompress_bench(compressed: &[u8], b: &mut Bencher) { }); } +#[cfg(feature = "stream")] +fn decompress_stream_bench(compressed: &[u8], b: &mut Bencher) { + use std::io::Write; + b.iter(|| { + let mut stream = lzma_rs::decompress::Stream::new(Vec::new()); + stream.write_all(compressed).unwrap(); + stream.finish().unwrap() + }); +} + fn decompress_bench_file(compfile: &str, b: &mut Bencher) { let mut f = std::fs::File::open(compfile).unwrap(); let mut compressed = Vec::new(); @@ -41,6 +51,14 @@ fn decompress_bench_file(compfile: &str, b: &mut Bencher) { decompress_bench(&compressed, b); } +#[cfg(feature = "stream")] +fn decompress_stream_bench_file(compfile: &str, b: &mut Bencher) { + let mut f = std::fs::File::open(compfile).unwrap(); + let mut compressed = Vec::new(); + f.read_to_end(&mut compressed).unwrap(); + decompress_stream_bench(&compressed, b); +} + #[bench] fn compress_empty(b: &mut Bencher) { #[cfg(feature = "enable_logging")] @@ -90,6 +108,14 @@ fn decompress_big_file(b: &mut Bencher) { decompress_bench_file("tests/files/foo.txt.lzma", b); } +#[cfg(feature = "stream")] +#[bench] +fn decompress_stream_big_file(b: &mut Bencher) { + #[cfg(feature = "enable_logging")] + let _ = env_logger::try_init(); + decompress_stream_bench_file("tests/files/foo.txt.lzma", b); +} + #[bench] fn decompress_huge_dict(b: &mut Bencher) { #[cfg(feature = "enable_logging")] diff --git a/src/decode/lzbuffer.rs b/src/decode/lzbuffer.rs index fecd5930..9d77331a 100644 --- a/src/decode/lzbuffer.rs +++ b/src/decode/lzbuffer.rs @@ -18,8 +18,10 @@ where fn get_output(&self) -> &W; // Get a mutable reference to the output sink fn get_output_mut(&mut self) -> &mut W; - // Flush the buffer to the output + // Consumes this buffer and flushes any data fn finish(self) -> io::Result; + // Consumes this buffer without flushing any data + fn into_output(self) -> W; } // An accumulating buffer for LZ sequences @@ -143,12 +145,17 @@ where &mut self.stream } - // Flush the buffer to the output + // Consumes this buffer and flushes any data fn finish(mut self) -> io::Result { self.stream.write_all(self.buf.as_slice())?; self.stream.flush()?; Ok(self.stream) } + + // Consumes this buffer without flushing any data + fn into_output(self) -> W { + self.stream + } } // A circular buffer for LZ sequences @@ -291,7 +298,7 @@ where &mut self.stream } - // Flush the buffer to the output + // Consumes this buffer and flushes any data fn finish(mut self) -> io::Result { if self.cursor > 0 { self.stream.write_all(&self.buf[0..self.cursor])?; @@ -299,4 +306,9 @@ where } Ok(self.stream) } + + // Consumes this buffer without flushing any data + fn into_output(self) -> W { + self.stream + } } diff --git a/src/decode/lzma.rs b/src/decode/lzma.rs index bc3c2c00..05dedbf7 100644 --- a/src/decode/lzma.rs +++ b/src/decode/lzma.rs @@ -8,6 +8,24 @@ use std::marker::PhantomData; use crate::decompress::Options; use crate::decompress::UnpackedSize; +/// Maximum input data that can be processed in one iteration +const MAX_REQUIRED_INPUT: usize = 20; + +/// Processing mode for decompression. +/// +/// Tells the decompressor if we should expect more data after parsing the +/// current input. +#[derive(Debug, PartialEq)] +pub enum Mode { + /// Streaming mode. Process the input bytes but assume there will be more + /// chunks of input data to receive in future calls to `process_mode()`. + Run, + /// Sync mode. Process the input bytes and confirm end of stream has been reached. + /// Use this mode if you are processing a fixed buffer of compressed data, or after + /// using `Mode::Run` to check for the end of stream. + Finish, +} + pub struct LZMAParams { // most lc significant bits of previous byte are part of the literal context lc: u32, // 0..8 @@ -24,9 +42,7 @@ impl LZMAParams { R: io::BufRead, { // Properties - let props = input - .read_u8() - .map_err(|e| error::Error::LZMAError(format!("LZMA header too short: {}", e)))?; + let props = input.read_u8().map_err(error::Error::HeaderTooShort)?; let mut pb = props as u32; if pb >= 225 { @@ -46,7 +62,7 @@ impl LZMAParams { // Dictionary let dict_size_provided = input .read_u32::() - .map_err(|e| error::Error::LZMAError(format!("LZMA header too short: {}", e)))?; + .map_err(error::Error::HeaderTooShort)?; let dict_size = if dict_size_provided < 0x1000 { 0x1000 } else { @@ -58,9 +74,9 @@ impl LZMAParams { // Unpacked size let unpacked_size: Option = match options.unpacked_size { UnpackedSize::ReadFromHeader => { - let unpacked_size_provided = input.read_u64::().map_err(|e| { - error::Error::LZMAError(format!("LZMA header too short: {}", e)) - })?; + let unpacked_size_provided = input + .read_u64::() + .map_err(error::Error::HeaderTooShort)?; let marker_mandatory: bool = unpacked_size_provided == 0xFFFF_FFFF_FFFF_FFFF; if marker_mandatory { None @@ -69,7 +85,9 @@ impl LZMAParams { } } UnpackedSize::ReadHeaderButUseProvided(x) => { - input.read_u64::()?; + input + .read_u64::() + .map_err(error::Error::HeaderTooShort)?; x } UnpackedSize::UseProvided(x) => x, @@ -95,6 +113,9 @@ where LZB: lzbuffer::LZBuffer, { _phantom: PhantomData, + // buffer input data here if we need more for decompression, 20 is the max + // number of bytes that can be consumed during one iteration + tmp: std::io::Cursor<[u8; MAX_REQUIRED_INPUT]>, pub output: LZB, // most lc significant bits of previous byte are part of the literal context pub lc: u32, // 0..8 @@ -131,6 +152,7 @@ where { DecoderState { _phantom: PhantomData, + tmp: std::io::Cursor::new([0; MAX_REQUIRED_INPUT]), output, lc, lp, @@ -181,6 +203,7 @@ where params.dict_size as usize, memlimit, ), + tmp: std::io::Cursor::new([0; MAX_REQUIRED_INPUT]), lc: params.lc, lp: params.lp, pb: params.pb, @@ -237,23 +260,38 @@ where &mut self, rangecoder: &mut rangecoder::RangeDecoder<'a, R>, ) -> error::Result<()> { - loop { - if let Some(unpacked_size) = self.unpacked_size { - if self.output.len() as u64 >= unpacked_size { - break; - } - } else if rangecoder.is_finished_ok()? { - break; - } + self.process_mode(rangecoder, Mode::Finish) + } - let pos_state = self.output.len() & ((1 << self.pb) - 1); + #[cfg(feature = "stream")] + pub fn process_stream<'a, R: io::BufRead>( + &mut self, + rangecoder: &mut rangecoder::RangeDecoder<'a, R>, + ) -> error::Result<()> { + self.process_mode(rangecoder, Mode::Run) + } - // Literal - if !rangecoder.decode_bit( - // TODO: assumes pb = 2 ?? - &mut self.is_match[(self.state << 4) + pos_state], - )? { - let byte: u8 = self.decode_literal(rangecoder)?; + /// Process the next iteration of the loop. + /// + /// If the update flag is true, the decoder's state will be updated. + /// + /// Returns true if we should continue processing the loop, false otherwise. + fn process_next_inner<'a, R: io::BufRead>( + &mut self, + rangecoder: &mut rangecoder::RangeDecoder<'a, R>, + update: bool, + ) -> error::Result { + let pos_state = self.output.len() & ((1 << self.pb) - 1); + + // Literal + if !rangecoder.decode_bit( + // TODO: assumes pb = 2 ?? + &mut self.is_match[(self.state << 4) + pos_state], + update, + )? { + let byte: u8 = self.decode_literal(rangecoder, update)?; + + if update { lzma_debug!("Literal: {}", byte); self.output.append_literal(byte)?; @@ -264,35 +302,40 @@ where } else { self.state - 6 }; - continue; } + return Ok(true); + } - // LZ - let mut len: usize; - // Distance is repeated from LRU - if rangecoder.decode_bit(&mut self.is_rep[self.state])? { - // dist = rep[0] - if !rangecoder.decode_bit(&mut self.is_rep_g0[self.state])? { - // len = 1 - if !rangecoder - .decode_bit(&mut self.is_rep_0long[(self.state << 4) + pos_state])? - { - // update state (short rep) + // LZ + let mut len: usize; + // Distance is repeated from LRU + if rangecoder.decode_bit(&mut self.is_rep[self.state], update)? { + // dist = rep[0] + if !rangecoder.decode_bit(&mut self.is_rep_g0[self.state], update)? { + // len = 1 + if !rangecoder.decode_bit( + &mut self.is_rep_0long[(self.state << 4) + pos_state], + update, + )? { + // update state (short rep) + if update { self.state = if self.state < 7 { 9 } else { 11 }; let dist = self.rep[0] + 1; self.output.append_lz(1, dist)?; - continue; } - // dist = rep[i] + return Ok(true); + } + // dist = rep[i] + } else { + let idx: usize; + if !rangecoder.decode_bit(&mut self.is_rep_g1[self.state], update)? { + idx = 1; + } else if !rangecoder.decode_bit(&mut self.is_rep_g2[self.state], update)? { + idx = 2; } else { - let idx: usize; - if !rangecoder.decode_bit(&mut self.is_rep_g1[self.state])? { - idx = 1; - } else if !rangecoder.decode_bit(&mut self.is_rep_g2[self.state])? { - idx = 2; - } else { - idx = 3; - } + idx = 3; + } + if update { // Update LRU let dist = self.rep[idx]; for i in (0..idx).rev() { @@ -300,40 +343,171 @@ where } self.rep[0] = dist } + } + + len = self.rep_len_decoder.decode(rangecoder, pos_state, update)?; - len = self.rep_len_decoder.decode(rangecoder, pos_state)?; + if update { // update state (rep) self.state = if self.state < 7 { 8 } else { 11 }; - // New distance - } else { + } + // New distance + } else { + if update { // Update LRU self.rep[3] = self.rep[2]; self.rep[2] = self.rep[1]; self.rep[1] = self.rep[0]; - len = self.len_decoder.decode(rangecoder, pos_state)?; + } + len = self.len_decoder.decode(rangecoder, pos_state, update)?; + + if update { // update state (match) self.state = if self.state < 7 { 7 } else { 10 }; - self.rep[0] = self.decode_distance(rangecoder, len)?; + } + let rep_0 = self.decode_distance(rangecoder, len, update)?; + + if update { + self.rep[0] = rep_0; if self.rep[0] == 0xFFFF_FFFF { if rangecoder.is_finished_ok()? { - break; + return Ok(false); } return Err(error::Error::LZMAError(String::from( "Found end-of-stream marker but more bytes are available", ))); } } + } + if update { len += 2; let dist = self.rep[0] + 1; self.output.append_lz(len, dist)?; } + Ok(true) + } + + fn process_next<'a, R: io::BufRead>( + &mut self, + rangecoder: &mut rangecoder::RangeDecoder<'a, R>, + ) -> error::Result { + self.process_next_inner(rangecoder, true) + } + + /// Try to process the next iteration of the loop. + /// + /// This will check to see if there is enough data to consume and advance the + /// decompressor. Needed in streaming mode to avoid corrupting the state while + /// processing incomplete chunks of data. + #[allow(clippy::if_same_then_else)] + fn try_process_next<'a>(&mut self, buf: &'a [u8], range: u32, code: u32) -> error::Result<()> { + let mut temp = std::io::Cursor::new(buf); + let mut rangecoder = rangecoder::RangeDecoder::from_parts(&mut temp, range, code); + let _ = self.process_next_inner(&mut rangecoder, false)?; + Ok(()) + } + + pub fn process_mode<'a, R: io::BufRead>( + &mut self, + mut rangecoder: &mut rangecoder::RangeDecoder<'a, R>, + mode: Mode, + ) -> error::Result<()> { + loop { + if let Some(unpacked_size) = self.unpacked_size { + if self.output.len() as u64 >= unpacked_size { + break; + } + } else if match mode { + Mode::Run => rangecoder.is_eof()?, + Mode::Finish => rangecoder.is_finished_ok()?, + } { + break; + } + + if self.tmp.position() as usize > 0 { + // Fill as much of the tmp buffer as possible + let start = self.tmp.position() as usize; + let bytes_read = rangecoder.read_into(&mut self.tmp.get_mut()[start..])?; + let bytes_read = if bytes_read < std::u64::MAX as usize { + bytes_read as u64 + } else { + return Err(error::Error::LZMAError( + "Failed to convert integer to u64.".to_string(), + )); + }; + self.tmp.set_position(self.tmp.position() + bytes_read); + let tmp = *self.tmp.get_ref(); + + // Check if we need more data to advance the decompressor + if Mode::Run == mode + && (self.tmp.position() as usize) < MAX_REQUIRED_INPUT + && self + .try_process_next( + &tmp[0..self.tmp.position() as usize], + rangecoder.range(), + rangecoder.code(), + ) + .is_err() + { + return Ok(()); + } + + // Run the decompressor on the tmp buffer + let mut tmp_reader = io::Cursor::new(&tmp[0..self.tmp.position() as usize]); + let mut tmp_rangecoder = rangecoder::RangeDecoder::from_parts( + &mut tmp_reader, + rangecoder.range(), + rangecoder.code(), + ); + let res = self.process_next(&mut tmp_rangecoder)?; + + // Update the actual rangecoder + let (range, code) = tmp_rangecoder.into_parts(); + rangecoder.set(range, code); + + // Update tmp buffer + let end = self.tmp.position(); + let new_len = end - tmp_reader.position(); + self.tmp.get_mut()[0..new_len as usize] + .copy_from_slice(&tmp[tmp_reader.position() as usize..end as usize]); + self.tmp.set_position(new_len); + + if !res { + break; + }; + } else { + if (Mode::Run == mode) && (rangecoder.remaining()? < MAX_REQUIRED_INPUT) { + let range = rangecoder.range(); + let code = rangecoder.code(); + let buf = rangecoder.buf()?; + + if self.try_process_next(buf, range, code).is_err() { + let bytes_read = rangecoder.read_into(&mut self.tmp.get_mut()[..])?; + let bytes_read = if bytes_read < std::u64::MAX as usize { + bytes_read as u64 + } else { + return Err(error::Error::LZMAError( + "Failed to convert integer to u64.".to_string(), + )); + }; + self.tmp.set_position(bytes_read); + return Ok(()); + } + } + + if !self.process_next(&mut rangecoder)? { + break; + }; + } + } + if let Some(len) = self.unpacked_size { - if self.output.len() as u64 != len { + if Mode::Finish == mode && self.output.len() as u64 != len { return Err(error::Error::LZMAError(format!( "Expected unpacked size of {} but decompressed to {}", len, @@ -348,6 +522,7 @@ where fn decode_literal<'a, R: io::BufRead>( &mut self, rangecoder: &mut rangecoder::RangeDecoder<'a, R>, + update: bool, ) -> error::Result { let def_prev_byte = 0u8; let prev_byte = self.output.last_or(def_prev_byte) as usize; @@ -363,8 +538,9 @@ where while result < 0x100 { let match_bit = (match_byte >> 7) & 1; match_byte <<= 1; - let bit = - rangecoder.decode_bit(&mut probs[((1 + match_bit) << 8) + result])? as usize; + let bit = rangecoder + .decode_bit(&mut probs[((1 + match_bit) << 8) + result], update)? + as usize; result = (result << 1) ^ bit; if match_bit != bit { break; @@ -373,7 +549,7 @@ where } while result < 0x100 { - result = (result << 1) ^ (rangecoder.decode_bit(&mut probs[result])? as usize); + result = (result << 1) ^ (rangecoder.decode_bit(&mut probs[result], update)? as usize); } Ok((result - 0x100) as u8) @@ -383,10 +559,11 @@ where &mut self, rangecoder: &mut rangecoder::RangeDecoder<'a, R>, length: usize, + update: bool, ) -> error::Result { let len_state = if length > 3 { 3 } else { length }; - let pos_slot = self.pos_slot_decoder[len_state].parse(rangecoder)? as usize; + let pos_slot = self.pos_slot_decoder[len_state].parse(rangecoder, update)? as usize; if pos_slot < 4 { return Ok(pos_slot); } @@ -399,10 +576,11 @@ where num_direct_bits, &mut self.pos_decoders, result - pos_slot, + update, )? as usize; } else { result += (rangecoder.get(num_direct_bits - 4)? as usize) << 4; - result += self.align_decoder.parse_reverse(rangecoder)? as usize; + result += self.align_decoder.parse_reverse(rangecoder, update)? as usize; } Ok(result) diff --git a/src/decode/mod.rs b/src/decode/mod.rs index 4a81d621..2a7b0b8a 100644 --- a/src/decode/mod.rs +++ b/src/decode/mod.rs @@ -7,3 +7,6 @@ pub mod options; pub mod rangecoder; pub mod util; pub mod xz; + +#[cfg(feature = "stream")] +pub mod stream; diff --git a/src/decode/options.rs b/src/decode/options.rs index 6d84362d..cea2b588 100644 --- a/src/decode/options.rs +++ b/src/decode/options.rs @@ -1,5 +1,5 @@ /// Options to tweak decompression behavior. -#[derive(Clone, Copy, Debug, Default)] +#[derive(Clone, Copy, Debug, PartialEq, Default)] pub struct Options { /// Defines whether the unpacked size should be read from the header or provided. /// @@ -10,10 +10,16 @@ pub struct Options { /// /// The default is unlimited. pub memlimit: Option, + /// Determines whether to bypass end of stream validation. + /// + /// This option only applies to the [`Stream`](struct.Stream.html) API. + /// + /// The default is false (always do completion check). + pub allow_incomplete: bool, } /// Alternatives for defining the unpacked size of the decoded data. -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq)] pub enum UnpackedSize { /// Assume that the 8 bytes used to specify the unpacked size are present in the header. /// If the bytes are `0xFFFF_FFFF_FFFF_FFFF`, assume that there is an end-of-payload marker in @@ -38,3 +44,19 @@ impl Default for UnpackedSize { UnpackedSize::ReadFromHeader } } + +#[cfg(test)] +mod test { + use super::*; + #[test] + fn test_options() { + assert_eq!( + Options { + unpacked_size: UnpackedSize::ReadFromHeader, + memlimit: None, + allow_incomplete: false, + }, + Options::default() + ); + } +} diff --git a/src/decode/rangecoder.rs b/src/decode/rangecoder.rs index 7643ea10..70cf7735 100644 --- a/src/decode/rangecoder.rs +++ b/src/decode/rangecoder.rs @@ -28,9 +28,51 @@ where Ok(dec) } + pub fn from_parts(stream: &'a mut R, range: u32, code: u32) -> Self { + Self { + stream, + range, + code, + } + } + + pub fn into_parts(self) -> (u32, u32) { + (self.range, self.code) + } + + pub fn set(&mut self, range: u32, code: u32) { + self.range = range; + self.code = code; + } + + pub fn range(&self) -> u32 { + self.range + } + + pub fn code(&self) -> u32 { + self.code + } + + pub fn buf(&mut self) -> io::Result<&[u8]> { + self.stream.fill_buf() + } + + pub fn remaining(&mut self) -> io::Result { + Ok(self.buf()?.len()) + } + + pub fn read_into(&mut self, dst: &mut [u8]) -> io::Result { + self.stream.read(dst) + } + #[inline] pub fn is_finished_ok(&mut self) -> io::Result { - Ok(self.code == 0 && util::is_eof(self.stream)?) + Ok(self.code == 0 && self.is_eof()?) + } + + #[inline] + pub fn is_eof(&mut self) -> io::Result { + util::is_eof(self.stream) } #[inline] @@ -67,7 +109,7 @@ where } #[inline] - pub fn decode_bit(&mut self, prob: &mut u16) -> io::Result { + pub fn decode_bit(&mut self, prob: &mut u16, update: bool) -> io::Result { let bound: u32 = (self.range >> 11) * (*prob as u32); lzma_trace!( @@ -77,13 +119,17 @@ where (self.code > bound) as u8 ); if self.code < bound { - *prob += (0x800_u16 - *prob) >> 5; + if update { + *prob += (0x800_u16 - *prob) >> 5; + } self.range = bound; self.normalize()?; Ok(false) } else { - *prob -= *prob >> 5; + if update { + *prob -= *prob >> 5; + } self.code -= bound; self.range -= bound; @@ -92,10 +138,15 @@ where } } - fn parse_bit_tree(&mut self, num_bits: usize, probs: &mut [u16]) -> io::Result { + fn parse_bit_tree( + &mut self, + num_bits: usize, + probs: &mut [u16], + update: bool, + ) -> io::Result { let mut tmp: u32 = 1; for _ in 0..num_bits { - let bit = self.decode_bit(&mut probs[tmp as usize])?; + let bit = self.decode_bit(&mut probs[tmp as usize], update)?; tmp = (tmp << 1) ^ (bit as u32); } Ok(tmp - (1 << num_bits)) @@ -106,11 +157,12 @@ where num_bits: usize, probs: &mut [u16], offset: usize, + update: bool, ) -> io::Result { let mut result = 0u32; let mut tmp: usize = 1; for i in 0..num_bits { - let bit = self.decode_bit(&mut probs[offset + tmp])?; + let bit = self.decode_bit(&mut probs[offset + tmp], update)?; tmp = (tmp << 1) ^ (bit as usize); result ^= (bit as u32) << i; } @@ -133,15 +185,20 @@ impl BitTree { } } - pub fn parse(&mut self, rangecoder: &mut RangeDecoder) -> io::Result { - rangecoder.parse_bit_tree(self.num_bits, self.probs.as_mut_slice()) + pub fn parse( + &mut self, + rangecoder: &mut RangeDecoder, + update: bool, + ) -> io::Result { + rangecoder.parse_bit_tree(self.num_bits, self.probs.as_mut_slice(), update) } pub fn parse_reverse( &mut self, rangecoder: &mut RangeDecoder, + update: bool, ) -> io::Result { - rangecoder.parse_reverse_bit_tree(self.num_bits, self.probs.as_mut_slice(), 0) + rangecoder.parse_reverse_bit_tree(self.num_bits, self.probs.as_mut_slice(), 0, update) } } @@ -168,13 +225,14 @@ impl LenDecoder { &mut self, rangecoder: &mut RangeDecoder, pos_state: usize, + update: bool, ) -> io::Result { - if !rangecoder.decode_bit(&mut self.choice)? { - Ok(self.low_coder[pos_state].parse(rangecoder)? as usize) - } else if !rangecoder.decode_bit(&mut self.choice2)? { - Ok(self.mid_coder[pos_state].parse(rangecoder)? as usize + 8) + if !rangecoder.decode_bit(&mut self.choice, update)? { + Ok(self.low_coder[pos_state].parse(rangecoder, update)? as usize) + } else if !rangecoder.decode_bit(&mut self.choice2, update)? { + Ok(self.mid_coder[pos_state].parse(rangecoder, update)? as usize + 8) } else { - Ok(self.high_coder.parse(rangecoder)? as usize + 16) + Ok(self.high_coder.parse(rangecoder, update)? as usize + 16) } } } diff --git a/src/decode/stream.rs b/src/decode/stream.rs new file mode 100644 index 00000000..e4312b6e --- /dev/null +++ b/src/decode/stream.rs @@ -0,0 +1,510 @@ +use crate::decode::lzbuffer::{LZBuffer, LZCircularBuffer}; +use crate::decode::lzma::{new_circular, new_circular_with_memlimit, DecoderState, LZMAParams}; +use crate::decode::rangecoder::RangeDecoder; +use crate::decompress::Options; +use crate::error::Error; +use std::fmt::Debug; +use std::io::{BufRead, Cursor, Read, Write}; + +/// Minimum header length to be read. +/// - props: u8 (1 byte) +/// - dict_size: u32 (4 bytes) +const MIN_HEADER_LEN: usize = 5; + +/// Max header length to be read. +/// - unpacked_size: u64 (8 bytes) +const MAX_HEADER_LEN: usize = MIN_HEADER_LEN + 8; + +/// Required bytes after the header. +/// - ignore: u8 (1 byte) +/// - code: u32 (4 bytes) +const START_BYTES: usize = 5; + +/// Maximum number of bytes to buffer while reading the header. +const MAX_TMP_LEN: usize = MAX_HEADER_LEN + START_BYTES; + +/// Internal state of this streaming decoder. This is needed because we have to +/// initialize the stream before processing any data. +#[derive(Debug)] +enum State +where + W: Write, +{ + /// Stream is initialized but header values have not yet been read. + Header(W), + /// Header values have been read and the stream is ready to process more data. + Data(RunState), +} + +/// Structures needed while decoding data. +struct RunState +where + W: Write, +{ + decoder: DecoderState>, + range: u32, + code: u32, +} + +impl Debug for RunState +where + W: Write, +{ + fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { + fmt.debug_struct("RunState") + .field("range", &self.range) + .field("code", &self.code) + .finish() + } +} + +/// Lzma decompressor that can process multiple chunks of data using the +/// `std::io::Write` interface. +pub struct Stream +where + W: Write, +{ + /// Temporary buffer to hold data while the header is being read. + tmp: Cursor<[u8; MAX_TMP_LEN]>, + /// Whether the stream is initialized and ready to process data. + /// An `Option` is used to avoid interior mutability when updating the state. + state: Option>, + /// Options given when a stream is created. + options: Options, +} + +impl Stream +where + W: Write, +{ + /// Initialize the stream. This will consume the `output` which is the sink + /// implementing `std::io::Write` that will receive decompressed bytes. + pub fn new(output: W) -> Self { + Self::new_with_options(&Options::default(), output) + } + + /// Initialize the stream with the given `options`. This will consume the + /// `output` which is the sink implementing `std::io::Write` that will + /// receive decompressed bytes. + pub fn new_with_options(options: &Options, output: W) -> Self { + Self { + tmp: Cursor::new([0; MAX_TMP_LEN]), + state: Some(State::Header(output)), + options: options.clone(), + } + } + + /// Get a reference to the output sink + pub fn get_output(&self) -> Option<&W> { + self.state.as_ref().map(|state| match state { + State::Header(output) => &output, + State::Data(state) => state.decoder.output.get_output(), + }) + } + + /// Get a mutable reference to the output sink + pub fn get_output_mut(&mut self) -> Option<&mut W> { + self.state.as_mut().map(|state| match state { + State::Header(output) => output, + State::Data(state) => state.decoder.output.get_output_mut(), + }) + } + + /// Consumes the stream and returns the output sink. This also makes sure + /// we have properly reached the end of the stream. + pub fn finish(mut self) -> crate::error::Result { + if let Some(state) = self.state.take() { + match state { + State::Header(output) => { + if self.tmp.position() > 0 { + Err(Error::LZMAError("failed to read header".to_string())) + } else { + Ok(output) + } + } + State::Data(mut state) => { + if !self.options.allow_incomplete { + // Process one last time with empty input to force end of + // stream checks + let mut stream = + Cursor::new(&self.tmp.get_ref()[0..self.tmp.position() as usize]); + let mut range_decoder = + RangeDecoder::from_parts(&mut stream, state.range, state.code); + state.decoder.process(&mut range_decoder)?; + } + let output = state.decoder.output.finish()?; + Ok(output) + } + } + } else { + // this will occur if a call to `write()` fails + Err(Error::LZMAError( + "can't finish stream because of previous write error".to_string(), + )) + } + } + + /// Attempts to read the header and transition into a running state. + /// + /// This function will consume the state, returning the next state on both + /// error and success. + fn read_header( + output: W, + mut input: &mut R, + options: &Options, + ) -> crate::error::Result> { + match LZMAParams::read_header(&mut input, options) { + Ok(params) => { + let decoder = if let Some(memlimit) = options.memlimit { + new_circular_with_memlimit(output, params, memlimit) + } else { + new_circular(output, params) + }?; + + // The RangeDecoder is only kept temporarily as we are processing + // chunks of data. + if let Ok(range_decoder) = RangeDecoder::new(&mut input) { + let (range, code) = range_decoder.into_parts(); + + Ok(State::Data(RunState { + decoder, + range, + code, + })) + } else { + // Failed to create a RangeDecoder because we need more data, + // try again later. + Ok(State::Header(decoder.output.into_output())) + } + } + // Failed to read_header() because we need more data, try again later. + Err(Error::HeaderTooShort(_)) => Ok(State::Header(output)), + // Fatal error. Don't retry. + Err(e) => Err(e), + } + } + + /// Process compressed data + fn read_data( + mut state: RunState, + mut input: &mut R, + ) -> std::io::Result> { + // Construct our RangeDecoder from the previous range and code + // values. + let mut range_decoder = RangeDecoder::from_parts(&mut input, state.range, state.code); + + // Try to process all bytes of data. + state + .decoder + .process_stream(&mut range_decoder) + .map_err(|e| -> std::io::Error { e.into() })?; + + // Save the range and code for the next chunk of data. + let (range, code) = range_decoder.into_parts(); + state.range = range; + state.code = code; + + Ok(RunState { + decoder: state.decoder, + range, + code, + }) + } +} + +impl Debug for Stream +where + W: Write + Debug, +{ + fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { + fmt.debug_struct("Stream") + .field("tmp", &self.tmp.position()) + .field("state", &self.state) + .field("options", &self.options) + .finish() + } +} + +impl Write for Stream +where + W: Write, +{ + fn write(&mut self, data: &[u8]) -> std::io::Result { + let mut input = Cursor::new(data); + + if let Some(state) = self.state.take() { + let state = match state { + // Read the header values and transition into a running state. + State::Header(state) => { + let res = if self.tmp.position() > 0 { + // attempt to fill the tmp buffer + let position = self.tmp.position(); + let bytes_read = + input.read(&mut self.tmp.get_mut()[position as usize..])?; + let bytes_read = if bytes_read < std::u64::MAX as usize { + bytes_read as u64 + } else { + return Err(std::io::Error::new( + std::io::ErrorKind::Other, + "Failed to convert integer to u64.", + )); + }; + self.tmp.set_position(position + bytes_read); + + // attempt to read the header from our tmp buffer + let (position, res) = { + let mut tmp_input = + Cursor::new(&self.tmp.get_ref()[0..self.tmp.position() as usize]); + let res = Stream::read_header(state, &mut tmp_input, &self.options); + (tmp_input.position(), res) + }; + + // discard all bytes up to position if reading the header + // was successful + match &res { + Ok(State::Data(_)) => { + let tmp = self.tmp.get_ref().clone(); + let end = self.tmp.position(); + let new_len = end - position; + (&mut self.tmp.get_mut()[0..new_len as usize]) + .copy_from_slice(&tmp[position as usize..end as usize]); + self.tmp.set_position(new_len); + } + _ => {} + } + res + } else { + Stream::read_header(state, &mut input, &self.options) + }; + + match res { + // occurs when not enough input bytes were provided to + // read the entire header + Ok(State::Header(val)) => { + if self.tmp.position() == 0 { + // reset the cursor because we may have partial reads + input.set_position(0); + let bytes_read = input.read(&mut self.tmp.get_mut()[..])?; + let bytes_read = if bytes_read < std::u64::MAX as usize { + bytes_read as u64 + } else { + return Err(std::io::Error::new( + std::io::ErrorKind::Other, + "Failed to convert integer to u64.", + )); + }; + self.tmp.set_position(bytes_read); + } + State::Header(val) + } + + // occurs when the header was successfully read and we + // move on to the next state + Ok(State::Data(val)) => State::Data(val), + + // occurs when the output was consumed due to a + // non-recoverable error + Err(e) => { + return Err(match e { + Error::IOError(e) | Error::HeaderTooShort(e) => e, + Error::LZMAError(e) | Error::XZError(e) => { + std::io::Error::new(std::io::ErrorKind::Other, e) + } + }); + } + } + } + + // Process another chunk of data. + State::Data(state) => { + let state = if self.tmp.position() > 0 { + let mut tmp_input = + Cursor::new(&self.tmp.get_ref()[0..self.tmp.position() as usize]); + let res = Stream::read_data(state, &mut tmp_input)?; + self.tmp.set_position(0); + res + } else { + state + }; + State::Data(Stream::read_data(state, &mut input)?) + } + }; + self.state.replace(state); + } + Ok(input.position() as usize) + } + + /// Flushes the output sink. The internal buffer isn't flushed to avoid + /// corrupting the internal state. Instead, call `finish()` to finalize the + /// stream and flush all remaining internal data. + fn flush(&mut self) -> std::io::Result<()> { + if let Some(ref mut state) = self.state { + match state { + State::Header(_) => Ok(()), + State::Data(state) => state.decoder.output.get_output_mut().flush(), + } + } else { + Ok(()) + } + } +} + +impl std::convert::Into for Error { + fn into(self) -> std::io::Error { + std::io::Error::new(std::io::ErrorKind::Other, format!("{:?}", self)) + } +} + +#[cfg(test)] +mod test { + use super::*; + + /// Test an empty stream + #[test] + fn test_stream_noop() { + let stream = Stream::new(Vec::new()); + assert!(stream.get_output().unwrap().is_empty()); + + let output = stream.finish().unwrap(); + assert!(output.is_empty()); + } + + /// Test writing an empty slice + #[test] + fn test_stream_zero() { + let mut stream = Stream::new(Vec::new()); + + stream.write_all(&[]).unwrap(); + stream.write_all(&[]).unwrap(); + + let output = stream.finish().unwrap(); + + assert!(output.is_empty()); + } + + /// Test a bad header value + #[test] + #[should_panic(expected = "LZMA header invalid properties: 255 must be < 225")] + fn test_bad_header() { + let input = [255u8; 32]; + + let mut stream = Stream::new(Vec::new()); + + stream.write_all(&input[..]).unwrap(); + + let output = stream.finish().unwrap(); + + assert!(output.is_empty()); + } + + /// Test processing only partial data + #[test] + fn test_stream_incomplete() { + let input = b"\x5d\x00\x00\x80\x00\xff\xff\xff\xff\xff\xff\xff\xff\x00\x83\xff\ + \xfb\xff\xff\xc0\x00\x00\x00"; + // Process until this index is reached. + let mut end = 1u64; + + // Test when we fail to provide the minimum number of bytes required to + // read the header. Header size is 13 bytes but we also read the first 5 + // bytes of data. + while end < (MAX_HEADER_LEN + START_BYTES) as u64 { + let mut stream = Stream::new(Vec::new()); + stream.write_all(&input[..end as usize]).unwrap(); + assert_eq!(stream.tmp.position(), end); + + let err = stream.finish().unwrap_err(); + assert!( + err.to_string().contains("failed to read header"), + "error was: {}", + err + ); + + end += 1; + } + + // Test when we fail to provide enough bytes to terminate the stream. A + // properly terminated stream will have a code value of 0. + while end < input.len() as u64 { + let mut stream = Stream::new(Vec::new()); + stream.write_all(&input[..end as usize]).unwrap(); + + // Header bytes will be buffered until there are enough to read + if end < (MAX_HEADER_LEN + START_BYTES) as u64 { + assert_eq!(stream.tmp.position(), end); + } + + let err = stream.finish().unwrap_err(); + assert!(err.to_string().contains("failed to fill whole buffer")); + + end += 1; + } + } + + /// Test processing all chunk sizes + #[test] + fn test_stream_chunked() { + let small_input = include_bytes!("../../tests/files/small.txt"); + + let mut reader = std::io::Cursor::new(&small_input[..]); + let mut small_input_compressed = Vec::new(); + crate::lzma_compress(&mut reader, &mut small_input_compressed).unwrap(); + + let input : Vec<(&[u8], &[u8])> = vec![ + (b"\x5d\x00\x00\x80\x00\xff\xff\xff\xff\xff\xff\xff\xff\x00\x83\xff\xfb\xff\xff\xc0\x00\x00\x00", b""), + (&small_input_compressed[..], small_input)]; + for (input, expected) in input { + for chunk in 1..input.len() { + let mut consumed = 0; + let mut stream = Stream::new(Vec::new()); + while consumed < input.len() { + let end = std::cmp::min(consumed + chunk, input.len()); + stream.write_all(&input[consumed..end]).unwrap(); + consumed = end; + } + let output = stream.finish().unwrap(); + assert_eq!(expected, &output[..]); + } + } + } + + #[test] + fn test_stream_corrupted() { + let mut stream = Stream::new(Vec::new()); + let err = stream + .write_all(b"corrupted bytes here corrupted bytes here") + .unwrap_err(); + assert!(err.to_string().contains("beyond output size")); + let err = stream.finish().unwrap_err(); + assert!(err + .to_string() + .contains("can\'t finish stream because of previous write error")); + } + + #[test] + fn test_allow_incomplete() { + let input = include_bytes!("../../tests/files/small.txt"); + + let mut reader = std::io::Cursor::new(&input[..]); + let mut compressed = Vec::new(); + crate::lzma_compress(&mut reader, &mut compressed).unwrap(); + let compressed = &compressed[..compressed.len() / 2]; + + // Should fail to finish() without the allow_incomplete option. + let mut stream = Stream::new(Vec::new()); + stream.write_all(&compressed[..]).unwrap(); + stream.finish().unwrap_err(); + + // Should succeed with the allow_incomplete option. + let mut stream = Stream::new_with_options( + &Options { + allow_incomplete: true, + ..Default::default() + }, + Vec::new(), + ); + stream.write_all(&compressed[..]).unwrap(); + let output = stream.finish().unwrap(); + assert_eq!(output, &input[..26]); + } +} diff --git a/src/encode/rangecoder.rs b/src/encode/rangecoder.rs index 5ab8582d..da5385d2 100644 --- a/src/encode/rangecoder.rs +++ b/src/encode/rangecoder.rs @@ -238,7 +238,7 @@ mod test { let mut decoder = RangeDecoder::new(&mut bufread).unwrap(); let mut prob = prob_init; for &b in bits { - assert_eq!(decoder.decode_bit(&mut prob).unwrap(), b); + assert_eq!(decoder.decode_bit(&mut prob, true).unwrap(), b); } assert!(decoder.is_finished_ok().unwrap()); } @@ -267,7 +267,7 @@ mod test { let mut decoder = RangeDecoder::new(&mut bufread).unwrap(); let mut tree = decode::rangecoder::BitTree::new(num_bits); for &v in values { - assert_eq!(tree.parse(&mut decoder).unwrap(), v); + assert_eq!(tree.parse(&mut decoder, true).unwrap(), v); } assert!(decoder.is_finished_ok().unwrap()); } @@ -309,7 +309,7 @@ mod test { let mut decoder = RangeDecoder::new(&mut bufread).unwrap(); let mut tree = decode::rangecoder::BitTree::new(num_bits); for &v in values { - assert_eq!(tree.parse_reverse(&mut decoder).unwrap(), v); + assert_eq!(tree.parse_reverse(&mut decoder, true).unwrap(), v); } assert!(decoder.is_finished_ok().unwrap()); } @@ -352,7 +352,7 @@ mod test { let mut len_decoder = LenDecoder::new(); for &v in values { assert_eq!( - len_decoder.decode(&mut decoder, pos_state).unwrap(), + len_decoder.decode(&mut decoder, pos_state, true).unwrap(), v as usize ); } diff --git a/src/error.rs b/src/error.rs index a9ec01d1..8e156558 100644 --- a/src/error.rs +++ b/src/error.rs @@ -9,6 +9,8 @@ use std::result; pub enum Error { /// I/O error. IOError(io::Error), + /// Not enough bytes to complete header + HeaderTooShort(io::Error), /// LZMA error. LZMAError(String), /// XZ error. @@ -28,6 +30,7 @@ impl Display for Error { fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Error::IOError(e) => write!(fmt, "io error: {}", e), + Error::HeaderTooShort(e) => write!(fmt, "header too short: {}", e), Error::LZMAError(e) => write!(fmt, "lzma error: {}", e), Error::XZError(e) => write!(fmt, "xz error: {}", e), } @@ -37,7 +40,7 @@ impl Display for Error { impl std::error::Error for Error { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { match self { - Error::IOError(e) => Some(e), + Error::IOError(e) | Error::HeaderTooShort(e) => Some(e), Error::LZMAError(_) | Error::XZError(_) => None, } } diff --git a/src/lib.rs b/src/lib.rs index fca86e5d..ca70bb0d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,6 +23,8 @@ pub mod compress { /// Decompression helpers. pub mod decompress { pub use crate::decode::options::*; + #[cfg(feature = "stream")] + pub use crate::decode::stream::Stream; } /// Decompress LZMA data with default [`Options`](decompress/struct.Options.html). diff --git a/tests/files/small.txt b/tests/files/small.txt new file mode 100644 index 00000000..ddb2f8d6 --- /dev/null +++ b/tests/files/small.txt @@ -0,0 +1 @@ +Project Gutenberg's Alice's Adventures in Wonderland, by Lewis Carroll diff --git a/tests/lzma.rs b/tests/lzma.rs index cec9576f..86a480a9 100644 --- a/tests/lzma.rs +++ b/tests/lzma.rs @@ -1,6 +1,9 @@ #[cfg(feature = "enable_logging")] use log::{debug, info}; +#[cfg(feature = "stream")] +use std::io::Write; + fn round_trip(x: &[u8]) { round_trip_no_options(x); @@ -22,17 +25,30 @@ fn round_trip_no_options(x: &[u8]) { info!("Compressed {} -> {} bytes", x.len(), compressed.len()); #[cfg(feature = "enable_logging")] debug!("Compressed content: {:?}", compressed); - let mut bf = std::io::BufReader::new(compressed.as_slice()); - let mut decomp: Vec = Vec::new(); - lzma_rs::lzma_decompress(&mut bf, &mut decomp).unwrap(); - assert_eq!(decomp, x) + + // test non-streaming decompression + { + let mut bf = std::io::BufReader::new(compressed.as_slice()); + let mut decomp: Vec = Vec::new(); + lzma_rs::lzma_decompress(&mut bf, &mut decomp).unwrap(); + assert_eq!(decomp, x); + } + + #[cfg(feature = "stream")] + // test streaming decompression + { + let mut stream = lzma_rs::decompress::Stream::new(Vec::new()); + stream.write_all(&compressed).unwrap(); + let decomp = stream.finish().unwrap(); + assert_eq!(decomp, x); + } } -fn round_trip_with_options( +fn assert_round_trip_with_options( x: &[u8], encode_options: &lzma_rs::compress::Options, decode_options: &lzma_rs::decompress::Options, -) -> lzma_rs::error::Result> { +) { let mut compressed: Vec = Vec::new(); lzma_rs::lzma_compress_with_options( &mut std::io::BufReader::new(x), @@ -44,21 +60,30 @@ fn round_trip_with_options( info!("Compressed {} -> {} bytes", x.len(), compressed.len()); #[cfg(feature = "enable_logging")] debug!("Compressed content: {:?}", compressed); - let mut bf = std::io::BufReader::new(compressed.as_slice()); - let mut decomp: Vec = Vec::new(); - lzma_rs::lzma_decompress_with_options(&mut bf, &mut decomp, decode_options)?; - Ok(decomp) -} -fn assert_round_trip_with_options( - x: &[u8], - encode_options: &lzma_rs::compress::Options, - decode_options: &lzma_rs::decompress::Options, -) { - assert_eq!( - round_trip_with_options(x, encode_options, decode_options).unwrap(), - x - ) + // test non-streaming decompression + { + let mut bf = std::io::BufReader::new(compressed.as_slice()); + let mut decomp: Vec = Vec::new(); + lzma_rs::lzma_decompress_with_options(&mut bf, &mut decomp, decode_options).unwrap(); + assert_eq!(decomp, x); + } + + #[cfg(feature = "stream")] + // test streaming decompression + { + let mut stream = lzma_rs::decompress::Stream::new_with_options(decode_options, Vec::new()); + + if let Err(error) = stream.write_all(&compressed) { + // WriteZero could indicate that the unpacked_size was reached before the + // end of the stream + if std::io::ErrorKind::WriteZero != error.kind() { + panic!(error); + } + } + let decomp = stream.finish().unwrap(); + assert_eq!(decomp, x); + } } fn round_trip_file(filename: &str) { @@ -75,30 +100,79 @@ fn round_trip_file(filename: &str) { fn decomp_big_file(compfile: &str, plainfile: &str) { use std::io::Read; - let mut expected = Vec::new(); - std::fs::File::open(plainfile) - .unwrap() - .read_to_end(&mut expected) - .unwrap(); - let mut f = std::io::BufReader::new(std::fs::File::open(compfile).unwrap()); - let mut decomp: Vec = Vec::new(); - lzma_rs::lzma_decompress(&mut f, &mut decomp).unwrap(); - assert!(decomp == expected) + let expected = { + let mut expected = Vec::new(); + std::fs::File::open(plainfile) + .unwrap() + .read_to_end(&mut expected) + .unwrap(); + expected + }; + + // test non-streaming decompression + { + let input = { + let mut input = Vec::new(); + std::fs::File::open(compfile) + .unwrap() + .read_to_end(&mut input) + .unwrap(); + input + }; + + let mut input = std::io::BufReader::new(input.as_slice()); + let mut decomp: Vec = Vec::new(); + lzma_rs::lzma_decompress(&mut input, &mut decomp).unwrap(); + assert_eq!(decomp, expected); + } + + #[cfg(feature = "stream")] + // test streaming decompression + { + let mut compfile = std::fs::File::open(compfile).unwrap(); + let mut stream = lzma_rs::decompress::Stream::new(Vec::new()); + + // read file in chunks + let mut tmp = [0u8; 1024]; + while { + let n = compfile.read(&mut tmp).unwrap(); + stream.write_all(&tmp[0..n]).unwrap(); + + n > 0 + } {} + + let decomp = stream.finish().unwrap(); + assert_eq!(decomp, expected); + } +} + +fn assert_decomp_eq(input: &[u8], expected: &[u8]) { + // test non-streaming decompression + { + let mut input = std::io::BufReader::new(input); + let mut decomp: Vec = Vec::new(); + lzma_rs::lzma_decompress(&mut input, &mut decomp).unwrap(); + assert_eq!(decomp, expected) + } + + #[cfg(feature = "stream")] + // test streaming decompression + { + let mut stream = lzma_rs::decompress::Stream::new(Vec::new()); + stream.write_all(input).unwrap(); + let decomp = stream.finish().unwrap(); + assert_eq!(decomp, expected); + } } #[test] +#[should_panic(expected = "HeaderTooShort")] fn decompress_short_header() { #[cfg(feature = "enable_logging")] let _ = env_logger::try_init(); let mut decomp: Vec = Vec::new(); // TODO: compare io::Errors? - assert_eq!( - format!( - "{:?}", - lzma_rs::lzma_decompress(&mut (b"" as &[u8]), &mut decomp).unwrap_err() - ), - String::from("LZMAError(\"LZMA header too short: failed to fill whole buffer\")") - ) + lzma_rs::lzma_decompress(&mut (b"" as &[u8]), &mut decomp).unwrap(); } #[test] @@ -142,23 +216,23 @@ fn big_file() { fn decompress_empty_world() { #[cfg(feature = "enable_logging")] let _ = env_logger::try_init(); - let mut x: &[u8] = b"\x5d\x00\x00\x80\x00\xff\xff\xff\xff\xff\xff\xff\xff\x00\x83\xff\ - \xfb\xff\xff\xc0\x00\x00\x00"; - let mut decomp: Vec = Vec::new(); - lzma_rs::lzma_decompress(&mut x, &mut decomp).unwrap(); - assert_eq!(decomp, b"") + assert_decomp_eq( + b"\x5d\x00\x00\x80\x00\xff\xff\xff\xff\xff\xff\xff\xff\x00\x83\xff\ + \xfb\xff\xff\xc0\x00\x00\x00", + b"", + ); } #[test] fn decompress_hello_world() { #[cfg(feature = "enable_logging")] let _ = env_logger::try_init(); - let mut x: &[u8] = b"\x5d\x00\x00\x80\x00\xff\xff\xff\xff\xff\xff\xff\xff\x00\x24\x19\ - \x49\x98\x6f\x10\x19\xc6\xd7\x31\xeb\x36\x50\xb2\x98\x48\xff\xfe\ - \xa5\xb0\x00"; - let mut decomp: Vec = Vec::new(); - lzma_rs::lzma_decompress(&mut x, &mut decomp).unwrap(); - assert_eq!(decomp, b"Hello world\x0a") + assert_decomp_eq( + b"\x5d\x00\x00\x80\x00\xff\xff\xff\xff\xff\xff\xff\xff\x00\x24\x19\ + \x49\x98\x6f\x10\x19\xc6\xd7\x31\xeb\x36\x50\xb2\x98\x48\xff\xfe\ + \xa5\xb0\x00", + b"Hello world\x0a", + ); } #[test] @@ -166,12 +240,12 @@ fn decompress_huge_dict() { // Hello world with a dictionary of size 0x7F7F7F7F #[cfg(feature = "enable_logging")] let _ = env_logger::try_init(); - let mut x: &[u8] = b"\x5d\x7f\x7f\x7f\x7f\xff\xff\xff\xff\xff\xff\xff\xff\x00\x24\x19\ + assert_decomp_eq( + b"\x5d\x7f\x7f\x7f\x7f\xff\xff\xff\xff\xff\xff\xff\xff\x00\x24\x19\ \x49\x98\x6f\x10\x19\xc6\xd7\x31\xeb\x36\x50\xb2\x98\x48\xff\xfe\ - \xa5\xb0\x00"; - let mut decomp: Vec = Vec::new(); - lzma_rs::lzma_decompress(&mut x, &mut decomp).unwrap(); - assert_eq!(decomp, b"Hello world\x0a") + \xa5\xb0\x00", + b"Hello world\x0a", + ); } #[test] @@ -244,7 +318,6 @@ fn unpacked_size_write_none_to_header_and_use_provided_none_on_read() { } #[test] -#[should_panic(expected = "exceeded memory limit of 0")] fn memlimit() { let data = b"Some data"; let encode_options = lzma_rs::compress::Options { @@ -253,6 +326,43 @@ fn memlimit() { let decode_options = lzma_rs::decompress::Options { unpacked_size: lzma_rs::decompress::UnpackedSize::ReadHeaderButUseProvided(None), memlimit: Some(0), + ..Default::default() }; - round_trip_with_options(&data[..], &encode_options, &decode_options).unwrap(); + + let mut compressed: Vec = Vec::new(); + lzma_rs::lzma_compress_with_options( + &mut std::io::BufReader::new(&data[..]), + &mut compressed, + &encode_options, + ) + .unwrap(); + + // test non-streaming decompression + { + let mut bf = std::io::BufReader::new(compressed.as_slice()); + let mut decomp: Vec = Vec::new(); + let error = lzma_rs::lzma_decompress_with_options(&mut bf, &mut decomp, &decode_options) + .unwrap_err(); + assert!( + error.to_string().contains("exceeded memory limit of 0"), + error.to_string() + ); + } + + #[cfg(feature = "stream")] + // test streaming decompression + { + let mut stream = lzma_rs::decompress::Stream::new_with_options(&decode_options, Vec::new()); + + let error = stream.write_all(&compressed).unwrap_err(); + assert!( + error.to_string().contains("exceeded memory limit of 0"), + error.to_string() + ); + let error = stream.finish().unwrap_err(); + assert!( + error.to_string().contains("previous write error"), + error.to_string() + ); + } }