Added proper support for sequential items in non-replicated PARs in the array usage checking

This commit is contained in:
Neil Brown 2008-01-27 01:43:42 +00:00
parent d37253d2af
commit 7276c3cc4a
3 changed files with 58 additions and 54 deletions

View File

@ -50,36 +50,36 @@ usageCheckPass t = do g' <- buildFlowGraph labelFunctions t
checkArrayUsage :: forall m. (Die m, CSM m) => FlowGraph m (Maybe Decl, Vars) -> m ()
checkArrayUsage graph = sequence_ $ checkPar checkArrayUsage' graph
where
-- TODO take proper account of replication!
flatten :: ParItems a -> [a]
flatten (ParItem x) = [x]
flatten (ParItems xs) = concatMap flatten xs
flatten (RepParItem _ x) = flatten x --TODO
checkArrayUsage' :: (Meta, ParItems (Maybe Decl, Vars)) -> m ()
checkArrayUsage' (m,p) = mapM_ (checkIndexes m) $ Map.toList $
foldl (Map.unionWith (\(a,b) (c,d) -> (a ++ c, b ++ d))) Map.empty $ map (groupArrayIndexes . snd) $ flatten p
groupArrayIndexes $ transformParItems snd p
-- Returns (array name, list of written-to indexes, list of read-from indexes)
groupArrayIndexes :: Vars -> Map.Map String ([A.Expression], [A.Expression])
groupArrayIndexes vs = zipMap join (makeList (writtenVars vs)) (makeList (readVars vs))
groupArrayIndexes :: ParItems Vars -> Map.Map String (ParItems ([A.Expression], [A.Expression]))
groupArrayIndexes vs = filterByKey $ transformParItems (uncurry (zipMap join) . (transformPair (makeList . writtenVars) (makeList . readVars)) . mkPair) vs
where
join :: Maybe [a] -> Maybe [a] -> Maybe ([a],[a])
join x y = Just (maybe [] id x, maybe [] id y)
flattenParItems :: ParItems a -> [a]
flattenParItems (SeqItems xs) = xs
flattenParItems (ParItems ps) = concatMap flattenParItems ps
flattenParItems (RepParItem _ p) = flattenParItems p
makeList :: Set.Set Var -> Map.Map String [A.Expression]
makeList = Set.fold (maybe id (uncurry $ Map.insertWith (++)) . getArrayIndex) Map.empty
-- sortAndGroupBy :: (a -> a -> Ordering) -> [a] -> [[a]]
-- sortAndGroupBy f = groupBy ((== EQ) . f) . sortBy f
filterByKey :: ParItems (Map.Map String ([A.Expression], [A.Expression])) -> Map.Map String (ParItems ([A.Expression], [A.Expression]))
filterByKey p = Map.fromList $ map (\k -> (k, transformParItems (Map.findWithDefault ([],[]) k) p)) (concatMap Map.keys $ flattenParItems p)
-- TODO this is quite hacky:
getArrayIndex :: Var -> Maybe (String, [A.Expression])
getArrayIndex (Var (A.SubscriptedVariable _ (A.Subscript _ e) (A.Variable _ n)))
= Just (A.nameName n, [e])
getArrayIndex _ = Nothing
checkIndexes :: Meta -> (String,([A.Expression],[A.Expression])) -> m ()
checkIndexes :: Meta -> (String,ParItems ([A.Expression],[A.Expression])) -> m ()
checkIndexes m (arrName, indexes)
= do userArrName <- getRealName (A.Name undefined undefined arrName)
arrType <- typeOfName (A.Name undefined undefined arrName)
@ -182,12 +182,16 @@ onlyConst _ = Nothing
-- Each item in the left branch can be paired with each other, and each item in the left branch can
-- be paired with all other items.
data ArrayAccess label =
Single (label, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))
| Group [(label, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]
Group [(label, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]
| Replicated [ArrayAccess label] [ArrayAccess label]
data ArrayAccessType = AAWrite | AARead
parItemToArrayAccessM :: Monad m => ([A.Replicator] -> a -> m [(label, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]) -> ParItems a -> m [ArrayAccess label]
parItemToArrayAccessM f (SeqItems xs) = sequence [concatMapM (f []) xs >>* Group]
parItemToArrayAccessM f (ParItems ps) = concatMapM (parItemToArrayAccessM f) ps
parItemToArrayAccessM f (RepParItem rep p) = parItemToArrayAccessM (\reps -> f (rep:reps)) p
makeExpSet :: [FlattenedExp] -> Either String (Set.Set FlattenedExp)
makeExpSet = foldM makeExpSet' Set.empty
where
@ -308,7 +312,7 @@ makeReplicatedEquations repVars accesses bound
= do acc0' <- mapM (flip addPossibleRepBound' v) acc0
acc1' <- mapM (flip addPossibleRepBound' v) acc1
return $ Replicated acc0' acc1'
addPossibleRepBound' (Single (l,t,acc)) v = addPossibleRepBound acc v >>* (\x -> Single (l,t,x))
-- addPossibleRepBound' (Single (l,t,acc)) v = addPossibleRepBound acc v >>* (\x -> Group [(l,t,x)])
addPossibleRepBound :: (EqualityConstraintEquation, EqualityProblem, InequalityProblem) ->
(A.Variable, Int, EqualityConstraintEquation, EqualityConstraintEquation) ->
@ -449,24 +453,10 @@ squareAndPair repVars s v lh
-- prime >= plain + 1 (prime - plain - 1 >= 0)
extraIneq = [simpleArray [(prime,1), (plain,-1), (0, -1)]]
{-
getSingles :: String -> [ArrayAccess] -> Either String [(EqualityConstraintEquation, EqualityProblem, InequalityProblem)]
getSingles err = mapM getSingle
where
getSingle (Single acc) = return acc
getSingle _ = throwError err
-}
getSingleAccessItem :: MonadTrans m => String -> ArrayAccess label -> m (Either String) EqualityConstraintEquation
getSingleAccessItem _ (Single (_,_,(acc,_,_))) = lift $ return acc
getSingleAccessItem _ (Group [(_,_,(acc,_,_))]) = lift $ return acc
getSingleAccessItem err _ = lift $ throwError err
{-
getSingleAccess :: MonadTrans m => String -> ArrayAccess -> m (Either String) (EqualityConstraintEquation, EqualityProblem, InequalityProblem)
getSingleAccess _ (Single acc) = lift $ return acc
getSingleAccess err _ = lift $ throwError err
-}
-- | Odd helper function for getting/asserting the first item of a triple from a singleton list inside a monad transformer (!)
getSingleItem :: MonadTrans m => String -> [(a,b,c)] -> m (Either String) a
getSingleItem _ [(item,_,_)] = lift $ return item
@ -478,25 +468,32 @@ getSingleItem err _ = lift $ throwError err
-- (unique, munged) variable name to variable-index in the equations.
-- TODO probably want to take this into the PassM monad at some point, to use the Meta in the error message
-- TODO allow "background knowledge" in the form of other equalities and inequalities
makeEquations :: ([A.Expression],[A.Expression]) -> A.Expression -> Either String [((A.Expression, A.Expression), VarMap, (EqualityProblem, InequalityProblem))]
makeEquations (esW,esR) high = makeEquations' >>* uncurry3 (squareAndPair [])
makeEquations :: ParItems ([A.Expression],[A.Expression]) -> A.Expression -> Either String [((A.Expression, A.Expression), VarMap, (EqualityProblem, InequalityProblem))]
makeEquations es high = makeEquations' >>* uncurry3 (squareAndPair [])
where
-- | The body of makeEquations; returns the variable mapping, the list of (nx,ex) pairs and a pair
-- representing the upper and lower bounds of the array (inclusive).
makeEquations' :: Either String (VarMap, [ArrayAccess A.Expression], (EqualityConstraintEquation, EqualityConstraintEquation))
makeEquations' = do ((v,h),s) <- (flip runStateT) Map.empty $
do eqsW <- mapM (makeEquationForItem AAWrite) esW
eqsR <- mapM (makeEquationForItem AARead) esR
do eqs <- parItemToArrayAccessM mkEq es
high' <- (lift $ flatten high) >>= makeEquation high (error "Type irrelevant for upper bound")
>>= getSingleAccessItem "Multiple possible upper bounds not supported"
return (eqsW ++ eqsR,high')
return (eqs, high')
return (s,v,(amap (const 0) h, addConstant (-1) h))
makeEquationForItem :: ArrayAccessType -> A.Expression -> StateT VarMap (Either String) (ArrayAccess A.Expression)
makeEquationForItem t e = lift (flatten e) >>= makeEquation e t
mkEq :: [A.Replicator] -> ([A.Expression],[A.Expression]) -> StateT VarMap (Either String) [(A.Expression, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]
mkEq [] (ws,rs) = concatMapM mkEq' (ws' ++ rs')
where
ws' = zip (repeat AAWrite) ws
rs' = zip (repeat AARead) rs
mkEq' :: (ArrayAccessType, A.Expression) -> StateT VarMap (Either String) [(A.Expression, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]
mkEq' (aat, e) = do g <- makeEquation e aat =<< lift (flatten e)
case g of
Group g' -> return g'
_ -> throwError "Replicated group found unexpectedly"
-- | Finds the index associated with a particular variable; either by finding an existing index
-- or allocating a new one.
varIndex :: FlattenedExp -> StateT (VarMap) (Either String) Int
@ -524,9 +521,6 @@ pairEqsAndBounds items bounds = (concatMap (uncurry pairEqs) . allPairs) items +
pairEqs :: ArrayAccess label
-> ArrayAccess label
-> [((label,label),EqualityProblem, InequalityProblem)]
pairEqs (Single acc) (Single acc') = maybeToList $ pairEqs'' acc acc'
pairEqs (Single acc) (Group accs) = mapMaybe (pairEqs'' acc) accs
pairEqs (Group accs) (Single acc) = mapMaybe (pairEqs'' acc) accs
pairEqs (Group accs) (Group accs') = mapMaybe (uncurry pairEqs'') $ product2 (accs,accs')
pairEqs (Replicated rA rB) lacc
= concatMap (pairEqs lacc) rA
@ -573,7 +567,7 @@ makeEquation l t summedItems
= do eqs <- process summedItems
let eqs' = map (transformTriple mapToArray (map mapToArray) (map mapToArray)) eqs :: [(EqualityConstraintEquation, EqualityProblem, InequalityProblem)]
return $ case eqs' of
[acc] -> Single (l,t,acc)
-- [acc] -> Single (l,t,acc)
_ -> Group [(l,t,e) | e <- eqs']
where
process :: [FlattenedExp] -> StateT VarMap (Either String) [(Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])]

View File

@ -35,6 +35,7 @@ import qualified AST as A
import Omega
import TestHarness
import TestUtils hiding (m)
import UsageCheck hiding (Var)
import Utils
testArrayCheck :: Test
@ -371,11 +372,14 @@ testMakeEquations = TestLabel "testMakeEquations" $ TestList
labelNums m n | m >= n = []
| otherwise = [(m,n') | n' <- [(m + 1) .. n]] ++ labelNums (m + 1) n
makeParItems :: [A.Expression] -> ParItems ([A.Expression],[A.Expression])
makeParItems es = ParItems $ map (\e -> SeqItems [([e],[])]) es
test' :: (Integer,[((Int,Int),VarMap,[HandyEq],[HandyIneq])],[A.Expression],A.Expression) -> Test
test' (ind, problems, exprs, upperBound) =
TestCase $ assertEquivalentProblems ("testMakeEquations " ++ show ind) (zip [0..] exprs)
(map (transformTriple (applyPair (exprs !!)) id (uncurry makeConsistent)) $ map pairLatterTwo problems) =<< (checkRight $ makeEquations (exprs,[]) upperBound)
(map (transformTriple (applyPair (exprs !!)) id (uncurry makeConsistent)) $ map pairLatterTwo problems) =<< (checkRight $ makeEquations (makeParItems exprs) upperBound)
testRep' :: (Integer,[((Int, Int), VarMap,[HandyEq],[HandyIneq])],[(A.Variable, A.Expression, A.Expression)],[A.Expression],A.Expression) -> Test
testRep' (ind, problems, reps, exprs, upperBound) =

View File

@ -16,7 +16,7 @@ You should have received a copy of the GNU General Public License along
with this program. If not, see <http://www.gnu.org/licenses/>.
-}
module UsageCheck (checkPar, customVarCompare, Decl, labelFunctions, ParItems(..), Var(..), Vars(..)) where
module UsageCheck (checkPar, customVarCompare, Decl, labelFunctions, ParItems(..), transformParItems, Var(..), Vars(..)) where
import Data.Generics
import Data.Graph.Inductive
@ -50,10 +50,16 @@ data Vars = Vars {
data Decl = ScopeIn String | ScopeOut String deriving (Show, Eq)
-- | A data type representing things that happen in parallel.
data ParItems a
= ParItem a
| ParItems [ParItems a]
| RepParItem A.Replicator (ParItems a)
= SeqItems [a] -- ^ A list of items that happen only in sequence (i.e. none are in parallel with each other)
| ParItems [ParItems a] -- ^ A list of items that are all in parallel with each other
| RepParItem A.Replicator (ParItems a) -- ^ A list of replicated items that happen in parallel
transformParItems :: (a -> b) -> ParItems a -> ParItems b
transformParItems f (SeqItems xs) = SeqItems $ map f xs
transformParItems f (ParItems ps) = ParItems $ map (transformParItems f) ps
transformParItems f (RepParItem r p) = RepParItem r (transformParItems f p)
emptyVars :: Vars
emptyVars = Vars Set.empty Set.empty Set.empty
@ -96,11 +102,11 @@ checkPar f g = map f allParItems
allParItems :: [(Meta, ParItems a)]
allParItems = map makeEntry $ map findNodes $ Map.toList allStartParEdges
where
findNodes :: (Int,[(Node,Node)]) -> (Node,[a])
findNodes (n,ses) = (undefined, concat [followUntilEdge e (EEndPar n) | (_,e) <- ses])
findNodes :: (Int,[(Node,Node)]) -> (Node,[ParItems a])
findNodes (n,ses) = (undefined, [SeqItems (followUntilEdge e (EEndPar n)) | (_,e) <- ses])
makeEntry :: (Node,[a]) -> (Meta, ParItems a)
makeEntry (_,x) = (emptyMeta {- TODO fix this again -} , ParItems $ map ParItem x)
makeEntry :: (Node,[ParItems a]) -> (Meta, ParItems a)
makeEntry (_,xs) = (emptyMeta {- TODO fix this again -} , ParItems xs)
-- | We need to follow all edges out of a particular node until we reach
-- an edge that matches the given edge. So what we effectively need