Better constant folding
This commit is contained in:
parent
a93439dfc7
commit
9e69317d7b
|
@ -1,5 +1,5 @@
|
|||
-- | Evaluate constant expressions.
|
||||
module EvalConstants where
|
||||
module EvalConstants (constantFold) where
|
||||
|
||||
import Control.Monad.Error
|
||||
import Control.Monad.Identity
|
||||
|
@ -13,17 +13,32 @@ import Numeric
|
|||
import qualified AST as A
|
||||
import Metadata
|
||||
import ParseState
|
||||
import Pass
|
||||
import Types
|
||||
|
||||
-- | Simplify an expression by constant folding, and also return whether it's a
|
||||
-- constant after that.
|
||||
constantFold :: PSM m => A.Expression -> m (A.Expression, Bool, String)
|
||||
constantFold e
|
||||
= do ps <- get
|
||||
let (e', msg) = case simplifyExpression ps e of
|
||||
Left err -> (e, err)
|
||||
Right val -> (val, "already folded")
|
||||
return (e', isConstant e', msg)
|
||||
|
||||
-- | Is an expression a constant literal?
|
||||
isConstant :: A.Expression -> Bool
|
||||
-- Array literals are only constant if all their components are.
|
||||
isConstant (A.ExprLiteral _ (A.Literal _ _ (A.ArrayLiteral _ es)))
|
||||
= and $ map isConstant es
|
||||
isConstant (A.ExprLiteral _ _) = True
|
||||
isConstant (A.True _) = True
|
||||
isConstant (A.False _) = True
|
||||
isConstant _ = False
|
||||
|
||||
-- | Attempt to simplify an expression as far as possible by precomputing
|
||||
-- constant bits.
|
||||
simplifyExpression :: ParseState -> A.Expression -> Either String A.Expression
|
||||
-- Non-array literals are "simple" already.
|
||||
simplifyExpression _ e@(A.ExprLiteral _ (A.Literal _ _ (A.ArrayLiteral _ _)))
|
||||
= Left "array literal"
|
||||
simplifyExpression _ e@(A.ExprLiteral _ _) = Right e
|
||||
simplifyExpression _ e@(A.True _) = Right e
|
||||
simplifyExpression _ e@(A.False _) = Right e
|
||||
simplifyExpression ps e
|
||||
= case runIdentity (evalStateT (runErrorT (evalExpression e)) ps) of
|
||||
Left err -> Left err
|
||||
|
|
|
@ -512,16 +512,16 @@ scopeOutRep (A.For m n b c) = scopeOut n
|
|||
scopeInSpec :: A.Specification -> OccParser A.Specification
|
||||
scopeInSpec (A.Specification m n st)
|
||||
= do ps <- getState
|
||||
let (st', isConst) = case st of
|
||||
(A.IsExpr m A.ValAbbrev t e) ->
|
||||
case simplifyExpression ps e of
|
||||
Left _ -> (st, False)
|
||||
Right e' -> (A.IsExpr m A.ValAbbrev t e', True)
|
||||
_ -> (st, False)
|
||||
(st', isConst) <- case st of
|
||||
(A.IsExpr m A.ValAbbrev t e) ->
|
||||
do (e', isConst, msg) <- constantFold e
|
||||
if isConst
|
||||
then return (A.IsExpr m A.ValAbbrev t e', True)
|
||||
else return (st, False)
|
||||
_ -> return (st, False)
|
||||
n' <- scopeIn n st' (abbrevModeOfSpec st')
|
||||
if isConst
|
||||
then updateState (\ps -> ps { psConstants = (A.nameName n', case st' of A.IsExpr _ _ _ e' -> e') : psConstants ps })
|
||||
else return ()
|
||||
when isConst $
|
||||
updateState (\ps -> ps { psConstants = (A.nameName n', case st' of A.IsExpr _ _ _ e' -> e') : psConstants ps })
|
||||
return $ A.Specification m n' st'
|
||||
|
||||
scopeOutSpec :: A.Specification -> OccParser ()
|
||||
|
@ -845,10 +845,10 @@ booleanExpr = expressionOfType A.Bool <?> "boolean expression"
|
|||
constExprOfType :: A.Type -> OccParser A.Expression
|
||||
constExprOfType wantT
|
||||
= do e <- expressionOfType wantT
|
||||
ps <- getState
|
||||
case simplifyExpression ps e of
|
||||
Left err -> fail $ "expected constant expression (" ++ err ++ ")"
|
||||
Right e' -> return e'
|
||||
(e', isConst, msg) <- constantFold e
|
||||
when (not isConst) $
|
||||
fail $ "expression is not constant (" ++ msg ++ ")"
|
||||
return e'
|
||||
|
||||
constIntExpr = constExprOfType A.Int <?> "constant integer expression"
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user