diff --git a/transformations/ArrayUsageCheck.hs b/transformations/ArrayUsageCheck.hs index df5882f..8268b03 100644 --- a/transformations/ArrayUsageCheck.hs +++ b/transformations/ArrayUsageCheck.hs @@ -50,36 +50,36 @@ usageCheckPass t = do g' <- buildFlowGraph labelFunctions t checkArrayUsage :: forall m. (Die m, CSM m) => FlowGraph m (Maybe Decl, Vars) -> m () checkArrayUsage graph = sequence_ $ checkPar checkArrayUsage' graph where - -- TODO take proper account of replication! - flatten :: ParItems a -> [a] - flatten (ParItem x) = [x] - flatten (ParItems xs) = concatMap flatten xs - flatten (RepParItem _ x) = flatten x --TODO - checkArrayUsage' :: (Meta, ParItems (Maybe Decl, Vars)) -> m () checkArrayUsage' (m,p) = mapM_ (checkIndexes m) $ Map.toList $ - foldl (Map.unionWith (\(a,b) (c,d) -> (a ++ c, b ++ d))) Map.empty $ map (groupArrayIndexes . snd) $ flatten p + groupArrayIndexes $ transformParItems snd p -- Returns (array name, list of written-to indexes, list of read-from indexes) - groupArrayIndexes :: Vars -> Map.Map String ([A.Expression], [A.Expression]) - groupArrayIndexes vs = zipMap join (makeList (writtenVars vs)) (makeList (readVars vs)) + groupArrayIndexes :: ParItems Vars -> Map.Map String (ParItems ([A.Expression], [A.Expression])) + groupArrayIndexes vs = filterByKey $ transformParItems (uncurry (zipMap join) . (transformPair (makeList . writtenVars) (makeList . readVars)) . mkPair) vs where join :: Maybe [a] -> Maybe [a] -> Maybe ([a],[a]) join x y = Just (maybe [] id x, maybe [] id y) + + flattenParItems :: ParItems a -> [a] + flattenParItems (SeqItems xs) = xs + flattenParItems (ParItems ps) = concatMap flattenParItems ps + flattenParItems (RepParItem _ p) = flattenParItems p + makeList :: Set.Set Var -> Map.Map String [A.Expression] makeList = Set.fold (maybe id (uncurry $ Map.insertWith (++)) . getArrayIndex) Map.empty - --- sortAndGroupBy :: (a -> a -> Ordering) -> [a] -> [[a]] --- sortAndGroupBy f = groupBy ((== EQ) . f) . sortBy f + + filterByKey :: ParItems (Map.Map String ([A.Expression], [A.Expression])) -> Map.Map String (ParItems ([A.Expression], [A.Expression])) + filterByKey p = Map.fromList $ map (\k -> (k, transformParItems (Map.findWithDefault ([],[]) k) p)) (concatMap Map.keys $ flattenParItems p) -- TODO this is quite hacky: getArrayIndex :: Var -> Maybe (String, [A.Expression]) getArrayIndex (Var (A.SubscriptedVariable _ (A.Subscript _ e) (A.Variable _ n))) = Just (A.nameName n, [e]) getArrayIndex _ = Nothing - - checkIndexes :: Meta -> (String,([A.Expression],[A.Expression])) -> m () + + checkIndexes :: Meta -> (String,ParItems ([A.Expression],[A.Expression])) -> m () checkIndexes m (arrName, indexes) = do userArrName <- getRealName (A.Name undefined undefined arrName) arrType <- typeOfName (A.Name undefined undefined arrName) @@ -182,12 +182,16 @@ onlyConst _ = Nothing -- Each item in the left branch can be paired with each other, and each item in the left branch can -- be paired with all other items. data ArrayAccess label = - Single (label, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem)) - | Group [(label, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))] + Group [(label, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))] | Replicated [ArrayAccess label] [ArrayAccess label] data ArrayAccessType = AAWrite | AARead +parItemToArrayAccessM :: Monad m => ([A.Replicator] -> a -> m [(label, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]) -> ParItems a -> m [ArrayAccess label] +parItemToArrayAccessM f (SeqItems xs) = sequence [concatMapM (f []) xs >>* Group] +parItemToArrayAccessM f (ParItems ps) = concatMapM (parItemToArrayAccessM f) ps +parItemToArrayAccessM f (RepParItem rep p) = parItemToArrayAccessM (\reps -> f (rep:reps)) p + makeExpSet :: [FlattenedExp] -> Either String (Set.Set FlattenedExp) makeExpSet = foldM makeExpSet' Set.empty where @@ -308,7 +312,7 @@ makeReplicatedEquations repVars accesses bound = do acc0' <- mapM (flip addPossibleRepBound' v) acc0 acc1' <- mapM (flip addPossibleRepBound' v) acc1 return $ Replicated acc0' acc1' - addPossibleRepBound' (Single (l,t,acc)) v = addPossibleRepBound acc v >>* (\x -> Single (l,t,x)) +-- addPossibleRepBound' (Single (l,t,acc)) v = addPossibleRepBound acc v >>* (\x -> Group [(l,t,x)]) addPossibleRepBound :: (EqualityConstraintEquation, EqualityProblem, InequalityProblem) -> (A.Variable, Int, EqualityConstraintEquation, EqualityConstraintEquation) -> @@ -449,24 +453,10 @@ squareAndPair repVars s v lh -- prime >= plain + 1 (prime - plain - 1 >= 0) extraIneq = [simpleArray [(prime,1), (plain,-1), (0, -1)]] -{- -getSingles :: String -> [ArrayAccess] -> Either String [(EqualityConstraintEquation, EqualityProblem, InequalityProblem)] -getSingles err = mapM getSingle - where - getSingle (Single acc) = return acc - getSingle _ = throwError err --} - getSingleAccessItem :: MonadTrans m => String -> ArrayAccess label -> m (Either String) EqualityConstraintEquation -getSingleAccessItem _ (Single (_,_,(acc,_,_))) = lift $ return acc +getSingleAccessItem _ (Group [(_,_,(acc,_,_))]) = lift $ return acc getSingleAccessItem err _ = lift $ throwError err -{- -getSingleAccess :: MonadTrans m => String -> ArrayAccess -> m (Either String) (EqualityConstraintEquation, EqualityProblem, InequalityProblem) -getSingleAccess _ (Single acc) = lift $ return acc -getSingleAccess err _ = lift $ throwError err --} - -- | Odd helper function for getting/asserting the first item of a triple from a singleton list inside a monad transformer (!) getSingleItem :: MonadTrans m => String -> [(a,b,c)] -> m (Either String) a getSingleItem _ [(item,_,_)] = lift $ return item @@ -478,25 +468,32 @@ getSingleItem err _ = lift $ throwError err -- (unique, munged) variable name to variable-index in the equations. -- TODO probably want to take this into the PassM monad at some point, to use the Meta in the error message -- TODO allow "background knowledge" in the form of other equalities and inequalities -makeEquations :: ([A.Expression],[A.Expression]) -> A.Expression -> Either String [((A.Expression, A.Expression), VarMap, (EqualityProblem, InequalityProblem))] -makeEquations (esW,esR) high = makeEquations' >>* uncurry3 (squareAndPair []) +makeEquations :: ParItems ([A.Expression],[A.Expression]) -> A.Expression -> Either String [((A.Expression, A.Expression), VarMap, (EqualityProblem, InequalityProblem))] +makeEquations es high = makeEquations' >>* uncurry3 (squareAndPair []) where -- | The body of makeEquations; returns the variable mapping, the list of (nx,ex) pairs and a pair -- representing the upper and lower bounds of the array (inclusive). makeEquations' :: Either String (VarMap, [ArrayAccess A.Expression], (EqualityConstraintEquation, EqualityConstraintEquation)) makeEquations' = do ((v,h),s) <- (flip runStateT) Map.empty $ - do eqsW <- mapM (makeEquationForItem AAWrite) esW - eqsR <- mapM (makeEquationForItem AARead) esR + do eqs <- parItemToArrayAccessM mkEq es high' <- (lift $ flatten high) >>= makeEquation high (error "Type irrelevant for upper bound") >>= getSingleAccessItem "Multiple possible upper bounds not supported" - return (eqsW ++ eqsR,high') + return (eqs, high') return (s,v,(amap (const 0) h, addConstant (-1) h)) - - makeEquationForItem :: ArrayAccessType -> A.Expression -> StateT VarMap (Either String) (ArrayAccess A.Expression) - makeEquationForItem t e = lift (flatten e) >>= makeEquation e t - + mkEq :: [A.Replicator] -> ([A.Expression],[A.Expression]) -> StateT VarMap (Either String) [(A.Expression, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))] + mkEq [] (ws,rs) = concatMapM mkEq' (ws' ++ rs') + where + ws' = zip (repeat AAWrite) ws + rs' = zip (repeat AARead) rs + + mkEq' :: (ArrayAccessType, A.Expression) -> StateT VarMap (Either String) [(A.Expression, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))] + mkEq' (aat, e) = do g <- makeEquation e aat =<< lift (flatten e) + case g of + Group g' -> return g' + _ -> throwError "Replicated group found unexpectedly" + -- | Finds the index associated with a particular variable; either by finding an existing index -- or allocating a new one. varIndex :: FlattenedExp -> StateT (VarMap) (Either String) Int @@ -524,9 +521,6 @@ pairEqsAndBounds items bounds = (concatMap (uncurry pairEqs) . allPairs) items + pairEqs :: ArrayAccess label -> ArrayAccess label -> [((label,label),EqualityProblem, InequalityProblem)] - pairEqs (Single acc) (Single acc') = maybeToList $ pairEqs'' acc acc' - pairEqs (Single acc) (Group accs) = mapMaybe (pairEqs'' acc) accs - pairEqs (Group accs) (Single acc) = mapMaybe (pairEqs'' acc) accs pairEqs (Group accs) (Group accs') = mapMaybe (uncurry pairEqs'') $ product2 (accs,accs') pairEqs (Replicated rA rB) lacc = concatMap (pairEqs lacc) rA @@ -573,7 +567,7 @@ makeEquation l t summedItems = do eqs <- process summedItems let eqs' = map (transformTriple mapToArray (map mapToArray) (map mapToArray)) eqs :: [(EqualityConstraintEquation, EqualityProblem, InequalityProblem)] return $ case eqs' of - [acc] -> Single (l,t,acc) +-- [acc] -> Single (l,t,acc) _ -> Group [(l,t,e) | e <- eqs'] where process :: [FlattenedExp] -> StateT VarMap (Either String) [(Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])] diff --git a/transformations/ArrayUsageCheckTest.hs b/transformations/ArrayUsageCheckTest.hs index 816cc91..419e021 100644 --- a/transformations/ArrayUsageCheckTest.hs +++ b/transformations/ArrayUsageCheckTest.hs @@ -35,6 +35,7 @@ import qualified AST as A import Omega import TestHarness import TestUtils hiding (m) +import UsageCheck hiding (Var) import Utils testArrayCheck :: Test @@ -371,11 +372,14 @@ testMakeEquations = TestLabel "testMakeEquations" $ TestList labelNums m n | m >= n = [] | otherwise = [(m,n') | n' <- [(m + 1) .. n]] ++ labelNums (m + 1) n + + makeParItems :: [A.Expression] -> ParItems ([A.Expression],[A.Expression]) + makeParItems es = ParItems $ map (\e -> SeqItems [([e],[])]) es test' :: (Integer,[((Int,Int),VarMap,[HandyEq],[HandyIneq])],[A.Expression],A.Expression) -> Test test' (ind, problems, exprs, upperBound) = TestCase $ assertEquivalentProblems ("testMakeEquations " ++ show ind) (zip [0..] exprs) - (map (transformTriple (applyPair (exprs !!)) id (uncurry makeConsistent)) $ map pairLatterTwo problems) =<< (checkRight $ makeEquations (exprs,[]) upperBound) + (map (transformTriple (applyPair (exprs !!)) id (uncurry makeConsistent)) $ map pairLatterTwo problems) =<< (checkRight $ makeEquations (makeParItems exprs) upperBound) testRep' :: (Integer,[((Int, Int), VarMap,[HandyEq],[HandyIneq])],[(A.Variable, A.Expression, A.Expression)],[A.Expression],A.Expression) -> Test testRep' (ind, problems, reps, exprs, upperBound) = diff --git a/transformations/UsageCheck.hs b/transformations/UsageCheck.hs index 3612a3b..eadaa63 100644 --- a/transformations/UsageCheck.hs +++ b/transformations/UsageCheck.hs @@ -16,7 +16,7 @@ You should have received a copy of the GNU General Public License along with this program. If not, see . -} -module UsageCheck (checkPar, customVarCompare, Decl, labelFunctions, ParItems(..), Var(..), Vars(..)) where +module UsageCheck (checkPar, customVarCompare, Decl, labelFunctions, ParItems(..), transformParItems, Var(..), Vars(..)) where import Data.Generics import Data.Graph.Inductive @@ -50,10 +50,16 @@ data Vars = Vars { data Decl = ScopeIn String | ScopeOut String deriving (Show, Eq) +-- | A data type representing things that happen in parallel. data ParItems a - = ParItem a - | ParItems [ParItems a] - | RepParItem A.Replicator (ParItems a) + = SeqItems [a] -- ^ A list of items that happen only in sequence (i.e. none are in parallel with each other) + | ParItems [ParItems a] -- ^ A list of items that are all in parallel with each other + | RepParItem A.Replicator (ParItems a) -- ^ A list of replicated items that happen in parallel + +transformParItems :: (a -> b) -> ParItems a -> ParItems b +transformParItems f (SeqItems xs) = SeqItems $ map f xs +transformParItems f (ParItems ps) = ParItems $ map (transformParItems f) ps +transformParItems f (RepParItem r p) = RepParItem r (transformParItems f p) emptyVars :: Vars emptyVars = Vars Set.empty Set.empty Set.empty @@ -96,11 +102,11 @@ checkPar f g = map f allParItems allParItems :: [(Meta, ParItems a)] allParItems = map makeEntry $ map findNodes $ Map.toList allStartParEdges where - findNodes :: (Int,[(Node,Node)]) -> (Node,[a]) - findNodes (n,ses) = (undefined, concat [followUntilEdge e (EEndPar n) | (_,e) <- ses]) + findNodes :: (Int,[(Node,Node)]) -> (Node,[ParItems a]) + findNodes (n,ses) = (undefined, [SeqItems (followUntilEdge e (EEndPar n)) | (_,e) <- ses]) - makeEntry :: (Node,[a]) -> (Meta, ParItems a) - makeEntry (_,x) = (emptyMeta {- TODO fix this again -} , ParItems $ map ParItem x) + makeEntry :: (Node,[ParItems a]) -> (Meta, ParItems a) + makeEntry (_,xs) = (emptyMeta {- TODO fix this again -} , ParItems xs) -- | We need to follow all edges out of a particular node until we reach -- an edge that matches the given edge. So what we effectively need