Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@
as a GraphViz DOT graph. The format is reminiscent of the output produced by
the `egglog` library.

* The library now supports a new variant of `Rewrite`: computed rewrites
(spelled `:=>`). This allows the user to rewrite a given pattern
to a right-hand-side computed by a Haskell function of the matching
substitution and `EGraph`.


## 0.6.0.0 -- 2024-07-13

* Fix a soundness bug that would cause equality saturation to be broken when
Expand Down
2 changes: 1 addition & 1 deletion hegg.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ test-suite hegg-test
type: exitcode-stdio-1.0
hs-source-dirs: test
main-is: Test.hs
other-modules: Invariants, Sym, Lambda, SimpleSym,
other-modules: Computed, Invariants, Sym, Lambda, SimpleSym,
T1, T2, T3, T32, T45, T51
if flag(vizdot)
other-modules: VizDot
Expand Down
35 changes: 33 additions & 2 deletions src/Data/Equality/Saturation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ module Data.Equality.Saturation
) where

import qualified Data.IntMap.Strict as IM
import qualified Data.Map.Strict as M

import Control.Monad

Expand Down Expand Up @@ -168,7 +169,10 @@ runEqualitySaturation schd rewrites = runEqualitySaturation' 0 mempty where -- S
-- Accumulate conditions while recursing through conditional rewrites
go :: [RewriteCondition a l] -> Rewrite a l -> ([Match], IM.IntMap (Stat l schd), VarsState)
go conds (rw' :| cond) = go (cond:conds) rw'
go conds (lhs := _) = do
go conds (lhs :=> _) = doPattern conds lhs
go conds (lhs := _) = doPattern conds lhs

doPattern conds lhs = do
let (lhs_query, varsState) = compileToQuery lhs

case IM.lookup rw_id stats of
Expand Down Expand Up @@ -225,11 +229,38 @@ runEqualitySaturation schd rewrites = runEqualitySaturation' 0 mempty where -- S
_ <- merge eclass eclass'
return ()

-- | Represent a pattern in the e-graph a pattern given substitions
(_ :=> f, Match subst eclass, vss) -> do
egr <- get
let matchCtx = buildMatchContext vss subst egr
case f matchCtx of
Just rhs -> do
eclass' <- reprExpr rhs
_ <- merge eclass eclass'
return ()
Nothing ->
return ()

-- | Build the match context mapping variable names to their matched class info
buildMatchContext :: VarsState -> Subst -> G.EGraph a l -> M.Map String (MatchInfo a l)
buildMatchContext vss subst egr =
M.mapWithKey lookupInfo (varNames vss)
where
lookupInfo :: String -> Var -> MatchInfo a l
lookupInfo _name var =
let classId = findSubst var subst
canonId = G.find classId egr
eclass = egr ^. _class canonId
in MatchInfo (eclass ^. _data) (eclass ^. _nodes)

-- | Represent a pattern in the e-graph given substitutions
reprPat :: VarsState -> Subst -> l (Pattern l) -> EGraphM a l ClassId
reprPat vss subst = add . Node <=< traverse \case
VariablePattern v -> pure $
findSubst (findVarName vss v) subst
NonVariablePattern p -> reprPat vss subst p

-- | Represent an expression (Fix l) in the e-graph
reprExpr :: Fix l -> EGraphM a l ClassId
reprExpr (Fix e) = add . Node =<< traverse reprExpr e
{-# INLINEABLE runEqualitySaturation #-}

85 changes: 82 additions & 3 deletions src/Data/Equality/Saturation/Rewrites.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,20 @@ Rewrite rules are applied to all represented expressions in an e-graph every
iteration of equality saturation.

-}
module Data.Equality.Saturation.Rewrites where
module Data.Equality.Saturation.Rewrites
( Rewrite(..)
, RewriteCondition
, RewriteFun
, MatchInfo(..)
) where

import Data.Map.Strict (Map)
import Data.Set (Set)

import Data.Equality.Graph
import Data.Equality.Matching
import Data.Equality.Matching.Database
import Data.Equality.Utils (Fix)

-- | A rewrite rule that might have conditions for being applied
--
Expand All @@ -32,9 +41,49 @@ import Data.Equality.Matching.Database
--
-- See the definition of @is_not_zero@ in the documentation for
-- 'RewriteCondition'
data Rewrite anl lang = !(Pattern lang) := !(Pattern lang) -- ^ Trivial Rewrite
| !(Rewrite anl lang) :| !(RewriteCondition anl lang) -- ^ Conditional Rewrite
--
-- === __Soundness__
--
-- Conditional rewrites (':|') and computed rewrites (':=>') must satisfy
-- a /monotonicity/ property to ensure confluence:
--
-- Given expressions \(e_0\) and \(e_1\) matching the rule's LHS where
-- \(\mathit{analysis}(e_0) \sqsubseteq \mathit{analysis}(e_1)\) (with respect
-- to the analysis join semilattice), if the rewrite fires for \(e_0\), it
-- must also fire for \(e_1\).
--
-- In other words, learning more information (via the analysis) should never
-- cause a rewrite to stop firing.
--
data Rewrite anl lang
= !(Pattern lang) := !(Pattern lang)
-- ^ Trivial Rewrite
| !(Pattern lang) :=> (RewriteFun anl lang)
-- ^ A computed rewrite. The RHS is computed by a function that receives
-- a 'MatchInfo' for each pattern variable, providing access to the
-- analysis data and e-nodes in the matched e-classes.
--
-- The function returns 'Just' with the computed expression to add to the
-- e-graph, or 'Nothing' to skip this match.
--
-- === __Example: Constant Folding__
-- @
-- foldAdd :: Rewrite () Lang
-- foldAdd = pat (x \`Add\` y) :=> \\ctx -> do
-- MatchInfo _ xNodes <- Map.lookup "x" ctx
-- MatchInfo _ yNodes <- Map.lookup "y" ctx
-- Node (Lit xn) <- findLit xNodes
-- Node (Lit yn) <- findLit yNodes
-- pure $ Fix $ Lit (xn + yn)
-- @
--
-- The 'MatchInfo' interface ensures computed rewrites can only inspect
-- the matched e-classes, not arbitrary e-graph structure. This makes it
-- easier to write confluent rewrites.
| !(Rewrite anl lang) :| !(RewriteCondition anl lang)
-- ^ Conditional Rewrite
infix 3 :=
infixl 3 :=>
infixl 2 :|

-- | A rewrite condition. With a substitution from bound variables in the
Expand All @@ -51,7 +100,37 @@ infixl 2 :|
-- @
type RewriteCondition anl lang = VarsState -> Subst -> EGraph anl lang -> Bool

-- | Information about a matched pattern variable, providing access to the
-- analysis data and nodes in the matched e-class.
--
-- This is the restricted interface provided to 'RewriteFun', ensuring that
-- computed rewrites can only inspect matched classes rather than arbitrary
-- e-graph structure.
data MatchInfo anl lang = MatchInfo
{ matchAnalysis :: anl
-- ^ The analysis data for the matched e-class
, matchNodes :: Set (ENode lang)
-- ^ The e-nodes in the matched e-class
}

-- | A function to compute the RHS of a rewrite.
--
-- The function receives a 'Map' from pattern variable names to 'MatchInfo',
-- containing the analysis data and nodes for each matched e-class.
--
-- Return 'Just' with the computed expression to add it to the e-graph and
-- merge it with the matched e-class, or 'Nothing' to skip this match.
--
-- __Soundness requirement__: The returned expression must be semantically
-- equivalent to the expression matched by the LHS pattern.
--
-- __Confluence requirement__: The result should be deterministic — it should
-- depend only on the semantic content of the matched classes (analysis data,
-- node values), not on incidental details like iteration order over sets.
type RewriteFun anl lang = Map String (MatchInfo anl lang) -> Maybe (Fix lang)


instance (∀ a. Show a => Show (lang a)) => Show (Rewrite anl lang) where
show (rw :| _) = show rw <> " :| <cond>"
show (lhs := rhs) = show lhs <> " := " <> show rhs
show (lhs :=> _) = show lhs <> " :=> <fun>"
48 changes: 48 additions & 0 deletions test/Computed.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE OverloadedStrings #-}
module Computed where

-- Tests computed rewrites

import Prelude hiding (not)

import Test.Tasty.HUnit
import Data.Equality.Graph.Nodes (ENode(..))
import Data.Equality.Matching (Pattern, pat)
import Data.Equality.Extraction
import Data.Equality.Saturation.Rewrites
import Data.Equality.Saturation
import Data.Maybe (listToMaybe)
import qualified Data.Map.Strict as M
import qualified Data.Set as S

data Lang a = Add a a
| Lit Int
deriving (Functor, Foldable, Traversable, Eq, Ord, Show)

x, y :: Pattern Lang
x = "x"
y = "y"

foldAddRule :: Rewrite () Lang
foldAddRule = pat (x `Add` y) :=> f
where
f :: RewriteFun () Lang
f ctx = do
xn <- isLit =<< M.lookup "x" ctx
yn <- isLit =<< M.lookup "y" ctx
return $ Fix $ Lit (xn + yn)

isLit :: MatchInfo () Lang -> Maybe Int
isLit (MatchInfo _ nodes) =
listToMaybe [ n | Node (Lit n) <- S.toList nodes ]

rules :: [Rewrite () Lang]
rules = [foldAddRule]

main :: IO ()
main = do
fst (equalitySaturation (Fix $ (Fix $ Lit 1) `Add` (Fix $ Lit 1)) rules depthCost) @?= Fix (Lit 2)
pure ()

2 changes: 2 additions & 0 deletions test/Test.hs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import Lambda
import SimpleSym
import T32

import qualified Computed
import qualified T1
import qualified T2
import qualified T3
Expand All @@ -30,6 +31,7 @@ tests =testGroup "Tests"
, T45.testT45
, invariants
, T51.testConditionalBan
, testCase "Computed" Computed.main
, testCase "T1" (T1.main `catch` (\(e :: SomeException) -> assertFailure (show e)))
, testCase "T2" (T2.main `catch` (\(e :: SomeException) -> assertFailure (show e)))
, testCase "T3" (T3.main `catch` (\(e :: SomeException) -> assertFailure (show e)))
Expand Down
Loading