Added proper support for sequential items in non-replicated PARs in the array usage checking

This commit is contained in:
Neil Brown 2008-01-27 01:43:42 +00:00
parent d37253d2af
commit 7276c3cc4a
3 changed files with 58 additions and 54 deletions

View File

@ -50,28 +50,28 @@ usageCheckPass t = do g' <- buildFlowGraph labelFunctions t
checkArrayUsage :: forall m. (Die m, CSM m) => FlowGraph m (Maybe Decl, Vars) -> m () checkArrayUsage :: forall m. (Die m, CSM m) => FlowGraph m (Maybe Decl, Vars) -> m ()
checkArrayUsage graph = sequence_ $ checkPar checkArrayUsage' graph checkArrayUsage graph = sequence_ $ checkPar checkArrayUsage' graph
where where
-- TODO take proper account of replication!
flatten :: ParItems a -> [a]
flatten (ParItem x) = [x]
flatten (ParItems xs) = concatMap flatten xs
flatten (RepParItem _ x) = flatten x --TODO
checkArrayUsage' :: (Meta, ParItems (Maybe Decl, Vars)) -> m () checkArrayUsage' :: (Meta, ParItems (Maybe Decl, Vars)) -> m ()
checkArrayUsage' (m,p) = mapM_ (checkIndexes m) $ Map.toList $ checkArrayUsage' (m,p) = mapM_ (checkIndexes m) $ Map.toList $
foldl (Map.unionWith (\(a,b) (c,d) -> (a ++ c, b ++ d))) Map.empty $ map (groupArrayIndexes . snd) $ flatten p groupArrayIndexes $ transformParItems snd p
-- Returns (array name, list of written-to indexes, list of read-from indexes) -- Returns (array name, list of written-to indexes, list of read-from indexes)
groupArrayIndexes :: Vars -> Map.Map String ([A.Expression], [A.Expression]) groupArrayIndexes :: ParItems Vars -> Map.Map String (ParItems ([A.Expression], [A.Expression]))
groupArrayIndexes vs = zipMap join (makeList (writtenVars vs)) (makeList (readVars vs)) groupArrayIndexes vs = filterByKey $ transformParItems (uncurry (zipMap join) . (transformPair (makeList . writtenVars) (makeList . readVars)) . mkPair) vs
where where
join :: Maybe [a] -> Maybe [a] -> Maybe ([a],[a]) join :: Maybe [a] -> Maybe [a] -> Maybe ([a],[a])
join x y = Just (maybe [] id x, maybe [] id y) join x y = Just (maybe [] id x, maybe [] id y)
flattenParItems :: ParItems a -> [a]
flattenParItems (SeqItems xs) = xs
flattenParItems (ParItems ps) = concatMap flattenParItems ps
flattenParItems (RepParItem _ p) = flattenParItems p
makeList :: Set.Set Var -> Map.Map String [A.Expression] makeList :: Set.Set Var -> Map.Map String [A.Expression]
makeList = Set.fold (maybe id (uncurry $ Map.insertWith (++)) . getArrayIndex) Map.empty makeList = Set.fold (maybe id (uncurry $ Map.insertWith (++)) . getArrayIndex) Map.empty
-- sortAndGroupBy :: (a -> a -> Ordering) -> [a] -> [[a]] filterByKey :: ParItems (Map.Map String ([A.Expression], [A.Expression])) -> Map.Map String (ParItems ([A.Expression], [A.Expression]))
-- sortAndGroupBy f = groupBy ((== EQ) . f) . sortBy f filterByKey p = Map.fromList $ map (\k -> (k, transformParItems (Map.findWithDefault ([],[]) k) p)) (concatMap Map.keys $ flattenParItems p)
-- TODO this is quite hacky: -- TODO this is quite hacky:
getArrayIndex :: Var -> Maybe (String, [A.Expression]) getArrayIndex :: Var -> Maybe (String, [A.Expression])
@ -79,7 +79,7 @@ checkArrayUsage graph = sequence_ $ checkPar checkArrayUsage' graph
= Just (A.nameName n, [e]) = Just (A.nameName n, [e])
getArrayIndex _ = Nothing getArrayIndex _ = Nothing
checkIndexes :: Meta -> (String,([A.Expression],[A.Expression])) -> m () checkIndexes :: Meta -> (String,ParItems ([A.Expression],[A.Expression])) -> m ()
checkIndexes m (arrName, indexes) checkIndexes m (arrName, indexes)
= do userArrName <- getRealName (A.Name undefined undefined arrName) = do userArrName <- getRealName (A.Name undefined undefined arrName)
arrType <- typeOfName (A.Name undefined undefined arrName) arrType <- typeOfName (A.Name undefined undefined arrName)
@ -182,12 +182,16 @@ onlyConst _ = Nothing
-- Each item in the left branch can be paired with each other, and each item in the left branch can -- Each item in the left branch can be paired with each other, and each item in the left branch can
-- be paired with all other items. -- be paired with all other items.
data ArrayAccess label = data ArrayAccess label =
Single (label, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem)) Group [(label, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]
| Group [(label, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]
| Replicated [ArrayAccess label] [ArrayAccess label] | Replicated [ArrayAccess label] [ArrayAccess label]
data ArrayAccessType = AAWrite | AARead data ArrayAccessType = AAWrite | AARead
parItemToArrayAccessM :: Monad m => ([A.Replicator] -> a -> m [(label, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]) -> ParItems a -> m [ArrayAccess label]
parItemToArrayAccessM f (SeqItems xs) = sequence [concatMapM (f []) xs >>* Group]
parItemToArrayAccessM f (ParItems ps) = concatMapM (parItemToArrayAccessM f) ps
parItemToArrayAccessM f (RepParItem rep p) = parItemToArrayAccessM (\reps -> f (rep:reps)) p
makeExpSet :: [FlattenedExp] -> Either String (Set.Set FlattenedExp) makeExpSet :: [FlattenedExp] -> Either String (Set.Set FlattenedExp)
makeExpSet = foldM makeExpSet' Set.empty makeExpSet = foldM makeExpSet' Set.empty
where where
@ -308,7 +312,7 @@ makeReplicatedEquations repVars accesses bound
= do acc0' <- mapM (flip addPossibleRepBound' v) acc0 = do acc0' <- mapM (flip addPossibleRepBound' v) acc0
acc1' <- mapM (flip addPossibleRepBound' v) acc1 acc1' <- mapM (flip addPossibleRepBound' v) acc1
return $ Replicated acc0' acc1' return $ Replicated acc0' acc1'
addPossibleRepBound' (Single (l,t,acc)) v = addPossibleRepBound acc v >>* (\x -> Single (l,t,x)) -- addPossibleRepBound' (Single (l,t,acc)) v = addPossibleRepBound acc v >>* (\x -> Group [(l,t,x)])
addPossibleRepBound :: (EqualityConstraintEquation, EqualityProblem, InequalityProblem) -> addPossibleRepBound :: (EqualityConstraintEquation, EqualityProblem, InequalityProblem) ->
(A.Variable, Int, EqualityConstraintEquation, EqualityConstraintEquation) -> (A.Variable, Int, EqualityConstraintEquation, EqualityConstraintEquation) ->
@ -449,24 +453,10 @@ squareAndPair repVars s v lh
-- prime >= plain + 1 (prime - plain - 1 >= 0) -- prime >= plain + 1 (prime - plain - 1 >= 0)
extraIneq = [simpleArray [(prime,1), (plain,-1), (0, -1)]] extraIneq = [simpleArray [(prime,1), (plain,-1), (0, -1)]]
{-
getSingles :: String -> [ArrayAccess] -> Either String [(EqualityConstraintEquation, EqualityProblem, InequalityProblem)]
getSingles err = mapM getSingle
where
getSingle (Single acc) = return acc
getSingle _ = throwError err
-}
getSingleAccessItem :: MonadTrans m => String -> ArrayAccess label -> m (Either String) EqualityConstraintEquation getSingleAccessItem :: MonadTrans m => String -> ArrayAccess label -> m (Either String) EqualityConstraintEquation
getSingleAccessItem _ (Single (_,_,(acc,_,_))) = lift $ return acc getSingleAccessItem _ (Group [(_,_,(acc,_,_))]) = lift $ return acc
getSingleAccessItem err _ = lift $ throwError err getSingleAccessItem err _ = lift $ throwError err
{-
getSingleAccess :: MonadTrans m => String -> ArrayAccess -> m (Either String) (EqualityConstraintEquation, EqualityProblem, InequalityProblem)
getSingleAccess _ (Single acc) = lift $ return acc
getSingleAccess err _ = lift $ throwError err
-}
-- | Odd helper function for getting/asserting the first item of a triple from a singleton list inside a monad transformer (!) -- | Odd helper function for getting/asserting the first item of a triple from a singleton list inside a monad transformer (!)
getSingleItem :: MonadTrans m => String -> [(a,b,c)] -> m (Either String) a getSingleItem :: MonadTrans m => String -> [(a,b,c)] -> m (Either String) a
getSingleItem _ [(item,_,_)] = lift $ return item getSingleItem _ [(item,_,_)] = lift $ return item
@ -478,24 +468,31 @@ getSingleItem err _ = lift $ throwError err
-- (unique, munged) variable name to variable-index in the equations. -- (unique, munged) variable name to variable-index in the equations.
-- TODO probably want to take this into the PassM monad at some point, to use the Meta in the error message -- TODO probably want to take this into the PassM monad at some point, to use the Meta in the error message
-- TODO allow "background knowledge" in the form of other equalities and inequalities -- TODO allow "background knowledge" in the form of other equalities and inequalities
makeEquations :: ([A.Expression],[A.Expression]) -> A.Expression -> Either String [((A.Expression, A.Expression), VarMap, (EqualityProblem, InequalityProblem))] makeEquations :: ParItems ([A.Expression],[A.Expression]) -> A.Expression -> Either String [((A.Expression, A.Expression), VarMap, (EqualityProblem, InequalityProblem))]
makeEquations (esW,esR) high = makeEquations' >>* uncurry3 (squareAndPair []) makeEquations es high = makeEquations' >>* uncurry3 (squareAndPair [])
where where
-- | The body of makeEquations; returns the variable mapping, the list of (nx,ex) pairs and a pair -- | The body of makeEquations; returns the variable mapping, the list of (nx,ex) pairs and a pair
-- representing the upper and lower bounds of the array (inclusive). -- representing the upper and lower bounds of the array (inclusive).
makeEquations' :: Either String (VarMap, [ArrayAccess A.Expression], (EqualityConstraintEquation, EqualityConstraintEquation)) makeEquations' :: Either String (VarMap, [ArrayAccess A.Expression], (EqualityConstraintEquation, EqualityConstraintEquation))
makeEquations' = do ((v,h),s) <- (flip runStateT) Map.empty $ makeEquations' = do ((v,h),s) <- (flip runStateT) Map.empty $
do eqsW <- mapM (makeEquationForItem AAWrite) esW do eqs <- parItemToArrayAccessM mkEq es
eqsR <- mapM (makeEquationForItem AARead) esR
high' <- (lift $ flatten high) >>= makeEquation high (error "Type irrelevant for upper bound") high' <- (lift $ flatten high) >>= makeEquation high (error "Type irrelevant for upper bound")
>>= getSingleAccessItem "Multiple possible upper bounds not supported" >>= getSingleAccessItem "Multiple possible upper bounds not supported"
return (eqsW ++ eqsR,high') return (eqs, high')
return (s,v,(amap (const 0) h, addConstant (-1) h)) return (s,v,(amap (const 0) h, addConstant (-1) h))
mkEq :: [A.Replicator] -> ([A.Expression],[A.Expression]) -> StateT VarMap (Either String) [(A.Expression, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]
mkEq [] (ws,rs) = concatMapM mkEq' (ws' ++ rs')
where
ws' = zip (repeat AAWrite) ws
rs' = zip (repeat AARead) rs
makeEquationForItem :: ArrayAccessType -> A.Expression -> StateT VarMap (Either String) (ArrayAccess A.Expression) mkEq' :: (ArrayAccessType, A.Expression) -> StateT VarMap (Either String) [(A.Expression, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]
makeEquationForItem t e = lift (flatten e) >>= makeEquation e t mkEq' (aat, e) = do g <- makeEquation e aat =<< lift (flatten e)
case g of
Group g' -> return g'
_ -> throwError "Replicated group found unexpectedly"
-- | Finds the index associated with a particular variable; either by finding an existing index -- | Finds the index associated with a particular variable; either by finding an existing index
-- or allocating a new one. -- or allocating a new one.
@ -524,9 +521,6 @@ pairEqsAndBounds items bounds = (concatMap (uncurry pairEqs) . allPairs) items +
pairEqs :: ArrayAccess label pairEqs :: ArrayAccess label
-> ArrayAccess label -> ArrayAccess label
-> [((label,label),EqualityProblem, InequalityProblem)] -> [((label,label),EqualityProblem, InequalityProblem)]
pairEqs (Single acc) (Single acc') = maybeToList $ pairEqs'' acc acc'
pairEqs (Single acc) (Group accs) = mapMaybe (pairEqs'' acc) accs
pairEqs (Group accs) (Single acc) = mapMaybe (pairEqs'' acc) accs
pairEqs (Group accs) (Group accs') = mapMaybe (uncurry pairEqs'') $ product2 (accs,accs') pairEqs (Group accs) (Group accs') = mapMaybe (uncurry pairEqs'') $ product2 (accs,accs')
pairEqs (Replicated rA rB) lacc pairEqs (Replicated rA rB) lacc
= concatMap (pairEqs lacc) rA = concatMap (pairEqs lacc) rA
@ -573,7 +567,7 @@ makeEquation l t summedItems
= do eqs <- process summedItems = do eqs <- process summedItems
let eqs' = map (transformTriple mapToArray (map mapToArray) (map mapToArray)) eqs :: [(EqualityConstraintEquation, EqualityProblem, InequalityProblem)] let eqs' = map (transformTriple mapToArray (map mapToArray) (map mapToArray)) eqs :: [(EqualityConstraintEquation, EqualityProblem, InequalityProblem)]
return $ case eqs' of return $ case eqs' of
[acc] -> Single (l,t,acc) -- [acc] -> Single (l,t,acc)
_ -> Group [(l,t,e) | e <- eqs'] _ -> Group [(l,t,e) | e <- eqs']
where where
process :: [FlattenedExp] -> StateT VarMap (Either String) [(Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])] process :: [FlattenedExp] -> StateT VarMap (Either String) [(Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])]

