Rework how constant evaluation is done
This commit is contained in:
parent
9e69317d7b
commit
939205670b
|
@ -1,5 +1,5 @@
|
|||
-- | Evaluate constant expressions.
|
||||
module EvalConstants (constantFold) where
|
||||
module EvalConstants (constantFold, isConstantName) where
|
||||
|
||||
import Control.Monad.Error
|
||||
import Control.Monad.Identity
|
||||
|
@ -11,6 +11,7 @@ import Data.Maybe
|
|||
import Numeric
|
||||
|
||||
import qualified AST as A
|
||||
import Errors
|
||||
import Metadata
|
||||
import ParseState
|
||||
import Pass
|
||||
|
@ -36,21 +37,43 @@ isConstant (A.True _) = True
|
|||
isConstant (A.False _) = True
|
||||
isConstant _ = False
|
||||
|
||||
-- | Is a name defined as a constant expression? If so, return its definition.
|
||||
getConstantName :: (PSM m, Die m) => A.Name -> m (Maybe A.Expression)
|
||||
getConstantName n
|
||||
= do st <- specTypeOfName n
|
||||
case st of
|
||||
A.IsExpr _ A.ValAbbrev _ e ->
|
||||
if isConstant e then return $ Just e
|
||||
else return Nothing
|
||||
_ -> return Nothing
|
||||
|
||||
-- | Is a name defined as a constant expression?
|
||||
isConstantName :: (PSM m, Die m) => A.Name -> m Bool
|
||||
isConstantName n
|
||||
= do me <- getConstantName n
|
||||
return $ case me of
|
||||
Just _ -> True
|
||||
Nothing -> False
|
||||
|
||||
-- | Attempt to simplify an expression as far as possible by precomputing
|
||||
-- constant bits.
|
||||
simplifyExpression :: ParseState -> A.Expression -> Either String A.Expression
|
||||
simplifyExpression ps e
|
||||
= case runIdentity (evalStateT (runErrorT (evalExpression e)) ps) of
|
||||
Left err -> Left err
|
||||
Right val -> Right $ renderValue (metaOfExpression e) val
|
||||
Right val -> Right $ snd $ renderValue (metaOfExpression e) val
|
||||
|
||||
--{{{ expression evaluator
|
||||
type EvalM a = ErrorT String (StateT ParseState Identity) a
|
||||
type EvalM = ErrorT String (StateT ParseState Identity)
|
||||
|
||||
instance Die EvalM where
|
||||
die = throwError
|
||||
|
||||
-- | Occam values of various types.
|
||||
data OccValue =
|
||||
OccBool Bool
|
||||
| OccInt Int32
|
||||
| OccArray [OccValue]
|
||||
deriving (Show, Eq, Typeable, Data)
|
||||
|
||||
-- | Turn the result of one of the read* functions into an OccValue,
|
||||
|
@ -62,6 +85,8 @@ fromRead _ _ = throwError "cannot parse literal"
|
|||
evalLiteral :: A.Literal -> EvalM OccValue
|
||||
evalLiteral (A.Literal _ A.Int (A.IntLiteral _ s)) = fromRead OccInt $ readDec s
|
||||
evalLiteral (A.Literal _ A.Int (A.HexLiteral _ s)) = fromRead OccInt $ readHex s
|
||||
evalLiteral (A.Literal _ _ (A.ArrayLiteral _ es))
|
||||
= liftM OccArray (mapM evalExpression es)
|
||||
evalLiteral _ = throwError "bad literal"
|
||||
|
||||
evalExpression :: A.Expression -> EvalM OccValue
|
||||
|
@ -76,8 +101,8 @@ evalExpression (A.MostPos _ A.Int) = return $ OccInt maxBound
|
|||
evalExpression (A.MostNeg _ A.Int) = return $ OccInt minBound
|
||||
evalExpression (A.ExprLiteral _ l) = evalLiteral l
|
||||
evalExpression (A.ExprVariable _ (A.Variable _ n))
|
||||
= do ps <- get
|
||||
case lookup (A.nameName n) (psConstants ps) of
|
||||
= do me <- getConstantName n
|
||||
case me of
|
||||
Just e -> evalExpression e
|
||||
Nothing -> throwError $ "non-constant variable " ++ show n ++ " used"
|
||||
evalExpression (A.True _) = return $ OccBool True
|
||||
|
@ -85,11 +110,16 @@ evalExpression (A.False _) = return $ OccBool False
|
|||
evalExpression _ = throwError "bad expression"
|
||||
|
||||
evalMonadic :: A.MonadicOp -> OccValue -> EvalM OccValue
|
||||
-- This, oddly, is probably the most important rule here: "-4" isn't a literal
|
||||
-- in occam, it's an operator applied to a literal.
|
||||
evalMonadic A.MonadicSubtr (OccInt i) = return $ OccInt (0 - i)
|
||||
evalMonadic A.MonadicBitNot (OccInt i) = return $ OccInt (complement i)
|
||||
evalMonadic A.MonadicNot (OccBool b) = return $ OccBool (not b)
|
||||
evalMonadic _ _ = throwError "bad monadic op"
|
||||
|
||||
int32ToInt :: Int32 -> Int
|
||||
int32ToInt n = fromInteger (toInteger n)
|
||||
|
||||
evalDyadic :: A.DyadicOp -> OccValue -> OccValue -> EvalM OccValue
|
||||
-- FIXME These should check for overflow.
|
||||
evalDyadic A.Add (OccInt a) (OccInt b) = return $ OccInt (a + b)
|
||||
|
@ -104,6 +134,8 @@ evalDyadic A.Times (OccInt a) (OccInt b) = return $ OccInt (a * b)
|
|||
evalDyadic A.BitAnd (OccInt a) (OccInt b) = return $ OccInt (a .&. b)
|
||||
evalDyadic A.BitOr (OccInt a) (OccInt b) = return $ OccInt (a .|. b)
|
||||
evalDyadic A.BitXor (OccInt a) (OccInt b) = return $ OccInt (a `xor` b)
|
||||
evalDyadic A.LeftShift (OccInt a) (OccInt b) = return $ OccInt (shiftL a (int32ToInt b))
|
||||
evalDyadic A.RightShift (OccInt a) (OccInt b) = return $ OccInt (shiftR a (int32ToInt b))
|
||||
evalDyadic A.And (OccBool a) (OccBool b) = return $ OccBool (a && b)
|
||||
evalDyadic A.Or (OccBool a) (OccBool b) = return $ OccBool (a || b)
|
||||
evalDyadic A.Eq a b = return $ OccBool (a == b)
|
||||
|
@ -118,8 +150,17 @@ evalDyadic A.After (OccInt a) (OccInt b) = return $ OccBool ((a - b) > 0)
|
|||
evalDyadic _ _ _ = throwError "bad dyadic op"
|
||||
|
||||
-- | Convert a value back into a literal.
|
||||
renderValue :: Meta -> OccValue -> A.Expression
|
||||
renderValue m (OccInt i) = A.ExprLiteral m (A.Literal m A.Int (A.IntLiteral m $ show i))
|
||||
renderValue m (OccBool True) = A.True m
|
||||
renderValue m (OccBool False) = A.False m
|
||||
renderValue :: Meta -> OccValue -> (A.Type, A.Expression)
|
||||
renderValue m (OccBool True) = (A.Bool, A.True m)
|
||||
renderValue m (OccBool False) = (A.Bool, A.False m)
|
||||
renderValue m v = (t, A.ExprLiteral m (A.Literal m t lr))
|
||||
where (t, lr) = renderLiteral m v
|
||||
|
||||
renderLiteral :: Meta -> OccValue -> (A.Type, A.LiteralRepr)
|
||||
renderLiteral m (OccInt i) = (A.Int, A.IntLiteral m $ show i)
|
||||
renderLiteral m (OccArray vs)
|
||||
= (t, A.ArrayLiteral m es)
|
||||
where
|
||||
t = makeArrayType (A.Dimension $ makeConstant m (length vs)) (head ts)
|
||||
(ts, es) = unzip $ map (renderValue m) vs
|
||||
--}}}
|
||||
|
|
|
@ -508,21 +508,10 @@ scopeInRep (A.For m n b c)
|
|||
scopeOutRep :: A.Replicator -> OccParser ()
|
||||
scopeOutRep (A.For m n b c) = scopeOut n
|
||||
|
||||
-- This one's more complicated because we need to check if we're introducing a constant.
|
||||
scopeInSpec :: A.Specification -> OccParser A.Specification
|
||||
scopeInSpec (A.Specification m n st)
|
||||
= do ps <- getState
|
||||
(st', isConst) <- case st of
|
||||
(A.IsExpr m A.ValAbbrev t e) ->
|
||||
do (e', isConst, msg) <- constantFold e
|
||||
if isConst
|
||||
then return (A.IsExpr m A.ValAbbrev t e', True)
|
||||
else return (st, False)
|
||||
_ -> return (st, False)
|
||||
n' <- scopeIn n st' (abbrevModeOfSpec st')
|
||||
when isConst $
|
||||
updateState (\ps -> ps { psConstants = (A.nameName n', case st' of A.IsExpr _ _ _ e' -> e') : psConstants ps })
|
||||
return $ A.Specification m n' st'
|
||||
= do n' <- scopeIn n st (abbrevModeOfSpec st)
|
||||
return $ A.Specification m n' st
|
||||
|
||||
scopeOutSpec :: A.Specification -> OccParser ()
|
||||
scopeOutSpec (A.Specification _ n _) = scopeOut n
|
||||
|
@ -1100,7 +1089,10 @@ valIsAbbrev
|
|||
= do m <- md
|
||||
(n, t, e) <- do { n <- tryXVX sVAL newVariableName sIS; e <- expression; sColon; eol; t <- typeOfExpression e; return (n, t, e) }
|
||||
<|> do { (s, n) <- tryXVVX sVAL specifier newVariableName sIS; e <- expressionOfType s; sColon; eol; return (n, s, e) }
|
||||
return $ A.Specification m n $ A.IsExpr m A.ValAbbrev t e
|
||||
-- Do constant folding early, so that we can use names defined this
|
||||
-- way as constants elsewhere.
|
||||
(e', _, _) <- constantFold e
|
||||
return $ A.Specification m n $ A.IsExpr m A.ValAbbrev t e'
|
||||
<?> "VAL IS abbreviation"
|
||||
|
||||
isAbbrev :: OccParser A.Name -> OccParser A.Variable -> OccParser A.Specification
|
||||
|
|
|
@ -26,7 +26,6 @@ data ParseState = ParseState {
|
|||
psNames :: [(String, A.NameDef)],
|
||||
psNameCounter :: Int,
|
||||
psTypeContext :: [Maybe A.Type],
|
||||
psConstants :: [(String, A.Expression)],
|
||||
psLoadedFiles :: [String],
|
||||
|
||||
-- Set by passes
|
||||
|
@ -55,7 +54,6 @@ emptyState = ParseState {
|
|||
psNames = [],
|
||||
psNameCounter = 0,
|
||||
psTypeContext = [],
|
||||
psConstants = [],
|
||||
psLoadedFiles = [],
|
||||
|
||||
psNonceCounter = 0,
|
||||
|
@ -157,11 +155,3 @@ makeNonceVariable :: PSM m => String -> Meta -> A.Type -> A.NameType -> A.Abbrev
|
|||
makeNonceVariable s m t nt am
|
||||
= defineNonce m s (A.Declaration m t) nt am
|
||||
|
||||
-- | Is a name on the list of constants?
|
||||
isConstantName :: PSM m => A.Name -> m Bool
|
||||
isConstantName n
|
||||
= do ps <- get
|
||||
case lookup (A.nameName n) (psConstants ps) of
|
||||
Just _ -> return True
|
||||
Nothing -> return False
|
||||
|
||||
|
|
|
@ -7,10 +7,11 @@ import qualified Data.Map as Map
|
|||
import Data.Maybe
|
||||
|
||||
import qualified AST as A
|
||||
import EvalConstants
|
||||
import Metadata
|
||||
import ParseState
|
||||
import Types
|
||||
import Pass
|
||||
import Types
|
||||
|
||||
unnest :: A.Process -> PassM A.Process
|
||||
unnest = runPasses passes
|
||||
|
@ -123,6 +124,7 @@ removeFreeNames = doGeneric `extM` doSpecification `extM` doProcess
|
|||
A.Abbrev -> A.ActualVariable am t (A.Variable m n)
|
||||
_ -> A.ActualExpression t (A.ExprVariable m (A.Variable m n))
|
||||
| (am, n, t) <- zip3 ams freeNames types]
|
||||
progress $ show n ++ " has new args " ++ show newAs
|
||||
case newAs of
|
||||
[] -> return ()
|
||||
_ -> modify $ (\ps -> ps { psAdditionalArgs = (A.nameName n, newAs) : psAdditionalArgs ps })
|
||||
|
@ -133,6 +135,7 @@ removeFreeNames = doGeneric `extM` doSpecification `extM` doProcess
|
|||
doProcess :: A.Process -> PassM A.Process
|
||||
doProcess p@(A.ProcCall m n as)
|
||||
= do st <- get
|
||||
progress $ "adding args to call of " ++ show n
|
||||
case lookup (A.nameName n) (psAdditionalArgs st) of
|
||||
Just add -> doGeneric $ A.ProcCall m n (as ++ add)
|
||||
Nothing -> doGeneric p
|
||||
|
|
Loading…
Reference in New Issue
Block a user