Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small refactoring + tests #27

Merged
merged 1 commit into from
May 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 1 addition & 68 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ path = "src/pisa2ciff.rs"
[dependencies]
protobuf = "2.22"
structopt = "0.3"
num = "0.4"
num-traits = "0"
indicatif = "0.15"
anyhow = "1.0"
memmap = "0.7"
Expand Down
126 changes: 103 additions & 23 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,17 @@ use anyhow::{anyhow, Context};
use indicatif::ProgressIterator;
use indicatif::{ProgressBar, ProgressStyle};
use memmap::Mmap;
use num::ToPrimitive;
use num_traits::ToPrimitive;
use protobuf::{CodedInputStream, CodedOutputStream};
use std::borrow::Borrow;
use std::convert::TryFrom;
use std::fmt;
use std::fs::File;
use std::io::{self, BufRead, BufReader, BufWriter, Write};
use std::path::{Path, PathBuf};

mod proto;
pub use proto::{DocRecord, Header, Posting, PostingsList};
pub use proto::{DocRecord, Posting, PostingsList};
mod binary_collection;
pub use binary_collection::{BinaryCollection, BinarySequence, InvalidFormat};

Expand All @@ -43,6 +44,43 @@ type Result<T> = anyhow::Result<T>;
const DEFAULT_PROGRESS_TEMPLATE: &str =
"{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {count}/{total} ({eta})";

/// Wraps [`proto::Header`] and additionally provides some important counts that are already cast
/// to an unsigned type.
#[derive(PartialEq, Clone, Default)]
struct Header {
num_postings_lists: u32,
num_documents: u32,
/// Used for printing.
protobuf_header: proto::Header,
}

impl Header {
/// Reads the protobuf header, and converts to a proper-typed header to fail fast if the protobuf
/// header contains any negative values.
///
/// # Errors
///
/// Returns an error if the protobuf header contains negative counts.
fn from_stream(input: &mut CodedInputStream<'_>) -> Result<Self> {
let header = input.read_message::<proto::Header>()?;
let num_documents = u32::try_from(header.get_num_docs())
.context("Number of documents must be non-negative.")?;
let num_postings_lists = u32::try_from(header.get_num_postings_lists())
.context("Number of documents must be non-negative.")?;
Ok(Self {
protobuf_header: header,
num_documents,
num_postings_lists,
})
}
}

impl fmt::Display for Header {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.protobuf_header)
}
}

