diff --git a/Makefile.am b/Makefile.am
index 3f5eddc..d282dcc 100644
--- a/Makefile.am
+++ b/Makefile.am
@@ -233,7 +233,7 @@ tocktest_SOURCES += transformations/PassTest.hs
tocktest_SOURCES += transformations/SimplifyAbbrevsTest.hs
tocktest_SOURCES += transformations/SimplifyTypesTest.hs
-pregen_sources = data/AST.hs
+pregen_sources = data/AST.hs data/CompState.hs
pregen_sources += pregen/PregenUtils.hs
pregen_sources += polyplate/Data/Generics/Polyplate/GenInstances.hs
diff --git a/backends/BackendPasses.hs b/backends/BackendPasses.hs
index cd4e4d8..6505925 100644
--- a/backends/BackendPasses.hs
+++ b/backends/BackendPasses.hs
@@ -71,7 +71,96 @@ removeDirectionsForC
doVariable (A.DirectedVariable _ _ v) = v
doVariable v = v
-transformWaitFor :: Pass
+
+-- | Remove variable directions that are superfluous. This prevents confusing
+-- later passes, where the user has written something like:
+-- []CHAN INT da! IS ...:
+-- foo(da!)
+--
+-- The second direction specifier is unneeded, and will confuse passes such as
+-- those adding sizes parameters (which looks for plain variables, since directed
+-- arrays should already have been pulled up).
+removeUnneededDirections :: PassOn A.Variable
+removeUnneededDirections
+ = occamOnlyPass "Remove unneeded variable directions"
+ prereq
+ []
+ (applyBottomUpM doVariable)
+ where
+ doVariable :: Transform (A.Variable)
+ doVariable whole@(A.DirectedVariable m dir v)
+ = do t <- astTypeOf v
+ case t of
+ A.Chan {} -> return whole
+ A.Array _ (A.Chan {}) -> return whole
+ A.ChanEnd chanDir _ _ | dir == chanDir -> return v
+ A.Array _ (A.ChanEnd chanDir _ _) | dir == chanDir -> return v
+ _ -> diePC m $ formatCode "Direction applied to non-channel type: %" t
+ doVariable v = return v
+
+type AllocMobileOps = ExtOpMSP BaseOp `ExtOpMP` A.Process
+
+-- | Pulls up any initialisers for mobile allocations. I think, after all the
+-- other passes have run, the only place these initialisers should be left is in
+-- assignments (and maybe not even those?) and A.Is items.
+pullAllocMobile :: PassOnOps AllocMobileOps
+pullAllocMobile = cOnlyPass "Pull up mobile initialisers" [] [] recurse
+ where
+ ops :: AllocMobileOps
+ ops = baseOp `extOpMS` (ops, doStructured) `extOpM` doProcess
+
+ recurse :: RecurseM PassM AllocMobileOps
+ recurse = makeRecurseM ops
+ descend :: DescendM PassM AllocMobileOps
+ descend = makeDescendM ops
+
+ doProcess :: Transform A.Process
+ doProcess (A.Assign m [v] (A.ExpressionList me [A.AllocMobile ma t (Just e)]))
+ = return $ A.Seq m $ A.Several m $ map (A.Only m) $
+ [A.Assign m [v] $ A.ExpressionList me [A.AllocMobile ma t Nothing]
+ ,A.Assign m [A.DerefVariable m v] $ A.ExpressionList me [e]
+ ]
+ doProcess p = descend p
+
+ doStructured :: TransformStructured' AllocMobileOps
+ doStructured (A.Spec mspec (A.Specification mif n
+ (A.Is mis am t (A.ActualExpression (A.AllocMobile ma tm (Just e)))))
+ body)
+ = do body' <- recurse body
+ return $ A.Spec mspec (A.Specification mif n $
+ A.Is mis am t $ A.ActualExpression $ A.AllocMobile ma tm Nothing)
+ $ A.ProcThen ma
+ (A.Assign ma [A.DerefVariable mif $ A.Variable mif n] $ A.ExpressionList ma [e])
+ body'
+ doStructured s = descend s
+
+-- | Turns any literals equivalent to a MOSTNEG back into a MOSTNEG
+-- The reason for doing this is that C (and presumably C++) don't technically (according
+-- to the standard) allow you to write INT_MIN directly as a constant. GCC certainly
+-- warns about it. So this pass takes any MOSTNEG-equivalent values (that will have been
+-- converted to constants in the constant folding earlier) and turns them back
+-- into MOSTNEG, for which the C backend uses INT_MIN and similar, which avoid
+-- this problem.
+fixMinInt :: PassOn A.Expression
+fixMinInt
+ = cOrCppOnlyPass "Turn any literals that are equal to MOSTNEG INT back into MOSTNEG INT"
+ prereq
+ []
+ (applyBottomUpM doExpression)
+ where
+ doExpression :: Transform (A.Expression)
+ doExpression l@(A.Literal m t (A.IntLiteral m' s))
+ = do folded <- constantFold (A.MostNeg m t)
+ case folded of
+ (A.Literal _ _ (A.IntLiteral _ s'), _, _)
+ -> if (s == s')
+ then return $ A.MostNeg m t
+ else return l
+ _ -> return l -- This can happen as some literals retain the Infer
+ -- type which fails the constant folding
+ doExpression e = return e
+
+transformWaitFor :: PassOn A.Process
transformWaitFor = cOnlyPass "Transform wait for guards into wait until guards"
[]
[Prop.waitForRemoved]
diff --git a/backends/BackendPassesTest.hs b/backends/BackendPassesTest.hs
index 19d79a7..871e6b6 100644
--- a/backends/BackendPassesTest.hs
+++ b/backends/BackendPassesTest.hs
@@ -22,7 +22,7 @@ with this program. If not, see .
module BackendPassesTest (qcTests) where
import Control.Monad.State
-import Data.Generics
+import Data.Generics (Data)
import qualified Data.Map as Map
import Test.HUnit hiding (State)
import Test.QuickCheck
diff --git a/backends/GenerateCTest.hs b/backends/GenerateCTest.hs
index ea561a7..118fbde 100644
--- a/backends/GenerateCTest.hs
+++ b/backends/GenerateCTest.hs
@@ -35,7 +35,7 @@ import Control.Monad.Error
import Control.Monad.State
import Control.Monad.Reader
import Control.Monad.Writer hiding (tell)
-import Data.Generics
+import Data.Generics (Data)
import Data.List (isInfixOf, intersperse)
import Data.Maybe (fromMaybe)
import Test.HUnit hiding (State)
diff --git a/checks/CheckFramework.hs b/checks/CheckFramework.hs
index 8cb6f1a..539e2e4 100644
--- a/checks/CheckFramework.hs
+++ b/checks/CheckFramework.hs
@@ -420,40 +420,14 @@ applyAccum _ ops = ops'
ops' :: ((t, Route t A.AST) -> StateT (AccumMap t) (RestartT CheckOptM) t, ops)
ops' = (accum, ops)
- extF ::
- (forall a. Data a => TransFuncS acc z a) ->
- (forall c. Data c => TransFuncS acc z c)
- extF = (`extMRAccS` (\(x,_) -> modify (accOneF x) >> return x))
-
- applyAccum' :: (forall a. Data a => TransFuncAcc acc a) ->
- (forall b. Data b => (b, Route b A.AST) -> StateT acc (RestartT CheckOptM) b)
- applyAccum' f (x, route)
- = do when (findMeta x /= emptyMeta) $ lift . lift . CheckOptM $ modify $ \d -> d {lastValidMeta = findMeta x}
- (x', acc) <- lift $ flip runStateT accEmpty (gmapMForRoute typeSet (extF wrap) x)
- r <- f' (x', route, acc)
- modify (`accJoinF` acc)
- return r
- where
- wrap (y, route') = applyAccum' f (y, route @-> route')
-
- -- Keep applying the function while there is a Left return (which indicates
- -- the value was replaced) until there is a Right return
- f' (x, route, acc) = do
- x' <- f (x, route, acc)
- case x' of
- Left y -> f' (y, route, foldl (flip accOneF) accEmpty (listify {-TODO-} (const True) y))
- Right y -> return y
+ accum xr = do x' <- transformMRoute () ops' xr
+ modify $ Map.insert (routeId $ snd xr) x'
+ return x'
-applyTopDown :: TypeSet -> (forall a. Data a => TransFunc a) ->
- (forall b. Data b => (b, Route b A.AST) -> RestartT CheckOptM b)
-applyTopDown typeSet f (x, route)
- = do when (findMeta x /= emptyMeta) $ lift . CheckOptM $ modify $ \d -> d {lastValidMeta = findMeta x}
- z <- f' (x, route)
- gmapMForRoute typeSet (\(y, route') -> applyTopDown typeSet f (y, route @-> route')) z
- where
- -- Keep applying the function while there is a Left return (which indicates
- -- the value was replaced) until there is a Right return
- f' (x, route) = do
+-- Keep applying the function while there is a Left return (which indicates
+-- the value was replaced) until there is a Right return
+keepApplying :: Monad m => ((t, Route t outer) -> m (Either t t)) -> ((t, Route t outer) -> m t)
+keepApplying f (x, route) = do
x' <- f (x, route)
case x' of
Left y -> keepApplying f (y, route)
diff --git a/common/CommonTest.hs b/common/CommonTest.hs
index 4582623..6545f6e 100644
--- a/common/CommonTest.hs
+++ b/common/CommonTest.hs
@@ -21,7 +21,7 @@ with this program. If not, see .
-- | A module with tests for various miscellaneous things in the common directory.
module CommonTest (tests) where
-import Data.Generics
+import Data.Generics (Constr, Data, Typeable)
import Test.HUnit hiding (State)
import qualified AST as A
diff --git a/common/OccamEDSL.hs b/common/OccamEDSL.hs
index 1ddf321..df5fbb4 100644
--- a/common/OccamEDSL.hs
+++ b/common/OccamEDSL.hs
@@ -30,7 +30,7 @@ module OccamEDSL (ExpInp, ExpInpT,
becomes) where
import Control.Monad.State hiding (guard)
-import Data.Generics
+import Data.Generics (Data)
import qualified Data.Map as Map
import qualified Data.Set as Set
import Test.HUnit hiding (State)
diff --git a/common/TestFramework.hs b/common/TestFramework.hs
index 0622549..6cc7604 100644
--- a/common/TestFramework.hs
+++ b/common/TestFramework.hs
@@ -21,7 +21,7 @@ with this program. If not, see .
module TestFramework where
import Control.Monad.Error
-import Data.Generics
+import Data.Generics (Data)
import System.IO.Unsafe
import Test.HUnit hiding (Testable)
import Test.QuickCheck hiding (check)
diff --git a/common/TestUtils.hs b/common/TestUtils.hs
index 6496145..4d9e01b 100644
--- a/common/TestUtils.hs
+++ b/common/TestUtils.hs
@@ -40,7 +40,7 @@ module TestUtils where
import Control.Monad.State
import Control.Monad.Writer
-import Data.Generics
+import Data.Generics (Data, Typeable)
import qualified Data.Map as Map
import System.Random
import Test.HUnit hiding (State,Testable)
diff --git a/data/CompState.hs b/data/CompState.hs
index 4785ccf..29698ff 100644
--- a/data/CompState.hs
+++ b/data/CompState.hs
@@ -23,7 +23,7 @@ import Control.Monad.Error
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
-import Data.Generics (Data, Typeable)
+import Data.Generics (Data, Typeable, listify)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe
@@ -496,6 +496,8 @@ instance FindMeta A.Name where
findMeta = A.nameMeta
-- Should stop being lazy, and put these as pattern matches:
+--
+-- TODO also, at least use Polyplate!
findMeta_Data :: Data a => a -> Meta
findMeta_Data = head . listify (const True)
diff --git a/data/Metadata.hs b/data/Metadata.hs
index 58718f5..be4be3b 100644
--- a/data/Metadata.hs
+++ b/data/Metadata.hs
@@ -21,7 +21,7 @@ module Metadata where
{-! global : Haskell2Xml !-}
-import Data.Generics (Data, Typeable, listify)
+import Data.Generics (Data, Typeable)
import Data.List
import Numeric
import Text.Printf
diff --git a/flow/FlowGraph.hs b/flow/FlowGraph.hs
index 7c8e701..1678c66 100644
--- a/flow/FlowGraph.hs
+++ b/flow/FlowGraph.hs
@@ -51,7 +51,8 @@ import Data.Graph.Inductive hiding (run)
import Data.Maybe
import qualified AST as A
-import GenericUtils
+import CompState
+import Data.Generics.Polyplate.Route
import Metadata
import FlowUtils
import Utils
diff --git a/flow/FlowGraphTest.hs b/flow/FlowGraphTest.hs
index ea3d9f9..21612f1 100644
--- a/flow/FlowGraphTest.hs
+++ b/flow/FlowGraphTest.hs
@@ -24,7 +24,7 @@ module FlowGraphTest (qcTests) where
import Control.Monad.Identity
import Control.Monad.State
-import Data.Generics
+import Data.Generics (Data)
import Data.Graph.Inductive
import Data.List
import qualified Data.Map as Map
@@ -34,11 +34,14 @@ import Test.HUnit hiding (Node, State, Testable)
import Test.QuickCheck
import qualified AST as A
+import CompState
+import Data.Generics.Polyplate.Route
import FlowGraph
import Metadata
import PrettyShow
import TestFramework
import TestUtils
+import Traversal
import Utils
-- | Makes a distinctive metatag for testing. The function is one-to-one.
@@ -789,7 +792,7 @@ pickFuncRep gr = Map.fromList $ filter ((/= emptyMeta) . fst) $ map (helpApplyFu
applyFunc (m,AlterSpec f) = routeModify f (g m)
applyFunc (m,AlterNothing _) = return
- g m = gmapM (mkM $ replaceM m (replaceMeta m))
+ g m = applyBottomUpM $ replaceM m (replaceMeta m)
-- | It is important to have these functions in the right ratio. The number of possible trees is
diff --git a/frontends/OccamPassesTest.hs b/frontends/OccamPassesTest.hs
index dfa046e..c062ba4 100644
--- a/frontends/OccamPassesTest.hs
+++ b/frontends/OccamPassesTest.hs
@@ -23,7 +23,7 @@ with this program. If not, see .
module OccamPassesTest (tests) where
import Control.Monad.State
-import Data.Generics
+import Data.Generics (Data)
import Test.HUnit hiding (State)
import qualified AST as A
@@ -32,6 +32,8 @@ import Metadata
import qualified OccamPasses
import Pass
import TestUtils
+import Traversal
+import Types
m :: Meta
m = emptyMeta
@@ -138,15 +140,15 @@ testCheckConstants = TestList
, testFail 33 (A.Option m [lit10, lit10, lit10, var] skip)
]
where
- testOK :: (PolyplateM a (TwoOpM PassM A.Dimension A.Option) () PassM
- ,PolyplateM a () (TwoOpM PassM A.Dimension A.Option) PassM
+ testOK :: (PolyplateM a (TwoOpM PassM A.Type A.Option) () PassM
+ ,PolyplateM a () (TwoOpM PassM A.Type A.Option) PassM
,Show a, Data a) => Int -> a -> Test
testOK n orig
= TestCase $ testPass ("testCheckConstants" ++ show n)
orig OccamPasses.checkConstants orig
(return ())
- testFail :: (PolyplateM a (TwoOpM PassM A.Dimension A.Option) () PassM
- ,PolyplateM a () (TwoOpM PassM A.Dimension A.Option) PassM
+ testFail :: (PolyplateM a (TwoOpM PassM A.Type A.Option) () PassM
+ ,PolyplateM a () (TwoOpM PassM A.Type A.Option) PassM
,Show a, Data a) => Int -> a -> Test
testFail n orig
= TestCase $ testPassShouldFail ("testCheckConstants" ++ show n)
diff --git a/frontends/OccamTypes.hs b/frontends/OccamTypes.hs
index 841d5b4..96e841d 100644
--- a/frontends/OccamTypes.hs
+++ b/frontends/OccamTypes.hs
@@ -648,6 +648,7 @@ type InferTypeOps
`ExtOpMP` A.Alternative
`ExtOpMP` A.Process
`ExtOpMP` A.Variable
+ `ExtOpMP` A.Variant
-- | Infer types.
inferTypes :: Pass A.AST
@@ -658,16 +659,18 @@ inferTypes = occamOnlyPass "Infer types"
where
ops :: InferTypeOps
ops = baseOp
- `extOp` doExpression
- `extOp` doDimension
- `extOp` doSubscript
- `extOp` doArrayConstr
- `extOp` doReplicator
- `extOp` doAlternative
- `extOp` doInputMode
- `extOp` doSpecification
- `extOp` doProcess
- `extOp` doVariable
+ `extOpMS` (ops, doStructured)
+ `extOpM` doExpression
+ `extOpM` doDimension
+ `extOpM` doSubscript
+ `extOpM` doReplicator
+ `extOpM` doAlternative
+ `extOpM` doProcess
+ `extOpM` doVariable
+ `extOpM` doVariant
+
+ recurse :: RecurseM PassM InferTypeOps
+ recurse = makeRecurseM ops
descend :: DescendM PassM InferTypeOps
descend = makeDescendM ops
@@ -834,7 +837,26 @@ inferTypes = occamOnlyPass "Infer types"
inTypeContext (Just ct) (recurse sv) >>* A.InputCase m
doInputMode _ im = inTypeContext (Just A.Int) $ descend im
- doStructured :: Data a => Transform (A.Structured a)
+ doVariant :: Transform A.Variant
+ doVariant (A.Variant m n iis p)
+ = do ctx <- getTypeContext
+ ets <- case ctx of
+ Just x -> protocolItems m x
+ Nothing -> dieP m "Could not deduce protocol"
+ case ets of
+ Left {} -> dieP m "Simple protocol expected during input CASE"
+ Right ps -> case lookup n ps of
+ Nothing -> diePC m $ formatCode "Name % is not part of protocol %"
+ n (fromJust ctx)
+ Just ts -> do iis' <- sequence [inTypeContext (Just t) $ recurse ii
+ | (t, ii) <- zip ts iis]
+ p' <- recurse p
+ return $ A.Variant m n iis' p'
+
+ doStructured :: ( PolyplateM (A.Structured t) InferTypeOps () PassM
+ , PolyplateM (A.Structured t) () InferTypeOps PassM
+ , Data t) => Transform (A.Structured t)
+
doStructured (A.Spec mspec s@(A.Specification m n st) body)
= do (st', wrap) <- runReaderT (doSpecType n st) body
-- Update the definition of each name after we handle it.
@@ -842,7 +864,11 @@ inferTypes = occamOnlyPass "Infer types"
wrap (recurse body) >>* A.Spec mspec (A.Specification m n st')
doStructured s = descend s
- doSpecType :: Data a => A.Name -> A.SpecType -> ReaderT (A.Structured a) PassM A.SpecType
+ -- The second parameter is a modifier (wrapper) for the descent into the body
+ doSpecType :: ( PolyplateM (A.Structured t) InferTypeOps () PassM
+ , PolyplateM (A.Structured t) () InferTypeOps PassM
+ , Data t) => A.Name -> A.SpecType -> ReaderT (A.Structured t) PassM
+ (A.SpecType, PassM (A.Structured a) -> PassM (A.Structured a))
doSpecType n st
= case st of
A.Place _ _ -> lift $ inTypeContext (Just A.Int) $ descend st >>* addId
@@ -1025,6 +1051,7 @@ inferTypes = occamOnlyPass "Infer types"
`extOpM` descend
`extOpM` descend
`extOpM` (doVariable r)
+ `extOpM` descend
descend :: DescendM PassM InferTypeOps
descend = makeDescendM ops
diff --git a/frontends/OccamTypesTest.hs b/frontends/OccamTypesTest.hs
index b25dacf..0a950af 100644
--- a/frontends/OccamTypesTest.hs
+++ b/frontends/OccamTypesTest.hs
@@ -22,7 +22,7 @@ with this program. If not, see .
module OccamTypesTest (vioTests) where
import Control.Monad.State
-import Data.Generics
+import Data.Generics (Data)
import Test.HUnit hiding (State)
import qualified AST as A
diff --git a/frontends/ParseRainTest.hs b/frontends/ParseRainTest.hs
index ba6a5ce..b418d14 100644
--- a/frontends/ParseRainTest.hs
+++ b/frontends/ParseRainTest.hs
@@ -35,7 +35,7 @@ with this program. If not, see .
-- and then turn these into Patterns where any Meta tag that is "m" is ignored during the comparison.
module ParseRainTest (tests) where
-import Data.Generics
+import Data.Generics (Data)
import Prelude hiding (fail)
import Test.HUnit
import Text.ParserCombinators.Parsec (runParser,eof)
diff --git a/frontends/RainPassesTest.hs b/frontends/RainPassesTest.hs
index 030b83e..8b11f45 100644
--- a/frontends/RainPassesTest.hs
+++ b/frontends/RainPassesTest.hs
@@ -31,7 +31,7 @@ module RainPassesTest (tests) where
import Control.Monad.State
import Control.Monad.Identity
-import Data.Generics
+import Data.Generics (Data, Typeable)
import qualified Data.Map as Map
import Test.HUnit hiding (State)
diff --git a/frontends/RainTypes.hs b/frontends/RainTypes.hs
index 79a5a2f..11d62f7 100644
--- a/frontends/RainTypes.hs
+++ b/frontends/RainTypes.hs
@@ -19,7 +19,7 @@ with this program. If not, see .
module RainTypes (constantFoldPass, performTypeUnification) where
import Control.Monad.State
-import Data.Generics (Data, showConstr, toConstr)
+import Data.Generics (Data, showConstr, toConstr, Typeable)
import Data.List
import qualified Data.Map as Map
import Data.Maybe
diff --git a/frontends/RainTypesTest.hs b/frontends/RainTypesTest.hs
index f3d05f3..a2d7592 100644
--- a/frontends/RainTypesTest.hs
+++ b/frontends/RainTypesTest.hs
@@ -22,7 +22,7 @@ module RainTypesTest (vioTests) where
import Control.Monad.State
import Control.Monad.Error
import Control.Monad.Writer
-import Data.Generics
+import Data.Generics (Data)
import qualified Data.Map as Map
import Test.HUnit hiding (State)
diff --git a/transformations/PassTest.hs b/transformations/PassTest.hs
index 20f2b5c..a3de92f 100644
--- a/transformations/PassTest.hs
+++ b/transformations/PassTest.hs
@@ -20,7 +20,7 @@ with this program. If not, see .
module PassTest (tests) where
import Control.Monad.State hiding (guard)
-import Data.Generics
+import Data.Generics (cast, Data, Typeable)
import qualified Data.Map as Map
import Test.HUnit hiding (State)
diff --git a/transformations/SimplifyAbbrevsTest.hs b/transformations/SimplifyAbbrevsTest.hs
index 4f44cb2..ce15297 100644
--- a/transformations/SimplifyAbbrevsTest.hs
+++ b/transformations/SimplifyAbbrevsTest.hs
@@ -20,7 +20,7 @@ with this program. If not, see .
module SimplifyAbbrevsTest (tests) where
import Control.Monad.State
-import Data.Generics
+import Data.Generics (Data)
import Test.HUnit hiding (State)
import CompState
diff --git a/transformations/SimplifyTypesTest.hs b/transformations/SimplifyTypesTest.hs
index a5db62e..d172d65 100644
--- a/transformations/SimplifyTypesTest.hs
+++ b/transformations/SimplifyTypesTest.hs
@@ -20,7 +20,7 @@ with this program. If not, see .
module SimplifyTypesTest (tests) where
import Control.Monad.State
-import Data.Generics
+import Data.Generics (Data)
import Test.HUnit hiding (State)
import CompState