diff --git a/checks/ArrayUsageCheck.hs b/checks/ArrayUsageCheck.hs index 6ba35c9..4e962e8 100644 --- a/checks/ArrayUsageCheck.hs +++ b/checks/ArrayUsageCheck.hs @@ -40,6 +40,7 @@ import Data.List import qualified Data.Map as Map import Data.Maybe import qualified Data.Set as Set +import qualified Data.Traversable as T import qualified AST as A import CompState @@ -96,8 +97,9 @@ findRepSolutions reps bks -- | A check-pass that checks the given ParItems (usually generated from a control-flow graph) -- for any overlapping array indices. checkArrayUsage :: forall m. (Die m, CSMR m, MonadIO m) => (Meta, ParItems (BK, UsageLabel)) -> m () -checkArrayUsage (m,p) = mapM_ (checkIndexes m) $ Map.toList $ - groupArrayIndexes $ fmap (transformPair id nodeVars) p +checkArrayUsage (m,p) + = do indexes <- groupArrayIndexes $ fmap (transformPair id nodeVars) p + mapM_ (checkIndexes m) $ Map.toList indexes where getDecl :: UsageLabel -> Maybe String getDecl = join . fmap getScopeIn . nodeDecl @@ -108,40 +110,52 @@ checkArrayUsage (m,p) = mapM_ (checkIndexes m) $ Map.toList $ -- Takes a ParItems Vars, and returns a map from array-variable-name to a list of writes and a list of reads for that array. -- Returns (array name, list of written-to indexes, list of read-from indexes) - groupArrayIndexes :: ParItems (BK, Vars) -> Map.Map String (ParItems (BK, [A.Expression], [A.Expression])) - groupArrayIndexes = filterByKey . fmap - (\(bk,vs) -> zipMap (join bk) (makeList $ (Map.keysSet $ writtenVars vs) - `Set.union` (usedVars vs)) (makeList $ readVars vs)) + groupArrayIndexes :: ParItems (BK, Vars) -> m (Map.Map (String, Maybe A.Direction) (ParItems (BK, [A.Expression], [A.Expression]))) + groupArrayIndexes = liftM filterByKey . T.mapM + (\(bk,vs) -> do w <- makeList $ (Map.keysSet $ writtenVars vs) `Set.union` (usedVars vs) + r <- makeList $ readVars vs + return $ zipMap (join bk) w r) where join :: b -> Maybe [a] -> Maybe [a] -> Maybe (b, [a],[a]) join k x y = Just (k, fromMaybe [] x, fromMaybe [] y) -- Turns a set of variables into a map (from array-name to list of index-expressions) - makeList :: Set.Set Var -> Map.Map String [A.Expression] - makeList = Set.fold (maybe id (uncurry $ Map.insertWith (++)) . getArrayIndex) Map.empty + makeList :: Set.Set Var -> m (Map.Map (String, Maybe A.Direction) [A.Expression]) + makeList vs = do indexes <- concatMapM getArrayIndex $ Set.toList vs + return $ Map.fromListWith (++) indexes -- Lifts a map (from array-name to expression-lists) inside a ParItems to being a map (from array-name to ParItems of expression lists) - filterByKey :: ParItems (Map.Map String (BK, [A.Expression], [A.Expression])) -> Map.Map String (ParItems (BK, - [A.Expression], [A.Expression])) + filterByKey :: ParItems (Map.Map (String, Maybe A.Direction) (BK, [A.Expression], [A.Expression])) + -> Map.Map (String, Maybe A.Direction) (ParItems (BK, [A.Expression], [A.Expression])) filterByKey p = Map.fromList $ map trans keys where - keys :: [String] + keys :: [(String, Maybe A.Direction)] keys = concatMap Map.keys $ flattenParItems p - trans :: String -> (String, ParItems (BK, [A.Expression], [A.Expression])) + trans :: (String, Maybe A.Direction) -> ((String, Maybe A.Direction), ParItems (BK, [A.Expression], [A.Expression])) trans k = (k, fmap (Map.findWithDefault ([], [], []) k) p) -- Gets the (array-name, indexes) from a Var. -- TODO this is quite hacky, and doesn't yet deal with slices and so on: - getArrayIndex :: Var -> Maybe (String, [A.Expression]) - getArrayIndex (Var (A.SubscriptedVariable _ (A.Subscript _ _ e) (A.Variable _ n))) - = Just (A.nameName n, [e]) - getArrayIndex _ = Nothing + getArrayIndex :: Var -> m [((String, Maybe A.Direction), [A.Expression])] + getArrayIndex (Var v@(A.SubscriptedVariable _ (A.Subscript _ _ e) (A.Variable _ n))) + = do t <- astTypeOf v + let dirs = case t of + A.Chan {} -> [Just A.DirInput, Just A.DirOutput] + _ -> [Nothing] + return [((A.nameName n, d), [e]) | d <- dirs] + getArrayIndex (Var (A.SubscriptedVariable _ (A.Subscript _ _ e) + (A.DirectedVariable _ dir (A.Variable _ n)))) + = return [((A.nameName n, Just dir), [e])] + getArrayIndex (Var (A.DirectedVariable _ dir (A.SubscriptedVariable _ (A.Subscript _ _ e) + (A.Variable _ n)))) + = return [((A.nameName n, Just dir), [e])] + getArrayIndex _ = return [] -- Checks the given ParItems of writes and reads against each other. The -- String (array-name) and Meta are only used for printing out error messages - checkIndexes :: Meta -> (String, ParItems (BK, [A.Expression], [A.Expression])) -> m () - checkIndexes m (arrName, indexes) = do + checkIndexes :: Meta -> ((String, Maybe A.Direction), ParItems (BK, [A.Expression], [A.Expression])) -> m () + checkIndexes m ((arrName, arrDir), indexes) = do sharedNames <- getCompState >>* csNameAttr let declNames = [x | Just x <- fmap (getDecl . snd) $ flattenParItems p] when (Map.lookup arrName sharedNames /= Just NameShared && arrName `notElem` declNames) $