From 692d2338abdf7c6ca0f4a0e6d9efae5beed9ff51 Mon Sep 17 00:00:00 2001 From: romes Date: Sun, 28 Aug 2022 22:52:22 +0200 Subject: [PATCH] {find,canonicalize} vs unsafe{Find,Canonicalize} --- src/Data/Equality/Graph.hs | 39 ++++++++++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/src/Data/Equality/Graph.hs b/src/Data/Equality/Graph.hs index a4a0d7c..7c57dd6 100644 --- a/src/Data/Equality/Graph.hs +++ b/src/Data/Equality/Graph.hs @@ -60,10 +60,10 @@ import Data.Equality.Graph.Lens -- class it's already represented in will be returned. add :: forall l. Language l => ENode l -> EGraph l -> (ClassId, EGraph l) add uncanon_e egr = - let !new_en = {-# SCC "-2" #-} canonicalize uncanon_e egr + let !new_en = {-# SCC "-2" #-} unsafeCanonicalize uncanon_e egr in case {-# SCC "-1" #-} lookupNM new_en (memo egr) of - Just canon_enode_id -> {-# SCC "0" #-} (find canon_enode_id egr, egr) + Just canon_enode_id -> {-# SCC "0" #-} (unsafeFind canon_enode_id egr, egr) Nothing -> let @@ -136,8 +136,8 @@ merge a b egr0 = -- Use canonical ids let - a' = find a egr0 - b' = find b egr0 + a' = unsafeFind a egr0 + b' = unsafeFind b egr0 in if a' == b' then (a', egr0) @@ -229,7 +229,7 @@ rebuild (EGraph uf cls mm wl awl) = repair :: forall l. Language l => ENode l -> ClassId -> EGraph l -> EGraph l repair node repair_id egr = - case insertLookupNM (node `canonicalize` egr) (find repair_id egr) (deleteNM node $ memo egr) of-- TODO: I seem to really need it. Is find needed? (they don't use it) + case insertLookupNM (node `unsafeCanonicalize` egr) (unsafeFind repair_id egr) (deleteNM node $ memo egr) of-- TODO: I seem to really need it. Is find needed? (they don't use it) (Nothing, memo2) -> egr { memo = memo2 } -- Return new memo but delete uncanonicalized node @@ -240,7 +240,7 @@ repair node repair_id egr = repairAnal :: forall l. Language l => ENode l -> ClassId -> EGraph l -> EGraph l repairAnal node repair_id egr = let - canon_id = find repair_id egr + canon_id = unsafeFind repair_id egr c = egr^._class canon_id new_data = joinA @l (c^._data) (makeA node egr) in @@ -263,17 +263,36 @@ repairAnal node repair_id egr = -- that their e-class ids are represented by the same e-class canonical ids -- -- canonicalize(𝑓(𝑎,𝑏,𝑐,...)) = 𝑓((find 𝑎), (find 𝑏), (find 𝑐),...) -canonicalize :: Functor l => ENode l -> EGraph l -> ENode l -canonicalize (Node enode) eg = Node $ fmap (`find` eg) enode +-- +-- This will force the e-graph to be rebuilt, as canonicalizing a node... TODO +canonicalize :: Language l => ENode l -> EGraph l -> ENode l +canonicalize n = unsafeCanonicalize n . rebuild {-# SCC canonicalize #-} -- | Find the canonical representation of an e-class id in the e-graph +-- +-- This will force the e-graph to be rebuilt, as finding a canonical representation.... TODO +-- -- Invariant: The e-class id always exists. -find :: ClassId -> EGraph l -> ClassId -find cid = findRepr cid . unionFind +find :: Language l => ClassId -> EGraph l -> ClassId +find cid = findRepr cid . unionFind . rebuild {-# INLINE find #-} -- | The empty e-graph. Nothing is represented in it yet. emptyEGraph :: Language l => EGraph l emptyEGraph = EGraph emptyUF mempty mempty mempty mempty {-# INLINE emptyEGraph #-} + +-- | Like 'canonicalize' but doesn't force a rebuild. +-- +-- Should be used when +unsafeCanonicalize :: Functor l => ENode l -> EGraph l -> ENode l +unsafeCanonicalize (Node enode) eg = Node $ fmap (`unsafeFind` eg) enode +{-# SCC unsafeCanonicalize #-} + +-- | Find the canonical representation of an e-class id in the e-graph +-- +-- Invariant: The e-class id always exists. +unsafeFind :: ClassId -> EGraph l -> ClassId +unsafeFind cid = findRepr cid . unionFind +{-# INLINE unsafeFind #-}