
This also fixes a bug in the original algorithm: it used to let you retype []INT to BYTE.
325 lines
14 KiB
Haskell
325 lines
14 KiB
Haskell
{-
|
|
Tock: a compiler for parallel languages
|
|
Copyright (C) 2007 University of Kent
|
|
|
|
This program is free software; you can redistribute it and/or modify it
|
|
under the terms of the GNU General Public License as published by the
|
|
Free Software Foundation, either version 2 of the License, or (at your
|
|
option) any later version.
|
|
|
|
This program is distributed in the hope that it will be useful, but
|
|
WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
|
General Public License for more details.
|
|
|
|
You should have received a copy of the GNU General Public License along
|
|
with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
-}
|
|
|
|
-- | Evaluate constant expressions.
|
|
module EvalConstants (constantFold, maybeEvalIntExpression, isConstantName) where
|
|
|
|
import Control.Monad.Error
|
|
import Control.Monad.State
|
|
import Data.Bits
|
|
import Data.Char
|
|
import Data.Int
|
|
import Data.Maybe
|
|
import Text.Printf
|
|
|
|
import qualified AST as A
|
|
import CompState hiding (CSM) -- everything here is read-only
|
|
import Errors
|
|
import EvalLiterals
|
|
import Metadata
|
|
import ShowCode
|
|
import Types
|
|
import Utils
|
|
|
|
-- | Simplify an expression by constant folding, and also return whether it's a
|
|
-- constant after that.
|
|
constantFold :: CSMR m => A.Expression -> m (A.Expression, Bool, ErrorReport)
|
|
constantFold e
|
|
= do ps <- getCompState
|
|
let (e', msg) = case simplifyExpression ps e of
|
|
Left err -> (e, err)
|
|
Right val -> (val, (Nothing, "already folded"))
|
|
return (e', isConstant e', msg)
|
|
|
|
-- | Try to fold and evaluate an integer expression.
|
|
-- If it's not a constant, return 'Nothing'.
|
|
maybeEvalIntExpression :: (CSMR m, Die m) => A.Expression -> m (Maybe Int)
|
|
maybeEvalIntExpression e
|
|
= do (e', isConst, _) <- constantFold e
|
|
if isConst
|
|
then evalIntExpression e' >>* Just
|
|
else return Nothing
|
|
|
|
-- | Is a name defined as a constant expression? If so, return its definition.
|
|
getConstantName :: (CSMR m, Die m) => A.Name -> m (Maybe A.Expression)
|
|
getConstantName n
|
|
= do st <- specTypeOfName n
|
|
case st of
|
|
A.IsExpr _ A.ValAbbrev _ e ->
|
|
if isConstant e then return $ Just e
|
|
else return Nothing
|
|
_ -> return Nothing
|
|
|
|
-- | Is a name defined as a constant expression?
|
|
isConstantName :: (CSMR m, Die m) => A.Name -> m Bool
|
|
isConstantName n
|
|
= do me <- getConstantName n
|
|
return $ case me of
|
|
Just _ -> True
|
|
Nothing -> False
|
|
|
|
-- | Attempt to simplify an expression as far as possible by precomputing
|
|
-- constant bits.
|
|
simplifyExpression :: CompState -> A.Expression -> Either ErrorReport A.Expression
|
|
simplifyExpression ps e
|
|
= case runEvaluator ps (evalExpression e) of
|
|
Left err -> Left err
|
|
Right val -> Right $ snd $ renderValue (findMeta e) val
|
|
|
|
--{{{ expression evaluator
|
|
evalLiteral :: A.Expression -> EvalM OccValue
|
|
evalLiteral (A.Literal m _ (A.ArrayLiteral _ []))
|
|
= throwError (Just m, "empty array")
|
|
evalLiteral (A.Literal _ _ (A.ArrayLiteral _ aes))
|
|
= liftM OccArray (mapM evalLiteralArray aes)
|
|
evalLiteral (A.Literal _ (A.Record n) (A.RecordLiteral _ es))
|
|
= liftM (OccRecord n) (mapM evalExpression es)
|
|
evalLiteral l = evalSimpleLiteral l
|
|
|
|
evalLiteralArray :: A.ArrayElem -> EvalM OccValue
|
|
evalLiteralArray (A.ArrayElemArray aes) = liftM OccArray (mapM evalLiteralArray aes)
|
|
evalLiteralArray (A.ArrayElemExpr e) = evalExpression e
|
|
|
|
evalVariable :: A.Variable -> EvalM OccValue
|
|
evalVariable (A.Variable m n)
|
|
= do me <- getConstantName n
|
|
case me of
|
|
Just e -> evalExpression e
|
|
Nothing -> throwError (Just m, "non-constant variable " ++ show n ++ " used")
|
|
evalVariable (A.SubscriptedVariable _ sub v) = evalVariable v >>= evalSubscript sub
|
|
evalVariable (A.DirectedVariable _ _ v) = evalVariable v
|
|
|
|
evalIndex :: A.Expression -> EvalM Int
|
|
evalIndex e
|
|
= do index <- evalExpression e
|
|
case index of
|
|
OccInt n -> return $ fromIntegral n
|
|
_ -> throwError (Just $ findMeta e, "index has non-INT type")
|
|
|
|
-- TODO should we obey the no-checking here, or not?
|
|
-- If it's not in bounds, we can't constant fold it, so no-checking would preclude constant folding...
|
|
evalSubscript :: A.Subscript -> OccValue -> EvalM OccValue
|
|
evalSubscript (A.Subscript m _ e) (OccArray vs)
|
|
= do index <- evalIndex e
|
|
if index >= 0 && index < length vs
|
|
then return $ vs !! index
|
|
else throwError (Just m, "subscript out of range")
|
|
evalSubscript s _ = throwError (Just $ findMeta s, "invalid subscript")
|
|
|
|
evalExpression :: A.Expression -> EvalM OccValue
|
|
evalExpression (A.Monadic _ op e)
|
|
= do v <- evalExpression e
|
|
evalMonadic op v
|
|
evalExpression (A.Dyadic _ op e1 e2)
|
|
= do v1 <- evalExpression e1
|
|
v2 <- evalExpression e2
|
|
evalDyadic op v1 v2
|
|
evalExpression (A.MostPos _ A.Byte) = return $ OccByte maxBound
|
|
evalExpression (A.MostNeg _ A.Byte) = return $ OccByte minBound
|
|
evalExpression (A.MostPos _ A.UInt16) = return $ OccUInt16 maxBound
|
|
evalExpression (A.MostNeg _ A.UInt16) = return $ OccUInt16 minBound
|
|
evalExpression (A.MostPos _ A.UInt32) = return $ OccUInt32 maxBound
|
|
evalExpression (A.MostNeg _ A.UInt32) = return $ OccUInt32 minBound
|
|
evalExpression (A.MostPos _ A.UInt64) = return $ OccUInt64 maxBound
|
|
evalExpression (A.MostNeg _ A.UInt64) = return $ OccUInt64 minBound
|
|
evalExpression (A.MostPos _ A.Int8) = return $ OccInt8 maxBound
|
|
evalExpression (A.MostNeg _ A.Int8) = return $ OccInt8 minBound
|
|
evalExpression (A.MostPos _ A.Int) = return $ OccInt maxBound
|
|
evalExpression (A.MostNeg _ A.Int) = return $ OccInt minBound
|
|
evalExpression (A.MostPos _ A.Int16) = return $ OccInt16 maxBound
|
|
evalExpression (A.MostNeg _ A.Int16) = return $ OccInt16 minBound
|
|
evalExpression (A.MostPos _ A.Int32) = return $ OccInt32 maxBound
|
|
evalExpression (A.MostNeg _ A.Int32) = return $ OccInt32 minBound
|
|
evalExpression (A.MostPos _ A.Int64) = return $ OccInt64 maxBound
|
|
evalExpression (A.MostNeg _ A.Int64) = return $ OccInt64 minBound
|
|
evalExpression (A.SizeExpr m e)
|
|
= do t <- typeOfExpression e >>= underlyingType m
|
|
case t of
|
|
A.Array (A.Dimension n:_) _ -> evalExpression n
|
|
_ ->
|
|
do v <- evalExpression e
|
|
case v of
|
|
OccArray vs -> return $ OccInt (fromIntegral $ length vs)
|
|
_ -> throwError (Just m, "size of non-constant expression " ++ show e ++ " used")
|
|
evalExpression (A.SizeVariable m v)
|
|
= do t <- typeOfVariable v >>= underlyingType m
|
|
case t of
|
|
A.Array (A.Dimension n:_) _ -> evalExpression n
|
|
_ -> throwError (Just m, "size of non-fixed-size variable " ++ show v ++ " used")
|
|
evalExpression e@(A.Literal _ _ _) = evalLiteral e
|
|
evalExpression (A.ExprVariable _ v) = evalVariable v
|
|
evalExpression (A.True _) = return $ OccBool True
|
|
evalExpression (A.False _) = return $ OccBool False
|
|
evalExpression (A.SubscriptedExpr _ sub e) = evalExpression e >>= evalSubscript sub
|
|
evalExpression (A.BytesInExpr m e)
|
|
= do b <- typeOfExpression e >>= underlyingType m >>= bytesInType
|
|
case b of
|
|
BIJust n -> evalExpression n
|
|
_ -> throwError (Just m, "BYTESIN non-constant-size expression " ++ show e ++ " used")
|
|
evalExpression (A.BytesInType m t)
|
|
= do b <- underlyingType m t >>= bytesInType
|
|
case b of
|
|
BIJust n -> evalExpression n
|
|
_ -> throwErrorC (Just m, formatCode "BYTESIN non-constant-size type % used" t)
|
|
evalExpression e = throwError (Just $ findMeta e, "bad expression")
|
|
|
|
evalMonadicOp :: (forall t. (Num t, Integral t, Bits t) => t -> t) -> OccValue -> EvalM OccValue
|
|
evalMonadicOp f (OccByte a) = return $ OccByte (f a)
|
|
evalMonadicOp f (OccUInt16 a) = return $ OccUInt16 (f a)
|
|
evalMonadicOp f (OccUInt32 a) = return $ OccUInt32 (f a)
|
|
evalMonadicOp f (OccUInt64 a) = return $ OccUInt64 (f a)
|
|
evalMonadicOp f (OccInt8 a) = return $ OccInt8 (f a)
|
|
evalMonadicOp f (OccInt a) = return $ OccInt (f a)
|
|
evalMonadicOp f (OccInt16 a) = return $ OccInt16 (f a)
|
|
evalMonadicOp f (OccInt32 a) = return $ OccInt32 (f a)
|
|
evalMonadicOp f (OccInt64 a) = return $ OccInt64 (f a)
|
|
evalMonadicOp _ v = throwError (Nothing, "monadic operator not implemented for this type: " ++ show v)
|
|
|
|
evalMonadic :: A.MonadicOp -> OccValue -> EvalM OccValue
|
|
-- This, oddly, is probably the most important rule here: "-4" isn't a literal
|
|
-- in occam, it's an operator applied to a literal.
|
|
evalMonadic A.MonadicSubtr a = evalMonadicOp negate a
|
|
evalMonadic A.MonadicMinus a = evalMonadicOp negate a
|
|
evalMonadic A.MonadicBitNot a = evalMonadicOp complement a
|
|
evalMonadic A.MonadicNot (OccBool b) = return $ OccBool (not b)
|
|
evalMonadic op _ = throwError (Nothing, "bad monadic op: " ++ show op)
|
|
|
|
evalDyadicOp :: (forall t. (Num t, Integral t, Bits t) => t -> t -> t) -> OccValue -> OccValue -> EvalM OccValue
|
|
evalDyadicOp f (OccByte a) (OccByte b) = return $ OccByte (f a b)
|
|
evalDyadicOp f (OccUInt16 a) (OccUInt16 b) = return $ OccUInt16 (f a b)
|
|
evalDyadicOp f (OccUInt32 a) (OccUInt32 b) = return $ OccUInt32 (f a b)
|
|
evalDyadicOp f (OccUInt64 a) (OccUInt64 b) = return $ OccUInt64 (f a b)
|
|
evalDyadicOp f (OccInt8 a) (OccInt8 b) = return $ OccInt8 (f a b)
|
|
evalDyadicOp f (OccInt a) (OccInt b) = return $ OccInt (f a b)
|
|
evalDyadicOp f (OccInt16 a) (OccInt16 b) = return $ OccInt16 (f a b)
|
|
evalDyadicOp f (OccInt32 a) (OccInt32 b) = return $ OccInt32 (f a b)
|
|
evalDyadicOp f (OccInt64 a) (OccInt64 b) = return $ OccInt64 (f a b)
|
|
evalDyadicOp _ v0 v1 = throwError (Nothing, "dyadic operator not implemented for these types: " ++ show v0 ++ " and " ++ show v1)
|
|
|
|
evalCompareOp :: (forall t. (Eq t, Ord t) => t -> t -> Bool) -> OccValue -> OccValue -> EvalM OccValue
|
|
evalCompareOp f (OccByte a) (OccByte b) = return $ OccBool (f a b)
|
|
evalCompareOp f (OccUInt16 a) (OccUInt16 b) = return $ OccBool (f a b)
|
|
evalCompareOp f (OccUInt32 a) (OccUInt32 b) = return $ OccBool (f a b)
|
|
evalCompareOp f (OccUInt64 a) (OccUInt64 b) = return $ OccBool (f a b)
|
|
evalCompareOp f (OccInt8 a) (OccInt8 b) = return $ OccBool (f a b)
|
|
evalCompareOp f (OccInt a) (OccInt b) = return $ OccBool (f a b)
|
|
evalCompareOp f (OccInt16 a) (OccInt16 b) = return $ OccBool (f a b)
|
|
evalCompareOp f (OccInt32 a) (OccInt32 b) = return $ OccBool (f a b)
|
|
evalCompareOp f (OccInt64 a) (OccInt64 b) = return $ OccBool (f a b)
|
|
evalCompareOp _ v0 v1 = throwError (Nothing, "comparison operator not implemented for these types: " ++ show v0 ++ " and " ++ show v1)
|
|
|
|
-- The idea is: set the lower N bits to zero,
|
|
-- then rotate right by N.
|
|
logicalShiftR :: Bits a => a -> Int -> a
|
|
logicalShiftR val 0 = val
|
|
logicalShiftR val n = rotateR (foldl clearBit val [0 .. (n - 1)]) n
|
|
|
|
-- | Equivalent to 'div', but handles @minBound `div` (-1)@ correctly.
|
|
-- (GHC's doesn't, at least as of 6.8.1.)
|
|
safeDiv :: Integral a => a -> a -> a
|
|
safeDiv a (-1) = 0
|
|
safeDiv a b = div a b
|
|
|
|
-- | Equivalent to 'rem', but handles @minBound `rem` (-1)@ correctly.
|
|
-- (GHC's doesn't, at least as of 6.8.1.)
|
|
safeRem :: Integral a => a -> a -> a
|
|
safeRem a (-1) = 0
|
|
safeRem a b = rem a b
|
|
|
|
evalDyadic :: A.DyadicOp -> OccValue -> OccValue -> EvalM OccValue
|
|
-- FIXME These should check for overflow.
|
|
evalDyadic A.Add a b = evalDyadicOp (+) a b
|
|
evalDyadic A.Subtr a b = evalDyadicOp (-) a b
|
|
evalDyadic A.Mul a b = evalDyadicOp (*) a b
|
|
evalDyadic A.Div a b = evalDyadicOp safeDiv a b
|
|
evalDyadic A.Rem a b = evalDyadicOp safeRem a b
|
|
-- ... end FIXME
|
|
evalDyadic A.Plus a b = evalDyadicOp (+) a b
|
|
evalDyadic A.Minus a b = evalDyadicOp (-) a b
|
|
evalDyadic A.Times a b = evalDyadicOp (*) a b
|
|
evalDyadic A.BitAnd a b = evalDyadicOp (.&.) a b
|
|
evalDyadic A.BitOr a b = evalDyadicOp (.|.) a b
|
|
evalDyadic A.BitXor a b = evalDyadicOp xor a b
|
|
evalDyadic A.LeftShift a (OccInt b)
|
|
= evalMonadicOp (\v -> shiftL v (fromIntegral b)) a
|
|
evalDyadic A.RightShift a (OccInt b)
|
|
-- occam shifts are logical (no sign-extending) but Haskell only has the signed
|
|
-- shift. So we use a custom shift
|
|
= evalMonadicOp (\v -> logicalShiftR v (fromIntegral b)) a
|
|
evalDyadic A.And (OccBool a) (OccBool b) = return $ OccBool (a && b)
|
|
evalDyadic A.Or (OccBool a) (OccBool b) = return $ OccBool (a || b)
|
|
evalDyadic A.Eq a b = evalCompareOp (==) a b
|
|
evalDyadic A.NotEq a b = evalCompareOp (/=) a b
|
|
evalDyadic A.Less a b = evalCompareOp (<) a b
|
|
evalDyadic A.More a b = evalCompareOp (>) a b
|
|
evalDyadic A.LessEq a b = evalCompareOp (<=) a b
|
|
evalDyadic A.MoreEq a b = evalCompareOp (>=) a b
|
|
evalDyadic A.After (OccInt a) (OccInt b) = return $ OccBool ((a - b) > 0)
|
|
evalDyadic op _ _ = throwError (Nothing, "bad dyadic op: " ++ show op)
|
|
--}}}
|
|
|
|
--{{{ rendering values
|
|
-- | Convert a value back into a literal.
|
|
renderValue :: Meta -> OccValue -> (A.Type, A.Expression)
|
|
renderValue m (OccBool True) = (A.Bool, A.True m)
|
|
renderValue m (OccBool False) = (A.Bool, A.False m)
|
|
renderValue m v = (t, A.Literal m t lr)
|
|
where (t, lr) = renderLiteral m v
|
|
|
|
renderLiteral :: Meta -> OccValue -> (A.Type, A.LiteralRepr)
|
|
renderLiteral m (OccByte c) = (A.Byte, A.ByteLiteral m $ renderChar (chr $ fromIntegral c))
|
|
renderLiteral m (OccUInt16 i) = (A.UInt16, A.IntLiteral m $ show i)
|
|
renderLiteral m (OccUInt32 i) = (A.UInt32, A.IntLiteral m $ show i)
|
|
renderLiteral m (OccUInt64 i) = (A.UInt64, A.IntLiteral m $ show i)
|
|
renderLiteral m (OccInt8 i) = (A.Int8, A.IntLiteral m $ show i)
|
|
renderLiteral m (OccInt i) = (A.Int, A.IntLiteral m $ show i)
|
|
renderLiteral m (OccInt16 i) = (A.Int16, A.IntLiteral m $ show i)
|
|
renderLiteral m (OccInt32 i) = (A.Int32, A.IntLiteral m $ show i)
|
|
renderLiteral m (OccInt64 i) = (A.Int64, A.IntLiteral m $ show i)
|
|
renderLiteral m (OccArray vs)
|
|
= (t, A.ArrayLiteral m aes)
|
|
where
|
|
t = addDimensions [makeDimension m $ length vs] (head ts)
|
|
(ts, aes) = unzip $ map (renderLiteralArray m) vs
|
|
renderLiteral m (OccRecord n vs)
|
|
= (A.Record n, A.RecordLiteral m (map (snd . renderValue m) vs))
|
|
|
|
renderChar :: Char -> String
|
|
renderChar '\'' = "*'"
|
|
renderChar '\"' = "*\""
|
|
renderChar '*' = "**"
|
|
renderChar '\r' = "*c"
|
|
renderChar '\n' = "*n"
|
|
renderChar '\t' = "*t"
|
|
renderChar c
|
|
| (o < 32 || o > 127) = printf "*#%02x" o
|
|
| otherwise = [c]
|
|
where o = ord c
|
|
|
|
renderLiteralArray :: Meta -> OccValue -> (A.Type, A.ArrayElem)
|
|
renderLiteralArray m (OccArray vs)
|
|
= (t, A.ArrayElemArray aes)
|
|
where
|
|
t = addDimensions [makeDimension m $ length vs] (head ts)
|
|
(ts, aes) = unzip $ map (renderLiteralArray m) vs
|
|
renderLiteralArray m v
|
|
= (t, A.ArrayElemExpr e)
|
|
where
|
|
(t, e) = renderValue m v
|
|
--}}}
|