From 0f713c2b9de282c96f969c7b00005ceae50ab47e Mon Sep 17 00:00:00 2001 From: morsecodist Date: Mon, 3 Jul 2023 08:28:03 -0700 Subject: [PATCH 1/5] find iter lifetimes --- src/core/src/index/mod.rs | 8 +-- src/core/src/index/sbt/mod.rs | 110 +++++++++++++++++++++++++++------- 2 files changed, 94 insertions(+), 24 deletions(-) diff --git a/src/core/src/index/mod.rs b/src/core/src/index/mod.rs index 4e43074ebe..8f721a5831 100644 --- a/src/core/src/index/mod.rs +++ b/src/core/src/index/mod.rs @@ -44,9 +44,9 @@ pub trait Index<'a> { //type SignatureIterator: Iterator; fn find( - &self, + &'a self, search_fn: F, - sig: &Self::Item, + sig: &'a Self::Item, threshold: f64, ) -> Result, Error> where @@ -66,8 +66,8 @@ pub trait Index<'a> { } fn search( - &self, - sig: &Self::Item, + &'a self, + sig: &'a Self::Item, threshold: f64, containment: bool, ) -> Result, Error> { diff --git a/src/core/src/index/sbt/mod.rs b/src/core/src/index/sbt/mod.rs index 5245defe1f..c7945bd572 100644 --- a/src/core/src/index/sbt/mod.rs +++ b/src/core/src/index/sbt/mod.rs @@ -374,43 +374,113 @@ where } } -impl<'a, N, L> Index<'a> for SBT +pub struct SBTFindIter<'a, N, L, F> where N: Comparable + Comparable + Update + Debug + Default, L: Comparable + Update + Clone + Debug + Default, + F: Fn(&dyn Comparable, &L, f64) -> bool, SBT: FromFactory, SigStore: From + ReadData, { - type Item = L; + queue: Vec, + visited: HashSet, + sbt: &'a SBT, + search_fn: F, + sig: &'a L, + threshold: f64, +} - fn find(&self, search_fn: F, sig: &L, threshold: f64) -> Result, Error> - where - F: Fn(&dyn Comparable, &Self::Item, f64) -> bool, - { - let mut matches = Vec::new(); - let mut visited = HashSet::new(); - let mut queue = vec![0u64]; +impl<'a, N, L, F> SBTFindIter<'a, N, L, F> +where + N: Comparable + Comparable + Update + Debug + Default, + L: Comparable + Update + Clone + Debug + Default, + F: Fn(&dyn Comparable, &L, f64) -> bool, + SBT: FromFactory, + SigStore: From + ReadData, +{ + pub fn new(sbt: &'a SBT, search_fn: F, sig: &'a L, threshold: f64) -> Self { + SBTFindIter { + queue: vec![0u64], + visited: HashSet::new(), + sbt, + search_fn, + sig, + threshold, + } + } +} - while let Some(pos) = queue.pop() { - if !visited.contains(&pos) { - visited.insert(pos); +impl<'a, N, L, F> Iterator for SBTFindIter<'a, N, L, F> +where + N: Comparable + Comparable + Update + Debug + Default, + L: Comparable + Update + Clone + Debug + Default, + F: Fn(&dyn Comparable, &L, f64) -> bool, + SBT: FromFactory, + SigStore: From + ReadData, +{ + type Item = &'a L; + + fn next(&mut self) -> Option { + while let Some(pos) = self.queue.pop() { + if !self.visited.contains(&pos) { + self.visited.insert(pos); - if let Some(node) = self.nodes.get(&pos) { - if search_fn(&node, sig, threshold) { - for c in self.children(pos) { - queue.push(c); + if let Some(node) = self.sbt.nodes.get(&pos) { + if (self.search_fn)(&node, self.sig, self.threshold) { + for c in self.sbt.children(pos) { + self.queue.push(c); } } - } else if let Some(leaf) = self.leaves.get(&pos) { + } else if let Some(leaf) = self.sbt.leaves.get(&pos) { let data = leaf.data().expect("Error reading data"); - if search_fn(data, sig, threshold) { - matches.push(data); + if (self.search_fn)(data, self.sig, self.threshold) { + return Some(data); } } } } + None + } +} - Ok(matches) +impl<'a, N, L> SBT +where + N: Comparable + Comparable + Update + Debug + Default, + L: Comparable + Update + Clone + Debug + Default, + SBT: FromFactory, + SigStore: From + ReadData, +{ + pub fn find_iter(&'a self, search_fn: impl Fn(&dyn Comparable, &L, f64) -> bool, sig: &'a L, threshold: f64) -> SBTFindIter<'a, N, L, impl Fn(&dyn Comparable, &L, f64) -> bool> { + SBTFindIter::new(self, search_fn, sig, threshold) + } + + pub fn find_any(&'a self, search_fn: impl Fn(&dyn Comparable, &L, f64) -> bool, sig: &'a L, threshold: f64) -> bool { + SBTFindIter::new(self, search_fn, sig, threshold).next().is_some() + } + + pub fn find_one(&'a self, search_fn: impl Fn(&dyn Comparable, &L, f64) -> bool, sig: &'a L, threshold: f64) -> Option<&'a L> { + SBTFindIter::new(self, search_fn, sig, threshold).next() + } + + pub fn find_n(&'a self, search_fn: impl Fn(&dyn Comparable, &L, f64) -> bool, sig: &'a L, threshold: f64) -> std::iter::Take, &L, f64) -> bool>> { + SBTFindIter::new(self, search_fn, sig, threshold).take(n) + } +} + +impl<'a, N, L> Index<'a> for SBT +where + N: Comparable + Comparable + Update + Debug + Default, + L: Comparable + Update + Clone + Debug + Default, + SBT: FromFactory, + SigStore: From + ReadData, +{ + type Item = L; + + fn find(&'a self, search_fn: F, sig: &'a L, threshold: f64) -> Result, Error> + where + F: Fn(&dyn Comparable, &Self::Item, f64) -> bool, + { + Ok(self.find_iter(search_fn, sig, threshold).collect()) } fn insert(&mut self, dataset: L) -> Result<(), Error> { From fd17e0b35c179f7adb31f2405ab2a0d9962218ae Mon Sep 17 00:00:00 2001 From: morsecodist Date: Mon, 3 Jul 2023 08:39:00 -0700 Subject: [PATCH 2/5] docs --- src/core/src/index/sbt/mod.rs | 65 ++++++++++++++++++++++++++++------- 1 file changed, 52 insertions(+), 13 deletions(-) diff --git a/src/core/src/index/sbt/mod.rs b/src/core/src/index/sbt/mod.rs index c7945bd572..51e4b98788 100644 --- a/src/core/src/index/sbt/mod.rs +++ b/src/core/src/index/sbt/mod.rs @@ -374,7 +374,7 @@ where } } -pub struct SBTFindIter<'a, N, L, F> +pub struct SBTFindIter<'a, N, L, F> where N: Comparable + Comparable + Update + Debug + Default, L: Comparable + Update + Clone + Debug + Default, @@ -410,8 +410,8 @@ where } } -impl<'a, N, L, F> Iterator for SBTFindIter<'a, N, L, F> -where +impl<'a, N, L, F> Iterator for SBTFindIter<'a, N, L, F> +where N: Comparable + Comparable + Update + Debug + Default, L: Comparable + Update + Clone + Debug + Default, F: Fn(&dyn Comparable, &L, f64) -> bool, @@ -450,21 +450,60 @@ where SBT: FromFactory, SigStore: From + ReadData, { - pub fn find_iter(&'a self, search_fn: impl Fn(&dyn Comparable, &L, f64) -> bool, sig: &'a L, threshold: f64) -> SBTFindIter<'a, N, L, impl Fn(&dyn Comparable, &L, f64) -> bool> { + /// Creates an iterator that will visit each element in the SBT that satisfies + /// the provided `search_fn` function. + /// + /// # Arguments + /// + /// * `search_fn` - A function that takes a reference to a `Comparable` item, + /// a reference to an `L` item, and a `f64` threshold value, and returns true if the comprable elements are within `threshold`. + /// * `sig` - A signature against which the `search_fn` will compare each item. + /// * `threshold` - A threshold value passed to `search_fn`. + pub fn find_iter( + &'a self, + search_fn: impl Fn(&dyn Comparable, &L, f64) -> bool, + sig: &'a L, + threshold: f64, + ) -> SBTFindIter<'a, N, L, impl Fn(&dyn Comparable, &L, f64) -> bool> { SBTFindIter::new(self, search_fn, sig, threshold) } - pub fn find_any(&'a self, search_fn: impl Fn(&dyn Comparable, &L, f64) -> bool, sig: &'a L, threshold: f64) -> bool { - SBTFindIter::new(self, search_fn, sig, threshold).next().is_some() - } - - pub fn find_one(&'a self, search_fn: impl Fn(&dyn Comparable, &L, f64) -> bool, sig: &'a L, threshold: f64) -> Option<&'a L> { + /// Checks if there exists at least one element in the SBT that satisfies the provided `search_fn` function for the given threshold. + /// + /// # Arguments + /// + /// * `search_fn` - A function that takes a reference to a `Comparable` item, + /// a reference to an `L` item, and a `f64` threshold value, and returns true if the comprable elements are within `threshold`. + /// * `sig` - A signature against which the `search_fn` will compare each item. + /// * `threshold` - A threshold value passed to `search_fn`. + pub fn find_any( + &'a self, + search_fn: impl Fn(&dyn Comparable, &L, f64) -> bool, + sig: &'a L, + threshold: f64, + ) -> bool { + SBTFindIter::new(self, search_fn, sig, threshold) + .next() + .is_some() + } + + /// Finds an element in the SBT that satisfies the provided `search_fn` function for the given threshold. + /// There is no guarantee that the element returned is the closest to the provided signature. + /// + /// # Arguments + /// + /// * `search_fn` - A function that takes a reference to a `Comparable` item, + /// a reference to an `L` item, and a `f64` threshold value, and returns true if the comprable elements are within `threshold`. + /// * `sig` - A signature against which the `search_fn` will compare each item. + /// * `threshold` - A threshold value passed to `search_fn`. + pub fn find_one( + &'a self, + search_fn: impl Fn(&dyn Comparable, &L, f64) -> bool, + sig: &'a L, + threshold: f64, + ) -> Option<&'a L> { SBTFindIter::new(self, search_fn, sig, threshold).next() } - - pub fn find_n(&'a self, search_fn: impl Fn(&dyn Comparable, &L, f64) -> bool, sig: &'a L, threshold: f64) -> std::iter::Take, &L, f64) -> bool>> { - SBTFindIter::new(self, search_fn, sig, threshold).take(n) - } } impl<'a, N, L> Index<'a> for SBT From bc854a5e73a07d59b1952e1966ba03c93863804c Mon Sep 17 00:00:00 2001 From: morsecodist Date: Mon, 3 Jul 2023 09:02:51 -0700 Subject: [PATCH 3/5] author --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index b9315c71e6..5c9faa574e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ authors = [ { name="Katrin Leinweber", orcid="0000-0001-5135-5758" }, { name="Marisa Lim", orcid="0000-0003-2097-8818" }, { name="Ricky Lim", orcid="0000-0003-1313-7076" }, + { name="Todd Morse", orcid="0009-0007-1711-5938" }, { name="Ivan Ogasawara", orcid="0000-0001-5049-4289" }, { name="N. Tessa Pierce", orcid="0000-0002-2942-5331" }, { name="Taylor Reiter", orcid="0000-0002-7388-421X" }, From 857404f0e3762f5bfa2a171000ed75f619cf133e Mon Sep 17 00:00:00 2001 From: morsecodist Date: Mon, 3 Jul 2023 09:03:30 -0700 Subject: [PATCH 4/5] initial --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 5c9faa574e..7e6353fedf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ authors = [ { name="Katrin Leinweber", orcid="0000-0001-5135-5758" }, { name="Marisa Lim", orcid="0000-0003-2097-8818" }, { name="Ricky Lim", orcid="0000-0003-1313-7076" }, - { name="Todd Morse", orcid="0009-0007-1711-5938" }, + { name="R. Todd Morse", orcid="0009-0007-1711-5938" }, { name="Ivan Ogasawara", orcid="0000-0001-5049-4289" }, { name="N. Tessa Pierce", orcid="0000-0002-2942-5331" }, { name="Taylor Reiter", orcid="0000-0002-7388-421X" }, From 97030e1d86519e777dc13c509ec2b0fd5644d527 Mon Sep 17 00:00:00 2001 From: morsecodist Date: Mon, 3 Jul 2023 12:52:34 -0700 Subject: [PATCH 5/5] test updates --- src/core/src/index/sbt/mhbt.rs | 35 ++++++++++++++++++++++++++++++++++ src/core/src/index/sbt/mod.rs | 6 ++---- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/src/core/src/index/sbt/mhbt.rs b/src/core/src/index/sbt/mhbt.rs index 2d4ceb3fb8..f7c310f8e4 100644 --- a/src/core/src/index/sbt/mhbt.rs +++ b/src/core/src/index/sbt/mhbt.rs @@ -331,6 +331,41 @@ mod test { Ok(()) } + #[test] + #[ignore] + fn find_one_or_any_sbt() -> Result<(), Box> { + let mut filename = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + filename.push("../../tests/test-data/v5.sbt.json"); + + let sbt = MHBT::from_path(filename)?; + + let mut filename = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + filename.push("../../tests/test-data/.sbt.v3/60f7e23c24a8d94791cc7a8680c493f9"); + + let mut reader = BufReader::new(File::open(filename)?); + let sigs = Signature::load_signatures( + &mut reader, + Some(31), + Some("DNA".try_into().unwrap()), + None, + )?; + let sig_data = sigs[0].clone(); + + let leaf: SigStore<_> = sig_data.into(); + + let find_results = sbt.find(search_minhashes, &leaf, 0.5)?; + assert_eq!(find_results.len(), 1); + let find_one_results = sbt.find_one(search_minhashes, &leaf, 0.5); + assert!(find_one_results.is_some()); + assert_eq!(find_results[0], find_one_results.unwrap()); + assert!(sbt.find_any(search_minhashes, &leaf, 0.5)); + + assert!(sbt.find_one(|_, _, _| false, &leaf, 0.9).is_none()); + assert!(!sbt.find_any(|_, _, _| false, &leaf, 0.9)); + + Ok(()) + } + #[test] fn scaffold_sbt() { let mut filename = PathBuf::from(env!("CARGO_MANIFEST_DIR")); diff --git a/src/core/src/index/sbt/mod.rs b/src/core/src/index/sbt/mod.rs index 51e4b98788..1d88bc4ccc 100644 --- a/src/core/src/index/sbt/mod.rs +++ b/src/core/src/index/sbt/mod.rs @@ -482,9 +482,7 @@ where sig: &'a L, threshold: f64, ) -> bool { - SBTFindIter::new(self, search_fn, sig, threshold) - .next() - .is_some() + self.find_one(search_fn, sig, threshold).is_some() } /// Finds an element in the SBT that satisfies the provided `search_fn` function for the given threshold. @@ -502,7 +500,7 @@ where sig: &'a L, threshold: f64, ) -> Option<&'a L> { - SBTFindIter::new(self, search_fn, sig, threshold).next() + self.find_iter(search_fn, sig, threshold).next() } }