diff --git a/CHANGELOG.md b/CHANGELOG.md index 8308ff1..f3c05f3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,12 @@ as a GraphViz DOT graph. The format is reminiscent of the output produced by the `egglog` library. +* The library now supports a new variant of `Rewrite`: computed rewrites + (spelled `:=>`). This allows the user to rewrite a given pattern + to a right-hand-side computed by a Haskell function of the matching + substitution and `EGraph`. + + ## 0.6.0.0 -- 2024-07-13 * Fix a soundness bug that would cause equality saturation to be broken when diff --git a/hegg.cabal b/hegg.cabal index 82e587d..76a07e7 100644 --- a/hegg.cabal +++ b/hegg.cabal @@ -117,7 +117,7 @@ test-suite hegg-test type: exitcode-stdio-1.0 hs-source-dirs: test main-is: Test.hs - other-modules: Invariants, Sym, Lambda, SimpleSym, + other-modules: Computed, Invariants, Sym, Lambda, SimpleSym, T1, T2, T3, T32, T45, T51 if flag(vizdot) other-modules: VizDot diff --git a/src/Data/Equality/Saturation.hs b/src/Data/Equality/Saturation.hs index 60e97fc..6040a19 100644 --- a/src/Data/Equality/Saturation.hs +++ b/src/Data/Equality/Saturation.hs @@ -45,6 +45,7 @@ module Data.Equality.Saturation ) where import qualified Data.IntMap.Strict as IM +import qualified Data.Map.Strict as M import Control.Monad @@ -168,7 +169,10 @@ runEqualitySaturation schd rewrites = runEqualitySaturation' 0 mempty where -- S -- Accumulate conditions while recursing through conditional rewrites go :: [RewriteCondition a l] -> Rewrite a l -> ([Match], IM.IntMap (Stat l schd), VarsState) go conds (rw' :| cond) = go (cond:conds) rw' - go conds (lhs := _) = do + go conds (lhs :=> _) = doPattern conds lhs + go conds (lhs := _) = doPattern conds lhs + + doPattern conds lhs = do let (lhs_query, varsState) = compileToQuery lhs case IM.lookup rw_id stats of @@ -225,11 +229,38 @@ runEqualitySaturation schd rewrites = runEqualitySaturation' 0 mempty where -- S _ <- merge eclass eclass' return () - -- | Represent a pattern in the e-graph a pattern given substitions + (_ :=> f, Match subst eclass, vss) -> do + egr <- get + let matchCtx = buildMatchContext vss subst egr + case f matchCtx of + Just rhs -> do + eclass' <- reprExpr rhs + _ <- merge eclass eclass' + return () + Nothing -> + return () + + -- | Build the match context mapping variable names to their matched class info + buildMatchContext :: VarsState -> Subst -> G.EGraph a l -> M.Map String (MatchInfo a l) + buildMatchContext vss subst egr = + M.mapWithKey lookupInfo (varNames vss) + where + lookupInfo :: String -> Var -> MatchInfo a l + lookupInfo _name var = + let classId = findSubst var subst + canonId = G.find classId egr + eclass = egr ^. _class canonId + in MatchInfo (eclass ^. _data) (eclass ^. _nodes) + + -- | Represent a pattern in the e-graph given substitutions reprPat :: VarsState -> Subst -> l (Pattern l) -> EGraphM a l ClassId reprPat vss subst = add . Node <=< traverse \case VariablePattern v -> pure $ findSubst (findVarName vss v) subst NonVariablePattern p -> reprPat vss subst p + + -- | Represent an expression (Fix l) in the e-graph + reprExpr :: Fix l -> EGraphM a l ClassId + reprExpr (Fix e) = add . Node =<< traverse reprExpr e {-# INLINEABLE runEqualitySaturation #-} diff --git a/src/Data/Equality/Saturation/Rewrites.hs b/src/Data/Equality/Saturation/Rewrites.hs index df216a9..6f94bd3 100644 --- a/src/Data/Equality/Saturation/Rewrites.hs +++ b/src/Data/Equality/Saturation/Rewrites.hs @@ -7,11 +7,20 @@ Rewrite rules are applied to all represented expressions in an e-graph every iteration of equality saturation. -} -module Data.Equality.Saturation.Rewrites where +module Data.Equality.Saturation.Rewrites + ( Rewrite(..) + , RewriteCondition + , RewriteFun + , MatchInfo(..) + ) where + +import Data.Map.Strict (Map) +import Data.Set (Set) import Data.Equality.Graph import Data.Equality.Matching import Data.Equality.Matching.Database +import Data.Equality.Utils (Fix) -- | A rewrite rule that might have conditions for being applied -- @@ -32,9 +41,49 @@ import Data.Equality.Matching.Database -- -- See the definition of @is_not_zero@ in the documentation for -- 'RewriteCondition' -data Rewrite anl lang = !(Pattern lang) := !(Pattern lang) -- ^ Trivial Rewrite - | !(Rewrite anl lang) :| !(RewriteCondition anl lang) -- ^ Conditional Rewrite +-- +-- === __Soundness__ +-- +-- Conditional rewrites (':|') and computed rewrites (':=>') must satisfy +-- a /monotonicity/ property to ensure confluence: +-- +-- Given expressions \(e_0\) and \(e_1\) matching the rule's LHS where +-- \(\mathit{analysis}(e_0) \sqsubseteq \mathit{analysis}(e_1)\) (with respect +-- to the analysis join semilattice), if the rewrite fires for \(e_0\), it +-- must also fire for \(e_1\). +-- +-- In other words, learning more information (via the analysis) should never +-- cause a rewrite to stop firing. +-- +data Rewrite anl lang + = !(Pattern lang) := !(Pattern lang) + -- ^ Trivial Rewrite + | !(Pattern lang) :=> (RewriteFun anl lang) + -- ^ A computed rewrite. The RHS is computed by a function that receives + -- a 'MatchInfo' for each pattern variable, providing access to the + -- analysis data and e-nodes in the matched e-classes. + -- + -- The function returns 'Just' with the computed expression to add to the + -- e-graph, or 'Nothing' to skip this match. + -- + -- === __Example: Constant Folding__ + -- @ + -- foldAdd :: Rewrite () Lang + -- foldAdd = pat (x \`Add\` y) :=> \\ctx -> do + -- MatchInfo _ xNodes <- Map.lookup "x" ctx + -- MatchInfo _ yNodes <- Map.lookup "y" ctx + -- Node (Lit xn) <- findLit xNodes + -- Node (Lit yn) <- findLit yNodes + -- pure $ Fix $ Lit (xn + yn) + -- @ + -- + -- The 'MatchInfo' interface ensures computed rewrites can only inspect + -- the matched e-classes, not arbitrary e-graph structure. This makes it + -- easier to write confluent rewrites. + | !(Rewrite anl lang) :| !(RewriteCondition anl lang) + -- ^ Conditional Rewrite infix 3 := +infixl 3 :=> infixl 2 :| -- | A rewrite condition. With a substitution from bound variables in the @@ -51,7 +100,37 @@ infixl 2 :| -- @ type RewriteCondition anl lang = VarsState -> Subst -> EGraph anl lang -> Bool +-- | Information about a matched pattern variable, providing access to the +-- analysis data and nodes in the matched e-class. +-- +-- This is the restricted interface provided to 'RewriteFun', ensuring that +-- computed rewrites can only inspect matched classes rather than arbitrary +-- e-graph structure. +data MatchInfo anl lang = MatchInfo + { matchAnalysis :: anl + -- ^ The analysis data for the matched e-class + , matchNodes :: Set (ENode lang) + -- ^ The e-nodes in the matched e-class + } + +-- | A function to compute the RHS of a rewrite. +-- +-- The function receives a 'Map' from pattern variable names to 'MatchInfo', +-- containing the analysis data and nodes for each matched e-class. +-- +-- Return 'Just' with the computed expression to add it to the e-graph and +-- merge it with the matched e-class, or 'Nothing' to skip this match. +-- +-- __Soundness requirement__: The returned expression must be semantically +-- equivalent to the expression matched by the LHS pattern. +-- +-- __Confluence requirement__: The result should be deterministic — it should +-- depend only on the semantic content of the matched classes (analysis data, +-- node values), not on incidental details like iteration order over sets. +type RewriteFun anl lang = Map String (MatchInfo anl lang) -> Maybe (Fix lang) + instance (∀ a. Show a => Show (lang a)) => Show (Rewrite anl lang) where show (rw :| _) = show rw <> " :| " show (lhs := rhs) = show lhs <> " := " <> show rhs + show (lhs :=> _) = show lhs <> " :=> " diff --git a/test/Computed.hs b/test/Computed.hs new file mode 100644 index 0000000..07181d1 --- /dev/null +++ b/test/Computed.hs @@ -0,0 +1,48 @@ +{-# LANGUAGE DeriveTraversable #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE OverloadedStrings #-} +module Computed where + +-- Tests computed rewrites + +import Prelude hiding (not) + +import Test.Tasty.HUnit +import Data.Equality.Graph.Nodes (ENode(..)) +import Data.Equality.Matching (Pattern, pat) +import Data.Equality.Extraction +import Data.Equality.Saturation.Rewrites +import Data.Equality.Saturation +import Data.Maybe (listToMaybe) +import qualified Data.Map.Strict as M +import qualified Data.Set as S + +data Lang a = Add a a + | Lit Int + deriving (Functor, Foldable, Traversable, Eq, Ord, Show) + +x, y :: Pattern Lang +x = "x" +y = "y" + +foldAddRule :: Rewrite () Lang +foldAddRule = pat (x `Add` y) :=> f + where + f :: RewriteFun () Lang + f ctx = do + xn <- isLit =<< M.lookup "x" ctx + yn <- isLit =<< M.lookup "y" ctx + return $ Fix $ Lit (xn + yn) + + isLit :: MatchInfo () Lang -> Maybe Int + isLit (MatchInfo _ nodes) = + listToMaybe [ n | Node (Lit n) <- S.toList nodes ] + +rules :: [Rewrite () Lang] +rules = [foldAddRule] + +main :: IO () +main = do + fst (equalitySaturation (Fix $ (Fix $ Lit 1) `Add` (Fix $ Lit 1)) rules depthCost) @?= Fix (Lit 2) + pure () + diff --git a/test/Test.hs b/test/Test.hs index cfffb87..91680ae 100644 --- a/test/Test.hs +++ b/test/Test.hs @@ -13,6 +13,7 @@ import Lambda import SimpleSym import T32 +import qualified Computed import qualified T1 import qualified T2 import qualified T3 @@ -30,6 +31,7 @@ tests =testGroup "Tests" , T45.testT45 , invariants , T51.testConditionalBan + , testCase "Computed" Computed.main , testCase "T1" (T1.main `catch` (\(e :: SomeException) -> assertFailure (show e))) , testCase "T2" (T2.main `catch` (\(e :: SomeException) -> assertFailure (show e))) , testCase "T3" (T3.main `catch` (\(e :: SomeException) -> assertFailure (show e)))