diff --git a/src/filters.rs b/src/filters.rs index 30f4c290..945eb4f3 100644 --- a/src/filters.rs +++ b/src/filters.rs @@ -296,16 +296,23 @@ impl JqDocFilter { Ok(filters) } - fn evaluate_match(&self, result: &Result) -> bool { + fn evaluate_match( + &self, + result: &Result, + ) -> Result { match result { - Ok(jaq_interpret::Val::Bool(b)) => *b, - Ok(jaq_interpret::Val::Null) => false, - Ok(jaq_interpret::Val::Int(i)) => *i != 0, - Ok(jaq_interpret::Val::Float(f)) => *f != 0.0, - Ok(jaq_interpret::Val::Str(s)) => !s.is_empty(), - Ok(jaq_interpret::Val::Arr(a)) => !a.is_empty(), - Ok(jaq_interpret::Val::Obj(d)) => !d.is_empty(), - _ => true, + Ok(jaq_interpret::Val::Bool(b)) => Ok(*b), + Ok(jaq_interpret::Val::Null) => Ok(false), + Ok(jaq_interpret::Val::Int(i)) => Ok(*i != 0), + Ok(jaq_interpret::Val::Float(f)) => Ok(*f != 0.0), + Ok(jaq_interpret::Val::Str(s)) => Ok(!s.is_empty()), + Ok(jaq_interpret::Val::Arr(a)) => Ok(!a.is_empty()), + Ok(jaq_interpret::Val::Obj(d)) => Ok(!d.is_empty()), + Err(err) => Err(io::Error::new( + io::ErrorKind::Other, + format!("Error evaluating filter: {:?}", err), + )), + _ => Ok(true), } } @@ -317,8 +324,9 @@ impl JqDocFilter { exclude: exclude_filters, }) } - pub fn should_keep(&self, json: &Value) -> Result { + pub fn should_keep(&self, json: &Value) -> Result { let mut keep = self.include.is_empty(); + let inputs: RcIter> = RcIter::new(core::iter::empty()); for filter in self.include.iter() { // exit early if keep is already true @@ -329,18 +337,49 @@ impl JqDocFilter { let out: Vec> = filter .run((Ctx::new(Vec::new(), &inputs), Val::from(json.clone()))) .collect(); - // if out is not empty and all its elements are true, then keep is true - keep = !out.is_empty() && out.iter().all(|x| self.evaluate_match(x)); + + // if the filter returns something, evaluate each result and update keep; + // keep will be true at the end of the loop if all results are true and the filter is not empty + // if an any point an error is encountered, immediately return the error + keep = match out.is_empty() { + true => false, + false => { + let mut partial_keep = true; + for result in out.iter() { + match self.evaluate_match(result) { + Ok(val) => partial_keep = partial_keep && val, + Err(e) => return Err(e), + } + } + partial_keep + } + }; } for filter in self.exclude.iter() { + // exit early if keep is already false if !keep { break; } let out: Vec<_> = filter .run((Ctx::new(Vec::new(), &inputs), Val::from(json.clone()))) .collect(); - keep = out.is_empty() || !out.iter().all(|x| self.evaluate_match(x)); + + // if the filter returns nothing, we keep the document; otherwise, evaluate each result + // and check if they all evaluate to false; if any result is true, we remove the document + keep = match out.is_empty() { + true => true, + false => { + let mut partial_keep = true; + for result in out.iter() { + match self.evaluate_match(result) { + Ok(val) => partial_keep = partial_keep && !val, + Err(e) => return Err(e), + } + } + partial_keep + } + }; } Ok(keep) } @@ -353,10 +392,21 @@ impl JsonPathFilter { exclude: filter_config.exclude.clone(), }) } - pub fn should_keep(&self, json: &Value) -> Result { + pub fn should_keep(&self, json: &Value) -> Result { let mut keep = self.include.is_empty(); for pattern in self.include.iter() { - let mut finder = JsonPathFinder::from_str("{}", pattern)?; + let mut finder = match JsonPathFinder::from_str("{}", pattern) { + Ok(finder) => finder, + Err(e) => { + return Err(io::Error::new( + io::ErrorKind::Other, + format!( + "Error making include pattern {} into filter: {:?}", + pattern, e + ), + )) + } + }; finder.set_json(Box::new(json.clone())); keep = finder.find() != Value::Null; if keep { @@ -365,7 +415,18 @@ impl JsonPathFilter { } if keep { for pattern in self.exclude.iter() { - let mut finder = JsonPathFinder::from_str("{}", pattern)?; + let mut finder = match JsonPathFinder::from_str("{}", pattern) { + Ok(finder) => finder, + Err(e) => { + return Err(io::Error::new( + io::ErrorKind::Other, + format!( + "Error making exclude pattern {} into filter: {:?}", + pattern, e + ), + )) + } + }; finder.set_json(Box::new(json.clone())); keep = finder.find() == Value::Null; if !keep { @@ -383,38 +444,38 @@ impl AllowAllFilter { pub fn new() -> Result { Ok(AllowAllFilter) } - pub fn should_keep(&self, _json: &Value) -> Result { + pub fn should_keep(&self, _json: &Value) -> Result { Ok(true) } } pub enum DocFilter { - JqDocFilter(JqDocFilter), - JsonPathFilter(JsonPathFilter), - AllowAllFilter(AllowAllFilter), + Jq(JqDocFilter), + JsonPath(JsonPathFilter), + AllowAll(AllowAllFilter), } impl DocFilter { pub fn new(filter_config: Option<&FilterConfig>) -> Result { match filter_config { Some(filter_config) => match filter_config.syntax.as_deref() { - Some("jq") => Ok(DocFilter::JqDocFilter(JqDocFilter::new(filter_config)?)), - Some("jsonpath") | None => Ok(DocFilter::JsonPathFilter(JsonPathFilter::new( - filter_config, - )?)), + Some("jq") => Ok(DocFilter::Jq(JqDocFilter::new(filter_config)?)), + Some("jsonpath") | None => { + Ok(DocFilter::JsonPath(JsonPathFilter::new(filter_config)?)) + } _ => Err(io::Error::new( io::ErrorKind::Other, format!("Unknown filter syntax: {:?}", filter_config.syntax), )), }, - None => Ok(DocFilter::AllowAllFilter(AllowAllFilter::new()?)), + None => Ok(DocFilter::AllowAll(AllowAllFilter::new()?)), } } - pub fn should_keep(&self, json: &Value) -> Result { + pub fn should_keep(&self, json: &Value) -> Result { match self { - DocFilter::JqDocFilter(f) => f.should_keep(json), - DocFilter::JsonPathFilter(f) => f.should_keep(json), - DocFilter::AllowAllFilter(f) => f.should_keep(json), + DocFilter::Jq(f) => f.should_keep(json), + DocFilter::JsonPath(f) => f.should_keep(json), + DocFilter::AllowAll(f) => f.should_keep(json), } } } @@ -565,4 +626,57 @@ mod filter_tests { let result = DocFilter::new(Some(&filter_config)); assert!(result.is_err()); } + + #[test] + fn test_jq_multiple_conditions() { + let filter_config = FilterConfig { + include: vec![ + "(.attributes.dedupe_para_ngrams_13_1 | length == 0) or ((.attributes.dedupe_para_ngrams_13_1 | map(.[2] * (.[1] - .[0])) | add) / (.text | length) <= 0.3)".to_string(), + ], + exclude: vec![ + ".attributes.paloma_documents != null".to_string(), + "(.attributes.paloma_paragraphs | length) > 0".to_string(), + "(.tokenizer_repetitions_v2r2__tokenizer_repetitions_v2r2__doc_max_score_repetition != null) and (.tokenizer_repetitions_v2r2__tokenizer_repetitions_v2r2__doc_max_score_repetition[0][-1] > 10)".to_string(), + ".attributes.cc_multi_bin__cc_multi_bin__hq[0][-1] <= 0.01".to_string(), + ".attributes.pii_regex_with_counts_fast_v2__pii_regex_with_counts_fast_v2__doc_count[0][-1] > 5".to_string(), + ], + syntax: Some("jq".to_string()), + }; + let filters = DocFilter::new(Some(&filter_config)).unwrap(); + + let doc = json!({ + "attributes": { + "cc_multi_bin__cc_multi_bin__lq": [[0, 1533, 0.99438]], + "cc_multi_bin__cc_multi_bin__hq": [[0, 1533, 0.00564]], + "dedupe_para_ngrams_13_1": [], + "paloma_paragraphs": [], + "pii_regex_with_counts_fast_v2__pii_regex_with_counts_fast_v2__doc_count": [[0, 1533, 0.0]], + "pii_regex_with_counts_fast_v2__pii_regex_with_counts_fast_v2__doc_frac": [[0, 1533, 1.0]], + "tokenizer_repetitions_v2r2__tokenizer_repetitions_v2r2__repetition": [[493, 533, 10.0]], + "tokenizer_repetitions_v2r2__tokenizer_repetitions_v2r2__doc_max_score_repetition": [[0, 1533, 10.0]], + "tokenizer_repetitions_v2r2__tokenizer_repetitions_v2r2__doc_max_length_repetition": [[0, 1533, 40.0]], + "tokenizer_repetitions_v2r2__tokenizer_repetitions_v2r2__doc_frac_repetition": [[0, 1533, 0.02609]] + } + }); + + assert_eq!(filters.should_keep(&doc).unwrap(), false); + } + + #[test] + fn test_jq_missing_attr() { + let filter_config = FilterConfig { + include: vec![".attributes.b.b != null".to_string()], + exclude: vec![], + syntax: Some("jq".to_string()), + }; + let filters = DocFilter::new(Some(&filter_config)).unwrap(); + let doc = json!({ + "text": "test", + "id": "0", + "attributes": {"a": [[0, 3, 1]]}, + "source": "test" + }); + let result = filters.should_keep(&doc); + assert!(result.is_err()); + } } diff --git a/src/shard.rs b/src/shard.rs index 89100e05..576ef167 100644 --- a/src/shard.rs +++ b/src/shard.rs @@ -213,13 +213,6 @@ impl Shard { ) .reader()?; - // let input_file = OpenOptions::new() - // .read(true) - // .write(false) - // .create(false) - // .open(&local_docs_file)?; - // let reader = BufReader::with_capacity(1024 * 1024, MultiGzDecoder::new(input_file)); - let mut line_number = 0; let mut lines_written = 0; @@ -332,15 +325,6 @@ impl Shard { .map_err(|s| IoError::new(IoErrorKind::Other, s))?; if should_write { - // if self.span_replacements.is_some() { - // let mut replacements = self - // .span_replacements - // .as_ref() - // .unwrap() - // .iter() - // .flat_map(|r| r.find_spans_to_replace(&data).unwrap()) - // .collect::>(); - let mut replacements = span_replacers .iter() .map(|replacer| replacer.find_spans_to_replace(&data)) diff --git a/tests/python/test_mixer.py b/tests/python/test_mixer.py index ddf404d6..68ea1721 100644 --- a/tests/python/test_mixer.py +++ b/tests/python/test_mixer.py @@ -1,9 +1,11 @@ import json from pathlib import Path -from tempfile import NamedTemporaryFile +from tempfile import NamedTemporaryFile, TemporaryDirectory from typing import List from unittest import TestCase +import smart_open + from dolma.cli.__main__ import main from .utils import ( @@ -221,3 +223,32 @@ def test_min_length(self): self.assertEqual(len(new_docs), 2) self.assertEqual(new_docs[0]["text"], self.combineIntoDoc(to_keep_head, "")) self.assertEqual(new_docs[1]["text"], self.combineIntoDoc(to_keep_head, to_keep_tail)) + + def test_fail_on_unk_attribute(self): + with TemporaryDirectory() as temp_dir: + src_fp = (docs_dir := Path(temp_dir) / "documents") / "0000.jsonl.gz" + docs_dir.mkdir(exist_ok=True, parents=True) + (dst_fp := (docs_dir / "output")).mkdir(exist_ok=True, parents=True) + + docs = [{"text": "test", "id": "0", "attributes": {"a": [[0, 3, 1]]}, "source": __file__}] + with smart_open.open(src_fp, "wt") as f: + f.write("\n".join(map(json.dumps, docs))) + + config = { + "streams": [ + { + "name": "test", + "documents": [str(src_fp)], + "output": {"path": str(dst_fp), "max_size_in_bytes": 10000000}, + "filter": {"include": [".attributes.b.b != null"], "syntax": "jq"}, + } + ], + "processes": 1, + } + + config_fp = Path(temp_dir) / "config.json" + with config_fp.open("w") as f: + json.dump(config, f) + + with self.assertRaises(Exception): + main(argv=["-c", str(config_fp), "mix"])