View File

@ -35,6 +35,7 @@ import qualified AST as A
import Omega import Omega
import TestHarness import TestHarness
import TestUtils hiding (m) import TestUtils hiding (m)
import UsageCheck hiding (Var)
import Utils import Utils
testArrayCheck :: Test testArrayCheck :: Test
@ -372,10 +373,13 @@ testMakeEquations = TestLabel "testMakeEquations" $ TestList
| otherwise = [(m,n') | n' <- [(m + 1) .. n]] ++ labelNums (m + 1) n | otherwise = [(m,n') | n' <- [(m + 1) .. n]] ++ labelNums (m + 1) n
makeParItems :: [A.Expression] -> ParItems ([A.Expression],[A.Expression])
makeParItems es = ParItems $ map (\e -> SeqItems [([e],[])]) es
test' :: (Integer,[((Int,Int),VarMap,[HandyEq],[HandyIneq])],[A.Expression],A.Expression) -> Test test' :: (Integer,[((Int,Int),VarMap,[HandyEq],[HandyIneq])],[A.Expression],A.Expression) -> Test
test' (ind, problems, exprs, upperBound) = test' (ind, problems, exprs, upperBound) =
TestCase $ assertEquivalentProblems ("testMakeEquations " ++ show ind) (zip [0..] exprs) TestCase $ assertEquivalentProblems ("testMakeEquations " ++ show ind) (zip [0..] exprs)
(map (transformTriple (applyPair (exprs !!)) id (uncurry makeConsistent)) $ map pairLatterTwo problems) =<< (checkRight $ makeEquations (exprs,[]) upperBound) (map (transformTriple (applyPair (exprs !!)) id (uncurry makeConsistent)) $ map pairLatterTwo problems) =<< (checkRight $ makeEquations (makeParItems exprs) upperBound)
testRep' :: (Integer,[((Int, Int), VarMap,[HandyEq],[HandyIneq])],[(A.Variable, A.Expression, A.Expression)],[A.Expression],A.Expression) -> Test testRep' :: (Integer,[((Int, Int), VarMap,[HandyEq],[HandyIneq])],[(A.Variable, A.Expression, A.Expression)],[A.Expression],A.Expression) -> Test
testRep' (ind, problems, reps, exprs, upperBound) = testRep' (ind, problems, reps, exprs, upperBound) =

View File

@ -16,7 +16,7 @@ You should have received a copy of the GNU General Public License along
with this program. If not, see <http://www.gnu.org/licenses/>. with this program. If not, see <http://www.gnu.org/licenses/>.
-} -}
module UsageCheck (checkPar, customVarCompare, Decl, labelFunctions, ParItems(..), Var(..), Vars(..)) where module UsageCheck (checkPar, customVarCompare, Decl, labelFunctions, ParItems(..), transformParItems, Var(..), Vars(..)) where
import Data.Generics import Data.Generics
import Data.Graph.Inductive import Data.Graph.Inductive
@ -50,10 +50,16 @@ data Vars = Vars {
data Decl = ScopeIn String | ScopeOut String deriving (Show, Eq) data Decl = ScopeIn String | ScopeOut String deriving (Show, Eq)
-- | A data type representing things that happen in parallel.
data ParItems a data ParItems a
= ParItem a = SeqItems [a] -- ^ A list of items that happen only in sequence (i.e. none are in parallel with each other)
| ParItems [ParItems a] | ParItems [ParItems a] -- ^ A list of items that are all in parallel with each other
| RepParItem A.Replicator (ParItems a) | RepParItem A.Replicator (ParItems a) -- ^ A list of replicated items that happen in parallel
transformParItems :: (a -> b) -> ParItems a -> ParItems b
transformParItems f (SeqItems xs) = SeqItems $ map f xs
transformParItems f (ParItems ps) = ParItems $ map (transformParItems f) ps
transformParItems f (RepParItem r p) = RepParItem r (transformParItems f p)
emptyVars :: Vars emptyVars :: Vars
emptyVars = Vars Set.empty Set.empty Set.empty emptyVars = Vars Set.empty Set.empty Set.empty
@ -96,11 +102,11 @@ checkPar f g = map f allParItems
allParItems :: [(Meta, ParItems a)] allParItems :: [(Meta, ParItems a)]
allParItems = map makeEntry $ map findNodes $ Map.toList allStartParEdges allParItems = map makeEntry $ map findNodes $ Map.toList allStartParEdges
where where
findNodes :: (Int,[(Node,Node)]) -> (Node,[a]) findNodes :: (Int,[(Node,Node)]) -> (Node,[ParItems a])
findNodes (n,ses) = (undefined, concat [followUntilEdge e (EEndPar n) | (_,e) <- ses]) findNodes (n,ses) = (undefined, [SeqItems (followUntilEdge e (EEndPar n)) | (_,e) <- ses])
makeEntry :: (Node,[a]) -> (Meta, ParItems a) makeEntry :: (Node,[ParItems a]) -> (Meta, ParItems a)
makeEntry (_,x) = (emptyMeta {- TODO fix this again -} , ParItems $ map ParItem x) makeEntry (_,xs) = (emptyMeta {- TODO fix this again -} , ParItems xs)
-- | We need to follow all edges out of a particular node until we reach -- | We need to follow all edges out of a particular node until we reach
-- an edge that matches the given edge. So what we effectively need -- an edge that matches the given edge. So what we effectively need