diff --git a/transformations/ArrayUsageCheck.hs b/transformations/ArrayUsageCheck.hs index bcc64c9..ed226b1 100644 --- a/transformations/ArrayUsageCheck.hs +++ b/transformations/ArrayUsageCheck.hs @@ -16,7 +16,7 @@ You should have received a copy of the GNU General Public License along with this program. If not, see . -} -module ArrayUsageCheck (checkArrayUsage, FlattenedExp(..), makeEquations, VarMap) where +module ArrayUsageCheck (checkArrayUsage, FlattenedExp(..), makeEquations, makeReplicatedEquations, VarMap) where import Control.Monad.Error import Control.Monad.State @@ -175,7 +175,81 @@ makeExpSet = foldM makeExpSet' Set.empty type VarMap = Map.Map FlattenedExp Int +-- | Given a list of (replicated variable, start, count), a list of parallel array accesses, the length of the array, +-- returns the problems +makeReplicatedEquations :: [(A.Variable, A.Expression, A.Expression)] -> [A.Expression] -> A.Expression -> + Either String [(VarMap, (EqualityProblem, InequalityProblem))] +makeReplicatedEquations repVars accesses bound + = do flattenedAccesses <- mapM flatten accesses + let flattenedAccessesMirror = concatMap (\(v,_,_) -> mapMaybe (setIndexVar v 1) flattenedAccesses) repVars + -- TODO only compare with a mirror that involves the same replicated variable (TODO or not?) + bound' <- flatten bound + ((v,h,repVars'),s) <- (flip runStateT) Map.empty $ + do accesses' <- liftM2 (++) (mapM makeEquation flattenedAccesses) (mapM makeEquation flattenedAccessesMirror) + high <- makeEquation bound' >>= getSingleItem "Multiple possible upper bounds not supported" + repVars' <- mapM (\(v,s,c) -> + do s' <- lift (flatten s) >>= makeEquation >>= getSingleItem "Modulo or Divide not allowed in replication start" + c' <- lift (flatten c) >>= makeEquation >>= getSingleItem "Modulo or Divide not allowed in replication count" + return (v,s',c')) repVars + return (accesses',high, repVars') + repBounds <- makeRepBound repVars' s + return $ concatMap (\repBound -> squareAndPair repBound s v (amap (const 0) h, addConstant (-1) h)) repBounds + where + setIndexVar :: A.Variable -> Int -> [FlattenedExp] -> Maybe [FlattenedExp] + setIndexVar tv ti es = case mapAccumL (setIndexVar' tv ti) False es of + (True, es') -> Just es' + _ -> Nothing + + setIndexVar' :: A.Variable -> Int -> Bool -> FlattenedExp -> (Bool,FlattenedExp) + setIndexVar' tv ti b s@(Scale n (v,_)) + | EQ == customVarCompare tv v = (True,Scale n (v,ti)) + | otherwise = (b,s) + setIndexVar' _ _ b e = (b,e) + + makeRepBound :: + [(A.Variable, EqualityConstraintEquation, EqualityConstraintEquation)] -> + VarMap -> + Either String [InequalityProblem] + makeRepBound repVars vm = doPairs $ map (makeBound vm) repVars + where + doPairs :: Monad m => [m (InequalityProblem,InequalityProblem)] -> m [InequalityProblem] + doPairs prs = do prs' <- sequence prs + return $ doPairs' prs' + where + doPairs' :: [([a],[a])] -> [[a]] + doPairs' [] = [[]] + doPairs' ((a,b):abs) = map (a ++) (doPairs' abs) ++ map (b ++) (doPairs' abs) + + makeBound :: VarMap -> (A.Variable, EqualityConstraintEquation, EqualityConstraintEquation) -> Either String (InequalityProblem,InequalityProblem) + makeBound vm (repVar, start, count) + = do plain <- findIndex repVar 0 + prime <- findIndex repVar 1 + return + ( + -- start <= i gives i - start >= 0 + [add (plain,1) (amap negate start) + -- i <= j - 1 gives j - 1 - i >= 0 + ,simpleArray [(0,-1),(prime,1),(plain,-1)] + -- j <= start + count - 1 gives start + count - j - 1 >= 0 + ,add (0,-1) $ add (prime, -1) $ arrayZipWith (+) start count] + , + -- start <= j gives j - start >= 0 + [add (prime,1) (amap negate start) + -- j <= i - 1 gives i - 1 - j >= 0 + ,simpleArray [(0,-1),(plain,1),(prime,-1)] + -- i <= start + count - 1 gives start + count - i - 1 >= 0 + ,add (0,-1) $ add (plain, -1) $ arrayZipWith (+) start count] + ) + where + findIndex v n = Map.lookup (Scale 1 (v,n)) vm + + add :: (Int,Integer) -> Array Int Integer -> Array Int Integer + add (ind,val) a = (makeSize (newMin, newMax) 0 a) // [(ind, (arrayLookupWithDefault 0 a ind) + val)] + where + newMin = minimum [fst $ bounds a, ind] + newMax = maximum [snd $ bounds a, ind] + -- Note that in all these functions, the divisor should always be positive! -- Takes an expression, and transforms it into an expression like: diff --git a/transformations/ArrayUsageCheckTest.hs b/transformations/ArrayUsageCheckTest.hs index ad0fe29..2ab01b9 100644 --- a/transformations/ArrayUsageCheckTest.hs +++ b/transformations/ArrayUsageCheckTest.hs @@ -330,6 +330,19 @@ testMakeEquations = TestList leq [j ++ con 1, i ++ k, con 0] &&& leq [con 0, i ++ k, con 7] &&& leq [con 0, con 3, con 7]) ], [buildExpr $ Dy (Var "i") A.Rem (Var "j"), intLiteral 3], intLiteral 8) + ,testRep (200,both_rep_i ([i === j],leq [con 1, i, j ++ con (-1), con 5] &&& leq [con 0, i, con 7] &&& leq [con 0, j, con 7]), + [(variable "i", intLiteral 1, intLiteral 6)],[exprVariable "i"],intLiteral 8) + + ,testRep (201,both_rep_i ([i === j],leq [con 1, i, j ++ con (-1), con 5] &&& leq [con 0, i, con 7] &&& leq [con 0, j, con 7]) + ++ [(rep_i_mapping,[i === con 3], leq [con 1,i, con 6] &&& leq [con 0, i, con 7] &&& leq [con 0, con 3, con 7])], + [(variable "i", intLiteral 1, intLiteral 6)],[exprVariable "i", intLiteral 3],intLiteral 8) + + ,testRep (202,[ + (rep_i_mapping,[i === j ++ con 1],leq [con 1, i, j ++ con (-1), con 5] &&& leq [con 0, i, con 7] &&& leq [con 0, j, con 7]) + ,(rep_i_mapping,[i ++ con 1 === j],leq [con 1, i, j ++ con (-1), con 5] &&& leq [con 0, i, con 7] &&& leq [con 0, j, con 7])] + ++ replicate 2 (rep_i_mapping,[i === j],leq [con 1, i, j ++ con (-1), con 5] &&& leq [con 0, i, con 7] &&& leq [con 0, j, con 7]) + ,[(variable "i", intLiteral 1, intLiteral 6)],[exprVariable "i", buildExpr $ Dy (Var "i") A.Add (Lit $ intLiteral 1)],intLiteral 8) + ] where test :: (Integer,[(VarMap,[HandyEq],[HandyIneq])],[A.Expression],A.Expression) -> Test @@ -337,7 +350,16 @@ testMakeEquations = TestList TestCase $ assertEquivalentProblems ("testMakeEquations " ++ show ind) (map (transformPair id (uncurry makeConsistent)) $ map pairLatterTwo problems) =<< (checkRight $ makeEquations exprs upperBound) + testRep :: (Integer,[(VarMap,[HandyEq],[HandyIneq])],[(A.Variable, A.Expression, A.Expression)],[A.Expression],A.Expression) -> Test + testRep (ind, problems, reps, exprs, upperBound) = + TestCase $ assertEquivalentProblems ("testMakeEquations " ++ show ind) + (map (transformPair id (uncurry makeConsistent)) $ map pairLatterTwo problems) + =<< (checkRight $ makeReplicatedEquations reps exprs upperBound) + pairLatterTwo (a,b,c) = (a,(b,c)) + + joinMapping :: [VarMap] -> ([HandyEq],[HandyIneq]) -> [(VarMap,[HandyEq],[HandyIneq])] + joinMapping vms (eq,ineq) = map (\vm -> (vm,eq,ineq)) vms i_mapping = Map.singleton (Scale 1 $ (variable "i",0)) 1 ij_mapping = Map.fromList [(Scale 1 $ (variable "i",0),1),(Scale 1 $ (variable "j",0),2)] @@ -351,7 +373,12 @@ testMakeEquations = TestList ,(Modulo (Set.singleton $ Scale 1 $ (variable "i",0)) (Set.singleton $ Const m),2) ,(Modulo (Set.fromList [Scale 1 $ (variable "i",0), Const 1]) (Set.singleton $ Const n),3) ] - + + rep_i_mapping = Map.fromList [((Scale 1 (variable "i",0)),1), ((Scale 1 (variable "i",1)),2)] + rep_i_mapping' = Map.fromList [((Scale 1 (variable "i",0)),2), ((Scale 1 (variable "i",1)),1)] + + both_rep_i = joinMapping [rep_i_mapping, rep_i_mapping'] + -- Helper functions for i REM 2 vs (i + 1) REM 4. Each one is a pair of equalities, inequalities rr_i_zero = ([i === con 0], leq [con 0,con 0,con 7]) rr_ip1_zero = ([i ++ con 1 === con 0], leq [con 0,con 0,con 7])