diff --git a/common/EvalConstants.hs b/common/EvalConstants.hs index d4f760b..a171ade 100644 --- a/common/EvalConstants.hs +++ b/common/EvalConstants.hs @@ -20,6 +20,7 @@ with this program. If not, see . module EvalConstants ( constantFold , evalIntExpression + , getConstantName , isConstantName ) where @@ -29,6 +30,7 @@ import Data.Bits import Data.Char import Data.Int import Data.Maybe +import Foreign import Text.Printf import qualified AST as A @@ -75,6 +77,11 @@ getConstantName n if isConst then return $ Just e' else return Nothing + A.RetypesExpr m A.ValAbbrev t e + -> do ps <- getCompState + case runEvaluator ps (evalRetypes m t =<< evalSimpleExpression e) of + Left _ -> return Nothing + Right ov -> renderValue m t ov >>* Just _ -> return Nothing -- | Is a name defined as a constant expression? @@ -126,6 +133,81 @@ evalSubscript (A.Subscript m _ e) (OccArray vs) else throwError (Just m, "subscript out of range") evalSubscript s _ = throwError (Just $ findMeta s, "invalid subscript") +conv :: Real a => a -> Rational +conv = toRational + +occToRational :: OccValue -> Rational +occToRational (OccByte x) = conv x +occToRational (OccUInt16 x) = conv x +occToRational (OccUInt32 x) = conv x +occToRational (OccUInt64 x) = conv x +occToRational (OccInt8 x) = conv x +occToRational (OccInt16 x) = conv x +occToRational (OccInt x) = conv x +occToRational (OccInt32 x) = conv x +occToRational (OccInt64 x) = conv x +occToRational (OccReal32 x) = conv x +occToRational (OccReal64 x) = conv x + + +evalConversion :: A.ConversionMode -> A.Type -> OccValue -> EvalM OccValue +evalConversion cm t ov + = case t of + A.Byte -> into OccByte + A.UInt16 -> into OccUInt16 + A.UInt32 -> into OccUInt32 + A.UInt64 -> into OccUInt64 + A.Int8 -> into OccInt8 + A.Int16 -> into OccInt16 + A.Int -> into OccInt + A.Int32 -> into OccInt32 + A.Int64 -> into OccInt64 + A.Real32 -> intoF OccReal32 + A.Real64 -> intoF OccReal64 + _ -> throwError (Nothing, "Cannot convert") + where + into cons = return $ cons $ fromInteger iv + intoF cons = return $ cons $ fromRational cv + + cv = occToRational ov + iv :: Integer + iv = case cm of + A.DefaultConversion -> round cv + A.Round -> round cv + A.Trunc -> truncate cv + +evalRetypes :: Die m => Meta -> A.Type -> OccValue -> m OccValue +evalRetypes m t ov + = case t of + A.Byte -> into OccByte + A.UInt16 -> into OccUInt16 + A.UInt32 -> into OccUInt32 + A.UInt64 -> into OccUInt64 + A.Int8 -> into OccInt8 + A.Int16 -> into OccInt16 + A.Int -> into OccInt + A.Int32 -> into OccInt32 + A.Int64 -> into OccInt64 + A.Real32 -> into OccReal32 + A.Real64 -> into OccReal64 + _ -> dieP m "Cannot retype to that type" + where + -- unsafePerformIO is nasty -- Adam made me do it! + into cons = return $ unsafePerformIO $ getBytesFor ov (liftM cons . peek) + + getBytesFor :: OccValue -> (Ptr b -> IO c) -> IO c + getBytesFor (OccByte x) f = with x (f . castPtr) + getBytesFor (OccUInt16 x) f = with x (f . castPtr) + getBytesFor (OccUInt32 x) f = with x (f . castPtr) + getBytesFor (OccUInt64 x) f = with x (f . castPtr) + getBytesFor (OccInt8 x) f = with x (f . castPtr) + getBytesFor (OccInt16 x) f = with x (f . castPtr) + getBytesFor (OccInt x) f = with x (f . castPtr) + getBytesFor (OccInt32 x) f = with x (f . castPtr) + getBytesFor (OccInt64 x) f = with x (f . castPtr) + getBytesFor (OccReal32 x) f = with x (f . castPtr) + getBytesFor (OccReal64 x) f = with x (f . castPtr) + evalExpression :: A.Expression -> EvalM OccValue evalExpression (A.Monadic _ op e) = do v <- evalExpression e @@ -152,6 +234,9 @@ evalExpression (A.MostPos _ A.Int32) = return $ OccInt32 maxBound evalExpression (A.MostNeg _ A.Int32) = return $ OccInt32 minBound evalExpression (A.MostPos _ A.Int64) = return $ OccInt64 maxBound evalExpression (A.MostNeg _ A.Int64) = return $ OccInt64 minBound +evalExpression (A.Conversion _ cm t e) + = do e' <- evalExpression e + evalConversion cm t e' evalExpression (A.SizeExpr m e) = do t <- astTypeOf e >>= underlyingType m case t of @@ -204,17 +289,45 @@ evalMonadic A.MonadicBitNot a = evalMonadicOp complement a evalMonadic A.MonadicNot (OccBool b) = return $ OccBool (not b) evalMonadic op _ = throwError (Nothing, "bad monadic op: " ++ show op) -evalDyadicOp :: (forall t. (Num t, Integral t, Bounded t, Bits t) => t -> t -> t) -> OccValue -> OccValue -> EvalM OccValue -evalDyadicOp f (OccByte a) (OccByte b) = return $ OccByte (f a b) -evalDyadicOp f (OccUInt16 a) (OccUInt16 b) = return $ OccUInt16 (f a b) -evalDyadicOp f (OccUInt32 a) (OccUInt32 b) = return $ OccUInt32 (f a b) -evalDyadicOp f (OccUInt64 a) (OccUInt64 b) = return $ OccUInt64 (f a b) -evalDyadicOp f (OccInt8 a) (OccInt8 b) = return $ OccInt8 (f a b) -evalDyadicOp f (OccInt a) (OccInt b) = return $ OccInt (f a b) -evalDyadicOp f (OccInt16 a) (OccInt16 b) = return $ OccInt16 (f a b) -evalDyadicOp f (OccInt32 a) (OccInt32 b) = return $ OccInt32 (f a b) -evalDyadicOp f (OccInt64 a) (OccInt64 b) = return $ OccInt64 (f a b) -evalDyadicOp _ v0 v1 = throwError (Nothing, "dyadic operator not implemented for these types: " ++ show v0 ++ " and " ++ show v1) +evalArithOp :: (forall t. (Num t) => t -> t -> t) -> OccValue -> OccValue -> EvalM OccValue +evalArithOp f (OccByte a) (OccByte b) = return $ OccByte (f a b) +evalArithOp f (OccUInt16 a) (OccUInt16 b) = return $ OccUInt16 (f a b) +evalArithOp f (OccUInt32 a) (OccUInt32 b) = return $ OccUInt32 (f a b) +evalArithOp f (OccUInt64 a) (OccUInt64 b) = return $ OccUInt64 (f a b) +evalArithOp f (OccInt8 a) (OccInt8 b) = return $ OccInt8 (f a b) +evalArithOp f (OccInt a) (OccInt b) = return $ OccInt (f a b) +evalArithOp f (OccInt16 a) (OccInt16 b) = return $ OccInt16 (f a b) +evalArithOp f (OccInt32 a) (OccInt32 b) = return $ OccInt32 (f a b) +evalArithOp f (OccInt64 a) (OccInt64 b) = return $ OccInt64 (f a b) +evalArithOp f (OccReal32 a) (OccReal32 b) = return $ OccReal32 (f a b) +evalArithOp f (OccReal64 a) (OccReal64 b) = return $ OccReal64 (f a b) +evalArithOp _ v0 v1 = throwError (Nothing, "dyadic operator not implemented for these types: " ++ show v0 ++ " and " ++ show v1) + +evalArithIntOp :: (forall t. (Num t, Integral t, Bounded t) => t -> t -> t) -> OccValue -> OccValue -> EvalM OccValue +evalArithIntOp f (OccByte a) (OccByte b) = return $ OccByte (f a b) +evalArithIntOp f (OccUInt16 a) (OccUInt16 b) = return $ OccUInt16 (f a b) +evalArithIntOp f (OccUInt32 a) (OccUInt32 b) = return $ OccUInt32 (f a b) +evalArithIntOp f (OccUInt64 a) (OccUInt64 b) = return $ OccUInt64 (f a b) +evalArithIntOp f (OccInt8 a) (OccInt8 b) = return $ OccInt8 (f a b) +evalArithIntOp f (OccInt a) (OccInt b) = return $ OccInt (f a b) +evalArithIntOp f (OccInt16 a) (OccInt16 b) = return $ OccInt16 (f a b) +evalArithIntOp f (OccInt32 a) (OccInt32 b) = return $ OccInt32 (f a b) +evalArithIntOp f (OccInt64 a) (OccInt64 b) = return $ OccInt64 (f a b) +evalArithIntOp _ v0 v1 = throwError (Nothing, "dyadic operator not implemented for these types: " ++ show v0 ++ " and " ++ show v1) + + +evalLogicOp :: (forall t. (Bits t) => t -> t -> t) -> OccValue -> OccValue -> EvalM OccValue +evalLogicOp f (OccByte a) (OccByte b) = return $ OccByte (f a b) +evalLogicOp f (OccUInt16 a) (OccUInt16 b) = return $ OccUInt16 (f a b) +evalLogicOp f (OccUInt32 a) (OccUInt32 b) = return $ OccUInt32 (f a b) +evalLogicOp f (OccUInt64 a) (OccUInt64 b) = return $ OccUInt64 (f a b) +evalLogicOp f (OccInt8 a) (OccInt8 b) = return $ OccInt8 (f a b) +evalLogicOp f (OccInt a) (OccInt b) = return $ OccInt (f a b) +evalLogicOp f (OccInt16 a) (OccInt16 b) = return $ OccInt16 (f a b) +evalLogicOp f (OccInt32 a) (OccInt32 b) = return $ OccInt32 (f a b) +evalLogicOp f (OccInt64 a) (OccInt64 b) = return $ OccInt64 (f a b) +evalLogicOp _ v0 v1 = throwError (Nothing, "dyadic operator not implemented for these types: " ++ show v0 ++ " and " ++ show v1) + evalCompareOp :: (forall t. (Eq t, Ord t) => t -> t -> Bool) -> OccValue -> OccValue -> EvalM OccValue evalCompareOp f (OccByte a) (OccByte b) = return $ OccBool (f a b) @@ -248,18 +361,20 @@ safeRem a b = rem a b evalDyadic :: A.DyadicOp -> OccValue -> OccValue -> EvalM OccValue -- FIXME These should check for overflow. -evalDyadic A.Add a b = evalDyadicOp (+) a b -evalDyadic A.Subtr a b = evalDyadicOp (-) a b -evalDyadic A.Mul a b = evalDyadicOp (*) a b -evalDyadic A.Div a b = evalDyadicOp safeDiv a b -evalDyadic A.Rem a b = evalDyadicOp safeRem a b +evalDyadic A.Add a b = evalArithOp (+) a b +evalDyadic A.Subtr a b = evalArithOp (-) a b +evalDyadic A.Mul a b = evalArithOp (*) a b +evalDyadic A.Div (OccReal32 a) (OccReal32 b) = return $ OccReal32 $ a / b +evalDyadic A.Div (OccReal64 a) (OccReal64 b) = return $ OccReal64 $ a / b +evalDyadic A.Div a b = evalArithIntOp safeDiv a b +evalDyadic A.Rem a b = evalArithIntOp safeRem a b -- ... end FIXME -evalDyadic A.Plus a b = evalDyadicOp (+) a b -evalDyadic A.Minus a b = evalDyadicOp (-) a b -evalDyadic A.Times a b = evalDyadicOp (*) a b -evalDyadic A.BitAnd a b = evalDyadicOp (.&.) a b -evalDyadic A.BitOr a b = evalDyadicOp (.|.) a b -evalDyadic A.BitXor a b = evalDyadicOp xor a b +evalDyadic A.Plus a b = evalArithOp (+) a b +evalDyadic A.Minus a b = evalArithOp (-) a b +evalDyadic A.Times a b = evalArithOp (*) a b +evalDyadic A.BitAnd a b = evalLogicOp (.&.) a b +evalDyadic A.BitOr a b = evalLogicOp (.|.) a b +evalDyadic A.BitXor a b = evalLogicOp xor a b evalDyadic A.LeftShift a (OccInt b) = evalMonadicOp (\v -> shiftL v (fromIntegral b)) a evalDyadic A.RightShift a (OccInt b) @@ -303,6 +418,8 @@ renderLiteral m t v OccInt64 i -> renderInt i OccArray vs -> renderArray vs OccRecord _ vs -> renderRecord vs + OccReal32 n -> renderReal n + OccReal64 n -> renderReal n where renderChar :: Char -> String renderChar '\'' = "*'" @@ -319,6 +436,9 @@ renderLiteral m t v renderInt :: Show s => s -> m (A.Type, A.LiteralRepr) renderInt i = return (t, A.IntLiteral m $ show i) + renderReal :: Show s => s -> m (A.Type, A.LiteralRepr) + renderReal i = return (t, A.RealLiteral m $ show i) + renderArray :: [OccValue] -> m (A.Type, A.LiteralRepr) renderArray vs = do (t', aes) <- renderArrayElems t vs diff --git a/common/EvalLiterals.hs b/common/EvalLiterals.hs index b8d566c..71dd761 100644 --- a/common/EvalLiterals.hs +++ b/common/EvalLiterals.hs @@ -52,6 +52,8 @@ data OccValue = | OccInt CIntReplacement | OccInt32 Int32 | OccInt64 Int64 + | OccReal32 Float + | OccReal64 Double | OccArray [OccValue] | OccRecord A.Name [OccValue] deriving (Show, Eq, Typeable, Data) @@ -117,6 +119,8 @@ evalSimpleLiteral (A.Literal _ t lr) A.Int -> into OccInt A.Int32 -> into OccInt32 A.Int64 -> into OccInt64 + A.Real32 -> intoF OccReal32 + A.Real64 -> intoF OccReal64 _ -> bad where defaults :: EvalM OccValue @@ -125,6 +129,7 @@ evalSimpleLiteral (A.Literal _ t lr) A.ByteLiteral _ s -> evalByteLiteral m OccByte s A.IntLiteral _ s -> fromRead m OccInt (readSigned readDec) s A.HexLiteral _ s -> fromRead m OccInt readHex s + A.RealLiteral _ s -> fromRead m OccReal32 readFloat s _ -> bad into :: (Num t, Real t) => (t -> OccValue) -> EvalM OccValue @@ -135,6 +140,16 @@ evalSimpleLiteral (A.Literal _ t lr) A.HexLiteral _ s -> fromRead m cons readHex s _ -> bad + intoF :: RealFrac t => (t -> OccValue) -> EvalM OccValue + intoF cons + = case lr of + A.ByteLiteral _ s -> evalByteLiteral m cons s + A.IntLiteral _ s -> fromRead m cons (readSigned readDec) s + A.HexLiteral _ s -> fromRead m cons readHex s + A.RealLiteral _ s -> fromRead m cons readFloat s + _ -> bad + + bad :: EvalM OccValue bad = throwError (Just m, "Cannot evaluate literal")