tock-mirror/fco2/EvalConstants.hs
2007-04-30 04:03:29 +00:00

194 lines
7.6 KiB
Haskell

-- | Evaluate constant expressions.
module EvalConstants (constantFold, isConstantName) where
import Control.Monad.Error
import Control.Monad.Identity
import Control.Monad.State
import Data.Bits
import Data.Char
import Data.Generics
import Data.Int
import Data.Maybe
import Numeric
import Text.Printf
import qualified AST as A
import Errors
import EvalLiterals
import Metadata
import ParseState
import Pass
import Types
-- | Simplify an expression by constant folding, and also return whether it's a
-- constant after that.
constantFold :: PSM m => A.Expression -> m (A.Expression, Bool, String)
constantFold e
= do ps <- get
let (e', msg) = case simplifyExpression ps e of
Left err -> (e, err)
Right val -> (val, "already folded")
return (e', isConstant e', msg)
-- | Is a name defined as a constant expression? If so, return its definition.
getConstantName :: (PSM m, Die m) => A.Name -> m (Maybe A.Expression)
getConstantName n
= do st <- specTypeOfName n
case st of
A.IsExpr _ A.ValAbbrev _ e ->
if isConstant e then return $ Just e
else return Nothing
_ -> return Nothing
-- | Is a name defined as a constant expression?
isConstantName :: (PSM m, Die m) => A.Name -> m Bool
isConstantName n
= do me <- getConstantName n
return $ case me of
Just _ -> True
Nothing -> False
-- | Attempt to simplify an expression as far as possible by precomputing
-- constant bits.
simplifyExpression :: ParseState -> A.Expression -> Either String A.Expression
simplifyExpression ps e
= case runEvaluator ps (evalExpression e) of
Left err -> Left err
Right val -> Right $ snd $ renderValue (findMeta e) val
--{{{ expression evaluator
evalLiteral :: A.Expression -> EvalM OccValue
evalLiteral (A.Literal _ _ (A.ArrayLiteral _ []))
= throwError "empty array"
evalLiteral (A.Literal _ _ (A.ArrayLiteral _ aes))
= liftM OccArray (mapM evalLiteralArray aes)
evalLiteral l = evalSimpleLiteral l
evalLiteralArray :: A.ArrayElem -> EvalM OccValue
evalLiteralArray (A.ArrayElemArray aes) = liftM OccArray (mapM evalLiteralArray aes)
evalLiteralArray (A.ArrayElemExpr e) = evalExpression e
evalExpression :: A.Expression -> EvalM OccValue
evalExpression (A.Monadic _ op e)
= do v <- evalExpression e
evalMonadic op v
evalExpression (A.Dyadic _ op e1 e2)
= do v1 <- evalExpression e1
v2 <- evalExpression e2
evalDyadic op v1 v2
evalExpression (A.MostPos _ A.Int) = return $ OccInt maxBound
evalExpression (A.MostNeg _ A.Int) = return $ OccInt minBound
evalExpression (A.SizeExpr _ e)
= do t <- typeOfExpression e
case t of
A.Array (A.Dimension n:_) _ -> return $ OccInt (fromIntegral n)
_ ->
do v <- evalExpression e
case v of
OccArray vs -> return $ OccInt (fromIntegral $ length vs)
_ -> throwError $ "size of non-constant expression " ++ show e ++ " used"
evalExpression (A.SizeVariable m v)
= do t <- typeOfVariable v
case t of
A.Array (A.Dimension n:_) _ -> return $ OccInt (fromIntegral n)
_ -> throwError $ "size of non-fixed-size variable " ++ show v ++ " used"
evalExpression e@(A.Literal _ _ _) = evalLiteral e
evalExpression (A.ExprVariable _ (A.Variable _ n))
= do me <- getConstantName n
case me of
Just e -> evalExpression e
Nothing -> throwError $ "non-constant variable " ++ show n ++ " used"
evalExpression (A.True _) = return $ OccBool True
evalExpression (A.False _) = return $ OccBool False
evalExpression (A.BytesInExpr _ e)
= do t <- typeOfExpression e
b <- bytesInType t
case b of
BIJust n -> return $ OccInt (fromIntegral $ n)
_ -> throwError $ "BYTESIN non-constant-size expression " ++ show e ++ " used"
evalExpression (A.BytesInType _ t)
= do b <- bytesInType t
case b of
BIJust n -> return $ OccInt (fromIntegral $ n)
_ -> throwError $ "BYTESIN non-constant-size type " ++ show t ++ " used"
evalExpression _ = throwError "bad expression"
evalMonadic :: A.MonadicOp -> OccValue -> EvalM OccValue
-- This, oddly, is probably the most important rule here: "-4" isn't a literal
-- in occam, it's an operator applied to a literal.
evalMonadic A.MonadicSubtr (OccInt i) = return $ OccInt (0 - i)
evalMonadic A.MonadicBitNot (OccInt i) = return $ OccInt (complement i)
evalMonadic A.MonadicNot (OccBool b) = return $ OccBool (not b)
evalMonadic _ _ = throwError "bad monadic op"
evalDyadic :: A.DyadicOp -> OccValue -> OccValue -> EvalM OccValue
-- FIXME These should check for overflow.
evalDyadic A.Add (OccInt a) (OccInt b) = return $ OccInt (a + b)
evalDyadic A.Subtr (OccInt a) (OccInt b) = return $ OccInt (a - b)
evalDyadic A.Mul (OccInt a) (OccInt b) = return $ OccInt (a * b)
evalDyadic A.Div (OccInt a) (OccInt b) = return $ OccInt (a `div` b)
evalDyadic A.Rem (OccInt a) (OccInt b) = return $ OccInt (a `mod` b)
-- ... end FIXME
evalDyadic A.Plus (OccInt a) (OccInt b) = return $ OccInt (a + b)
evalDyadic A.Minus (OccInt a) (OccInt b) = return $ OccInt (a - b)
evalDyadic A.Times (OccInt a) (OccInt b) = return $ OccInt (a * b)
evalDyadic A.BitAnd (OccInt a) (OccInt b) = return $ OccInt (a .&. b)
evalDyadic A.BitOr (OccInt a) (OccInt b) = return $ OccInt (a .|. b)
evalDyadic A.BitXor (OccInt a) (OccInt b) = return $ OccInt (a `xor` b)
evalDyadic A.LeftShift (OccInt a) (OccInt b) = return $ OccInt (shiftL a (fromIntegral b))
evalDyadic A.RightShift (OccInt a) (OccInt b) = return $ OccInt (shiftR a (fromIntegral b))
evalDyadic A.And (OccBool a) (OccBool b) = return $ OccBool (a && b)
evalDyadic A.Or (OccBool a) (OccBool b) = return $ OccBool (a || b)
evalDyadic A.Eq a b = return $ OccBool (a == b)
evalDyadic A.NotEq a b
= do (OccBool b) <- evalDyadic A.Eq a b
return $ OccBool (not b)
evalDyadic A.Less (OccInt a) (OccInt b) = return $ OccBool (a < b)
evalDyadic A.More (OccInt a) (OccInt b) = return $ OccBool (a > b)
evalDyadic A.LessEq a b = evalDyadic A.More b a
evalDyadic A.MoreEq a b = evalDyadic A.Less b a
evalDyadic A.After (OccInt a) (OccInt b) = return $ OccBool ((a - b) > 0)
evalDyadic _ _ _ = throwError "bad dyadic op"
--}}}
--{{{ rendering values
-- | Convert a value back into a literal.
renderValue :: Meta -> OccValue -> (A.Type, A.Expression)
renderValue m (OccBool True) = (A.Bool, A.True m)
renderValue m (OccBool False) = (A.Bool, A.False m)
renderValue m v = (t, A.Literal m t lr)
where (t, lr) = renderLiteral m v
renderLiteral :: Meta -> OccValue -> (A.Type, A.LiteralRepr)
renderLiteral m (OccByte c) = (A.Byte, A.ByteLiteral m $ renderChar c)
renderLiteral m (OccInt i) = (A.Int, A.IntLiteral m $ show i)
renderLiteral m (OccArray vs)
= (t, A.ArrayLiteral m aes)
where
t = makeArrayType (A.Dimension $ length vs) (head ts)
(ts, aes) = unzip $ map (renderLiteralArray m) vs
renderChar :: Char -> String
renderChar '\'' = "*'"
renderChar '\"' = "*\""
renderChar '*' = "**"
renderChar '\r' = "*c"
renderChar '\n' = "*n"
renderChar '\t' = "*t"
renderChar c
| (o < 32 || o > 127) = printf "*#%02x" o
| otherwise = [c]
where o = ord c
renderLiteralArray :: Meta -> OccValue -> (A.Type, A.ArrayElem)
renderLiteralArray m (OccArray vs)
= (t, A.ArrayElemArray aes)
where
t = makeArrayType (A.Dimension $ length vs) (head ts)
(ts, aes) = unzip $ map (renderLiteralArray m) vs
renderLiteralArray m v
= (t, A.ArrayElemExpr e)
where
(t, e) = renderValue m v
--}}}