Changed the array usage checking and all the tests to support modulo and division

This patch is unavoidably large (no easy way to split it down).  The code compiles, but the modulo test (which is currently wrong anyway) fails at the moment
This commit is contained in:
Neil Brown 2008-01-15 17:08:15 +00:00
parent 8cfa9e3cb0
commit 918b9e7af7
2 changed files with 220 additions and 51 deletions

View File

@ -21,10 +21,11 @@ module ArrayUsageCheck where
import Control.Monad.Error
import Control.Monad.State
import Data.Array.IArray
import Data.Generics
import Data.Generics hiding (GT)
import Data.List
import qualified Data.Map as Map
import Data.Maybe
import qualified Data.Set as Set
import qualified AST as A
import CompState
@ -56,46 +57,129 @@ checkArrayUsage tree = (mapM_ checkPar $ listify (const True) tree) >> return tr
= -- liftIO (putStr $ "Checking: " ++ show (arrName, indexes)) >>
case makeEquations indexes (makeConstant emptyMeta 1000000) of
Left err -> die $ "Could not work with array indexes for array \"" ++ arrName ++ "\": " ++ err
Right (varMapping, problem) ->
case uncurry solveProblem problem of
Right problems ->
case mapM (\(vm,p) -> seqPair (return vm,uncurry solveProblem p)) problems of
-- No solutions; no worries!
Nothing -> return ()
Just vm -> do sol <- formatSolution varMapping (getCounterEqs vm)
arrName' <- getRealName arrName
dieP m $ "Overlapping indexes of array \"" ++ arrName' ++ "\" when: " ++ sol
Just ((varMapping,vm):_) -> do sol <- formatSolution varMapping (getCounterEqs vm)
arrName' <- getRealName (A.Name undefined undefined arrName)
dieP m $ "Overlapping indexes of array \"" ++ arrName' ++ "\" when: " ++ sol
formatSolution :: Map.Map String CoeffIndex -> Map.Map CoeffIndex Integer -> PassM String
formatSolution :: VarMap -> Map.Map CoeffIndex Integer -> PassM String
formatSolution varToIndex indexToConst = do names <- mapM valOfVar $ Map.assocs varToIndex
return $ concat $ intersperse " , " $ catMaybes names
where
valOfVar (varName,k) = case Map.lookup k indexToConst of
valOfVar (varExp,k) = case Map.lookup k indexToConst of
Nothing -> return Nothing
Just val -> do varName' <- getRealName varName
return $ Just $ varName' ++ " = " ++ show val
Just val -> do varExp' <- showFlattenedExp varExp
return $ Just $ varExp' ++ " = " ++ show val
-- TODO this is surely defined elsewhere already?
getRealName :: String -> PassM String
getRealName s = lookupName (A.Name undefined undefined s) >>* A.ndOrigName
getRealName :: A.Name -> PassM String
getRealName n = lookupName n >>* A.ndOrigName
showFlattenedExp :: FlattenedExp -> PassM String
showFlattenedExp (Const n) = return $ show n
showFlattenedExp (Scale n (A.Variable _ vn)) = do vn' <- getRealName vn
return $ (show n) ++ "*" ++ vn'
showFlattenedExp (Modulo top bottom)
= do top' <- showFlattenedExpSet top
bottom' <- showFlattenedExpSet bottom
return $ "(" ++ top' ++ " REM " ++ bottom' ++ ")"
showFlattenedExp (Divide top bottom)
= do top' <- showFlattenedExpSet top
bottom' <- showFlattenedExpSet bottom
return $ "(" ++ top' ++ " / " ++ bottom' ++ ")"
showFlattenedExpSet :: Set.Set FlattenedExp -> PassM String
showFlattenedExpSet s = liftM concat $ sequence $ intersperse (return " + ") $ map showFlattenedExp $ Set.toList s
-- | A type for inside makeEquations:
data FlattenedExp = Const Integer | Scale Integer A.Variable deriving (Eq,Show)
data FlattenedExp
= Const Integer
| Scale Integer A.Variable
| Modulo (Set.Set FlattenedExp) (Set.Set FlattenedExp)
| Divide (Set.Set FlattenedExp) (Set.Set FlattenedExp)
instance Eq FlattenedExp where
a == b = EQ == compare a b
instance Ord FlattenedExp where
compare (Const _) (Const _) = EQ
compare (Const _) _ = LT
compare _ (Const _) = GT
compare (Scale _ lv) (Scale _ rv) = customVarCompare lv rv
compare (Scale {}) _ = LT
compare _ (Scale {}) = GT
compare (Modulo ltop lbottom) (Modulo rtop rbottom)
= combineCompare [compare ltop lbottom, compare lbottom rbottom]
compare (Modulo {}) _ = LT
compare _ (Modulo {}) = GT
compare (Divide ltop lbottom) (Divide rtop rbottom)
= combineCompare [compare ltop lbottom, compare lbottom rbottom]
customVarCompare :: A.Variable -> A.Variable -> Ordering
customVarCompare (A.Variable _ (A.Name _ _ lname)) (A.Variable _ (A.Name _ _ rname)) = compare lname rname
-- TODO the rest
makeExpSet :: [FlattenedExp] -> Either String (Set.Set FlattenedExp)
makeExpSet = foldM makeExpSet' Set.empty
where
makeExpSet' :: Set.Set FlattenedExp -> FlattenedExp -> Either String (Set.Set FlattenedExp)
makeExpSet' accum (Const n) = return $ insert (addConst n) (Const n) accum
makeExpSet' accum (Scale n v) = return $ insert (addScale n v) (Scale n v) accum
makeExpSet' accum m@(Modulo {}) | Set.member m accum = throwError "Cannot have repeated REM items in an expression"
| otherwise = return $ Set.insert m accum
makeExpSet' accum d@(Divide {}) | Set.member d accum = throwError "Cannot have repeated (/) items in an expression"
| otherwise = return $ Set.insert d accum
insert :: (FlattenedExp -> Set.Set FlattenedExp -> Maybe (Set.Set FlattenedExp)) -> FlattenedExp -> Set.Set FlattenedExp -> Set.Set FlattenedExp
insert f e s = case Set.fold insert' (Set.empty,False) s of
(s',True) -> s'
_ -> Set.insert e s
where
insert' :: FlattenedExp -> (Set.Set FlattenedExp, Bool) -> (Set.Set FlattenedExp, Bool)
insert' e (s,b) = case f e s of
Just s' -> (s', True)
Nothing -> (Set.insert e s, False)
addConst :: Integer -> FlattenedExp -> Set.Set FlattenedExp -> Maybe (Set.Set FlattenedExp)
addConst x (Const n) s = Just $ Set.insert (Const (n + x)) s
addConst _ _ _ = Nothing
addScale :: Integer -> A.Variable -> FlattenedExp -> Set.Set FlattenedExp -> Maybe (Set.Set FlattenedExp)
addScale x lv (Scale n rv) s | EQ == customVarCompare lv rv = Just $ Set.insert (Scale (x + n) rv) s
| otherwise = Nothing
addScale _ _ _ _ = Nothing
type VarMap = Map.Map FlattenedExp Int
-- | Given a list of expressions, an expression representing the upper array bound, returns either an error
-- (because the expressions can't be handled, typically) or a set of equalities, inequalities and mapping from
-- (unique, munged) variable name to variable-index in the equations.
-- TODO probably want to take this into the PassM monad at some point, to use the Meta in the error message
makeEquations :: [A.Expression] -> A.Expression -> Either String (Map.Map String Int, (EqualityProblem, InequalityProblem))
makeEquations es high = makeEquations' >>* (\(s,v,lh) -> (s,squareEquations (pairEqs v, getIneqs lh v)))
makeEquations :: [A.Expression] -> A.Expression -> Either String [(VarMap, (EqualityProblem, InequalityProblem))]
makeEquations es high = makeEquations' >>* (\(s,v,lh) -> [(s,squareEquations eqIneq) | eqIneq <- pairEqsAndBounds v lh])
where
{-
makeProblem :: Map.Map String Int
-> [(EqualityConstraintEquation,EqualityProblem, InequalityProblem)]
-> (EqualityConstraintEquation,EqualityConstraintEquation)
-> [(Map.Map String Int, (EqualityProblem, InequalityProblem))]
makeProblem varMap problems lowHigh = [(varMap, (eq,ineq)) | (eq,ineq) <- pairEqsAndBounds problems lowHigh]
-}
-- | The body of makeEquations; returns the variable mapping, the list of (nx,ex) pairs and a pair
-- representing the upper and lower bounds of the array (inclusive).
makeEquations' :: Either String (Map.Map String Int, [EqualityConstraintEquation], (EqualityConstraintEquation, EqualityConstraintEquation))
makeEquations' :: Either String (VarMap, [(EqualityConstraintEquation,EqualityProblem,InequalityProblem)], (EqualityConstraintEquation, EqualityConstraintEquation))
makeEquations' = do ((v,h),s) <- (flip runStateT) Map.empty $
do flattened <- lift (mapM flatten es)
eqs <- mapM makeEquation flattened
high' <- (lift $ flatten high) >>= makeEquation
return (eqs,high')
high'' <- case high' of
[(h,_,_)] -> return h
_ -> throwError "Multiple possible upper bounds not supported"
return (concat eqs,high'')
return (s,v,(amap (const 0) h, h))
-- Note that in all these functions, the divisor should always be positive!
@ -110,11 +194,15 @@ makeEquations es high = makeEquations' >>* (\(s,v,lh) -> (s,squareEquations (pai
flatten (A.Dyadic m op lhs rhs) | op == A.Add = combine' (flatten lhs) (flatten rhs)
| op == A.Subtr = combine' (flatten lhs) (liftM (scale (-1)) $ flatten rhs)
| op == A.Mul = multiplyOut' (flatten lhs) (flatten rhs)
-- TODO Div (either constant on bottom, or common (variable) factor(s) with top)
| op == A.Rem = liftM2L Modulo (flatten lhs) (flatten rhs)
| op == A.Div = liftM2L Divide (flatten lhs) (flatten rhs)
| otherwise = throwError ("Unhandleable operator found in expression: " ++ show op)
flatten (A.ExprVariable _ v) = return [Scale 1 v]
flatten other = throwError ("Unhandleable item found in expression: " ++ show other)
-- liftM2L :: (Ord a, Ord b, Monad m) => (Set.Set a -> Set.Set b -> c) -> m [a] -> m [b] -> m [c]
liftM2L f x y = liftM (:[]) $ liftM2 f (x >>= makeExpSet) (y >>= makeExpSet)
--TODO we need to handle lots more different expression types in future.
multiplyOut' :: Either String [FlattenedExp] -> Either String [FlattenedExp] -> Either String [FlattenedExp]
@ -151,21 +239,35 @@ makeEquations es high = makeEquations' >>* (\(s,v,lh) -> (s,squareEquations (pai
-- | Finds the index associated with a particular variable; either by finding an existing index
-- or allocating a new one.
varIndex :: A.Variable -> StateT (Map.Map String Int) (Either String) Int
varIndex (A.Variable _ (A.Name _ _ varName))
varIndex :: FlattenedExp -> StateT (VarMap) (Either String) Int
varIndex (Scale _ var@(A.Variable _ (A.Name _ _ varName)))
= do st <- get
let (st',ind) = case Map.lookup varName st of
let (st',ind) = case Map.lookup (Scale 1 var) st of
Just val -> (st,val)
Nothing -> let newId = (1 + (maximum $ 0 : Map.elems st)) in
(Map.insert varName newId st, newId)
(Map.insert (Scale 1 var) newId st, newId)
put st'
return ind
varIndex mod@(Modulo top bottom)
= do st <- get
let (st',ind) = case Map.lookup mod st of
Just val -> (st,val)
Nothing -> let newId = (1 + (maximum $ 0 : Map.elems st)) in
(Map.insert mod newId st, newId)
put st'
return ind
-- | Pairs all possible combinations of the list of equations.
pairEqs :: [EqualityConstraintEquation] -> [EqualityConstraintEquation]
pairEqs = filter (any (/= 0) . elems) . map (uncurry pairEqs') . allPairs
pairEqsAndBounds :: [(EqualityConstraintEquation, EqualityProblem, InequalityProblem)] -> (EqualityConstraintEquation, EqualityConstraintEquation) -> [(EqualityProblem, InequalityProblem)]
pairEqsAndBounds items bounds = (map (filterProblems . uncurry pairEqs') . allPairs) items
where
pairEqs' ex ey = arrayZipWith' 0 (-) ex ey
pairEqs' :: (EqualityConstraintEquation, EqualityProblem, InequalityProblem)
-> (EqualityConstraintEquation, EqualityProblem, InequalityProblem)
-> (EqualityProblem, InequalityProblem)
pairEqs' (ex,eqX,ineqX) (ey,eqY,ineqY) = ([arrayZipWith' 0 (-) ex ey] ++ eqX ++ eqY, ineqX ++ ineqY ++ getIneqs bounds [ex,ey])
filterProblems :: (EqualityProblem, InequalityProblem) -> (EqualityProblem, InequalityProblem)
filterProblems = transformPair (filter (any (/= 0) . elems)) (filter (any (/= 0) . elems))
-- | Given a (low,high) bound (typically: array dimensions), and a list of equations ex,
-- forms the possible inequalities:
@ -182,18 +284,73 @@ makeEquations es high = makeEquations' >>* (\(s,v,lh) -> (s,squareEquations (pai
addEq = arrayZipWith' 0 (+)
-- | Given ex, forms an equation (e) from the latter part, and returns it
makeEquation :: [FlattenedExp] -> StateT (Map.Map String Int) (Either String) EqualityConstraintEquation
-- | Given an expression, forms equations (and accompanying additional equation-sets) and returns it
makeEquation :: [FlattenedExp] -> StateT (VarMap) (Either String) [(EqualityConstraintEquation, EqualityProblem, InequalityProblem)]
makeEquation summedItems
= do eqs <- foldM makeEquation' Map.empty summedItems
return $ mapToArray eqs
= do eqs <- process summedItems
return $ map (transformTriple mapToArray (map mapToArray) (map mapToArray)) eqs
where
makeEquation' :: Map.Map Int Integer -> FlattenedExp -> StateT (Map.Map String Int) (Either String) (Map.Map Int Integer)
process = foldM makeEquation' empty
makeEquation' :: [(Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])] -> FlattenedExp -> StateT (VarMap) (Either String) [(Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])]
makeEquation' m (Const n) = return $ add (0,n) m
makeEquation' m (Scale n v) = varIndex v >>* (\ind -> add (ind, n) m)
makeEquation' m sc@(Scale n v) = varIndex sc >>* (\ind -> add (ind, n) m)
makeEquation' m mod@(Modulo top bottom)
= do top' <- process $ Set.toList top
top'' <- case top' of
[(t,_,_)] -> return t
_ -> throwError "Modulo or divide not allowed in the numerator of Modulo"
bottom' <- process $ Set.toList bottom
topIndex <- varIndex mod
case onlyConst (Set.toList bottom) of
Just bottomConst ->
let add_x_plus_my = zipMap plus top'' . zipMap plus (Map.fromList [(topIndex,bottomConst)]) in
return $
-- The zero option (x = 0, x REM y = 0):
( map (transformTriple id (++ [top'']) id) m)
++
-- The top-is-positive option:
( map (transformTriple add_x_plus_my id (++
-- x >= 1
[zipMap plus (Map.fromList [(0,-1)]) top''
-- m <= 0
,Map.fromList [(topIndex,-1)]
-- x + my + 1 - |y| <= 0
,Map.map negate $ add_x_plus_my $ Map.fromList [(0,1 - bottomConst)]
-- x + my >= 0
,add_x_plus_my $ Map.empty])
) m) ++
-- The top-is-negative option:
( map (transformTriple add_x_plus_my id (++
-- x <= -1
[add' (0,-1) $ Map.map negate top''
-- m >= 0
,Map.fromList [(topIndex,1)]
-- x + my - 1 + |y| >= 0
,add_x_plus_my $ Map.fromList [(0,bottomConst - 1)]
-- x + my <= 0
,Map.map negate $ add_x_plus_my Map.empty])
) m)
_ -> throwError "TODO - Variable divisor for modulo"
makeEquation' m (Divide top bottom) = throwError "TODO Divide"
add :: (Int,Integer) -> Map.Map Int Integer -> Map.Map Int Integer
add = uncurry (Map.insertWith (+))
onlyConst :: [FlattenedExp] -> Maybe Integer
onlyConst [] = Just 0
onlyConst ((Const n):es) = liftM2 (+) (return n) $ onlyConst es
onlyConst _ = Nothing
empty :: [(Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])]
empty = [(Map.empty,[],[])]
plus :: Num n => Maybe n -> Maybe n -> Maybe n
plus x y = Just $ (fromMaybe 0 x) + (fromMaybe 0 y)
add' :: (Int,Integer) -> Map.Map Int Integer -> Map.Map Int Integer
add' (m,n) = Map.insertWith (+) m n
add :: (Int,Integer) -> [(Map.Map Int Integer,a,b)] -> [(Map.Map Int Integer,a,b)]
add (m,n) = map $ transformTriple (Map.insertWith (+) m n) id id
-- | Converts a map to an array. Any missing elements in the middle of the bounds are given the value zero.
-- Could probably be moved to Utils

View File

@ -23,6 +23,7 @@ import Data.Array.IArray
import Data.List
import qualified Data.Map as Map
import Data.Maybe
import qualified Data.Set as Set
import Prelude hiding ((**),fail)
import Test.HUnit
import Test.QuickCheck hiding (check)
@ -232,29 +233,35 @@ showMaybe _ Nothing = "Nothing"
testMakeEquations :: Test
testMakeEquations = TestList
[
test (0,Map.empty,[con 0 === con 1],leq [con 0,con 0,con 7] &&& leq [con 0,con 1,con 7],
[intLiteral 0, intLiteral 1],intLiteral 7)
test (0,[(Map.empty,[con 0 === con 1],leq [con 0,con 1,con 7] &&& leq [con 0,con 2,con 7])],
[intLiteral 1, intLiteral 2],intLiteral 7)
,test (1,i_mapping,[i === con 3],leq [con 0,con 3,con 7] &&& leq [con 0,i,con 7],
,test (1,[(i_mapping,[i === con 3],leq [con 0,con 3,con 7] &&& leq [con 0,i,con 7])],
[exprVariable "i",intLiteral 3],intLiteral 7)
,test (2,ij_mapping,[i === j],leq [con 0,i,con 7] &&& leq [con 0,j,con 7],
,test (2,[(ij_mapping,[i === j],leq [con 0,i,con 7] &&& leq [con 0,j,con 7])],
[exprVariable "i",exprVariable "j"],intLiteral 7)
,test (3,ij_mapping,[i ++ con 3 === j],leq [con 0,i ++ con 3,con 7] &&& leq [con 0,j,con 7],
,test (3,[(ij_mapping,[i ++ con 3 === j],leq [con 0,i ++ con 3,con 7] &&& leq [con 0,j,con 7])],
[buildExpr $ Dy (Var "i") A.Add (Lit $ intLiteral 3),exprVariable "j"],intLiteral 7)
,test (4,ij_mapping,[2 ** i === j],leq [con 0,2 ** i,con 7] &&& leq [con 0,j,con 7],
,test (4,[(ij_mapping,[2 ** i === j],leq [con 0,2 ** i,con 7] &&& leq [con 0,j,con 7])],
[buildExpr $ Dy (Var "i") A.Mul (Lit $ intLiteral 2),exprVariable "j"],intLiteral 7)
,test (10,[(i_mod_mapping 3,[i ++ 3 ** j === con 4], leq [con 0,i ++ 3 ** j,con 7])],
[buildExpr $ Dy (Var "i") A.Rem (Lit $ intLiteral 3),intLiteral 4],intLiteral 7)
]
where
test :: (Integer,Map.Map String CoeffIndex,[HandyEq],[HandyIneq],[A.Expression],A.Expression) -> Test
test (ind, mpping, eqs, ineqs, exprs, upperBound) =
TestCase $ assertEquivalentProblems ("testMakeEquations" ++ show ind)
(mpping,makeConsistent eqs ineqs) =<< (checkRight $ makeEquations exprs upperBound)
test :: (Integer,[(VarMap,[HandyEq],[HandyIneq])],[A.Expression],A.Expression) -> Test
test (ind, problems, exprs, upperBound) =
TestCase $ assertEquivalentProblems ("testMakeEquations " ++ show ind)
(map (transformPair id (uncurry makeConsistent)) $ map pairLatterTwo problems) =<< (checkRight $ makeEquations exprs upperBound)
i_mapping = Map.singleton "i" 1
ij_mapping = Map.fromList [("i",1),("j",2)]
pairLatterTwo (a,b,c) = (a,(b,c))
i_mapping = Map.singleton (Scale 1 $ variable "i") 1
ij_mapping = Map.fromList [(Scale 1 $ variable "i",1),(Scale 1 $ variable "j",2)]
i_mod_mapping n = Map.fromList [(Scale 1 $ variable "i",1),(Modulo (Set.singleton $ Scale 1 $ variable "i") (Set.singleton $ Const n),2)]
testIndexes :: Test
testIndexes = TestList
@ -369,7 +376,7 @@ testIndexes = TestList
-- | Given one mapping and a second mapping, gives a function that converts the indexes
-- from one to the indexes of the next. If any of the keys in the map don't match
-- (i.e. if (keys m0 /= keys m1)) Nothing will be returned
generateMapping :: Map.Map String CoeffIndex -> Map.Map String CoeffIndex -> Maybe [(CoeffIndex,CoeffIndex)]
generateMapping :: VarMap -> VarMap -> Maybe [(CoeffIndex,CoeffIndex)]
generateMapping m0 m1 = if Map.keys m0 /= Map.keys m1 then Nothing else Just (Map.elems $ zipMap f m0 m1)
where
f (Just x) (Just y) = Just (x,y)
@ -388,12 +395,17 @@ translateEquations mp = seqPair . transformPair (mapM swapColumns) (mapM swapCol
swapColumns' (x,v) = transformMaybe (\y -> (y,v)) $ transformMaybe fst $ find ((== x) . snd) mp
-- | Asserts that the two problems are equivalent, once you take into account the potentially different variable mappings
assertEquivalentProblems :: String -> (Map.Map String CoeffIndex, (EqualityProblem, InequalityProblem)) -> (Map.Map String CoeffIndex, (EqualityProblem, InequalityProblem)) -> Assertion
assertEquivalentProblems title exp act = assertEqual title translatedExp (Just $ sortP $ snd act)
assertEquivalentProblems :: String -> [(VarMap, (EqualityProblem, InequalityProblem))] -> [(VarMap, (EqualityProblem, InequalityProblem))] -> Assertion
assertEquivalentProblems title exp act = mapM_ (uncurry $ assertEqual title) $ map (uncurry transform) $ zip exp act
where
sortP (eq,ineq) = (sort $ map normaliseEquality eq, sort ineq)
transform :: (VarMap, (EqualityProblem, InequalityProblem)) -> (VarMap, (EqualityProblem, InequalityProblem)) ->
( Maybe (EqualityProblem, InequalityProblem), Maybe (EqualityProblem, InequalityProblem) )
transform exp act = (translatedExp, Just $ sortP $ snd act)
where
sortP :: (EqualityProblem, InequalityProblem) -> (EqualityProblem, InequalityProblem)
sortP (eq,ineq) = (sort $ map normaliseEquality eq, sort ineq)
translatedExp = ( generateMapping (fst exp) (fst act) >>= flip translateEquations (snd exp)) >>* sortP
translatedExp = ( generateMapping (fst exp) (fst act) >>= flip translateEquations (snd exp)) >>* sortP
checkRight :: Show a => Either a b -> IO b
checkRight (Left err) = assertFailure ("Not Right: " ++ show err) >> return undefined