diff --git a/benches/lzma.rs b/benches/lzma.rs index 7883c95..0c41136 100644 --- a/benches/lzma.rs +++ b/benches/lzma.rs @@ -2,7 +2,7 @@ extern crate test; -use std::io::Read; +use std::io::{Read, Write}; use test::Bencher; fn compress_bench(x: &[u8], b: &mut Bencher) { @@ -31,6 +31,14 @@ fn decompress_bench(compressed: &[u8], b: &mut Bencher) { }); } +fn decompress_stream_bench(compressed: &[u8], b: &mut Bencher) { + b.iter(|| { + let mut stream = lzma_rs::decompress::Stream::new(Vec::new()); + stream.write_all(compressed).unwrap(); + stream.finish().unwrap(); + }); +} + fn decompress_bench_file(compfile: &str, b: &mut Bencher) { let mut f = std::fs::File::open(compfile).unwrap(); let mut compressed = Vec::new(); @@ -38,6 +46,13 @@ fn decompress_bench_file(compfile: &str, b: &mut Bencher) { decompress_bench(&compressed, b); } +fn decompress_stream_bench_file(compfile: &str, b: &mut Bencher) { + let mut f = std::fs::File::open(compfile).unwrap(); + let mut compressed = Vec::new(); + f.read_to_end(&mut compressed).unwrap(); + decompress_stream_bench(&compressed, b); +} + #[bench] fn compress_empty(b: &mut Bencher) { #[cfg(feature = "enable_logging")] @@ -87,6 +102,13 @@ fn decompress_big_file(b: &mut Bencher) { decompress_bench_file("tests/files/foo.txt.lzma", b); } +#[bench] +fn decompress_stream_big_file(b: &mut Bencher) { + #[cfg(feature = "enable_logging")] + let _ = env_logger::try_init(); + decompress_stream_bench_file("tests/files/foo.txt.lzma", b); +} + #[bench] fn decompress_huge_dict(b: &mut Bencher) { #[cfg(feature = "enable_logging")] diff --git a/src/decode/lzbuffer.rs b/src/decode/lzbuffer.rs index d54ac67..763afa1 100644 --- a/src/decode/lzbuffer.rs +++ b/src/decode/lzbuffer.rs @@ -1,38 +1,51 @@ use crate::error; use std::io; -pub trait LZBuffer { +pub trait LZBuffer +where + W: io::Write, +{ fn len(&self) -> usize; // Retrieve the last byte or return a default fn last_or(&self, lit: u8) -> u8; // Retrieve the n-th last byte fn last_n(&self, dist: usize) -> error::Result; // Append a literal - fn append_literal(&mut self, lit: u8) -> io::Result<()>; + fn append_literal(&mut self, lit: u8) -> error::Result<()>; // Fetch an LZ sequence (length, distance) from inside the buffer fn append_lz(&mut self, len: usize, dist: usize) -> error::Result<()>; + // Get a reference to the output sink + fn get_ref(&self) -> &W; + // Get a mutable reference to the output sink + fn get_mut(&mut self) -> &mut W; // Flush the buffer to the output - fn finish(self) -> io::Result<()>; + fn finish(self) -> io::Result; } // An accumulating buffer for LZ sequences -pub struct LZAccumBuffer<'a, W> +pub struct LZAccumBuffer where - W: 'a + io::Write, + W: io::Write, { - stream: &'a mut W, // Output sink - buf: Vec, // Buffer - len: usize, // Total number of bytes sent through the buffer + stream: W, // Output sink + buf: Vec, // Buffer + memlimit: usize, // Buffer memory limit + len: usize, // Total number of bytes sent through the buffer } -impl<'a, W> LZAccumBuffer<'a, W> +impl LZAccumBuffer where W: io::Write, { - pub fn from_stream(stream: &'a mut W) -> Self { + pub fn from_stream(stream: W) -> Self { + Self::from_stream_with_memlimit(stream, std::usize::MAX) + } + + pub fn from_stream_with_memlimit(stream: W, memlimit: usize) -> Self { Self { stream, buf: Vec::new(), + memlimit, len: 0, } } @@ -52,7 +65,7 @@ where } } -impl<'a, W> LZBuffer for LZAccumBuffer<'a, W> +impl LZBuffer for LZAccumBuffer where W: io::Write, { @@ -84,10 +97,19 @@ where } // Append a literal - fn append_literal(&mut self, lit: u8) -> io::Result<()> { - self.buf.push(lit); - self.len += 1; - Ok(()) + fn append_literal(&mut self, lit: u8) -> error::Result<()> { + let new_len = self.len + 1; + + if new_len > self.memlimit { + Err(error::Error::LZMAError(format!( + "exceeded memory limit of {}", + self.memlimit + ))) + } else { + self.buf.push(lit); + self.len = new_len; + Ok(()) + } } // Fetch an LZ sequence (length, distance) from inside the buffer @@ -111,36 +133,48 @@ where Ok(()) } + // Get a reference to the output sink + fn get_ref(&self) -> &W { + &self.stream + } + + // Get a mutable reference to the output sink + fn get_mut(&mut self) -> &mut W { + &mut self.stream + } + // Flush the buffer to the output - fn finish(self) -> io::Result<()> { + fn finish(mut self) -> io::Result { self.stream.write_all(self.buf.as_slice())?; self.stream.flush()?; - Ok(()) + Ok(self.stream) } } // A circular buffer for LZ sequences -pub struct LZCircularBuffer<'a, W> +pub struct LZCircularBuffer where - W: 'a + io::Write, + W: io::Write, { - stream: &'a mut W, // Output sink - buf: Vec, // Circular buffer - dict_size: usize, // Length of the buffer - cursor: usize, // Current position - len: usize, // Total number of bytes sent through the buffer + stream: W, // Output sink + buf: Vec, // Circular buffer + dict_size: usize, // Length of the buffer + memlimit: usize, // Buffer memory limit + cursor: usize, // Current position + len: usize, // Total number of bytes sent through the buffer } -impl<'a, W> LZCircularBuffer<'a, W> +impl LZCircularBuffer where W: io::Write, { - pub fn from_stream(stream: &'a mut W, dict_size: usize) -> Self { + pub fn from_stream_with_memlimit(stream: W, dict_size: usize, memlimit: usize) -> Self { lzma_info!("Dict size in LZ buffer: {}", dict_size); Self { stream, buf: Vec::new(), dict_size, + memlimit, cursor: 0, len: 0, } @@ -150,15 +184,25 @@ where *self.buf.get(index).unwrap_or(&0) } - fn set(&mut self, index: usize, value: u8) { - if self.buf.len() < index + 1 { - self.buf.resize(index + 1, 0); + fn set(&mut self, index: usize, value: u8) -> error::Result<()> { + let new_len = index + 1; + + if self.buf.len() < new_len { + if new_len <= self.memlimit { + self.buf.resize(new_len, 0); + } else { + return Err(error::Error::LZMAError(format!( + "exceeded memory limit of {}", + self.memlimit + ))); + } } self.buf[index] = value; + Ok(()) } } -impl<'a, W> LZBuffer for LZCircularBuffer<'a, W> +impl LZBuffer for LZCircularBuffer where W: io::Write, { @@ -195,8 +239,8 @@ where } // Append a literal - fn append_literal(&mut self, lit: u8) -> io::Result<()> { - self.set(self.cursor, lit); + fn append_literal(&mut self, lit: u8) -> error::Result<()> { + self.set(self.cursor, lit)?; self.cursor += 1; self.len += 1; @@ -237,12 +281,22 @@ where Ok(()) } + // Get a reference to the output sink + fn get_ref(&self) -> &W { + &self.stream + } + + // Get a mutable reference to the output sink + fn get_mut(&mut self) -> &mut W { + &mut self.stream + } + // Flush the buffer to the output - fn finish(self) -> io::Result<()> { + fn finish(mut self) -> io::Result { if self.cursor > 0 { self.stream.write_all(&self.buf[0..self.cursor])?; self.stream.flush()?; } - Ok(()) + Ok(self.stream) } } diff --git a/src/decode/lzma.rs b/src/decode/lzma.rs index 184d0a2..805a9ab 100644 --- a/src/decode/lzma.rs +++ b/src/decode/lzma.rs @@ -7,6 +7,45 @@ use std::io; use crate::decompress::Options; use crate::decompress::UnpackedSize; +/// Minimum header length to be read. +/// - props: u8 (1 byte) +/// - dict_size: u32 (4 bytes) +pub const MIN_HEADER_LEN: usize = 5; + +/// Max header length to be read. +/// - unpacked_size: u64 (8 bytes) +pub const MAX_HEADER_LEN: usize = MIN_HEADER_LEN + 8; + +/// Required bytes after the header. +/// - ignore: u8 (1 byte) +/// - code: u32 (4 bytes) +pub const START_BYTES: usize = 5; + +/// Maximum input data that can be processed in one iteration +const MAX_REQUIRED_INPUT: usize = 20; + +/// Processing mode for decompression. +/// +/// Tells the decompressor if we should expect more data after parsing the +/// current input. +#[derive(Debug, PartialEq)] +pub enum Mode { + /// Streaming mode. Process the input bytes but assume there will be more + /// chunks of input data to receive in future calls to `process_mode()`. + Run, + /// Sync mode. Process the input bytes and confirm end of stream has been reached. + /// Use this mode if you are processing a fixed buffer of compressed data, or after + /// using `Mode::Run` to check for the end of stream. + Finish, +} + +/// Used during stream processing to identify the next state. +pub enum CheckState { + Lit, + Rep, + Match, +} + pub struct LZMAParams { // most lc significant bits of previous byte are part of the literal context lc: u32, // 0..8 @@ -97,10 +136,16 @@ impl LZMAParams { } } -pub struct DecoderState +pub struct DecoderState where - LZB: lzbuffer::LZBuffer, + W: io::Write, + LZB: lzbuffer::LZBuffer, { + _phantom: std::marker::PhantomData, + // buffer input data here if we need more for decompression, 20 is the max + // number of bytes that can be consumed during one iteration + tmp: [u8; MAX_REQUIRED_INPUT], + tmp_len: usize, pub output: LZB, // most lc significant bits of previous byte are part of the literal context pub lc: u32, // 0..8 @@ -125,17 +170,20 @@ where } // Initialize decoder with accumulating buffer -pub fn new_accum<'a, W>( - output: lzbuffer::LZAccumBuffer<'a, W>, +pub fn new_accum( + output: lzbuffer::LZAccumBuffer, lc: u32, lp: u32, pb: u32, unpacked_size: Option, -) -> DecoderState> +) -> DecoderState> where W: io::Write, { DecoderState { + _phantom: std::marker::PhantomData, + tmp: [0; MAX_REQUIRED_INPUT], + tmp_len: 0, output, lc, lp, @@ -159,16 +207,35 @@ where } // Initialize decoder with circular buffer -pub fn new_circular<'a, W>( - output: &'a mut W, +pub fn new_circular( + output: W, + params: LZMAParams, +) -> error::Result>> +where + W: io::Write, +{ + new_circular_with_memlimit(output, params, std::usize::MAX) +} + +// Initialize decoder with circular buffer +pub fn new_circular_with_memlimit( + output: W, params: LZMAParams, -) -> error::Result>> + memlimit: usize, +) -> error::Result>> where W: io::Write, { // Decoder let decoder = DecoderState { - output: lzbuffer::LZCircularBuffer::from_stream(output, params.dict_size as usize), + _phantom: std::marker::PhantomData, + output: lzbuffer::LZCircularBuffer::from_stream_with_memlimit( + output, + params.dict_size as usize, + memlimit, + ), + tmp: [0; MAX_REQUIRED_INPUT], + tmp_len: 0, lc: params.lc, lp: params.lp, pb: params.pb, @@ -192,9 +259,10 @@ where Ok(decoder) } -impl DecoderState +impl DecoderState where - LZB: lzbuffer::LZBuffer, + W: io::Write, + LZB: lzbuffer::LZBuffer, { pub fn reset_state(&mut self, lc: u32, lp: u32, pb: u32) { self.lc = lc; @@ -223,104 +291,230 @@ where pub fn process<'a, R: io::BufRead>( &mut self, rangecoder: &mut rangecoder::RangeDecoder<'a, R>, + ) -> error::Result<()> { + self.process_mode(rangecoder, Mode::Finish) + } + + pub fn process_stream<'a, R: io::BufRead>( + &mut self, + rangecoder: &mut rangecoder::RangeDecoder<'a, R>, + ) -> error::Result<()> { + self.process_mode(rangecoder, Mode::Run) + } + + /// Process the next iteration of the loop. + /// + /// Returns true if we should continue processing the loop, false otherwise. + fn process_next<'a, R: io::BufRead>( + &mut self, + rangecoder: &mut rangecoder::RangeDecoder<'a, R>, + ) -> error::Result { + let pos_state = self.output.len() & ((1 << self.pb) - 1); + + // Literal + if !rangecoder.decode_bit( + // TODO: assumes pb = 2 ?? + &mut self.is_match[(self.state << 4) + pos_state], + )? { + let byte: u8 = self.decode_literal(rangecoder)?; + lzma_debug!("Literal: {}", byte); + self.output.append_literal(byte)?; + + self.state = if self.state < 4 { + 0 + } else if self.state < 10 { + self.state - 3 + } else { + self.state - 6 + }; + return Ok(true); + } + + // LZ + let mut len: usize; + // Distance is repeated from LRU + if rangecoder.decode_bit(&mut self.is_rep[self.state])? { + // dist = rep[0] + if !rangecoder.decode_bit(&mut self.is_rep_g0[self.state])? { + // len = 1 + if !rangecoder.decode_bit(&mut self.is_rep_0long[(self.state << 4) + pos_state])? { + // update state (short rep) + self.state = if self.state < 7 { 9 } else { 11 }; + let dist = self.rep[0] + 1; + self.output.append_lz(1, dist)?; + return Ok(true); + } + // dist = rep[i] + } else { + let idx: usize; + if !rangecoder.decode_bit(&mut self.is_rep_g1[self.state])? { + idx = 1; + } else if !rangecoder.decode_bit(&mut self.is_rep_g2[self.state])? { + idx = 2; + } else { + idx = 3; + } + // Update LRU + let dist = self.rep[idx]; + for i in (0..idx).rev() { + self.rep[i + 1] = self.rep[i]; + } + self.rep[0] = dist + } + + len = self.rep_len_decoder.decode(rangecoder, pos_state)?; + // update state (rep) + self.state = if self.state < 7 { 8 } else { 11 }; + // New distance + } else { + // Update LRU + self.rep[3] = self.rep[2]; + self.rep[2] = self.rep[1]; + self.rep[1] = self.rep[0]; + len = self.len_decoder.decode(rangecoder, pos_state)?; + + // update state (match) + self.state = if self.state < 7 { 7 } else { 10 }; + self.rep[0] = self.decode_distance(rangecoder, len)?; + + if self.rep[0] == 0xFFFF_FFFF { + if rangecoder.is_finished_ok()? { + return Ok(false); + } + return Err(error::Error::LZMAError(String::from( + "Found end-of-stream marker but more bytes are available", + ))); + } + } + + len += 2; + + let dist = self.rep[0] + 1; + self.output.append_lz(len, dist)?; + + Ok(true) + } + + /// Try to process the next iteration of the loop. + /// + /// This will check to see if there is enough data to consume and advance the + /// decompressor. Needed in streaming mode to avoid corrupting the state while + /// processing incomplete chunks of data. + fn try_process_next<'a>( + &self, + buf: &'a [u8], + range: u32, + code: u32, + ) -> error::Result { + let mut temp = std::io::Cursor::new(buf); + let mut rangecoder = rangecoder::RangeDecoder::from_parts(&mut temp, range, code); + let pos_state = self.output.len() & ((1 << self.pb) - 1); + + // Literal + if !rangecoder.decode_bit_check(self.is_match[(self.state << 4) + pos_state])? { + self.decode_literal_check(&mut rangecoder)?; + return Ok(CheckState::Lit); + } + + // LZ + if rangecoder.decode_bit_check(self.is_rep[self.state])? { + if !rangecoder.decode_bit_check(self.is_rep_g0[self.state])? { + if !rangecoder.decode_bit_check(self.is_rep_0long[(self.state << 4) + pos_state])? { + return Ok(CheckState::Rep); + } + } else if !rangecoder.decode_bit_check(self.is_rep_g1[self.state])? { + } else if !rangecoder.decode_bit_check(self.is_rep_g2[self.state])? { + } + + self.rep_len_decoder + .decode_check(&mut rangecoder, pos_state)?; + Ok(CheckState::Rep) + // New distance + } else { + let len = self.len_decoder.decode_check(&mut rangecoder, pos_state)?; + self.decode_distance_check(&mut rangecoder, len)?; + Ok(CheckState::Match) + } + } + + pub fn process_mode<'a, R: io::BufRead>( + &mut self, + mut rangecoder: &mut rangecoder::RangeDecoder<'a, R>, + mode: Mode, ) -> error::Result<()> { loop { if let Some(unpacked_size) = self.unpacked_size { if self.output.len() as u64 >= unpacked_size { break; } - } else if rangecoder.is_finished_ok()? { + } else if match mode { + Mode::Run => rangecoder.is_eof()?, + Mode::Finish => rangecoder.is_finished_ok()?, + } { break; } - let pos_state = self.output.len() & ((1 << self.pb) - 1); - - // Literal - if !rangecoder.decode_bit( - // TODO: assumes pb = 2 ?? - &mut self.is_match[(self.state << 4) + pos_state], - )? { - let byte: u8 = self.decode_literal(rangecoder)?; - lzma_debug!("Literal: {}", byte); - self.output.append_literal(byte)?; - - self.state = if self.state < 4 { - 0 - } else if self.state < 10 { - self.state - 3 - } else { - self.state - 6 - }; - continue; - } - - // LZ - let mut len: usize; - // Distance is repeated from LRU - if rangecoder.decode_bit(&mut self.is_rep[self.state])? { - // dist = rep[0] - if !rangecoder.decode_bit(&mut self.is_rep_g0[self.state])? { - // len = 1 - if !rangecoder - .decode_bit(&mut self.is_rep_0long[(self.state << 4) + pos_state])? + if self.tmp_len > 0 { + // Fill as much of the tmp buffer as possible + self.tmp_len += rangecoder.read_into(&mut self.tmp[self.tmp_len..])?; + + // Check if we need more data to advance the decompressor + if Mode::Run == mode && self.tmp_len < MAX_REQUIRED_INPUT { + if self + .try_process_next( + &self.tmp[0..self.tmp_len], + rangecoder.range(), + rangecoder.code(), + ) + .is_err() { - // update state (short rep) - self.state = if self.state < 7 { 9 } else { 11 }; - let dist = self.rep[0] + 1; - self.output.append_lz(1, dist)?; - continue; - } - // dist = rep[i] - } else { - let idx: usize; - if !rangecoder.decode_bit(&mut self.is_rep_g1[self.state])? { - idx = 1; - } else if !rangecoder.decode_bit(&mut self.is_rep_g2[self.state])? { - idx = 2; - } else { - idx = 3; - } - // Update LRU - let dist = self.rep[idx]; - for i in (0..idx).rev() { - self.rep[i + 1] = self.rep[i]; + return Ok(()); } - self.rep[0] = dist } - len = self.rep_len_decoder.decode(rangecoder, pos_state)?; - // update state (rep) - self.state = if self.state < 7 { 8 } else { 11 }; - // New distance + // Run the decompressor on the tmp buffer + let tmp = self.tmp; + let mut tmp_reader = io::Cursor::new(&tmp[0..self.tmp_len]); + let mut tmp_rangecoder = rangecoder::RangeDecoder::from_parts( + &mut tmp_reader, + rangecoder.range(), + rangecoder.code(), + ); + let res = self.process_next(&mut tmp_rangecoder)?; + + // Update the actual rangecoder + let (range, code) = tmp_rangecoder.into_parts(); + rangecoder.set(range, code); + + // Update tmp buffer + let new_len = self.tmp_len - tmp_reader.position() as usize; + self.tmp[0..new_len] + .copy_from_slice(&tmp[tmp_reader.position() as usize..self.tmp_len]); + self.tmp_len = new_len; + + if !res { + break; + }; } else { - // Update LRU - self.rep[3] = self.rep[2]; - self.rep[2] = self.rep[1]; - self.rep[1] = self.rep[0]; - len = self.len_decoder.decode(rangecoder, pos_state)?; - - // update state (match) - self.state = if self.state < 7 { 7 } else { 10 }; - self.rep[0] = self.decode_distance(rangecoder, len)?; - - if self.rep[0] == 0xFFFF_FFFF { - if rangecoder.is_finished_ok()? { - break; + if (Mode::Run == mode) && (rangecoder.remaining()? < MAX_REQUIRED_INPUT) { + let range = rangecoder.range(); + let code = rangecoder.code(); + let buf = rangecoder.buf()?; + + if self.try_process_next(buf, range, code).is_err() { + self.tmp_len = rangecoder.read_into(&mut self.tmp)?; + return Ok(()); } - return Err(error::Error::LZMAError(String::from( - "Found end-of-stream marker but more bytes are available", - ))); } - } - len += 2; - - let dist = self.rep[0] + 1; - self.output.append_lz(len, dist)?; + if !self.process_next(&mut rangecoder)? { + break; + }; + } } if let Some(len) = self.unpacked_size { - if self.output.len() as u64 != len { + if Mode::Finish == mode && self.output.len() as u64 != len { return Err(error::Error::LZMAError(format!( "Expected unpacked size of {} but decompressed to {}", len, @@ -366,6 +560,41 @@ where Ok((result - 0x100) as u8) } + /// Attempts to decode a literal without mutating state + fn decode_literal_check<'a, R: io::BufRead>( + &self, + rangecoder: &mut rangecoder::RangeDecoder<'a, R>, + ) -> error::Result { + let def_prev_byte = 0u8; + let prev_byte = self.output.last_or(def_prev_byte) as usize; + + let mut result: usize = 1; + let lit_state = + ((self.output.len() & ((1 << self.lp) - 1)) << self.lc) + (prev_byte >> (8 - self.lc)); + let probs = &self.literal_probs[lit_state]; + + if self.state >= 7 { + let mut match_byte = self.output.last_n(self.rep[0] + 1)? as usize; + + while result < 0x100 { + let match_bit = (match_byte >> 7) & 1; + match_byte <<= 1; + let bit = + rangecoder.decode_bit_check(probs[((1 + match_bit) << 8) + result])? as usize; + result = (result << 1) ^ bit; + if match_bit != bit { + break; + } + } + } + + while result < 0x100 { + result = (result << 1) ^ (rangecoder.decode_bit_check(probs[result])? as usize); + } + + Ok((result - 0x100) as u8) + } + fn decode_distance<'a, R: io::BufRead>( &mut self, rangecoder: &mut rangecoder::RangeDecoder<'a, R>, @@ -394,4 +623,34 @@ where Ok(result) } + + /// Attempts to decode distance without mutating state + fn decode_distance_check<'a, R: io::BufRead>( + &self, + rangecoder: &mut rangecoder::RangeDecoder<'a, R>, + length: usize, + ) -> error::Result { + let len_state = if length > 3 { 3 } else { length }; + + let pos_slot = self.pos_slot_decoder[len_state].parse_check(rangecoder)? as usize; + if pos_slot < 4 { + return Ok(pos_slot); + } + + let num_direct_bits = (pos_slot >> 1) - 1; + let mut result = (2 ^ (pos_slot & 1)) << num_direct_bits; + + if pos_slot < 14 { + result += rangecoder.parse_reverse_bit_tree_check( + num_direct_bits, + &self.pos_decoders, + result - pos_slot, + )? as usize; + } else { + result += (rangecoder.get(num_direct_bits - 4)? as usize) << 4; + result += self.align_decoder.parse_reverse_check(rangecoder)? as usize; + } + + Ok(result) + } } diff --git a/src/decode/lzma2.rs b/src/decode/lzma2.rs index a50b9fd..678ccd4 100644 --- a/src/decode/lzma2.rs +++ b/src/decode/lzma2.rs @@ -43,8 +43,8 @@ where Ok(()) } -fn parse_lzma<'a, R, W>( - decoder: &mut lzma::DecoderState>, +fn parse_lzma( + decoder: &mut lzma::DecoderState>, input: &mut R, status: u8, ) -> error::Result<()> @@ -166,7 +166,7 @@ where } fn parse_uncompressed<'a, R, W>( - decoder: &mut lzma::DecoderState>, + decoder: &mut lzma::DecoderState>, input: &mut R, reset_dict: bool, ) -> error::Result<()> diff --git a/src/decode/mod.rs b/src/decode/mod.rs index 4a81d62..b830240 100644 --- a/src/decode/mod.rs +++ b/src/decode/mod.rs @@ -5,5 +5,6 @@ pub mod lzma; pub mod lzma2; pub mod options; pub mod rangecoder; +pub mod stream; pub mod util; pub mod xz; diff --git a/src/decode/options.rs b/src/decode/options.rs index a40737d..2c4316e 100644 --- a/src/decode/options.rs +++ b/src/decode/options.rs @@ -6,6 +6,16 @@ pub struct Options { /// The default is /// [`UnpackedSize::ReadFromHeader`](enum.UnpackedSize.html#variant.ReadFromHeader). pub unpacked_size: UnpackedSize, + /// Defines whether the dictionary's dynamic size should be limited during decompression. + /// + /// The default is unlimited. + pub memlimit: Option, + /// Determines whether to bypass end of stream validation. + /// + /// This option only applies to the [`Stream`](struct.Stream.html) API. + /// + /// The default is false (always do completion check). + pub allow_incomplete: bool, } /// Alternatives for defining the unpacked size of the decoded data. diff --git a/src/decode/rangecoder.rs b/src/decode/rangecoder.rs index 7643ea1..f90877d 100644 --- a/src/decode/rangecoder.rs +++ b/src/decode/rangecoder.rs @@ -28,9 +28,51 @@ where Ok(dec) } + pub fn from_parts(stream: &'a mut R, range: u32, code: u32) -> Self { + Self { + stream, + range, + code, + } + } + + pub fn into_parts(self) -> (u32, u32) { + (self.range, self.code) + } + + pub fn set(&mut self, range: u32, code: u32) { + self.range = range; + self.code = code; + } + + pub fn range(&self) -> u32 { + self.range + } + + pub fn code(&self) -> u32 { + self.code + } + + pub fn buf(&mut self) -> io::Result<&[u8]> { + self.stream.fill_buf() + } + + pub fn remaining(&mut self) -> io::Result { + Ok(self.buf()?.len()) + } + + pub fn read_into(&mut self, dst: &mut [u8]) -> io::Result { + self.stream.read(dst) + } + #[inline] pub fn is_finished_ok(&mut self) -> io::Result { - Ok(self.code == 0 && util::is_eof(self.stream)?) + Ok(self.code == 0 && self.is_eof()?) + } + + #[inline] + pub fn is_eof(&mut self) -> io::Result { + util::is_eof(self.stream) } #[inline] @@ -92,6 +134,23 @@ where } } + #[inline] + pub fn decode_bit_check(&mut self, prob: u16) -> io::Result { + let bound: u32 = (self.range >> 11) * (prob as u32); + if self.code < bound { + self.range = bound; + + self.normalize()?; + Ok(false) + } else { + self.code -= bound; + self.range -= bound; + + self.normalize()?; + Ok(true) + } + } + fn parse_bit_tree(&mut self, num_bits: usize, probs: &mut [u16]) -> io::Result { let mut tmp: u32 = 1; for _ in 0..num_bits { @@ -101,6 +160,15 @@ where Ok(tmp - (1 << num_bits)) } + fn parse_bit_tree_check(&mut self, num_bits: usize, probs: &[u16]) -> io::Result { + let mut tmp: u32 = 1; + for _ in 0..num_bits { + let bit = self.decode_bit_check(probs[tmp as usize])?; + tmp = (tmp << 1) ^ (bit as u32); + } + Ok(tmp - (1 << num_bits)) + } + pub fn parse_reverse_bit_tree( &mut self, num_bits: usize, @@ -116,6 +184,22 @@ where } Ok(result) } + + pub fn parse_reverse_bit_tree_check( + &mut self, + num_bits: usize, + probs: &[u16], + offset: usize, + ) -> io::Result { + let mut result = 0u32; + let mut tmp: usize = 1; + for i in 0..num_bits { + let bit = self.decode_bit_check(probs[offset + tmp])?; + tmp = (tmp << 1) ^ (bit as usize); + result ^= (bit as u32) << i; + } + Ok(result) + } } // TODO: parametrize by constant and use [u16; 1 << num_bits] as soon as Rust supports this @@ -137,12 +221,23 @@ impl BitTree { rangecoder.parse_bit_tree(self.num_bits, self.probs.as_mut_slice()) } + pub fn parse_check(&self, rangecoder: &mut RangeDecoder) -> io::Result { + rangecoder.parse_bit_tree_check(self.num_bits, self.probs.as_slice()) + } + pub fn parse_reverse( &mut self, rangecoder: &mut RangeDecoder, ) -> io::Result { rangecoder.parse_reverse_bit_tree(self.num_bits, self.probs.as_mut_slice(), 0) } + + pub fn parse_reverse_check( + &self, + rangecoder: &mut RangeDecoder, + ) -> io::Result { + rangecoder.parse_reverse_bit_tree_check(self.num_bits, self.probs.as_slice(), 0) + } } pub struct LenDecoder { @@ -177,4 +272,18 @@ impl LenDecoder { Ok(self.high_coder.parse(rangecoder)? as usize + 16) } } + + pub fn decode_check( + &self, + rangecoder: &mut RangeDecoder, + pos_state: usize, + ) -> io::Result { + if !rangecoder.decode_bit_check(self.choice)? { + Ok(self.low_coder[pos_state].parse_check(rangecoder)? as usize) + } else if !rangecoder.decode_bit_check(self.choice2)? { + Ok(self.mid_coder[pos_state].parse_check(rangecoder)? as usize + 8) + } else { + Ok(self.high_coder.parse_check(rangecoder)? as usize + 16) + } + } } diff --git a/src/decode/stream.rs b/src/decode/stream.rs new file mode 100644 index 0000000..d77f599 --- /dev/null +++ b/src/decode/stream.rs @@ -0,0 +1,476 @@ +use crate::decode::lzbuffer::{LZBuffer, LZCircularBuffer}; +use crate::decode::lzma::{ + new_circular, new_circular_with_memlimit, DecoderState, LZMAParams, MAX_HEADER_LEN, START_BYTES, +}; +use crate::decode::rangecoder::RangeDecoder; +use crate::decompress::Options; +use crate::error::Error; +use std::io::{BufRead, Cursor, Read, Write}; + +/// Maximum number of bytes to buffer while reading the header. +const MAX_TMP_LEN: usize = MAX_HEADER_LEN + START_BYTES; + +/// Internal state of this streaming decoder. This is needed because we have to +/// initialize the stream before processing any data. +#[derive(Debug)] +enum State +where + W: Write, +{ + /// Stream is initialized but header values have not yet been read. + Init(W), + /// Header values have been read and the stream is ready to process more data. + Run(RunState), +} + +/// Structures needed while decoding data. +struct RunState +where + W: Write, +{ + decoder: DecoderState>, + range: u32, + code: u32, +} + +impl RunState +where + W: Write, +{ + fn new(decoder: DecoderState>, range: u32, code: u32) -> Self { + Self { + decoder, + range, + code, + } + } +} + +impl std::fmt::Debug for RunState +where + W: Write, +{ + fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::result::Result<(), std::fmt::Error> { + fmt.debug_struct("RunState") + .field("range", &self.range) + .field("code", &self.code) + .finish() + } +} + +/// Lzma decompressor that can process multiple chunks of data using the +/// `std::io::Write` interface. +pub struct Stream +where + W: Write, +{ + /// Temporary buffer to hold data while the header is being read. + tmp: [u8; MAX_TMP_LEN], + /// How many bytes of the temp buffer are in use. + tmp_len: usize, + /// Whether the stream is initialized and ready to process data. + /// An `Option` is used to avoid interior mutability when updating the state. + state: Option>, + /// Options given when a stream is created. + options: Options, +} + +impl Stream +where + W: Write, +{ + /// Initialize the stream. This will consume the `output` which is the sink + /// implementing `std::io::Write` that will receive decompressed bytes. + pub fn new(output: W) -> Self { + Self::new_with_options(Options::default(), output) + } + + /// Initialize the stream with the given `options`. This will consume the + /// `output` which is the sink implementing `std::io::Write` that will + /// receive decompressed bytes. + pub fn new_with_options(options: Options, output: W) -> Self { + Self { + tmp: [0; MAX_TMP_LEN], + tmp_len: 0, + state: Some(State::Init(output)), + options, + } + } + + /// Get a reference to the output sink + pub fn get_ref(&self) -> Option<&W> { + match &self.state { + Some(State::Init(output)) => Some(&output), + Some(State::Run(state)) => Some(state.decoder.output.get_ref()), + None => None, + } + } + + /// Get a mutable reference to the output sink + pub fn get_mut(&mut self) -> Option<&mut W> { + match &mut self.state { + Some(State::Init(output)) => Some(output), + Some(State::Run(state)) => Some(state.decoder.output.get_mut()), + None => None, + } + } + + /// Consumes the stream and returns the output sink. This also makes sure + /// we have properly reached the end of the stream. + pub fn finish(mut self) -> Result { + if let Some(state) = self.state.take() { + match state { + State::Init(output) => { + if self.tmp_len > 0 { + Err(std::io::Error::new( + std::io::ErrorKind::Other, + "failed to read header", + )) + } else { + Ok(output) + } + } + State::Run(mut state) => { + if !self.options.allow_incomplete { + // Process one last time with empty input to force end of + // stream checks + let mut stream = Cursor::new(&self.tmp[0..self.tmp_len]); + let mut range_decoder = + RangeDecoder::from_parts(&mut stream, state.range, state.code); + state + .decoder + .process(&mut range_decoder) + .map_err(|e| -> std::io::Error { e.into() })?; + } + state.decoder.output.finish() + } + } + } else { + // this will occur if a call to `write()` fails + Err(std::io::Error::new( + std::io::ErrorKind::Other, + "can't finish stream because of previous write error", + )) + } + } + + /// Attempts to read the header and transition into a running state. + /// + /// This function will consume the state, returning the next state on both + /// error and success. + fn read_header( + output: W, + mut input: &mut R, + options: &Options, + ) -> Result, (Option>, std::io::Error)> { + let params = match LZMAParams::read_header(&mut input, options) { + Ok(params) => params, + Err(e) => { + return Err((Some(State::Init(output)), e.into())); + } + }; + + let len = match input.fill_buf() { + Ok(val) => val, + Err(_) => { + return Err(( + Some(State::Init(output)), + std::io::Error::new(std::io::ErrorKind::Other, "need more input"), + )); + } + } + .len(); + + if len < START_BYTES { + return Err(( + Some(State::Init(output)), + std::io::Error::new(std::io::ErrorKind::Other, "need more input"), + )); + }; + + let decoder = if let Some(memlimit) = options.memlimit { + new_circular_with_memlimit(output, params, memlimit) + } else { + new_circular(output, params) + } + .map_err(|e| -> (Option>, std::io::Error) { (None, e.into()) })?; + + // The RangeDecoder is only kept temporarily as we are processing + // chunks of data. + let range_decoder = RangeDecoder::new(&mut input) + .map_err(|e| -> (Option>, std::io::Error) { (None, e) })?; + let (range, code) = range_decoder.into_parts(); + + Ok(State::Run(RunState::new(decoder, range, code))) + } + + /// Process compressed data + fn read_data( + mut state: RunState, + mut input: &mut R, + ) -> Result, std::io::Error> { + // Construct our RangeDecoder from the previous range and code + // values. + let mut range_decoder = RangeDecoder::from_parts(&mut input, state.range, state.code); + + // Try to process all bytes of data. + state + .decoder + .process_stream(&mut range_decoder) + .map_err(|e| -> std::io::Error { e.into() })?; + + // Save the range and code for the next chunk of data. + let (range, code) = range_decoder.into_parts(); + state.range = range; + state.code = code; + + Ok(RunState::new(state.decoder, range, code)) + } +} + +impl std::fmt::Debug for Stream +where + W: Write + std::fmt::Debug, +{ + fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::result::Result<(), std::fmt::Error> { + fmt.debug_struct("Stream") + .field("tmp_len", &self.tmp_len) + .field("state", &self.state) + .field("options", &self.options) + .finish() + } +} + +impl Write for Stream +where + W: Write, +{ + fn write(&mut self, data: &[u8]) -> std::io::Result { + let mut input = Cursor::new(data); + + if let Some(state) = self.state.take() { + let state = match state { + // Read the header values and transition into a running state. + State::Init(state) => { + let res = if self.tmp_len > 0 { + // attempt to fill the tmp buffer + self.tmp_len += input.read(&mut self.tmp[self.tmp_len..])?; + + // attempt to read the header from our tmp buffer + let (position, res) = { + let mut tmp_input = Cursor::new(&self.tmp[0..self.tmp_len]); + let res = Stream::read_header(state, &mut tmp_input, &self.options); + (tmp_input.position() as usize, res) + }; + + // discard all bytes up to position if reading the header + // was successful + if res.is_ok() { + let tmp = self.tmp; + let new_len = self.tmp_len - position; + (&mut self.tmp[0..new_len]) + .copy_from_slice(&tmp[position..self.tmp_len]); + self.tmp_len = new_len; + } + res + } else { + Stream::read_header(state, &mut input, &self.options) + }; + + match res { + Ok(state) => state, + // occurs when not enough input bytes were provided to + // read the entire header + Err((Some(state), _)) => { + if self.tmp_len == 0 { + // reset the cursor because we may have partial reads + input.set_position(0); + self.tmp_len = input.read(&mut self.tmp)?; + } + state + } + // occurs when the output was consumed due to a + // non-recoverable error + Err((None, e)) => { + return Err(e); + } + } + } + + // Process another chunk of data. + State::Run(state) => { + let state = if self.tmp_len > 0 { + let mut tmp_input = Cursor::new(&self.tmp[0..self.tmp_len]); + let res = Stream::read_data(state, &mut tmp_input)?; + self.tmp_len = 0; + res + } else { + state + }; + State::Run(Stream::read_data(state, &mut input)?) + } + }; + self.state.replace(state); + } + Ok(input.position() as usize) + } + + /// Flushes the output sink. The internal buffer isn't flushed to avoid + /// corrupting the internal state. Instead, call `finish()` to finalize the + /// stream and flush all remaining internal data. + fn flush(&mut self) -> std::result::Result<(), std::io::Error> { + if let Some(ref mut state) = self.state { + match state { + State::Init(_) => Ok(()), + State::Run(state) => state.decoder.output.get_mut().flush(), + } + } else { + Ok(()) + } + } +} + +impl std::convert::Into for Error { + fn into(self) -> std::io::Error { + std::io::Error::new(std::io::ErrorKind::Other, format!("{:?}", self)) + } +} + +#[cfg(test)] +mod test { + use super::*; + + /// Test an empty stream + #[test] + fn test_stream_noop() { + let stream = Stream::new(Vec::new()); + assert!(stream.get_ref().unwrap().is_empty()); + + let output = stream.finish().unwrap(); + assert!(output.is_empty()); + } + + /// Test writing an empty slice + #[test] + fn test_stream_zero() { + let mut stream = Stream::new(Vec::new()); + + stream.write_all(&[]).unwrap(); + stream.write_all(&[]).unwrap(); + + let output = stream.finish().unwrap(); + + assert!(output.is_empty()); + } + + /// Test processing only partial data + #[test] + fn test_stream_incomplete() { + let input = b"\x5d\x00\x00\x80\x00\xff\xff\xff\xff\xff\xff\xff\xff\x00\x83\xff\ + \xfb\xff\xff\xc0\x00\x00\x00"; + // Process until this index is reached. + let mut end = 1; + + // Test when we fail to provide the minimum number of bytes required to + // read the header. Header size is 13 bytes but we also read the first 5 + // bytes of data. + while end < MAX_HEADER_LEN + START_BYTES { + let mut stream = Stream::new(Vec::new()); + stream.write_all(&input[..end]).unwrap(); + assert_eq!(stream.tmp_len, end); + + let err = stream.finish().unwrap_err(); + assert!( + err.to_string().contains("failed to read header"), + "error was: {}", + err + ); + + end += 1; + } + + // Test when we fail to provide enough bytes to terminate the stream. A + // properly terminated stream will have a code value of 0. + while end < input.len() { + let mut stream = Stream::new(Vec::new()); + stream.write_all(&input[..end]).unwrap(); + + // Header bytes will be buffered until there are enough to read + if end < MAX_HEADER_LEN + START_BYTES { + assert_eq!(stream.tmp_len, end); + } + + let err = stream.finish().unwrap_err(); + assert!(err.to_string().contains("failed to fill whole buffer")); + + end += 1; + } + } + + /// Test processing all chunk sizes + #[test] + fn test_stream_chunked() { + let small_input = b"Project Gutenberg's Alice's Adventures in Wonderland, by Lewis Carroll"; + + let mut reader = std::io::Cursor::new(&small_input[..]); + let mut small_input_compressed = Vec::new(); + crate::lzma_compress(&mut reader, &mut small_input_compressed).unwrap(); + + let input : Vec<(&[u8], &[u8])> = vec![ + (b"\x5d\x00\x00\x80\x00\xff\xff\xff\xff\xff\xff\xff\xff\x00\x83\xff\xfb\xff\xff\xc0\x00\x00\x00", b""), + (&small_input_compressed[..], small_input)]; + for (input, expected) in input { + for chunk in 1..input.len() { + let mut consumed = 0; + let mut stream = Stream::new(Vec::new()); + while consumed < input.len() { + let end = std::cmp::min(consumed + chunk, input.len()); + stream.write_all(&input[consumed..end]).unwrap(); + consumed = end; + } + let output = stream.finish().unwrap(); + assert_eq!(expected, &output[..]); + } + } + } + + #[test] + fn test_stream_corrupted() { + let mut stream = Stream::new(Vec::new()); + let err = stream + .write_all(b"corrupted bytes here corrupted bytes here") + .unwrap_err(); + assert!(err.to_string().contains("beyond output size")); + let err = stream.finish().unwrap_err(); + assert!(err + .to_string() + .contains("can\'t finish stream because of previous write error")); + } + + #[test] + fn test_allow_incomplete() { + let input = b"Project Gutenberg's Alice's Adventures in Wonderland, by Lewis Carroll"; + + let mut reader = std::io::Cursor::new(&input[..]); + let mut compressed = Vec::new(); + crate::lzma_compress(&mut reader, &mut compressed).unwrap(); + let compressed = &compressed[..compressed.len() / 2]; + + // Should fail to finish() without the allow_incomplete option. + let mut stream = Stream::new(Vec::new()); + stream.write_all(&compressed[..]).unwrap(); + stream.finish().unwrap_err(); + + // Should succeed with the allow_incomplete option. + let mut stream = Stream::new_with_options( + Options { + allow_incomplete: true, + ..Default::default() + }, + Vec::new(), + ); + stream.write_all(&compressed[..]).unwrap(); + let output = stream.finish().unwrap(); + assert_eq!(output, &input[..26]); + } +} diff --git a/src/lib.rs b/src/lib.rs index 355bdb9..f4e8c45 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,6 +23,7 @@ pub mod compress { /// Decompression helpers. pub mod decompress { pub use crate::decode::options::*; + pub use crate::decode::stream::Stream; } /// Decompress LZMA data with default [`Options`](decompress/struct.Options.html). @@ -40,7 +41,12 @@ pub fn lzma_decompress_with_options( options: &decompress::Options, ) -> error::Result<()> { let params = decode::lzma::LZMAParams::read_header(input, options)?; - let mut decoder = decode::lzma::new_circular(output, params)?; + let mut decoder = if let Some(memlimit) = options.memlimit { + decode::lzma::new_circular_with_memlimit(output, params, memlimit)? + } else { + decode::lzma::new_circular(output, params)? + }; + let mut rangecoder = decode::rangecoder::RangeDecoder::new(input).or_else(|e| { Err(error::Error::LZMAError(format!( "LZMA stream too short: {}", diff --git a/tests/lzma.rs b/tests/lzma.rs index b933f28..0b86c7e 100644 --- a/tests/lzma.rs +++ b/tests/lzma.rs @@ -10,8 +10,9 @@ fn round_trip(x: &[u8]) { }; let decode_options = lzma_rs::decompress::Options { unpacked_size: lzma_rs::decompress::UnpackedSize::ReadFromHeader, + ..Default::default() }; - round_trip_with_options(x, &encode_options, &decode_options); + assert_round_trip_with_options(x, &encode_options, &decode_options); } fn round_trip_no_options(x: &[u8]) { @@ -31,7 +32,7 @@ fn round_trip_with_options( x: &[u8], encode_options: &lzma_rs::compress::Options, decode_options: &lzma_rs::decompress::Options, -) { +) -> lzma_rs::error::Result> { let mut compressed: Vec = Vec::new(); lzma_rs::lzma_compress_with_options( &mut std::io::BufReader::new(x), @@ -45,8 +46,19 @@ fn round_trip_with_options( debug!("Compressed content: {:?}", compressed); let mut bf = std::io::BufReader::new(compressed.as_slice()); let mut decomp: Vec = Vec::new(); - lzma_rs::lzma_decompress_with_options(&mut bf, &mut decomp, decode_options).unwrap(); - assert_eq!(decomp, x) + lzma_rs::lzma_decompress_with_options(&mut bf, &mut decomp, decode_options)?; + Ok(decomp) +} + +fn assert_round_trip_with_options( + x: &[u8], + encode_options: &lzma_rs::compress::Options, + decode_options: &lzma_rs::decompress::Options, +) { + assert_eq!( + round_trip_with_options(x, encode_options, decode_options).unwrap(), + x + ) } fn round_trip_file(filename: &str) { @@ -170,8 +182,9 @@ fn unpacked_size_write_to_header() { }; let decode_options = lzma_rs::decompress::Options { unpacked_size: lzma_rs::decompress::UnpackedSize::ReadFromHeader, + ..Default::default() }; - round_trip_with_options(&data[..], &encode_options, &decode_options); + assert_round_trip_with_options(&data[..], &encode_options, &decode_options); } #[test] @@ -182,8 +195,9 @@ fn unpacked_size_provided_outside() { }; let decode_options = lzma_rs::decompress::Options { unpacked_size: lzma_rs::decompress::UnpackedSize::UseProvided(Some(data.len() as u64)), + ..Default::default() }; - round_trip_with_options(&data[..], &encode_options, &decode_options); + assert_round_trip_with_options(&data[..], &encode_options, &decode_options); } #[test] @@ -196,8 +210,9 @@ fn unpacked_size_write_some_to_header_but_use_provided_on_read() { unpacked_size: lzma_rs::decompress::UnpackedSize::ReadHeaderButUseProvided(Some( data.len() as u64, )), + ..Default::default() }; - round_trip_with_options(&data[..], &encode_options, &decode_options); + assert_round_trip_with_options(&data[..], &encode_options, &decode_options); } #[test] @@ -210,8 +225,9 @@ fn unpacked_size_write_none_to_header_and_use_provided_on_read() { unpacked_size: lzma_rs::decompress::UnpackedSize::ReadHeaderButUseProvided(Some( data.len() as u64, )), + ..Default::default() }; - round_trip_with_options(&data[..], &encode_options, &decode_options); + assert_round_trip_with_options(&data[..], &encode_options, &decode_options); } #[test] @@ -222,6 +238,22 @@ fn unpacked_size_write_none_to_header_and_use_provided_none_on_read() { }; let decode_options = lzma_rs::decompress::Options { unpacked_size: lzma_rs::decompress::UnpackedSize::ReadHeaderButUseProvided(None), + ..Default::default() + }; + assert_round_trip_with_options(&data[..], &encode_options, &decode_options); +} + +#[test] +#[should_panic(expected = "exceeded memory limit of 0")] +fn memlimit() { + let data = b"Some data"; + let encode_options = lzma_rs::compress::Options { + unpacked_size: lzma_rs::compress::UnpackedSize::WriteToHeader(None), + }; + let decode_options = lzma_rs::decompress::Options { + unpacked_size: lzma_rs::decompress::UnpackedSize::ReadHeaderButUseProvided(None), + memlimit: Some(0), + ..Default::default() }; - round_trip_with_options(&data[..], &encode_options, &decode_options); + round_trip_with_options(&data[..], &encode_options, &decode_options).unwrap(); } diff --git a/tests/stream.rs b/tests/stream.rs new file mode 100644 index 0000000..bad4c88 --- /dev/null +++ b/tests/stream.rs @@ -0,0 +1,273 @@ +use std::io::Write; + +fn round_trip(x: &[u8]) { + round_trip_no_options(x); + + // Do another round trip, but this time also write it to the header + let encode_options = lzma_rs::compress::Options { + unpacked_size: lzma_rs::compress::UnpackedSize::WriteToHeader(Some(x.len() as u64)), + }; + let decode_options = lzma_rs::decompress::Options { + unpacked_size: lzma_rs::decompress::UnpackedSize::ReadFromHeader, + ..Default::default() + }; + round_trip_with_options(x, &encode_options, &decode_options); +} + +fn round_trip_no_options(x: &[u8]) { + let mut compressed: Vec = Vec::new(); + lzma_rs::lzma_compress(&mut std::io::BufReader::new(x), &mut compressed).unwrap(); + #[cfg(feature = "enable_logging")] + info!("Compressed {} -> {} bytes", x.len(), compressed.len()); + #[cfg(feature = "enable_logging")] + debug!("Compressed content: {:?}", compressed); + let mut stream = lzma_rs::decompress::Stream::new(Vec::new()); + stream.write_all(&compressed).unwrap(); + let decomp = stream.finish().unwrap(); + assert_eq!(decomp, x) +} + +fn round_trip_with_options( + x: &[u8], + encode_options: &lzma_rs::compress::Options, + decode_options: &lzma_rs::decompress::Options, +) { + let mut compressed: Vec = Vec::new(); + lzma_rs::lzma_compress_with_options( + &mut std::io::BufReader::new(x), + &mut compressed, + encode_options, + ) + .unwrap(); + #[cfg(feature = "enable_logging")] + info!("Compressed {} -> {} bytes", x.len(), compressed.len()); + #[cfg(feature = "enable_logging")] + debug!("Compressed content: {:?}", compressed); + let mut stream = + lzma_rs::decompress::Stream::new_with_options(decode_options.clone(), Vec::new()); + + if let Err(error) = stream.write_all(&compressed) { + // WriteZero could indicate that the unpacked_size was reached before the + // end of the stream + if std::io::ErrorKind::WriteZero != error.kind() { + panic!(error); + } + } + let decomp = stream.finish().unwrap(); + assert_eq!(decomp, x) +} + +fn round_trip_file(filename: &str) { + use std::io::Read; + + let mut x = Vec::new(); + std::fs::File::open(filename) + .unwrap() + .read_to_end(&mut x) + .unwrap(); + round_trip(x.as_slice()); +} + +fn decomp_big_file(compfile: &str, plainfile: &str) { + use std::io::Read; + + let mut expected = Vec::new(); + std::fs::File::open(plainfile) + .unwrap() + .read_to_end(&mut expected) + .unwrap(); + + let mut buf = [0u8; 1024]; + let mut file = std::fs::File::open(compfile).unwrap(); + let mut stream = lzma_rs::decompress::Stream::new(Vec::new()); + + loop { + let consumed = file.read(&mut buf).unwrap(); + + if consumed == 0 { + break; + } + stream.write_all(&buf[..consumed]).unwrap(); + } + + let decomp = stream.finish().unwrap(); + assert!(decomp == expected) +} + +#[test] +fn round_trip_basics() { + #[cfg(feature = "enable_logging")] + let _ = env_logger::try_init(); + round_trip(b""); + // Note: we use vec! to avoid storing the slice in the binary + round_trip(vec![0x00; 1_000_000].as_slice()); + round_trip(vec![0xFF; 1_000_000].as_slice()); +} + +#[test] +fn round_trip_hello() { + #[cfg(feature = "enable_logging")] + let _ = env_logger::try_init(); + round_trip(b"Hello world"); +} + +#[test] +fn round_trip_files() { + #[cfg(feature = "enable_logging")] + let _ = env_logger::try_init(); + round_trip_file("tests/files/foo.txt"); + round_trip_file("tests/files/range-coder-edge-case"); +} + +#[test] +fn big_file() { + #[cfg(feature = "enable_logging")] + let _ = env_logger::try_init(); + decomp_big_file("tests/files/foo.txt.lzma", "tests/files/foo.txt"); + decomp_big_file("tests/files/hugedict.txt.lzma", "tests/files/foo.txt"); + decomp_big_file( + "tests/files/range-coder-edge-case.lzma", + "tests/files/range-coder-edge-case", + ); +} + +#[test] +fn stream_decompress_empty() { + let input: &[u8] = b"\x5d\x00\x00\x80\x00\xff\xff\xff\xff\xff\xff\xff\xff\x00\x83\xff\ + \xfb\xff\xff\xc0\x00\x00\x00"; + + let mut stream = lzma_rs::decompress::Stream::new(Vec::new()); + stream.write_all(&input).unwrap(); + + let decomp = stream.finish().unwrap(); + assert_eq!(decomp, b"") +} + +#[test] +fn decompress_hello_world() { + #[cfg(feature = "enable_logging")] + let _ = env_logger::try_init(); + let x: &[u8] = b"\x5d\x00\x00\x80\x00\xff\xff\xff\xff\xff\xff\xff\xff\x00\x24\x19\ + \x49\x98\x6f\x10\x19\xc6\xd7\x31\xeb\x36\x50\xb2\x98\x48\xff\xfe\ + \xa5\xb0\x00"; + let mut stream = lzma_rs::decompress::Stream::new(Vec::new()); + stream.write_all(x).unwrap(); + let decomp = stream.finish().unwrap(); + assert_eq!(decomp, b"Hello world\x0a") +} + +#[test] +fn decompress_huge_dict() { + // Hello world with a dictionary of size 0x7F7F7F7F + #[cfg(feature = "enable_logging")] + let _ = env_logger::try_init(); + let x = b"\x5d\x7f\x7f\x7f\x7f\xff\xff\xff\xff\xff\xff\xff\xff\x00\x24\x19\ + \x49\x98\x6f\x10\x19\xc6\xd7\x31\xeb\x36\x50\xb2\x98\x48\xff\xfe\ + \xa5\xb0\x00"; + let mut stream = lzma_rs::decompress::Stream::new(Vec::new()); + stream.write_all(x).unwrap(); + let decomp = stream.finish().unwrap(); + assert_eq!(decomp, b"Hello world\x0a") +} + +#[test] +fn unpacked_size_write_to_header() { + let data = b"Some data"; + let encode_options = lzma_rs::compress::Options { + unpacked_size: lzma_rs::compress::UnpackedSize::WriteToHeader(Some(data.len() as u64)), + }; + let decode_options = lzma_rs::decompress::Options { + unpacked_size: lzma_rs::decompress::UnpackedSize::ReadFromHeader, + ..Default::default() + }; + round_trip_with_options(&data[..], &encode_options, &decode_options); +} + +#[test] +fn unpacked_size_provided_outside() { + let data = b"Some data"; + let encode_options = lzma_rs::compress::Options { + unpacked_size: lzma_rs::compress::UnpackedSize::SkipWritingToHeader, + }; + let decode_options = lzma_rs::decompress::Options { + unpacked_size: lzma_rs::decompress::UnpackedSize::UseProvided(Some(data.len() as u64)), + ..Default::default() + }; + round_trip_with_options(&data[..], &encode_options, &decode_options); +} + +#[test] +fn unpacked_size_write_some_to_header_but_use_provided_on_read() { + let data = b"Some data"; + let encode_options = lzma_rs::compress::Options { + unpacked_size: lzma_rs::compress::UnpackedSize::WriteToHeader(Some(data.len() as u64)), + }; + let decode_options = lzma_rs::decompress::Options { + unpacked_size: lzma_rs::decompress::UnpackedSize::ReadHeaderButUseProvided(Some( + data.len() as u64, + )), + ..Default::default() + }; + round_trip_with_options(&data[..], &encode_options, &decode_options); +} + +#[test] +fn unpacked_size_write_none_to_header_and_use_provided_on_read() { + let data = b"Some data"; + let encode_options = lzma_rs::compress::Options { + unpacked_size: lzma_rs::compress::UnpackedSize::WriteToHeader(None), + }; + let decode_options = lzma_rs::decompress::Options { + unpacked_size: lzma_rs::decompress::UnpackedSize::ReadHeaderButUseProvided(Some( + data.len() as u64, + )), + ..Default::default() + }; + round_trip_with_options(&data[..], &encode_options, &decode_options); +} + +#[test] +fn unpacked_size_write_none_to_header_and_use_provided_none_on_read() { + let data = b"Some data"; + let encode_options = lzma_rs::compress::Options { + unpacked_size: lzma_rs::compress::UnpackedSize::WriteToHeader(None), + }; + let decode_options = lzma_rs::decompress::Options { + unpacked_size: lzma_rs::decompress::UnpackedSize::ReadHeaderButUseProvided(None), + ..Default::default() + }; + round_trip_with_options(&data[..], &encode_options, &decode_options); +} + +#[test] +fn memlimit() { + let data = b"Some data"; + let encode_options = lzma_rs::compress::Options { + unpacked_size: lzma_rs::compress::UnpackedSize::WriteToHeader(None), + }; + let decode_options = lzma_rs::decompress::Options { + unpacked_size: lzma_rs::decompress::UnpackedSize::ReadHeaderButUseProvided(None), + memlimit: Some(0), + ..Default::default() + }; + let mut compressed: Vec = Vec::new(); + lzma_rs::lzma_compress_with_options( + &mut std::io::BufReader::new(&data[..]), + &mut compressed, + &encode_options, + ) + .unwrap(); + let mut stream = + lzma_rs::decompress::Stream::new_with_options(decode_options.clone(), Vec::new()); + + let error = stream.write_all(&compressed).unwrap_err(); + assert!( + error.to_string().contains("exceeded memory limit of 0"), + error.to_string() + ); + let error = stream.finish().unwrap_err(); + assert!( + error.to_string().contains("previous write error"), + error.to_string() + ); +}