tock-mirror/checks/ArrayUsageCheck.hs

1074 lines
56 KiB
Haskell

{-
Tock: a compiler for parallel languages
Copyright (C) 2007--2009 University of Kent
This program is free software; you can redistribute it and/or modify it
under the terms of the GNU General Public License as published by the
Free Software Foundation, either version 2 of the License, or (at your
option) any later version.
This program is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
General Public License for more details.
You should have received a copy of the GNU General Public License along
with this program. If not, see <http://www.gnu.org/licenses/>.
-}
module ArrayUsageCheck (
BackgroundKnowledge(..),
BK,
canonicalise,
checkArrayUsage,
findRepSolutions,
FlattenedExp(..),
fmapFlattenedExp,
makeEquations,
makeExpSet,
ModuloCase(..),
onlyConst,
showFlattenedExp,
VarMap) where
import Control.Monad.Error
import Control.Monad.Reader
import Control.Monad.State
import Data.Array.IArray
import qualified Data.Foldable as F
import Data.Generics hiding (GT)
import Data.Int
import Data.List
import qualified Data.Map as Map
import Data.Maybe
import qualified Data.Set as Set
import qualified Data.Traversable as T
import qualified AST as A
import CompState
import Errors
import Metadata
import Omega
import OrdAST()
import Pass
import ShowCode
import Types
import UsageCheckUtils
import Utils
-- Each list is a possible set of background knowledge mapping vars to a list
-- of constraints. So it is a disjunction of map from variables to conjunctions
type BK = [Map.Map Var [BackgroundKnowledge]]
type BK' = [Map.Map Var (Either String (EqualityProblem, InequalityProblem))]
-- | Given a list of replicators, and a set of background knowledge for each
-- access inside the replicator, checks if there are any solutions for a
-- combination of the normal replicator constraints, and the given background
-- knowledge (pairing each set against each other, applying one set to the replicator,
-- and the other to the mirror of the replicator).
--
-- Returns Nothing if no solutions, a String with a counter-example if there
-- are solutions
findRepSolutions :: (CSMR m, MonadIO m) => [(A.Name, A.Replicator)] -> [BK] -> m (Maybe String)
findRepSolutions reps bks
-- To get the right comparison, we create a SeqItems with all the accesses
-- Because they are inside a PAR replicator, they will all get compared to each
-- other with one set of BK applied to i and one applied to i', but they will
-- never be compared to each other just with the constraints on i (which is not
-- what we are checking here). We set the dummy array accesses to all be zero,
-- which means they can overlap -- but only if there is also a solution to the
-- replicator background knowledge, which is what this function is trying to
-- determine.
= getCompState >>= \cs -> case flip runReaderT cs $ makeEquations (addReps $ SeqItems [(bk, [makeConstant emptyMeta 0], [])
| bk <- bks]) maxInt of
Right problems -> do
probs <- formatProblems [(vm, prob) | (_,vm,prob) <- problems]
debug $ "Problems in findRepSolutions:\n" ++ probs
case catMaybes [fmap ((,) i) $ solve p | (i::Integer, p) <- zip [0..] problems] of
[] -> return Nothing -- No solutions, safe
xs -> liftM (Just . unlines) $ mapM format xs
res -> error $ "Unexpected reachability result"
where
maxInt = makeConstant emptyMeta $ fromInteger $ toInteger (maxBound :: Int32)
format (i, ((lx,ly),varMapping,vm,problem))
= formatSolution varMapping vm >>* (("#" ++ show i ++ ": ") ++)
addReps = flip (foldl $ flip RepParItem) reps
-- | A check-pass that checks the given ParItems (usually generated from a control-flow graph)
-- for any overlapping array indices.
checkArrayUsage :: forall m. (Die m, CSMR m, MonadIO m) => NameAttr -> (Meta, ParItems (BK, UsageLabel)) -> m ()
checkArrayUsage sharedAttr (m,p)
= do debug $ "checkArrayUsage: " ++ show m
indexes <- groupArrayIndexes $ fmap (transformPair id nodeVars) p
mapM_ (checkIndexes m) $ Map.toList $ Map.filter
((<= 1) . length . map (\(_,w,r) -> w++r) . F.toList) indexes
where
getDecl :: UsageLabel -> Maybe String
getDecl = join . fmap getScopeIn . nodeDecl
where
getScopeIn (ScopeIn _ n) = Just n
getScopeIn _ = Nothing
-- Takes a ParItems Vars, and returns a map from array-variable-name to a list of writes and a list of reads for that array.
-- Returns (array name, list of written-to indexes, list of read-from indexes)
groupArrayIndexes :: ParItems (BK, Vars) -> m (Map.Map (String, Maybe A.Direction) (ParItems (BK, [A.Expression], [A.Expression])))
groupArrayIndexes = liftM filterByKey . T.mapM
(\(bk,vs) -> do w <- makeList $ (Map.keysSet $ writtenVars vs) `Set.union` (usedVars vs)
r <- makeList $ readVars vs
return $ zipMap (join bk) w r)
where
join :: b -> Maybe [a] -> Maybe [a] -> Maybe (b, [a],[a])
join k x y = Just (k, fromMaybe [] x, fromMaybe [] y)
-- Turns a set of variables into a map (from array-name to list of index-expressions)
makeList :: Set.Set Var -> m (Map.Map (String, Maybe A.Direction) [A.Expression])
makeList vs = do indexes <- concatMapM getArrayIndex $ Set.toList vs
return $ Map.fromListWith (++) indexes
-- Lifts a map (from array-name to expression-lists) inside a ParItems to being a map (from array-name to ParItems of expression lists)
filterByKey :: ParItems (Map.Map (String, Maybe A.Direction) (BK, [A.Expression], [A.Expression]))
-> Map.Map (String, Maybe A.Direction) (ParItems (BK, [A.Expression], [A.Expression]))
filterByKey p = Map.fromList $ map trans keys
where
keys :: [(String, Maybe A.Direction)]
keys = concatMap Map.keys $ flattenParItems p
trans :: (String, Maybe A.Direction) -> ((String, Maybe A.Direction), ParItems (BK, [A.Expression], [A.Expression]))
trans k = (k, fmap (Map.findWithDefault ([], [], []) k) p)
-- Gets the (array-name, indexes) from a Var.
-- TODO this is quite hacky, and doesn't yet deal with slices and so on:
getArrayIndex :: Var -> m [((String, Maybe A.Direction), [A.Expression])]
getArrayIndex (Var v@(A.SubscriptedVariable _ (A.Subscript _ _ e) (A.Variable _ n)))
= do t <- astTypeOf v
let dirs = case t of
A.Chan {} -> [Just A.DirInput, Just A.DirOutput]
_ -> [Nothing]
return [((A.nameName n, d), [e]) | d <- dirs]
getArrayIndex (Var (A.SubscriptedVariable _ (A.Subscript _ _ e)
(A.DirectedVariable _ dir (A.Variable _ n))))
= return [((A.nameName n, Just dir), [e])]
getArrayIndex (Var (A.DirectedVariable _ dir (A.SubscriptedVariable _ (A.Subscript _ _ e)
(A.Variable _ n))))
= return [((A.nameName n, Just dir), [e])]
getArrayIndex _ = return []
-- Checks the given ParItems of writes and reads against each other. The
-- String (array-name) and Meta are only used for printing out error messages
checkIndexes :: Meta -> ((String, Maybe A.Direction), ParItems (BK, [A.Expression], [A.Expression])) -> m ()
checkIndexes m ((arrName, arrDir), indexes) = do
sharedNames <- getCompState >>* csNameAttr
let declNames = [x | Just x <- fmap (getDecl . snd) $ flattenParItems p]
when (fmap (Set.member sharedAttr) (Map.lookup arrName sharedNames) /= Just True && arrName `notElem` declNames) $
do userArrName <- getRealName (A.Name undefined arrName)
arrType <- astTypeOf (A.Name undefined arrName) >>= resolveUserType m
arrLength <- case arrType of
A.Array (A.Dimension d:_) _ -> return d
-- Unknown dimension, use the maximum value for a (assumed 32-bit for INT) integer:
A.Array (A.UnknownDimension:_) _ -> return $ makeConstant m $ fromInteger $ toInteger (maxBound :: Int32)
-- It's not an array:
_ -> dieP m $ "Cannot usage check array \"" ++ userArrName ++ "\"; found to be of type: " ++ show arrType
cs <- getCompState
case runReaderT (makeEquations indexes arrLength) cs of
Left err -> dieP m $ "Could not work with array indexes for array \"" ++ userArrName ++ "\": " ++ err
Right [] -> return () -- No problems to work with
Right problems -> do
probs <- formatProblems [(vm, prob) | (_,vm,prob) <- problems]
debug $ "Problems in checkArrayUsage" ++ show m ++ ":\n" ++ probs
case mapMaybe solve problems of
-- No solutions; no worries!
[] -> return ()
(((lx,ly),varMapping,vm,problem):_) ->
do sol <- formatSolution varMapping vm
cx <- showCode (fst lx)
cy <- showCode (fst ly)
-- liftIO $ putStrLn $ "Found solution for problem: " ++ probs
-- ++ show p
-- liftIO $ putStrLn $ "Succeeded on problem: " ++ prob
-- allProbs <- concatMapM (\(_,_,p) -> formatProblem varMapping p >>* (++ "\n#\n")) problems
-- svm <- mapM (showFlattenedExp showCode) $ Map.keys varMapping
-- liftIO $ putStrLn $ "All problems: " ++ allProbs ++ "\n" ++ (concat $ intersperse " ; " $ svm)
dieP m $ "Indexes of array \"" ++ userArrName ++ "\" "
++ "(\"" ++ cx ++ "\" and \"" ++ cy ++ "\") could overlap"
++ if sol /= "" then " when: " ++ sol else ""
-- TODO this is surely defined elsewhere already?
getRealName :: A.Name -> m String
getRealName n = lookupName n >>* A.ndOrigName
formatProblems :: CSMR m => [(VarMap, (EqualityProblem, InequalityProblem))] -> m String
formatProblems probs = do formatted <- mapM (uncurry formatProblem) probs
return $ concat [addNum i (lines p) | (p, i) <- zip formatted [0..]]
where
addNum :: Int -> [String] -> String
addNum i [] = ""
addNum i (p:ps) = unlines $
("#" ++ show i ++ (if length (show i) == 1 then " :" else ":")
++ p) : map (" " ++) ps
-- | Formats an entire problem ready to print it out half-legibly for debugging purposes
formatProblem :: forall m. CSMR m => VarMap -> (EqualityProblem, InequalityProblem) -> m String
formatProblem varToIndex (eq, ineq)
= do feqs <- mapM (showWithConst "=") $ eq
fineqs <- mapM (\e -> if allNegative e
then showWithConst "<=" (negateAll e)
else showWithConst ">=" e) $ ineq
return $ unlines $ feqs ++ fineqs
where
--Returns true if all the variable coefficients are negative (ignoring
-- the constant term)
allNegative :: Array CoeffIndex Integer -> Bool
allNegative = all (<= 0) . tail . elems
negateAll :: Array CoeffIndex Integer -> Array CoeffIndex Integer
negateAll = amap negate
showWithConst :: String -> Array CoeffIndex Integer -> m String
showWithConst op item = do text <- showEq item
return $
(if text == "" then "0" else text)
++ " " ++ op ++ " " ++ show (negate $ item ! 0)
showEq :: Array CoeffIndex Integer -> m String
showEq = liftM (joinWith " + ") . mapM showItem . filter ((/= 0) . snd) . tail . assocs
showItem :: (CoeffIndex, Integer) -> m String
showItem (n, a) = case find ((== n) . snd) $ Map.assocs varToIndex of
Just (exp,_) -> showFlattenedExp showCode exp >>* (mult ++)
Nothing -> return "<unknown>"
where
mult = case a of
1 -> ""
-1 -> "-"
_ -> show a ++ "*"
-- | Solves the problem and munges the arguments and results into a useful order
solve :: (labels,vm,(EqualityProblem,InequalityProblem)) ->
Maybe (labels,vm,VariableMapping,(EqualityProblem,InequalityProblem))
solve (ls,vm,(eq,ineq)) = case solveProblem eq ineq of
Nothing -> Nothing
Just vm' -> Just (ls,vm,vm',(eq,ineq))
-- | Formats a solution (not a problem, just the solution) ready to print it out for the user
formatSolution :: (CSMR m, Monad m) => VarMap -> VariableMapping -> m String
formatSolution varToIndex vm
= do names <- mapM valOfVar $ Map.assocs varToIndex
return $ joinWith " , " $ catMaybes names
where
indexToVar = flip lookup $ map revPair $ Map.assocs varToIndex
indexToVar' (0, x) = Just (Nothing, x)
indexToVar' (_, 0) = Nothing
indexToVar' (i, x) = case indexToVar i of
Just v -> Just (Just v, x)
Nothing -> Nothing
indexToConst = getCounterEqs vm
showWithCoeff' (Nothing, n) = return $ show n
showWithCoeff' (Just v, n) = liftM (mult ++) $ showFlattenedExp showCode v
where
mult = case n of
1 -> ""
-1 -> "-"
n -> show n ++ "*"
showWithCoeff xs = liftM (joinWith " + ") $ mapM showWithCoeff' xs
valOfVar (varExp,k) = case Map.lookup k indexToConst of
Nothing -> return Nothing
Just (Left (n, low, high)) ->
do varExp' <- showWithCoeff' (Just varExp, n)
low' <- mapM showWithCoeff $ map (mapMaybe indexToVar') low
high' <- mapM showWithCoeff $ map (mapMaybe indexToVar') high
return $ Just $ formatBounds (++ " <= ") low'
++ varExp' ++ formatBounds (" <= " ++) high'
Just (Right val) -> do varExp' <- showFlattenedExp showCode varExp
return $ Just $ varExp' ++ " = " ++ show val
formatBounds _ [] = ""
formatBounds f [b] = f b
formatBounds f bs = f $ "(" ++ joinWith "," bs ++ ")"
showFlattenedExpSet :: Monad m => (A.Expression -> m String) -> Set.Set FlattenedExp -> m String
showFlattenedExpSet showExp s = liftM concat $ sequence $ intersperse (return " + ") $ map (showFlattenedExp showExp) $ Set.toList s
-- Shows a FlattenedExp legibly by looking up real names for variables, and formatting things.
-- The output for things involving modulo might be a bit odd, but there isn't really anything
-- much that can be done about that
showFlattenedExp :: Monad m => (A.Expression -> m String) -> FlattenedExp -> m String
showFlattenedExp _ (Const n) = return $ show n
showFlattenedExp showExp (Scale n (e,vi))
= do vn' <- showExp e >>* (++ replicate vi '\'')
return $ showScale vn' n
showFlattenedExp showExp (Modulo n top bottom)
= do top' <- showFlattenedExpSet showExp top
bottom' <- showFlattenedExpSet showExp bottom
case onlyConst (Set.toList bottom) of
Just _ -> return $ showScale ("(" ++ top' ++ " / " ++ bottom' ++ ")") (-n)
Nothing -> return $ showScale ("((" ++ top' ++ " REM " ++ bottom' ++ ") - " ++ top' ++ ")") n
showFlattenedExp showExp (Divide n top bottom)
= do top' <- showFlattenedExpSet showExp top
bottom' <- showFlattenedExpSet showExp bottom
return $ showScale ("(" ++ top' ++ " / " ++ bottom' ++ ")") n
showScale :: String -> Integer -> String
showScale s n =
case n of
1 -> s
-1 -> "-" ++ s
_ -> (show n) ++ "*" ++ s
-- | A type for inside makeEquations:
data FlattenedExp
= Const Integer
-- ^ A constant
| Scale Integer (A.Expression, Int)
-- ^ A variable and coefficient. The first argument is the coefficient
-- The second part of the pair is for sub-indexing (or "priming") variables.
-- For example, replication is done by checking the replicated variable "i"
-- against a sub-indexed (with "1") version (denoted "i'"). The sub-index
-- is what differentiates i from i', given that they are technically the
-- same A.Variable
| Modulo Integer (Set.Set FlattenedExp) (Set.Set FlattenedExp)
-- ^ A modulo, with a coefficient\/scale and given top and bottom (in that order)
| Divide Integer (Set.Set FlattenedExp) (Set.Set FlattenedExp)
-- ^ An integer division, with a coefficient\/scale and the given top and bottom (in that order)
instance Eq FlattenedExp where
a == b = EQ == compare a b
-- | A straightforward comparison for FlattenedExp that compares while ignoring
-- the value of a const @(Const 3 == Const 5)@ and the value of a scale
-- @(Scale 1 (v,0)) == (Scale 3 (v,0))@, although note that @(Scale 1 (v,0)) \/= (Scale 1 (v,1))@.
instance Ord FlattenedExp where
compare (Const _) (Const _) = EQ
compare (Const _) _ = LT
compare _ (Const _) = GT
compare (Scale _ (lv,li)) (Scale _ (rv,ri)) = combineCompare [compare lv rv, compare li ri]
compare (Scale {}) _ = LT
compare _ (Scale {}) = GT
compare (Modulo _ ltop lbottom) (Modulo _ rtop rbottom)
= combineCompare [compare ltop rtop, compare lbottom rbottom]
compare (Modulo {}) _ = LT
compare _ (Modulo {}) = GT
compare (Divide _ ltop lbottom) (Divide _ rtop rbottom)
= combineCompare [compare ltop rtop, compare lbottom rbottom]
-- | Checks if an expression list contains only constants. Returns Just (the aggregate constant) if so,
-- otherwise returns Nothing.
onlyConst :: [FlattenedExp] -> Maybe Integer
onlyConst [] = Just 0
onlyConst ((Const n):es) = liftM2 (+) (return n) $ onlyConst es
onlyConst _ = Nothing
fmapFlattenedExp :: (A.Expression -> A.Expression) -> FlattenedExp -> FlattenedExp
fmapFlattenedExp f x@(Const _) = x
fmapFlattenedExp f (Scale n (e, i)) = Scale n (f e, i)
fmapFlattenedExp f (Modulo n top bottom)
= Modulo n (Set.map (fmapFlattenedExp f) top) (Set.map (fmapFlattenedExp f) bottom)
fmapFlattenedExp f (Divide n top bottom)
= Divide n (Set.map (fmapFlattenedExp f) top) (Set.map (fmapFlattenedExp f) bottom)
-- | A data type representing an array access. Each triple is (index, extra-equalities, extra-inequalities).
-- A Single item can be paired with every other access.
-- Each item of a Group cannot be paired with each other, but can be paired with each other access.
-- With a Replicated, each item in the left branch can be paired with each item in the right branch.
-- Each item in the left branch can be paired with each other, and each item in the left branch can
-- be paired with all other items.
data ArrayAccess label =
Group [(label, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]
| Replicated [ArrayAccess label] [ArrayAccess label]
-- | A simple data type for denoting whether an array access is a read or a write
data ArrayAccessType = AAWrite | AARead
-- | Transforms the ParItems (from the control-flow graph) into the more suitable ArrayAccess
-- data type used by this array usage checker.
parItemToArrayAccessM :: Monad m =>
( [((A.Name, A.Replicator), Bool)] ->
a ->
m [(label, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]
) ->
ParItems a ->
m ([ArrayAccess label], [A.Name])
parItemToArrayAccessM f (SeqItems xs)
-- Each sequential item is a group of one:
= do aas <- sequence [concatMapM (f []) xs >>* Group]
return (aas, [])
parItemToArrayAccessM f (ParItems ps)
= liftM (transformPair concat concat . unzip) $ mapM (parItemToArrayAccessM f) ps
parItemToArrayAccessM f (RepParItem rep p)
= do (normal, otherReps) <- parItemToArrayAccessM (\reps -> f ((rep,False):reps)) p
mirror <- liftM fst $ parItemToArrayAccessM (\reps -> f ((rep,True):reps)) p
return ([Replicated normal mirror], fst rep : otherReps)
-- | Turns a list of expressions (which may contain many constants, or duplicated variables)
-- into a set of expressions with at most one constant term, and at most one appearance
-- of a any variable, or distinct modulo\/division of variables.
-- If there is any problem (specifically, nested modulo or divisions) an error will be returned instead
makeExpSet :: forall m. MonadError String m => [FlattenedExp] -> m (Set.Set FlattenedExp)
makeExpSet = foldM makeExpSet' Set.empty
where
makeExpSet' :: Set.Set FlattenedExp -> FlattenedExp -> m (Set.Set FlattenedExp)
makeExpSet' accum (Const n) = return $ insert (addConst n) (Const n) accum
makeExpSet' accum (Scale n v) = return $ insert (addScale n v) (Scale n v) accum
makeExpSet' accum m@(Modulo {}) | Set.member m accum = throwError "Cannot have repeated REM items in an expression"
| otherwise = return $ Set.insert m accum
makeExpSet' accum d@(Divide {}) | Set.member d accum = throwError "Cannot have repeated (/) items in an expression"
| otherwise = return $ Set.insert d accum
insert :: (FlattenedExp -> Set.Set FlattenedExp -> Maybe (Set.Set FlattenedExp)) -> FlattenedExp -> Set.Set FlattenedExp -> Set.Set FlattenedExp
insert f e s = case Set.fold insert' (Set.empty,False) s of
(s',True) -> s'
_ -> Set.insert e s
where
insert' :: FlattenedExp -> (Set.Set FlattenedExp, Bool) -> (Set.Set FlattenedExp, Bool)
insert' e (s,b) = case f e s of
Just s' -> (s', True)
Nothing -> (Set.insert e s, False)
addConst :: Integer -> FlattenedExp -> Set.Set FlattenedExp -> Maybe (Set.Set FlattenedExp)
addConst x (Const n) s = Just $ Set.insert (Const (n + x)) s
addConst _ _ _ = Nothing
addScale :: Integer -> (A.Expression,Int) -> FlattenedExp -> Set.Set FlattenedExp -> Maybe (Set.Set FlattenedExp)
addScale x (lv,li) (Scale n (rv,ri)) s
| (EQ == compare lv rv) && (li == ri) = Just $ Set.insert (Scale (x + n) (rv,ri)) s
| otherwise = Nothing
addScale _ _ _ _ = Nothing
-- | A map from an item (a FlattenedExp, which may be a variable, or modulo\/divide item) to its coefficient in the problem.
type VarMap = Map.Map FlattenedExp CoeffIndex
-- | Background knowledge about a problem; either an equality or an inequality.
data BackgroundKnowledge
= Equal A.Expression A.Expression
| LessThanOrEqual A.Expression A.Expression
| RepBoundsIncl A.Variable A.Expression A.Expression
deriving (Typeable, Data)
instance Show BackgroundKnowledge where
show (Equal e e') = showOccam e ++ " = " ++ showOccam e'
show (LessThanOrEqual e e') = showOccam e ++ " <= " ++ showOccam e'
show (RepBoundsIncl v e e')
= showOccam e ++ " <= " ++ showOccam v ++ " <= " ++ showOccam e'
-- | The names relate to the equations given in my Omega Test presentation.
-- X is the top, Y is the bottom, A is the other var (x REM y = x + a)
data ModuloCase =
XZero
| XPos | XNeg -- these two are for constant divisor, all the ones below are for variable divisor
| XPosYPosAZero | XPosYPosANonZero
| XPosYNegAZero | XPosYNegANonZero
| XNegYPosAZero | XNegYPosANonZero
| XNegYNegAZero | XNegYNegANonZero
deriving (Show, Eq, Ord)
type BKM = StateT VarMap (ReaderT CompState (Either String))
-- | Transforms background knowledge into problems
-- TODO allow modulo in background knowledge
transformBK :: ([FlattenedExp] -> [FlattenedExp]) -> BackgroundKnowledge ->
BKM (EqualityProblem,InequalityProblem)
transformBK f (Equal eL eR) = do eL' <- makeSingleEq f eL "background knowledge"
eR' <- makeSingleEq f eR "background knowledge"
let e = addEq eL' (amap negate eR')
return ([e],[])
transformBK f (LessThanOrEqual eL eR)
= do eL' <- makeSingleEq f eL "background knowledge"
eR' <- makeSingleEq f eR "background knowledge"
-- eL <= eR implies eR - eL >= 0
let e = addEq (amap negate eL') eR'
return ([],[e])
transformBK f (RepBoundsIncl v low high)
= do eLow <- makeSingleEq f low "background knowledge, lower bound"
eHigh <- makeSingleEq f high "background knowledge, upper bound"
-- v <= eH implies eH - v >= 0
-- eL <= v implies v - eL >= 0
ev <- makeEquation v ([], id) (error "Irrelevant type") [Scale 1 (A.ExprVariable emptyMeta v, 0)]
>>= getSingleAccessItem ("Modulo or divide impossible")
ev' <- makeEquation v ([], id) (error "Irrelevant type") [Scale 1 (A.ExprVariable emptyMeta v, 1)]
>>= getSingleAccessItem ("Modulo or divide impossible")
return ([], [ addEq (amap negate ev) eHigh
, addEq (amap negate ev') eHigh
, addEq (amap negate eLow) ev
, addEq (amap negate eLow) ev'
])
transformBKList :: ([FlattenedExp] -> [FlattenedExp]) -> [BackgroundKnowledge] -> BKM (EqualityProblem,InequalityProblem)
transformBKList f bk = mapM (transformBK f) bk >>* foldl accumProblem ([],[])
-- | Turns a single expression into an equation-item. An error is given if the resulting
-- expression is anything complicated (for example, modulo or divide)
makeSingleEq :: ([FlattenedExp] -> [FlattenedExp]) -> A.Expression -> String -> BKM EqualityConstraintEquation
makeSingleEq f e desc = (lift (flatten e) >>* f) >>= makeEquation e ([{-TODO-}],
f) (error $ "Type is irrelevant for " ++ desc)
>>= getSingleAccessItem ("Modulo or Divide not allowed in " ++ desc
++ "(while processing: " ++ showOccam e ++ ")")
-- | A helper function for joining two problems
accumProblem :: (EqualityProblem,InequalityProblem) -> (EqualityProblem,InequalityProblem) -> (EqualityProblem,InequalityProblem)
accumProblem = concatPair
-- | Given a list of (written,read) expressions, an expression representing the upper array bound, returns either an error
-- (because the expressions can't be handled, typically) or a set of equalities, inequalities and mapping from
-- (unique, munged) variable name to variable-index in the equations.
--
-- The general strategy is as follows.
-- For every array index (here termed an "access"), we transform it into
-- the usual @[FlattenedExp]@ using the flatten function. Then we also transform
-- any access that is in the mirror-side of a Replicated item into its mirrored version
-- where each i is changed into i\'. This is done by using @vi=(variable "i",0)@
-- (in @Scale _ vi@) for the plain (normal) version, and @vi=(variable "i",1)@
-- for the prime (mirror) version.
--
-- Then the equations have bounds added. The rules are fairly simple; if
-- any of the transformed EqualityConstraintEquation (or related equalities or inequalities) representing an access
-- have a non-zero i (and\/or i\'), the bound for that variable is added.
-- So for example, an expression like i = i\' + 3 would have the bounds for
-- both i and i\' added (which would be near-identical, e.g. 1 <= i <= 6 and
-- 1 <= i\' <= 6). We have to check the equalities and inequalities because
-- when processing modulo, for the i REM y == 0 option, i will not appear in
-- the index itself (which will be 0) but will appear in the surrounding
-- constraints, and we still want to add the replication bounds.
--
-- The remainder of the work (correctly pairing equations) is done by
-- squareAndPair.
--
-- TODO probably want to take this into the PassM monad at some point, to use the Meta in the error message
makeEquations :: ParItems (BK, [A.Expression], [A.Expression]) -> A.Expression ->
ReaderT CompState (Either String) [(((A.Expression, [ModuloCase]), (A.Expression, [ModuloCase])), VarMap, (EqualityProblem, InequalityProblem))]
makeEquations accesses bound
= do ((v,h,repVarIndexes, allReps),s) <- (flip runStateT) Map.empty $
do ((accesses', allReps),repVars) <- flip runStateT [] $ parItemToArrayAccessM mkEq accesses
high <- makeSingleEq id bound "upper bound"
return (accesses', high, nub repVars, allReps)
lift $ squareAndPair (lookupBK allReps) (\(x,y,_) -> (x,y)) repVarIndexes s v (amap (const 0) h, addConstant (-1) h)
where
lookupBK :: [A.Name] -> (A.Expression, [ModuloCase], BK') -> Either String
[(EqualityProblem, InequalityProblem)]
lookupBK reps (e,_,bk) = liftM (filter (\x ->
(not $ null $ fst x) || (not $ null $ snd x))) $
mapM (foldl (liftM2 accumProblem) (return ([],[])) . map snd .
filter (\(v,_) -> v `elem` vs || v `elem` reps') . Map.toList) bk
where
reps' :: [Var]
reps' = map (Var . A.Variable emptyMeta) reps
vs :: [Var]
vs = map Var $ listify (const True :: A.Variable -> Bool) e
-- | A front-end to the setIndexVar' function
setIndexVar :: A.Variable -> Int -> [FlattenedExp] -> [FlattenedExp]
setIndexVar tv ti = map (setIndexVar' tv ti)
-- | Sets the sub-index of the specified variable throughout the expression
setIndexVar' :: A.Variable -> Int -> FlattenedExp -> FlattenedExp
setIndexVar' tv ti s@(Scale n (v,_))
| EQ == compare (A.ExprVariable emptyMeta tv) v = Scale n (v,ti)
| otherwise = s
setIndexVar' tv ti (Modulo n top bottom) = Modulo n top' bottom'
where
top' = Set.map (setIndexVar' tv ti) top
bottom' = Set.map (setIndexVar' tv ti) bottom
setIndexVar' tv ti (Divide n top bottom) = Divide n top' bottom'
where
top' = Set.map (setIndexVar' tv ti) top
bottom' = Set.map (setIndexVar' tv ti) bottom
setIndexVar' _ _ e = e
-- | Given a list of replicators (marked enabled\/disabled by a flag), the writes and reads,
-- turns them into a single list of accesses with all the relevant information. The writes and reads
-- can be grouped together because they are differentiated by the ArrayAccessType in the result
mkEq :: [((A.Name, A.Replicator), Bool)] ->
(BK, [A.Expression], [A.Expression]) ->
StateT [(CoeffIndex, CoeffIndex)]
BKM
[((A.Expression, [ModuloCase], BK'), ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]
mkEq reps (bk, ws, rs) = do repVarEqs <- mapM (liftF makeRepVarEq) reps
concatMapM (mkEq' repVarEqs) (ws' ++ rs')
where
ws' = zip (repeat AAWrite) ws
rs' = zip (repeat AARead) rs
makeRepVarEq :: ((A.Name, A.Replicator), Bool) -> BKM (A.Variable, EqualityConstraintEquation, EqualityConstraintEquation)
makeRepVarEq ((varName, A.For m from for _), _)
= do from' <- makeSingleEq id from "replication start"
upper <- makeSingleEq id (subExprsInt (addExprsInt for from) (makeConstant m 1)) "replication count"
return (A.Variable m varName, from', upper)
mkEq' :: [(A.Variable, EqualityConstraintEquation, EqualityConstraintEquation)] ->
(ArrayAccessType, A.Expression) ->
StateT [(CoeffIndex,CoeffIndex)]
BKM
[((A.Expression, [ModuloCase], BK'), ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]
mkEq' repVarEqs (aat, e)
= do f <- lift . lift $ flatten e
mirrorFunc <- liftM foldFuncs $ mapM mirrorFlaggedVar reps
g <- lift $ makeEquation e (bk, mirrorFunc) aat (mirrorFunc f)
case g of
Group g' -> return g'
_ -> throwError "Replicated group found unexpectedly"
-- | Turns all instances of the variable from the given replicator into their primed version in the given expression
mirrorFlaggedVar :: ((A.Name, A.Replicator),Bool) -> StateT [(CoeffIndex,CoeffIndex)] BKM ([FlattenedExp] -> [FlattenedExp])
mirrorFlaggedVar (_,False) = return id
mirrorFlaggedVar ((varName, A.For m from for _), True)
= do varIndexes <- lift $ seqPair (varIndex (Scale 1 (A.ExprVariable emptyMeta var,0)), varIndex (Scale 1 (A.ExprVariable emptyMeta var,1)))
modify (varIndexes :)
return $ setIndexVar var 1
where
var = A.Variable m varName
instance Die (ReaderT CompState (Either String)) where
dieReport (_, s) = throwError s
-- Note that in all these functions, the divisor should always be positive!
canonicalise :: forall m. (CSMR m, Die m) => A.Expression -> m A.Expression
canonicalise e@(A.FunctionCall m n es)
= do mOp <- functionOperator n
ts <- mapM astTypeOf es
case (mOp, fmap (\op -> A.nameName n == occamDefaultOperator op ts) mOp) of
(Just op, Just True) | op == "+" || op == "*"
-> liftM (foldl1 (\a b -> A.FunctionCall m n [a, b]) . sort) $ gatherTerms n e
_ -> mapM canonicalise es >>* A.FunctionCall m n
where
gatherTerms :: A.Name -> A.Expression -> m [A.Expression]
gatherTerms n (A.FunctionCall _ n' es) | n == n'
= concatMapM (gatherTerms n) es
gatherTerms _ e = canonicalise e >>* singleton
canonicalise e = return e
flatten :: A.Expression -> ReaderT CompState (Either String) [FlattenedExp]
flatten (A.Literal _ _ (A.IntLiteral _ n)) = return [Const (read n)]
flatten e@(A.FunctionCall m fn [lhs, rhs])
= do mOp <- builtInOperator fn
case mOp of
Just "+" -> combine' (flatten lhs) (flatten rhs)
Just "-" -> combine' (flatten lhs) (mapM (scale (-1)) =<< flatten rhs)
Just "*" -> multiplyOut' (flatten lhs) (flatten rhs)
Just "\\" -> liftM2L (Modulo 1) (flatten lhs) (flatten rhs)
Just "/" ->do rhs' <- flatten rhs
case onlyConst rhs' of
Just _ -> liftM2L (Divide 1) (flatten lhs) (return rhs')
-- Can't deal with variable divisors, leave expression as-is:
Nothing -> do e' <- canonicalise e
return [Scale 1 (e',0)]
_ -> do e' <- canonicalise e
return [Scale 1 (e',0)]
where
liftM2L :: MonadError String m => (Set.Set FlattenedExp -> Set.Set FlattenedExp -> c)
-> m [FlattenedExp] -> m [FlattenedExp] -> m [c]
liftM2L f x y = liftM singleton $ liftM2 f (x >>= makeExpSet) (y >>= makeExpSet)
multiplyOut' :: (Die m, CSMR m, MonadError String m ) => m [FlattenedExp] -> m [FlattenedExp] -> m [FlattenedExp]
multiplyOut' x y = join $ liftM2 multiplyOut x y
multiplyOut :: forall m. (Die m, CSMR m, MonadError String m) => [FlattenedExp] -> [FlattenedExp] -> m [FlattenedExp]
multiplyOut lhs rhs = mapM (uncurry mult) pairs
where
pairs = product2 (lhs,rhs)
mult :: FlattenedExp -> FlattenedExp -> m FlattenedExp
mult (Const x) e = scale x e
mult e (Const x) = scale x e
mult lhs rhs
= do lhs' <- backToEq lhs
rhs' <- backToEq rhs
e <- mulExprs lhs' rhs' >>= canonicalise
return $ (Scale 1 (e, 0))
backScale :: Integer -> A.Expression -> m A.Expression
backScale 1 e = return e
backScale n e = do t <- astTypeOf e
mulExprs (makeConstant' emptyMeta t n) e >>= canonicalise
backToEq :: FlattenedExp -> m A.Expression
backToEq (Const c) = return $ makeConstant emptyMeta (fromInteger c)
backToEq (Scale n (e,0)) = backScale n e
backToEq (Modulo n t b)
| Set.null t || Set.null b = throwError "Modulo had empty top or bottom"
| otherwise = do t' <- mapM backToEq $ Set.toList t
b' <- mapM backToEq $ Set.toList b
t'' <- foldM1 addExprs t'
b'' <- foldM1 addExprs b'
remExprs t'' b'' >>= backScale n
backToEq (Divide n t b)
| Set.null t || Set.null b = throwError "Divide had empty top or bottom"
| otherwise = do t' <- mapM backToEq $ Set.toList t
b' <- mapM backToEq $ Set.toList b
t'' <- foldM1 addExprs t'
b'' <- foldM1 addExprs b'
divExprs t'' b'' >>= backScale n
-- | Scales a flattened expression by the given integer scaling.
scale :: Monad m => Integer -> FlattenedExp -> m FlattenedExp
scale sc (Const n) = return $ Const (n * sc)
scale sc (Scale n v) = return $ Scale (n * sc) v
scale sc (Modulo n t b) = return $ Modulo (n * sc) t b
scale sc (Divide n t b) = return $ Divide (n * sc) t b
-- | An easy way of applying combine to two monadic returns
combine' :: Monad m => m [FlattenedExp] -> m [FlattenedExp] -> m [FlattenedExp]
combine' = liftM2 combine
-- | Combines (adds) two flattened expressions.
combine :: [FlattenedExp] -> [FlattenedExp] -> [FlattenedExp]
combine = (++)
flatten e = do e' <- canonicalise e
return [Scale 1 (e',0)]
-- | The "square" refers to making all equations the length of the longest
-- one, and the pair refers to pairing each in a list of array accesses (e.g.
-- [0, 5, i + 2]) into all possible pairings ([0 == 5, 0 == i + 2, 5 == i + 2])
--
-- There are two complications to this function.
--
-- Firstly, the array accesses are not actually given in a plain list, but
-- instead a list of lists. This is because for things like modulo, there are
-- groups of possible accesses that should not be paired against each other.
-- For example, you may have something like [0,x,-x] as the three possible
-- options for a modulo. You want to pair the accesses against other accesses
-- (e.g. y + 6), but not against each other. So the arguments are passed in
-- in groups: [[0,x,-x],[y + 6]] and groups are paired against each other,
-- but not against themselves. This all refers to the third argument to the
-- function. Each item is actually a triple of (item, equalities, inequalities)
-- because the modulo aspect adds additional constraints.
--
-- The other complication comes from replicated variables.
-- The first argument is a list of (plain,prime) coefficient indexes
-- that effectively labels the indexes related to replicated variables.
-- squareAndPair does two things with this information:
-- 1. It discards all equations that feature only the prime version of
-- a variable. You might have passed in the accesses as [[i],[i'],[3]].
-- (Altering the grouping would not be able to solve this particular problem)
-- The pairings generated would be [i == i', i == 3, i' == 3]. But the
-- last two are in effect identical. Therefore we drop the i' prime
-- version, because it has i' but not i. In contrast, the first item
-- (i == i') is retained because it features both i and i'.
-- 2. For every equation that features both i and i', it adds
-- the inequality "i <= i' - 1". Because all possible combinations of
-- accesses are examined, in the case of [i,i + 1,i', i' + 1], the pairing
-- will produce both "i = i' + 1" and "i + 1 = i'" so there is no need
-- to vary the inequality itself.
squareAndPair ::
(label -> Either String [(EqualityProblem, InequalityProblem)]) ->
(label -> labelStripped) ->
[(CoeffIndex, CoeffIndex)] ->
VarMap ->
[ArrayAccess label] ->
(EqualityConstraintEquation, EqualityConstraintEquation) ->
Either String [((labelStripped, labelStripped), VarMap, (EqualityProblem, InequalityProblem))]
squareAndPair lookupBK strip repVars s v lh
= concatMapM id
[let f ((bkEqA, bkIneqA), (bkEqB, bkIneqB))
= (transformPair strip strip labels,
s,
squareEquations (nub (bkEqA ++ bkEqB) ++ eq,
nub (bkIneqA ++ bkIneqB) ++ ineq ++ concat (applyAll (eq,ineq) (map addExtra repVars))))
bk = case liftM2 (curry product2) (lookupBK (fst labels)) (lookupBK (snd labels)) of
Right [] -> Right [(([],[]),([],[]))] -- No BK
xs -> xs
in bk >>* map f
| (labels, eq,ineq) <- pairEqsAndBounds v lh
,and (map (primeImpliesPlain (eq,ineq)) repVars)
]
where
itemPresent :: CoeffIndex -> [Array CoeffIndex Integer] -> Bool
itemPresent x = any (\a -> arrayLookupWithDefault 0 a x /= 0)
primeImpliesPlain :: (EqualityProblem,InequalityProblem) -> (CoeffIndex,CoeffIndex) -> Bool
primeImpliesPlain (eq,ineq) (plain,prime) =
if itemPresent prime (eq ++ ineq)
-- There are primes, check all the plains are present:
then itemPresent plain (eq ++ ineq)
-- No prime, therefore fine:
else True
addExtra :: (CoeffIndex, CoeffIndex) -> (EqualityProblem,InequalityProblem) -> InequalityProblem
addExtra (plain,prime) (eq, ineq)
-- prime >= plain + 1 (prime - plain - 1 >= 0)
= [mapToArray $ Map.fromList [(prime,1), (plain,-1), (0, -1)]]
getSingleAccessItem :: MonadError String m => String -> ArrayAccess label -> m EqualityConstraintEquation
getSingleAccessItem _ (Group [(_,_,(acc,_,_))]) = return acc
getSingleAccessItem err _ = throwError err
-- | Odd helper function for getting\/asserting the first item of a triple from a singleton list inside a monad transformer (!)
getSingleItem :: MonadError String m => String -> [(a,b,c)] -> m a
getSingleItem _ [(item,_,_)] = return item
getSingleItem err _ = throwError err
-- | Finds the index associated with a particular variable; either by finding an existing index
-- or allocating a new one.
varIndex :: FlattenedExp -> BKM Int
varIndex (Scale _ (e,vi))
= do st <- get
let (st',ind) = case Map.lookup (Scale 1 (e,vi)) st of
Just val -> (st,val)
Nothing -> let newId = (1 + (maximum $ 0 : Map.elems st)) in
(Map.insert (Scale 1 (e,vi)) newId st, newId)
put st'
return ind
varIndex mod@(Modulo _ top bottom)
= do st <- get
let (st',ind) = case Map.lookup mod st of
Just val -> (st,val)
Nothing -> let newId = (1 + (maximum $ 0 : Map.elems st)) in
(Map.insert mod newId st, newId)
put st'
return ind
varIndex div@(Divide _ top bottom)
= do st <- get
let (st',ind) = case Map.lookup div st of
Just val -> (st,val)
Nothing -> let newId = (1 + (maximum $ 0 : Map.elems st)) in
(Map.insert div newId st, newId)
put st'
return ind
-- | Pairs all possible combinations of the list of equations.
pairEqsAndBounds :: [ArrayAccess label] -> (EqualityConstraintEquation, EqualityConstraintEquation) -> [((label,label),EqualityProblem, InequalityProblem)]
pairEqsAndBounds items bounds = (concatMap (uncurry pairEqs) . allPairs) items ++ concatMap pairRep items
where
pairEqs :: ArrayAccess label
-> ArrayAccess label
-> [((label,label),EqualityProblem, InequalityProblem)]
pairEqs (Group accs) (Group accs') = mapMaybe (uncurry pairEqs'') $ product2 (accs,accs')
pairEqs (Replicated rA rB) lacc
= concatMap (pairEqs lacc) rA
pairEqs lacc (Replicated rA rB)
= concatMap (pairEqs lacc) rA
-- Used to pair the items of a single instance of PAR replication with each other
pairRep :: ArrayAccess label -> [((label,label),EqualityProblem, InequalityProblem)]
pairRep (Replicated rA rB) = concatMap (uncurry pairEqs) (product2 (rA,rB))
++ concatMap (uncurry pairEqs) (allPairs rA)
pairRep _ = []
pairEqs'' :: (label, ArrayAccessType,(EqualityConstraintEquation, EqualityProblem, InequalityProblem))
-> (label, ArrayAccessType,(EqualityConstraintEquation, EqualityProblem, InequalityProblem))
-> Maybe ((label,label), EqualityProblem, InequalityProblem)
pairEqs'' (lx,x,x') (ly,y,y') = case pairEqs' (x,x') (y,y') of
Just (eq,ineq) -> Just ((lx,ly),eq,ineq)
Nothing -> Nothing
pairEqs' :: (ArrayAccessType,(EqualityConstraintEquation, EqualityProblem, InequalityProblem))
-> (ArrayAccessType,(EqualityConstraintEquation, EqualityProblem, InequalityProblem))
-> Maybe (EqualityProblem, InequalityProblem)
pairEqs' (AARead,_) (AARead,_) = Nothing
pairEqs' (_,(ex,eqX,ineqX)) (_,(ey,eqY,ineqY)) = Just ([arrayZipWith' 0 (-) ex ey] ++ eqX ++ eqY, ineqX ++ ineqY ++ getIneqs bounds [ex,ey])
addEq :: EqualityConstraintEquation -> EqualityConstraintEquation -> EqualityConstraintEquation
addEq = arrayZipWith' 0 (+)
-- | Given a (low,high) bound (typically: array dimensions), and a list of equations ex,
-- forms the possible inequalities:
-- * ex >= low
-- * ex <= high
getIneqs :: (EqualityConstraintEquation, EqualityConstraintEquation) -> [EqualityConstraintEquation] -> [InequalityConstraintEquation]
getIneqs (low, high) = concatMap getLH
where
-- eq >= low => eq - low >= 0
-- eq <= high => high - eq >= 0
getLH :: EqualityConstraintEquation -> [InequalityConstraintEquation]
getLH eq = [eq `addEq` (amap negate low),high `addEq` amap negate eq]
justState :: Error e => StateT s (ReaderT r (Either e)) a -> StateT s (ReaderT r (Either e)) (Either e a)
justState m = do st <- get
r <- ask
let (x, st') = case runReaderT (runStateT m st) r of
Left err -> (Left err, st)
Right (x, st') -> (Right x, st')
put st'
return x
-- | Given an expression, forms equations (and accompanying additional equation-sets) and returns it
makeEquation :: label -> (BK, [FlattenedExp] -> [FlattenedExp]) -> ArrayAccessType -> [FlattenedExp]
-> BKM (ArrayAccess (label,[ModuloCase], BK'))
makeEquation l (bk, bkF) t summedItems
= do eqs <- process summedItems
bk' <- mapM (mapMapM (justState . transformBKList bkF)) bk
let eqs' = map (transformQuad id mapToArray (map mapToArray) (map mapToArray)) eqs :: [([ModuloCase], EqualityConstraintEquation, EqualityProblem, InequalityProblem)]
return $ Group [((l,c,bk'),t,(e0,e1,e2)) | (c,e0,e1,e2) <- eqs']
where
process :: [FlattenedExp] -> BKM [([ModuloCase], Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])]
process = foldM makeEquation' empty
makeEquation' :: [([ModuloCase], Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])] ->
FlattenedExp ->
BKM
[([ModuloCase], Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])]
makeEquation' m (Const n) = return $ add (0,n) m
makeEquation' m sc@(Scale n v) = varIndex sc >>* (\ind -> add (ind, n) m)
makeEquation' m mod@(Modulo n top bottom)
= do top' <- process (Set.toList top) >>* map (\(_,a,b,c) -> (a,b,c))
top'' <- getSingleItem "Modulo or divide not allowed in the numerator of Modulo" top'
bottom' <- process (Set.toList bottom) >>* map (\(_,a,b,c) -> (a,b,c))
modIndex <- varIndex mod
case onlyConst (Set.toList bottom) of
Just bottomConst ->
let add_x_plus_my = zipMap plus top'' . zipMap plus (Map.fromList [(modIndex, abs bottomConst)]) in
-- Adds n*(x + my)
let add_n_x_plus_my = zipMap plus (Map.map (*n) top'') . zipMap plus (Map.fromList [(modIndex, n * abs bottomConst)]) in
return $
-- The zero option (x = 0, x REM y = 0):
( map (transformQuad (++ [XZero]) id (++ [top'']) id) m)
++
-- The top-is-positive option:
( map (transformQuad (++ [XPos]) add_n_x_plus_my id (++
-- x >= 1
[zipMap plus (Map.fromList [(0,-1)]) top''
-- m <= 0
,Map.fromList [(modIndex,-1)]
-- x + my + 1 - |y| <= 0
,Map.map negate $ add_x_plus_my $ Map.fromList [(0,1 - abs bottomConst)]
-- x + my >= 0
,add_x_plus_my $ Map.empty])
) m) ++
-- The top-is-negative option:
( map (transformQuad (++ [XNeg]) add_n_x_plus_my id (++
-- x <= -1
[add' (0,-1) $ Map.map negate top''
-- m >= 0
,Map.fromList [(modIndex,1)]
-- x + my - 1 + |y| >= 0
,add_x_plus_my $ Map.fromList [(0,abs bottomConst - 1)]
-- x + my <= 0
,Map.map negate $ add_x_plus_my Map.empty])
) m)
_ ->
do bottom'' <- getSingleItem "Modulo or divide not allowed in the divisor of Modulo" bottom'
return $
-- The zero option (x = 0, x REM y = 0):
(map (transformQuad (++ [XZero]) id (++ [top'']) id) m)
-- The rest:
++ twinItems True True n (top'', modIndex) bottom''
++ twinItems True False n (top'', modIndex) bottom''
++ twinItems False True n (top'', modIndex) bottom''
++ twinItems False False n (top'', modIndex) bottom''
where
-- Each pair for modulo (variable divisor) depending on signs of x and y (in x REM y):
twinItems :: Bool -> Bool -> Integer -> (Map.Map Int Integer,Int) -> Map.Map Int Integer ->
[([ModuloCase], Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])]
twinItems xPos yPos n (top,modIndex) bottom
= (map (transformQuad (++ [findCase xPos yPos False]) (zipMap plus $ Map.map (*n) top) id
(++ [xEquation]
++ [xLowerBound False]
++ [xUpperBound False])) m)
++ (map (transformQuad (++ [findCase xPos yPos True]) (zipMap plus (Map.map (*n) top) . add' (modIndex, n)) id
(++ [xEquation]
++ [xLowerBound True]
++ [xUpperBound True]
-- We want to add the bounds for a and y as follows:
-- xPos yPos | Equation
-- T T | -y - a >= 0
-- T F | y - a >= 0
-- F T | a - y >= 0
-- F F | a + y >= 0
-- Therefore the sign of a is (not xPos), the sign of y is (not yPos)
++ [add' (modIndex,if xPos then -1 else 1) (signEq (not yPos) bottom)])) m)
where
-- x >= 1 or x <= -1 (rearranged: -1 + x >= 0 or -1 - x >= 0)
xEquation = add' (0,-1) (signEq xPos top)
-- We include (x [+ a] >= 0 or x [+ a] <= 0) even though they are redundant in some cases (addA = False):
xLowerBound addA = signEq xPos $ (if addA then add' (modIndex,1) else id) top
-- We want to add the bounds as follows:
-- xPos yPos | Equation
-- T T | y - 1 - x - a >= 0
-- T F | -y - 1 - x - a >= 0
-- F T | x + a - 1 + y >= 0
-- F F | x + a - y - 1 >= 0
-- Therefore the sign of y in the equation is yPos, the sign of x and a is (not xPos)
xUpperBound addA = add' (0,-1) $ zipMap plus (signEq (not xPos) ((if addA then add' (modIndex,1) else id) top)) (signEq yPos bottom)
signEq sign eq = if sign then eq else Map.map negate eq
findCase xPos yPos aNonZero = case (xPos, yPos, aNonZero) of
(True , True , True ) -> XPosYPosANonZero
(True , True , False) -> XPosYPosAZero
(True , False, True ) -> XPosYNegANonZero
(True , False, False) -> XPosYNegAZero
(False, True , True ) -> XNegYPosANonZero
(False, True , False) -> XNegYPosAZero
(False, False, True ) -> XNegYNegANonZero
(False, False, False) -> XNegYNegAZero
makeEquation' m div@(Divide n top bottom)
= do top' <- process (Set.toList top) >>* map (\(_,a,b,c) -> (a,b,c))
top'' <- getSingleItem "Modulo or Divide not allowed in the numerator of Divide" top'
bottom' <- process (Set.toList bottom) >>* map (\(_,a,b,c) -> (a,b,c))
divIndex <- varIndex div
case onlyConst (Set.toList bottom) of
Just bottomConst ->
let add_m :: Map.Map Int Integer -> Map.Map Int Integer
add_m = zipMap plus (Map.fromList [(divIndex,n)])
add_x_minus_my = zipMap plus top'' . zipMap plus (Map.fromList [(divIndex,-bottomConst)]) in
return $
-- The zero option (x = 0, x REM y = 0):
( map (transformQuad (++ [XZero]) id (++ [top'']) id) m)
++
-- The top-is-positive option:
( map (transformQuad (++ [XPos]) add_m id (++
-- x >= 1
[zipMap plus (Map.fromList [(0,-1)]) top''
-- m >= 0 if y positive
-- m <= 0 (i.e. -m >= 0) if y negative
,Map.fromList [(divIndex, signum bottomConst)]
-- x + my + 1 - y <= 0 if y positive
-- x + my - 1 - y >= 0 if y negative
,(if (bottomConst > 0) then Map.map negate else id) $ add_x_minus_my $ Map.fromList [(0,signum bottomConst - bottomConst)]
-- x + my >= 0 if y positive
-- x + my <= 0 if negative
,(if (bottomConst > 0) then id else Map.map negate) $ add_x_minus_my $ Map.empty])
) m) ++
-- The top-is-negative option:
( map (transformQuad (++ [XNeg]) add_m id (++
-- x <= -1
[add' (0,-1) $ Map.map negate top''
-- m <= 0 if y positive
-- m >= 0 if y negative
,Map.fromList [(divIndex, - signum bottomConst)]
-- x + my - 1 + y >= 0 if y positive
-- x + my + 1 + y <= 0 if y negative
,(if (bottomConst > 0) then id else Map.map negate) $ add_x_minus_my $ Map.fromList [(0,bottomConst - signum bottomConst)]
-- x + my <= 0 if y positive
-- x + my >= 0 if y negative
,(if (bottomConst > 0) then Map.map negate else id) $ add_x_minus_my Map.empty])
) m)
_ -> throwError "Variables in divisor not supported by usage checker"
empty :: [([ModuloCase],Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])]
empty = [([],Map.empty,[],[])]
plus :: Num n => Maybe n -> Maybe n -> Maybe n
plus x y = Just $ (fromMaybe 0 x) + (fromMaybe 0 y)
add' :: (Int,Integer) -> Map.Map Int Integer -> Map.Map Int Integer
add' (m,n) = Map.insertWith (+) m n
add :: (Int,Integer) -> [(z,Map.Map Int Integer,a,b)] -> [(z,Map.Map Int Integer,a,b)]
add (m,n) = map $ (\(a,b,c,d) -> (a,(Map.insertWith (+) m n) b,c,d))
-- | Converts a map to an array. Any missing elements in the middle of the bounds are given the value zero.
-- Could probably be moved to Utils
mapToArray :: (IArray a v, Num v, Num k, Ord k, Ix k) => Map.Map k v -> a k v
mapToArray m = accumArray (+) 0 (0, highest') . Map.assocs $ m
where
highest' = maximum $ 0 : Map.keys m
-- | Given a pair of equation sets, makes all the equations in the lists be the length
-- of the longest equation. All missing elements are of course given value zero.
squareEquations :: ([Array CoeffIndex Integer],[Array CoeffIndex Integer]) -> ([Array CoeffIndex Integer],[Array CoeffIndex Integer])
squareEquations (eqs,ineqs) = uncurry transformPair (mkPair $ map $ makeArraySize (0,highest) 0) (eqs,ineqs)
where
highest = maximum $ 0 : (concatMap indices $ eqs ++ ineqs)