diff --git a/transformations/ArrayUsageCheck.hs b/transformations/ArrayUsageCheck.hs index 8268b03..ddec1c2 100644 --- a/transformations/ArrayUsageCheck.hs +++ b/transformations/ArrayUsageCheck.hs @@ -16,7 +16,7 @@ You should have received a copy of the GNU General Public License along with this program. If not, see . -} -module ArrayUsageCheck (checkArrayUsage, FlattenedExp(..), makeEquations, makeReplicatedEquations, usageCheckPass, VarMap) where +module ArrayUsageCheck (checkArrayUsage, FlattenedExp(..), makeEquations, usageCheckPass, VarMap) where import Control.Monad.Error import Control.Monad.State @@ -187,10 +187,13 @@ data ArrayAccess label = data ArrayAccessType = AAWrite | AARead -parItemToArrayAccessM :: Monad m => ([A.Replicator] -> a -> m [(label, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]) -> ParItems a -> m [ArrayAccess label] +parItemToArrayAccessM :: Monad m => ([(A.Replicator, Bool)] -> a -> m [(label, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]) -> ParItems a -> m [ArrayAccess label] parItemToArrayAccessM f (SeqItems xs) = sequence [concatMapM (f []) xs >>* Group] parItemToArrayAccessM f (ParItems ps) = concatMapM (parItemToArrayAccessM f) ps -parItemToArrayAccessM f (RepParItem rep p) = parItemToArrayAccessM (\reps -> f (rep:reps)) p +parItemToArrayAccessM f (RepParItem rep p) + = do normal <- parItemToArrayAccessM (\reps -> f ((rep,False):reps)) p + mirror <- parItemToArrayAccessM (\reps -> f ((rep,True):reps)) p + return [Replicated normal mirror] makeExpSet :: [FlattenedExp] -> Either String (Set.Set FlattenedExp) makeExpSet = foldM makeExpSet' Set.empty @@ -225,13 +228,14 @@ makeExpSet = foldM makeExpSet' Set.empty type VarMap = Map.Map FlattenedExp Int --- | Given a list of (replicated variable, start, count), a list of (written,read) parallel array accesses, --- the length of the array, returns the problems. +-- | Given a list of (written,read) expressions, an expression representing the upper array bound, returns either an error +-- (because the expressions can't be handled, typically) or a set of equalities, inequalities and mapping from +-- (unique, munged) variable name to variable-index in the equations. -- -- The general strategy is as follows. -- For every array index (here termed an "access"), we transform it into -- the usual [FlattenedExp] using the flatten function. Then we also transform --- any access that features a replicated variable into its mirrored version +-- any access that is in the mirror-side of a Replicated item into its mirrored version -- where each i is changed into i'. This is done by using vi=(variable "i",0) -- (in Scale _ vi) for the plain (normal) version, and vi=(variable "i",1) -- for the prime (mirror) version. @@ -245,45 +249,21 @@ type VarMap = Map.Map FlattenedExp Int -- -- The remainder of the work (correctly pairing equations) is done by -- squareAndPair. -makeReplicatedEquations :: [(A.Variable, A.Expression, A.Expression)] -> ([A.Expression],[A.Expression]) -> A.Expression -> +-- +-- TODO probably want to take this into the PassM monad at some point, to use the Meta in the error message +-- TODO allow "background knowledge" in the form of other equalities and inequalities +makeEquations :: ParItems ([A.Expression],[A.Expression]) -> A.Expression -> Either String [((A.Expression, A.Expression), VarMap, (EqualityProblem, InequalityProblem))] -makeReplicatedEquations repVars accesses bound - = do flattenedAccesses <- applyPairM (mapM copyAndFlatten) accesses - let flattenedAccessesMirror = applyPair (map mirrorAllVars) flattenedAccesses - bound' <- flatten bound - ((v,h,repVars',repVarIndexes),s) <- (flip runStateT) Map.empty $ - do repVars' <- mapM (\(v,s,c) -> - do s' <- lift (flatten s) >>= makeEquation s (error "Type is irrelevant for replication count") - >>= getSingleAccessItem "Modulo or Divide not allowed in replication start" - c' <- lift (flatten c) >>= makeEquation c (error "Type is irrelevant for replication count") - >>= getSingleAccessItem "Modulo or Divide not allowed in replication count" - return (v,s',c')) repVars - accesses' <- mapM (makeEquationWithPossibleRepBounds repVars') =<< makeEquationsWR flattenedAccesses - accesses'' <- mapM (makeEquationWithPossibleRepBounds repVars') =<< makeEquationsWR flattenedAccessesMirror - high <- makeEquation bound (error "Type is irrelevant for uppper bound") bound' +makeEquations accesses bound + = do bound' <- flatten bound + ((v,h,repVarIndexes),s) <- (flip runStateT) Map.empty $ + do (accesses',repVars) <- flip runStateT [] $ parItemToArrayAccessM mkEq accesses + high <- makeEquation bound (error "Type is irrelevant for upper bound") bound' >>= getSingleAccessItem "Multiple possible upper bounds not supported" - repVarIndexes <- mapM (\(v,_,_) -> seqPair (varIndex (Scale 1 (v,0)), varIndex (Scale 1 (v,1)))) repVars - return (Replicated accesses' accesses'',high, repVars',repVarIndexes) - return $ squareAndPair repVarIndexes s [v] (amap (const 0) h, addConstant (-1) h) + return (accesses', high, nub repVars) + return $ squareAndPair repVarIndexes s v (amap (const 0) h, addConstant (-1) h) where - copyAndFlatten :: A.Expression -> Either String (A.Expression, [FlattenedExp]) - copyAndFlatten e = do f <- flatten e - return (e, f) - - -- Mirrors all of repVars in the given equation - mirrorAllVars :: (A.Expression, [FlattenedExp]) -> (A.Expression, [FlattenedExp]) - mirrorAllVars (e, f) = (e, foldl mirror f repVars) - where - mirror :: [FlattenedExp] -> (A.Variable, A.Expression, A.Expression) -> [FlattenedExp] - mirror exp (v,_,_) = setIndexVar v 1 exp - - makeEquationsWR :: ([(A.Expression, [FlattenedExp])],[(A.Expression, [FlattenedExp])]) -> StateT (VarMap) (Either String) [ArrayAccess A.Expression] - makeEquationsWR (w,r) = do w' <- mapM (\(e,f) -> makeEquation e AAWrite f) w - r' <- mapM (\(e,f) -> makeEquation e AARead f) r - return (w' ++ r') - - setIndexVar :: A.Variable -> Int -> [FlattenedExp] -> [FlattenedExp] setIndexVar tv ti es = case mapAccumL (setIndexVar' tv ti) False es of (_, es') -> es' @@ -306,13 +286,11 @@ makeReplicatedEquations repVars accesses bound addPossibleRepBound' :: ArrayAccess label -> (A.Variable, Int, EqualityConstraintEquation, EqualityConstraintEquation) -> StateT (VarMap) (Either String) (ArrayAccess label) --- addPossibleRepBound' (Group accesses) v = mapM (seqPair . transformPair return (flip addPossibleRepBound v)) accesses >>* Group addPossibleRepBound' (Group accesses) v = sequence [addPossibleRepBound acc v >>* (\acc' -> (l,t,acc')) | (l,t,acc) <- accesses ] >>* Group addPossibleRepBound' (Replicated acc0 acc1) v = do acc0' <- mapM (flip addPossibleRepBound' v) acc0 acc1' <- mapM (flip addPossibleRepBound' v) acc1 return $ Replicated acc0' acc1' --- addPossibleRepBound' (Single (l,t,acc)) v = addPossibleRepBound acc v >>* (\x -> Group [(l,t,x)]) addPossibleRepBound :: (EqualityConstraintEquation, EqualityProblem, InequalityProblem) -> (A.Variable, Int, EqualityConstraintEquation, EqualityConstraintEquation) -> @@ -331,7 +309,40 @@ makeReplicatedEquations repVars accesses bound where newMin = minimum [fst $ bounds a, ind] newMax = maximum [snd $ bounds a, ind] - + + mkEq :: [(A.Replicator, Bool)] -> ([A.Expression], [A.Expression]) -> StateT [(CoeffIndex, CoeffIndex)] (StateT VarMap (Either String)) [(A.Expression, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))] + mkEq reps (ws,rs) = do repVarEqs <- mapM (liftF makeRepVarEq) reps + concatMapM (mkEq' repVarEqs) (ws' ++ rs') + where + ws' = zip (repeat AAWrite) ws + rs' = zip (repeat AARead) rs + + makeRepVarEq :: (A.Replicator, Bool) -> StateT VarMap (Either String) (A.Variable, EqualityConstraintEquation, EqualityConstraintEquation) + makeRepVarEq (A.For m varName from for, _) + = do from' <- lift (flatten from) >>= makeEquation from (error "Type is irrelevant for replication start") + >>= getSingleAccessItem "Modulo or Divide not allowed in replication start" + for' <- lift (flatten for) >>= makeEquation for (error "Type is irrelevant for replication count") + >>= getSingleAccessItem "Modulo or Divide not allowed in replication count" + return (A.Variable m varName, from', for') + + mkEq' :: [(A.Variable, EqualityConstraintEquation, EqualityConstraintEquation)] -> (ArrayAccessType, A.Expression) -> StateT [(CoeffIndex,CoeffIndex)] (StateT VarMap (Either String)) [(A.Expression, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))] + mkEq' repVarEqs (aat, e) + = do f <- lift . lift $ flatten e + f' <- foldM mirrorFlaggedVars f reps + g <- lift $ makeEquationWithPossibleRepBounds repVarEqs =<< makeEquation e aat f' + case g of + Group g' -> return g' + _ -> throwError "Replicated group found unexpectedly" + + mirrorFlaggedVars :: [FlattenedExp] -> (A.Replicator,Bool) -> StateT [(CoeffIndex,CoeffIndex)] (StateT VarMap (Either String)) [FlattenedExp] + mirrorFlaggedVars exp (_,False) = return exp + mirrorFlaggedVars exp (A.For m varName from for, True) + = do varIndexes <- lift $ seqPair (varIndex (Scale 1 (var,0)), varIndex (Scale 1 (var,1))) + modify (varIndexes :) + return $ setIndexVar var 1 exp + where + var = A.Variable m varName + -- Note that in all these functions, the divisor should always be positive! -- Takes an expression, and transforms it into an expression like: @@ -462,38 +473,6 @@ getSingleItem :: MonadTrans m => String -> [(a,b,c)] -> m (Either String) a getSingleItem _ [(item,_,_)] = lift $ return item getSingleItem err _ = lift $ throwError err - --- | Given a list of (written,read) expressions, an expression representing the upper array bound, returns either an error --- (because the expressions can't be handled, typically) or a set of equalities, inequalities and mapping from --- (unique, munged) variable name to variable-index in the equations. --- TODO probably want to take this into the PassM monad at some point, to use the Meta in the error message --- TODO allow "background knowledge" in the form of other equalities and inequalities -makeEquations :: ParItems ([A.Expression],[A.Expression]) -> A.Expression -> Either String [((A.Expression, A.Expression), VarMap, (EqualityProblem, InequalityProblem))] -makeEquations es high = makeEquations' >>* uncurry3 (squareAndPair []) - where - - -- | The body of makeEquations; returns the variable mapping, the list of (nx,ex) pairs and a pair - -- representing the upper and lower bounds of the array (inclusive). - makeEquations' :: Either String (VarMap, [ArrayAccess A.Expression], (EqualityConstraintEquation, EqualityConstraintEquation)) - makeEquations' = do ((v,h),s) <- (flip runStateT) Map.empty $ - do eqs <- parItemToArrayAccessM mkEq es - high' <- (lift $ flatten high) >>= makeEquation high (error "Type irrelevant for upper bound") - >>= getSingleAccessItem "Multiple possible upper bounds not supported" - return (eqs, high') - return (s,v,(amap (const 0) h, addConstant (-1) h)) - - mkEq :: [A.Replicator] -> ([A.Expression],[A.Expression]) -> StateT VarMap (Either String) [(A.Expression, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))] - mkEq [] (ws,rs) = concatMapM mkEq' (ws' ++ rs') - where - ws' = zip (repeat AAWrite) ws - rs' = zip (repeat AARead) rs - - mkEq' :: (ArrayAccessType, A.Expression) -> StateT VarMap (Either String) [(A.Expression, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))] - mkEq' (aat, e) = do g <- makeEquation e aat =<< lift (flatten e) - case g of - Group g' -> return g' - _ -> throwError "Replicated group found unexpectedly" - -- | Finds the index associated with a particular variable; either by finding an existing index -- or allocating a new one. varIndex :: FlattenedExp -> StateT (VarMap) (Either String) Int diff --git a/transformations/ArrayUsageCheckTest.hs b/transformations/ArrayUsageCheckTest.hs index 419e021..65662a5 100644 --- a/transformations/ArrayUsageCheckTest.hs +++ b/transformations/ArrayUsageCheckTest.hs @@ -32,6 +32,7 @@ import Test.QuickCheck hiding (check) import ArrayUsageCheck import qualified AST as A +import Metadata import Omega import TestHarness import TestUtils hiding (m) @@ -336,7 +337,7 @@ testMakeEquations = TestLabel "testMakeEquations" $ TestList ,testRep' (200,[((0,0),rep_i_mapping, [i === j], ij_16 &&& [i <== j ++ con (-1)] &&& leq [con 0, i, con 7] &&& leq [con 0, j, con 7])], - [(variable "i", intLiteral 1, intLiteral 6)],[exprVariable "i"],intLiteral 8) + ("i", intLiteral 1, intLiteral 6),[exprVariable "i"],intLiteral 8) ,testRep' (201, [((0,0),rep_i_mapping, [i === j], @@ -344,7 +345,7 @@ testMakeEquations = TestLabel "testMakeEquations" $ TestList &&& leq [con 0, i, con 7] &&& leq [con 0, j, con 7])] ++ replicate 2 ((0,1),rep_i_mapping,[i === con 3], leq [con 1,i, con 6] &&& leq [con 0, i, con 7] &&& leq [con 0, con 3, con 7]) ++ [((1,1),rep_i_mapping,[con 3 === con 3],concat $ replicate 2 (leq [con 0, con 3, con 7]))] - ,[(variable "i", intLiteral 1, intLiteral 6)],[exprVariable "i", intLiteral 3],intLiteral 8) + ,("i", intLiteral 1, intLiteral 6),[exprVariable "i", intLiteral 3],intLiteral 8) ,testRep' (202,[ ((0,1),rep_i_mapping,[i === j ++ con 1],ij_16 &&& [i <== j ++ con (-1)] &&& leq [con 0, i, con 7] &&& leq [con 0, j ++ con 1, con 7]) @@ -353,12 +354,13 @@ testMakeEquations = TestLabel "testMakeEquations" $ TestList ,((1,1),rep_i_mapping,[i === j],ij_16 &&& [i <== j ++ con (-1)] &&& leq [con 0, i ++ con 1, con 7] &&& leq [con 0, j ++ con 1, con 7])] ++ [((0,1),rep_i_mapping, [i === i ++ con 1], leq [con 1, i, con 6] &&& leq [con 1, i, con 6] &&& -- deliberate repeat leq [con 0, i, con 7] &&& leq [con 0,i ++ con 1, con 7])] - ,[(variable "i", intLiteral 1, intLiteral 6)],[exprVariable "i", buildExpr $ Dy (Var "i") A.Add (Lit $ intLiteral 1)],intLiteral 8) + ,("i", intLiteral 1, intLiteral 6),[exprVariable "i", buildExpr $ Dy (Var "i") A.Add (Lit $ intLiteral 1)],intLiteral 8) -- Only a constant: ,testRep' (210,[((0,0),rep_i_mapping,[con 4 === con 4],concat $ replicate 2 $ leq [con 0, con 4, con 7])] - ,[(variable "i", intLiteral 1, intLiteral 6)],[intLiteral 4],intLiteral 8) + ,("i", intLiteral 1, intLiteral 6),[intLiteral 4],intLiteral 8) + -- TODO test reads and writes are paired properly ] where -- These functions assume that you pair each list [x,y,z] as (x,y) (x,z) (y,z) in that order. @@ -381,11 +383,11 @@ testMakeEquations = TestLabel "testMakeEquations" $ TestList TestCase $ assertEquivalentProblems ("testMakeEquations " ++ show ind) (zip [0..] exprs) (map (transformTriple (applyPair (exprs !!)) id (uncurry makeConsistent)) $ map pairLatterTwo problems) =<< (checkRight $ makeEquations (makeParItems exprs) upperBound) - testRep' :: (Integer,[((Int, Int), VarMap,[HandyEq],[HandyIneq])],[(A.Variable, A.Expression, A.Expression)],[A.Expression],A.Expression) -> Test - testRep' (ind, problems, reps, exprs, upperBound) = + testRep' :: (Integer,[((Int, Int), VarMap,[HandyEq],[HandyIneq])],(String, A.Expression, A.Expression),[A.Expression],A.Expression) -> Test + testRep' (ind, problems, (repName, repFrom, repFor), exprs, upperBound) = TestCase $ assertEquivalentProblems ("testMakeEquations " ++ show ind) (zip [0..] exprs) (map (transformTriple (applyPair (exprs !!)) id (uncurry makeConsistent)) $ map pairLatterTwo problems) - =<< (checkRight $ makeReplicatedEquations reps (exprs,[]) upperBound) + =<< (checkRight $ makeEquations (RepParItem (A.For emptyMeta (simpleName repName) repFrom repFor) $ makeParItems exprs) upperBound) pairLatterTwo (l,a,b,c) = (l,a,(b,c)) diff --git a/transformations/UsageCheck.hs b/transformations/UsageCheck.hs index eadaa63..c06c713 100644 --- a/transformations/UsageCheck.hs +++ b/transformations/UsageCheck.hs @@ -16,7 +16,7 @@ You should have received a copy of the GNU General Public License along with this program. If not, see . -} -module UsageCheck (checkPar, customVarCompare, Decl, labelFunctions, ParItems(..), transformParItems, Var(..), Vars(..)) where +module UsageCheck (checkPar, customVarCompare, Decl, labelFunctions, ParItems(..), transformParItems, Var(..), Vars(..), vars) where import Data.Generics import Data.Graph.Inductive