Rework how constant evaluation is done

This commit is contained in:
Adam Sampson 2007-04-26 16:05:04 +00:00
parent 9e69317d7b
commit 939205670b
4 changed files with 60 additions and 34 deletions

View File

@ -1,5 +1,5 @@
-- | Evaluate constant expressions.
module EvalConstants (constantFold) where
module EvalConstants (constantFold, isConstantName) where
import Control.Monad.Error
import Control.Monad.Identity
@ -11,6 +11,7 @@ import Data.Maybe
import Numeric
import qualified AST as A
import Errors
import Metadata
import ParseState
import Pass
@ -36,21 +37,43 @@ isConstant (A.True _) = True
isConstant (A.False _) = True
isConstant _ = False
-- | Is a name defined as a constant expression? If so, return its definition.
getConstantName :: (PSM m, Die m) => A.Name -> m (Maybe A.Expression)
getConstantName n
= do st <- specTypeOfName n
case st of
A.IsExpr _ A.ValAbbrev _ e ->
if isConstant e then return $ Just e
else return Nothing
_ -> return Nothing
-- | Is a name defined as a constant expression?
isConstantName :: (PSM m, Die m) => A.Name -> m Bool
isConstantName n
= do me <- getConstantName n
return $ case me of
Just _ -> True
Nothing -> False
-- | Attempt to simplify an expression as far as possible by precomputing
-- constant bits.
simplifyExpression :: ParseState -> A.Expression -> Either String A.Expression
simplifyExpression ps e
= case runIdentity (evalStateT (runErrorT (evalExpression e)) ps) of
Left err -> Left err
Right val -> Right $ renderValue (metaOfExpression e) val
Right val -> Right $ snd $ renderValue (metaOfExpression e) val
--{{{ expression evaluator
type EvalM a = ErrorT String (StateT ParseState Identity) a
type EvalM = ErrorT String (StateT ParseState Identity)
instance Die EvalM where
die = throwError
-- | Occam values of various types.
data OccValue =
OccBool Bool
| OccInt Int32
| OccArray [OccValue]
deriving (Show, Eq, Typeable, Data)
-- | Turn the result of one of the read* functions into an OccValue,
@ -62,6 +85,8 @@ fromRead _ _ = throwError "cannot parse literal"
evalLiteral :: A.Literal -> EvalM OccValue
evalLiteral (A.Literal _ A.Int (A.IntLiteral _ s)) = fromRead OccInt $ readDec s
evalLiteral (A.Literal _ A.Int (A.HexLiteral _ s)) = fromRead OccInt $ readHex s
evalLiteral (A.Literal _ _ (A.ArrayLiteral _ es))
= liftM OccArray (mapM evalExpression es)
evalLiteral _ = throwError "bad literal"
evalExpression :: A.Expression -> EvalM OccValue
@ -76,8 +101,8 @@ evalExpression (A.MostPos _ A.Int) = return $ OccInt maxBound
evalExpression (A.MostNeg _ A.Int) = return $ OccInt minBound
evalExpression (A.ExprLiteral _ l) = evalLiteral l
evalExpression (A.ExprVariable _ (A.Variable _ n))
= do ps <- get
case lookup (A.nameName n) (psConstants ps) of
= do me <- getConstantName n
case me of
Just e -> evalExpression e
Nothing -> throwError $ "non-constant variable " ++ show n ++ " used"
evalExpression (A.True _) = return $ OccBool True
@ -85,11 +110,16 @@ evalExpression (A.False _) = return $ OccBool False
evalExpression _ = throwError "bad expression"
evalMonadic :: A.MonadicOp -> OccValue -> EvalM OccValue
-- This, oddly, is probably the most important rule here: "-4" isn't a literal
-- in occam, it's an operator applied to a literal.
evalMonadic A.MonadicSubtr (OccInt i) = return $ OccInt (0 - i)
evalMonadic A.MonadicBitNot (OccInt i) = return $ OccInt (complement i)
evalMonadic A.MonadicNot (OccBool b) = return $ OccBool (not b)
evalMonadic _ _ = throwError "bad monadic op"
int32ToInt :: Int32 -> Int
int32ToInt n = fromInteger (toInteger n)
evalDyadic :: A.DyadicOp -> OccValue -> OccValue -> EvalM OccValue
-- FIXME These should check for overflow.
evalDyadic A.Add (OccInt a) (OccInt b) = return $ OccInt (a + b)
@ -104,6 +134,8 @@ evalDyadic A.Times (OccInt a) (OccInt b) = return $ OccInt (a * b)
evalDyadic A.BitAnd (OccInt a) (OccInt b) = return $ OccInt (a .&. b)
evalDyadic A.BitOr (OccInt a) (OccInt b) = return $ OccInt (a .|. b)
evalDyadic A.BitXor (OccInt a) (OccInt b) = return $ OccInt (a `xor` b)
evalDyadic A.LeftShift (OccInt a) (OccInt b) = return $ OccInt (shiftL a (int32ToInt b))
evalDyadic A.RightShift (OccInt a) (OccInt b) = return $ OccInt (shiftR a (int32ToInt b))
evalDyadic A.And (OccBool a) (OccBool b) = return $ OccBool (a && b)
evalDyadic A.Or (OccBool a) (OccBool b) = return $ OccBool (a || b)
evalDyadic A.Eq a b = return $ OccBool (a == b)
@ -118,8 +150,17 @@ evalDyadic A.After (OccInt a) (OccInt b) = return $ OccBool ((a - b) > 0)
evalDyadic _ _ _ = throwError "bad dyadic op"
-- | Convert a value back into a literal.
renderValue :: Meta -> OccValue -> A.Expression
renderValue m (OccInt i) = A.ExprLiteral m (A.Literal m A.Int (A.IntLiteral m $ show i))
renderValue m (OccBool True) = A.True m
renderValue m (OccBool False) = A.False m
renderValue :: Meta -> OccValue -> (A.Type, A.Expression)
renderValue m (OccBool True) = (A.Bool, A.True m)
renderValue m (OccBool False) = (A.Bool, A.False m)
renderValue m v = (t, A.ExprLiteral m (A.Literal m t lr))
where (t, lr) = renderLiteral m v
renderLiteral :: Meta -> OccValue -> (A.Type, A.LiteralRepr)
renderLiteral m (OccInt i) = (A.Int, A.IntLiteral m $ show i)
renderLiteral m (OccArray vs)
= (t, A.ArrayLiteral m es)
where
t = makeArrayType (A.Dimension $ makeConstant m (length vs)) (head ts)
(ts, es) = unzip $ map (renderValue m) vs
--}}}

