diff --git a/frontends/RainTypes.hs b/frontends/RainTypes.hs index de2a830..4be2f02 100644 --- a/frontends/RainTypes.hs +++ b/frontends/RainTypes.hs @@ -250,3 +250,36 @@ checkConditionalTypes t = (everywhereASTM checkWhile t) >>= (everywhereASTM chec if (t == A.Bool) then return c else dieP m "Expression in if conditional must be of boolean type" + +-- | Checks the types in inputs and outputs +checkCommTypes :: Data t => t -> PassM t +checkCommTypes = everywhereASTM checkInputOutput + where + checkInputOutput :: A.Process -> PassM A.Process + checkInputOutput p@(A.Input m chanVar (A.InputSimple _ [A.InVariable _ destVar])) + = do chanType <- typeOfVariable chanVar + destType <- typeOfVariable destVar + case chanType of + A.Chan dir _ innerType -> + if (dir == A.DirOutput) + then dieP m $ "Tried to input from the writing end of a channel: " ++ show chanVar + else + if (innerType == destType) + then return p + else dieP m $ "Mis-matching types; channel: " ++ show chanVar ++ " has inner-type: " ++ show innerType ++ + " but destination variable: " ++ show destVar ++ " has type: " ++ show destType + _ -> dieP m $ "Tried to input from a variable that is not of type channel: " ++ show chanVar + checkInputOutput p@(A.Output m chanVar [A.OutExpression m' srcExp]) + = do chanType <- typeOfVariable chanVar + srcType <- typeOfExpression srcExp + case chanType of + A.Chan dir _ innerType -> + if (dir == A.DirInput) + then dieP m $ "Tried to output to the reading end of a channel: " ++ show chanVar + else + if (innerType == srcType) + then return p + else do castExp <- coerceType " for writing to channel" innerType srcType srcExp + return $ A.Output m chanVar [A.OutExpression m' castExp] + _ -> dieP m $ "Tried to output to a variable that is not of type channel: " ++ show chanVar + checkInputOutput p = return p diff --git a/frontends/RainTypesTest.hs b/frontends/RainTypesTest.hs index cc104e1..214817f 100644 --- a/frontends/RainTypesTest.hs +++ b/frontends/RainTypesTest.hs @@ -188,7 +188,7 @@ checkExpressionTest = TestList ,failWhileIf 4100 $ Var "x" ,failWhileIf 4101 $ Dy (Var "x") A.Plus (Var "x") - + ,testAllCheckCommTypesIn 5000 ] where passAssign :: Int -> String -> ExprHelper -> ExprHelper -> Test @@ -234,7 +234,48 @@ checkExpressionTest = TestList (checkConditionalTypes $ A.While m (buildExpr src) (A.Skip m)) state ] + + --Takes an index, the inner type of the channel and direction with a variable, then the type and variable for the RHS + --Expects a pass only if the inner type of the channel is the same as the type of the variable, and channel direction is unknown or input + testCheckCommTypesIn :: Int -> (A.Direction,A.Type,A.Variable) -> (A.Type,A.Variable) -> Test + testCheckCommTypesIn n (chanDir,chanType,chanVar) (destType,destVar) + = if (chanType == destType && chanDir /= A.DirOutput) + then TestCase $ testPass ("testCheckCommTypesIn " ++ show n) (mkPattern st) (checkCommTypes st) state + else TestCase $ testPassShouldFail ("testCheckCommTypesIn " ++ show n) (checkCommTypes st) state + where + st = A.Input m chanVar $ A.InputSimple m [A.InVariable m destVar] + --Automatically tests checking inputs for various combinations of channel type and direction + testAllCheckCommTypesIn :: Int -> Test + testAllCheckCommTypesIn n = TestList $ map (\(n,f) -> f n) $ zip [n..] $ + concat [[\ind -> testCheckCommTypesIn ind c d,\ind -> testCheckCommTypesOut ind c d] | c <- chans, d <- vars] + where + chans = concatMap allDirs [(A.Int64,variable "c"), (A.Bool,variable "cb"), (A.Byte, variable "cu8")] + vars = [(A.Bool, variable "b"), (A.Int64, variable "x"), (A.Byte, variable "xu8"), (A.Int16, variable "x16")] + allDirs :: (A.Type,A.Variable) -> [(A.Direction,A.Type,A.Variable)] + allDirs (t,v) = + [ + (A.DirInput,t,A.DirectedVariable m A.DirInput v) + ,(A.DirOutput,t,A.DirectedVariable m A.DirOutput v) + ,(A.DirUnknown,t,v) + ] + + --Takes an index, the inner type of the channel and direction with a variable, then the type and variable for the RHS + --Expects a pass only if the expression type can be cast to the inner type of the channel, and channel direction is unknown or output + testCheckCommTypesOut :: Int -> (A.Direction,A.Type,A.Variable) -> (A.Type,A.Variable) -> Test + testCheckCommTypesOut n (chanDir,chanType,chanVar) (srcType,srcVar) + = if (isImplicitConversionRain srcType chanType && chanDir /= A.DirInput) + then (if srcType == chanType + then TestCase $ testPass ("testCheckCommTypesOut " ++ show n) (mkPattern st) (checkCommTypes st) state + else TestCase $ testPass ("testCheckCommTypesOut " ++ show n) stCast (checkCommTypes st) state + ) + else TestCase $ testPassShouldFail ("testCheckCommTypesOut " ++ show n) (checkCommTypes st) state + where + st = A.Output m chanVar [A.OutExpression m $ A.ExprVariable m srcVar] + stCast = tag3 A.Output DontCare chanVar [tag2 A.OutExpression DontCare $ tag4 A.Conversion DontCare A.DefaultConversion chanType $ + A.ExprVariable m srcVar] + + passSame :: Int -> A.Type -> ExprHelper -> Test passSame n t e = pass n t e e @@ -279,6 +320,9 @@ checkExpressionTest = TestList defVar "x32" A.Int32 defVar "x16" A.Int16 defVar "x8" A.Int8 + defVar "c" $ A.Chan A.DirUnknown (A.ChanAttributes False False) A.Int64 + defVar "cu8" $ A.Chan A.DirUnknown (A.ChanAttributes False False) A.Byte + defVar "cb" $ A.Chan A.DirUnknown (A.ChanAttributes False False) A.Bool tests :: Test tests = TestList