/// Returns default progress style.
fn pb_style() -> ProgressStyle {
ProgressStyle::default_bar()
Expand Down Expand Up @@ -146,18 +184,15 @@ pub fn ciff_to_pisa(input: &Path, output: &Path) -> Result<()> {
let mut frequencies = BufWriter::new(File::create(format!("{}.freqs", output.display()))?);
let mut terms = BufWriter::new(File::create(format!("{}.terms", output.display()))?);

// Read protobuf header
let header = input.read_message::<Header>()?;
let num_documents = u32::try_from(header.get_num_docs())
.context("Number of documents must be non-negative.")?;
println!("{}", &header);
let header = Header::from_stream(&mut input)?;
println!("{}", header);

eprintln!("Processing postings");
encode_u32_sequence(&mut documents, 1, [num_documents].iter())?;
let progress = ProgressBar::new(u64::try_from(header.get_num_postings_lists())?);
encode_u32_sequence(&mut documents, 1, [header.num_documents].iter())?;
let progress = ProgressBar::new(u64::try_from(header.num_postings_lists)?);
progress.set_style(pb_style());
progress.set_draw_delta(10);
for _ in 0..header.get_num_postings_lists() {
for _ in 0..header.num_postings_lists {
write_posting_list(
&input.read_message::<PostingsList>()?,
&mut documents,
Expand All @@ -176,17 +211,12 @@ pub fn ciff_to_pisa(input: &Path, output: &Path) -> Result<()> {
let mut sizes = BufWriter::new(File::create(format!("{}.sizes", output.display()))?);
let mut trecids = BufWriter::new(File::create(format!("{}.documents", output.display()))?);

let progress = ProgressBar::new(u64::from(num_documents));
let progress = ProgressBar::new(u64::from(header.num_documents));
progress.set_style(pb_style());
progress.set_draw_delta(u64::from(num_documents) / 100);
sizes.write_all(&num_documents.to_le_bytes())?;

let expected_docs: usize = header
.get_num_docs()
.to_usize()
.ok_or_else(|| anyhow!("Cannot cast num docs to usize: {}", header.get_num_docs()))?;
progress.set_draw_delta(u64::from(header.num_documents) / 100);
sizes.write_all(&header.num_documents.to_le_bytes())?;

for docs_seen in 0..expected_docs {
for docs_seen in 0..header.num_documents {
let doc_record = input.read_message::<DocRecord>()?;

let docid: u32 = doc_record
Expand All @@ -202,7 +232,7 @@ pub fn ciff_to_pisa(input: &Path, output: &Path) -> Result<()> {
)
})?;

if docid as usize != docs_seen {
if docid != docs_seen {
anyhow::bail!("Document sizes must come in order");
}

Expand All @@ -226,7 +256,7 @@ fn read_document_count(
.ok_or_else(invalid)
}

fn header(documents_bytes: &[u8], sizes_bytes: &[u8], description: &str) -> Result<Header> {
fn header(documents_bytes: &[u8], sizes_bytes: &[u8], description: &str) -> Result<proto::Header> {
let mut num_postings_lists = 0;

eprintln!("Collecting posting lists statistics");
Expand All @@ -251,7 +281,7 @@ fn header(documents_bytes: &[u8], sizes_bytes: &[u8], description: &str) -> Resu
.progress_with(progress)
.sum();

let mut header = Header::default();
let mut header = proto::Header::default();
header.set_version(1);
header.set_description(description.into());
header.set_num_postings_lists(num_postings_lists);
Expand Down Expand Up @@ -384,10 +414,12 @@ fn pisa_to_ciff_from_paths(

#[cfg(test)]
mod test {
use super::*;

#[test]
fn test_size_sequence() {
let empty_memory = Vec::<u8>::new();
let sizes = super::sizes(&empty_memory);
let sizes = sizes(&empty_memory);
assert!(sizes.is_err());
assert_eq!(
"Invalid binary collection format: sizes collection is empty",
Expand All @@ -413,4 +445,52 @@ mod test {
vec![1_u32, 2, 3, 4, 5]
);
}

fn header_to_buf(header: &proto::Header) -> Result<Vec<u8>> {
let mut buffer = Vec::<u8>::new();
let mut out = CodedOutputStream::vec(&mut buffer);
out.write_message_no_tag(header)?;
out.flush()?;
Ok(buffer)
}

#[test]
fn test_read_default_header() -> Result<()> {
let mut proto_header = proto::Header::default();
proto_header.set_num_docs(17);
proto_header.set_num_postings_lists(1234);

let buffer = header_to_buf(&proto_header)?;

let mut input = CodedInputStream::from_bytes(&buffer);
let header = Header::from_stream(&mut input)?;
assert_eq!(header.protobuf_header, proto_header);
assert_eq!(header.num_documents, 17);
assert_eq!(header.num_postings_lists, 1234);
Ok(())
}

#[test]
fn test_read_negative_num_documents() -> Result<()> {
let mut proto_header = proto::Header::default();
proto_header.set_num_docs(-17);

let buffer = header_to_buf(&proto_header)?;

let mut input = CodedInputStream::from_bytes(&buffer);
assert!(Header::from_stream(&mut input).is_err());
Ok(())
}

#[test]
fn test_read_negative_num_posting_lists() -> Result<()> {
let mut proto_header = proto::Header::default();
proto_header.set_num_postings_lists(-1234);

let buffer = header_to_buf(&proto_header)?;

let mut input = CodedInputStream::from_bytes(&buffer);
assert!(Header::from_stream(&mut input).is_err());
Ok(())
}
}