Merged makeEquations with makeReplicatedEquations and adjusted the tests accordingly
This commit is contained in:
parent
5a69459668
commit
349d3c5811
|
@ -16,7 +16,7 @@ You should have received a copy of the GNU General Public License along
|
|||
with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
-}
|
||||
|
||||
module ArrayUsageCheck (checkArrayUsage, FlattenedExp(..), makeEquations, makeReplicatedEquations, usageCheckPass, VarMap) where
|
||||
module ArrayUsageCheck (checkArrayUsage, FlattenedExp(..), makeEquations, usageCheckPass, VarMap) where
|
||||
|
||||
import Control.Monad.Error
|
||||
import Control.Monad.State
|
||||
|
@ -187,10 +187,13 @@ data ArrayAccess label =
|
|||
|
||||
data ArrayAccessType = AAWrite | AARead
|
||||
|
||||
parItemToArrayAccessM :: Monad m => ([A.Replicator] -> a -> m [(label, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]) -> ParItems a -> m [ArrayAccess label]
|
||||
parItemToArrayAccessM :: Monad m => ([(A.Replicator, Bool)] -> a -> m [(label, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]) -> ParItems a -> m [ArrayAccess label]
|
||||
parItemToArrayAccessM f (SeqItems xs) = sequence [concatMapM (f []) xs >>* Group]
|
||||
parItemToArrayAccessM f (ParItems ps) = concatMapM (parItemToArrayAccessM f) ps
|
||||
parItemToArrayAccessM f (RepParItem rep p) = parItemToArrayAccessM (\reps -> f (rep:reps)) p
|
||||
parItemToArrayAccessM f (RepParItem rep p)
|
||||
= do normal <- parItemToArrayAccessM (\reps -> f ((rep,False):reps)) p
|
||||
mirror <- parItemToArrayAccessM (\reps -> f ((rep,True):reps)) p
|
||||
return [Replicated normal mirror]
|
||||
|
||||
makeExpSet :: [FlattenedExp] -> Either String (Set.Set FlattenedExp)
|
||||
makeExpSet = foldM makeExpSet' Set.empty
|
||||
|
@ -225,13 +228,14 @@ makeExpSet = foldM makeExpSet' Set.empty
|
|||
|
||||
type VarMap = Map.Map FlattenedExp Int
|
||||
|
||||
-- | Given a list of (replicated variable, start, count), a list of (written,read) parallel array accesses,
|
||||
-- the length of the array, returns the problems.
|
||||
-- | Given a list of (written,read) expressions, an expression representing the upper array bound, returns either an error
|
||||
-- (because the expressions can't be handled, typically) or a set of equalities, inequalities and mapping from
|
||||
-- (unique, munged) variable name to variable-index in the equations.
|
||||
--
|
||||
-- The general strategy is as follows.
|
||||
-- For every array index (here termed an "access"), we transform it into
|
||||
-- the usual [FlattenedExp] using the flatten function. Then we also transform
|
||||
-- any access that features a replicated variable into its mirrored version
|
||||
-- any access that is in the mirror-side of a Replicated item into its mirrored version
|
||||
-- where each i is changed into i'. This is done by using vi=(variable "i",0)
|
||||
-- (in Scale _ vi) for the plain (normal) version, and vi=(variable "i",1)
|
||||
-- for the prime (mirror) version.
|
||||
|
@ -245,45 +249,21 @@ type VarMap = Map.Map FlattenedExp Int
|
|||
--
|
||||
-- The remainder of the work (correctly pairing equations) is done by
|
||||
-- squareAndPair.
|
||||
makeReplicatedEquations :: [(A.Variable, A.Expression, A.Expression)] -> ([A.Expression],[A.Expression]) -> A.Expression ->
|
||||
--
|
||||
-- TODO probably want to take this into the PassM monad at some point, to use the Meta in the error message
|
||||
-- TODO allow "background knowledge" in the form of other equalities and inequalities
|
||||
makeEquations :: ParItems ([A.Expression],[A.Expression]) -> A.Expression ->
|
||||
Either String [((A.Expression, A.Expression), VarMap, (EqualityProblem, InequalityProblem))]
|
||||
makeReplicatedEquations repVars accesses bound
|
||||
= do flattenedAccesses <- applyPairM (mapM copyAndFlatten) accesses
|
||||
let flattenedAccessesMirror = applyPair (map mirrorAllVars) flattenedAccesses
|
||||
bound' <- flatten bound
|
||||
((v,h,repVars',repVarIndexes),s) <- (flip runStateT) Map.empty $
|
||||
do repVars' <- mapM (\(v,s,c) ->
|
||||
do s' <- lift (flatten s) >>= makeEquation s (error "Type is irrelevant for replication count")
|
||||
>>= getSingleAccessItem "Modulo or Divide not allowed in replication start"
|
||||
c' <- lift (flatten c) >>= makeEquation c (error "Type is irrelevant for replication count")
|
||||
>>= getSingleAccessItem "Modulo or Divide not allowed in replication count"
|
||||
return (v,s',c')) repVars
|
||||
accesses' <- mapM (makeEquationWithPossibleRepBounds repVars') =<< makeEquationsWR flattenedAccesses
|
||||
accesses'' <- mapM (makeEquationWithPossibleRepBounds repVars') =<< makeEquationsWR flattenedAccessesMirror
|
||||
high <- makeEquation bound (error "Type is irrelevant for uppper bound") bound'
|
||||
makeEquations accesses bound
|
||||
= do bound' <- flatten bound
|
||||
((v,h,repVarIndexes),s) <- (flip runStateT) Map.empty $
|
||||
do (accesses',repVars) <- flip runStateT [] $ parItemToArrayAccessM mkEq accesses
|
||||
high <- makeEquation bound (error "Type is irrelevant for upper bound") bound'
|
||||
>>= getSingleAccessItem "Multiple possible upper bounds not supported"
|
||||
repVarIndexes <- mapM (\(v,_,_) -> seqPair (varIndex (Scale 1 (v,0)), varIndex (Scale 1 (v,1)))) repVars
|
||||
return (Replicated accesses' accesses'',high, repVars',repVarIndexes)
|
||||
return $ squareAndPair repVarIndexes s [v] (amap (const 0) h, addConstant (-1) h)
|
||||
return (accesses', high, nub repVars)
|
||||
return $ squareAndPair repVarIndexes s v (amap (const 0) h, addConstant (-1) h)
|
||||
|
||||
where
|
||||
copyAndFlatten :: A.Expression -> Either String (A.Expression, [FlattenedExp])
|
||||
copyAndFlatten e = do f <- flatten e
|
||||
return (e, f)
|
||||
|
||||
-- Mirrors all of repVars in the given equation
|
||||
mirrorAllVars :: (A.Expression, [FlattenedExp]) -> (A.Expression, [FlattenedExp])
|
||||
mirrorAllVars (e, f) = (e, foldl mirror f repVars)
|
||||
where
|
||||
mirror :: [FlattenedExp] -> (A.Variable, A.Expression, A.Expression) -> [FlattenedExp]
|
||||
mirror exp (v,_,_) = setIndexVar v 1 exp
|
||||
|
||||
makeEquationsWR :: ([(A.Expression, [FlattenedExp])],[(A.Expression, [FlattenedExp])]) -> StateT (VarMap) (Either String) [ArrayAccess A.Expression]
|
||||
makeEquationsWR (w,r) = do w' <- mapM (\(e,f) -> makeEquation e AAWrite f) w
|
||||
r' <- mapM (\(e,f) -> makeEquation e AARead f) r
|
||||
return (w' ++ r')
|
||||
|
||||
|
||||
setIndexVar :: A.Variable -> Int -> [FlattenedExp] -> [FlattenedExp]
|
||||
setIndexVar tv ti es = case mapAccumL (setIndexVar' tv ti) False es of
|
||||
(_, es') -> es'
|
||||
|
@ -306,13 +286,11 @@ makeReplicatedEquations repVars accesses bound
|
|||
addPossibleRepBound' :: ArrayAccess label ->
|
||||
(A.Variable, Int, EqualityConstraintEquation, EqualityConstraintEquation) ->
|
||||
StateT (VarMap) (Either String) (ArrayAccess label)
|
||||
-- addPossibleRepBound' (Group accesses) v = mapM (seqPair . transformPair return (flip addPossibleRepBound v)) accesses >>* Group
|
||||
addPossibleRepBound' (Group accesses) v = sequence [addPossibleRepBound acc v >>* (\acc' -> (l,t,acc')) | (l,t,acc) <- accesses ] >>* Group
|
||||
addPossibleRepBound' (Replicated acc0 acc1) v
|
||||
= do acc0' <- mapM (flip addPossibleRepBound' v) acc0
|
||||
acc1' <- mapM (flip addPossibleRepBound' v) acc1
|
||||
return $ Replicated acc0' acc1'
|
||||
-- addPossibleRepBound' (Single (l,t,acc)) v = addPossibleRepBound acc v >>* (\x -> Group [(l,t,x)])
|
||||
|
||||
addPossibleRepBound :: (EqualityConstraintEquation, EqualityProblem, InequalityProblem) ->
|
||||
(A.Variable, Int, EqualityConstraintEquation, EqualityConstraintEquation) ->
|
||||
|
@ -331,7 +309,40 @@ makeReplicatedEquations repVars accesses bound
|
|||
where
|
||||
newMin = minimum [fst $ bounds a, ind]
|
||||
newMax = maximum [snd $ bounds a, ind]
|
||||
|
||||
|
||||
mkEq :: [(A.Replicator, Bool)] -> ([A.Expression], [A.Expression]) -> StateT [(CoeffIndex, CoeffIndex)] (StateT VarMap (Either String)) [(A.Expression, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]
|
||||
mkEq reps (ws,rs) = do repVarEqs <- mapM (liftF makeRepVarEq) reps
|
||||
concatMapM (mkEq' repVarEqs) (ws' ++ rs')
|
||||
where
|
||||
ws' = zip (repeat AAWrite) ws
|
||||
rs' = zip (repeat AARead) rs
|
||||
|
||||
makeRepVarEq :: (A.Replicator, Bool) -> StateT VarMap (Either String) (A.Variable, EqualityConstraintEquation, EqualityConstraintEquation)
|
||||
makeRepVarEq (A.For m varName from for, _)
|
||||
= do from' <- lift (flatten from) >>= makeEquation from (error "Type is irrelevant for replication start")
|
||||
>>= getSingleAccessItem "Modulo or Divide not allowed in replication start"
|
||||
for' <- lift (flatten for) >>= makeEquation for (error "Type is irrelevant for replication count")
|
||||
>>= getSingleAccessItem "Modulo or Divide not allowed in replication count"
|
||||
return (A.Variable m varName, from', for')
|
||||
|
||||
mkEq' :: [(A.Variable, EqualityConstraintEquation, EqualityConstraintEquation)] -> (ArrayAccessType, A.Expression) -> StateT [(CoeffIndex,CoeffIndex)] (StateT VarMap (Either String)) [(A.Expression, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]
|
||||
mkEq' repVarEqs (aat, e)
|
||||
= do f <- lift . lift $ flatten e
|
||||
f' <- foldM mirrorFlaggedVars f reps
|
||||
g <- lift $ makeEquationWithPossibleRepBounds repVarEqs =<< makeEquation e aat f'
|
||||
case g of
|
||||
Group g' -> return g'
|
||||
_ -> throwError "Replicated group found unexpectedly"
|
||||
|
||||
mirrorFlaggedVars :: [FlattenedExp] -> (A.Replicator,Bool) -> StateT [(CoeffIndex,CoeffIndex)] (StateT VarMap (Either String)) [FlattenedExp]
|
||||
mirrorFlaggedVars exp (_,False) = return exp
|
||||
mirrorFlaggedVars exp (A.For m varName from for, True)
|
||||
= do varIndexes <- lift $ seqPair (varIndex (Scale 1 (var,0)), varIndex (Scale 1 (var,1)))
|
||||
modify (varIndexes :)
|
||||
return $ setIndexVar var 1 exp
|
||||
where
|
||||
var = A.Variable m varName
|
||||
|
||||
-- Note that in all these functions, the divisor should always be positive!
|
||||
|
||||
-- Takes an expression, and transforms it into an expression like:
|
||||
|
@ -462,38 +473,6 @@ getSingleItem :: MonadTrans m => String -> [(a,b,c)] -> m (Either String) a
|
|||
getSingleItem _ [(item,_,_)] = lift $ return item
|
||||
getSingleItem err _ = lift $ throwError err
|
||||
|
||||
|
||||
-- | Given a list of (written,read) expressions, an expression representing the upper array bound, returns either an error
|
||||
-- (because the expressions can't be handled, typically) or a set of equalities, inequalities and mapping from
|
||||
-- (unique, munged) variable name to variable-index in the equations.
|
||||
-- TODO probably want to take this into the PassM monad at some point, to use the Meta in the error message
|
||||
-- TODO allow "background knowledge" in the form of other equalities and inequalities
|
||||
makeEquations :: ParItems ([A.Expression],[A.Expression]) -> A.Expression -> Either String [((A.Expression, A.Expression), VarMap, (EqualityProblem, InequalityProblem))]
|
||||
makeEquations es high = makeEquations' >>* uncurry3 (squareAndPair [])
|
||||
where
|
||||
|
||||
-- | The body of makeEquations; returns the variable mapping, the list of (nx,ex) pairs and a pair
|
||||
-- representing the upper and lower bounds of the array (inclusive).
|
||||
makeEquations' :: Either String (VarMap, [ArrayAccess A.Expression], (EqualityConstraintEquation, EqualityConstraintEquation))
|
||||
makeEquations' = do ((v,h),s) <- (flip runStateT) Map.empty $
|
||||
do eqs <- parItemToArrayAccessM mkEq es
|
||||
high' <- (lift $ flatten high) >>= makeEquation high (error "Type irrelevant for upper bound")
|
||||
>>= getSingleAccessItem "Multiple possible upper bounds not supported"
|
||||
return (eqs, high')
|
||||
return (s,v,(amap (const 0) h, addConstant (-1) h))
|
||||
|
||||
mkEq :: [A.Replicator] -> ([A.Expression],[A.Expression]) -> StateT VarMap (Either String) [(A.Expression, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]
|
||||
mkEq [] (ws,rs) = concatMapM mkEq' (ws' ++ rs')
|
||||
where
|
||||
ws' = zip (repeat AAWrite) ws
|
||||
rs' = zip (repeat AARead) rs
|
||||
|
||||
mkEq' :: (ArrayAccessType, A.Expression) -> StateT VarMap (Either String) [(A.Expression, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]
|
||||
mkEq' (aat, e) = do g <- makeEquation e aat =<< lift (flatten e)
|
||||
case g of
|
||||
Group g' -> return g'
|
||||
_ -> throwError "Replicated group found unexpectedly"
|
||||
|
||||
-- | Finds the index associated with a particular variable; either by finding an existing index
|
||||
-- or allocating a new one.
|
||||
varIndex :: FlattenedExp -> StateT (VarMap) (Either String) Int
|
||||
|
|
|
@ -32,6 +32,7 @@ import Test.QuickCheck hiding (check)
|
|||
|
||||
import ArrayUsageCheck
|
||||
import qualified AST as A
|
||||
import Metadata
|
||||
import Omega
|
||||
import TestHarness
|
||||
import TestUtils hiding (m)
|
||||
|
@ -336,7 +337,7 @@ testMakeEquations = TestLabel "testMakeEquations" $ TestList
|
|||
,testRep' (200,[((0,0),rep_i_mapping, [i === j],
|
||||
ij_16 &&& [i <== j ++ con (-1)]
|
||||
&&& leq [con 0, i, con 7] &&& leq [con 0, j, con 7])],
|
||||
[(variable "i", intLiteral 1, intLiteral 6)],[exprVariable "i"],intLiteral 8)
|
||||
("i", intLiteral 1, intLiteral 6),[exprVariable "i"],intLiteral 8)
|
||||
|
||||
,testRep' (201,
|
||||
[((0,0),rep_i_mapping, [i === j],
|
||||
|
@ -344,7 +345,7 @@ testMakeEquations = TestLabel "testMakeEquations" $ TestList
|
|||
&&& leq [con 0, i, con 7] &&& leq [con 0, j, con 7])]
|
||||
++ replicate 2 ((0,1),rep_i_mapping,[i === con 3], leq [con 1,i, con 6] &&& leq [con 0, i, con 7] &&& leq [con 0, con 3, con 7])
|
||||
++ [((1,1),rep_i_mapping,[con 3 === con 3],concat $ replicate 2 (leq [con 0, con 3, con 7]))]
|
||||
,[(variable "i", intLiteral 1, intLiteral 6)],[exprVariable "i", intLiteral 3],intLiteral 8)
|
||||
,("i", intLiteral 1, intLiteral 6),[exprVariable "i", intLiteral 3],intLiteral 8)
|
||||
|
||||
,testRep' (202,[
|
||||
((0,1),rep_i_mapping,[i === j ++ con 1],ij_16 &&& [i <== j ++ con (-1)] &&& leq [con 0, i, con 7] &&& leq [con 0, j ++ con 1, con 7])
|
||||
|
@ -353,12 +354,13 @@ testMakeEquations = TestLabel "testMakeEquations" $ TestList
|
|||
,((1,1),rep_i_mapping,[i === j],ij_16 &&& [i <== j ++ con (-1)] &&& leq [con 0, i ++ con 1, con 7] &&& leq [con 0, j ++ con 1, con 7])]
|
||||
++ [((0,1),rep_i_mapping, [i === i ++ con 1], leq [con 1, i, con 6] &&& leq [con 1, i, con 6] &&& -- deliberate repeat
|
||||
leq [con 0, i, con 7] &&& leq [con 0,i ++ con 1, con 7])]
|
||||
,[(variable "i", intLiteral 1, intLiteral 6)],[exprVariable "i", buildExpr $ Dy (Var "i") A.Add (Lit $ intLiteral 1)],intLiteral 8)
|
||||
,("i", intLiteral 1, intLiteral 6),[exprVariable "i", buildExpr $ Dy (Var "i") A.Add (Lit $ intLiteral 1)],intLiteral 8)
|
||||
|
||||
-- Only a constant:
|
||||
,testRep' (210,[((0,0),rep_i_mapping,[con 4 === con 4],concat $ replicate 2 $ leq [con 0, con 4, con 7])]
|
||||
,[(variable "i", intLiteral 1, intLiteral 6)],[intLiteral 4],intLiteral 8)
|
||||
,("i", intLiteral 1, intLiteral 6),[intLiteral 4],intLiteral 8)
|
||||
|
||||
-- TODO test reads and writes are paired properly
|
||||
]
|
||||
where
|
||||
-- These functions assume that you pair each list [x,y,z] as (x,y) (x,z) (y,z) in that order.
|
||||
|
@ -381,11 +383,11 @@ testMakeEquations = TestLabel "testMakeEquations" $ TestList
|
|||
TestCase $ assertEquivalentProblems ("testMakeEquations " ++ show ind) (zip [0..] exprs)
|
||||
(map (transformTriple (applyPair (exprs !!)) id (uncurry makeConsistent)) $ map pairLatterTwo problems) =<< (checkRight $ makeEquations (makeParItems exprs) upperBound)
|
||||
|
||||
testRep' :: (Integer,[((Int, Int), VarMap,[HandyEq],[HandyIneq])],[(A.Variable, A.Expression, A.Expression)],[A.Expression],A.Expression) -> Test
|
||||
testRep' (ind, problems, reps, exprs, upperBound) =
|
||||
testRep' :: (Integer,[((Int, Int), VarMap,[HandyEq],[HandyIneq])],(String, A.Expression, A.Expression),[A.Expression],A.Expression) -> Test
|
||||
testRep' (ind, problems, (repName, repFrom, repFor), exprs, upperBound) =
|
||||
TestCase $ assertEquivalentProblems ("testMakeEquations " ++ show ind) (zip [0..] exprs)
|
||||
(map (transformTriple (applyPair (exprs !!)) id (uncurry makeConsistent)) $ map pairLatterTwo problems)
|
||||
=<< (checkRight $ makeReplicatedEquations reps (exprs,[]) upperBound)
|
||||
=<< (checkRight $ makeEquations (RepParItem (A.For emptyMeta (simpleName repName) repFrom repFor) $ makeParItems exprs) upperBound)
|
||||
|
||||
pairLatterTwo (l,a,b,c) = (l,a,(b,c))
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ You should have received a copy of the GNU General Public License along
|
|||
with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
-}
|
||||
|
||||
module UsageCheck (checkPar, customVarCompare, Decl, labelFunctions, ParItems(..), transformParItems, Var(..), Vars(..)) where
|
||||
module UsageCheck (checkPar, customVarCompare, Decl, labelFunctions, ParItems(..), transformParItems, Var(..), Vars(..), vars) where
|
||||
|
||||
import Data.Generics
|
||||
import Data.Graph.Inductive
|
||||
|
|
Loading…
Reference in New Issue
Block a user