View File

@ -508,21 +508,10 @@ scopeInRep (A.For m n b c)
scopeOutRep :: A.Replicator -> OccParser ()
scopeOutRep (A.For m n b c) = scopeOut n
-- This one's more complicated because we need to check if we're introducing a constant.
scopeInSpec :: A.Specification -> OccParser A.Specification
scopeInSpec (A.Specification m n st)
= do ps <- getState
(st', isConst) <- case st of
(A.IsExpr m A.ValAbbrev t e) ->
do (e', isConst, msg) <- constantFold e
if isConst
then return (A.IsExpr m A.ValAbbrev t e', True)
else return (st, False)
_ -> return (st, False)
n' <- scopeIn n st' (abbrevModeOfSpec st')
when isConst $
updateState (\ps -> ps { psConstants = (A.nameName n', case st' of A.IsExpr _ _ _ e' -> e') : psConstants ps })
return $ A.Specification m n' st'
= do n' <- scopeIn n st (abbrevModeOfSpec st)
return $ A.Specification m n' st
scopeOutSpec :: A.Specification -> OccParser ()
scopeOutSpec (A.Specification _ n _) = scopeOut n
@ -1100,7 +1089,10 @@ valIsAbbrev
= do m <- md
(n, t, e) <- do { n <- tryXVX sVAL newVariableName sIS; e <- expression; sColon; eol; t <- typeOfExpression e; return (n, t, e) }
<|> do { (s, n) <- tryXVVX sVAL specifier newVariableName sIS; e <- expressionOfType s; sColon; eol; return (n, s, e) }
return $ A.Specification m n $ A.IsExpr m A.ValAbbrev t e
-- Do constant folding early, so that we can use names defined this
-- way as constants elsewhere.
(e', _, _) <- constantFold e
return $ A.Specification m n $ A.IsExpr m A.ValAbbrev t e'
<?> "VAL IS abbreviation"
isAbbrev :: OccParser A.Name -> OccParser A.Variable -> OccParser A.Specification

View File

@ -26,7 +26,6 @@ data ParseState = ParseState {
psNames :: [(String, A.NameDef)],
psNameCounter :: Int,
psTypeContext :: [Maybe A.Type],
psConstants :: [(String, A.Expression)],
psLoadedFiles :: [String],
-- Set by passes
@ -55,7 +54,6 @@ emptyState = ParseState {
psNames = [],
psNameCounter = 0,
psTypeContext = [],
psConstants = [],
psLoadedFiles = [],
psNonceCounter = 0,
@ -157,11 +155,3 @@ makeNonceVariable :: PSM m => String -> Meta -> A.Type -> A.NameType -> A.Abbrev
makeNonceVariable s m t nt am
= defineNonce m s (A.Declaration m t) nt am
-- | Is a name on the list of constants?
isConstantName :: PSM m => A.Name -> m Bool
isConstantName n
= do ps <- get
case lookup (A.nameName n) (psConstants ps) of
Just _ -> return True
Nothing -> return False

View File

@ -7,10 +7,11 @@ import qualified Data.Map as Map
import Data.Maybe
import qualified AST as A
import EvalConstants
import Metadata
import ParseState
import Types
import Pass
import Types
unnest :: A.Process -> PassM A.Process
unnest = runPasses passes
@ -123,6 +124,7 @@ removeFreeNames = doGeneric `extM` doSpecification `extM` doProcess
A.Abbrev -> A.ActualVariable am t (A.Variable m n)
_ -> A.ActualExpression t (A.ExprVariable m (A.Variable m n))
| (am, n, t) <- zip3 ams freeNames types]
progress $ show n ++ " has new args " ++ show newAs
case newAs of
[] -> return ()
_ -> modify $ (\ps -> ps { psAdditionalArgs = (A.nameName n, newAs) : psAdditionalArgs ps })
@ -133,6 +135,7 @@ removeFreeNames = doGeneric `extM` doSpecification `extM` doProcess
doProcess :: A.Process -> PassM A.Process
doProcess p@(A.ProcCall m n as)
= do st <- get
progress $ "adding args to call of " ++ show n
case lookup (A.nameName n) (psAdditionalArgs st) of
Just add -> doGeneric $ A.ProcCall m n (as ++ add)
Nothing -> doGeneric p