Fixed the constant folding to understand the new operators-as-functions system, and only fold when the original built-in definition is being used

This commit is contained in:
Neil Brown 2009-04-08 10:28:26 +00:00
parent 6cb47fd1c3
commit 22cb9d35b6

View File

@ -222,13 +222,6 @@ evalRetypes m t ov
getBytesFor (OccReal64 x) f = with x (f . castPtr)
evalExpression :: A.Expression -> EvalM OccValue
evalExpression (A.Monadic _ op e)
= do v <- evalExpression e
evalMonadic op v
evalExpression (A.Dyadic _ op e1 e2)
= do v1 <- evalExpression e1
v2 <- evalExpression e2
evalDyadic op v1 v2
evalExpression (A.MostPos _ A.Byte) = return $ OccByte maxBound
evalExpression (A.MostNeg _ A.Byte) = return $ OccByte minBound
evalExpression (A.MostPos _ A.UInt16) = return $ OccUInt16 maxBound
@ -274,6 +267,21 @@ evalExpression (A.BytesInType m t)
case b of
BIJust n -> evalExpression n
_ -> throwErrorC (Just m, formatCode "BYTESIN non-constant-size type % used" t)
evalExpression (A.FunctionCall m n es)
= do mOp <- functionOperator n
ts <- mapM (underlyingTypeOf m) es
case (mOp, es) of
(Just op, [a, b])
-- Only fold if they're using the built-in version:
| A.nameName n == occamDefaultOperator op ts
-> do a' <- evalExpression a
b' <- evalExpression b
evalDyadic op a' b'
(Just op, [a])
-- Only fold if they're using the built-in version:
| A.nameName n == occamDefaultOperator op ts
-> evalExpression a >>= evalMonadic op
_ -> throwError (Just m, "bad expression")
evalExpression e = throwError (Just $ findMeta e, "bad expression")
evalMonadicOp :: (forall t. (Num t, Integral t, Bits t) => t -> t) -> OccValue -> EvalM OccValue
@ -288,13 +296,13 @@ evalMonadicOp f (OccInt32 a) = return $ OccInt32 (f a)
evalMonadicOp f (OccInt64 a) = return $ OccInt64 (f a)
evalMonadicOp _ v = throwError (Nothing, "monadic operator not implemented for this type: " ++ show v)
evalMonadic :: A.MonadicOp -> OccValue -> EvalM OccValue
evalMonadic :: String -> OccValue -> EvalM OccValue
-- This, oddly, is probably the most important rule here: "-4" isn't a literal
-- in occam, it's an operator applied to a literal.
evalMonadic A.MonadicSubtr a = evalMonadicOp negate a
evalMonadic A.MonadicMinus a = evalMonadicOp negate a
evalMonadic A.MonadicBitNot a = evalMonadicOp complement a
evalMonadic A.MonadicNot (OccBool b) = return $ OccBool (not b)
evalMonadic "-" a = evalMonadicOp negate a
evalMonadic "MINUS" a = evalMonadicOp negate a
evalMonadic "~" a = evalMonadicOp complement a
evalMonadic "NOT" (OccBool b) = return $ OccBool (not b)
evalMonadic op _ = throwError (Nothing, "bad monadic op: " ++ show op)
evalArithOp :: (forall t. (Num t) => t -> t -> t) -> OccValue -> OccValue -> EvalM OccValue
@ -367,37 +375,38 @@ safeRem :: (Integral a, Bounded a) => a -> a -> a
safeRem a (-1) | a == minBound = 0 -- The correct answer
safeRem a b = rem a b
evalDyadic :: A.DyadicOp -> OccValue -> OccValue -> EvalM OccValue
evalDyadic :: String -> OccValue -> OccValue -> EvalM OccValue
-- FIXME These should check for overflow.
evalDyadic A.Add a b = evalArithOp (+) a b
evalDyadic A.Subtr a b = evalArithOp (-) a b
evalDyadic A.Mul a b = evalArithOp (*) a b
evalDyadic A.Div (OccReal32 a) (OccReal32 b) = return $ OccReal32 $ a / b
evalDyadic A.Div (OccReal64 a) (OccReal64 b) = return $ OccReal64 $ a / b
evalDyadic A.Div a b = evalArithIntOp safeDiv a b
evalDyadic A.Rem a b = evalArithIntOp safeRem a b
evalDyadic "+" a b = evalArithOp (+) a b
evalDyadic "-" a b = evalArithOp (-) a b
evalDyadic "*" a b = evalArithOp (*) a b
evalDyadic "/" (OccReal32 a) (OccReal32 b) = return $ OccReal32 $ a / b
evalDyadic "/" (OccReal64 a) (OccReal64 b) = return $ OccReal64 $ a / b
evalDyadic "/" a b = evalArithIntOp safeDiv a b
evalDyadic "\\" a b = evalArithIntOp safeRem a b
evalDyadic "REM" a b = evalArithIntOp safeRem a b
-- ... end FIXME
evalDyadic A.Plus a b = evalArithOp (+) a b
evalDyadic A.Minus a b = evalArithOp (-) a b
evalDyadic A.Times a b = evalArithOp (*) a b
evalDyadic A.BitAnd a b = evalLogicOp (.&.) a b
evalDyadic A.BitOr a b = evalLogicOp (.|.) a b
evalDyadic A.BitXor a b = evalLogicOp xor a b
evalDyadic A.LeftShift a (OccInt b)
evalDyadic "PLUS" a b = evalArithOp (+) a b
evalDyadic "MINUS" a b = evalArithOp (-) a b
evalDyadic "TIMES" a b = evalArithOp (*) a b
evalDyadic "/\\" a b = evalLogicOp (.&.) a b
evalDyadic "\\/" a b = evalLogicOp (.|.) a b
evalDyadic "><" a b = evalLogicOp xor a b
evalDyadic "<<" a (OccInt b)
= evalMonadicOp (\v -> shiftL v (fromIntegral b)) a
evalDyadic A.RightShift a (OccInt b)
evalDyadic ">>" a (OccInt b)
-- occam shifts are logical (no sign-extending) but Haskell only has the signed
-- shift. So we use a custom shift
= evalMonadicOp (\v -> logicalShiftR v (fromIntegral b)) a
evalDyadic A.And (OccBool a) (OccBool b) = return $ OccBool (a && b)
evalDyadic A.Or (OccBool a) (OccBool b) = return $ OccBool (a || b)
evalDyadic A.Eq a b = evalCompareOp (==) a b
evalDyadic A.NotEq a b = evalCompareOp (/=) a b
evalDyadic A.Less a b = evalCompareOp (<) a b
evalDyadic A.More a b = evalCompareOp (>) a b
evalDyadic A.LessEq a b = evalCompareOp (<=) a b
evalDyadic A.MoreEq a b = evalCompareOp (>=) a b
evalDyadic A.After (OccInt a) (OccInt b) = return $ OccBool ((a - b) > 0)
evalDyadic "AND" (OccBool a) (OccBool b) = return $ OccBool (a && b)
evalDyadic "OR" (OccBool a) (OccBool b) = return $ OccBool (a || b)
evalDyadic "=" a b = evalCompareOp (==) a b
evalDyadic "<>" a b = evalCompareOp (/=) a b
evalDyadic "<" a b = evalCompareOp (<) a b
evalDyadic ">" a b = evalCompareOp (>) a b
evalDyadic "<=" a b = evalCompareOp (<=) a b
evalDyadic ">=" a b = evalCompareOp (>=) a b
evalDyadic "AFTER" (OccInt a) (OccInt b) = return $ OccBool ((a - b) > 0)
evalDyadic op _ _ = throwError (Nothing, "bad dyadic op: " ++ show op)
--}}}