Added support for constant-folding reals, and for constant-folding RETYPES

This commit is contained in:
Neil Brown 2009-03-26 18:33:14 +00:00
parent 96984250b7
commit d35825ec50
2 changed files with 157 additions and 22 deletions

View File

@ -20,6 +20,7 @@ with this program. If not, see <http://www.gnu.org/licenses/>.
module EvalConstants
( constantFold
, evalIntExpression
, getConstantName
, isConstantName
) where
@ -29,6 +30,7 @@ import Data.Bits
import Data.Char
import Data.Int
import Data.Maybe
import Foreign
import Text.Printf
import qualified AST as A
@ -75,6 +77,11 @@ getConstantName n
if isConst
then return $ Just e'
else return Nothing
A.RetypesExpr m A.ValAbbrev t e
-> do ps <- getCompState
case runEvaluator ps (evalRetypes m t =<< evalSimpleExpression e) of
Left _ -> return Nothing
Right ov -> renderValue m t ov >>* Just
_ -> return Nothing
-- | Is a name defined as a constant expression?
@ -126,6 +133,81 @@ evalSubscript (A.Subscript m _ e) (OccArray vs)
else throwError (Just m, "subscript out of range")
evalSubscript s _ = throwError (Just $ findMeta s, "invalid subscript")
conv :: Real a => a -> Rational
conv = toRational
occToRational :: OccValue -> Rational
occToRational (OccByte x) = conv x
occToRational (OccUInt16 x) = conv x
occToRational (OccUInt32 x) = conv x
occToRational (OccUInt64 x) = conv x
occToRational (OccInt8 x) = conv x
occToRational (OccInt16 x) = conv x
occToRational (OccInt x) = conv x
occToRational (OccInt32 x) = conv x
occToRational (OccInt64 x) = conv x
occToRational (OccReal32 x) = conv x
occToRational (OccReal64 x) = conv x
evalConversion :: A.ConversionMode -> A.Type -> OccValue -> EvalM OccValue
evalConversion cm t ov
= case t of
A.Byte -> into OccByte
A.UInt16 -> into OccUInt16
A.UInt32 -> into OccUInt32
A.UInt64 -> into OccUInt64
A.Int8 -> into OccInt8
A.Int16 -> into OccInt16
A.Int -> into OccInt
A.Int32 -> into OccInt32
A.Int64 -> into OccInt64
A.Real32 -> intoF OccReal32
A.Real64 -> intoF OccReal64
_ -> throwError (Nothing, "Cannot convert")
where
into cons = return $ cons $ fromInteger iv
intoF cons = return $ cons $ fromRational cv
cv = occToRational ov
iv :: Integer
iv = case cm of
A.DefaultConversion -> round cv
A.Round -> round cv
A.Trunc -> truncate cv
evalRetypes :: Die m => Meta -> A.Type -> OccValue -> m OccValue
evalRetypes m t ov
= case t of
A.Byte -> into OccByte
A.UInt16 -> into OccUInt16
A.UInt32 -> into OccUInt32
A.UInt64 -> into OccUInt64
A.Int8 -> into OccInt8
A.Int16 -> into OccInt16
A.Int -> into OccInt
A.Int32 -> into OccInt32
A.Int64 -> into OccInt64
A.Real32 -> into OccReal32
A.Real64 -> into OccReal64
_ -> dieP m "Cannot retype to that type"
where
-- unsafePerformIO is nasty -- Adam made me do it!
into cons = return $ unsafePerformIO $ getBytesFor ov (liftM cons . peek)
getBytesFor :: OccValue -> (Ptr b -> IO c) -> IO c
getBytesFor (OccByte x) f = with x (f . castPtr)
getBytesFor (OccUInt16 x) f = with x (f . castPtr)
getBytesFor (OccUInt32 x) f = with x (f . castPtr)
getBytesFor (OccUInt64 x) f = with x (f . castPtr)
getBytesFor (OccInt8 x) f = with x (f . castPtr)
getBytesFor (OccInt16 x) f = with x (f . castPtr)
getBytesFor (OccInt x) f = with x (f . castPtr)
getBytesFor (OccInt32 x) f = with x (f . castPtr)
getBytesFor (OccInt64 x) f = with x (f . castPtr)
getBytesFor (OccReal32 x) f = with x (f . castPtr)
getBytesFor (OccReal64 x) f = with x (f . castPtr)
evalExpression :: A.Expression -> EvalM OccValue
evalExpression (A.Monadic _ op e)
= do v <- evalExpression e
@ -152,6 +234,9 @@ evalExpression (A.MostPos _ A.Int32) = return $ OccInt32 maxBound
evalExpression (A.MostNeg _ A.Int32) = return $ OccInt32 minBound
evalExpression (A.MostPos _ A.Int64) = return $ OccInt64 maxBound
evalExpression (A.MostNeg _ A.Int64) = return $ OccInt64 minBound
evalExpression (A.Conversion _ cm t e)
= do e' <- evalExpression e
evalConversion cm t e'
evalExpression (A.SizeExpr m e)
= do t <- astTypeOf e >>= underlyingType m
case t of
@ -204,17 +289,45 @@ evalMonadic A.MonadicBitNot a = evalMonadicOp complement a
evalMonadic A.MonadicNot (OccBool b) = return $ OccBool (not b)
evalMonadic op _ = throwError (Nothing, "bad monadic op: " ++ show op)
evalDyadicOp :: (forall t. (Num t, Integral t, Bounded t, Bits t) => t -> t -> t) -> OccValue -> OccValue -> EvalM OccValue
evalDyadicOp f (OccByte a) (OccByte b) = return $ OccByte (f a b)
evalDyadicOp f (OccUInt16 a) (OccUInt16 b) = return $ OccUInt16 (f a b)
evalDyadicOp f (OccUInt32 a) (OccUInt32 b) = return $ OccUInt32 (f a b)
evalDyadicOp f (OccUInt64 a) (OccUInt64 b) = return $ OccUInt64 (f a b)
evalDyadicOp f (OccInt8 a) (OccInt8 b) = return $ OccInt8 (f a b)
evalDyadicOp f (OccInt a) (OccInt b) = return $ OccInt (f a b)
evalDyadicOp f (OccInt16 a) (OccInt16 b) = return $ OccInt16 (f a b)
evalDyadicOp f (OccInt32 a) (OccInt32 b) = return $ OccInt32 (f a b)
evalDyadicOp f (OccInt64 a) (OccInt64 b) = return $ OccInt64 (f a b)
evalDyadicOp _ v0 v1 = throwError (Nothing, "dyadic operator not implemented for these types: " ++ show v0 ++ " and " ++ show v1)
evalArithOp :: (forall t. (Num t) => t -> t -> t) -> OccValue -> OccValue -> EvalM OccValue
evalArithOp f (OccByte a) (OccByte b) = return $ OccByte (f a b)
evalArithOp f (OccUInt16 a) (OccUInt16 b) = return $ OccUInt16 (f a b)
evalArithOp f (OccUInt32 a) (OccUInt32 b) = return $ OccUInt32 (f a b)
evalArithOp f (OccUInt64 a) (OccUInt64 b) = return $ OccUInt64 (f a b)
evalArithOp f (OccInt8 a) (OccInt8 b) = return $ OccInt8 (f a b)
evalArithOp f (OccInt a) (OccInt b) = return $ OccInt (f a b)
evalArithOp f (OccInt16 a) (OccInt16 b) = return $ OccInt16 (f a b)
evalArithOp f (OccInt32 a) (OccInt32 b) = return $ OccInt32 (f a b)
evalArithOp f (OccInt64 a) (OccInt64 b) = return $ OccInt64 (f a b)
evalArithOp f (OccReal32 a) (OccReal32 b) = return $ OccReal32 (f a b)
evalArithOp f (OccReal64 a) (OccReal64 b) = return $ OccReal64 (f a b)
evalArithOp _ v0 v1 = throwError (Nothing, "dyadic operator not implemented for these types: " ++ show v0 ++ " and " ++ show v1)
evalArithIntOp :: (forall t. (Num t, Integral t, Bounded t) => t -> t -> t) -> OccValue -> OccValue -> EvalM OccValue
evalArithIntOp f (OccByte a) (OccByte b) = return $ OccByte (f a b)
evalArithIntOp f (OccUInt16 a) (OccUInt16 b) = return $ OccUInt16 (f a b)
evalArithIntOp f (OccUInt32 a) (OccUInt32 b) = return $ OccUInt32 (f a b)
evalArithIntOp f (OccUInt64 a) (OccUInt64 b) = return $ OccUInt64 (f a b)
evalArithIntOp f (OccInt8 a) (OccInt8 b) = return $ OccInt8 (f a b)
evalArithIntOp f (OccInt a) (OccInt b) = return $ OccInt (f a b)
evalArithIntOp f (OccInt16 a) (OccInt16 b) = return $ OccInt16 (f a b)
evalArithIntOp f (OccInt32 a) (OccInt32 b) = return $ OccInt32 (f a b)
evalArithIntOp f (OccInt64 a) (OccInt64 b) = return $ OccInt64 (f a b)
evalArithIntOp _ v0 v1 = throwError (Nothing, "dyadic operator not implemented for these types: " ++ show v0 ++ " and " ++ show v1)
evalLogicOp :: (forall t. (Bits t) => t -> t -> t) -> OccValue -> OccValue -> EvalM OccValue
evalLogicOp f (OccByte a) (OccByte b) = return $ OccByte (f a b)
evalLogicOp f (OccUInt16 a) (OccUInt16 b) = return $ OccUInt16 (f a b)
evalLogicOp f (OccUInt32 a) (OccUInt32 b) = return $ OccUInt32 (f a b)
evalLogicOp f (OccUInt64 a) (OccUInt64 b) = return $ OccUInt64 (f a b)
evalLogicOp f (OccInt8 a) (OccInt8 b) = return $ OccInt8 (f a b)
evalLogicOp f (OccInt a) (OccInt b) = return $ OccInt (f a b)
evalLogicOp f (OccInt16 a) (OccInt16 b) = return $ OccInt16 (f a b)
evalLogicOp f (OccInt32 a) (OccInt32 b) = return $ OccInt32 (f a b)
evalLogicOp f (OccInt64 a) (OccInt64 b) = return $ OccInt64 (f a b)
evalLogicOp _ v0 v1 = throwError (Nothing, "dyadic operator not implemented for these types: " ++ show v0 ++ " and " ++ show v1)
evalCompareOp :: (forall t. (Eq t, Ord t) => t -> t -> Bool) -> OccValue -> OccValue -> EvalM OccValue
evalCompareOp f (OccByte a) (OccByte b) = return $ OccBool (f a b)
@ -248,18 +361,20 @@ safeRem a b = rem a b
evalDyadic :: A.DyadicOp -> OccValue -> OccValue -> EvalM OccValue
-- FIXME These should check for overflow.
evalDyadic A.Add a b = evalDyadicOp (+) a b
evalDyadic A.Subtr a b = evalDyadicOp (-) a b
evalDyadic A.Mul a b = evalDyadicOp (*) a b
evalDyadic A.Div a b = evalDyadicOp safeDiv a b
evalDyadic A.Rem a b = evalDyadicOp safeRem a b
evalDyadic A.Add a b = evalArithOp (+) a b
evalDyadic A.Subtr a b = evalArithOp (-) a b
evalDyadic A.Mul a b = evalArithOp (*) a b
evalDyadic A.Div (OccReal32 a) (OccReal32 b) = return $ OccReal32 $ a / b
evalDyadic A.Div (OccReal64 a) (OccReal64 b) = return $ OccReal64 $ a / b
evalDyadic A.Div a b = evalArithIntOp safeDiv a b
evalDyadic A.Rem a b = evalArithIntOp safeRem a b
-- ... end FIXME
evalDyadic A.Plus a b = evalDyadicOp (+) a b
evalDyadic A.Minus a b = evalDyadicOp (-) a b
evalDyadic A.Times a b = evalDyadicOp (*) a b
evalDyadic A.BitAnd a b = evalDyadicOp (.&.) a b
evalDyadic A.BitOr a b = evalDyadicOp (.|.) a b
evalDyadic A.BitXor a b = evalDyadicOp xor a b
evalDyadic A.Plus a b = evalArithOp (+) a b
evalDyadic A.Minus a b = evalArithOp (-) a b
evalDyadic A.Times a b = evalArithOp (*) a b
evalDyadic A.BitAnd a b = evalLogicOp (.&.) a b
evalDyadic A.BitOr a b = evalLogicOp (.|.) a b
evalDyadic A.BitXor a b = evalLogicOp xor a b
evalDyadic A.LeftShift a (OccInt b)
= evalMonadicOp (\v -> shiftL v (fromIntegral b)) a
evalDyadic A.RightShift a (OccInt b)
@ -303,6 +418,8 @@ renderLiteral m t v
OccInt64 i -> renderInt i
OccArray vs -> renderArray vs
OccRecord _ vs -> renderRecord vs
OccReal32 n -> renderReal n
OccReal64 n -> renderReal n
where
renderChar :: Char -> String
renderChar '\'' = "*'"
@ -319,6 +436,9 @@ renderLiteral m t v
renderInt :: Show s => s -> m (A.Type, A.LiteralRepr)
renderInt i = return (t, A.IntLiteral m $ show i)
renderReal :: Show s => s -> m (A.Type, A.LiteralRepr)
renderReal i = return (t, A.RealLiteral m $ show i)
renderArray :: [OccValue] -> m (A.Type, A.LiteralRepr)
renderArray vs
= do (t', aes) <- renderArrayElems t vs

View File

@ -52,6 +52,8 @@ data OccValue =
| OccInt CIntReplacement
| OccInt32 Int32
| OccInt64 Int64
| OccReal32 Float
| OccReal64 Double
| OccArray [OccValue]
| OccRecord A.Name [OccValue]
deriving (Show, Eq, Typeable, Data)
@ -117,6 +119,8 @@ evalSimpleLiteral (A.Literal _ t lr)
A.Int -> into OccInt
A.Int32 -> into OccInt32
A.Int64 -> into OccInt64
A.Real32 -> intoF OccReal32
A.Real64 -> intoF OccReal64
_ -> bad
where
defaults :: EvalM OccValue
@ -125,6 +129,7 @@ evalSimpleLiteral (A.Literal _ t lr)
A.ByteLiteral _ s -> evalByteLiteral m OccByte s
A.IntLiteral _ s -> fromRead m OccInt (readSigned readDec) s
A.HexLiteral _ s -> fromRead m OccInt readHex s
A.RealLiteral _ s -> fromRead m OccReal32 readFloat s
_ -> bad
into :: (Num t, Real t) => (t -> OccValue) -> EvalM OccValue
@ -135,6 +140,16 @@ evalSimpleLiteral (A.Literal _ t lr)
A.HexLiteral _ s -> fromRead m cons readHex s
_ -> bad
intoF :: RealFrac t => (t -> OccValue) -> EvalM OccValue
intoF cons
= case lr of
A.ByteLiteral _ s -> evalByteLiteral m cons s
A.IntLiteral _ s -> fromRead m cons (readSigned readDec) s
A.HexLiteral _ s -> fromRead m cons readHex s
A.RealLiteral _ s -> fromRead m cons readFloat s
_ -> bad
bad :: EvalM OccValue
bad = throwError (Just m, "Cannot evaluate literal")