diff --git a/frontends/OccamPasses.hs b/frontends/OccamPasses.hs index d270e34..da9260f 100644 --- a/frontends/OccamPasses.hs +++ b/frontends/OccamPasses.hs @@ -22,7 +22,6 @@ module OccamPasses (occamPasses, foldConstants, checkConstants, import Control.Monad.State import Data.Generics -import System.IO import qualified AST as A import CompState @@ -56,44 +55,36 @@ occamPasses = makePassesDep' ((== FrontendOccam) . csFrontend) -- | Fold constant expressions. foldConstants :: Data t => t -> PassM t -foldConstants = doGeneric `extM` doSpecification `extM` doExpression +foldConstants = applyDepthM2 doExpression doSpecification where - doGeneric :: Data t => t -> PassM t - doGeneric = makeGeneric foldConstants - - -- When an expression is abbreviated, try to fold it, and update its - -- definition so that it can be used when folding later expressions. - doSpecification :: A.Specification -> PassM A.Specification - doSpecification s@(A.Specification m n (A.IsExpr m' am t e)) - = do e' <- doExpression e - let st' = A.IsExpr m' am t e' - modifyName n (\nd -> nd { A.ndType = st' }) - return $ A.Specification m n st' - doSpecification s = doGeneric s - - -- For all other expressions, just try to fold them. - -- We recurse into the expression first so that we fold subexpressions of - -- non-constant expressions too. + -- Try to fold all expressions we encounter. Since we've recursed into the + -- expression first, this'll also fold subexpressions of non-constant + -- expressions. doExpression :: A.Expression -> PassM A.Expression doExpression e - = do e' <- doGeneric e - (e'', _, _) <- constantFold e' - return e'' + = do (e', _, _) <- constantFold e + return e' + + -- When an expression is abbreviated, update its definition so that it can + -- be used when folding later expressions. + doSpecification :: A.Specification -> PassM A.Specification + doSpecification s@(A.Specification _ n st@(A.IsExpr _ _ _ _)) + = do modifyName n (\nd -> nd { A.ndType = st }) + return s + doSpecification s = return s + -- | Check that things that must be constant are. checkConstants :: Data t => t -> PassM t -checkConstants = doGeneric `extM` doDimension `extM` doOption +checkConstants = applyDepthM2 doDimension doOption where - doGeneric :: Data t => t -> PassM t - doGeneric = makeGeneric checkConstants - -- Check array dimensions are constant. doDimension :: A.Dimension -> PassM A.Dimension doDimension d@(A.Dimension e) = do when (not $ isConstant e) $ diePC (findMeta e) $ formatCode "Array dimension must be constant: %" e - doGeneric d - doDimension d = doGeneric d + return d + doDimension d = return d -- Check case options are constant. doOption :: A.Option -> PassM A.Option @@ -101,12 +92,12 @@ checkConstants = doGeneric `extM` doDimension `extM` doOption = do sequence_ [when (not $ isConstant e) $ diePC (findMeta e) $ formatCode "Case option must be constant: %" e | e <- es] - doGeneric o - doOption o = doGeneric o + return o + doOption o = return o -- | Check that retyping is safe. checkRetypes :: Data t => t -> PassM t -checkRetypes = everywhereASTM doSpecType +checkRetypes = applyDepthM doSpecType where doSpecType :: A.SpecType -> PassM A.SpecType doSpecType st@(A.Retypes m _ t v) diff --git a/frontends/RainTypes.hs b/frontends/RainTypes.hs index e4c8949..cc44394 100644 --- a/frontends/RainTypes.hs +++ b/frontends/RainTypes.hs @@ -49,14 +49,14 @@ recordInfNameTypes = everywhereM (mkM recordInfNameTypes') -- | Folds all constants. constantFoldPass :: Data t => t -> PassM t -constantFoldPass = everywhereASTM doExpression +constantFoldPass = applyDepthM doExpression where doExpression :: A.Expression -> PassM A.Expression doExpression = (liftM (\(x,_,_) -> x)) . constantFold -- | Annotates all integer literal types annnotateIntLiteralTypes :: Data t => t -> PassM t -annnotateIntLiteralTypes = everywhereASTM doExpression +annnotateIntLiteralTypes = applyDepthM doExpression where --Function is separated out to easily provide the type description of Integer powOf2 :: Integer -> Integer @@ -148,7 +148,7 @@ coerceType customMsg to from item -- | Checks the types in expressions checkExpressionTypes :: Data t => t -> PassM t -checkExpressionTypes = everywhereASTM checkExpression +checkExpressionTypes = applyDepthM checkExpression where checkExpression :: A.Expression -> PassM A.Expression checkExpression e@(A.Dyadic m op lhs rhs) @@ -231,7 +231,7 @@ checkExpressionTypes = everywhereASTM checkExpression -- | Checks the types in assignments checkAssignmentTypes :: Data t => t -> PassM t -checkAssignmentTypes = everywhereASTM checkAssignment +checkAssignmentTypes = applyDepthM checkAssignment where checkAssignment :: A.Process -> PassM A.Process checkAssignment ass@(A.Assign m [v] (A.ExpressionList m' [e])) @@ -246,7 +246,7 @@ checkAssignmentTypes = everywhereASTM checkAssignment -- | Checks the types in if and while conditionals checkConditionalTypes :: Data t => t -> PassM t -checkConditionalTypes t = (everywhereASTM checkWhile t) >>= (everywhereASTM checkIf) +checkConditionalTypes = applyDepthM2 checkWhile checkIf where checkWhile :: A.Process -> PassM A.Process checkWhile w@(A.While m exp _) @@ -265,7 +265,7 @@ checkConditionalTypes t = (everywhereASTM checkWhile t) >>= (everywhereASTM chec -- | Checks the types in inputs and outputs, including inputs in alts checkCommTypes :: Data t => t -> PassM t -checkCommTypes p = (everywhereASTM checkInputOutput p) >>= (everywhereASTM checkAltInput) +checkCommTypes = applyDepthM2 checkInputOutput checkAltInput where checkInput :: A.Variable -> A.Variable -> Meta -> a -> PassM a checkInput chanVar destVar m p @@ -307,7 +307,7 @@ checkCommTypes p = (everywhereASTM checkInputOutput p) >>= (everywhereASTM check -- | Checks the types in now and wait statements, and wait guards: checkGetTimeTypes :: Data t => t -> PassM t -checkGetTimeTypes p = (everywhereASTM checkGetTime p) >>= (everywhereASTM checkTimeGuards) +checkGetTimeTypes = applyDepthM2 checkGetTime checkTimeGuards where checkGetTime :: A.Process -> PassM A.Process checkGetTime p@(A.GetTime m v) diff --git a/pass/Pass.hs b/pass/Pass.hs index f967512..424d16c 100644 --- a/pass/Pass.hs +++ b/pass/Pass.hs @@ -196,11 +196,23 @@ makeGeneric top -- | Apply a monadic operation everywhere that it matches in the AST, going -- depth-first. -everywhereASTM :: (Data s, Data t) => (s -> PassM s) -> t -> PassM t -everywhereASTM f = doGeneric `extM` (doSpecific f) +applyDepthM :: (Data a, Data t) => (a -> PassM a) -> t -> PassM t +applyDepthM f = doGeneric `extM` (doSpecific f) where doGeneric :: Data t => t -> PassM t - doGeneric = makeGeneric (everywhereASTM f) + doGeneric = makeGeneric (applyDepthM f) + + doSpecific :: Data t => (t -> PassM t) -> t -> PassM t + doSpecific f x = (doGeneric x >>= f) + +-- | Apply two monadic operations everywhere they match in the AST, going +-- depth-first. +applyDepthM2 :: (Data a, Data b, Data t) => + (a -> PassM a) -> (b -> PassM b) -> t -> PassM t +applyDepthM2 f1 f2 = doGeneric `extM` (doSpecific f1) `extM` (doSpecific f2) + where + doGeneric :: Data t => t -> PassM t + doGeneric = makeGeneric (applyDepthM2 f1 f2) doSpecific :: Data t => (t -> PassM t) -> t -> PassM t doSpecific f x = (doGeneric x >>= f)