
This touches an awful lot of code, but cgtest07/17 (arrays and retyping) pass. This is useful because there are going to be places in the future where we'll want to represent dimensions that are known at runtime but not at compile time -- for example, mobile allocations, or dynamically-sized arrays. It simplifies the code in a number of places. However, we do now need to be careful that expressions containing variables do not leak into the State, since they won't be affected by later passes. Two caveats (marked as FIXMEs in the source): - Retypes checking in the occam parser is disabled, since the plan is to move it out to a pass anyway. - There's some (now very obvious) duplication, particularly in the backend, of bits of code that construct expressions for the total size of an array (either in bytes or elements); this should be moved to a couple of helper functions that everything can use.
303 lines
14 KiB
Haskell
303 lines
14 KiB
Haskell
{-
|
|
Tock: a compiler for parallel languages
|
|
Copyright (C) 2007 University of Kent
|
|
|
|
This program is free software; you can redistribute it and/or modify it
|
|
under the terms of the GNU General Public License as published by the
|
|
Free Software Foundation, either version 2 of the License, or (at your
|
|
option) any later version.
|
|
|
|
This program is distributed in the hope that it will be useful, but
|
|
WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
|
General Public License for more details.
|
|
|
|
You should have received a copy of the GNU General Public License along
|
|
with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
-}
|
|
|
|
-- | Evaluate constant expressions.
|
|
module EvalConstants (constantFold, isConstantName) where
|
|
|
|
import Control.Monad.Error
|
|
import Control.Monad.State
|
|
import Data.Bits
|
|
import Data.Char
|
|
import Data.Int
|
|
import Data.Maybe
|
|
import Text.Printf
|
|
|
|
import qualified AST as A
|
|
import CompState hiding (CSM) -- everything here is read-only
|
|
import Errors
|
|
import EvalLiterals
|
|
import Metadata
|
|
import ShowCode
|
|
import Types
|
|
|
|
-- | Simplify an expression by constant folding, and also return whether it's a
|
|
-- constant after that.
|
|
constantFold :: CSMR m => A.Expression -> m (A.Expression, Bool, ErrorReport)
|
|
constantFold e
|
|
= do ps <- getCompState
|
|
let (e', msg) = case simplifyExpression ps e of
|
|
Left err -> (e, err)
|
|
Right val -> (val, (Nothing, "already folded"))
|
|
return (e', isConstant e', msg)
|
|
|
|
-- | Is a name defined as a constant expression? If so, return its definition.
|
|
getConstantName :: (CSMR m, Die m) => A.Name -> m (Maybe A.Expression)
|
|
getConstantName n
|
|
= do st <- specTypeOfName n
|
|
case st of
|
|
A.IsExpr _ A.ValAbbrev _ e ->
|
|
if isConstant e then return $ Just e
|
|
else return Nothing
|
|
_ -> return Nothing
|
|
|
|
-- | Is a name defined as a constant expression?
|
|
isConstantName :: (CSMR m, Die m) => A.Name -> m Bool
|
|
isConstantName n
|
|
= do me <- getConstantName n
|
|
return $ case me of
|
|
Just _ -> True
|
|
Nothing -> False
|
|
|
|
-- | Attempt to simplify an expression as far as possible by precomputing
|
|
-- constant bits.
|
|
simplifyExpression :: CompState -> A.Expression -> Either ErrorReport A.Expression
|
|
simplifyExpression ps e
|
|
= case runEvaluator ps (evalExpression e) of
|
|
Left err -> Left err
|
|
Right val -> Right $ snd $ renderValue (findMeta e) val
|
|
|
|
--{{{ expression evaluator
|
|
evalLiteral :: A.Expression -> EvalM OccValue
|
|
evalLiteral (A.Literal m _ (A.ArrayLiteral _ []))
|
|
= throwError (Just m, "empty array")
|
|
evalLiteral (A.Literal _ _ (A.ArrayLiteral _ aes))
|
|
= liftM OccArray (mapM evalLiteralArray aes)
|
|
evalLiteral (A.Literal _ (A.Record n) (A.RecordLiteral _ es))
|
|
= liftM (OccRecord n) (mapM evalExpression es)
|
|
evalLiteral l = evalSimpleLiteral l
|
|
|
|
evalLiteralArray :: A.ArrayElem -> EvalM OccValue
|
|
evalLiteralArray (A.ArrayElemArray aes) = liftM OccArray (mapM evalLiteralArray aes)
|
|
evalLiteralArray (A.ArrayElemExpr e) = evalExpression e
|
|
|
|
evalVariable :: A.Variable -> EvalM OccValue
|
|
evalVariable (A.Variable m n)
|
|
= do me <- getConstantName n
|
|
case me of
|
|
Just e -> evalExpression e
|
|
Nothing -> throwError (Just m, "non-constant variable " ++ show n ++ " used")
|
|
evalVariable (A.SubscriptedVariable _ sub v) = evalVariable v >>= evalSubscript sub
|
|
evalVariable (A.DirectedVariable _ _ v) = evalVariable v
|
|
|
|
evalIndex :: A.Expression -> EvalM Int
|
|
evalIndex e
|
|
= do index <- evalExpression e
|
|
case index of
|
|
OccInt n -> return $ fromIntegral n
|
|
_ -> throwError (Just $ findMeta e, "index has non-INT type")
|
|
|
|
-- TODO should we obey the no-checking here, or not?
|
|
-- If it's not in bounds, we can't constant fold it, so no-checking would preclude constant folding...
|
|
evalSubscript :: A.Subscript -> OccValue -> EvalM OccValue
|
|
evalSubscript (A.Subscript m _ e) (OccArray vs)
|
|
= do index <- evalIndex e
|
|
if index >= 0 && index < length vs
|
|
then return $ vs !! index
|
|
else throwError (Just m, "subscript out of range")
|
|
evalSubscript s _ = throwError (Just $ findMeta s, "invalid subscript")
|
|
|
|
evalExpression :: A.Expression -> EvalM OccValue
|
|
evalExpression (A.Monadic _ op e)
|
|
= do v <- evalExpression e
|
|
evalMonadic op v
|
|
evalExpression (A.Dyadic _ op e1 e2)
|
|
= do v1 <- evalExpression e1
|
|
v2 <- evalExpression e2
|
|
evalDyadic op v1 v2
|
|
evalExpression (A.MostPos _ A.Byte) = return $ OccByte maxBound
|
|
evalExpression (A.MostNeg _ A.Byte) = return $ OccByte minBound
|
|
evalExpression (A.MostPos _ A.UInt16) = return $ OccUInt16 maxBound
|
|
evalExpression (A.MostNeg _ A.UInt16) = return $ OccUInt16 minBound
|
|
evalExpression (A.MostPos _ A.UInt32) = return $ OccUInt32 maxBound
|
|
evalExpression (A.MostNeg _ A.UInt32) = return $ OccUInt32 minBound
|
|
evalExpression (A.MostPos _ A.UInt64) = return $ OccUInt64 maxBound
|
|
evalExpression (A.MostNeg _ A.UInt64) = return $ OccUInt64 minBound
|
|
evalExpression (A.MostPos _ A.Int8) = return $ OccInt8 maxBound
|
|
evalExpression (A.MostNeg _ A.Int8) = return $ OccInt8 minBound
|
|
evalExpression (A.MostPos _ A.Int) = return $ OccInt maxBound
|
|
evalExpression (A.MostNeg _ A.Int) = return $ OccInt minBound
|
|
evalExpression (A.MostPos _ A.Int16) = return $ OccInt16 maxBound
|
|
evalExpression (A.MostNeg _ A.Int16) = return $ OccInt16 minBound
|
|
evalExpression (A.MostPos _ A.Int32) = return $ OccInt32 maxBound
|
|
evalExpression (A.MostNeg _ A.Int32) = return $ OccInt32 minBound
|
|
evalExpression (A.MostPos _ A.Int64) = return $ OccInt64 maxBound
|
|
evalExpression (A.MostNeg _ A.Int64) = return $ OccInt64 minBound
|
|
evalExpression (A.SizeExpr m e)
|
|
= do t <- typeOfExpression e >>= underlyingType m
|
|
case t of
|
|
A.Array (A.Dimension n:_) _ -> evalExpression n
|
|
_ ->
|
|
do v <- evalExpression e
|
|
case v of
|
|
OccArray vs -> return $ OccInt (fromIntegral $ length vs)
|
|
_ -> throwError (Just m, "size of non-constant expression " ++ show e ++ " used")
|
|
evalExpression (A.SizeVariable m v)
|
|
= do t <- typeOfVariable v >>= underlyingType m
|
|
case t of
|
|
A.Array (A.Dimension n:_) _ -> evalExpression n
|
|
_ -> throwError (Just m, "size of non-fixed-size variable " ++ show v ++ " used")
|
|
evalExpression e@(A.Literal _ _ _) = evalLiteral e
|
|
evalExpression (A.ExprVariable _ v) = evalVariable v
|
|
evalExpression (A.True _) = return $ OccBool True
|
|
evalExpression (A.False _) = return $ OccBool False
|
|
evalExpression (A.SubscriptedExpr _ sub e) = evalExpression e >>= evalSubscript sub
|
|
evalExpression (A.BytesInExpr m e)
|
|
= do b <- typeOfExpression e >>= underlyingType m >>= bytesInType
|
|
case b of
|
|
BIJust n -> evalExpression n
|
|
_ -> throwError (Just m, "BYTESIN non-constant-size expression " ++ show e ++ " used")
|
|
evalExpression (A.BytesInType m t)
|
|
= do b <- underlyingType m t >>= bytesInType
|
|
case b of
|
|
BIJust n -> evalExpression n
|
|
_ -> throwErrorC (Just m, formatCode "BYTESIN non-constant-size type % used" t)
|
|
evalExpression e = throwError (Just $ findMeta e, "bad expression")
|
|
|
|
evalMonadicOp :: (forall t. (Num t, Integral t, Bits t) => t -> t) -> OccValue -> EvalM OccValue
|
|
evalMonadicOp f (OccByte a) = return $ OccByte (f a)
|
|
evalMonadicOp f (OccUInt16 a) = return $ OccUInt16 (f a)
|
|
evalMonadicOp f (OccUInt32 a) = return $ OccUInt32 (f a)
|
|
evalMonadicOp f (OccUInt64 a) = return $ OccUInt64 (f a)
|
|
evalMonadicOp f (OccInt8 a) = return $ OccInt8 (f a)
|
|
evalMonadicOp f (OccInt a) = return $ OccInt (f a)
|
|
evalMonadicOp f (OccInt16 a) = return $ OccInt16 (f a)
|
|
evalMonadicOp f (OccInt32 a) = return $ OccInt32 (f a)
|
|
evalMonadicOp f (OccInt64 a) = return $ OccInt64 (f a)
|
|
evalMonadicOp _ v = throwError (Nothing, "monadic operator not implemented for this type: " ++ show v)
|
|
|
|
evalMonadic :: A.MonadicOp -> OccValue -> EvalM OccValue
|
|
-- This, oddly, is probably the most important rule here: "-4" isn't a literal
|
|
-- in occam, it's an operator applied to a literal.
|
|
evalMonadic A.MonadicSubtr a = evalMonadicOp negate a
|
|
evalMonadic A.MonadicMinus a = evalMonadicOp negate a
|
|
evalMonadic A.MonadicBitNot a = evalMonadicOp complement a
|
|
evalMonadic A.MonadicNot (OccBool b) = return $ OccBool (not b)
|
|
evalMonadic op _ = throwError (Nothing, "bad monadic op: " ++ show op)
|
|
|
|
evalDyadicOp :: (forall t. (Num t, Integral t, Bits t) => t -> t -> t) -> OccValue -> OccValue -> EvalM OccValue
|
|
evalDyadicOp f (OccByte a) (OccByte b) = return $ OccByte (f a b)
|
|
evalDyadicOp f (OccUInt16 a) (OccUInt16 b) = return $ OccUInt16 (f a b)
|
|
evalDyadicOp f (OccUInt32 a) (OccUInt32 b) = return $ OccUInt32 (f a b)
|
|
evalDyadicOp f (OccUInt64 a) (OccUInt64 b) = return $ OccUInt64 (f a b)
|
|
evalDyadicOp f (OccInt8 a) (OccInt8 b) = return $ OccInt8 (f a b)
|
|
evalDyadicOp f (OccInt a) (OccInt b) = return $ OccInt (f a b)
|
|
evalDyadicOp f (OccInt16 a) (OccInt16 b) = return $ OccInt16 (f a b)
|
|
evalDyadicOp f (OccInt32 a) (OccInt32 b) = return $ OccInt32 (f a b)
|
|
evalDyadicOp f (OccInt64 a) (OccInt64 b) = return $ OccInt64 (f a b)
|
|
evalDyadicOp _ v0 v1 = throwError (Nothing, "dyadic operator not implemented for these types: " ++ show v0 ++ " and " ++ show v1)
|
|
|
|
evalCompareOp :: (forall t. (Eq t, Ord t) => t -> t -> Bool) -> OccValue -> OccValue -> EvalM OccValue
|
|
evalCompareOp f (OccByte a) (OccByte b) = return $ OccBool (f a b)
|
|
evalCompareOp f (OccUInt16 a) (OccUInt16 b) = return $ OccBool (f a b)
|
|
evalCompareOp f (OccUInt32 a) (OccUInt32 b) = return $ OccBool (f a b)
|
|
evalCompareOp f (OccUInt64 a) (OccUInt64 b) = return $ OccBool (f a b)
|
|
evalCompareOp f (OccInt8 a) (OccInt8 b) = return $ OccBool (f a b)
|
|
evalCompareOp f (OccInt a) (OccInt b) = return $ OccBool (f a b)
|
|
evalCompareOp f (OccInt16 a) (OccInt16 b) = return $ OccBool (f a b)
|
|
evalCompareOp f (OccInt32 a) (OccInt32 b) = return $ OccBool (f a b)
|
|
evalCompareOp f (OccInt64 a) (OccInt64 b) = return $ OccBool (f a b)
|
|
evalCompareOp _ v0 v1 = throwError (Nothing, "comparison operator not implemented for these types: " ++ show v0 ++ " and " ++ show v1)
|
|
|
|
-- The idea is: set the lower N bits to zero,
|
|
-- then rotate right by N.
|
|
logicalShiftR :: Bits a => a -> Int -> a
|
|
logicalShiftR val 0 = val
|
|
logicalShiftR val n = rotateR (foldl clearBit val [0 .. (n - 1)]) n
|
|
|
|
evalDyadic :: A.DyadicOp -> OccValue -> OccValue -> EvalM OccValue
|
|
-- FIXME These should check for overflow.
|
|
evalDyadic A.Add a b = evalDyadicOp (+) a b
|
|
evalDyadic A.Subtr a b = evalDyadicOp (-) a b
|
|
evalDyadic A.Mul a b = evalDyadicOp (*) a b
|
|
evalDyadic A.Div a b = evalDyadicOp div a b
|
|
evalDyadic A.Rem a b = evalDyadicOp rem a b
|
|
-- ... end FIXME
|
|
evalDyadic A.Plus a b = evalDyadicOp (+) a b
|
|
evalDyadic A.Minus a b = evalDyadicOp (-) a b
|
|
evalDyadic A.Times a b = evalDyadicOp (*) a b
|
|
evalDyadic A.BitAnd a b = evalDyadicOp (.&.) a b
|
|
evalDyadic A.BitOr a b = evalDyadicOp (.|.) a b
|
|
evalDyadic A.BitXor a b = evalDyadicOp xor a b
|
|
evalDyadic A.LeftShift a (OccInt b)
|
|
= evalMonadicOp (\v -> shiftL v (fromIntegral b)) a
|
|
evalDyadic A.RightShift a (OccInt b)
|
|
-- occam shifts are logical (no sign-extending) but Haskell only has the signed
|
|
-- shift. So we use a custom shift
|
|
= evalMonadicOp (\v -> logicalShiftR v (fromIntegral b)) a
|
|
evalDyadic A.And (OccBool a) (OccBool b) = return $ OccBool (a && b)
|
|
evalDyadic A.Or (OccBool a) (OccBool b) = return $ OccBool (a || b)
|
|
evalDyadic A.Eq a b = evalCompareOp (==) a b
|
|
evalDyadic A.NotEq a b = evalCompareOp (/=) a b
|
|
evalDyadic A.Less a b = evalCompareOp (<) a b
|
|
evalDyadic A.More a b = evalCompareOp (>) a b
|
|
evalDyadic A.LessEq a b = evalCompareOp (<=) a b
|
|
evalDyadic A.MoreEq a b = evalCompareOp (>=) a b
|
|
evalDyadic A.After (OccInt a) (OccInt b) = return $ OccBool ((a - b) > 0)
|
|
evalDyadic op _ _ = throwError (Nothing, "bad dyadic op: " ++ show op)
|
|
--}}}
|
|
|
|
--{{{ rendering values
|
|
-- | Convert a value back into a literal.
|
|
renderValue :: Meta -> OccValue -> (A.Type, A.Expression)
|
|
renderValue m (OccBool True) = (A.Bool, A.True m)
|
|
renderValue m (OccBool False) = (A.Bool, A.False m)
|
|
renderValue m v = (t, A.Literal m t lr)
|
|
where (t, lr) = renderLiteral m v
|
|
|
|
renderLiteral :: Meta -> OccValue -> (A.Type, A.LiteralRepr)
|
|
renderLiteral m (OccByte c) = (A.Byte, A.ByteLiteral m $ renderChar (chr $ fromIntegral c))
|
|
renderLiteral m (OccUInt16 i) = (A.UInt16, A.IntLiteral m $ show i)
|
|
renderLiteral m (OccUInt32 i) = (A.UInt32, A.IntLiteral m $ show i)
|
|
renderLiteral m (OccUInt64 i) = (A.UInt64, A.IntLiteral m $ show i)
|
|
renderLiteral m (OccInt8 i) = (A.Int8, A.IntLiteral m $ show i)
|
|
renderLiteral m (OccInt i) = (A.Int, A.IntLiteral m $ show i)
|
|
renderLiteral m (OccInt16 i) = (A.Int16, A.IntLiteral m $ show i)
|
|
renderLiteral m (OccInt32 i) = (A.Int32, A.IntLiteral m $ show i)
|
|
renderLiteral m (OccInt64 i) = (A.Int64, A.IntLiteral m $ show i)
|
|
renderLiteral m (OccArray vs)
|
|
= (t, A.ArrayLiteral m aes)
|
|
where
|
|
t = addDimensions [makeDimension m $ length vs] (head ts)
|
|
(ts, aes) = unzip $ map (renderLiteralArray m) vs
|
|
renderLiteral m (OccRecord n vs)
|
|
= (A.Record n, A.RecordLiteral m (map (snd . renderValue m) vs))
|
|
|
|
renderChar :: Char -> String
|
|
renderChar '\'' = "*'"
|
|
renderChar '\"' = "*\""
|
|
renderChar '*' = "**"
|
|
renderChar '\r' = "*c"
|
|
renderChar '\n' = "*n"
|
|
renderChar '\t' = "*t"
|
|
renderChar c
|
|
| (o < 32 || o > 127) = printf "*#%02x" o
|
|
| otherwise = [c]
|
|
where o = ord c
|
|
|
|
renderLiteralArray :: Meta -> OccValue -> (A.Type, A.ArrayElem)
|
|
renderLiteralArray m (OccArray vs)
|
|
= (t, A.ArrayElemArray aes)
|
|
where
|
|
t = addDimensions [makeDimension m $ length vs] (head ts)
|
|
(ts, aes) = unzip $ map (renderLiteralArray m) vs
|
|
renderLiteralArray m v
|
|
= (t, A.ArrayElemExpr e)
|
|
where
|
|
(t, e) = renderValue m v
|
|
--}}}
|