diff --git a/frontends/RainTypes.hs b/frontends/RainTypes.hs index abe4322..9c54aca 100644 --- a/frontends/RainTypes.hs +++ b/frontends/RainTypes.hs @@ -101,9 +101,21 @@ checkExpressionTypes = everywhereASTM checkExpression trhs <- typeOfExpression rhs if (tlhs == trhs) then return e - else case (leastGeneralSharedTypeRain [tlhs,trhs]) of - Nothing -> dieP m $ "Cannot find a suitable type to convert expression to, types are: " ++ show tlhs ++ " and " ++ show trhs - Just t -> return $ A.Dyadic m op (convert t tlhs lhs) (convert t trhs rhs) + else if (isIntegerType tlhs && isIntegerType trhs) + then case (leastGeneralSharedTypeRain [tlhs,trhs]) of + Nothing -> dieP m $ "Cannot find a suitable type to convert expression to, types are: " ++ show tlhs ++ " and " ++ show trhs + Just t -> return $ A.Dyadic m op (convert t tlhs lhs) (convert t trhs rhs) + else return e --TODO + checkExpression e@(A.Monadic m op rhs) + = do trhs <- typeOfExpression rhs + if (op == A.MonadicMinus) + then case trhs of + A.Byte -> return $ A.Monadic m op $ convert A.Int16 trhs rhs + A.UInt16 -> return $ A.Monadic m op $ convert A.Int32 trhs rhs + A.UInt32 -> return $ A.Monadic m op $ convert A.Int64 trhs rhs + A.UInt64 -> dieP m $ "Cannot apply unary minus to type: " ++ show trhs ++ " because there is no type large enough to safely contain the result" + _ -> if (isIntegerType trhs) then return e else dieP m $ "Trying to apply unary minus to non-integer type: " ++ show trhs + else return e checkExpression e = return e convert :: A.Type -> A.Type -> A.Expression -> A.Expression diff --git a/frontends/RainTypesTest.hs b/frontends/RainTypesTest.hs index f9c95fa..412b0d5 100644 --- a/frontends/RainTypesTest.hs +++ b/frontends/RainTypesTest.hs @@ -103,6 +103,12 @@ checkExpressionTest = TestList ,pass 401 A.Int16 (Dy (Cast A.Int16 $ Var "x8") A.Plus (int A.Int16 200)) (Dy (Var "x8") A.Plus (int A.Int16 200)) --This fails because you are trying to add a signed constant to an unsigned integer that cannot be expanded: ,fail 402 $ Dy (Var "xu64") A.Plus (int A.Int64 0) + + ,passSame 500 A.Int32 (Mon A.MonadicMinus (Var "x32")) + ,pass 501 A.Int32 (Mon A.MonadicMinus (Cast A.Int32 $ Var "xu16")) (Mon A.MonadicMinus (Var "xu16")) + ,fail 502 $ Mon A.MonadicMinus (Var "xu64") + ,pass 503 A.Int64 (Dy (Var "x") A.Plus (Cast A.Int64 $ Mon A.MonadicMinus (Var "x32"))) (Dy (Var "x") A.Plus (Mon A.MonadicMinus (Var "x32"))) + ] where passSame :: Int -> A.Type -> ExprHelper -> Test @@ -146,6 +152,7 @@ checkExpressionTest = TestList defVar "xu16" A.UInt16 defVar "xu32" A.UInt32 defVar "xu64" A.UInt64 + defVar "x32" A.Int32 defVar "x16" A.Int16 defVar "x8" A.Int8