
The compiler itself is under the GPLv2+; the support code that gets built into user programs is under the LGPLv2+. This matches the existing practice for the KRoC project. (As with Occade, I've used the new GPLv3-style license header in the source files, though, since that avoids having to update the FSF's postal address.)
271 lines
11 KiB
Haskell
271 lines
11 KiB
Haskell
{-
|
|
Tock: a compiler for parallel languages
|
|
Copyright (C) 2007 University of Kent
|
|
|
|
This program is free software; you can redistribute it and/or modify it
|
|
under the terms of the GNU General Public License as published by the
|
|
Free Software Foundation, either version 2 of the License, or (at your
|
|
option) any later version.
|
|
|
|
This program is distributed in the hope that it will be useful, but
|
|
WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
|
General Public License for more details.
|
|
|
|
You should have received a copy of the GNU General Public License along
|
|
with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
-}
|
|
|
|
-- | Evaluate constant expressions.
|
|
module EvalConstants (constantFold, isConstantName) where
|
|
|
|
import Control.Monad.Error
|
|
import Control.Monad.Identity
|
|
import Control.Monad.State
|
|
import Data.Bits
|
|
import Data.Char
|
|
import Data.Generics
|
|
import Data.Int
|
|
import Data.Maybe
|
|
import Data.Word
|
|
import Numeric
|
|
import Text.Printf
|
|
|
|
import qualified AST as A
|
|
import CompState
|
|
import Errors
|
|
import EvalLiterals
|
|
import Metadata
|
|
import Pass
|
|
import Types
|
|
|
|
-- | Simplify an expression by constant folding, and also return whether it's a
|
|
-- constant after that.
|
|
constantFold :: CSM m => A.Expression -> m (A.Expression, Bool, String)
|
|
constantFold e
|
|
= do ps <- get
|
|
let (e', msg) = case simplifyExpression ps e of
|
|
Left err -> (e, err)
|
|
Right val -> (val, "already folded")
|
|
return (e', isConstant e', msg)
|
|
|
|
-- | Is a name defined as a constant expression? If so, return its definition.
|
|
getConstantName :: (CSM m, Die m) => A.Name -> m (Maybe A.Expression)
|
|
getConstantName n
|
|
= do st <- specTypeOfName n
|
|
case st of
|
|
A.IsExpr _ A.ValAbbrev _ e ->
|
|
if isConstant e then return $ Just e
|
|
else return Nothing
|
|
_ -> return Nothing
|
|
|
|
-- | Is a name defined as a constant expression?
|
|
isConstantName :: (CSM m, Die m) => A.Name -> m Bool
|
|
isConstantName n
|
|
= do me <- getConstantName n
|
|
return $ case me of
|
|
Just _ -> True
|
|
Nothing -> False
|
|
|
|
-- | Attempt to simplify an expression as far as possible by precomputing
|
|
-- constant bits.
|
|
simplifyExpression :: CompState -> A.Expression -> Either String A.Expression
|
|
simplifyExpression ps e
|
|
= case runEvaluator ps (evalExpression e) of
|
|
Left err -> Left err
|
|
Right val -> Right $ snd $ renderValue (findMeta e) val
|
|
|
|
--{{{ expression evaluator
|
|
evalLiteral :: A.Expression -> EvalM OccValue
|
|
evalLiteral (A.Literal _ _ (A.ArrayLiteral _ []))
|
|
= throwError "empty array"
|
|
evalLiteral (A.Literal _ _ (A.ArrayLiteral _ aes))
|
|
= liftM OccArray (mapM evalLiteralArray aes)
|
|
evalLiteral (A.Literal _ (A.Record n) (A.RecordLiteral _ es))
|
|
= liftM (OccRecord n) (mapM evalExpression es)
|
|
evalLiteral l = evalSimpleLiteral l
|
|
|
|
evalLiteralArray :: A.ArrayElem -> EvalM OccValue
|
|
evalLiteralArray (A.ArrayElemArray aes) = liftM OccArray (mapM evalLiteralArray aes)
|
|
evalLiteralArray (A.ArrayElemExpr e) = evalExpression e
|
|
|
|
evalVariable :: A.Variable -> EvalM OccValue
|
|
evalVariable (A.Variable _ n)
|
|
= do me <- getConstantName n
|
|
case me of
|
|
Just e -> evalExpression e
|
|
Nothing -> throwError $ "non-constant variable " ++ show n ++ " used"
|
|
evalVariable (A.SubscriptedVariable _ sub v) = evalVariable v >>= evalSubscript sub
|
|
|
|
evalIndex :: A.Expression -> EvalM Int
|
|
evalIndex e
|
|
= do index <- evalExpression e
|
|
case index of
|
|
OccInt n -> return $ fromIntegral n
|
|
_ -> throwError $ "index has non-INT type"
|
|
|
|
evalSubscript :: A.Subscript -> OccValue -> EvalM OccValue
|
|
evalSubscript (A.Subscript _ e) (OccArray vs)
|
|
= do index <- evalIndex e
|
|
if index >= 0 && index < length vs
|
|
then return $ vs !! index
|
|
else throwError $ "subscript out of range"
|
|
evalSubscript _ _ = throwError $ "invalid subscript"
|
|
|
|
evalExpression :: A.Expression -> EvalM OccValue
|
|
evalExpression (A.Monadic _ op e)
|
|
= do v <- evalExpression e
|
|
evalMonadic op v
|
|
evalExpression (A.Dyadic _ op e1 e2)
|
|
= do v1 <- evalExpression e1
|
|
v2 <- evalExpression e2
|
|
evalDyadic op v1 v2
|
|
evalExpression (A.MostPos _ A.Byte) = return $ OccByte maxBound
|
|
evalExpression (A.MostNeg _ A.Byte) = return $ OccByte minBound
|
|
evalExpression (A.MostPos _ A.Int) = return $ OccInt maxBound
|
|
evalExpression (A.MostNeg _ A.Int) = return $ OccInt minBound
|
|
evalExpression (A.MostPos _ A.Int16) = return $ OccInt16 maxBound
|
|
evalExpression (A.MostNeg _ A.Int16) = return $ OccInt16 minBound
|
|
evalExpression (A.MostPos _ A.Int32) = return $ OccInt32 maxBound
|
|
evalExpression (A.MostNeg _ A.Int32) = return $ OccInt32 minBound
|
|
evalExpression (A.MostPos _ A.Int64) = return $ OccInt64 maxBound
|
|
evalExpression (A.MostNeg _ A.Int64) = return $ OccInt64 minBound
|
|
evalExpression (A.SizeExpr _ e)
|
|
= do t <- typeOfExpression e >>= underlyingType
|
|
case t of
|
|
A.Array (A.Dimension n:_) _ -> return $ OccInt (fromIntegral n)
|
|
_ ->
|
|
do v <- evalExpression e
|
|
case v of
|
|
OccArray vs -> return $ OccInt (fromIntegral $ length vs)
|
|
_ -> throwError $ "size of non-constant expression " ++ show e ++ " used"
|
|
evalExpression (A.SizeVariable m v)
|
|
= do t <- typeOfVariable v >>= underlyingType
|
|
case t of
|
|
A.Array (A.Dimension n:_) _ -> return $ OccInt (fromIntegral n)
|
|
_ -> throwError $ "size of non-fixed-size variable " ++ show v ++ " used"
|
|
evalExpression e@(A.Literal _ _ _) = evalLiteral e
|
|
evalExpression (A.ExprVariable _ v) = evalVariable v
|
|
evalExpression (A.True _) = return $ OccBool True
|
|
evalExpression (A.False _) = return $ OccBool False
|
|
evalExpression (A.SubscriptedExpr _ sub e) = evalExpression e >>= evalSubscript sub
|
|
evalExpression (A.BytesInExpr _ e)
|
|
= do b <- typeOfExpression e >>= underlyingType >>= bytesInType
|
|
case b of
|
|
BIJust n -> return $ OccInt (fromIntegral $ n)
|
|
_ -> throwError $ "BYTESIN non-constant-size expression " ++ show e ++ " used"
|
|
evalExpression (A.BytesInType _ t)
|
|
= do b <- underlyingType t >>= bytesInType
|
|
case b of
|
|
BIJust n -> return $ OccInt (fromIntegral $ n)
|
|
_ -> throwError $ "BYTESIN non-constant-size type " ++ show t ++ " used"
|
|
evalExpression e = throwError "bad expression"
|
|
|
|
evalMonadicOp :: (forall t. (Num t, Integral t, Bits t) => t -> t) -> OccValue -> EvalM OccValue
|
|
evalMonadicOp f (OccByte a) = return $ OccByte (f a)
|
|
evalMonadicOp f (OccInt a) = return $ OccInt (f a)
|
|
evalMonadicOp f (OccInt16 a) = return $ OccInt16 (f a)
|
|
evalMonadicOp f (OccInt32 a) = return $ OccInt32 (f a)
|
|
evalMonadicOp f (OccInt64 a) = return $ OccInt64 (f a)
|
|
evalMonadicOp _ _ = throwError "monadic operator not implemented for this type"
|
|
|
|
evalMonadic :: A.MonadicOp -> OccValue -> EvalM OccValue
|
|
-- This, oddly, is probably the most important rule here: "-4" isn't a literal
|
|
-- in occam, it's an operator applied to a literal.
|
|
evalMonadic A.MonadicSubtr a = evalMonadicOp negate a
|
|
evalMonadic A.MonadicBitNot a = evalMonadicOp complement a
|
|
evalMonadic A.MonadicNot (OccBool b) = return $ OccBool (not b)
|
|
evalMonadic _ _ = throwError "bad monadic op"
|
|
|
|
evalDyadicOp :: (forall t. (Num t, Integral t, Bits t) => t -> t -> t) -> OccValue -> OccValue -> EvalM OccValue
|
|
evalDyadicOp f (OccByte a) (OccByte b) = return $ OccByte (f a b)
|
|
evalDyadicOp f (OccInt a) (OccInt b) = return $ OccInt (f a b)
|
|
evalDyadicOp f (OccInt16 a) (OccInt16 b) = return $ OccInt16 (f a b)
|
|
evalDyadicOp f (OccInt32 a) (OccInt32 b) = return $ OccInt32 (f a b)
|
|
evalDyadicOp f (OccInt64 a) (OccInt64 b) = return $ OccInt64 (f a b)
|
|
evalDyadicOp _ _ _ = throwError "dyadic operator not implemented for this type"
|
|
|
|
evalCompareOp :: (forall t. (Eq t, Ord t) => t -> t -> Bool) -> OccValue -> OccValue -> EvalM OccValue
|
|
evalCompareOp f (OccByte a) (OccByte b) = return $ OccBool (f a b)
|
|
evalCompareOp f (OccInt a) (OccInt b) = return $ OccBool (f a b)
|
|
evalCompareOp f (OccInt16 a) (OccInt16 b) = return $ OccBool (f a b)
|
|
evalCompareOp f (OccInt32 a) (OccInt32 b) = return $ OccBool (f a b)
|
|
evalCompareOp f (OccInt64 a) (OccInt64 b) = return $ OccBool (f a b)
|
|
evalCompareOp _ _ _ = throwError "comparison operator not implemented for this type"
|
|
|
|
evalDyadic :: A.DyadicOp -> OccValue -> OccValue -> EvalM OccValue
|
|
-- FIXME These should check for overflow.
|
|
evalDyadic A.Add a b = evalDyadicOp (+) a b
|
|
evalDyadic A.Subtr a b = evalDyadicOp (-) a b
|
|
evalDyadic A.Mul a b = evalDyadicOp (*) a b
|
|
evalDyadic A.Div a b = evalDyadicOp div a b
|
|
evalDyadic A.Rem a b = evalDyadicOp rem a b
|
|
-- ... end FIXME
|
|
evalDyadic A.Plus a b = evalDyadicOp (+) a b
|
|
evalDyadic A.Minus a b = evalDyadicOp (-) a b
|
|
evalDyadic A.Times a b = evalDyadicOp (*) a b
|
|
evalDyadic A.BitAnd a b = evalDyadicOp (.&.) a b
|
|
evalDyadic A.BitOr a b = evalDyadicOp (.|.) a b
|
|
evalDyadic A.BitXor a b = evalDyadicOp xor a b
|
|
evalDyadic A.LeftShift a (OccInt b)
|
|
= evalMonadicOp (\v -> shiftL v (fromIntegral b)) a
|
|
evalDyadic A.RightShift a (OccInt b)
|
|
= evalMonadicOp (\v -> shiftR v (fromIntegral b)) a
|
|
evalDyadic A.And (OccBool a) (OccBool b) = return $ OccBool (a && b)
|
|
evalDyadic A.Or (OccBool a) (OccBool b) = return $ OccBool (a || b)
|
|
evalDyadic A.Eq a b = evalCompareOp (==) a b
|
|
evalDyadic A.NotEq a b = evalCompareOp (/=) a b
|
|
evalDyadic A.Less a b = evalCompareOp (<) a b
|
|
evalDyadic A.More a b = evalCompareOp (>) a b
|
|
evalDyadic A.LessEq a b = evalCompareOp (<=) a b
|
|
evalDyadic A.MoreEq a b = evalCompareOp (>=) a b
|
|
evalDyadic A.After (OccInt a) (OccInt b) = return $ OccBool ((a - b) > 0)
|
|
evalDyadic _ _ _ = throwError "bad dyadic op"
|
|
--}}}
|
|
|
|
--{{{ rendering values
|
|
-- | Convert a value back into a literal.
|
|
renderValue :: Meta -> OccValue -> (A.Type, A.Expression)
|
|
renderValue m (OccBool True) = (A.Bool, A.True m)
|
|
renderValue m (OccBool False) = (A.Bool, A.False m)
|
|
renderValue m v = (t, A.Literal m t lr)
|
|
where (t, lr) = renderLiteral m v
|
|
|
|
renderLiteral :: Meta -> OccValue -> (A.Type, A.LiteralRepr)
|
|
renderLiteral m (OccByte c) = (A.Byte, A.ByteLiteral m $ renderChar (chr $ fromIntegral c))
|
|
renderLiteral m (OccInt i) = (A.Int, A.IntLiteral m $ show i)
|
|
renderLiteral m (OccInt16 i) = (A.Int16, A.IntLiteral m $ show i)
|
|
renderLiteral m (OccInt32 i) = (A.Int32, A.IntLiteral m $ show i)
|
|
renderLiteral m (OccInt64 i) = (A.Int64, A.IntLiteral m $ show i)
|
|
renderLiteral m (OccArray vs)
|
|
= (t, A.ArrayLiteral m aes)
|
|
where
|
|
t = makeArrayType (A.Dimension $ length vs) (head ts)
|
|
(ts, aes) = unzip $ map (renderLiteralArray m) vs
|
|
renderLiteral m (OccRecord n vs)
|
|
= (A.Record n, A.RecordLiteral m (map (snd . renderValue m) vs))
|
|
|
|
renderChar :: Char -> String
|
|
renderChar '\'' = "*'"
|
|
renderChar '\"' = "*\""
|
|
renderChar '*' = "**"
|
|
renderChar '\r' = "*c"
|
|
renderChar '\n' = "*n"
|
|
renderChar '\t' = "*t"
|
|
renderChar c
|
|
| (o < 32 || o > 127) = printf "*#%02x" o
|
|
| otherwise = [c]
|
|
where o = ord c
|
|
|
|
renderLiteralArray :: Meta -> OccValue -> (A.Type, A.ArrayElem)
|
|
renderLiteralArray m (OccArray vs)
|
|
= (t, A.ArrayElemArray aes)
|
|
where
|
|
t = makeArrayType (A.Dimension $ length vs) (head ts)
|
|
(ts, aes) = unzip $ map (renderLiteralArray m) vs
|
|
renderLiteralArray m v
|
|
= (t, A.ArrayElemExpr e)
|
|
where
|
|
(t, e) = renderValue m v
|
|
--}}}
|