tock-mirror/common/EvalConstants.hs
Neil Brown c97d1d00c8 Changed the error type from String to ErrorReport throughout the code
ErrorReport is of type (Maybe Meta, String), thereby adding an optional code position to error messages.

Die has been changed so that die and dieP are now implemented in terms of dieReport (:: ErrorReport -> m a).  This involved changing less code than changing die to be of type ErrorReport -> m a.  All that had to be changed directly was that Die instances now implement dieReport instead of die.

Any bits of code that "caught" errors has been changed so that it handles ErrorReport instead of String.  This ErrorReport is eventually, in Main, passed to dieIO, which will soon be changed to read the file in and provide the context.  Accordingly, MonadIO m has been added as a constraint to dieIO, and dieInternal has been changed to no longer use dieIO (because really we can't add the MonadIO constraint to dieInternal).

Various error messages have been changed.  Notably, all instances of fail in ParseOccam have been changed to use die or, wherever possible, dieP.  A similar thing has been done in EvalConstants and EvalLiterals.
2007-09-18 10:17:38 +00:00

298 lines
13 KiB
Haskell

{-
Tock: a compiler for parallel languages
Copyright (C) 2007 University of Kent
This program is free software; you can redistribute it and/or modify it
under the terms of the GNU General Public License as published by the
Free Software Foundation, either version 2 of the License, or (at your
option) any later version.
This program is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
General Public License for more details.
You should have received a copy of the GNU General Public License along
with this program. If not, see <http://www.gnu.org/licenses/>.
-}
-- | Evaluate constant expressions.
module EvalConstants (constantFold, isConstantName) where
import Control.Monad.Error
import Control.Monad.Identity
import Control.Monad.State
import Data.Bits
import Data.Char
import Data.Generics
import Data.Int
import Data.Maybe
import Data.Word
import Numeric
import Text.Printf
import qualified AST as A
import CompState
import Errors
import EvalLiterals
import Metadata
import Pass
import ShowCode
import Types
-- | Simplify an expression by constant folding, and also return whether it's a
-- constant after that.
constantFold :: CSM m => A.Expression -> m (A.Expression, Bool, ErrorReport)
constantFold e
= do ps <- get
let (e', msg) = case simplifyExpression ps e of
Left err -> (e, err)
Right val -> (val, (Nothing, "already folded"))
return (e', isConstant e', msg)
-- | Is a name defined as a constant expression? If so, return its definition.
getConstantName :: (CSM m, Die m) => A.Name -> m (Maybe A.Expression)
getConstantName n
= do st <- specTypeOfName n
case st of
A.IsExpr _ A.ValAbbrev _ e ->
if isConstant e then return $ Just e
else return Nothing
_ -> return Nothing
-- | Is a name defined as a constant expression?
isConstantName :: (CSM m, Die m) => A.Name -> m Bool
isConstantName n
= do me <- getConstantName n
return $ case me of
Just _ -> True
Nothing -> False
-- | Attempt to simplify an expression as far as possible by precomputing
-- constant bits.
simplifyExpression :: CompState -> A.Expression -> Either ErrorReport A.Expression
simplifyExpression ps e
= case runEvaluator ps (evalExpression e) of
Left err -> Left err
Right val -> Right $ snd $ renderValue (findMeta e) val
--{{{ expression evaluator
evalLiteral :: A.Expression -> EvalM OccValue
evalLiteral (A.Literal m _ (A.ArrayLiteral _ []))
= throwError (Just m, "empty array")
evalLiteral (A.Literal _ _ (A.ArrayLiteral _ aes))
= liftM OccArray (mapM evalLiteralArray aes)
evalLiteral (A.Literal _ (A.Record n) (A.RecordLiteral _ es))
= liftM (OccRecord n) (mapM evalExpression es)
evalLiteral l = evalSimpleLiteral l
evalLiteralArray :: A.ArrayElem -> EvalM OccValue
evalLiteralArray (A.ArrayElemArray aes) = liftM OccArray (mapM evalLiteralArray aes)
evalLiteralArray (A.ArrayElemExpr e) = evalExpression e
evalVariable :: A.Variable -> EvalM OccValue
evalVariable (A.Variable m n)
= do me <- getConstantName n
case me of
Just e -> evalExpression e
Nothing -> throwError (Just m, "non-constant variable " ++ show n ++ " used")
evalVariable (A.SubscriptedVariable _ sub v) = evalVariable v >>= evalSubscript sub
evalVariable (A.DirectedVariable _ _ v) = evalVariable v
evalIndex :: A.Expression -> EvalM Int
evalIndex e
= do index <- evalExpression e
case index of
OccInt n -> return $ fromIntegral n
_ -> throwError (Just $ findMeta e, "index has non-INT type")
evalSubscript :: A.Subscript -> OccValue -> EvalM OccValue
evalSubscript (A.Subscript m e) (OccArray vs)
= do index <- evalIndex e
if index >= 0 && index < length vs
then return $ vs !! index
else throwError (Just m, "subscript out of range")
evalSubscript s _ = throwError (Just $ findMeta s, "invalid subscript")
evalExpression :: A.Expression -> EvalM OccValue
evalExpression (A.Monadic _ op e)
= do v <- evalExpression e
evalMonadic op v
evalExpression (A.Dyadic _ op e1 e2)
= do v1 <- evalExpression e1
v2 <- evalExpression e2
evalDyadic op v1 v2
evalExpression (A.MostPos _ A.Byte) = return $ OccByte maxBound
evalExpression (A.MostNeg _ A.Byte) = return $ OccByte minBound
evalExpression (A.MostPos _ A.UInt16) = return $ OccUInt16 maxBound
evalExpression (A.MostNeg _ A.UInt16) = return $ OccUInt16 minBound
evalExpression (A.MostPos _ A.UInt32) = return $ OccUInt32 maxBound
evalExpression (A.MostNeg _ A.UInt32) = return $ OccUInt32 minBound
evalExpression (A.MostPos _ A.UInt64) = return $ OccUInt64 maxBound
evalExpression (A.MostNeg _ A.UInt64) = return $ OccUInt64 minBound
evalExpression (A.MostPos _ A.Int8) = return $ OccInt8 maxBound
evalExpression (A.MostNeg _ A.Int8) = return $ OccInt8 minBound
evalExpression (A.MostPos _ A.Int) = return $ OccInt maxBound
evalExpression (A.MostNeg _ A.Int) = return $ OccInt minBound
evalExpression (A.MostPos _ A.Int16) = return $ OccInt16 maxBound
evalExpression (A.MostNeg _ A.Int16) = return $ OccInt16 minBound
evalExpression (A.MostPos _ A.Int32) = return $ OccInt32 maxBound
evalExpression (A.MostNeg _ A.Int32) = return $ OccInt32 minBound
evalExpression (A.MostPos _ A.Int64) = return $ OccInt64 maxBound
evalExpression (A.MostNeg _ A.Int64) = return $ OccInt64 minBound
evalExpression (A.SizeExpr m e)
= do t <- typeOfExpression e >>= underlyingType
case t of
A.Array (A.Dimension n:_) _ -> return $ OccInt (fromIntegral n)
_ ->
do v <- evalExpression e
case v of
OccArray vs -> return $ OccInt (fromIntegral $ length vs)
_ -> throwError (Just m, "size of non-constant expression " ++ show e ++ " used")
evalExpression (A.SizeVariable m v)
= do t <- typeOfVariable v >>= underlyingType
case t of
A.Array (A.Dimension n:_) _ -> return $ OccInt (fromIntegral n)
_ -> throwError (Just m, "size of non-fixed-size variable " ++ show v ++ " used")
evalExpression e@(A.Literal _ _ _) = evalLiteral e
evalExpression (A.ExprVariable _ v) = evalVariable v
evalExpression (A.True _) = return $ OccBool True
evalExpression (A.False _) = return $ OccBool False
evalExpression (A.SubscriptedExpr _ sub e) = evalExpression e >>= evalSubscript sub
evalExpression (A.BytesInExpr m e)
= do b <- typeOfExpression e >>= underlyingType >>= bytesInType
case b of
BIJust n -> return $ OccInt (fromIntegral $ n)
_ -> throwError (Just m, "BYTESIN non-constant-size expression " ++ show e ++ " used")
evalExpression (A.BytesInType m t)
= do b <- underlyingType t >>= bytesInType
case b of
BIJust n -> return $ OccInt (fromIntegral $ n)
_ -> throwErrorC (Just m, formatCode "BYTESIN non-constant-size type % used" t)
evalExpression e = throwError (Just $ findMeta e, "bad expression")
evalMonadicOp :: (forall t. (Num t, Integral t, Bits t) => t -> t) -> OccValue -> EvalM OccValue
evalMonadicOp f (OccByte a) = return $ OccByte (f a)
evalMonadicOp f (OccUInt16 a) = return $ OccUInt16 (f a)
evalMonadicOp f (OccUInt32 a) = return $ OccUInt32 (f a)
evalMonadicOp f (OccUInt64 a) = return $ OccUInt64 (f a)
evalMonadicOp f (OccInt8 a) = return $ OccInt8 (f a)
evalMonadicOp f (OccInt a) = return $ OccInt (f a)
evalMonadicOp f (OccInt16 a) = return $ OccInt16 (f a)
evalMonadicOp f (OccInt32 a) = return $ OccInt32 (f a)
evalMonadicOp f (OccInt64 a) = return $ OccInt64 (f a)
evalMonadicOp _ v = throwError (Nothing, "monadic operator not implemented for this type: " ++ show v)
evalMonadic :: A.MonadicOp -> OccValue -> EvalM OccValue
-- This, oddly, is probably the most important rule here: "-4" isn't a literal
-- in occam, it's an operator applied to a literal.
evalMonadic A.MonadicSubtr a = evalMonadicOp negate a
evalMonadic A.MonadicMinus a = evalMonadicOp negate a
evalMonadic A.MonadicBitNot a = evalMonadicOp complement a
evalMonadic A.MonadicNot (OccBool b) = return $ OccBool (not b)
evalMonadic op _ = throwError (Nothing, "bad monadic op: " ++ show op)
evalDyadicOp :: (forall t. (Num t, Integral t, Bits t) => t -> t -> t) -> OccValue -> OccValue -> EvalM OccValue
evalDyadicOp f (OccByte a) (OccByte b) = return $ OccByte (f a b)
evalDyadicOp f (OccUInt16 a) (OccUInt16 b) = return $ OccUInt16 (f a b)
evalDyadicOp f (OccUInt32 a) (OccUInt32 b) = return $ OccUInt32 (f a b)
evalDyadicOp f (OccUInt64 a) (OccUInt64 b) = return $ OccUInt64 (f a b)
evalDyadicOp f (OccInt8 a) (OccInt8 b) = return $ OccInt8 (f a b)
evalDyadicOp f (OccInt a) (OccInt b) = return $ OccInt (f a b)
evalDyadicOp f (OccInt16 a) (OccInt16 b) = return $ OccInt16 (f a b)
evalDyadicOp f (OccInt32 a) (OccInt32 b) = return $ OccInt32 (f a b)
evalDyadicOp f (OccInt64 a) (OccInt64 b) = return $ OccInt64 (f a b)
evalDyadicOp _ v0 v1 = throwError (Nothing, "dyadic operator not implemented for these types: " ++ show v0 ++ " and " ++ show v1)
evalCompareOp :: (forall t. (Eq t, Ord t) => t -> t -> Bool) -> OccValue -> OccValue -> EvalM OccValue
evalCompareOp f (OccByte a) (OccByte b) = return $ OccBool (f a b)
evalCompareOp f (OccUInt16 a) (OccUInt16 b) = return $ OccBool (f a b)
evalCompareOp f (OccUInt32 a) (OccUInt32 b) = return $ OccBool (f a b)
evalCompareOp f (OccUInt64 a) (OccUInt64 b) = return $ OccBool (f a b)
evalCompareOp f (OccInt8 a) (OccInt8 b) = return $ OccBool (f a b)
evalCompareOp f (OccInt a) (OccInt b) = return $ OccBool (f a b)
evalCompareOp f (OccInt16 a) (OccInt16 b) = return $ OccBool (f a b)
evalCompareOp f (OccInt32 a) (OccInt32 b) = return $ OccBool (f a b)
evalCompareOp f (OccInt64 a) (OccInt64 b) = return $ OccBool (f a b)
evalCompareOp _ v0 v1 = throwError (Nothing, "comparison operator not implemented for these types: " ++ show v0 ++ " and " ++ show v1)
evalDyadic :: A.DyadicOp -> OccValue -> OccValue -> EvalM OccValue
-- FIXME These should check for overflow.
evalDyadic A.Add a b = evalDyadicOp (+) a b
evalDyadic A.Subtr a b = evalDyadicOp (-) a b
evalDyadic A.Mul a b = evalDyadicOp (*) a b
evalDyadic A.Div a b = evalDyadicOp div a b
evalDyadic A.Rem a b = evalDyadicOp rem a b
-- ... end FIXME
evalDyadic A.Plus a b = evalDyadicOp (+) a b
evalDyadic A.Minus a b = evalDyadicOp (-) a b
evalDyadic A.Times a b = evalDyadicOp (*) a b
evalDyadic A.BitAnd a b = evalDyadicOp (.&.) a b
evalDyadic A.BitOr a b = evalDyadicOp (.|.) a b
evalDyadic A.BitXor a b = evalDyadicOp xor a b
evalDyadic A.LeftShift a (OccInt b)
= evalMonadicOp (\v -> shiftL v (fromIntegral b)) a
evalDyadic A.RightShift a (OccInt b)
= evalMonadicOp (\v -> shiftR v (fromIntegral b)) a
evalDyadic A.And (OccBool a) (OccBool b) = return $ OccBool (a && b)
evalDyadic A.Or (OccBool a) (OccBool b) = return $ OccBool (a || b)
evalDyadic A.Eq a b = evalCompareOp (==) a b
evalDyadic A.NotEq a b = evalCompareOp (/=) a b
evalDyadic A.Less a b = evalCompareOp (<) a b
evalDyadic A.More a b = evalCompareOp (>) a b
evalDyadic A.LessEq a b = evalCompareOp (<=) a b
evalDyadic A.MoreEq a b = evalCompareOp (>=) a b
evalDyadic A.After (OccInt a) (OccInt b) = return $ OccBool ((a - b) > 0)
evalDyadic op _ _ = throwError (Nothing, "bad dyadic op: " ++ show op)
--}}}
--{{{ rendering values
-- | Convert a value back into a literal.
renderValue :: Meta -> OccValue -> (A.Type, A.Expression)
renderValue m (OccBool True) = (A.Bool, A.True m)
renderValue m (OccBool False) = (A.Bool, A.False m)
renderValue m v = (t, A.Literal m t lr)
where (t, lr) = renderLiteral m v
renderLiteral :: Meta -> OccValue -> (A.Type, A.LiteralRepr)
renderLiteral m (OccByte c) = (A.Byte, A.ByteLiteral m $ renderChar (chr $ fromIntegral c))
renderLiteral m (OccUInt16 i) = (A.UInt16, A.IntLiteral m $ show i)
renderLiteral m (OccUInt32 i) = (A.UInt32, A.IntLiteral m $ show i)
renderLiteral m (OccUInt64 i) = (A.UInt64, A.IntLiteral m $ show i)
renderLiteral m (OccInt8 i) = (A.Int8, A.IntLiteral m $ show i)
renderLiteral m (OccInt i) = (A.Int, A.IntLiteral m $ show i)
renderLiteral m (OccInt16 i) = (A.Int16, A.IntLiteral m $ show i)
renderLiteral m (OccInt32 i) = (A.Int32, A.IntLiteral m $ show i)
renderLiteral m (OccInt64 i) = (A.Int64, A.IntLiteral m $ show i)
renderLiteral m (OccArray vs)
= (t, A.ArrayLiteral m aes)
where
t = addDimensions [A.Dimension $ length vs] (head ts)
(ts, aes) = unzip $ map (renderLiteralArray m) vs
renderLiteral m (OccRecord n vs)
= (A.Record n, A.RecordLiteral m (map (snd . renderValue m) vs))
renderChar :: Char -> String
renderChar '\'' = "*'"
renderChar '\"' = "*\""
renderChar '*' = "**"
renderChar '\r' = "*c"
renderChar '\n' = "*n"
renderChar '\t' = "*t"
renderChar c
| (o < 32 || o > 127) = printf "*#%02x" o
| otherwise = [c]
where o = ord c
renderLiteralArray :: Meta -> OccValue -> (A.Type, A.ArrayElem)
renderLiteralArray m (OccArray vs)
= (t, A.ArrayElemArray aes)
where
t = addDimensions [A.Dimension $ length vs] (head ts)
(ts, aes) = unzip $ map (renderLiteralArray m) vs
renderLiteralArray m v
= (t, A.ArrayElemExpr e)
where
(t, e) = renderValue m v
--}}}