From 9dee755fd363ec12d8602b05cc51e62b79cd24fc Mon Sep 17 00:00:00 2001 From: Eytan Singher Date: Sat, 22 Jun 2024 09:29:36 +0300 Subject: [PATCH] Closes #316; Added test for existance + fix bad connection choice and bad drop of `rest_if_proof` --- src/explain.rs | 51 ++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 39 insertions(+), 12 deletions(-) diff --git a/src/explain.rs b/src/explain.rs index c42eb635..7a9e467b 100644 --- a/src/explain.rs +++ b/src/explain.rs @@ -1,19 +1,18 @@ -use crate::Symbol; -use crate::{ - util::pretty_print, Analysis, EClass, ENodeOrVar, FromOp, HashMap, HashSet, Id, Language, - PatternAst, RecExpr, Rewrite, UnionFind, Var, -}; - use std::cmp::Ordering; use std::collections::{BinaryHeap, VecDeque}; use std::fmt::{self, Debug, Display, Formatter}; use std::ops::{Deref, DerefMut}; use std::rc::Rc; -use symbolic_expressions::Sexp; - use num_bigint::BigUint; use num_traits::identities::{One, Zero}; +use symbolic_expressions::Sexp; + +use crate::{ + Analysis, EClass, ENodeOrVar, FromOp, HashMap, HashSet, Id, Language, PatternAst, + RecExpr, Rewrite, UnionFind, util::pretty_print, Var, +}; +use crate::Symbol; type ProofCost = BigUint; @@ -1276,18 +1275,26 @@ impl<'x, L: Language> ExplainNodes<'x, L> { if graphnode.parent_connection.next == existance || existance_node.parent_connection.next == node { - let mut connection = graphnode.parent_connection.clone(); + let mut connection = if graphnode.parent_connection.next == existance { + graphnode.parent_connection.clone() + } else { + existance_node.parent_connection.clone() + }; if graphnode.parent_connection.next == existance { connection.is_rewrite_forward = !connection.is_rewrite_forward; std::mem::swap(&mut connection.next, &mut connection.current); } - return self.explain_enode_existance( + + let adj = self.explain_adjacent(connection, cache, enode_cache, false); + let mut exp = self.explain_enode_existance( existance, - self.explain_adjacent(connection, cache, enode_cache, false), + adj, cache, enode_cache, ); + exp.push(rest_of_proof); + return exp; } // case 3) @@ -1944,7 +1951,6 @@ impl<'x, L: Language> ExplainNodes<'x, L> { #[cfg(test)] mod tests { - use super::super::*; #[test] @@ -2049,6 +2055,27 @@ mod tests { egraph.dot().to_dot("target/foo.dot").unwrap(); } + + #[test] + fn simple_explain_exists() { + //! Same as previous test, but now I want to make a rewrite add some term and see it exists in + //! more then one step + use crate::SymbolLang; + init_logger(); + + let rws: Vec> = [rewrite!("makeb"; "a" => "b"), rewrite!("makec"; "b" => "c")].iter().cloned().collect(); + let mut egraph = Runner::default().with_explanations_enabled() + .without_explanation_length_optimization() + .with_expr(&"a".parse().unwrap()) + .run(&rws).egraph; + egraph.rebuild(); + let a: Symbol = "a".parse().unwrap(); + let b: Symbol = "b".parse().unwrap(); + let c: Symbol = "c".parse().unwrap(); + let mut exp = egraph.explain_existance(&"c".parse().unwrap()); + println!("{:?}", exp.make_flat_explanation()); + assert_eq!(exp.make_flat_explanation().len(), 3, "Expected 3 steps, got {:?}", exp.make_flat_explanation()); + } } #[test]