diff --git a/checks/Check.hs b/checks/Check.hs index 94f7fcd..1ee80eb 100644 --- a/checks/Check.hs +++ b/checks/Check.hs @@ -299,19 +299,34 @@ foldUnionVarsBK checkPlainVarUsage :: forall m. (MonadIO m, Die m, CSMR m) => NameAttr -> (Meta, ParItems (BK, UsageLabel)) -> m () checkPlainVarUsage sharedAttr (m, p) = check p where - addBK :: BK -> Vars -> VarsBK - addBK bk vs = VarsBK (Map.fromAscList $ zip (Set.toAscList $ readVars vs) (repeat bk)) - ((Map.map (\me -> (maybeToList me, bk)) $ writtenVars vs) - `Map.union` Map.fromAscList (zip (Set.toAscList $ usedVars - vs) (repeat ([], bk)))) + addBK :: BK -> Vars -> m VarsBK + addBK bk vs + = do let read = Map.fromAscList $ zip (Set.toAscList $ readVars vs) (repeat bk) + splitUsed <- splitEnds' $ Set.toList $ usedVars vs + splitWritten <- concatMapM splitEnds (Map.toList $ writtenVars vs) >>* Map.fromList + let used = Map.fromList (zip splitUsed (repeat ([], bk))) + return $ VarsBK read + ((Map.map (\me -> (maybeToList me, bk)) splitWritten) + `Map.union` used) + + splitEnds' = liftM (map fst) . concatMapM splitEnds . flip zip (repeat ()) + + splitEnds :: (Var, a) -> m [(Var, a)] + splitEnds (Var v, x) + = do t <- astTypeOf v + case t of + A.Chan {} -> return + [(Var $ A.DirectedVariable (findMeta v) dir v, x) + | dir <- [A.DirInput, A.DirOutput]] + _ -> return [(Var v, x)] reps (RepParItem r p) = r : reps p reps (SeqItems _) = [] reps (ParItems ps) = concatMap reps ps - getVars :: ParItems (BK, UsageLabel) -> Map.Map Var (Maybe BK, Maybe BK) - getVars (SeqItems ss) = foldUnionVarsBK $ [addBK bk $ nodeVars u | (bk, u) <- ss] - getVars (ParItems ps) = foldl (Map.unionWith join) Map.empty (map getVars ps) + getVars :: ParItems (BK, UsageLabel) -> m (Map.Map Var (Maybe BK, Maybe BK)) + getVars (SeqItems ss) = liftM foldUnionVarsBK $ sequence [addBK bk $ nodeVars u | (bk, u) <- ss] + getVars (ParItems ps) = liftM (foldl (Map.unionWith join) Map.empty) (mapM getVars ps) where join a b = (f (fst a) (fst b), f (snd a) (snd b)) f Nothing x = x @@ -344,13 +359,21 @@ checkPlainVarUsage sharedAttr (m, p) = check p -- A quick way to do this is to do a fold-union across all the maps, turning -- the values into lists that can then be scanned for any problems. check (ParItems ps) - = do sharedNames <- getCompState >>* csNameAttr >>* Map.filter (Set.member sharedAttr) - >>* Map.keysSet >>* (Set.map $ UsageCheckUtils.Var . A.Variable emptyMeta . A.Name emptyMeta) + = do rawSharedNames <- getCompState >>* csNameAttr >>* Map.filter (Set.member sharedAttr) + >>* Map.keysSet + -- We add in the directed versions of each (channel or not) so that + -- we make sure to ignore c? when c is shared: + let allSharedNames + = Set.fromList $ concatMap (map UsageCheckUtils.Var . + flip applyAll [id, A.DirectedVariable emptyMeta A.DirInput + , A.DirectedVariable emptyMeta A.DirOutput] + . A.Variable emptyMeta . A.Name emptyMeta) $ Set.toList rawSharedNames let decl = concatMap getDecl ps filt <- filterPlain + vars <- mapM getVars ps let examineVars = - map (filterMapByKey filt . (`difference` (Set.fromList decl `Set.union` sharedNames))) - (map getVars ps) + map (filterMapByKey filt . (`difference` (Set.fromList decl `Set.union` allSharedNames))) + vars checkCREW examineVars where difference m s = m `Map.difference` (Map.fromAscList $ zip (Set.toAscList diff --git a/frontends/OccamPasses.hs b/frontends/OccamPasses.hs index 4b0433f..5aa1778 100644 --- a/frontends/OccamPasses.hs +++ b/frontends/OccamPasses.hs @@ -147,6 +147,8 @@ checkConstants = occamOnlyPass "Check mandatory constants" return o doOption o = return o +-- | Turns things like cs[0]? into cs?[0], which helps later on in the usage checking +-- (as we can consider cs? a different array than cs!). pushUpDirections :: Pass pushUpDirections = occamOnlyPass "Push up direction specifiers on arrays" [] [] diff --git a/frontends/OccamTypes.hs b/frontends/OccamTypes.hs index a79bae9..270639f 100644 --- a/frontends/OccamTypes.hs +++ b/frontends/OccamTypes.hs @@ -769,6 +769,8 @@ inferTypes = occamOnlyPass "Infer types" A.Array _ (A.ChanEnd dir _ _) -> do v'' <- makeEnd m dir v' return (t', v'') + -- TODO infer direction of IS channel type + -- We will need the body! _ -> return (t', v') return $ A.Is m am' t'' v'' A.IsExpr m am t e ->