diff --git a/transformations/ArrayUsageCheck.hs b/transformations/ArrayUsageCheck.hs index efa135b..a302b12 100644 --- a/transformations/ArrayUsageCheck.hs +++ b/transformations/ArrayUsageCheck.hs @@ -21,10 +21,11 @@ module ArrayUsageCheck where import Control.Monad.Error import Control.Monad.State import Data.Array.IArray -import Data.Generics +import Data.Generics hiding (GT) import Data.List import qualified Data.Map as Map import Data.Maybe +import qualified Data.Set as Set import qualified AST as A import CompState @@ -56,46 +57,129 @@ checkArrayUsage tree = (mapM_ checkPar $ listify (const True) tree) >> return tr = -- liftIO (putStr $ "Checking: " ++ show (arrName, indexes)) >> case makeEquations indexes (makeConstant emptyMeta 1000000) of Left err -> die $ "Could not work with array indexes for array \"" ++ arrName ++ "\": " ++ err - Right (varMapping, problem) -> - case uncurry solveProblem problem of + Right problems -> + case mapM (\(vm,p) -> seqPair (return vm,uncurry solveProblem p)) problems of -- No solutions; no worries! Nothing -> return () - Just vm -> do sol <- formatSolution varMapping (getCounterEqs vm) - arrName' <- getRealName arrName - dieP m $ "Overlapping indexes of array \"" ++ arrName' ++ "\" when: " ++ sol + Just ((varMapping,vm):_) -> do sol <- formatSolution varMapping (getCounterEqs vm) + arrName' <- getRealName (A.Name undefined undefined arrName) + dieP m $ "Overlapping indexes of array \"" ++ arrName' ++ "\" when: " ++ sol - formatSolution :: Map.Map String CoeffIndex -> Map.Map CoeffIndex Integer -> PassM String + formatSolution :: VarMap -> Map.Map CoeffIndex Integer -> PassM String formatSolution varToIndex indexToConst = do names <- mapM valOfVar $ Map.assocs varToIndex return $ concat $ intersperse " , " $ catMaybes names where - valOfVar (varName,k) = case Map.lookup k indexToConst of + valOfVar (varExp,k) = case Map.lookup k indexToConst of Nothing -> return Nothing - Just val -> do varName' <- getRealName varName - return $ Just $ varName' ++ " = " ++ show val + Just val -> do varExp' <- showFlattenedExp varExp + return $ Just $ varExp' ++ " = " ++ show val -- TODO this is surely defined elsewhere already? - getRealName :: String -> PassM String - getRealName s = lookupName (A.Name undefined undefined s) >>* A.ndOrigName - + getRealName :: A.Name -> PassM String + getRealName n = lookupName n >>* A.ndOrigName + + showFlattenedExp :: FlattenedExp -> PassM String + showFlattenedExp (Const n) = return $ show n + showFlattenedExp (Scale n (A.Variable _ vn)) = do vn' <- getRealName vn + return $ (show n) ++ "*" ++ vn' + showFlattenedExp (Modulo top bottom) + = do top' <- showFlattenedExpSet top + bottom' <- showFlattenedExpSet bottom + return $ "(" ++ top' ++ " REM " ++ bottom' ++ ")" + showFlattenedExp (Divide top bottom) + = do top' <- showFlattenedExpSet top + bottom' <- showFlattenedExpSet bottom + return $ "(" ++ top' ++ " / " ++ bottom' ++ ")" + + showFlattenedExpSet :: Set.Set FlattenedExp -> PassM String + showFlattenedExpSet s = liftM concat $ sequence $ intersperse (return " + ") $ map showFlattenedExp $ Set.toList s -- | A type for inside makeEquations: -data FlattenedExp = Const Integer | Scale Integer A.Variable deriving (Eq,Show) +data FlattenedExp + = Const Integer + | Scale Integer A.Variable + | Modulo (Set.Set FlattenedExp) (Set.Set FlattenedExp) + | Divide (Set.Set FlattenedExp) (Set.Set FlattenedExp) + +instance Eq FlattenedExp where + a == b = EQ == compare a b + +instance Ord FlattenedExp where + compare (Const _) (Const _) = EQ + compare (Const _) _ = LT + compare _ (Const _) = GT + compare (Scale _ lv) (Scale _ rv) = customVarCompare lv rv + compare (Scale {}) _ = LT + compare _ (Scale {}) = GT + compare (Modulo ltop lbottom) (Modulo rtop rbottom) + = combineCompare [compare ltop lbottom, compare lbottom rbottom] + compare (Modulo {}) _ = LT + compare _ (Modulo {}) = GT + compare (Divide ltop lbottom) (Divide rtop rbottom) + = combineCompare [compare ltop lbottom, compare lbottom rbottom] + +customVarCompare :: A.Variable -> A.Variable -> Ordering +customVarCompare (A.Variable _ (A.Name _ _ lname)) (A.Variable _ (A.Name _ _ rname)) = compare lname rname +-- TODO the rest + +makeExpSet :: [FlattenedExp] -> Either String (Set.Set FlattenedExp) +makeExpSet = foldM makeExpSet' Set.empty + where + makeExpSet' :: Set.Set FlattenedExp -> FlattenedExp -> Either String (Set.Set FlattenedExp) + makeExpSet' accum (Const n) = return $ insert (addConst n) (Const n) accum + makeExpSet' accum (Scale n v) = return $ insert (addScale n v) (Scale n v) accum + makeExpSet' accum m@(Modulo {}) | Set.member m accum = throwError "Cannot have repeated REM items in an expression" + | otherwise = return $ Set.insert m accum + makeExpSet' accum d@(Divide {}) | Set.member d accum = throwError "Cannot have repeated (/) items in an expression" + | otherwise = return $ Set.insert d accum + + insert :: (FlattenedExp -> Set.Set FlattenedExp -> Maybe (Set.Set FlattenedExp)) -> FlattenedExp -> Set.Set FlattenedExp -> Set.Set FlattenedExp + insert f e s = case Set.fold insert' (Set.empty,False) s of + (s',True) -> s' + _ -> Set.insert e s + where + insert' :: FlattenedExp -> (Set.Set FlattenedExp, Bool) -> (Set.Set FlattenedExp, Bool) + insert' e (s,b) = case f e s of + Just s' -> (s', True) + Nothing -> (Set.insert e s, False) + + addConst :: Integer -> FlattenedExp -> Set.Set FlattenedExp -> Maybe (Set.Set FlattenedExp) + addConst x (Const n) s = Just $ Set.insert (Const (n + x)) s + addConst _ _ _ = Nothing + + addScale :: Integer -> A.Variable -> FlattenedExp -> Set.Set FlattenedExp -> Maybe (Set.Set FlattenedExp) + addScale x lv (Scale n rv) s | EQ == customVarCompare lv rv = Just $ Set.insert (Scale (x + n) rv) s + | otherwise = Nothing + addScale _ _ _ _ = Nothing + +type VarMap = Map.Map FlattenedExp Int -- | Given a list of expressions, an expression representing the upper array bound, returns either an error -- (because the expressions can't be handled, typically) or a set of equalities, inequalities and mapping from -- (unique, munged) variable name to variable-index in the equations. -- TODO probably want to take this into the PassM monad at some point, to use the Meta in the error message -makeEquations :: [A.Expression] -> A.Expression -> Either String (Map.Map String Int, (EqualityProblem, InequalityProblem)) -makeEquations es high = makeEquations' >>* (\(s,v,lh) -> (s,squareEquations (pairEqs v, getIneqs lh v))) +makeEquations :: [A.Expression] -> A.Expression -> Either String [(VarMap, (EqualityProblem, InequalityProblem))] +makeEquations es high = makeEquations' >>* (\(s,v,lh) -> [(s,squareEquations eqIneq) | eqIneq <- pairEqsAndBounds v lh]) where + {- + makeProblem :: Map.Map String Int + -> [(EqualityConstraintEquation,EqualityProblem, InequalityProblem)] + -> (EqualityConstraintEquation,EqualityConstraintEquation) + -> [(Map.Map String Int, (EqualityProblem, InequalityProblem))] + makeProblem varMap problems lowHigh = [(varMap, (eq,ineq)) | (eq,ineq) <- pairEqsAndBounds problems lowHigh] + -} + -- | The body of makeEquations; returns the variable mapping, the list of (nx,ex) pairs and a pair -- representing the upper and lower bounds of the array (inclusive). - makeEquations' :: Either String (Map.Map String Int, [EqualityConstraintEquation], (EqualityConstraintEquation, EqualityConstraintEquation)) + makeEquations' :: Either String (VarMap, [(EqualityConstraintEquation,EqualityProblem,InequalityProblem)], (EqualityConstraintEquation, EqualityConstraintEquation)) makeEquations' = do ((v,h),s) <- (flip runStateT) Map.empty $ do flattened <- lift (mapM flatten es) eqs <- mapM makeEquation flattened high' <- (lift $ flatten high) >>= makeEquation - return (eqs,high') + high'' <- case high' of + [(h,_,_)] -> return h + _ -> throwError "Multiple possible upper bounds not supported" + return (concat eqs,high'') return (s,v,(amap (const 0) h, h)) -- Note that in all these functions, the divisor should always be positive! @@ -110,11 +194,15 @@ makeEquations es high = makeEquations' >>* (\(s,v,lh) -> (s,squareEquations (pai flatten (A.Dyadic m op lhs rhs) | op == A.Add = combine' (flatten lhs) (flatten rhs) | op == A.Subtr = combine' (flatten lhs) (liftM (scale (-1)) $ flatten rhs) | op == A.Mul = multiplyOut' (flatten lhs) (flatten rhs) - -- TODO Div (either constant on bottom, or common (variable) factor(s) with top) + | op == A.Rem = liftM2L Modulo (flatten lhs) (flatten rhs) + | op == A.Div = liftM2L Divide (flatten lhs) (flatten rhs) | otherwise = throwError ("Unhandleable operator found in expression: " ++ show op) flatten (A.ExprVariable _ v) = return [Scale 1 v] flatten other = throwError ("Unhandleable item found in expression: " ++ show other) +-- liftM2L :: (Ord a, Ord b, Monad m) => (Set.Set a -> Set.Set b -> c) -> m [a] -> m [b] -> m [c] + liftM2L f x y = liftM (:[]) $ liftM2 f (x >>= makeExpSet) (y >>= makeExpSet) + --TODO we need to handle lots more different expression types in future. multiplyOut' :: Either String [FlattenedExp] -> Either String [FlattenedExp] -> Either String [FlattenedExp] @@ -151,21 +239,35 @@ makeEquations es high = makeEquations' >>* (\(s,v,lh) -> (s,squareEquations (pai -- | Finds the index associated with a particular variable; either by finding an existing index -- or allocating a new one. - varIndex :: A.Variable -> StateT (Map.Map String Int) (Either String) Int - varIndex (A.Variable _ (A.Name _ _ varName)) + varIndex :: FlattenedExp -> StateT (VarMap) (Either String) Int + varIndex (Scale _ var@(A.Variable _ (A.Name _ _ varName))) = do st <- get - let (st',ind) = case Map.lookup varName st of + let (st',ind) = case Map.lookup (Scale 1 var) st of Just val -> (st,val) Nothing -> let newId = (1 + (maximum $ 0 : Map.elems st)) in - (Map.insert varName newId st, newId) + (Map.insert (Scale 1 var) newId st, newId) put st' return ind + varIndex mod@(Modulo top bottom) + = do st <- get + let (st',ind) = case Map.lookup mod st of + Just val -> (st,val) + Nothing -> let newId = (1 + (maximum $ 0 : Map.elems st)) in + (Map.insert mod newId st, newId) + put st' + return ind -- | Pairs all possible combinations of the list of equations. - pairEqs :: [EqualityConstraintEquation] -> [EqualityConstraintEquation] - pairEqs = filter (any (/= 0) . elems) . map (uncurry pairEqs') . allPairs + pairEqsAndBounds :: [(EqualityConstraintEquation, EqualityProblem, InequalityProblem)] -> (EqualityConstraintEquation, EqualityConstraintEquation) -> [(EqualityProblem, InequalityProblem)] + pairEqsAndBounds items bounds = (map (filterProblems . uncurry pairEqs') . allPairs) items where - pairEqs' ex ey = arrayZipWith' 0 (-) ex ey + pairEqs' :: (EqualityConstraintEquation, EqualityProblem, InequalityProblem) + -> (EqualityConstraintEquation, EqualityProblem, InequalityProblem) + -> (EqualityProblem, InequalityProblem) + pairEqs' (ex,eqX,ineqX) (ey,eqY,ineqY) = ([arrayZipWith' 0 (-) ex ey] ++ eqX ++ eqY, ineqX ++ ineqY ++ getIneqs bounds [ex,ey]) + + filterProblems :: (EqualityProblem, InequalityProblem) -> (EqualityProblem, InequalityProblem) + filterProblems = transformPair (filter (any (/= 0) . elems)) (filter (any (/= 0) . elems)) -- | Given a (low,high) bound (typically: array dimensions), and a list of equations ex, -- forms the possible inequalities: @@ -182,18 +284,73 @@ makeEquations es high = makeEquations' >>* (\(s,v,lh) -> (s,squareEquations (pai addEq = arrayZipWith' 0 (+) - -- | Given ex, forms an equation (e) from the latter part, and returns it - makeEquation :: [FlattenedExp] -> StateT (Map.Map String Int) (Either String) EqualityConstraintEquation + -- | Given an expression, forms equations (and accompanying additional equation-sets) and returns it + makeEquation :: [FlattenedExp] -> StateT (VarMap) (Either String) [(EqualityConstraintEquation, EqualityProblem, InequalityProblem)] makeEquation summedItems - = do eqs <- foldM makeEquation' Map.empty summedItems - return $ mapToArray eqs + = do eqs <- process summedItems + return $ map (transformTriple mapToArray (map mapToArray) (map mapToArray)) eqs where - makeEquation' :: Map.Map Int Integer -> FlattenedExp -> StateT (Map.Map String Int) (Either String) (Map.Map Int Integer) + process = foldM makeEquation' empty + + makeEquation' :: [(Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])] -> FlattenedExp -> StateT (VarMap) (Either String) [(Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])] makeEquation' m (Const n) = return $ add (0,n) m - makeEquation' m (Scale n v) = varIndex v >>* (\ind -> add (ind, n) m) + makeEquation' m sc@(Scale n v) = varIndex sc >>* (\ind -> add (ind, n) m) + makeEquation' m mod@(Modulo top bottom) + = do top' <- process $ Set.toList top + top'' <- case top' of + [(t,_,_)] -> return t + _ -> throwError "Modulo or divide not allowed in the numerator of Modulo" + bottom' <- process $ Set.toList bottom + topIndex <- varIndex mod + case onlyConst (Set.toList bottom) of + Just bottomConst -> + let add_x_plus_my = zipMap plus top'' . zipMap plus (Map.fromList [(topIndex,bottomConst)]) in + return $ + -- The zero option (x = 0, x REM y = 0): + ( map (transformTriple id (++ [top'']) id) m) + ++ + -- The top-is-positive option: + ( map (transformTriple add_x_plus_my id (++ + -- x >= 1 + [zipMap plus (Map.fromList [(0,-1)]) top'' + -- m <= 0 + ,Map.fromList [(topIndex,-1)] + -- x + my + 1 - |y| <= 0 + ,Map.map negate $ add_x_plus_my $ Map.fromList [(0,1 - bottomConst)] + -- x + my >= 0 + ,add_x_plus_my $ Map.empty]) + ) m) ++ + -- The top-is-negative option: + ( map (transformTriple add_x_plus_my id (++ + -- x <= -1 + [add' (0,-1) $ Map.map negate top'' + -- m >= 0 + ,Map.fromList [(topIndex,1)] + -- x + my - 1 + |y| >= 0 + ,add_x_plus_my $ Map.fromList [(0,bottomConst - 1)] + -- x + my <= 0 + ,Map.map negate $ add_x_plus_my Map.empty]) + ) m) + _ -> throwError "TODO - Variable divisor for modulo" + + makeEquation' m (Divide top bottom) = throwError "TODO Divide" - add :: (Int,Integer) -> Map.Map Int Integer -> Map.Map Int Integer - add = uncurry (Map.insertWith (+)) + onlyConst :: [FlattenedExp] -> Maybe Integer + onlyConst [] = Just 0 + onlyConst ((Const n):es) = liftM2 (+) (return n) $ onlyConst es + onlyConst _ = Nothing + + empty :: [(Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])] + empty = [(Map.empty,[],[])] + + plus :: Num n => Maybe n -> Maybe n -> Maybe n + plus x y = Just $ (fromMaybe 0 x) + (fromMaybe 0 y) + + add' :: (Int,Integer) -> Map.Map Int Integer -> Map.Map Int Integer + add' (m,n) = Map.insertWith (+) m n + + add :: (Int,Integer) -> [(Map.Map Int Integer,a,b)] -> [(Map.Map Int Integer,a,b)] + add (m,n) = map $ transformTriple (Map.insertWith (+) m n) id id -- | Converts a map to an array. Any missing elements in the middle of the bounds are given the value zero. -- Could probably be moved to Utils diff --git a/transformations/ArrayUsageCheckTest.hs b/transformations/ArrayUsageCheckTest.hs index f2765df..4fb5cf2 100644 --- a/transformations/ArrayUsageCheckTest.hs +++ b/transformations/ArrayUsageCheckTest.hs @@ -23,6 +23,7 @@ import Data.Array.IArray import Data.List import qualified Data.Map as Map import Data.Maybe +import qualified Data.Set as Set import Prelude hiding ((**),fail) import Test.HUnit import Test.QuickCheck hiding (check) @@ -232,29 +233,35 @@ showMaybe _ Nothing = "Nothing" testMakeEquations :: Test testMakeEquations = TestList [ - test (0,Map.empty,[con 0 === con 1],leq [con 0,con 0,con 7] &&& leq [con 0,con 1,con 7], - [intLiteral 0, intLiteral 1],intLiteral 7) + test (0,[(Map.empty,[con 0 === con 1],leq [con 0,con 1,con 7] &&& leq [con 0,con 2,con 7])], + [intLiteral 1, intLiteral 2],intLiteral 7) - ,test (1,i_mapping,[i === con 3],leq [con 0,con 3,con 7] &&& leq [con 0,i,con 7], + ,test (1,[(i_mapping,[i === con 3],leq [con 0,con 3,con 7] &&& leq [con 0,i,con 7])], [exprVariable "i",intLiteral 3],intLiteral 7) - ,test (2,ij_mapping,[i === j],leq [con 0,i,con 7] &&& leq [con 0,j,con 7], + ,test (2,[(ij_mapping,[i === j],leq [con 0,i,con 7] &&& leq [con 0,j,con 7])], [exprVariable "i",exprVariable "j"],intLiteral 7) - ,test (3,ij_mapping,[i ++ con 3 === j],leq [con 0,i ++ con 3,con 7] &&& leq [con 0,j,con 7], + ,test (3,[(ij_mapping,[i ++ con 3 === j],leq [con 0,i ++ con 3,con 7] &&& leq [con 0,j,con 7])], [buildExpr $ Dy (Var "i") A.Add (Lit $ intLiteral 3),exprVariable "j"],intLiteral 7) - ,test (4,ij_mapping,[2 ** i === j],leq [con 0,2 ** i,con 7] &&& leq [con 0,j,con 7], + ,test (4,[(ij_mapping,[2 ** i === j],leq [con 0,2 ** i,con 7] &&& leq [con 0,j,con 7])], [buildExpr $ Dy (Var "i") A.Mul (Lit $ intLiteral 2),exprVariable "j"],intLiteral 7) + + ,test (10,[(i_mod_mapping 3,[i ++ 3 ** j === con 4], leq [con 0,i ++ 3 ** j,con 7])], + [buildExpr $ Dy (Var "i") A.Rem (Lit $ intLiteral 3),intLiteral 4],intLiteral 7) ] where - test :: (Integer,Map.Map String CoeffIndex,[HandyEq],[HandyIneq],[A.Expression],A.Expression) -> Test - test (ind, mpping, eqs, ineqs, exprs, upperBound) = - TestCase $ assertEquivalentProblems ("testMakeEquations" ++ show ind) - (mpping,makeConsistent eqs ineqs) =<< (checkRight $ makeEquations exprs upperBound) + test :: (Integer,[(VarMap,[HandyEq],[HandyIneq])],[A.Expression],A.Expression) -> Test + test (ind, problems, exprs, upperBound) = + TestCase $ assertEquivalentProblems ("testMakeEquations " ++ show ind) + (map (transformPair id (uncurry makeConsistent)) $ map pairLatterTwo problems) =<< (checkRight $ makeEquations exprs upperBound) - i_mapping = Map.singleton "i" 1 - ij_mapping = Map.fromList [("i",1),("j",2)] + pairLatterTwo (a,b,c) = (a,(b,c)) + + i_mapping = Map.singleton (Scale 1 $ variable "i") 1 + ij_mapping = Map.fromList [(Scale 1 $ variable "i",1),(Scale 1 $ variable "j",2)] + i_mod_mapping n = Map.fromList [(Scale 1 $ variable "i",1),(Modulo (Set.singleton $ Scale 1 $ variable "i") (Set.singleton $ Const n),2)] testIndexes :: Test testIndexes = TestList @@ -369,7 +376,7 @@ testIndexes = TestList -- | Given one mapping and a second mapping, gives a function that converts the indexes -- from one to the indexes of the next. If any of the keys in the map don't match -- (i.e. if (keys m0 /= keys m1)) Nothing will be returned -generateMapping :: Map.Map String CoeffIndex -> Map.Map String CoeffIndex -> Maybe [(CoeffIndex,CoeffIndex)] +generateMapping :: VarMap -> VarMap -> Maybe [(CoeffIndex,CoeffIndex)] generateMapping m0 m1 = if Map.keys m0 /= Map.keys m1 then Nothing else Just (Map.elems $ zipMap f m0 m1) where f (Just x) (Just y) = Just (x,y) @@ -388,12 +395,17 @@ translateEquations mp = seqPair . transformPair (mapM swapColumns) (mapM swapCol swapColumns' (x,v) = transformMaybe (\y -> (y,v)) $ transformMaybe fst $ find ((== x) . snd) mp -- | Asserts that the two problems are equivalent, once you take into account the potentially different variable mappings -assertEquivalentProblems :: String -> (Map.Map String CoeffIndex, (EqualityProblem, InequalityProblem)) -> (Map.Map String CoeffIndex, (EqualityProblem, InequalityProblem)) -> Assertion -assertEquivalentProblems title exp act = assertEqual title translatedExp (Just $ sortP $ snd act) +assertEquivalentProblems :: String -> [(VarMap, (EqualityProblem, InequalityProblem))] -> [(VarMap, (EqualityProblem, InequalityProblem))] -> Assertion +assertEquivalentProblems title exp act = mapM_ (uncurry $ assertEqual title) $ map (uncurry transform) $ zip exp act where - sortP (eq,ineq) = (sort $ map normaliseEquality eq, sort ineq) + transform :: (VarMap, (EqualityProblem, InequalityProblem)) -> (VarMap, (EqualityProblem, InequalityProblem)) -> + ( Maybe (EqualityProblem, InequalityProblem), Maybe (EqualityProblem, InequalityProblem) ) + transform exp act = (translatedExp, Just $ sortP $ snd act) + where + sortP :: (EqualityProblem, InequalityProblem) -> (EqualityProblem, InequalityProblem) + sortP (eq,ineq) = (sort $ map normaliseEquality eq, sort ineq) - translatedExp = ( generateMapping (fst exp) (fst act) >>= flip translateEquations (snd exp)) >>* sortP + translatedExp = ( generateMapping (fst exp) (fst act) >>= flip translateEquations (snd exp)) >>* sortP checkRight :: Show a => Either a b -> IO b checkRight (Left err) = assertFailure ("Not Right: " ++ show err) >> return undefined