194 lines
7.6 KiB
Haskell
194 lines
7.6 KiB
Haskell
-- | Evaluate constant expressions.
|
|
module EvalConstants (constantFold, isConstantName) where
|
|
|
|
import Control.Monad.Error
|
|
import Control.Monad.Identity
|
|
import Control.Monad.State
|
|
import Data.Bits
|
|
import Data.Char
|
|
import Data.Generics
|
|
import Data.Int
|
|
import Data.Maybe
|
|
import Numeric
|
|
import Text.Printf
|
|
|
|
import qualified AST as A
|
|
import Errors
|
|
import EvalLiterals
|
|
import Metadata
|
|
import ParseState
|
|
import Pass
|
|
import Types
|
|
|
|
-- | Simplify an expression by constant folding, and also return whether it's a
|
|
-- constant after that.
|
|
constantFold :: PSM m => A.Expression -> m (A.Expression, Bool, String)
|
|
constantFold e
|
|
= do ps <- get
|
|
let (e', msg) = case simplifyExpression ps e of
|
|
Left err -> (e, err)
|
|
Right val -> (val, "already folded")
|
|
return (e', isConstant e', msg)
|
|
|
|
-- | Is a name defined as a constant expression? If so, return its definition.
|
|
getConstantName :: (PSM m, Die m) => A.Name -> m (Maybe A.Expression)
|
|
getConstantName n
|
|
= do st <- specTypeOfName n
|
|
case st of
|
|
A.IsExpr _ A.ValAbbrev _ e ->
|
|
if isConstant e then return $ Just e
|
|
else return Nothing
|
|
_ -> return Nothing
|
|
|
|
-- | Is a name defined as a constant expression?
|
|
isConstantName :: (PSM m, Die m) => A.Name -> m Bool
|
|
isConstantName n
|
|
= do me <- getConstantName n
|
|
return $ case me of
|
|
Just _ -> True
|
|
Nothing -> False
|
|
|
|
-- | Attempt to simplify an expression as far as possible by precomputing
|
|
-- constant bits.
|
|
simplifyExpression :: ParseState -> A.Expression -> Either String A.Expression
|
|
simplifyExpression ps e
|
|
= case runEvaluator ps (evalExpression e) of
|
|
Left err -> Left err
|
|
Right val -> Right $ snd $ renderValue (findMeta e) val
|
|
|
|
--{{{ expression evaluator
|
|
evalLiteral :: A.Expression -> EvalM OccValue
|
|
evalLiteral (A.Literal _ _ (A.ArrayLiteral _ []))
|
|
= throwError "empty array"
|
|
evalLiteral (A.Literal _ _ (A.ArrayLiteral _ aes))
|
|
= liftM OccArray (mapM evalLiteralArray aes)
|
|
evalLiteral l = evalSimpleLiteral l
|
|
|
|
evalLiteralArray :: A.ArrayElem -> EvalM OccValue
|
|
evalLiteralArray (A.ArrayElemArray aes) = liftM OccArray (mapM evalLiteralArray aes)
|
|
evalLiteralArray (A.ArrayElemExpr e) = evalExpression e
|
|
|
|
evalExpression :: A.Expression -> EvalM OccValue
|
|
evalExpression (A.Monadic _ op e)
|
|
= do v <- evalExpression e
|
|
evalMonadic op v
|
|
evalExpression (A.Dyadic _ op e1 e2)
|
|
= do v1 <- evalExpression e1
|
|
v2 <- evalExpression e2
|
|
evalDyadic op v1 v2
|
|
evalExpression (A.MostPos _ A.Int) = return $ OccInt maxBound
|
|
evalExpression (A.MostNeg _ A.Int) = return $ OccInt minBound
|
|
evalExpression (A.SizeExpr _ e)
|
|
= do t <- typeOfExpression e
|
|
case t of
|
|
A.Array (A.Dimension n:_) _ -> return $ OccInt (fromIntegral n)
|
|
_ ->
|
|
do v <- evalExpression e
|
|
case v of
|
|
OccArray vs -> return $ OccInt (fromIntegral $ length vs)
|
|
_ -> throwError $ "size of non-constant expression " ++ show e ++ " used"
|
|
evalExpression (A.SizeVariable m v)
|
|
= do t <- typeOfVariable v
|
|
case t of
|
|
A.Array (A.Dimension n:_) _ -> return $ OccInt (fromIntegral n)
|
|
_ -> throwError $ "size of non-fixed-size variable " ++ show v ++ " used"
|
|
evalExpression e@(A.Literal _ _ _) = evalLiteral e
|
|
evalExpression (A.ExprVariable _ (A.Variable _ n))
|
|
= do me <- getConstantName n
|
|
case me of
|
|
Just e -> evalExpression e
|
|
Nothing -> throwError $ "non-constant variable " ++ show n ++ " used"
|
|
evalExpression (A.True _) = return $ OccBool True
|
|
evalExpression (A.False _) = return $ OccBool False
|
|
evalExpression (A.BytesInExpr _ e)
|
|
= do t <- typeOfExpression e
|
|
b <- bytesInType t
|
|
case b of
|
|
BIJust n -> return $ OccInt (fromIntegral $ n)
|
|
_ -> throwError $ "BYTESIN non-constant-size expression " ++ show e ++ " used"
|
|
evalExpression (A.BytesInType _ t)
|
|
= do b <- bytesInType t
|
|
case b of
|
|
BIJust n -> return $ OccInt (fromIntegral $ n)
|
|
_ -> throwError $ "BYTESIN non-constant-size type " ++ show t ++ " used"
|
|
evalExpression _ = throwError "bad expression"
|
|
|
|
evalMonadic :: A.MonadicOp -> OccValue -> EvalM OccValue
|
|
-- This, oddly, is probably the most important rule here: "-4" isn't a literal
|
|
-- in occam, it's an operator applied to a literal.
|
|
evalMonadic A.MonadicSubtr (OccInt i) = return $ OccInt (0 - i)
|
|
evalMonadic A.MonadicBitNot (OccInt i) = return $ OccInt (complement i)
|
|
evalMonadic A.MonadicNot (OccBool b) = return $ OccBool (not b)
|
|
evalMonadic _ _ = throwError "bad monadic op"
|
|
|
|
evalDyadic :: A.DyadicOp -> OccValue -> OccValue -> EvalM OccValue
|
|
-- FIXME These should check for overflow.
|
|
evalDyadic A.Add (OccInt a) (OccInt b) = return $ OccInt (a + b)
|
|
evalDyadic A.Subtr (OccInt a) (OccInt b) = return $ OccInt (a - b)
|
|
evalDyadic A.Mul (OccInt a) (OccInt b) = return $ OccInt (a * b)
|
|
evalDyadic A.Div (OccInt a) (OccInt b) = return $ OccInt (a `div` b)
|
|
evalDyadic A.Rem (OccInt a) (OccInt b) = return $ OccInt (a `mod` b)
|
|
-- ... end FIXME
|
|
evalDyadic A.Plus (OccInt a) (OccInt b) = return $ OccInt (a + b)
|
|
evalDyadic A.Minus (OccInt a) (OccInt b) = return $ OccInt (a - b)
|
|
evalDyadic A.Times (OccInt a) (OccInt b) = return $ OccInt (a * b)
|
|
evalDyadic A.BitAnd (OccInt a) (OccInt b) = return $ OccInt (a .&. b)
|
|
evalDyadic A.BitOr (OccInt a) (OccInt b) = return $ OccInt (a .|. b)
|
|
evalDyadic A.BitXor (OccInt a) (OccInt b) = return $ OccInt (a `xor` b)
|
|
evalDyadic A.LeftShift (OccInt a) (OccInt b) = return $ OccInt (shiftL a (fromIntegral b))
|
|
evalDyadic A.RightShift (OccInt a) (OccInt b) = return $ OccInt (shiftR a (fromIntegral b))
|
|
evalDyadic A.And (OccBool a) (OccBool b) = return $ OccBool (a && b)
|
|
evalDyadic A.Or (OccBool a) (OccBool b) = return $ OccBool (a || b)
|
|
evalDyadic A.Eq a b = return $ OccBool (a == b)
|
|
evalDyadic A.NotEq a b
|
|
= do (OccBool b) <- evalDyadic A.Eq a b
|
|
return $ OccBool (not b)
|
|
evalDyadic A.Less (OccInt a) (OccInt b) = return $ OccBool (a < b)
|
|
evalDyadic A.More (OccInt a) (OccInt b) = return $ OccBool (a > b)
|
|
evalDyadic A.LessEq a b = evalDyadic A.More b a
|
|
evalDyadic A.MoreEq a b = evalDyadic A.Less b a
|
|
evalDyadic A.After (OccInt a) (OccInt b) = return $ OccBool ((a - b) > 0)
|
|
evalDyadic _ _ _ = throwError "bad dyadic op"
|
|
--}}}
|
|
|
|
--{{{ rendering values
|
|
-- | Convert a value back into a literal.
|
|
renderValue :: Meta -> OccValue -> (A.Type, A.Expression)
|
|
renderValue m (OccBool True) = (A.Bool, A.True m)
|
|
renderValue m (OccBool False) = (A.Bool, A.False m)
|
|
renderValue m v = (t, A.Literal m t lr)
|
|
where (t, lr) = renderLiteral m v
|
|
|
|
renderLiteral :: Meta -> OccValue -> (A.Type, A.LiteralRepr)
|
|
renderLiteral m (OccByte c) = (A.Byte, A.ByteLiteral m $ renderChar c)
|
|
renderLiteral m (OccInt i) = (A.Int, A.IntLiteral m $ show i)
|
|
renderLiteral m (OccArray vs)
|
|
= (t, A.ArrayLiteral m aes)
|
|
where
|
|
t = makeArrayType (A.Dimension $ length vs) (head ts)
|
|
(ts, aes) = unzip $ map (renderLiteralArray m) vs
|
|
|
|
renderChar :: Char -> String
|
|
renderChar '\'' = "*'"
|
|
renderChar '\"' = "*\""
|
|
renderChar '*' = "**"
|
|
renderChar '\r' = "*c"
|
|
renderChar '\n' = "*n"
|
|
renderChar '\t' = "*t"
|
|
renderChar c
|
|
| (o < 32 || o > 127) = printf "*#%02x" o
|
|
| otherwise = [c]
|
|
where o = ord c
|
|
|
|
renderLiteralArray :: Meta -> OccValue -> (A.Type, A.ArrayElem)
|
|
renderLiteralArray m (OccArray vs)
|
|
= (t, A.ArrayElemArray aes)
|
|
where
|
|
t = makeArrayType (A.Dimension $ length vs) (head ts)
|
|
(ts, aes) = unzip $ map (renderLiteralArray m) vs
|
|
renderLiteralArray m v
|
|
= (t, A.ArrayElemExpr e)
|
|
where
|
|
(t, e) = renderValue m v
|
|
--}}}
|