Added support for constant-folding reals, and for constant-folding RETYPES
This commit is contained in:
parent
96984250b7
commit
d35825ec50
|
@ -20,6 +20,7 @@ with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
module EvalConstants
|
||||
( constantFold
|
||||
, evalIntExpression
|
||||
, getConstantName
|
||||
, isConstantName
|
||||
) where
|
||||
|
||||
|
@ -29,6 +30,7 @@ import Data.Bits
|
|||
import Data.Char
|
||||
import Data.Int
|
||||
import Data.Maybe
|
||||
import Foreign
|
||||
import Text.Printf
|
||||
|
||||
import qualified AST as A
|
||||
|
@ -75,6 +77,11 @@ getConstantName n
|
|||
if isConst
|
||||
then return $ Just e'
|
||||
else return Nothing
|
||||
A.RetypesExpr m A.ValAbbrev t e
|
||||
-> do ps <- getCompState
|
||||
case runEvaluator ps (evalRetypes m t =<< evalSimpleExpression e) of
|
||||
Left _ -> return Nothing
|
||||
Right ov -> renderValue m t ov >>* Just
|
||||
_ -> return Nothing
|
||||
|
||||
-- | Is a name defined as a constant expression?
|
||||
|
@ -126,6 +133,81 @@ evalSubscript (A.Subscript m _ e) (OccArray vs)
|
|||
else throwError (Just m, "subscript out of range")
|
||||
evalSubscript s _ = throwError (Just $ findMeta s, "invalid subscript")
|
||||
|
||||
conv :: Real a => a -> Rational
|
||||
conv = toRational
|
||||
|
||||
occToRational :: OccValue -> Rational
|
||||
occToRational (OccByte x) = conv x
|
||||
occToRational (OccUInt16 x) = conv x
|
||||
occToRational (OccUInt32 x) = conv x
|
||||
occToRational (OccUInt64 x) = conv x
|
||||
occToRational (OccInt8 x) = conv x
|
||||
occToRational (OccInt16 x) = conv x
|
||||
occToRational (OccInt x) = conv x
|
||||
occToRational (OccInt32 x) = conv x
|
||||
occToRational (OccInt64 x) = conv x
|
||||
occToRational (OccReal32 x) = conv x
|
||||
occToRational (OccReal64 x) = conv x
|
||||
|
||||
|
||||
evalConversion :: A.ConversionMode -> A.Type -> OccValue -> EvalM OccValue
|
||||
evalConversion cm t ov
|
||||
= case t of
|
||||
A.Byte -> into OccByte
|
||||
A.UInt16 -> into OccUInt16
|
||||
A.UInt32 -> into OccUInt32
|
||||
A.UInt64 -> into OccUInt64
|
||||
A.Int8 -> into OccInt8
|
||||
A.Int16 -> into OccInt16
|
||||
A.Int -> into OccInt
|
||||
A.Int32 -> into OccInt32
|
||||
A.Int64 -> into OccInt64
|
||||
A.Real32 -> intoF OccReal32
|
||||
A.Real64 -> intoF OccReal64
|
||||
_ -> throwError (Nothing, "Cannot convert")
|
||||
where
|
||||
into cons = return $ cons $ fromInteger iv
|
||||
intoF cons = return $ cons $ fromRational cv
|
||||
|
||||
cv = occToRational ov
|
||||
iv :: Integer
|
||||
iv = case cm of
|
||||
A.DefaultConversion -> round cv
|
||||
A.Round -> round cv
|
||||
A.Trunc -> truncate cv
|
||||
|
||||
evalRetypes :: Die m => Meta -> A.Type -> OccValue -> m OccValue
|
||||
evalRetypes m t ov
|
||||
= case t of
|
||||
A.Byte -> into OccByte
|
||||
A.UInt16 -> into OccUInt16
|
||||
A.UInt32 -> into OccUInt32
|
||||
A.UInt64 -> into OccUInt64
|
||||
A.Int8 -> into OccInt8
|
||||
A.Int16 -> into OccInt16
|
||||
A.Int -> into OccInt
|
||||
A.Int32 -> into OccInt32
|
||||
A.Int64 -> into OccInt64
|
||||
A.Real32 -> into OccReal32
|
||||
A.Real64 -> into OccReal64
|
||||
_ -> dieP m "Cannot retype to that type"
|
||||
where
|
||||
-- unsafePerformIO is nasty -- Adam made me do it!
|
||||
into cons = return $ unsafePerformIO $ getBytesFor ov (liftM cons . peek)
|
||||
|
||||
getBytesFor :: OccValue -> (Ptr b -> IO c) -> IO c
|
||||
getBytesFor (OccByte x) f = with x (f . castPtr)
|
||||
getBytesFor (OccUInt16 x) f = with x (f . castPtr)
|
||||
getBytesFor (OccUInt32 x) f = with x (f . castPtr)
|
||||
getBytesFor (OccUInt64 x) f = with x (f . castPtr)
|
||||
getBytesFor (OccInt8 x) f = with x (f . castPtr)
|
||||
getBytesFor (OccInt16 x) f = with x (f . castPtr)
|
||||
getBytesFor (OccInt x) f = with x (f . castPtr)
|
||||
getBytesFor (OccInt32 x) f = with x (f . castPtr)
|
||||
getBytesFor (OccInt64 x) f = with x (f . castPtr)
|
||||
getBytesFor (OccReal32 x) f = with x (f . castPtr)
|
||||
getBytesFor (OccReal64 x) f = with x (f . castPtr)
|
||||
|
||||
evalExpression :: A.Expression -> EvalM OccValue
|
||||
evalExpression (A.Monadic _ op e)
|
||||
= do v <- evalExpression e
|
||||
|
@ -152,6 +234,9 @@ evalExpression (A.MostPos _ A.Int32) = return $ OccInt32 maxBound
|
|||
evalExpression (A.MostNeg _ A.Int32) = return $ OccInt32 minBound
|
||||
evalExpression (A.MostPos _ A.Int64) = return $ OccInt64 maxBound
|
||||
evalExpression (A.MostNeg _ A.Int64) = return $ OccInt64 minBound
|
||||
evalExpression (A.Conversion _ cm t e)
|
||||
= do e' <- evalExpression e
|
||||
evalConversion cm t e'
|
||||
evalExpression (A.SizeExpr m e)
|
||||
= do t <- astTypeOf e >>= underlyingType m
|
||||
case t of
|
||||
|
@ -204,17 +289,45 @@ evalMonadic A.MonadicBitNot a = evalMonadicOp complement a
|
|||
evalMonadic A.MonadicNot (OccBool b) = return $ OccBool (not b)
|
||||
evalMonadic op _ = throwError (Nothing, "bad monadic op: " ++ show op)
|
||||
|
||||
evalDyadicOp :: (forall t. (Num t, Integral t, Bounded t, Bits t) => t -> t -> t) -> OccValue -> OccValue -> EvalM OccValue
|
||||
evalDyadicOp f (OccByte a) (OccByte b) = return $ OccByte (f a b)
|
||||
evalDyadicOp f (OccUInt16 a) (OccUInt16 b) = return $ OccUInt16 (f a b)
|
||||
evalDyadicOp f (OccUInt32 a) (OccUInt32 b) = return $ OccUInt32 (f a b)
|
||||
evalDyadicOp f (OccUInt64 a) (OccUInt64 b) = return $ OccUInt64 (f a b)
|
||||
evalDyadicOp f (OccInt8 a) (OccInt8 b) = return $ OccInt8 (f a b)
|
||||
evalDyadicOp f (OccInt a) (OccInt b) = return $ OccInt (f a b)
|
||||
evalDyadicOp f (OccInt16 a) (OccInt16 b) = return $ OccInt16 (f a b)
|
||||
evalDyadicOp f (OccInt32 a) (OccInt32 b) = return $ OccInt32 (f a b)
|
||||
evalDyadicOp f (OccInt64 a) (OccInt64 b) = return $ OccInt64 (f a b)
|
||||
evalDyadicOp _ v0 v1 = throwError (Nothing, "dyadic operator not implemented for these types: " ++ show v0 ++ " and " ++ show v1)
|
||||
evalArithOp :: (forall t. (Num t) => t -> t -> t) -> OccValue -> OccValue -> EvalM OccValue
|
||||
evalArithOp f (OccByte a) (OccByte b) = return $ OccByte (f a b)
|
||||
evalArithOp f (OccUInt16 a) (OccUInt16 b) = return $ OccUInt16 (f a b)
|
||||
evalArithOp f (OccUInt32 a) (OccUInt32 b) = return $ OccUInt32 (f a b)
|
||||
evalArithOp f (OccUInt64 a) (OccUInt64 b) = return $ OccUInt64 (f a b)
|
||||
evalArithOp f (OccInt8 a) (OccInt8 b) = return $ OccInt8 (f a b)
|
||||
evalArithOp f (OccInt a) (OccInt b) = return $ OccInt (f a b)
|
||||
evalArithOp f (OccInt16 a) (OccInt16 b) = return $ OccInt16 (f a b)
|
||||
evalArithOp f (OccInt32 a) (OccInt32 b) = return $ OccInt32 (f a b)
|
||||
evalArithOp f (OccInt64 a) (OccInt64 b) = return $ OccInt64 (f a b)
|
||||
evalArithOp f (OccReal32 a) (OccReal32 b) = return $ OccReal32 (f a b)
|
||||
evalArithOp f (OccReal64 a) (OccReal64 b) = return $ OccReal64 (f a b)
|
||||
evalArithOp _ v0 v1 = throwError (Nothing, "dyadic operator not implemented for these types: " ++ show v0 ++ " and " ++ show v1)
|
||||
|
||||
evalArithIntOp :: (forall t. (Num t, Integral t, Bounded t) => t -> t -> t) -> OccValue -> OccValue -> EvalM OccValue
|
||||
evalArithIntOp f (OccByte a) (OccByte b) = return $ OccByte (f a b)
|
||||
evalArithIntOp f (OccUInt16 a) (OccUInt16 b) = return $ OccUInt16 (f a b)
|
||||
evalArithIntOp f (OccUInt32 a) (OccUInt32 b) = return $ OccUInt32 (f a b)
|
||||
evalArithIntOp f (OccUInt64 a) (OccUInt64 b) = return $ OccUInt64 (f a b)
|
||||
evalArithIntOp f (OccInt8 a) (OccInt8 b) = return $ OccInt8 (f a b)
|
||||
evalArithIntOp f (OccInt a) (OccInt b) = return $ OccInt (f a b)
|
||||
evalArithIntOp f (OccInt16 a) (OccInt16 b) = return $ OccInt16 (f a b)
|
||||
evalArithIntOp f (OccInt32 a) (OccInt32 b) = return $ OccInt32 (f a b)
|
||||
evalArithIntOp f (OccInt64 a) (OccInt64 b) = return $ OccInt64 (f a b)
|
||||
evalArithIntOp _ v0 v1 = throwError (Nothing, "dyadic operator not implemented for these types: " ++ show v0 ++ " and " ++ show v1)
|
||||
|
||||
|
||||
evalLogicOp :: (forall t. (Bits t) => t -> t -> t) -> OccValue -> OccValue -> EvalM OccValue
|
||||
evalLogicOp f (OccByte a) (OccByte b) = return $ OccByte (f a b)
|
||||
evalLogicOp f (OccUInt16 a) (OccUInt16 b) = return $ OccUInt16 (f a b)
|
||||
evalLogicOp f (OccUInt32 a) (OccUInt32 b) = return $ OccUInt32 (f a b)
|
||||
evalLogicOp f (OccUInt64 a) (OccUInt64 b) = return $ OccUInt64 (f a b)
|
||||
evalLogicOp f (OccInt8 a) (OccInt8 b) = return $ OccInt8 (f a b)
|
||||
evalLogicOp f (OccInt a) (OccInt b) = return $ OccInt (f a b)
|
||||
evalLogicOp f (OccInt16 a) (OccInt16 b) = return $ OccInt16 (f a b)
|
||||
evalLogicOp f (OccInt32 a) (OccInt32 b) = return $ OccInt32 (f a b)
|
||||
evalLogicOp f (OccInt64 a) (OccInt64 b) = return $ OccInt64 (f a b)
|
||||
evalLogicOp _ v0 v1 = throwError (Nothing, "dyadic operator not implemented for these types: " ++ show v0 ++ " and " ++ show v1)
|
||||
|
||||
|
||||
evalCompareOp :: (forall t. (Eq t, Ord t) => t -> t -> Bool) -> OccValue -> OccValue -> EvalM OccValue
|
||||
evalCompareOp f (OccByte a) (OccByte b) = return $ OccBool (f a b)
|
||||
|
@ -248,18 +361,20 @@ safeRem a b = rem a b
|
|||
|
||||
evalDyadic :: A.DyadicOp -> OccValue -> OccValue -> EvalM OccValue
|
||||
-- FIXME These should check for overflow.
|
||||
evalDyadic A.Add a b = evalDyadicOp (+) a b
|
||||
evalDyadic A.Subtr a b = evalDyadicOp (-) a b
|
||||
evalDyadic A.Mul a b = evalDyadicOp (*) a b
|
||||
evalDyadic A.Div a b = evalDyadicOp safeDiv a b
|
||||
evalDyadic A.Rem a b = evalDyadicOp safeRem a b
|
||||
evalDyadic A.Add a b = evalArithOp (+) a b
|
||||
evalDyadic A.Subtr a b = evalArithOp (-) a b
|
||||
evalDyadic A.Mul a b = evalArithOp (*) a b
|
||||
evalDyadic A.Div (OccReal32 a) (OccReal32 b) = return $ OccReal32 $ a / b
|
||||
evalDyadic A.Div (OccReal64 a) (OccReal64 b) = return $ OccReal64 $ a / b
|
||||
evalDyadic A.Div a b = evalArithIntOp safeDiv a b
|
||||
evalDyadic A.Rem a b = evalArithIntOp safeRem a b
|
||||
-- ... end FIXME
|
||||
evalDyadic A.Plus a b = evalDyadicOp (+) a b
|
||||
evalDyadic A.Minus a b = evalDyadicOp (-) a b
|
||||
evalDyadic A.Times a b = evalDyadicOp (*) a b
|
||||
evalDyadic A.BitAnd a b = evalDyadicOp (.&.) a b
|
||||
evalDyadic A.BitOr a b = evalDyadicOp (.|.) a b
|
||||
evalDyadic A.BitXor a b = evalDyadicOp xor a b
|
||||
evalDyadic A.Plus a b = evalArithOp (+) a b
|
||||
evalDyadic A.Minus a b = evalArithOp (-) a b
|
||||
evalDyadic A.Times a b = evalArithOp (*) a b
|
||||
evalDyadic A.BitAnd a b = evalLogicOp (.&.) a b
|
||||
evalDyadic A.BitOr a b = evalLogicOp (.|.) a b
|
||||
evalDyadic A.BitXor a b = evalLogicOp xor a b
|
||||
evalDyadic A.LeftShift a (OccInt b)
|
||||
= evalMonadicOp (\v -> shiftL v (fromIntegral b)) a
|
||||
evalDyadic A.RightShift a (OccInt b)
|
||||
|
@ -303,6 +418,8 @@ renderLiteral m t v
|
|||
OccInt64 i -> renderInt i
|
||||
OccArray vs -> renderArray vs
|
||||
OccRecord _ vs -> renderRecord vs
|
||||
OccReal32 n -> renderReal n
|
||||
OccReal64 n -> renderReal n
|
||||
where
|
||||
renderChar :: Char -> String
|
||||
renderChar '\'' = "*'"
|
||||
|
@ -319,6 +436,9 @@ renderLiteral m t v
|
|||
renderInt :: Show s => s -> m (A.Type, A.LiteralRepr)
|
||||
renderInt i = return (t, A.IntLiteral m $ show i)
|
||||
|
||||
renderReal :: Show s => s -> m (A.Type, A.LiteralRepr)
|
||||
renderReal i = return (t, A.RealLiteral m $ show i)
|
||||
|
||||
renderArray :: [OccValue] -> m (A.Type, A.LiteralRepr)
|
||||
renderArray vs
|
||||
= do (t', aes) <- renderArrayElems t vs
|
||||
|
|
|
@ -52,6 +52,8 @@ data OccValue =
|
|||
| OccInt CIntReplacement
|
||||
| OccInt32 Int32
|
||||
| OccInt64 Int64
|
||||
| OccReal32 Float
|
||||
| OccReal64 Double
|
||||
| OccArray [OccValue]
|
||||
| OccRecord A.Name [OccValue]
|
||||
deriving (Show, Eq, Typeable, Data)
|
||||
|
@ -117,6 +119,8 @@ evalSimpleLiteral (A.Literal _ t lr)
|
|||
A.Int -> into OccInt
|
||||
A.Int32 -> into OccInt32
|
||||
A.Int64 -> into OccInt64
|
||||
A.Real32 -> intoF OccReal32
|
||||
A.Real64 -> intoF OccReal64
|
||||
_ -> bad
|
||||
where
|
||||
defaults :: EvalM OccValue
|
||||
|
@ -125,6 +129,7 @@ evalSimpleLiteral (A.Literal _ t lr)
|
|||
A.ByteLiteral _ s -> evalByteLiteral m OccByte s
|
||||
A.IntLiteral _ s -> fromRead m OccInt (readSigned readDec) s
|
||||
A.HexLiteral _ s -> fromRead m OccInt readHex s
|
||||
A.RealLiteral _ s -> fromRead m OccReal32 readFloat s
|
||||
_ -> bad
|
||||
|
||||
into :: (Num t, Real t) => (t -> OccValue) -> EvalM OccValue
|
||||
|
@ -135,6 +140,16 @@ evalSimpleLiteral (A.Literal _ t lr)
|
|||
A.HexLiteral _ s -> fromRead m cons readHex s
|
||||
_ -> bad
|
||||
|
||||
intoF :: RealFrac t => (t -> OccValue) -> EvalM OccValue
|
||||
intoF cons
|
||||
= case lr of
|
||||
A.ByteLiteral _ s -> evalByteLiteral m cons s
|
||||
A.IntLiteral _ s -> fromRead m cons (readSigned readDec) s
|
||||
A.HexLiteral _ s -> fromRead m cons readHex s
|
||||
A.RealLiteral _ s -> fromRead m cons readFloat s
|
||||
_ -> bad
|
||||
|
||||
|
||||
bad :: EvalM OccValue
|
||||
bad = throwError (Just m, "Cannot evaluate literal")
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user