Merged makeEquations with makeReplicatedEquations and adjusted the tests accordingly

This commit is contained in:
Neil Brown 2008-01-27 16:53:07 +00:00
parent 5a69459668
commit 349d3c5811
3 changed files with 65 additions and 84 deletions

View File

@ -16,7 +16,7 @@ You should have received a copy of the GNU General Public License along
with this program. If not, see <http://www.gnu.org/licenses/>.
-}
module ArrayUsageCheck (checkArrayUsage, FlattenedExp(..), makeEquations, makeReplicatedEquations, usageCheckPass, VarMap) where
module ArrayUsageCheck (checkArrayUsage, FlattenedExp(..), makeEquations, usageCheckPass, VarMap) where
import Control.Monad.Error
import Control.Monad.State
@ -187,10 +187,13 @@ data ArrayAccess label =
data ArrayAccessType = AAWrite | AARead
parItemToArrayAccessM :: Monad m => ([A.Replicator] -> a -> m [(label, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]) -> ParItems a -> m [ArrayAccess label]
parItemToArrayAccessM :: Monad m => ([(A.Replicator, Bool)] -> a -> m [(label, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]) -> ParItems a -> m [ArrayAccess label]
parItemToArrayAccessM f (SeqItems xs) = sequence [concatMapM (f []) xs >>* Group]
parItemToArrayAccessM f (ParItems ps) = concatMapM (parItemToArrayAccessM f) ps
parItemToArrayAccessM f (RepParItem rep p) = parItemToArrayAccessM (\reps -> f (rep:reps)) p
parItemToArrayAccessM f (RepParItem rep p)
= do normal <- parItemToArrayAccessM (\reps -> f ((rep,False):reps)) p
mirror <- parItemToArrayAccessM (\reps -> f ((rep,True):reps)) p
return [Replicated normal mirror]
makeExpSet :: [FlattenedExp] -> Either String (Set.Set FlattenedExp)
makeExpSet = foldM makeExpSet' Set.empty
@ -225,13 +228,14 @@ makeExpSet = foldM makeExpSet' Set.empty
type VarMap = Map.Map FlattenedExp Int
-- | Given a list of (replicated variable, start, count), a list of (written,read) parallel array accesses,
-- the length of the array, returns the problems.
-- | Given a list of (written,read) expressions, an expression representing the upper array bound, returns either an error
-- (because the expressions can't be handled, typically) or a set of equalities, inequalities and mapping from
-- (unique, munged) variable name to variable-index in the equations.
--
-- The general strategy is as follows.
-- For every array index (here termed an "access"), we transform it into
-- the usual [FlattenedExp] using the flatten function. Then we also transform
-- any access that features a replicated variable into its mirrored version
-- any access that is in the mirror-side of a Replicated item into its mirrored version
-- where each i is changed into i'. This is done by using vi=(variable "i",0)
-- (in Scale _ vi) for the plain (normal) version, and vi=(variable "i",1)
-- for the prime (mirror) version.
@ -245,45 +249,21 @@ type VarMap = Map.Map FlattenedExp Int
--
-- The remainder of the work (correctly pairing equations) is done by
-- squareAndPair.
makeReplicatedEquations :: [(A.Variable, A.Expression, A.Expression)] -> ([A.Expression],[A.Expression]) -> A.Expression ->
--
-- TODO probably want to take this into the PassM monad at some point, to use the Meta in the error message
-- TODO allow "background knowledge" in the form of other equalities and inequalities
makeEquations :: ParItems ([A.Expression],[A.Expression]) -> A.Expression ->
Either String [((A.Expression, A.Expression), VarMap, (EqualityProblem, InequalityProblem))]
makeReplicatedEquations repVars accesses bound
= do flattenedAccesses <- applyPairM (mapM copyAndFlatten) accesses
let flattenedAccessesMirror = applyPair (map mirrorAllVars) flattenedAccesses
bound' <- flatten bound
((v,h,repVars',repVarIndexes),s) <- (flip runStateT) Map.empty $
do repVars' <- mapM (\(v,s,c) ->
do s' <- lift (flatten s) >>= makeEquation s (error "Type is irrelevant for replication count")
>>= getSingleAccessItem "Modulo or Divide not allowed in replication start"
c' <- lift (flatten c) >>= makeEquation c (error "Type is irrelevant for replication count")
>>= getSingleAccessItem "Modulo or Divide not allowed in replication count"
return (v,s',c')) repVars
accesses' <- mapM (makeEquationWithPossibleRepBounds repVars') =<< makeEquationsWR flattenedAccesses
accesses'' <- mapM (makeEquationWithPossibleRepBounds repVars') =<< makeEquationsWR flattenedAccessesMirror
high <- makeEquation bound (error "Type is irrelevant for uppper bound") bound'
makeEquations accesses bound
= do bound' <- flatten bound
((v,h,repVarIndexes),s) <- (flip runStateT) Map.empty $
do (accesses',repVars) <- flip runStateT [] $ parItemToArrayAccessM mkEq accesses
high <- makeEquation bound (error "Type is irrelevant for upper bound") bound'
>>= getSingleAccessItem "Multiple possible upper bounds not supported"
repVarIndexes <- mapM (\(v,_,_) -> seqPair (varIndex (Scale 1 (v,0)), varIndex (Scale 1 (v,1)))) repVars
return (Replicated accesses' accesses'',high, repVars',repVarIndexes)
return $ squareAndPair repVarIndexes s [v] (amap (const 0) h, addConstant (-1) h)
return (accesses', high, nub repVars)
return $ squareAndPair repVarIndexes s v (amap (const 0) h, addConstant (-1) h)
where
copyAndFlatten :: A.Expression -> Either String (A.Expression, [FlattenedExp])
copyAndFlatten e = do f <- flatten e
return (e, f)
-- Mirrors all of repVars in the given equation
mirrorAllVars :: (A.Expression, [FlattenedExp]) -> (A.Expression, [FlattenedExp])
mirrorAllVars (e, f) = (e, foldl mirror f repVars)
where
mirror :: [FlattenedExp] -> (A.Variable, A.Expression, A.Expression) -> [FlattenedExp]
mirror exp (v,_,_) = setIndexVar v 1 exp
makeEquationsWR :: ([(A.Expression, [FlattenedExp])],[(A.Expression, [FlattenedExp])]) -> StateT (VarMap) (Either String) [ArrayAccess A.Expression]
makeEquationsWR (w,r) = do w' <- mapM (\(e,f) -> makeEquation e AAWrite f) w
r' <- mapM (\(e,f) -> makeEquation e AARead f) r
return (w' ++ r')
setIndexVar :: A.Variable -> Int -> [FlattenedExp] -> [FlattenedExp]
setIndexVar tv ti es = case mapAccumL (setIndexVar' tv ti) False es of
(_, es') -> es'
@ -306,13 +286,11 @@ makeReplicatedEquations repVars accesses bound
addPossibleRepBound' :: ArrayAccess label ->
(A.Variable, Int, EqualityConstraintEquation, EqualityConstraintEquation) ->
StateT (VarMap) (Either String) (ArrayAccess label)
-- addPossibleRepBound' (Group accesses) v = mapM (seqPair . transformPair return (flip addPossibleRepBound v)) accesses >>* Group
addPossibleRepBound' (Group accesses) v = sequence [addPossibleRepBound acc v >>* (\acc' -> (l,t,acc')) | (l,t,acc) <- accesses ] >>* Group
addPossibleRepBound' (Replicated acc0 acc1) v
= do acc0' <- mapM (flip addPossibleRepBound' v) acc0
acc1' <- mapM (flip addPossibleRepBound' v) acc1
return $ Replicated acc0' acc1'
-- addPossibleRepBound' (Single (l,t,acc)) v = addPossibleRepBound acc v >>* (\x -> Group [(l,t,x)])
addPossibleRepBound :: (EqualityConstraintEquation, EqualityProblem, InequalityProblem) ->
(A.Variable, Int, EqualityConstraintEquation, EqualityConstraintEquation) ->
@ -331,7 +309,40 @@ makeReplicatedEquations repVars accesses bound
where
newMin = minimum [fst $ bounds a, ind]
newMax = maximum [snd $ bounds a, ind]
mkEq :: [(A.Replicator, Bool)] -> ([A.Expression], [A.Expression]) -> StateT [(CoeffIndex, CoeffIndex)] (StateT VarMap (Either String)) [(A.Expression, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]
mkEq reps (ws,rs) = do repVarEqs <- mapM (liftF makeRepVarEq) reps
concatMapM (mkEq' repVarEqs) (ws' ++ rs')
where
ws' = zip (repeat AAWrite) ws
rs' = zip (repeat AARead) rs
makeRepVarEq :: (A.Replicator, Bool) -> StateT VarMap (Either String) (A.Variable, EqualityConstraintEquation, EqualityConstraintEquation)
makeRepVarEq (A.For m varName from for, _)
= do from' <- lift (flatten from) >>= makeEquation from (error "Type is irrelevant for replication start")
>>= getSingleAccessItem "Modulo or Divide not allowed in replication start"
for' <- lift (flatten for) >>= makeEquation for (error "Type is irrelevant for replication count")
>>= getSingleAccessItem "Modulo or Divide not allowed in replication count"
return (A.Variable m varName, from', for')
mkEq' :: [(A.Variable, EqualityConstraintEquation, EqualityConstraintEquation)] -> (ArrayAccessType, A.Expression) -> StateT [(CoeffIndex,CoeffIndex)] (StateT VarMap (Either String)) [(A.Expression, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]
mkEq' repVarEqs (aat, e)
= do f <- lift . lift $ flatten e
f' <- foldM mirrorFlaggedVars f reps
g <- lift $ makeEquationWithPossibleRepBounds repVarEqs =<< makeEquation e aat f'
case g of
Group g' -> return g'
_ -> throwError "Replicated group found unexpectedly"
mirrorFlaggedVars :: [FlattenedExp] -> (A.Replicator,Bool) -> StateT [(CoeffIndex,CoeffIndex)] (StateT VarMap (Either String)) [FlattenedExp]
mirrorFlaggedVars exp (_,False) = return exp
mirrorFlaggedVars exp (A.For m varName from for, True)
= do varIndexes <- lift $ seqPair (varIndex (Scale 1 (var,0)), varIndex (Scale 1 (var,1)))
modify (varIndexes :)
return $ setIndexVar var 1 exp
where
var = A.Variable m varName
-- Note that in all these functions, the divisor should always be positive!
-- Takes an expression, and transforms it into an expression like:
@ -462,38 +473,6 @@ getSingleItem :: MonadTrans m => String -> [(a,b,c)] -> m (Either String) a
getSingleItem _ [(item,_,_)] = lift $ return item
getSingleItem err _ = lift $ throwError err
-- | Given a list of (written,read) expressions, an expression representing the upper array bound, returns either an error
-- (because the expressions can't be handled, typically) or a set of equalities, inequalities and mapping from
-- (unique, munged) variable name to variable-index in the equations.
-- TODO probably want to take this into the PassM monad at some point, to use the Meta in the error message
-- TODO allow "background knowledge" in the form of other equalities and inequalities
makeEquations :: ParItems ([A.Expression],[A.Expression]) -> A.Expression -> Either String [((A.Expression, A.Expression), VarMap, (EqualityProblem, InequalityProblem))]
makeEquations es high = makeEquations' >>* uncurry3 (squareAndPair [])
where
-- | The body of makeEquations; returns the variable mapping, the list of (nx,ex) pairs and a pair
-- representing the upper and lower bounds of the array (inclusive).
makeEquations' :: Either String (VarMap, [ArrayAccess A.Expression], (EqualityConstraintEquation, EqualityConstraintEquation))
makeEquations' = do ((v,h),s) <- (flip runStateT) Map.empty $
do eqs <- parItemToArrayAccessM mkEq es
high' <- (lift $ flatten high) >>= makeEquation high (error "Type irrelevant for upper bound")
>>= getSingleAccessItem "Multiple possible upper bounds not supported"
return (eqs, high')
return (s,v,(amap (const 0) h, addConstant (-1) h))
mkEq :: [A.Replicator] -> ([A.Expression],[A.Expression]) -> StateT VarMap (Either String) [(A.Expression, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]
mkEq [] (ws,rs) = concatMapM mkEq' (ws' ++ rs')
where
ws' = zip (repeat AAWrite) ws
rs' = zip (repeat AARead) rs
mkEq' :: (ArrayAccessType, A.Expression) -> StateT VarMap (Either String) [(A.Expression, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]
mkEq' (aat, e) = do g <- makeEquation e aat =<< lift (flatten e)
case g of
Group g' -> return g'
_ -> throwError "Replicated group found unexpectedly"
-- | Finds the index associated with a particular variable; either by finding an existing index
-- or allocating a new one.
varIndex :: FlattenedExp -> StateT (VarMap) (Either String) Int

View File

@ -32,6 +32,7 @@ import Test.QuickCheck hiding (check)
import ArrayUsageCheck
import qualified AST as A
import Metadata
import Omega
import TestHarness
import TestUtils hiding (m)
@ -336,7 +337,7 @@ testMakeEquations = TestLabel "testMakeEquations" $ TestList
,testRep' (200,[((0,0),rep_i_mapping, [i === j],
ij_16 &&& [i <== j ++ con (-1)]
&&& leq [con 0, i, con 7] &&& leq [con 0, j, con 7])],
[(variable "i", intLiteral 1, intLiteral 6)],[exprVariable "i"],intLiteral 8)
("i", intLiteral 1, intLiteral 6),[exprVariable "i"],intLiteral 8)
,testRep' (201,
[((0,0),rep_i_mapping, [i === j],
@ -344,7 +345,7 @@ testMakeEquations = TestLabel "testMakeEquations" $ TestList
&&& leq [con 0, i, con 7] &&& leq [con 0, j, con 7])]
++ replicate 2 ((0,1),rep_i_mapping,[i === con 3], leq [con 1,i, con 6] &&& leq [con 0, i, con 7] &&& leq [con 0, con 3, con 7])
++ [((1,1),rep_i_mapping,[con 3 === con 3],concat $ replicate 2 (leq [con 0, con 3, con 7]))]
,[(variable "i", intLiteral 1, intLiteral 6)],[exprVariable "i", intLiteral 3],intLiteral 8)
,("i", intLiteral 1, intLiteral 6),[exprVariable "i", intLiteral 3],intLiteral 8)
,testRep' (202,[
((0,1),rep_i_mapping,[i === j ++ con 1],ij_16 &&& [i <== j ++ con (-1)] &&& leq [con 0, i, con 7] &&& leq [con 0, j ++ con 1, con 7])
@ -353,12 +354,13 @@ testMakeEquations = TestLabel "testMakeEquations" $ TestList
,((1,1),rep_i_mapping,[i === j],ij_16 &&& [i <== j ++ con (-1)] &&& leq [con 0, i ++ con 1, con 7] &&& leq [con 0, j ++ con 1, con 7])]
++ [((0,1),rep_i_mapping, [i === i ++ con 1], leq [con 1, i, con 6] &&& leq [con 1, i, con 6] &&& -- deliberate repeat
leq [con 0, i, con 7] &&& leq [con 0,i ++ con 1, con 7])]
,[(variable "i", intLiteral 1, intLiteral 6)],[exprVariable "i", buildExpr $ Dy (Var "i") A.Add (Lit $ intLiteral 1)],intLiteral 8)
,("i", intLiteral 1, intLiteral 6),[exprVariable "i", buildExpr $ Dy (Var "i") A.Add (Lit $ intLiteral 1)],intLiteral 8)
-- Only a constant:
,testRep' (210,[((0,0),rep_i_mapping,[con 4 === con 4],concat $ replicate 2 $ leq [con 0, con 4, con 7])]
,[(variable "i", intLiteral 1, intLiteral 6)],[intLiteral 4],intLiteral 8)
,("i", intLiteral 1, intLiteral 6),[intLiteral 4],intLiteral 8)
-- TODO test reads and writes are paired properly
]
where
-- These functions assume that you pair each list [x,y,z] as (x,y) (x,z) (y,z) in that order.
@ -381,11 +383,11 @@ testMakeEquations = TestLabel "testMakeEquations" $ TestList
TestCase $ assertEquivalentProblems ("testMakeEquations " ++ show ind) (zip [0..] exprs)
(map (transformTriple (applyPair (exprs !!)) id (uncurry makeConsistent)) $ map pairLatterTwo problems) =<< (checkRight $ makeEquations (makeParItems exprs) upperBound)
testRep' :: (Integer,[((Int, Int), VarMap,[HandyEq],[HandyIneq])],[(A.Variable, A.Expression, A.Expression)],[A.Expression],A.Expression) -> Test
testRep' (ind, problems, reps, exprs, upperBound) =
testRep' :: (Integer,[((Int, Int), VarMap,[HandyEq],[HandyIneq])],(String, A.Expression, A.Expression),[A.Expression],A.Expression) -> Test
testRep' (ind, problems, (repName, repFrom, repFor), exprs, upperBound) =
TestCase $ assertEquivalentProblems ("testMakeEquations " ++ show ind) (zip [0..] exprs)
(map (transformTriple (applyPair (exprs !!)) id (uncurry makeConsistent)) $ map pairLatterTwo problems)
=<< (checkRight $ makeReplicatedEquations reps (exprs,[]) upperBound)
=<< (checkRight $ makeEquations (RepParItem (A.For emptyMeta (simpleName repName) repFrom repFor) $ makeParItems exprs) upperBound)
pairLatterTwo (l,a,b,c) = (l,a,(b,c))

View File

@ -16,7 +16,7 @@ You should have received a copy of the GNU General Public License along
with this program. If not, see <http://www.gnu.org/licenses/>.
-}
module UsageCheck (checkPar, customVarCompare, Decl, labelFunctions, ParItems(..), transformParItems, Var(..), Vars(..)) where
module UsageCheck (checkPar, customVarCompare, Decl, labelFunctions, ParItems(..), transformParItems, Var(..), Vars(..), vars) where
import Data.Generics
import Data.Graph.Inductive