diff --git a/transformations/ArrayUsageCheck.hs b/transformations/ArrayUsageCheck.hs
index 8268b03..ddec1c2 100644
--- a/transformations/ArrayUsageCheck.hs
+++ b/transformations/ArrayUsageCheck.hs
@@ -16,7 +16,7 @@ You should have received a copy of the GNU General Public License along
with this program. If not, see .
-}
-module ArrayUsageCheck (checkArrayUsage, FlattenedExp(..), makeEquations, makeReplicatedEquations, usageCheckPass, VarMap) where
+module ArrayUsageCheck (checkArrayUsage, FlattenedExp(..), makeEquations, usageCheckPass, VarMap) where
import Control.Monad.Error
import Control.Monad.State
@@ -187,10 +187,13 @@ data ArrayAccess label =
data ArrayAccessType = AAWrite | AARead
-parItemToArrayAccessM :: Monad m => ([A.Replicator] -> a -> m [(label, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]) -> ParItems a -> m [ArrayAccess label]
+parItemToArrayAccessM :: Monad m => ([(A.Replicator, Bool)] -> a -> m [(label, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]) -> ParItems a -> m [ArrayAccess label]
parItemToArrayAccessM f (SeqItems xs) = sequence [concatMapM (f []) xs >>* Group]
parItemToArrayAccessM f (ParItems ps) = concatMapM (parItemToArrayAccessM f) ps
-parItemToArrayAccessM f (RepParItem rep p) = parItemToArrayAccessM (\reps -> f (rep:reps)) p
+parItemToArrayAccessM f (RepParItem rep p)
+ = do normal <- parItemToArrayAccessM (\reps -> f ((rep,False):reps)) p
+ mirror <- parItemToArrayAccessM (\reps -> f ((rep,True):reps)) p
+ return [Replicated normal mirror]
makeExpSet :: [FlattenedExp] -> Either String (Set.Set FlattenedExp)
makeExpSet = foldM makeExpSet' Set.empty
@@ -225,13 +228,14 @@ makeExpSet = foldM makeExpSet' Set.empty
type VarMap = Map.Map FlattenedExp Int
--- | Given a list of (replicated variable, start, count), a list of (written,read) parallel array accesses,
--- the length of the array, returns the problems.
+-- | Given a list of (written,read) expressions, an expression representing the upper array bound, returns either an error
+-- (because the expressions can't be handled, typically) or a set of equalities, inequalities and mapping from
+-- (unique, munged) variable name to variable-index in the equations.
--
-- The general strategy is as follows.
-- For every array index (here termed an "access"), we transform it into
-- the usual [FlattenedExp] using the flatten function. Then we also transform
--- any access that features a replicated variable into its mirrored version
+-- any access that is in the mirror-side of a Replicated item into its mirrored version
-- where each i is changed into i'. This is done by using vi=(variable "i",0)
-- (in Scale _ vi) for the plain (normal) version, and vi=(variable "i",1)
-- for the prime (mirror) version.
@@ -245,45 +249,21 @@ type VarMap = Map.Map FlattenedExp Int
--
-- The remainder of the work (correctly pairing equations) is done by
-- squareAndPair.
-makeReplicatedEquations :: [(A.Variable, A.Expression, A.Expression)] -> ([A.Expression],[A.Expression]) -> A.Expression ->
+--
+-- TODO probably want to take this into the PassM monad at some point, to use the Meta in the error message
+-- TODO allow "background knowledge" in the form of other equalities and inequalities
+makeEquations :: ParItems ([A.Expression],[A.Expression]) -> A.Expression ->
Either String [((A.Expression, A.Expression), VarMap, (EqualityProblem, InequalityProblem))]
-makeReplicatedEquations repVars accesses bound
- = do flattenedAccesses <- applyPairM (mapM copyAndFlatten) accesses
- let flattenedAccessesMirror = applyPair (map mirrorAllVars) flattenedAccesses
- bound' <- flatten bound
- ((v,h,repVars',repVarIndexes),s) <- (flip runStateT) Map.empty $
- do repVars' <- mapM (\(v,s,c) ->
- do s' <- lift (flatten s) >>= makeEquation s (error "Type is irrelevant for replication count")
- >>= getSingleAccessItem "Modulo or Divide not allowed in replication start"
- c' <- lift (flatten c) >>= makeEquation c (error "Type is irrelevant for replication count")
- >>= getSingleAccessItem "Modulo or Divide not allowed in replication count"
- return (v,s',c')) repVars
- accesses' <- mapM (makeEquationWithPossibleRepBounds repVars') =<< makeEquationsWR flattenedAccesses
- accesses'' <- mapM (makeEquationWithPossibleRepBounds repVars') =<< makeEquationsWR flattenedAccessesMirror
- high <- makeEquation bound (error "Type is irrelevant for uppper bound") bound'
+makeEquations accesses bound
+ = do bound' <- flatten bound
+ ((v,h,repVarIndexes),s) <- (flip runStateT) Map.empty $
+ do (accesses',repVars) <- flip runStateT [] $ parItemToArrayAccessM mkEq accesses
+ high <- makeEquation bound (error "Type is irrelevant for upper bound") bound'
>>= getSingleAccessItem "Multiple possible upper bounds not supported"
- repVarIndexes <- mapM (\(v,_,_) -> seqPair (varIndex (Scale 1 (v,0)), varIndex (Scale 1 (v,1)))) repVars
- return (Replicated accesses' accesses'',high, repVars',repVarIndexes)
- return $ squareAndPair repVarIndexes s [v] (amap (const 0) h, addConstant (-1) h)
+ return (accesses', high, nub repVars)
+ return $ squareAndPair repVarIndexes s v (amap (const 0) h, addConstant (-1) h)
where
- copyAndFlatten :: A.Expression -> Either String (A.Expression, [FlattenedExp])
- copyAndFlatten e = do f <- flatten e
- return (e, f)
-
- -- Mirrors all of repVars in the given equation
- mirrorAllVars :: (A.Expression, [FlattenedExp]) -> (A.Expression, [FlattenedExp])
- mirrorAllVars (e, f) = (e, foldl mirror f repVars)
- where
- mirror :: [FlattenedExp] -> (A.Variable, A.Expression, A.Expression) -> [FlattenedExp]
- mirror exp (v,_,_) = setIndexVar v 1 exp
-
- makeEquationsWR :: ([(A.Expression, [FlattenedExp])],[(A.Expression, [FlattenedExp])]) -> StateT (VarMap) (Either String) [ArrayAccess A.Expression]
- makeEquationsWR (w,r) = do w' <- mapM (\(e,f) -> makeEquation e AAWrite f) w
- r' <- mapM (\(e,f) -> makeEquation e AARead f) r
- return (w' ++ r')
-
-
setIndexVar :: A.Variable -> Int -> [FlattenedExp] -> [FlattenedExp]
setIndexVar tv ti es = case mapAccumL (setIndexVar' tv ti) False es of
(_, es') -> es'
@@ -306,13 +286,11 @@ makeReplicatedEquations repVars accesses bound
addPossibleRepBound' :: ArrayAccess label ->
(A.Variable, Int, EqualityConstraintEquation, EqualityConstraintEquation) ->
StateT (VarMap) (Either String) (ArrayAccess label)
--- addPossibleRepBound' (Group accesses) v = mapM (seqPair . transformPair return (flip addPossibleRepBound v)) accesses >>* Group
addPossibleRepBound' (Group accesses) v = sequence [addPossibleRepBound acc v >>* (\acc' -> (l,t,acc')) | (l,t,acc) <- accesses ] >>* Group
addPossibleRepBound' (Replicated acc0 acc1) v
= do acc0' <- mapM (flip addPossibleRepBound' v) acc0
acc1' <- mapM (flip addPossibleRepBound' v) acc1
return $ Replicated acc0' acc1'
--- addPossibleRepBound' (Single (l,t,acc)) v = addPossibleRepBound acc v >>* (\x -> Group [(l,t,x)])
addPossibleRepBound :: (EqualityConstraintEquation, EqualityProblem, InequalityProblem) ->
(A.Variable, Int, EqualityConstraintEquation, EqualityConstraintEquation) ->
@@ -331,7 +309,40 @@ makeReplicatedEquations repVars accesses bound
where
newMin = minimum [fst $ bounds a, ind]
newMax = maximum [snd $ bounds a, ind]
-
+
+ mkEq :: [(A.Replicator, Bool)] -> ([A.Expression], [A.Expression]) -> StateT [(CoeffIndex, CoeffIndex)] (StateT VarMap (Either String)) [(A.Expression, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]
+ mkEq reps (ws,rs) = do repVarEqs <- mapM (liftF makeRepVarEq) reps
+ concatMapM (mkEq' repVarEqs) (ws' ++ rs')
+ where
+ ws' = zip (repeat AAWrite) ws
+ rs' = zip (repeat AARead) rs
+
+ makeRepVarEq :: (A.Replicator, Bool) -> StateT VarMap (Either String) (A.Variable, EqualityConstraintEquation, EqualityConstraintEquation)
+ makeRepVarEq (A.For m varName from for, _)
+ = do from' <- lift (flatten from) >>= makeEquation from (error "Type is irrelevant for replication start")
+ >>= getSingleAccessItem "Modulo or Divide not allowed in replication start"
+ for' <- lift (flatten for) >>= makeEquation for (error "Type is irrelevant for replication count")
+ >>= getSingleAccessItem "Modulo or Divide not allowed in replication count"
+ return (A.Variable m varName, from', for')
+
+ mkEq' :: [(A.Variable, EqualityConstraintEquation, EqualityConstraintEquation)] -> (ArrayAccessType, A.Expression) -> StateT [(CoeffIndex,CoeffIndex)] (StateT VarMap (Either String)) [(A.Expression, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]
+ mkEq' repVarEqs (aat, e)
+ = do f <- lift . lift $ flatten e
+ f' <- foldM mirrorFlaggedVars f reps
+ g <- lift $ makeEquationWithPossibleRepBounds repVarEqs =<< makeEquation e aat f'
+ case g of
+ Group g' -> return g'
+ _ -> throwError "Replicated group found unexpectedly"
+
+ mirrorFlaggedVars :: [FlattenedExp] -> (A.Replicator,Bool) -> StateT [(CoeffIndex,CoeffIndex)] (StateT VarMap (Either String)) [FlattenedExp]
+ mirrorFlaggedVars exp (_,False) = return exp
+ mirrorFlaggedVars exp (A.For m varName from for, True)
+ = do varIndexes <- lift $ seqPair (varIndex (Scale 1 (var,0)), varIndex (Scale 1 (var,1)))
+ modify (varIndexes :)
+ return $ setIndexVar var 1 exp
+ where
+ var = A.Variable m varName
+
-- Note that in all these functions, the divisor should always be positive!
-- Takes an expression, and transforms it into an expression like:
@@ -462,38 +473,6 @@ getSingleItem :: MonadTrans m => String -> [(a,b,c)] -> m (Either String) a
getSingleItem _ [(item,_,_)] = lift $ return item
getSingleItem err _ = lift $ throwError err
-
--- | Given a list of (written,read) expressions, an expression representing the upper array bound, returns either an error
--- (because the expressions can't be handled, typically) or a set of equalities, inequalities and mapping from
--- (unique, munged) variable name to variable-index in the equations.
--- TODO probably want to take this into the PassM monad at some point, to use the Meta in the error message
--- TODO allow "background knowledge" in the form of other equalities and inequalities
-makeEquations :: ParItems ([A.Expression],[A.Expression]) -> A.Expression -> Either String [((A.Expression, A.Expression), VarMap, (EqualityProblem, InequalityProblem))]
-makeEquations es high = makeEquations' >>* uncurry3 (squareAndPair [])
- where
-
- -- | The body of makeEquations; returns the variable mapping, the list of (nx,ex) pairs and a pair
- -- representing the upper and lower bounds of the array (inclusive).
- makeEquations' :: Either String (VarMap, [ArrayAccess A.Expression], (EqualityConstraintEquation, EqualityConstraintEquation))
- makeEquations' = do ((v,h),s) <- (flip runStateT) Map.empty $
- do eqs <- parItemToArrayAccessM mkEq es
- high' <- (lift $ flatten high) >>= makeEquation high (error "Type irrelevant for upper bound")
- >>= getSingleAccessItem "Multiple possible upper bounds not supported"
- return (eqs, high')
- return (s,v,(amap (const 0) h, addConstant (-1) h))
-
- mkEq :: [A.Replicator] -> ([A.Expression],[A.Expression]) -> StateT VarMap (Either String) [(A.Expression, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]
- mkEq [] (ws,rs) = concatMapM mkEq' (ws' ++ rs')
- where
- ws' = zip (repeat AAWrite) ws
- rs' = zip (repeat AARead) rs
-
- mkEq' :: (ArrayAccessType, A.Expression) -> StateT VarMap (Either String) [(A.Expression, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]
- mkEq' (aat, e) = do g <- makeEquation e aat =<< lift (flatten e)
- case g of
- Group g' -> return g'
- _ -> throwError "Replicated group found unexpectedly"
-
-- | Finds the index associated with a particular variable; either by finding an existing index
-- or allocating a new one.
varIndex :: FlattenedExp -> StateT (VarMap) (Either String) Int
diff --git a/transformations/ArrayUsageCheckTest.hs b/transformations/ArrayUsageCheckTest.hs
index 419e021..65662a5 100644
--- a/transformations/ArrayUsageCheckTest.hs
+++ b/transformations/ArrayUsageCheckTest.hs
@@ -32,6 +32,7 @@ import Test.QuickCheck hiding (check)
import ArrayUsageCheck
import qualified AST as A
+import Metadata
import Omega
import TestHarness
import TestUtils hiding (m)
@@ -336,7 +337,7 @@ testMakeEquations = TestLabel "testMakeEquations" $ TestList
,testRep' (200,[((0,0),rep_i_mapping, [i === j],
ij_16 &&& [i <== j ++ con (-1)]
&&& leq [con 0, i, con 7] &&& leq [con 0, j, con 7])],
- [(variable "i", intLiteral 1, intLiteral 6)],[exprVariable "i"],intLiteral 8)
+ ("i", intLiteral 1, intLiteral 6),[exprVariable "i"],intLiteral 8)
,testRep' (201,
[((0,0),rep_i_mapping, [i === j],
@@ -344,7 +345,7 @@ testMakeEquations = TestLabel "testMakeEquations" $ TestList
&&& leq [con 0, i, con 7] &&& leq [con 0, j, con 7])]
++ replicate 2 ((0,1),rep_i_mapping,[i === con 3], leq [con 1,i, con 6] &&& leq [con 0, i, con 7] &&& leq [con 0, con 3, con 7])
++ [((1,1),rep_i_mapping,[con 3 === con 3],concat $ replicate 2 (leq [con 0, con 3, con 7]))]
- ,[(variable "i", intLiteral 1, intLiteral 6)],[exprVariable "i", intLiteral 3],intLiteral 8)
+ ,("i", intLiteral 1, intLiteral 6),[exprVariable "i", intLiteral 3],intLiteral 8)
,testRep' (202,[
((0,1),rep_i_mapping,[i === j ++ con 1],ij_16 &&& [i <== j ++ con (-1)] &&& leq [con 0, i, con 7] &&& leq [con 0, j ++ con 1, con 7])
@@ -353,12 +354,13 @@ testMakeEquations = TestLabel "testMakeEquations" $ TestList
,((1,1),rep_i_mapping,[i === j],ij_16 &&& [i <== j ++ con (-1)] &&& leq [con 0, i ++ con 1, con 7] &&& leq [con 0, j ++ con 1, con 7])]
++ [((0,1),rep_i_mapping, [i === i ++ con 1], leq [con 1, i, con 6] &&& leq [con 1, i, con 6] &&& -- deliberate repeat
leq [con 0, i, con 7] &&& leq [con 0,i ++ con 1, con 7])]
- ,[(variable "i", intLiteral 1, intLiteral 6)],[exprVariable "i", buildExpr $ Dy (Var "i") A.Add (Lit $ intLiteral 1)],intLiteral 8)
+ ,("i", intLiteral 1, intLiteral 6),[exprVariable "i", buildExpr $ Dy (Var "i") A.Add (Lit $ intLiteral 1)],intLiteral 8)
-- Only a constant:
,testRep' (210,[((0,0),rep_i_mapping,[con 4 === con 4],concat $ replicate 2 $ leq [con 0, con 4, con 7])]
- ,[(variable "i", intLiteral 1, intLiteral 6)],[intLiteral 4],intLiteral 8)
+ ,("i", intLiteral 1, intLiteral 6),[intLiteral 4],intLiteral 8)
+ -- TODO test reads and writes are paired properly
]
where
-- These functions assume that you pair each list [x,y,z] as (x,y) (x,z) (y,z) in that order.
@@ -381,11 +383,11 @@ testMakeEquations = TestLabel "testMakeEquations" $ TestList
TestCase $ assertEquivalentProblems ("testMakeEquations " ++ show ind) (zip [0..] exprs)
(map (transformTriple (applyPair (exprs !!)) id (uncurry makeConsistent)) $ map pairLatterTwo problems) =<< (checkRight $ makeEquations (makeParItems exprs) upperBound)
- testRep' :: (Integer,[((Int, Int), VarMap,[HandyEq],[HandyIneq])],[(A.Variable, A.Expression, A.Expression)],[A.Expression],A.Expression) -> Test
- testRep' (ind, problems, reps, exprs, upperBound) =
+ testRep' :: (Integer,[((Int, Int), VarMap,[HandyEq],[HandyIneq])],(String, A.Expression, A.Expression),[A.Expression],A.Expression) -> Test
+ testRep' (ind, problems, (repName, repFrom, repFor), exprs, upperBound) =
TestCase $ assertEquivalentProblems ("testMakeEquations " ++ show ind) (zip [0..] exprs)
(map (transformTriple (applyPair (exprs !!)) id (uncurry makeConsistent)) $ map pairLatterTwo problems)
- =<< (checkRight $ makeReplicatedEquations reps (exprs,[]) upperBound)
+ =<< (checkRight $ makeEquations (RepParItem (A.For emptyMeta (simpleName repName) repFrom repFor) $ makeParItems exprs) upperBound)
pairLatterTwo (l,a,b,c) = (l,a,(b,c))
diff --git a/transformations/UsageCheck.hs b/transformations/UsageCheck.hs
index eadaa63..c06c713 100644
--- a/transformations/UsageCheck.hs
+++ b/transformations/UsageCheck.hs
@@ -16,7 +16,7 @@ You should have received a copy of the GNU General Public License along
with this program. If not, see .
-}
-module UsageCheck (checkPar, customVarCompare, Decl, labelFunctions, ParItems(..), transformParItems, Var(..), Vars(..)) where
+module UsageCheck (checkPar, customVarCompare, Decl, labelFunctions, ParItems(..), transformParItems, Var(..), Vars(..), vars) where
import Data.Generics
import Data.Graph.Inductive