Skip to content

Commit

Permalink
Better Filters Error Handling (#171)
Browse files Browse the repository at this point in the history
* better error handling

* dropping dead code.
  • Loading branch information
soldni authored Aug 6, 2024
1 parent a72c76b commit b406546
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 46 deletions.
172 changes: 143 additions & 29 deletions src/filters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,16 +296,23 @@ impl JqDocFilter {
Ok(filters)
}

fn evaluate_match(&self, result: &Result<Val, jaq_interpret::Error>) -> bool {
fn evaluate_match(
&self,
result: &Result<Val, jaq_interpret::Error>,
) -> Result<bool, io::Error> {
match result {
Ok(jaq_interpret::Val::Bool(b)) => *b,
Ok(jaq_interpret::Val::Null) => false,
Ok(jaq_interpret::Val::Int(i)) => *i != 0,
Ok(jaq_interpret::Val::Float(f)) => *f != 0.0,
Ok(jaq_interpret::Val::Str(s)) => !s.is_empty(),
Ok(jaq_interpret::Val::Arr(a)) => !a.is_empty(),
Ok(jaq_interpret::Val::Obj(d)) => !d.is_empty(),
_ => true,
Ok(jaq_interpret::Val::Bool(b)) => Ok(*b),
Ok(jaq_interpret::Val::Null) => Ok(false),
Ok(jaq_interpret::Val::Int(i)) => Ok(*i != 0),
Ok(jaq_interpret::Val::Float(f)) => Ok(*f != 0.0),
Ok(jaq_interpret::Val::Str(s)) => Ok(!s.is_empty()),
Ok(jaq_interpret::Val::Arr(a)) => Ok(!a.is_empty()),
Ok(jaq_interpret::Val::Obj(d)) => Ok(!d.is_empty()),
Err(err) => Err(io::Error::new(
io::ErrorKind::Other,
format!("Error evaluating filter: {:?}", err),
)),
_ => Ok(true),
}
}

Expand All @@ -317,8 +324,9 @@ impl JqDocFilter {
exclude: exclude_filters,
})
}
pub fn should_keep(&self, json: &Value) -> Result<bool, String> {
pub fn should_keep(&self, json: &Value) -> Result<bool, io::Error> {
let mut keep = self.include.is_empty();

let inputs: RcIter<std::iter::Empty<_>> = RcIter::new(core::iter::empty());
for filter in self.include.iter() {
// exit early if keep is already true
Expand All @@ -329,18 +337,49 @@ impl JqDocFilter {
let out: Vec<Result<jaq_interpret::Val, jaq_interpret::Error>> = filter
.run((Ctx::new(Vec::new(), &inputs), Val::from(json.clone())))
.collect();
// if out is not empty and all its elements are true, then keep is true
keep = !out.is_empty() && out.iter().all(|x| self.evaluate_match(x));

// if the filter returns something, evaluate each result and update keep;
// keep will be true at the end of the loop if all results are true and the filter is not empty
// if an any point an error is encountered, immediately return the error
keep = match out.is_empty() {
true => false,
false => {
let mut partial_keep = true;
for result in out.iter() {
match self.evaluate_match(result) {
Ok(val) => partial_keep = partial_keep && val,
Err(e) => return Err(e),
}
}
partial_keep
}
};
}

for filter in self.exclude.iter() {
// exit early if keep is already false
if !keep {
break;
}
let out: Vec<_> = filter
.run((Ctx::new(Vec::new(), &inputs), Val::from(json.clone())))
.collect();
keep = out.is_empty() || !out.iter().all(|x| self.evaluate_match(x));

// if the filter returns nothing, we keep the document; otherwise, evaluate each result
// and check if they all evaluate to false; if any result is true, we remove the document
keep = match out.is_empty() {
true => true,
false => {
let mut partial_keep = true;
for result in out.iter() {
match self.evaluate_match(result) {
Ok(val) => partial_keep = partial_keep && !val,
Err(e) => return Err(e),
}
}
partial_keep
}
};
}
Ok(keep)
}
Expand All @@ -353,10 +392,21 @@ impl JsonPathFilter {
exclude: filter_config.exclude.clone(),
})
}
pub fn should_keep(&self, json: &Value) -> Result<bool, String> {
pub fn should_keep(&self, json: &Value) -> Result<bool, io::Error> {
let mut keep = self.include.is_empty();
for pattern in self.include.iter() {
let mut finder = JsonPathFinder::from_str("{}", pattern)?;
let mut finder = match JsonPathFinder::from_str("{}", pattern) {
Ok(finder) => finder,
Err(e) => {
return Err(io::Error::new(
io::ErrorKind::Other,
format!(
"Error making include pattern {} into filter: {:?}",
pattern, e
),
))
}
};
finder.set_json(Box::new(json.clone()));
keep = finder.find() != Value::Null;
if keep {
Expand All @@ -365,7 +415,18 @@ impl JsonPathFilter {
}
if keep {
for pattern in self.exclude.iter() {
let mut finder = JsonPathFinder::from_str("{}", pattern)?;
let mut finder = match JsonPathFinder::from_str("{}", pattern) {
Ok(finder) => finder,
Err(e) => {
return Err(io::Error::new(
io::ErrorKind::Other,
format!(
"Error making exclude pattern {} into filter: {:?}",
pattern, e
),
))
}
};
finder.set_json(Box::new(json.clone()));
keep = finder.find() == Value::Null;
if !keep {
Expand All @@ -383,38 +444,38 @@ impl AllowAllFilter {
pub fn new() -> Result<AllowAllFilter, io::Error> {
Ok(AllowAllFilter)
}
pub fn should_keep(&self, _json: &Value) -> Result<bool, String> {
pub fn should_keep(&self, _json: &Value) -> Result<bool, io::Error> {
Ok(true)
}
}

pub enum DocFilter {
JqDocFilter(JqDocFilter),
JsonPathFilter(JsonPathFilter),
AllowAllFilter(AllowAllFilter),
Jq(JqDocFilter),
JsonPath(JsonPathFilter),
AllowAll(AllowAllFilter),
}

impl DocFilter {
pub fn new(filter_config: Option<&FilterConfig>) -> Result<DocFilter, io::Error> {
match filter_config {
Some(filter_config) => match filter_config.syntax.as_deref() {
Some("jq") => Ok(DocFilter::JqDocFilter(JqDocFilter::new(filter_config)?)),
Some("jsonpath") | None => Ok(DocFilter::JsonPathFilter(JsonPathFilter::new(
filter_config,
)?)),
Some("jq") => Ok(DocFilter::Jq(JqDocFilter::new(filter_config)?)),
Some("jsonpath") | None => {
Ok(DocFilter::JsonPath(JsonPathFilter::new(filter_config)?))
}
_ => Err(io::Error::new(
io::ErrorKind::Other,
format!("Unknown filter syntax: {:?}", filter_config.syntax),
)),
},
None => Ok(DocFilter::AllowAllFilter(AllowAllFilter::new()?)),
None => Ok(DocFilter::AllowAll(AllowAllFilter::new()?)),
}
}
pub fn should_keep(&self, json: &Value) -> Result<bool, String> {
pub fn should_keep(&self, json: &Value) -> Result<bool, io::Error> {
match self {
DocFilter::JqDocFilter(f) => f.should_keep(json),
DocFilter::JsonPathFilter(f) => f.should_keep(json),
DocFilter::AllowAllFilter(f) => f.should_keep(json),
DocFilter::Jq(f) => f.should_keep(json),
DocFilter::JsonPath(f) => f.should_keep(json),
DocFilter::AllowAll(f) => f.should_keep(json),
}
}
}
Expand Down Expand Up @@ -565,4 +626,57 @@ mod filter_tests {
let result = DocFilter::new(Some(&filter_config));
assert!(result.is_err());
}

#[test]
fn test_jq_multiple_conditions() {
let filter_config = FilterConfig {
include: vec![
"(.attributes.dedupe_para_ngrams_13_1 | length == 0) or ((.attributes.dedupe_para_ngrams_13_1 | map(.[2] * (.[1] - .[0])) | add) / (.text | length) <= 0.3)".to_string(),
],
exclude: vec![
".attributes.paloma_documents != null".to_string(),
"(.attributes.paloma_paragraphs | length) > 0".to_string(),
"(.tokenizer_repetitions_v2r2__tokenizer_repetitions_v2r2__doc_max_score_repetition != null) and (.tokenizer_repetitions_v2r2__tokenizer_repetitions_v2r2__doc_max_score_repetition[0][-1] > 10)".to_string(),
".attributes.cc_multi_bin__cc_multi_bin__hq[0][-1] <= 0.01".to_string(),
".attributes.pii_regex_with_counts_fast_v2__pii_regex_with_counts_fast_v2__doc_count[0][-1] > 5".to_string(),
],
syntax: Some("jq".to_string()),
};
let filters = DocFilter::new(Some(&filter_config)).unwrap();

let doc = json!({
"attributes": {
"cc_multi_bin__cc_multi_bin__lq": [[0, 1533, 0.99438]],
"cc_multi_bin__cc_multi_bin__hq": [[0, 1533, 0.00564]],
"dedupe_para_ngrams_13_1": [],
"paloma_paragraphs": [],
"pii_regex_with_counts_fast_v2__pii_regex_with_counts_fast_v2__doc_count": [[0, 1533, 0.0]],
"pii_regex_with_counts_fast_v2__pii_regex_with_counts_fast_v2__doc_frac": [[0, 1533, 1.0]],
"tokenizer_repetitions_v2r2__tokenizer_repetitions_v2r2__repetition": [[493, 533, 10.0]],
"tokenizer_repetitions_v2r2__tokenizer_repetitions_v2r2__doc_max_score_repetition": [[0, 1533, 10.0]],
"tokenizer_repetitions_v2r2__tokenizer_repetitions_v2r2__doc_max_length_repetition": [[0, 1533, 40.0]],
"tokenizer_repetitions_v2r2__tokenizer_repetitions_v2r2__doc_frac_repetition": [[0, 1533, 0.02609]]
}
});

assert_eq!(filters.should_keep(&doc).unwrap(), false);
}

#[test]
fn test_jq_missing_attr() {
let filter_config = FilterConfig {
include: vec![".attributes.b.b != null".to_string()],
exclude: vec![],
syntax: Some("jq".to_string()),
};
let filters = DocFilter::new(Some(&filter_config)).unwrap();
let doc = json!({
"text": "test",
"id": "0",
"attributes": {"a": [[0, 3, 1]]},
"source": "test"
});
let result = filters.should_keep(&doc);
assert!(result.is_err());
}
}
16 changes: 0 additions & 16 deletions src/shard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,6 @@ impl Shard {
)
.reader()?;

// let input_file = OpenOptions::new()
// .read(true)
// .write(false)
// .create(false)
// .open(&local_docs_file)?;
// let reader = BufReader::with_capacity(1024 * 1024, MultiGzDecoder::new(input_file));

let mut line_number = 0;
let mut lines_written = 0;

Expand Down Expand Up @@ -332,15 +325,6 @@ impl Shard {
.map_err(|s| IoError::new(IoErrorKind::Other, s))?;

if should_write {
// if self.span_replacements.is_some() {
// let mut replacements = self
// .span_replacements
// .as_ref()
// .unwrap()
// .iter()
// .flat_map(|r| r.find_spans_to_replace(&data).unwrap())
// .collect::<Vec<SpanReplacement>>();

let mut replacements = span_replacers
.iter()
.map(|replacer| replacer.find_spans_to_replace(&data))
Expand Down
33 changes: 32 additions & 1 deletion tests/python/test_mixer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import json
from pathlib import Path
from tempfile import NamedTemporaryFile
from tempfile import NamedTemporaryFile, TemporaryDirectory
from typing import List
from unittest import TestCase

import smart_open

from dolma.cli.__main__ import main

from .utils import (
Expand Down Expand Up @@ -221,3 +223,32 @@ def test_min_length(self):
self.assertEqual(len(new_docs), 2)
self.assertEqual(new_docs[0]["text"], self.combineIntoDoc(to_keep_head, ""))
self.assertEqual(new_docs[1]["text"], self.combineIntoDoc(to_keep_head, to_keep_tail))

def test_fail_on_unk_attribute(self):
with TemporaryDirectory() as temp_dir:
src_fp = (docs_dir := Path(temp_dir) / "documents") / "0000.jsonl.gz"
docs_dir.mkdir(exist_ok=True, parents=True)
(dst_fp := (docs_dir / "output")).mkdir(exist_ok=True, parents=True)

docs = [{"text": "test", "id": "0", "attributes": {"a": [[0, 3, 1]]}, "source": __file__}]
with smart_open.open(src_fp, "wt") as f:
f.write("\n".join(map(json.dumps, docs)))

config = {
"streams": [
{
"name": "test",
"documents": [str(src_fp)],
"output": {"path": str(dst_fp), "max_size_in_bytes": 10000000},
"filter": {"include": [".attributes.b.b != null"], "syntax": "jq"},
}
],
"processes": 1,
}

config_fp = Path(temp_dir) / "config.json"
with config_fp.open("w") as f:
json.dump(config, f)

with self.assertRaises(Exception):
main(argv=["-c", str(config_fp), "mix"])

0 comments on commit b406546

Please sign in to comment.