diff --git a/frontends/RainPassesTest.hs b/frontends/RainPassesTest.hs index 5a072fa..e0110ed 100644 --- a/frontends/RainPassesTest.hs +++ b/frontends/RainPassesTest.hs @@ -296,10 +296,10 @@ testParamPass :: testParamPass testName formals params transParams = case transParams of - Just act -> TestList [TestCase $ testPass (testName ++ "/process") (expProc act) (matchParamPass origProc) startStateProc, - TestCase $ testPass (testName ++ "/function") (expFunc act) (matchParamPass origFunc) startStateFunc] - Nothing -> TestList [TestCase $ testPassShouldFail (testName ++ "/process") (matchParamPass origProc) startStateProc, - TestCase $ testPassShouldFail (testName ++ "/function") (matchParamPass origFunc) startStateFunc] + Just act -> TestList [TestCase $ testPass (testName ++ "/process") (expProc act) (performTypeUnification origProc) startStateProc, + TestCase $ testPass (testName ++ "/function") (expFunc act) (performTypeUnification origFunc) startStateFunc] + Nothing -> TestList [TestCase $ testPassShouldFail (testName ++ "/process") (performTypeUnification origProc) startStateProc, + TestCase $ testPassShouldFail (testName ++ "/function") (performTypeUnification origFunc) startStateFunc] where startStateProc :: State CompState () startStateProc = do defineName (simpleName "x") $ simpleDefDecl "x" (A.UInt16) @@ -378,8 +378,8 @@ testParamPass7 = testParamPass "testParamPass7" -- | Test calling something that is not a process: testParamPass8 :: Test -testParamPass8 = TestList [TestCase $ testPassShouldFail "testParamPass8/process" (matchParamPass origProc) (startState'), - TestCase $ testPassShouldFail "testParamPass8/function" (matchParamPass origFunc) (startState')] +testParamPass8 = TestList [TestCase $ testPassShouldFail "testParamPass8/process" (performTypeUnification origProc) (startState'), + TestCase $ testPassShouldFail "testParamPass8/function" (performTypeUnification origFunc) (startState')] where startState' :: State CompState () startState' = do defineName (simpleName "x") $ simpleDefDecl "x" (A.UInt16) diff --git a/frontends/RainTypes.hs b/frontends/RainTypes.hs index 206ac41..829df8c 100644 --- a/frontends/RainTypes.hs +++ b/frontends/RainTypes.hs @@ -93,6 +93,7 @@ performTypeUnification x put st {csUnifyPairs = [], csUnifyLookup = ul} -- Then we markup all the types in the tree: x' <- markConditionalTypes + <.< markParamPass <.< markAssignmentTypes <.< markCommTypes $ x --TODO markup everything else @@ -206,58 +207,35 @@ annotateListLiteralTypes = applyDepthM doExpression -- | A pass that finds all the 'A.ProcCall' and 'A.FunctionCall' in the -- AST, and checks that the actual parameters are valid inputs, given -- the 'A.Formal' parameters in the process's type -matchParamPass :: Data t => t -> PassM t -matchParamPass = everywhereM ((mkM matchParamPassProc) `extM` matchParamPassFunc) +markParamPass :: Data t => t -> PassM t +markParamPass = checkDepthM2 matchParamPassProc matchParamPassFunc where --Picks out the parameters of a process call, checks the number is correct, and maps doParam over them - matchParamPassProc :: A.Process -> PassM A.Process + matchParamPassProc :: Check A.Process matchParamPassProc (A.ProcCall m n actualParams) = do def <- lookupNameOrError n $ dieP m ("Process name is unknown: \"" ++ (show $ A.nameName n) ++ "\"") case A.ndType def of A.Proc _ _ expectedParams _ -> if (length expectedParams) == (length actualParams) - then do transActualParams <- mapM (doParam m (A.nameName n)) (zip3 [0..] expectedParams actualParams) - return $ A.ProcCall m n transActualParams + then mapM_ (uncurry markUnify) (zip expectedParams actualParams) else dieP m $ "Wrong number of parameters given to process call; expected: " ++ show (length expectedParams) ++ " but found: " ++ show (length actualParams) _ -> dieP m $ "You cannot run things that are not processes, such as: \"" ++ (show $ A.nameName n) ++ "\"" - matchParamPassProc p = return p + matchParamPassProc _ = return () --Picks out the parameters of a function call, checks the number is correct, and maps doExpParam over them - matchParamPassFunc :: A.Expression -> PassM A.Expression + matchParamPassFunc :: Check A.Expression matchParamPassFunc (A.FunctionCall m n actualParams) = do def <- lookupNameOrError n $ dieP m ("Function name is unknown: \"" ++ (show $ A.nameName n) ++ "\"") case A.ndType def of A.Function _ _ _ expectedParams _ -> if (length expectedParams) == (length actualParams) - then do transActualParams <- mapM (doExpParam m (A.nameName n)) (zip3 [0..] expectedParams actualParams) - return $ A.FunctionCall m n transActualParams + then mapM_ (uncurry markUnify) (zip expectedParams actualParams) else dieP m $ "Wrong number of parameters given to function call; expected: " ++ show (length expectedParams) ++ " but found: " ++ show (length actualParams) _ -> dieP m $ "Attempt to make a function call with something" ++ " that is not a function: \"" ++ A.nameName n ++ "\"; is actually: " ++ showConstr (toConstr $ A.ndType def) - matchParamPassFunc e = return e - - --Checks the type of a parameter (A.Actual), and inserts a cast if it is safe to do so - doParam :: Meta -> String -> (Int,A.Formal, A.Actual) -> PassM A.Actual - doParam m n (index, A.Formal formalAbbrev formalType formalName, A.ActualVariable v) - = do actualType <- astTypeOf v - if (actualType == formalType) - then return $ A.ActualVariable v - else (liftM A.ActualExpression) $ doCast index formalType actualType (A.ExprVariable (findMeta v) v ) - doParam m n (index, for@(A.Formal _ formalType _), A.ActualExpression e) - = (liftM A.ActualExpression) $ doExpParam m n (index, for, e) - - --Checks the type of a parameter (A.Expression), and inserts a cast if it is safe to do so - doExpParam :: Meta -> String -> (Int, A.Formal, A.Expression) -> PassM A.Expression - doExpParam m n (index, A.Formal formalAbbrev formalType formalName, e) - = do actualType <- astTypeOf e - if (actualType == formalType) - then return e - else doCast index formalType actualType e - - doCast :: Int -> A.Type -> A.Type -> A.Expression -> PassM A.Expression - doCast index = coerceType $ " for parameter (zero-based): " ++ (show index) + matchParamPassFunc _ = return () --Adds a cast between two types if it is safe to do so, otherwise gives an error coerceType :: String -> A.Type -> A.Type -> A.Expression -> PassM A.Expression