diff --git a/pass/Pass.hs b/pass/Pass.hs index fea3daf..1c6d77b 100644 --- a/pass/Pass.hs +++ b/pass/Pass.hs @@ -37,20 +37,67 @@ import TreeUtils import Utils -- | The monad in which AST-mangling passes operate. -type PassM = ErrorT ErrorReport (StateT CompState (WriterT [WarningReport] IO)) -type PassMR = ErrorT ErrorReport (ReaderT CompState (WriterT [WarningReport] IO)) +-- The old monad stacks: +--type PassM = ErrorT ErrorReport (StateT CompState (WriterT [WarningReport] IO)) +--type PassMR = ErrorT ErrorReport (ReaderT CompState (WriterT [WarningReport] IO)) + +newtype PassM a = PassMInternal { runPassM :: CompState -> IO (Either ErrorReport (a, CompState, [WarningReport])) } +newtype PassMR a = PassMRInternal { runPassMR :: CompState -> IO (Either ErrorReport (a, [WarningReport])) } + +instance Monad PassM where + return x = PassMInternal $ \cs -> return (Right (x, cs, [])) + m >>= b = PassMInternal $ \cs -> + do mresult <- runPassM m cs + case mresult of + Left err -> return $ Left err + Right (x, cs', w') -> + do bresult <- runPassM (b x) cs' + case bresult of + Left err -> return $ Left err + Right (x', cs'', w'') -> return $ Right (x', cs'', w' ++ w'') + +instance Monad PassMR where + return x = PassMRInternal $ \cs -> return (Right (x, [])) + m >>= b = PassMRInternal $ \cs -> + do mresult <- runPassMR m cs + case mresult of + Left err -> return $ Left err + Right (x, w') -> + do bresult <- runPassMR (b x) cs + case bresult of + Left err -> return $ Left err + Right (x', w'') -> return $ Right (x', w' ++ w'') + instance Die PassM where - dieReport = throwError + dieReport err = PassMInternal $ const $ return $ Left err instance Die PassMR where - dieReport = throwError + dieReport err = PassMRInternal $ const $ return $ Left err instance Warn PassM where - warnReport w = tell [w] + warnReport w = PassMInternal $ \cs -> return (Right ((), cs, [w])) instance Warn PassMR where - warnReport w = tell [w] + warnReport w = PassMRInternal $ \cs -> return (Right ((), [w])) + +instance MonadIO PassM where + liftIO a = PassMInternal $ \cs -> do a' <- a + return $ Right (a', cs, []) + +instance MonadIO PassMR where + liftIO a = PassMRInternal $ const $ do a' <- a + return $ Right (a', []) + +instance CSMR PassM where + getCompState = PassMInternal $ \cs -> return (Right (cs, cs, [])) + +instance CSMR PassMR where + getCompState = PassMRInternal $ \cs -> return (Right (cs, [])) + +instance MonadState CompState PassM where + get = getCompState + put cs = PassMInternal $ const $ return (Right ((), cs, [])) -- | The type of an AST-mangling pass. data Monad m => Pass_ m = Pass { @@ -83,12 +130,10 @@ instance Ord Property where compare x y = compare (propName x) (propName y) runPassR :: (A.AST -> PassMR A.AST) -> (A.AST -> PassM A.AST) -runPassR p t - = do st <- get - (r,w) <- liftIO $ runWriterT $ runReaderT (runErrorT (p t)) st - case r of - Left err -> throwError err - Right result -> tell w >> return result +runPassR p t = PassMInternal $ \cs -> do result <- runPassMR (p t) cs + case result of + Left err -> return $ Left err + Right (t', w) -> return $ Right (t', cs, w) makePassesDep :: [(String, A.AST -> PassM A.AST, [Property], [Property])] -> [Pass] makePassesDep = map (\(s, p, pre, post) -> Pass p s (Set.fromList pre) (Set.fromList post) (const True))