diff --git a/frontends/OccamTypes.hs b/frontends/OccamTypes.hs index 330dd2f..04ff06b 100644 --- a/frontends/OccamTypes.hs +++ b/frontends/OccamTypes.hs @@ -837,28 +837,20 @@ inferTypes = occamOnlyPass "Infer types" p' <- recurse p return $ A.Variant m n iis' p' - doStructured :: Data a => Transform (A.Structured a) doStructured (A.Spec mspec s@(A.Specification m n st) body) - = do st' <- runReaderT (doSpecType n st) body + = do (st', wrap) <- runReaderT (doSpecType n st) body -- Update the definition of each name after we handle it. modifyName n (\nd -> nd { A.ndSpecType = st' }) - let doBody = recurse body >>* A.Spec mspec (A.Specification m n st') - mOp <- functionOperator n - case (st, mOp) of - (A.Function _ _ _ fs _, Just raw) -> do - ts <- mapM astTypeOf fs - modify $ \cs -> cs { csOperators = (raw, n, ts) : csOperators cs } - x <- doBody - modify $ \cs -> cs { csOperators = tail (csOperators cs)} - return x - _ -> doBody + wrap (recurse body) >>* A.Spec mspec (A.Specification m n st') doStructured s = descend s - doSpecType :: Data a => A.Name -> A.SpecType -> ReaderT (A.Structured a) PassM A.SpecType + -- The second parameter is a modifier (wrapper) for the descent into the body + doSpecType :: Data a => A.Name -> A.SpecType -> ReaderT (A.Structured a) PassM + (A.SpecType, PassM (A.Structured a) -> PassM (A.Structured a)) doSpecType n st = case st of - A.Place _ _ -> lift $ inTypeContext (Just A.Int) $ descend st + A.Place _ _ -> lift $ inTypeContext (Just A.Int) $ descend st >>* addId A.Is m am t (A.ActualVariable v) -> do am' <- lift $ recurse am t' <- lift $ recurse t @@ -892,7 +884,7 @@ inferTypes = occamOnlyPass "Infer types" return (tEnd, A.DirectedVariable m dir v') _ -> return (t', v') -- no direction, or two _ -> return (t', v') - return $ A.Is m am' t'' $ A.ActualVariable v'' + return $ addId $ A.Is m am' t'' $ A.ActualVariable v'' A.Is m am t (A.ActualExpression e) -> lift $ do am' <- recurse am t' <- recurse t @@ -901,7 +893,7 @@ inferTypes = occamOnlyPass "Infer types" A.Infer -> astTypeOf e' A.Array ds _ | A.UnknownDimension `elem` ds -> astTypeOf e' _ -> return t' - return $ A.Is m am' t'' (A.ActualExpression e') + return $ addId $ A.Is m am' t'' (A.ActualExpression e') A.Is m am t (A.ActualClaim v) -> lift $ do am' <- recurse am t' <- recurse t @@ -909,7 +901,7 @@ inferTypes = occamOnlyPass "Infer types" t'' <- case t' of A.Infer -> astTypeOf (A.ActualClaim v') _ -> return t' - return $ A.Is m am' t'' (A.ActualClaim v') + return $ addId $ A.Is m am' t'' (A.ActualClaim v') A.Is m am t (A.ActualChannelArray vs) -> -- No expressions in this -- but we may need to infer the type -- of the variable if it's something like "cs IS [c]:". @@ -935,21 +927,35 @@ inferTypes = occamOnlyPass "Infer types" ,A.DirectedVariable m dir) _ -> return (t'', id) _ -> return (t'', id) - return $ A.Is m am t''' $ A.ActualChannelArray $ map f vs' - A.Function m sm ts fs (Just (Left sel)) -> lift $ + return $ addId $ A.Is m am t''' $ A.ActualChannelArray $ map f vs' + A.Function m sm ts fs mbody -> lift $ do sm' <- recurse sm ts' <- recurse ts fs' <- recurse fs - sel' <- doFuncDef ts sel - return $ A.Function m sm' ts' fs' $ Just (Left sel') - A.RetypesExpr _ _ _ _ -> lift $ noTypeContext $ descend st + sel' <- case mbody of + Just (Left sel) -> doFuncDef ts sel >>* (Just . Left) + _ -> return mbody + mOp <- functionOperator n + let func = A.Function m sm' ts' fs' sel' + case mOp of + Just raw -> do + ts <- mapM astTypeOf fs + let before = modify $ \cs -> cs { csOperators = (raw, n, ts) : csOperators cs } + after = modify $ \cs -> cs { csOperators = tail (csOperators cs)} + return (func + ,\m -> do before + x <- m + after + return x) + _ -> return func >>* addId + A.RetypesExpr _ _ _ _ -> lift $ noTypeContext $ descend st >>* addId -- For PROCs that take any channels without direction, -- we must determine if we can infer a specific direction -- for that channel A.Proc m sm fs body -> lift $ do body' <- recurse body fs' <- mapM (processFormal body') fs - return $ A.Proc m sm fs' body' + return $ addId $ A.Proc m sm fs' body' where processFormal body f@(A.Formal am t n) = do t' <- recurse t @@ -967,17 +973,20 @@ inferTypes = occamOnlyPass "Infer types" _ -> do modifyName n (\nd -> nd {A.ndSpecType = A.Declaration m t'}) return $ A.Formal am t' n - _ -> lift $ descend st + _ -> lift $ descend st >>* addId where + addId :: a -> (a, b -> b) + addId a = (a, id) + -- | This is a bit ugly: walk down a Structured to find the single -- ExpressionList that must be in there. -- (This can go away once we represent all functions in the new Process -- form.) doFuncDef :: [A.Type] -> Transform (A.Structured A.ExpressionList) doFuncDef ts (A.Spec m (A.Specification m' n st) s) - = do st' <- runReaderT (doSpecType n st) s + = do (st', wrap) <- runReaderT (doSpecType n st) s modifyName n (\nd -> nd { A.ndSpecType = st' }) - s' <- doFuncDef ts s + s' <- wrap $ doFuncDef ts s return $ A.Spec m (A.Specification m' n st') s' doFuncDef ts (A.ProcThen m p s) = do p' <- recurse p