{- Tock: a compiler for parallel languages Copyright (C) 2007 University of Kent This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 2 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . -} -- | Evaluate constant expressions. module EvalConstants (constantFold, isConstantName) where import Control.Monad.Error import Control.Monad.State import Data.Bits import Data.Char import Data.Int import Data.Maybe import Text.Printf import qualified AST as A import CompState import Errors import EvalLiterals import Metadata import ShowCode import Types -- | Simplify an expression by constant folding, and also return whether it's a -- constant after that. constantFold :: CSM m => A.Expression -> m (A.Expression, Bool, ErrorReport) constantFold e = do ps <- get let (e', msg) = case simplifyExpression ps e of Left err -> (e, err) Right val -> (val, (Nothing, "already folded")) return (e', isConstant e', msg) -- | Is a name defined as a constant expression? If so, return its definition. getConstantName :: (CSM m, Die m) => A.Name -> m (Maybe A.Expression) getConstantName n = do st <- specTypeOfName n case st of A.IsExpr _ A.ValAbbrev _ e -> if isConstant e then return $ Just e else return Nothing _ -> return Nothing -- | Is a name defined as a constant expression? isConstantName :: (CSM m, Die m) => A.Name -> m Bool isConstantName n = do me <- getConstantName n return $ case me of Just _ -> True Nothing -> False -- | Attempt to simplify an expression as far as possible by precomputing -- constant bits. simplifyExpression :: CompState -> A.Expression -> Either ErrorReport A.Expression simplifyExpression ps e = case runEvaluator ps (evalExpression e) of Left err -> Left err Right val -> Right $ snd $ renderValue (findMeta e) val --{{{ expression evaluator evalLiteral :: A.Expression -> EvalM OccValue evalLiteral (A.Literal m _ (A.ArrayLiteral _ [])) = throwError (Just m, "empty array") evalLiteral (A.Literal _ _ (A.ArrayLiteral _ aes)) = liftM OccArray (mapM evalLiteralArray aes) evalLiteral (A.Literal _ (A.Record n) (A.RecordLiteral _ es)) = liftM (OccRecord n) (mapM evalExpression es) evalLiteral l = evalSimpleLiteral l evalLiteralArray :: A.ArrayElem -> EvalM OccValue evalLiteralArray (A.ArrayElemArray aes) = liftM OccArray (mapM evalLiteralArray aes) evalLiteralArray (A.ArrayElemExpr e) = evalExpression e evalVariable :: A.Variable -> EvalM OccValue evalVariable (A.Variable m n) = do me <- getConstantName n case me of Just e -> evalExpression e Nothing -> throwError (Just m, "non-constant variable " ++ show n ++ " used") evalVariable (A.SubscriptedVariable _ sub v) = evalVariable v >>= evalSubscript sub evalVariable (A.DirectedVariable _ _ v) = evalVariable v evalIndex :: A.Expression -> EvalM Int evalIndex e = do index <- evalExpression e case index of OccInt n -> return $ fromIntegral n _ -> throwError (Just $ findMeta e, "index has non-INT type") evalSubscript :: A.Subscript -> OccValue -> EvalM OccValue evalSubscript (A.Subscript m e) (OccArray vs) = do index <- evalIndex e if index >= 0 && index < length vs then return $ vs !! index else throwError (Just m, "subscript out of range") evalSubscript s _ = throwError (Just $ findMeta s, "invalid subscript") evalExpression :: A.Expression -> EvalM OccValue evalExpression (A.Monadic _ op e) = do v <- evalExpression e evalMonadic op v evalExpression (A.Dyadic _ op e1 e2) = do v1 <- evalExpression e1 v2 <- evalExpression e2 evalDyadic op v1 v2 evalExpression (A.MostPos _ A.Byte) = return $ OccByte maxBound evalExpression (A.MostNeg _ A.Byte) = return $ OccByte minBound evalExpression (A.MostPos _ A.UInt16) = return $ OccUInt16 maxBound evalExpression (A.MostNeg _ A.UInt16) = return $ OccUInt16 minBound evalExpression (A.MostPos _ A.UInt32) = return $ OccUInt32 maxBound evalExpression (A.MostNeg _ A.UInt32) = return $ OccUInt32 minBound evalExpression (A.MostPos _ A.UInt64) = return $ OccUInt64 maxBound evalExpression (A.MostNeg _ A.UInt64) = return $ OccUInt64 minBound evalExpression (A.MostPos _ A.Int8) = return $ OccInt8 maxBound evalExpression (A.MostNeg _ A.Int8) = return $ OccInt8 minBound evalExpression (A.MostPos _ A.Int) = return $ OccInt maxBound evalExpression (A.MostNeg _ A.Int) = return $ OccInt minBound evalExpression (A.MostPos _ A.Int16) = return $ OccInt16 maxBound evalExpression (A.MostNeg _ A.Int16) = return $ OccInt16 minBound evalExpression (A.MostPos _ A.Int32) = return $ OccInt32 maxBound evalExpression (A.MostNeg _ A.Int32) = return $ OccInt32 minBound evalExpression (A.MostPos _ A.Int64) = return $ OccInt64 maxBound evalExpression (A.MostNeg _ A.Int64) = return $ OccInt64 minBound evalExpression (A.SizeExpr m e) = do t <- typeOfExpression e >>= underlyingType m case t of A.Array (A.Dimension n:_) _ -> return $ OccInt (fromIntegral n) _ -> do v <- evalExpression e case v of OccArray vs -> return $ OccInt (fromIntegral $ length vs) _ -> throwError (Just m, "size of non-constant expression " ++ show e ++ " used") evalExpression (A.SizeVariable m v) = do t <- typeOfVariable v >>= underlyingType m case t of A.Array (A.Dimension n:_) _ -> return $ OccInt (fromIntegral n) _ -> throwError (Just m, "size of non-fixed-size variable " ++ show v ++ " used") evalExpression e@(A.Literal _ _ _) = evalLiteral e evalExpression (A.ExprVariable _ v) = evalVariable v evalExpression (A.True _) = return $ OccBool True evalExpression (A.False _) = return $ OccBool False evalExpression (A.SubscriptedExpr _ sub e) = evalExpression e >>= evalSubscript sub evalExpression (A.BytesInExpr m e) = do b <- typeOfExpression e >>= underlyingType m >>= bytesInType case b of BIJust n -> return $ OccInt (fromIntegral $ n) _ -> throwError (Just m, "BYTESIN non-constant-size expression " ++ show e ++ " used") evalExpression (A.BytesInType m t) = do b <- underlyingType m t >>= bytesInType case b of BIJust n -> return $ OccInt (fromIntegral $ n) _ -> throwErrorC (Just m, formatCode "BYTESIN non-constant-size type % used" t) evalExpression e = throwError (Just $ findMeta e, "bad expression") evalMonadicOp :: (forall t. (Num t, Integral t, Bits t) => t -> t) -> OccValue -> EvalM OccValue evalMonadicOp f (OccByte a) = return $ OccByte (f a) evalMonadicOp f (OccUInt16 a) = return $ OccUInt16 (f a) evalMonadicOp f (OccUInt32 a) = return $ OccUInt32 (f a) evalMonadicOp f (OccUInt64 a) = return $ OccUInt64 (f a) evalMonadicOp f (OccInt8 a) = return $ OccInt8 (f a) evalMonadicOp f (OccInt a) = return $ OccInt (f a) evalMonadicOp f (OccInt16 a) = return $ OccInt16 (f a) evalMonadicOp f (OccInt32 a) = return $ OccInt32 (f a) evalMonadicOp f (OccInt64 a) = return $ OccInt64 (f a) evalMonadicOp _ v = throwError (Nothing, "monadic operator not implemented for this type: " ++ show v) evalMonadic :: A.MonadicOp -> OccValue -> EvalM OccValue -- This, oddly, is probably the most important rule here: "-4" isn't a literal -- in occam, it's an operator applied to a literal. evalMonadic A.MonadicSubtr a = evalMonadicOp negate a evalMonadic A.MonadicMinus a = evalMonadicOp negate a evalMonadic A.MonadicBitNot a = evalMonadicOp complement a evalMonadic A.MonadicNot (OccBool b) = return $ OccBool (not b) evalMonadic op _ = throwError (Nothing, "bad monadic op: " ++ show op) evalDyadicOp :: (forall t. (Num t, Integral t, Bits t) => t -> t -> t) -> OccValue -> OccValue -> EvalM OccValue evalDyadicOp f (OccByte a) (OccByte b) = return $ OccByte (f a b) evalDyadicOp f (OccUInt16 a) (OccUInt16 b) = return $ OccUInt16 (f a b) evalDyadicOp f (OccUInt32 a) (OccUInt32 b) = return $ OccUInt32 (f a b) evalDyadicOp f (OccUInt64 a) (OccUInt64 b) = return $ OccUInt64 (f a b) evalDyadicOp f (OccInt8 a) (OccInt8 b) = return $ OccInt8 (f a b) evalDyadicOp f (OccInt a) (OccInt b) = return $ OccInt (f a b) evalDyadicOp f (OccInt16 a) (OccInt16 b) = return $ OccInt16 (f a b) evalDyadicOp f (OccInt32 a) (OccInt32 b) = return $ OccInt32 (f a b) evalDyadicOp f (OccInt64 a) (OccInt64 b) = return $ OccInt64 (f a b) evalDyadicOp _ v0 v1 = throwError (Nothing, "dyadic operator not implemented for these types: " ++ show v0 ++ " and " ++ show v1) evalCompareOp :: (forall t. (Eq t, Ord t) => t -> t -> Bool) -> OccValue -> OccValue -> EvalM OccValue evalCompareOp f (OccByte a) (OccByte b) = return $ OccBool (f a b) evalCompareOp f (OccUInt16 a) (OccUInt16 b) = return $ OccBool (f a b) evalCompareOp f (OccUInt32 a) (OccUInt32 b) = return $ OccBool (f a b) evalCompareOp f (OccUInt64 a) (OccUInt64 b) = return $ OccBool (f a b) evalCompareOp f (OccInt8 a) (OccInt8 b) = return $ OccBool (f a b) evalCompareOp f (OccInt a) (OccInt b) = return $ OccBool (f a b) evalCompareOp f (OccInt16 a) (OccInt16 b) = return $ OccBool (f a b) evalCompareOp f (OccInt32 a) (OccInt32 b) = return $ OccBool (f a b) evalCompareOp f (OccInt64 a) (OccInt64 b) = return $ OccBool (f a b) evalCompareOp _ v0 v1 = throwError (Nothing, "comparison operator not implemented for these types: " ++ show v0 ++ " and " ++ show v1) evalDyadic :: A.DyadicOp -> OccValue -> OccValue -> EvalM OccValue -- FIXME These should check for overflow. evalDyadic A.Add a b = evalDyadicOp (+) a b evalDyadic A.Subtr a b = evalDyadicOp (-) a b evalDyadic A.Mul a b = evalDyadicOp (*) a b evalDyadic A.Div a b = evalDyadicOp div a b evalDyadic A.Rem a b = evalDyadicOp rem a b -- ... end FIXME evalDyadic A.Plus a b = evalDyadicOp (+) a b evalDyadic A.Minus a b = evalDyadicOp (-) a b evalDyadic A.Times a b = evalDyadicOp (*) a b evalDyadic A.BitAnd a b = evalDyadicOp (.&.) a b evalDyadic A.BitOr a b = evalDyadicOp (.|.) a b evalDyadic A.BitXor a b = evalDyadicOp xor a b evalDyadic A.LeftShift a (OccInt b) = evalMonadicOp (\v -> shiftL v (fromIntegral b)) a evalDyadic A.RightShift a (OccInt b) = evalMonadicOp (\v -> shiftR v (fromIntegral b)) a evalDyadic A.And (OccBool a) (OccBool b) = return $ OccBool (a && b) evalDyadic A.Or (OccBool a) (OccBool b) = return $ OccBool (a || b) evalDyadic A.Eq a b = evalCompareOp (==) a b evalDyadic A.NotEq a b = evalCompareOp (/=) a b evalDyadic A.Less a b = evalCompareOp (<) a b evalDyadic A.More a b = evalCompareOp (>) a b evalDyadic A.LessEq a b = evalCompareOp (<=) a b evalDyadic A.MoreEq a b = evalCompareOp (>=) a b evalDyadic A.After (OccInt a) (OccInt b) = return $ OccBool ((a - b) > 0) evalDyadic op _ _ = throwError (Nothing, "bad dyadic op: " ++ show op) --}}} --{{{ rendering values -- | Convert a value back into a literal. renderValue :: Meta -> OccValue -> (A.Type, A.Expression) renderValue m (OccBool True) = (A.Bool, A.True m) renderValue m (OccBool False) = (A.Bool, A.False m) renderValue m v = (t, A.Literal m t lr) where (t, lr) = renderLiteral m v renderLiteral :: Meta -> OccValue -> (A.Type, A.LiteralRepr) renderLiteral m (OccByte c) = (A.Byte, A.ByteLiteral m $ renderChar (chr $ fromIntegral c)) renderLiteral m (OccUInt16 i) = (A.UInt16, A.IntLiteral m $ show i) renderLiteral m (OccUInt32 i) = (A.UInt32, A.IntLiteral m $ show i) renderLiteral m (OccUInt64 i) = (A.UInt64, A.IntLiteral m $ show i) renderLiteral m (OccInt8 i) = (A.Int8, A.IntLiteral m $ show i) renderLiteral m (OccInt i) = (A.Int, A.IntLiteral m $ show i) renderLiteral m (OccInt16 i) = (A.Int16, A.IntLiteral m $ show i) renderLiteral m (OccInt32 i) = (A.Int32, A.IntLiteral m $ show i) renderLiteral m (OccInt64 i) = (A.Int64, A.IntLiteral m $ show i) renderLiteral m (OccArray vs) = (t, A.ArrayLiteral m aes) where t = addDimensions [A.Dimension $ length vs] (head ts) (ts, aes) = unzip $ map (renderLiteralArray m) vs renderLiteral m (OccRecord n vs) = (A.Record n, A.RecordLiteral m (map (snd . renderValue m) vs)) renderChar :: Char -> String renderChar '\'' = "*'" renderChar '\"' = "*\"" renderChar '*' = "**" renderChar '\r' = "*c" renderChar '\n' = "*n" renderChar '\t' = "*t" renderChar c | (o < 32 || o > 127) = printf "*#%02x" o | otherwise = [c] where o = ord c renderLiteralArray :: Meta -> OccValue -> (A.Type, A.ArrayElem) renderLiteralArray m (OccArray vs) = (t, A.ArrayElemArray aes) where t = addDimensions [A.Dimension $ length vs] (head ts) (ts, aes) = unzip $ map (renderLiteralArray m) vs renderLiteralArray m v = (t, A.ArrayElemExpr e) where (t, e) = renderValue m v --}}}