Changed the array usage checking and all the tests to support modulo and division
This patch is unavoidably large (no easy way to split it down). The code compiles, but the modulo test (which is currently wrong anyway) fails at the moment
This commit is contained in:
parent
8cfa9e3cb0
commit
918b9e7af7
|
@ -21,10 +21,11 @@ module ArrayUsageCheck where
|
|||
import Control.Monad.Error
|
||||
import Control.Monad.State
|
||||
import Data.Array.IArray
|
||||
import Data.Generics
|
||||
import Data.Generics hiding (GT)
|
||||
import Data.List
|
||||
import qualified Data.Map as Map
|
||||
import Data.Maybe
|
||||
import qualified Data.Set as Set
|
||||
|
||||
import qualified AST as A
|
||||
import CompState
|
||||
|
@ -56,46 +57,129 @@ checkArrayUsage tree = (mapM_ checkPar $ listify (const True) tree) >> return tr
|
|||
= -- liftIO (putStr $ "Checking: " ++ show (arrName, indexes)) >>
|
||||
case makeEquations indexes (makeConstant emptyMeta 1000000) of
|
||||
Left err -> die $ "Could not work with array indexes for array \"" ++ arrName ++ "\": " ++ err
|
||||
Right (varMapping, problem) ->
|
||||
case uncurry solveProblem problem of
|
||||
Right problems ->
|
||||
case mapM (\(vm,p) -> seqPair (return vm,uncurry solveProblem p)) problems of
|
||||
-- No solutions; no worries!
|
||||
Nothing -> return ()
|
||||
Just vm -> do sol <- formatSolution varMapping (getCounterEqs vm)
|
||||
arrName' <- getRealName arrName
|
||||
dieP m $ "Overlapping indexes of array \"" ++ arrName' ++ "\" when: " ++ sol
|
||||
Just ((varMapping,vm):_) -> do sol <- formatSolution varMapping (getCounterEqs vm)
|
||||
arrName' <- getRealName (A.Name undefined undefined arrName)
|
||||
dieP m $ "Overlapping indexes of array \"" ++ arrName' ++ "\" when: " ++ sol
|
||||
|
||||
formatSolution :: Map.Map String CoeffIndex -> Map.Map CoeffIndex Integer -> PassM String
|
||||
formatSolution :: VarMap -> Map.Map CoeffIndex Integer -> PassM String
|
||||
formatSolution varToIndex indexToConst = do names <- mapM valOfVar $ Map.assocs varToIndex
|
||||
return $ concat $ intersperse " , " $ catMaybes names
|
||||
where
|
||||
valOfVar (varName,k) = case Map.lookup k indexToConst of
|
||||
valOfVar (varExp,k) = case Map.lookup k indexToConst of
|
||||
Nothing -> return Nothing
|
||||
Just val -> do varName' <- getRealName varName
|
||||
return $ Just $ varName' ++ " = " ++ show val
|
||||
Just val -> do varExp' <- showFlattenedExp varExp
|
||||
return $ Just $ varExp' ++ " = " ++ show val
|
||||
|
||||
-- TODO this is surely defined elsewhere already?
|
||||
getRealName :: String -> PassM String
|
||||
getRealName s = lookupName (A.Name undefined undefined s) >>* A.ndOrigName
|
||||
|
||||
getRealName :: A.Name -> PassM String
|
||||
getRealName n = lookupName n >>* A.ndOrigName
|
||||
|
||||
showFlattenedExp :: FlattenedExp -> PassM String
|
||||
showFlattenedExp (Const n) = return $ show n
|
||||
showFlattenedExp (Scale n (A.Variable _ vn)) = do vn' <- getRealName vn
|
||||
return $ (show n) ++ "*" ++ vn'
|
||||
showFlattenedExp (Modulo top bottom)
|
||||
= do top' <- showFlattenedExpSet top
|
||||
bottom' <- showFlattenedExpSet bottom
|
||||
return $ "(" ++ top' ++ " REM " ++ bottom' ++ ")"
|
||||
showFlattenedExp (Divide top bottom)
|
||||
= do top' <- showFlattenedExpSet top
|
||||
bottom' <- showFlattenedExpSet bottom
|
||||
return $ "(" ++ top' ++ " / " ++ bottom' ++ ")"
|
||||
|
||||
showFlattenedExpSet :: Set.Set FlattenedExp -> PassM String
|
||||
showFlattenedExpSet s = liftM concat $ sequence $ intersperse (return " + ") $ map showFlattenedExp $ Set.toList s
|
||||
|
||||
-- | A type for inside makeEquations:
|
||||
data FlattenedExp = Const Integer | Scale Integer A.Variable deriving (Eq,Show)
|
||||
data FlattenedExp
|
||||
= Const Integer
|
||||
| Scale Integer A.Variable
|
||||
| Modulo (Set.Set FlattenedExp) (Set.Set FlattenedExp)
|
||||
| Divide (Set.Set FlattenedExp) (Set.Set FlattenedExp)
|
||||
|
||||
instance Eq FlattenedExp where
|
||||
a == b = EQ == compare a b
|
||||
|
||||
instance Ord FlattenedExp where
|
||||
compare (Const _) (Const _) = EQ
|
||||
compare (Const _) _ = LT
|
||||
compare _ (Const _) = GT
|
||||
compare (Scale _ lv) (Scale _ rv) = customVarCompare lv rv
|
||||
compare (Scale {}) _ = LT
|
||||
compare _ (Scale {}) = GT
|
||||
compare (Modulo ltop lbottom) (Modulo rtop rbottom)
|
||||
= combineCompare [compare ltop lbottom, compare lbottom rbottom]
|
||||
compare (Modulo {}) _ = LT
|
||||
compare _ (Modulo {}) = GT
|
||||
compare (Divide ltop lbottom) (Divide rtop rbottom)
|
||||
= combineCompare [compare ltop lbottom, compare lbottom rbottom]
|
||||
|
||||
customVarCompare :: A.Variable -> A.Variable -> Ordering
|
||||
customVarCompare (A.Variable _ (A.Name _ _ lname)) (A.Variable _ (A.Name _ _ rname)) = compare lname rname
|
||||
-- TODO the rest
|
||||
|
||||
makeExpSet :: [FlattenedExp] -> Either String (Set.Set FlattenedExp)
|
||||
makeExpSet = foldM makeExpSet' Set.empty
|
||||
where
|
||||
makeExpSet' :: Set.Set FlattenedExp -> FlattenedExp -> Either String (Set.Set FlattenedExp)
|
||||
makeExpSet' accum (Const n) = return $ insert (addConst n) (Const n) accum
|
||||
makeExpSet' accum (Scale n v) = return $ insert (addScale n v) (Scale n v) accum
|
||||
makeExpSet' accum m@(Modulo {}) | Set.member m accum = throwError "Cannot have repeated REM items in an expression"
|
||||
| otherwise = return $ Set.insert m accum
|
||||
makeExpSet' accum d@(Divide {}) | Set.member d accum = throwError "Cannot have repeated (/) items in an expression"
|
||||
| otherwise = return $ Set.insert d accum
|
||||
|
||||
insert :: (FlattenedExp -> Set.Set FlattenedExp -> Maybe (Set.Set FlattenedExp)) -> FlattenedExp -> Set.Set FlattenedExp -> Set.Set FlattenedExp
|
||||
insert f e s = case Set.fold insert' (Set.empty,False) s of
|
||||
(s',True) -> s'
|
||||
_ -> Set.insert e s
|
||||
where
|
||||
insert' :: FlattenedExp -> (Set.Set FlattenedExp, Bool) -> (Set.Set FlattenedExp, Bool)
|
||||
insert' e (s,b) = case f e s of
|
||||
Just s' -> (s', True)
|
||||
Nothing -> (Set.insert e s, False)
|
||||
|
||||
addConst :: Integer -> FlattenedExp -> Set.Set FlattenedExp -> Maybe (Set.Set FlattenedExp)
|
||||
addConst x (Const n) s = Just $ Set.insert (Const (n + x)) s
|
||||
addConst _ _ _ = Nothing
|
||||
|
||||
addScale :: Integer -> A.Variable -> FlattenedExp -> Set.Set FlattenedExp -> Maybe (Set.Set FlattenedExp)
|
||||
addScale x lv (Scale n rv) s | EQ == customVarCompare lv rv = Just $ Set.insert (Scale (x + n) rv) s
|
||||
| otherwise = Nothing
|
||||
addScale _ _ _ _ = Nothing
|
||||
|
||||
type VarMap = Map.Map FlattenedExp Int
|
||||
|
||||
-- | Given a list of expressions, an expression representing the upper array bound, returns either an error
|
||||
-- (because the expressions can't be handled, typically) or a set of equalities, inequalities and mapping from
|
||||
-- (unique, munged) variable name to variable-index in the equations.
|
||||
-- TODO probably want to take this into the PassM monad at some point, to use the Meta in the error message
|
||||
makeEquations :: [A.Expression] -> A.Expression -> Either String (Map.Map String Int, (EqualityProblem, InequalityProblem))
|
||||
makeEquations es high = makeEquations' >>* (\(s,v,lh) -> (s,squareEquations (pairEqs v, getIneqs lh v)))
|
||||
makeEquations :: [A.Expression] -> A.Expression -> Either String [(VarMap, (EqualityProblem, InequalityProblem))]
|
||||
makeEquations es high = makeEquations' >>* (\(s,v,lh) -> [(s,squareEquations eqIneq) | eqIneq <- pairEqsAndBounds v lh])
|
||||
where
|
||||
{-
|
||||
makeProblem :: Map.Map String Int
|
||||
-> [(EqualityConstraintEquation,EqualityProblem, InequalityProblem)]
|
||||
-> (EqualityConstraintEquation,EqualityConstraintEquation)
|
||||
-> [(Map.Map String Int, (EqualityProblem, InequalityProblem))]
|
||||
makeProblem varMap problems lowHigh = [(varMap, (eq,ineq)) | (eq,ineq) <- pairEqsAndBounds problems lowHigh]
|
||||
-}
|
||||
|
||||
-- | The body of makeEquations; returns the variable mapping, the list of (nx,ex) pairs and a pair
|
||||
-- representing the upper and lower bounds of the array (inclusive).
|
||||
makeEquations' :: Either String (Map.Map String Int, [EqualityConstraintEquation], (EqualityConstraintEquation, EqualityConstraintEquation))
|
||||
makeEquations' :: Either String (VarMap, [(EqualityConstraintEquation,EqualityProblem,InequalityProblem)], (EqualityConstraintEquation, EqualityConstraintEquation))
|
||||
makeEquations' = do ((v,h),s) <- (flip runStateT) Map.empty $
|
||||
do flattened <- lift (mapM flatten es)
|
||||
eqs <- mapM makeEquation flattened
|
||||
high' <- (lift $ flatten high) >>= makeEquation
|
||||
return (eqs,high')
|
||||
high'' <- case high' of
|
||||
[(h,_,_)] -> return h
|
||||
_ -> throwError "Multiple possible upper bounds not supported"
|
||||
return (concat eqs,high'')
|
||||
return (s,v,(amap (const 0) h, h))
|
||||
|
||||
-- Note that in all these functions, the divisor should always be positive!
|
||||
|
@ -110,11 +194,15 @@ makeEquations es high = makeEquations' >>* (\(s,v,lh) -> (s,squareEquations (pai
|
|||
flatten (A.Dyadic m op lhs rhs) | op == A.Add = combine' (flatten lhs) (flatten rhs)
|
||||
| op == A.Subtr = combine' (flatten lhs) (liftM (scale (-1)) $ flatten rhs)
|
||||
| op == A.Mul = multiplyOut' (flatten lhs) (flatten rhs)
|
||||
-- TODO Div (either constant on bottom, or common (variable) factor(s) with top)
|
||||
| op == A.Rem = liftM2L Modulo (flatten lhs) (flatten rhs)
|
||||
| op == A.Div = liftM2L Divide (flatten lhs) (flatten rhs)
|
||||
| otherwise = throwError ("Unhandleable operator found in expression: " ++ show op)
|
||||
flatten (A.ExprVariable _ v) = return [Scale 1 v]
|
||||
flatten other = throwError ("Unhandleable item found in expression: " ++ show other)
|
||||
|
||||
-- liftM2L :: (Ord a, Ord b, Monad m) => (Set.Set a -> Set.Set b -> c) -> m [a] -> m [b] -> m [c]
|
||||
liftM2L f x y = liftM (:[]) $ liftM2 f (x >>= makeExpSet) (y >>= makeExpSet)
|
||||
|
||||
--TODO we need to handle lots more different expression types in future.
|
||||
|
||||
multiplyOut' :: Either String [FlattenedExp] -> Either String [FlattenedExp] -> Either String [FlattenedExp]
|
||||
|
@ -151,21 +239,35 @@ makeEquations es high = makeEquations' >>* (\(s,v,lh) -> (s,squareEquations (pai
|
|||
|
||||
-- | Finds the index associated with a particular variable; either by finding an existing index
|
||||
-- or allocating a new one.
|
||||
varIndex :: A.Variable -> StateT (Map.Map String Int) (Either String) Int
|
||||
varIndex (A.Variable _ (A.Name _ _ varName))
|
||||
varIndex :: FlattenedExp -> StateT (VarMap) (Either String) Int
|
||||
varIndex (Scale _ var@(A.Variable _ (A.Name _ _ varName)))
|
||||
= do st <- get
|
||||
let (st',ind) = case Map.lookup varName st of
|
||||
let (st',ind) = case Map.lookup (Scale 1 var) st of
|
||||
Just val -> (st,val)
|
||||
Nothing -> let newId = (1 + (maximum $ 0 : Map.elems st)) in
|
||||
(Map.insert varName newId st, newId)
|
||||
(Map.insert (Scale 1 var) newId st, newId)
|
||||
put st'
|
||||
return ind
|
||||
varIndex mod@(Modulo top bottom)
|
||||
= do st <- get
|
||||
let (st',ind) = case Map.lookup mod st of
|
||||
Just val -> (st,val)
|
||||
Nothing -> let newId = (1 + (maximum $ 0 : Map.elems st)) in
|
||||
(Map.insert mod newId st, newId)
|
||||
put st'
|
||||
return ind
|
||||
|
||||
-- | Pairs all possible combinations of the list of equations.
|
||||
pairEqs :: [EqualityConstraintEquation] -> [EqualityConstraintEquation]
|
||||
pairEqs = filter (any (/= 0) . elems) . map (uncurry pairEqs') . allPairs
|
||||
pairEqsAndBounds :: [(EqualityConstraintEquation, EqualityProblem, InequalityProblem)] -> (EqualityConstraintEquation, EqualityConstraintEquation) -> [(EqualityProblem, InequalityProblem)]
|
||||
pairEqsAndBounds items bounds = (map (filterProblems . uncurry pairEqs') . allPairs) items
|
||||
where
|
||||
pairEqs' ex ey = arrayZipWith' 0 (-) ex ey
|
||||
pairEqs' :: (EqualityConstraintEquation, EqualityProblem, InequalityProblem)
|
||||
-> (EqualityConstraintEquation, EqualityProblem, InequalityProblem)
|
||||
-> (EqualityProblem, InequalityProblem)
|
||||
pairEqs' (ex,eqX,ineqX) (ey,eqY,ineqY) = ([arrayZipWith' 0 (-) ex ey] ++ eqX ++ eqY, ineqX ++ ineqY ++ getIneqs bounds [ex,ey])
|
||||
|
||||
filterProblems :: (EqualityProblem, InequalityProblem) -> (EqualityProblem, InequalityProblem)
|
||||
filterProblems = transformPair (filter (any (/= 0) . elems)) (filter (any (/= 0) . elems))
|
||||
|
||||
-- | Given a (low,high) bound (typically: array dimensions), and a list of equations ex,
|
||||
-- forms the possible inequalities:
|
||||
|
@ -182,18 +284,73 @@ makeEquations es high = makeEquations' >>* (\(s,v,lh) -> (s,squareEquations (pai
|
|||
|
||||
addEq = arrayZipWith' 0 (+)
|
||||
|
||||
-- | Given ex, forms an equation (e) from the latter part, and returns it
|
||||
makeEquation :: [FlattenedExp] -> StateT (Map.Map String Int) (Either String) EqualityConstraintEquation
|
||||
-- | Given an expression, forms equations (and accompanying additional equation-sets) and returns it
|
||||
makeEquation :: [FlattenedExp] -> StateT (VarMap) (Either String) [(EqualityConstraintEquation, EqualityProblem, InequalityProblem)]
|
||||
makeEquation summedItems
|
||||
= do eqs <- foldM makeEquation' Map.empty summedItems
|
||||
return $ mapToArray eqs
|
||||
= do eqs <- process summedItems
|
||||
return $ map (transformTriple mapToArray (map mapToArray) (map mapToArray)) eqs
|
||||
where
|
||||
makeEquation' :: Map.Map Int Integer -> FlattenedExp -> StateT (Map.Map String Int) (Either String) (Map.Map Int Integer)
|
||||
process = foldM makeEquation' empty
|
||||
|
||||
makeEquation' :: [(Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])] -> FlattenedExp -> StateT (VarMap) (Either String) [(Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])]
|
||||
makeEquation' m (Const n) = return $ add (0,n) m
|
||||
makeEquation' m (Scale n v) = varIndex v >>* (\ind -> add (ind, n) m)
|
||||
makeEquation' m sc@(Scale n v) = varIndex sc >>* (\ind -> add (ind, n) m)
|
||||
makeEquation' m mod@(Modulo top bottom)
|
||||
= do top' <- process $ Set.toList top
|
||||
top'' <- case top' of
|
||||
[(t,_,_)] -> return t
|
||||
_ -> throwError "Modulo or divide not allowed in the numerator of Modulo"
|
||||
bottom' <- process $ Set.toList bottom
|
||||
topIndex <- varIndex mod
|
||||
case onlyConst (Set.toList bottom) of
|
||||
Just bottomConst ->
|
||||
let add_x_plus_my = zipMap plus top'' . zipMap plus (Map.fromList [(topIndex,bottomConst)]) in
|
||||
return $
|
||||
-- The zero option (x = 0, x REM y = 0):
|
||||
( map (transformTriple id (++ [top'']) id) m)
|
||||
++
|
||||
-- The top-is-positive option:
|
||||
( map (transformTriple add_x_plus_my id (++
|
||||
-- x >= 1
|
||||
[zipMap plus (Map.fromList [(0,-1)]) top''
|
||||
-- m <= 0
|
||||
,Map.fromList [(topIndex,-1)]
|
||||
-- x + my + 1 - |y| <= 0
|
||||
,Map.map negate $ add_x_plus_my $ Map.fromList [(0,1 - bottomConst)]
|
||||
-- x + my >= 0
|
||||
,add_x_plus_my $ Map.empty])
|
||||
) m) ++
|
||||
-- The top-is-negative option:
|
||||
( map (transformTriple add_x_plus_my id (++
|
||||
-- x <= -1
|
||||
[add' (0,-1) $ Map.map negate top''
|
||||
-- m >= 0
|
||||
,Map.fromList [(topIndex,1)]
|
||||
-- x + my - 1 + |y| >= 0
|
||||
,add_x_plus_my $ Map.fromList [(0,bottomConst - 1)]
|
||||
-- x + my <= 0
|
||||
,Map.map negate $ add_x_plus_my Map.empty])
|
||||
) m)
|
||||
_ -> throwError "TODO - Variable divisor for modulo"
|
||||
|
||||
makeEquation' m (Divide top bottom) = throwError "TODO Divide"
|
||||
|
||||
add :: (Int,Integer) -> Map.Map Int Integer -> Map.Map Int Integer
|
||||
add = uncurry (Map.insertWith (+))
|
||||
onlyConst :: [FlattenedExp] -> Maybe Integer
|
||||
onlyConst [] = Just 0
|
||||
onlyConst ((Const n):es) = liftM2 (+) (return n) $ onlyConst es
|
||||
onlyConst _ = Nothing
|
||||
|
||||
empty :: [(Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])]
|
||||
empty = [(Map.empty,[],[])]
|
||||
|
||||
plus :: Num n => Maybe n -> Maybe n -> Maybe n
|
||||
plus x y = Just $ (fromMaybe 0 x) + (fromMaybe 0 y)
|
||||
|
||||
add' :: (Int,Integer) -> Map.Map Int Integer -> Map.Map Int Integer
|
||||
add' (m,n) = Map.insertWith (+) m n
|
||||
|
||||
add :: (Int,Integer) -> [(Map.Map Int Integer,a,b)] -> [(Map.Map Int Integer,a,b)]
|
||||
add (m,n) = map $ transformTriple (Map.insertWith (+) m n) id id
|
||||
|
||||
-- | Converts a map to an array. Any missing elements in the middle of the bounds are given the value zero.
|
||||
-- Could probably be moved to Utils
|
||||
|
|
|
@ -23,6 +23,7 @@ import Data.Array.IArray
|
|||
import Data.List
|
||||
import qualified Data.Map as Map
|
||||
import Data.Maybe
|
||||
import qualified Data.Set as Set
|
||||
import Prelude hiding ((**),fail)
|
||||
import Test.HUnit
|
||||
import Test.QuickCheck hiding (check)
|
||||
|
@ -232,29 +233,35 @@ showMaybe _ Nothing = "Nothing"
|
|||
testMakeEquations :: Test
|
||||
testMakeEquations = TestList
|
||||
[
|
||||
test (0,Map.empty,[con 0 === con 1],leq [con 0,con 0,con 7] &&& leq [con 0,con 1,con 7],
|
||||
[intLiteral 0, intLiteral 1],intLiteral 7)
|
||||
test (0,[(Map.empty,[con 0 === con 1],leq [con 0,con 1,con 7] &&& leq [con 0,con 2,con 7])],
|
||||
[intLiteral 1, intLiteral 2],intLiteral 7)
|
||||
|
||||
,test (1,i_mapping,[i === con 3],leq [con 0,con 3,con 7] &&& leq [con 0,i,con 7],
|
||||
,test (1,[(i_mapping,[i === con 3],leq [con 0,con 3,con 7] &&& leq [con 0,i,con 7])],
|
||||
[exprVariable "i",intLiteral 3],intLiteral 7)
|
||||
|
||||
,test (2,ij_mapping,[i === j],leq [con 0,i,con 7] &&& leq [con 0,j,con 7],
|
||||
,test (2,[(ij_mapping,[i === j],leq [con 0,i,con 7] &&& leq [con 0,j,con 7])],
|
||||
[exprVariable "i",exprVariable "j"],intLiteral 7)
|
||||
|
||||
,test (3,ij_mapping,[i ++ con 3 === j],leq [con 0,i ++ con 3,con 7] &&& leq [con 0,j,con 7],
|
||||
,test (3,[(ij_mapping,[i ++ con 3 === j],leq [con 0,i ++ con 3,con 7] &&& leq [con 0,j,con 7])],
|
||||
[buildExpr $ Dy (Var "i") A.Add (Lit $ intLiteral 3),exprVariable "j"],intLiteral 7)
|
||||
|
||||
,test (4,ij_mapping,[2 ** i === j],leq [con 0,2 ** i,con 7] &&& leq [con 0,j,con 7],
|
||||
,test (4,[(ij_mapping,[2 ** i === j],leq [con 0,2 ** i,con 7] &&& leq [con 0,j,con 7])],
|
||||
[buildExpr $ Dy (Var "i") A.Mul (Lit $ intLiteral 2),exprVariable "j"],intLiteral 7)
|
||||
|
||||
,test (10,[(i_mod_mapping 3,[i ++ 3 ** j === con 4], leq [con 0,i ++ 3 ** j,con 7])],
|
||||
[buildExpr $ Dy (Var "i") A.Rem (Lit $ intLiteral 3),intLiteral 4],intLiteral 7)
|
||||
]
|
||||
where
|
||||
test :: (Integer,Map.Map String CoeffIndex,[HandyEq],[HandyIneq],[A.Expression],A.Expression) -> Test
|
||||
test (ind, mpping, eqs, ineqs, exprs, upperBound) =
|
||||
TestCase $ assertEquivalentProblems ("testMakeEquations" ++ show ind)
|
||||
(mpping,makeConsistent eqs ineqs) =<< (checkRight $ makeEquations exprs upperBound)
|
||||
test :: (Integer,[(VarMap,[HandyEq],[HandyIneq])],[A.Expression],A.Expression) -> Test
|
||||
test (ind, problems, exprs, upperBound) =
|
||||
TestCase $ assertEquivalentProblems ("testMakeEquations " ++ show ind)
|
||||
(map (transformPair id (uncurry makeConsistent)) $ map pairLatterTwo problems) =<< (checkRight $ makeEquations exprs upperBound)
|
||||
|
||||
i_mapping = Map.singleton "i" 1
|
||||
ij_mapping = Map.fromList [("i",1),("j",2)]
|
||||
pairLatterTwo (a,b,c) = (a,(b,c))
|
||||
|
||||
i_mapping = Map.singleton (Scale 1 $ variable "i") 1
|
||||
ij_mapping = Map.fromList [(Scale 1 $ variable "i",1),(Scale 1 $ variable "j",2)]
|
||||
i_mod_mapping n = Map.fromList [(Scale 1 $ variable "i",1),(Modulo (Set.singleton $ Scale 1 $ variable "i") (Set.singleton $ Const n),2)]
|
||||
|
||||
testIndexes :: Test
|
||||
testIndexes = TestList
|
||||
|
@ -369,7 +376,7 @@ testIndexes = TestList
|
|||
-- | Given one mapping and a second mapping, gives a function that converts the indexes
|
||||
-- from one to the indexes of the next. If any of the keys in the map don't match
|
||||
-- (i.e. if (keys m0 /= keys m1)) Nothing will be returned
|
||||
generateMapping :: Map.Map String CoeffIndex -> Map.Map String CoeffIndex -> Maybe [(CoeffIndex,CoeffIndex)]
|
||||
generateMapping :: VarMap -> VarMap -> Maybe [(CoeffIndex,CoeffIndex)]
|
||||
generateMapping m0 m1 = if Map.keys m0 /= Map.keys m1 then Nothing else Just (Map.elems $ zipMap f m0 m1)
|
||||
where
|
||||
f (Just x) (Just y) = Just (x,y)
|
||||
|
@ -388,12 +395,17 @@ translateEquations mp = seqPair . transformPair (mapM swapColumns) (mapM swapCol
|
|||
swapColumns' (x,v) = transformMaybe (\y -> (y,v)) $ transformMaybe fst $ find ((== x) . snd) mp
|
||||
|
||||
-- | Asserts that the two problems are equivalent, once you take into account the potentially different variable mappings
|
||||
assertEquivalentProblems :: String -> (Map.Map String CoeffIndex, (EqualityProblem, InequalityProblem)) -> (Map.Map String CoeffIndex, (EqualityProblem, InequalityProblem)) -> Assertion
|
||||
assertEquivalentProblems title exp act = assertEqual title translatedExp (Just $ sortP $ snd act)
|
||||
assertEquivalentProblems :: String -> [(VarMap, (EqualityProblem, InequalityProblem))] -> [(VarMap, (EqualityProblem, InequalityProblem))] -> Assertion
|
||||
assertEquivalentProblems title exp act = mapM_ (uncurry $ assertEqual title) $ map (uncurry transform) $ zip exp act
|
||||
where
|
||||
sortP (eq,ineq) = (sort $ map normaliseEquality eq, sort ineq)
|
||||
transform :: (VarMap, (EqualityProblem, InequalityProblem)) -> (VarMap, (EqualityProblem, InequalityProblem)) ->
|
||||
( Maybe (EqualityProblem, InequalityProblem), Maybe (EqualityProblem, InequalityProblem) )
|
||||
transform exp act = (translatedExp, Just $ sortP $ snd act)
|
||||
where
|
||||
sortP :: (EqualityProblem, InequalityProblem) -> (EqualityProblem, InequalityProblem)
|
||||
sortP (eq,ineq) = (sort $ map normaliseEquality eq, sort ineq)
|
||||
|
||||
translatedExp = ( generateMapping (fst exp) (fst act) >>= flip translateEquations (snd exp)) >>* sortP
|
||||
translatedExp = ( generateMapping (fst exp) (fst act) >>= flip translateEquations (snd exp)) >>* sortP
|
||||
|
||||
checkRight :: Show a => Either a b -> IO b
|
||||
checkRight (Left err) = assertFailure ("Not Right: " ++ show err) >> return undefined
|
||||
|
|
Loading…
Reference in New Issue
Block a user