diff --git a/data/CompState.hs b/data/CompState.hs index 22baa29..28faec8 100644 --- a/data/CompState.hs +++ b/data/CompState.hs @@ -152,15 +152,11 @@ data CompState = CompState { csGlobalSizes :: Map [Int] String, -- Set by passes - csTypeContext :: [Maybe A.Type], csNonceCounter :: Int, csFunctionReturns :: Map String [A.Type], csPulledItems :: [[PulledItem]], csParProcs :: Map A.Name ParOrFork, csUnifyId :: Int, - -- The string is the operator, the name is the munged function name, the single - -- type is the return type - csOperators :: [(String, A.Name, A.Type, [A.Type])], csWarnings :: [WarningReport] } deriving (Data, Typeable, Show) @@ -212,13 +208,11 @@ emptyState = CompState { csArraySizes = Map.empty, csGlobalSizes = Map.empty, - csTypeContext = [], csNonceCounter = 0, csFunctionReturns = Map.empty, csPulledItems = [], csParProcs = Map.empty, csUnifyId = 0, - csOperators = [], csWarnings = [] } @@ -381,26 +375,6 @@ applyPulled ast --}}} ---{{{ type contexts --- | Enter a type context. -pushTypeContext :: CSM m => Maybe A.Type -> m () -pushTypeContext t - = modifyCompState (\ps -> ps { csTypeContext = t : csTypeContext ps }) - --- | Leave a type context. -popTypeContext :: CSM m => m () -popTypeContext - = modifyCompState (\ps -> ps { csTypeContext = tail $ csTypeContext ps }) - --- | Get the current type context, if there is one. -getTypeContext :: CSMR m => m (Maybe A.Type) -getTypeContext - = do ps <- getCompState - case csTypeContext ps of - (Just c):_ -> return $ Just c - _ -> return Nothing ---}}} - --{{{ nonces -- | Generate a throwaway unique name. makeNonce :: CSM m => Meta -> String -> m String diff --git a/frontends/OccamInferTypes.hs b/frontends/OccamInferTypes.hs index 8a7f0e8..216db29 100644 --- a/frontends/OccamInferTypes.hs +++ b/frontends/OccamInferTypes.hs @@ -61,7 +61,7 @@ betterType t _ = t --{{{ type context management -- | Run an operation in a given type context. -inTypeContext :: Maybe A.Type -> PassM a -> PassM a +inTypeContext :: Maybe A.Type -> InferTypeM a -> InferTypeM a inTypeContext ctx body = do pushTypeContext (case ctx of Just A.Infer -> Nothing @@ -71,13 +71,13 @@ inTypeContext ctx body return v -- | Run an operation in the type context 'Nothing'. -noTypeContext :: PassM a -> PassM a +noTypeContext :: InferTypeM a -> InferTypeM a noTypeContext = inTypeContext Nothing -- | Run an operation in the type context that results from subscripting -- the current type context. -- If the current type context is 'Nothing', the resulting one will be too. -inSubscriptedContext :: Meta -> PassM a -> PassM a +inSubscriptedContext :: Meta -> InferTypeM a -> InferTypeM a inSubscriptedContext m body = do ctx <- getTypeContext subCtx <- case ctx of @@ -139,14 +139,14 @@ makeEnd m dir v -- If unsure (e.g. Infer), just shove a direction on it to be sure: _ -> return $ A.DirectedVariable m dir v -scrubMobile :: PassM a -> PassM a +scrubMobile :: InferTypeM a -> InferTypeM a scrubMobile m = do ctx <- getTypeContext case ctx of (Just (A.Mobile t)) -> inTypeContext (Just t) m _ -> m -inferAllocMobile :: Meta -> A.Type -> A.Expression -> PassM A.Expression +inferAllocMobile :: Meta -> A.Type -> Infer A.Expression inferAllocMobile m (A.Mobile {}) e = do t <- astTypeOf e >>= underlyingType m case t of @@ -154,27 +154,63 @@ inferAllocMobile m (A.Mobile {}) e _ -> return $ A.AllocMobile m (A.Mobile t) (Just e) inferAllocMobile _ _ e = return e +data InferTypeState = InferTypeState + -- The string is the operator, the name is the munged function name, the single + -- type is the return type + { csOperators :: [(String, A.Name, A.Type, [A.Type])] + , csTypeContext :: [Maybe A.Type] + } + +type InferTypeM = StateT InferTypeState PassM + +type ExtOpMI ops t = ExtOpM InferTypeM ops t + +--{{{ type contexts + +-- | Enter a type context. +pushTypeContext :: Maybe A.Type -> InferTypeM () +pushTypeContext t + = modify (\ps -> ps { csTypeContext = t : csTypeContext ps }) + +-- | Leave a type context. +popTypeContext :: InferTypeM () +popTypeContext + = modify (\ps -> ps { csTypeContext = tail $ csTypeContext ps }) + +-- | Get the current type context, if there is one. +getTypeContext :: InferTypeM (Maybe A.Type) +getTypeContext + = do ps <- get + case csTypeContext ps of + (Just c):_ -> return $ Just c + _ -> return Nothing + +--}}} + + --{{{ inferTypes -- I can't put this in the where clause of inferTypes, so it has to be out -- here. It should be the type of ops inside the inferTypes function below. type InferTypeOps - = ExtOpMSP BaseOp - `ExtOpMP` A.Expression - `ExtOpMP` A.Dimension - `ExtOpMP` A.Subscript - `ExtOpMP` A.Replicator - `ExtOpMP` A.Alternative - `ExtOpMP` A.Process - `ExtOpMP` A.Variable - `ExtOpMP` A.Variant + = ExtOpMS InferTypeM BaseOp + `ExtOpMI` A.Expression + `ExtOpMI` A.Dimension + `ExtOpMI` A.Subscript + `ExtOpMI` A.Replicator + `ExtOpMI` A.Alternative + `ExtOpMI` A.Process + `ExtOpMI` A.Variable + `ExtOpMI` A.Variant + +type Infer a = a -> InferTypeM a -- | Infer types. inferTypes :: Pass A.AST inferTypes = occamOnlyPass "Infer types" [] [Prop.inferredTypesRecorded] - recurse + (flip evalStateT (InferTypeState [] []) . recurse) where ops :: InferTypeOps ops = baseOp @@ -188,13 +224,13 @@ inferTypes = occamOnlyPass "Infer types" `extOpM` doVariable `extOpM` doVariant - recurse :: RecurseM PassM InferTypeOps + recurse :: RecurseM InferTypeM InferTypeOps recurse = makeRecurseM ops - descend :: DescendM PassM InferTypeOps + descend :: DescendM InferTypeM InferTypeOps descend = makeDescendM ops - doExpression :: Transform A.Expression + doExpression :: Infer A.Expression doExpression outer = case outer of -- Literals are what we're really looking for here. @@ -237,7 +273,7 @@ inferTypes = occamOnlyPass "Infer types" -- Other expressions don't modify the type context. _ -> descend outer - doFunctionCall :: Meta -> Transform (A.Name, [A.Expression]) + doFunctionCall :: Meta -> Infer (A.Name, [A.Expression]) doFunctionCall m (n, es) = do if isOperator (A.nameName n) then @@ -248,12 +284,12 @@ inferTypes = occamOnlyPass "Infer types" 2 -> "binary" n -> show n ++ "-ary" - cs <- getCompState + operators <- get >>* csOperators resolvedOps <- sequence [ do ts' <- mapM (underlyingType m) ts rt' <- underlyingType m rt return (op, n, rt', ts') - | (op, n, rt, ts) <- csOperators cs + | (op, n, rt, ts) <- operators ] -- The nubBy will ensure that only one definition remains for each -- set of type-arguments, and will keep the first definition in the @@ -305,7 +341,7 @@ inferTypes = occamOnlyPass "Infer types" posss -> dieP m $ "Ambigious " ++ opDescrip ++ " operator, matches definitions: " ++ show (map (transformPair A.nameMeta showOccam) posss) else - do (_, fs) <- checkFunction m n + do (_, fs) <- lift $ checkFunction m n doActuals m n fs (direct, const return) es >>* (,) n where direct = error "Cannot direct channels passed to FUNCTIONs" @@ -329,26 +365,26 @@ inferTypes = occamOnlyPass "Infer types" = typeEqForOp t t' typeEqForOp t t' = t == t' - doActuals :: (PolyplateM a InferTypeOps () PassM, Data a) => Meta -> A.Name -> [A.Formal] -> - (Meta -> A.Direction -> Transform a, A.Type -> Transform a) -> Transform [a] + doActuals :: (PolyplateM a InferTypeOps () InferTypeM, Data a) => Meta -> A.Name -> [A.Formal] -> + (Meta -> A.Direction -> Infer a, A.Type -> Infer a) -> Infer [a] doActuals m n fs applyDir_Deref as - = do checkActualCount m n fs as + = do lift $ checkActualCount m n fs as sequence [doActual m applyDir_Deref t a | (A.Formal _ t _, a) <- zip fs as] -- First function directs, second function dereferences if needed - doActual :: (PolyplateM a InferTypeOps () PassM, Data a) => - Meta -> (Meta -> A.Direction -> Transform a, A.Type -> Transform a) -> A.Type -> Transform a + doActual :: (PolyplateM a InferTypeOps () InferTypeM, Data a) => + Meta -> (Meta -> A.Direction -> Infer a, A.Type -> Infer a) -> A.Type -> Infer a doActual m (applyDir, _) (A.ChanEnd dir _ _) a = recurse a >>= applyDir m dir doActual m (_, deref) t a = inTypeContext (Just t) $ recurse a >>= deref t - doDimension :: Transform A.Dimension + doDimension :: Infer A.Dimension doDimension dim = inTypeContext (Just A.Int) $ descend dim - doSubscript :: Transform A.Subscript + doSubscript :: Infer A.Subscript doSubscript s = inTypeContext (Just A.Int) $ descend s - doExpressionList :: [A.Type] -> Transform A.ExpressionList + doExpressionList :: [A.Type] -> Infer A.ExpressionList doExpressionList ts el = case el of A.FunctionCallList m n es -> @@ -361,13 +397,13 @@ inferTypes = occamOnlyPass "Infer types" return $ A.ExpressionList m es'' A.AllocChannelBundle {} -> return el - doReplicator :: Transform A.Replicator + doReplicator :: Infer A.Replicator doReplicator rep = case rep of A.For _ _ _ _ -> inTypeContext (Just A.Int) $ descend rep A.ForEach _ _ -> noTypeContext $ descend rep - doAlternative :: Transform A.Alternative + doAlternative :: Infer A.Alternative doAlternative (A.Alternative m pre v im p) = do pre' <- inTypeContext (Just A.Bool) $ recurse pre v' <- recurse v >>= derefVariableIfNeeded Nothing @@ -379,7 +415,7 @@ inferTypes = occamOnlyPass "Infer types" p' <- recurse p return $ A.AlternativeSkip m pre' p' - doInputMode :: A.Variable -> Transform A.InputMode + doInputMode :: A.Variable -> Infer A.InputMode doInputMode v (A.InputSimple m iis) = do ts <- protocolItems m v >>* either id (const []) iis' <- sequence [doInputItem t ii @@ -392,7 +428,7 @@ inferTypes = occamOnlyPass "Infer types" = doInputItem A.Int ii >>* A.InputTimerRead m doInputMode _ im = inTypeContext (Just A.Int) $ descend im - doInputItem :: A.Type -> Transform A.InputItem + doInputItem :: A.Type -> Infer A.InputItem doInputItem t (A.InVariable m v) = (inTypeContext (Just t) (recurse v) >>= derefVariableIfNeeded (Just t) @@ -404,7 +440,7 @@ inferTypes = occamOnlyPass "Infer types" >>= derefVariableIfNeeded (Just t) return $ A.InCounted m cv' av' - doVariant :: Transform A.Variant + doVariant :: Infer A.Variant doVariant (A.Variant m n iis p) = do ctx <- getTypeContext ets <- case ctx of @@ -420,9 +456,9 @@ inferTypes = occamOnlyPass "Infer types" p' <- recurse p return $ A.Variant m n iis' p' - doStructured :: ( PolyplateM (A.Structured t) InferTypeOps () PassM - , PolyplateM (A.Structured t) () InferTypeOps PassM - , Data t) => Transform (A.Structured t) + doStructured :: ( PolyplateM (A.Structured t) InferTypeOps () InferTypeM + , PolyplateM (A.Structured t) () InferTypeOps InferTypeM + , Data t) => Infer (A.Structured t) doStructured (A.Spec mspec s@(A.Specification m n st) body) = do (st', wrap) <- runReaderT (doSpecType n st) body @@ -432,10 +468,10 @@ inferTypes = occamOnlyPass "Infer types" doStructured s = descend s -- The second parameter is a modifier (wrapper) for the descent into the body - doSpecType :: ( PolyplateM (A.Structured t) InferTypeOps () PassM - , PolyplateM (A.Structured t) () InferTypeOps PassM - , Data t) => A.Name -> A.SpecType -> ReaderT (A.Structured t) PassM - (A.SpecType, PassM (A.Structured a) -> PassM (A.Structured a)) + doSpecType :: ( PolyplateM (A.Structured t) InferTypeOps () InferTypeM + , PolyplateM (A.Structured t) () InferTypeOps InferTypeM + , Data t) => A.Name -> A.SpecType -> ReaderT (A.Structured t) InferTypeM + (A.SpecType, InferTypeM (A.Structured a) -> InferTypeM (A.Structured a)) doSpecType n st = case st of A.Place _ _ -> lift $ inTypeContext (Just A.Int) $ descend st >>* addId @@ -454,10 +490,10 @@ inferTypes = occamOnlyPass "Infer types" return (tEnd, A.DirectedVariable m dir v') _ -> return (vt, v') -- no direction, or two (A.Infer, _) -> return (vt, v') - (A.ChanEnd dir _ _, _) -> do v'' <- lift $ makeEnd m dir v' + (A.ChanEnd dir _ _, _) -> do v'' <- lift $ lift $ makeEnd m dir v' return (t', v'') (A.Array _ (A.ChanEnd dir _ _), _) -> - do v'' <- lift $ makeEnd m dir v' + do v'' <- lift $ lift $ makeEnd m dir v' return (t', v'') (A.Chan cattr cinnerT, A.ChanEnd dir _ einnerT) -> do cinnerT' <- lift $ recurse cinnerT @@ -506,7 +542,7 @@ inferTypes = occamOnlyPass "Infer types" vs' <- lift $ mapM recurse vs >>= case t' of A.Infer -> return A.Array _ (A.Chan {}) -> return - A.Array _ (A.ChanEnd dir _ _) -> mapM (makeEnd m dir) + A.Array _ (A.ChanEnd dir _ _) -> mapM (lift . makeEnd m dir) _ -> const $ dieP m "Cannot coerce non-channels into channels" let dim = makeDimension m $ length vs' t'' <- lift $ case (t', vs') of @@ -540,7 +576,7 @@ inferTypes = occamOnlyPass "Infer types" rt <- case ts' of [t] -> return t _ -> dieP m "Operator must have exactly one return" - let before, after :: PassM () + let before, after :: InferTypeM () before = modify $ \cs -> cs { csOperators = (raw, n, rt, ts) : csOperators cs } after = modify $ \cs -> cs { csOperators = tail (csOperators cs)} return (func @@ -588,7 +624,7 @@ inferTypes = occamOnlyPass "Infer types" -- ExpressionList that must be in there. -- (This can go away once we represent all functions in the new Process -- form.) - doFuncDef :: [A.Type] -> Transform (A.Structured A.ExpressionList) + doFuncDef :: [A.Type] -> Infer (A.Structured A.ExpressionList) doFuncDef ts (A.Spec m (A.Specification m' n st) s) = do (st', wrap) <- runReaderT (doSpecType n st) s modifyName n (\nd -> nd { A.ndSpecType = st' }) @@ -615,9 +651,9 @@ inferTypes = occamOnlyPass "Infer types" -- Also, to fit with the normal ops, we must do so in the PassM monad. -- Normally we would do this pass in a StateT monad, but to slip inside -- PassM, I've used an IORef instead. - findDir :: ( PolyplateM a InferTypeOps () PassM - , PolyplateM a () InferTypeOps PassM - ) => A.Name -> a -> PassM [A.Direction] + findDir :: ( PolyplateM a InferTypeOps () InferTypeM + , PolyplateM a () InferTypeOps InferTypeM + ) => A.Name -> a -> InferTypeM [A.Direction] findDir n x = do r <- liftIO $ newIORef [] makeRecurseM (makeOps r) x @@ -637,12 +673,12 @@ inferTypes = occamOnlyPass "Infer types" `extOpM` descend `extOpM` (doVariable r) `extOpM` descend - descend :: DescendM PassM InferTypeOps + descend :: DescendM InferTypeM InferTypeOps descend = makeDescendM ops -- This will cover everything, since we will have inferred the direction -- specifiers before applying this function. - doVariable :: IORef [A.Direction] -> A.Variable -> PassM A.Variable + doVariable :: IORef [A.Direction] -> Infer A.Variable doVariable r v@(A.DirectedVariable _ dir (A.Variable _ n')) | n == n' = liftIO $ modifyIORef r (dir:) >> return v doVariable r v@(A.DirectedVariable _ dir @@ -650,7 +686,7 @@ inferTypes = occamOnlyPass "Infer types" = liftIO $ modifyIORef r (dir:) >> return v doVariable r v = makeDescendM (makeOps r) v - doProcess :: Transform A.Process + doProcess :: Infer A.Process doProcess p = case p of A.Assign m vs el -> @@ -696,9 +732,9 @@ inferTypes = occamOnlyPass "Infer types" A.While _ _ _ -> inTypeContext (Just A.Bool) $ descend p A.Processor _ _ _ -> inTypeContext (Just A.Int) $ descend p A.ProcCall m n as -> - do fs <- checkProc m n + do fs <- lift $ checkProc m n as' <- doActuals m n fs - (\m dir (A.ActualVariable v) -> liftM A.ActualVariable $ makeEnd m dir v + (\m dir (A.ActualVariable v) -> lift $ liftM A.ActualVariable $ makeEnd m dir v ,\t a -> case a of A.ActualVariable v -> derefVariableIfNeeded (Just t) v >>* A.ActualVariable _ -> return a @@ -724,9 +760,9 @@ inferTypes = occamOnlyPass "Infer types" _ -> descend p where -- | Does a channel carry a tagged protocol? - isTagged :: A.Variable -> PassM Bool + isTagged :: A.Variable -> InferTypeM Bool isTagged c - = do protoT <- checkChannel A.DirOutput c + = do protoT <- lift $ checkChannel A.DirOutput c case protoT of A.UserProtocol n -> do st <- specTypeOfName n @@ -736,13 +772,13 @@ inferTypes = occamOnlyPass "Infer types" _ -> return False doOutputItems :: Meta -> A.Variable -> Maybe A.Name - -> Transform [A.OutputItem] + -> Infer [A.OutputItem] doOutputItems m v tag ois - = do chanT <- checkChannel A.DirOutput v - ts <- protocolTypes m chanT tag + = do chanT <- lift $ checkChannel A.DirOutput v + ts <- lift $ protocolTypes m chanT tag sequence [doOutputItem t oi | (t, oi) <- zip ts ois] - doOutputItem :: A.Type -> Transform A.OutputItem + doOutputItem :: A.Type -> Infer A.OutputItem doOutputItem (A.Counted ct at) (A.OutCounted m ce ae) = do ce' <- inTypeContext (Just ct) $ recurse ce ae' <- inTypeContext (Just at) $ recurse ae @@ -752,7 +788,7 @@ inferTypes = occamOnlyPass "Infer types" = inTypeContext (Just t) (recurse e >>= inferAllocMobile m t) >>* A.OutExpression m - doVariable :: Transform A.Variable + doVariable :: Infer A.Variable doVariable (A.SubscriptedVariable m s v) = do v' <- noTypeContext (recurse v) >>= derefVariableIfNeeded Nothing t <- astTypeOf v' @@ -760,7 +796,7 @@ inferTypes = occamOnlyPass "Infer types" return $ A.SubscriptedVariable m s' v' doVariable v = descend v - derefVariableIfNeeded :: Maybe (A.Type) -> A.Variable -> PassM A.Variable + derefVariableIfNeeded :: Maybe (A.Type) -> Infer A.Variable derefVariableIfNeeded ctxOrig v = do ctx <- (T.sequence . fmap (resolveUserType (findMeta v))) ctxOrig underT <- astTypeOf v >>= resolveUserType (findMeta v) @@ -772,7 +808,7 @@ inferTypes = occamOnlyPass "Infer types" -- | Resolve the @v[s]@ ambiguity: this takes the type that @v@ is, and -- returns the correct 'Subscript'. - fixSubscript :: A.Type -> A.Subscript -> PassM A.Subscript + fixSubscript :: A.Type -> Infer A.Subscript fixSubscript t s@(A.Subscript m _ (A.ExprVariable _ (A.Variable _ wrong))) = do underT <- resolveUserType m t case underT of @@ -786,14 +822,14 @@ inferTypes = occamOnlyPass "Infer types" fixSubscript _ s = return s -- | Given a name that should really have been a tag, make it one. - nameToUnscoped :: A.Name -> PassM A.Name + nameToUnscoped :: A.Name -> InferTypeM A.Name nameToUnscoped n@(A.Name m _) = do nd <- lookupName n findUnscopedName (A.Name m (A.ndOrigName nd)) -- | Process a 'LiteralRepr', taking the type it's meant to represent or -- 'Infer', and returning the type it really is. - doLiteral :: Transform (A.Type, A.LiteralRepr) + doLiteral :: Infer (A.Type, A.LiteralRepr) doLiteral (wantT, lr) = case lr of A.ArrayListLiteral m aes -> @@ -820,7 +856,7 @@ inferTypes = occamOnlyPass "Infer types" where m = findMeta lr - doArrayElem :: A.Type -> A.Structured A.Expression -> PassM (A.Type, A.Structured A.Expression) + doArrayElem :: A.Type -> A.Structured A.Expression -> InferTypeM (A.Type, A.Structured A.Expression) doArrayElem wantT (A.Spec m spec body) -- A replicator: strip off a subscript and keep going = do underT <- resolveUserType m wantT @@ -856,7 +892,7 @@ inferTypes = occamOnlyPass "Infer types" A.Several m aes') _ -> diePC m $ formatCode "Table literal is not valid for type %" wantT where - doElems :: A.Type -> [A.Structured A.Expression] -> PassM (A.Type, [A.Structured A.Expression]) + doElems :: A.Type -> [A.Structured A.Expression] -> InferTypeM (A.Type, [A.Structured A.Expression]) doElems t aes = do ts <- mapM (\ae -> doArrayElem t ae >>* fst) aes let bestT = foldl betterType t ts @@ -866,12 +902,12 @@ inferTypes = occamOnlyPass "Infer types" doArrayElem wantT (A.Only m e) = do e' <- inTypeContext (Just wantT) $ doExpression e t <- astTypeOf e' - checkType (findMeta e') wantT t + lift $ checkType (findMeta e') wantT t return (t, A.Only m e') -- | Turn a raw table literal into the appropriate combination of -- arrays and records. - buildTable :: A.Type -> [A.Structured A.Expression] -> PassM A.LiteralRepr + buildTable :: A.Type -> [A.Structured A.Expression] -> InferTypeM A.LiteralRepr buildTable t aes = do underT <- resolveUserType m t case underT of @@ -885,13 +921,13 @@ inferTypes = occamOnlyPass "Infer types" | ((_, elemT), ae) <- zip nts aes] return $ A.RecordLiteral m aes' where - buildExpr :: A.Type -> A.Structured A.Expression -> PassM A.Expression + buildExpr :: A.Type -> A.Structured A.Expression -> InferTypeM A.Expression buildExpr t (A.Several _ aes) = do lr <- buildTable t aes return $ A.Literal m t lr buildExpr _ (A.Only _ e) = return e - buildElem :: A.Type -> A.Structured A.Expression -> PassM (A.Structured A.Expression) + buildElem :: A.Type -> Infer (A.Structured A.Expression) buildElem t ae = do underT <- resolveUserType m t case (underT, ae) of