From 4ef1ff71965c97f1c7951171e1535a65fb4c6bfb Mon Sep 17 00:00:00 2001 From: Neil Brown Date: Thu, 3 Apr 2008 12:21:59 +0000 Subject: [PATCH] Changed to a state monad for warnings, and added a runPassM function to remove duplicate code for running passes --- Main.hs | 6 +++--- backends/GenerateCTest.hs | 5 ++++- checks/UsageCheckTest.hs | 4 ++-- common/TestHarness.hs | 2 +- common/TestUtils.hs | 5 +++-- frontends/RainTypesTest.hs | 4 +++- pass/Pass.hs | 18 ++++++++++++------ 7 files changed, 28 insertions(+), 16 deletions(-) diff --git a/Main.hs b/Main.hs index 6074c44..b31d1e5 100644 --- a/Main.hs +++ b/Main.hs @@ -161,10 +161,10 @@ main = do ModeFull -> evalStateT (compileFull fn fileStem) [] -- Run the compiler. - v <- runWriterT $ evalStateT (runErrorT operation) initState + v <- runPassM initState operation case v of - (Left e, ws) -> showWarnings ws >> dieIO e - (Right r, ws) -> showWarnings ws + (Left e, _, ws) -> showWarnings ws >> dieIO e + (Right r, _, ws) -> showWarnings ws removeFiles :: [FilePath] -> IO () removeFiles = mapM_ (\file -> catch (removeFile file) doNothing) diff --git a/backends/GenerateCTest.hs b/backends/GenerateCTest.hs index f469c40..ff2a2f3 100644 --- a/backends/GenerateCTest.hs +++ b/backends/GenerateCTest.hs @@ -48,6 +48,7 @@ import GenerateC import GenerateCBased import GenerateCPPCSP import Metadata +import Pass import TestUtils import Utils @@ -109,7 +110,9 @@ evalCGen :: CGen () -> GenOps -> CompState -> IO (Either Errors.ErrorReport [Str evalCGen act ops state = evalCGen' (runReaderT act ops) state evalCGen' :: CGen' () -> CompState -> IO (Either Errors.ErrorReport [String]) -evalCGen' act state = runWriterT (evalStateT (runErrorT $ execStateT act (Left []) >>* (\(Left x) -> x)) state) >>* fst +evalCGen' act state = runPassM state pass >>* (\(x,_,_) -> x) + where + pass = execStateT act (Left []) >>* (\(Left x) -> x) -- | Checks that running the test for the C and C++ backends produces the right output for each. testBothS :: diff --git a/checks/UsageCheckTest.hs b/checks/UsageCheckTest.hs index 30ffb6a..6e70096 100644 --- a/checks/UsageCheckTest.hs +++ b/checks/UsageCheckTest.hs @@ -111,9 +111,9 @@ testGetVarProc = TestList (map doTest tests) doTest (index, r, w, u, proc) = TestCase $ do result <- runPass (getVarProc proc) startState case result of - Left err -> + (_, Left err) -> testFailure $ name ++ " failed: " ++ show err - Right (_, result) -> + (_, Right result) -> assertEqual name (vars r w u) result where name = "testGetVarProc" ++ show index diff --git a/common/TestHarness.hs b/common/TestHarness.hs index f4ea8c0..a45a4ff 100644 --- a/common/TestHarness.hs +++ b/common/TestHarness.hs @@ -59,7 +59,7 @@ defaultState = emptyState {csUsageChecking = True} -- | Tests if compiling the given source gives any errors. -- If there are errors, they are returned. Upon success, Nothing is returned testOccam :: String -> IO (Maybe String) -testOccam source = do (result,_) <- runWriterT $ evalStateT (runErrorT compilation) defaultState +testOccam source = do (result,_,_) <- runPassM defaultState compilation return $ case result of Left (_,err) -> Just err Right _ -> Nothing diff --git a/common/TestUtils.hs b/common/TestUtils.hs index 16856bc..8d6ea4c 100644 --- a/common/TestUtils.hs +++ b/common/TestUtils.hs @@ -495,7 +495,8 @@ runPass :: TestMonad m r => PassM b -- ^ The actual pass. -> CompState -- ^ The state to use to run the pass. -> m (CompState, Either ErrorReport b) -- ^ The resultant state, and either an error or the successful outcome of the pass. -runPass actualPass startState = liftM (\((x,y),_) -> (y,x)) $ runIO (runWriterT $ runStateT (runErrorT actualPass) startState) +runPass actualPass startState = liftM (\(x,y,_) -> (y,x)) $ + runIO (runPassM startState actualPass) -- | A test that runs a given AST pass and checks that it succeeds. testPass :: @@ -585,7 +586,7 @@ testPassShouldFail testName actualPass startStateTrans = do ret <- runPass actualPass (execState startStateTrans emptyState) case ret of (_,Left err) -> return () - Right (state, output) -> testFailure $ testName ++ " pass succeeded when expected to fail; output: " ++ pshow output + (state, Right output) -> testFailure $ testName ++ " pass succeeded when expected to fail; output: " ++ pshow output --}}} --{{{ miscellaneous utilities diff --git a/frontends/RainTypesTest.hs b/frontends/RainTypesTest.hs index 71f385f..6c7a542 100644 --- a/frontends/RainTypesTest.hs +++ b/frontends/RainTypesTest.hs @@ -448,7 +448,9 @@ checkExpressionTest = TestList if (e /= act) then pass' (10000 + n) t (mkPattern e) e else return () where errorOrType :: IO (Either ErrorReport A.Type) - errorOrType = ((runWriterT (evalStateT (runErrorT $ typeOfExpression e) (execState state emptyState))) :: IO (Either ErrorReport A.Type, [WarningReport])) >>* fst + errorOrType + = (flip runPassM (typeOfExpression e) (execState state emptyState)) + >>* \(x,_,_) -> x fail :: Int -> ExprHelper -> Test diff --git a/pass/Pass.hs b/pass/Pass.hs index 67472dc..a1436f5 100644 --- a/pass/Pass.hs +++ b/pass/Pass.hs @@ -37,8 +37,8 @@ import TreeUtils import Utils -- | The monad in which AST-mangling passes operate. -type PassM = ErrorT ErrorReport (StateT CompState (WriterT [WarningReport] IO)) -type PassMR = ErrorT ErrorReport (ReaderT CompState (WriterT [WarningReport] IO)) +type PassM = ErrorT ErrorReport (StateT CompState (StateT [WarningReport] IO)) +type PassMR = ErrorT ErrorReport (ReaderT CompState (StateT [WarningReport] IO)) instance Die PassM where dieReport = throwError @@ -47,10 +47,10 @@ instance Die PassMR where dieReport = throwError instance Warn PassM where - warnReport w = tell [w] + warnReport w = lift $ lift $ modify (++ [w]) instance Warn PassMR where - warnReport w = tell [w] + warnReport w = lift $ lift $ modify (++ [w]) -- | The type of an AST-mangling pass. data Monad m => Pass_ m = Pass { @@ -85,10 +85,16 @@ instance Ord Property where runPassR :: (A.AST -> PassMR A.AST) -> (A.AST -> PassM A.AST) runPassR p t = do st <- get - (r,w) <- liftIO $ runWriterT $ runReaderT (runErrorT (p t)) st + (r,w) <- liftIO $ flip runStateT [] $ runReaderT (runErrorT (p t)) st case r of Left err -> throwError err - Right result -> tell w >> return result + Right result -> mapM_ warnReport w >> return result + +runPassM :: CompState -> PassM a -> IO (Either ErrorReport a, CompState, [WarningReport]) +runPassM cs pass = liftM flatten $ flip runStateT [] $ flip runStateT cs $ runErrorT pass + where + flatten :: ((a, b),c) -> (a, b, c) + flatten ((x, y), z) = (x, y, z) makePassesDep :: [(String, A.AST -> PassM A.AST, [Property], [Property])] -> [Pass] makePassesDep = map (\(s, p, pre, post) -> Pass p s (Set.fromList pre) (Set.fromList post) (const True))