1074 lines
56 KiB
Haskell
1074 lines
56 KiB
Haskell
{-
|
|
Tock: a compiler for parallel languages
|
|
Copyright (C) 2007--2009 University of Kent
|
|
|
|
This program is free software; you can redistribute it and/or modify it
|
|
under the terms of the GNU General Public License as published by the
|
|
Free Software Foundation, either version 2 of the License, or (at your
|
|
option) any later version.
|
|
|
|
This program is distributed in the hope that it will be useful, but
|
|
WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
|
General Public License for more details.
|
|
|
|
You should have received a copy of the GNU General Public License along
|
|
with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
-}
|
|
|
|
module ArrayUsageCheck (
|
|
BackgroundKnowledge(..),
|
|
BK,
|
|
canonicalise,
|
|
checkArrayUsage,
|
|
findRepSolutions,
|
|
FlattenedExp(..),
|
|
fmapFlattenedExp,
|
|
makeEquations,
|
|
makeExpSet,
|
|
ModuloCase(..),
|
|
onlyConst,
|
|
showFlattenedExp,
|
|
VarMap) where
|
|
|
|
import Control.Monad.Error
|
|
import Control.Monad.Reader
|
|
import Control.Monad.State
|
|
import Data.Array.IArray
|
|
import qualified Data.Foldable as F
|
|
import Data.Generics hiding (GT)
|
|
import Data.Int
|
|
import Data.List
|
|
import qualified Data.Map as Map
|
|
import Data.Maybe
|
|
import qualified Data.Set as Set
|
|
import qualified Data.Traversable as T
|
|
|
|
import qualified AST as A
|
|
import CompState
|
|
import Errors
|
|
import Metadata
|
|
import Omega
|
|
import OrdAST()
|
|
import Pass
|
|
import ShowCode
|
|
import Types
|
|
import UsageCheckUtils
|
|
import Utils
|
|
|
|
-- Each list is a possible set of background knowledge mapping vars to a list
|
|
-- of constraints. So it is a disjunction of map from variables to conjunctions
|
|
type BK = [Map.Map Var [BackgroundKnowledge]]
|
|
type BK' = [Map.Map Var (Either String (EqualityProblem, InequalityProblem))]
|
|
|
|
-- | Given a list of replicators, and a set of background knowledge for each
|
|
-- access inside the replicator, checks if there are any solutions for a
|
|
-- combination of the normal replicator constraints, and the given background
|
|
-- knowledge (pairing each set against each other, applying one set to the replicator,
|
|
-- and the other to the mirror of the replicator).
|
|
--
|
|
-- Returns Nothing if no solutions, a String with a counter-example if there
|
|
-- are solutions
|
|
findRepSolutions :: (CSMR m, MonadIO m) => [(A.Name, A.Replicator)] -> [BK] -> m (Maybe String)
|
|
findRepSolutions reps bks
|
|
-- To get the right comparison, we create a SeqItems with all the accesses
|
|
-- Because they are inside a PAR replicator, they will all get compared to each
|
|
-- other with one set of BK applied to i and one applied to i', but they will
|
|
-- never be compared to each other just with the constraints on i (which is not
|
|
-- what we are checking here). We set the dummy array accesses to all be zero,
|
|
-- which means they can overlap -- but only if there is also a solution to the
|
|
-- replicator background knowledge, which is what this function is trying to
|
|
-- determine.
|
|
= getCompState >>= \cs -> case flip runReaderT cs $ makeEquations (addReps $ SeqItems [(bk, [makeConstant emptyMeta 0], [])
|
|
| bk <- bks]) maxInt of
|
|
Right problems -> do
|
|
probs <- formatProblems [(vm, prob) | (_,vm,prob) <- problems]
|
|
debug $ "Problems in findRepSolutions:\n" ++ probs
|
|
case catMaybes [fmap ((,) i) $ solve p | (i::Integer, p) <- zip [0..] problems] of
|
|
[] -> return Nothing -- No solutions, safe
|
|
xs -> liftM (Just . unlines) $ mapM format xs
|
|
res -> error $ "Unexpected reachability result"
|
|
where
|
|
maxInt = makeConstant emptyMeta $ fromInteger $ toInteger (maxBound :: Int32)
|
|
|
|
format (i, ((lx,ly),varMapping,vm,problem))
|
|
= formatSolution varMapping vm >>* (("#" ++ show i ++ ": ") ++)
|
|
|
|
addReps = flip (foldl $ flip RepParItem) reps
|
|
|
|
-- | A check-pass that checks the given ParItems (usually generated from a control-flow graph)
|
|
-- for any overlapping array indices.
|
|
checkArrayUsage :: forall m. (Die m, CSMR m, MonadIO m) => NameAttr -> (Meta, ParItems (BK, UsageLabel)) -> m ()
|
|
checkArrayUsage sharedAttr (m,p)
|
|
= do debug $ "checkArrayUsage: " ++ show m
|
|
indexes <- groupArrayIndexes $ fmap (transformPair id nodeVars) p
|
|
mapM_ (checkIndexes m) $ Map.toList $ Map.filter
|
|
((<= 1) . length . map (\(_,w,r) -> w++r) . F.toList) indexes
|
|
where
|
|
getDecl :: UsageLabel -> Maybe String
|
|
getDecl = join . fmap getScopeIn . nodeDecl
|
|
where
|
|
getScopeIn (ScopeIn _ n) = Just n
|
|
getScopeIn _ = Nothing
|
|
|
|
|
|
-- Takes a ParItems Vars, and returns a map from array-variable-name to a list of writes and a list of reads for that array.
|
|
-- Returns (array name, list of written-to indexes, list of read-from indexes)
|
|
groupArrayIndexes :: ParItems (BK, Vars) -> m (Map.Map (String, Maybe A.Direction) (ParItems (BK, [A.Expression], [A.Expression])))
|
|
groupArrayIndexes = liftM filterByKey . T.mapM
|
|
(\(bk,vs) -> do w <- makeList $ (Map.keysSet $ writtenVars vs) `Set.union` (usedVars vs)
|
|
r <- makeList $ readVars vs
|
|
return $ zipMap (join bk) w r)
|
|
where
|
|
join :: b -> Maybe [a] -> Maybe [a] -> Maybe (b, [a],[a])
|
|
join k x y = Just (k, fromMaybe [] x, fromMaybe [] y)
|
|
|
|
-- Turns a set of variables into a map (from array-name to list of index-expressions)
|
|
makeList :: Set.Set Var -> m (Map.Map (String, Maybe A.Direction) [A.Expression])
|
|
makeList vs = do indexes <- concatMapM getArrayIndex $ Set.toList vs
|
|
return $ Map.fromListWith (++) indexes
|
|
|
|
-- Lifts a map (from array-name to expression-lists) inside a ParItems to being a map (from array-name to ParItems of expression lists)
|
|
filterByKey :: ParItems (Map.Map (String, Maybe A.Direction) (BK, [A.Expression], [A.Expression]))
|
|
-> Map.Map (String, Maybe A.Direction) (ParItems (BK, [A.Expression], [A.Expression]))
|
|
filterByKey p = Map.fromList $ map trans keys
|
|
where
|
|
keys :: [(String, Maybe A.Direction)]
|
|
keys = concatMap Map.keys $ flattenParItems p
|
|
|
|
trans :: (String, Maybe A.Direction) -> ((String, Maybe A.Direction), ParItems (BK, [A.Expression], [A.Expression]))
|
|
trans k = (k, fmap (Map.findWithDefault ([], [], []) k) p)
|
|
|
|
-- Gets the (array-name, indexes) from a Var.
|
|
-- TODO this is quite hacky, and doesn't yet deal with slices and so on:
|
|
getArrayIndex :: Var -> m [((String, Maybe A.Direction), [A.Expression])]
|
|
getArrayIndex (Var v@(A.SubscriptedVariable _ (A.Subscript _ _ e) (A.Variable _ n)))
|
|
= do t <- astTypeOf v
|
|
let dirs = case t of
|
|
A.Chan {} -> [Just A.DirInput, Just A.DirOutput]
|
|
_ -> [Nothing]
|
|
return [((A.nameName n, d), [e]) | d <- dirs]
|
|
getArrayIndex (Var (A.SubscriptedVariable _ (A.Subscript _ _ e)
|
|
(A.DirectedVariable _ dir (A.Variable _ n))))
|
|
= return [((A.nameName n, Just dir), [e])]
|
|
getArrayIndex (Var (A.DirectedVariable _ dir (A.SubscriptedVariable _ (A.Subscript _ _ e)
|
|
(A.Variable _ n))))
|
|
= return [((A.nameName n, Just dir), [e])]
|
|
getArrayIndex _ = return []
|
|
|
|
-- Checks the given ParItems of writes and reads against each other. The
|
|
-- String (array-name) and Meta are only used for printing out error messages
|
|
checkIndexes :: Meta -> ((String, Maybe A.Direction), ParItems (BK, [A.Expression], [A.Expression])) -> m ()
|
|
checkIndexes m ((arrName, arrDir), indexes) = do
|
|
sharedNames <- getCompState >>* csNameAttr
|
|
let declNames = [x | Just x <- fmap (getDecl . snd) $ flattenParItems p]
|
|
when (fmap (Set.member sharedAttr) (Map.lookup arrName sharedNames) /= Just True && arrName `notElem` declNames) $
|
|
do userArrName <- getRealName (A.Name undefined arrName)
|
|
arrType <- astTypeOf (A.Name undefined arrName) >>= resolveUserType m
|
|
arrLength <- case arrType of
|
|
A.Array (A.Dimension d:_) _ -> return d
|
|
-- Unknown dimension, use the maximum value for a (assumed 32-bit for INT) integer:
|
|
A.Array (A.UnknownDimension:_) _ -> return $ makeConstant m $ fromInteger $ toInteger (maxBound :: Int32)
|
|
-- It's not an array:
|
|
_ -> dieP m $ "Cannot usage check array \"" ++ userArrName ++ "\"; found to be of type: " ++ show arrType
|
|
cs <- getCompState
|
|
case runReaderT (makeEquations indexes arrLength) cs of
|
|
Left err -> dieP m $ "Could not work with array indexes for array \"" ++ userArrName ++ "\": " ++ err
|
|
Right [] -> return () -- No problems to work with
|
|
Right problems -> do
|
|
probs <- formatProblems [(vm, prob) | (_,vm,prob) <- problems]
|
|
debug $ "Problems in checkArrayUsage" ++ show m ++ ":\n" ++ probs
|
|
case mapMaybe solve problems of
|
|
-- No solutions; no worries!
|
|
[] -> return ()
|
|
(((lx,ly),varMapping,vm,problem):_) ->
|
|
do sol <- formatSolution varMapping vm
|
|
cx <- showCode (fst lx)
|
|
cy <- showCode (fst ly)
|
|
-- liftIO $ putStrLn $ "Found solution for problem: " ++ probs
|
|
-- ++ show p
|
|
-- liftIO $ putStrLn $ "Succeeded on problem: " ++ prob
|
|
-- allProbs <- concatMapM (\(_,_,p) -> formatProblem varMapping p >>* (++ "\n#\n")) problems
|
|
-- svm <- mapM (showFlattenedExp showCode) $ Map.keys varMapping
|
|
-- liftIO $ putStrLn $ "All problems: " ++ allProbs ++ "\n" ++ (concat $ intersperse " ; " $ svm)
|
|
dieP m $ "Indexes of array \"" ++ userArrName ++ "\" "
|
|
++ "(\"" ++ cx ++ "\" and \"" ++ cy ++ "\") could overlap"
|
|
++ if sol /= "" then " when: " ++ sol else ""
|
|
|
|
-- TODO this is surely defined elsewhere already?
|
|
getRealName :: A.Name -> m String
|
|
getRealName n = lookupName n >>* A.ndOrigName
|
|
|
|
formatProblems :: CSMR m => [(VarMap, (EqualityProblem, InequalityProblem))] -> m String
|
|
formatProblems probs = do formatted <- mapM (uncurry formatProblem) probs
|
|
return $ concat [addNum i (lines p) | (p, i) <- zip formatted [0..]]
|
|
where
|
|
addNum :: Int -> [String] -> String
|
|
addNum i [] = ""
|
|
addNum i (p:ps) = unlines $
|
|
("#" ++ show i ++ (if length (show i) == 1 then " :" else ":")
|
|
++ p) : map (" " ++) ps
|
|
|
|
-- | Formats an entire problem ready to print it out half-legibly for debugging purposes
|
|
formatProblem :: forall m. CSMR m => VarMap -> (EqualityProblem, InequalityProblem) -> m String
|
|
formatProblem varToIndex (eq, ineq)
|
|
= do feqs <- mapM (showWithConst "=") $ eq
|
|
fineqs <- mapM (\e -> if allNegative e
|
|
then showWithConst "<=" (negateAll e)
|
|
else showWithConst ">=" e) $ ineq
|
|
return $ unlines $ feqs ++ fineqs
|
|
where
|
|
--Returns true if all the variable coefficients are negative (ignoring
|
|
-- the constant term)
|
|
allNegative :: Array CoeffIndex Integer -> Bool
|
|
allNegative = all (<= 0) . tail . elems
|
|
negateAll :: Array CoeffIndex Integer -> Array CoeffIndex Integer
|
|
negateAll = amap negate
|
|
|
|
showWithConst :: String -> Array CoeffIndex Integer -> m String
|
|
showWithConst op item = do text <- showEq item
|
|
return $
|
|
(if text == "" then "0" else text)
|
|
++ " " ++ op ++ " " ++ show (negate $ item ! 0)
|
|
|
|
showEq :: Array CoeffIndex Integer -> m String
|
|
showEq = liftM (joinWith " + ") . mapM showItem . filter ((/= 0) . snd) . tail . assocs
|
|
|
|
showItem :: (CoeffIndex, Integer) -> m String
|
|
showItem (n, a) = case find ((== n) . snd) $ Map.assocs varToIndex of
|
|
Just (exp,_) -> showFlattenedExp showCode exp >>* (mult ++)
|
|
Nothing -> return "<unknown>"
|
|
where
|
|
mult = case a of
|
|
1 -> ""
|
|
-1 -> "-"
|
|
_ -> show a ++ "*"
|
|
|
|
-- | Solves the problem and munges the arguments and results into a useful order
|
|
solve :: (labels,vm,(EqualityProblem,InequalityProblem)) ->
|
|
Maybe (labels,vm,VariableMapping,(EqualityProblem,InequalityProblem))
|
|
solve (ls,vm,(eq,ineq)) = case solveProblem eq ineq of
|
|
Nothing -> Nothing
|
|
Just vm' -> Just (ls,vm,vm',(eq,ineq))
|
|
|
|
-- | Formats a solution (not a problem, just the solution) ready to print it out for the user
|
|
formatSolution :: (CSMR m, Monad m) => VarMap -> VariableMapping -> m String
|
|
formatSolution varToIndex vm
|
|
= do names <- mapM valOfVar $ Map.assocs varToIndex
|
|
return $ joinWith " , " $ catMaybes names
|
|
where
|
|
indexToVar = flip lookup $ map revPair $ Map.assocs varToIndex
|
|
|
|
indexToVar' (0, x) = Just (Nothing, x)
|
|
indexToVar' (_, 0) = Nothing
|
|
indexToVar' (i, x) = case indexToVar i of
|
|
Just v -> Just (Just v, x)
|
|
Nothing -> Nothing
|
|
|
|
indexToConst = getCounterEqs vm
|
|
|
|
showWithCoeff' (Nothing, n) = return $ show n
|
|
showWithCoeff' (Just v, n) = liftM (mult ++) $ showFlattenedExp showCode v
|
|
where
|
|
mult = case n of
|
|
1 -> ""
|
|
-1 -> "-"
|
|
n -> show n ++ "*"
|
|
|
|
showWithCoeff xs = liftM (joinWith " + ") $ mapM showWithCoeff' xs
|
|
|
|
valOfVar (varExp,k) = case Map.lookup k indexToConst of
|
|
Nothing -> return Nothing
|
|
Just (Left (n, low, high)) ->
|
|
do varExp' <- showWithCoeff' (Just varExp, n)
|
|
low' <- mapM showWithCoeff $ map (mapMaybe indexToVar') low
|
|
high' <- mapM showWithCoeff $ map (mapMaybe indexToVar') high
|
|
return $ Just $ formatBounds (++ " <= ") low'
|
|
++ varExp' ++ formatBounds (" <= " ++) high'
|
|
Just (Right val) -> do varExp' <- showFlattenedExp showCode varExp
|
|
return $ Just $ varExp' ++ " = " ++ show val
|
|
|
|
formatBounds _ [] = ""
|
|
formatBounds f [b] = f b
|
|
formatBounds f bs = f $ "(" ++ joinWith "," bs ++ ")"
|
|
|
|
showFlattenedExpSet :: Monad m => (A.Expression -> m String) -> Set.Set FlattenedExp -> m String
|
|
showFlattenedExpSet showExp s = liftM concat $ sequence $ intersperse (return " + ") $ map (showFlattenedExp showExp) $ Set.toList s
|
|
|
|
-- Shows a FlattenedExp legibly by looking up real names for variables, and formatting things.
|
|
-- The output for things involving modulo might be a bit odd, but there isn't really anything
|
|
-- much that can be done about that
|
|
showFlattenedExp :: Monad m => (A.Expression -> m String) -> FlattenedExp -> m String
|
|
showFlattenedExp _ (Const n) = return $ show n
|
|
showFlattenedExp showExp (Scale n (e,vi))
|
|
= do vn' <- showExp e >>* (++ replicate vi '\'')
|
|
return $ showScale vn' n
|
|
showFlattenedExp showExp (Modulo n top bottom)
|
|
= do top' <- showFlattenedExpSet showExp top
|
|
bottom' <- showFlattenedExpSet showExp bottom
|
|
case onlyConst (Set.toList bottom) of
|
|
Just _ -> return $ showScale ("(" ++ top' ++ " / " ++ bottom' ++ ")") (-n)
|
|
Nothing -> return $ showScale ("((" ++ top' ++ " REM " ++ bottom' ++ ") - " ++ top' ++ ")") n
|
|
showFlattenedExp showExp (Divide n top bottom)
|
|
= do top' <- showFlattenedExpSet showExp top
|
|
bottom' <- showFlattenedExpSet showExp bottom
|
|
return $ showScale ("(" ++ top' ++ " / " ++ bottom' ++ ")") n
|
|
|
|
showScale :: String -> Integer -> String
|
|
showScale s n =
|
|
case n of
|
|
1 -> s
|
|
-1 -> "-" ++ s
|
|
_ -> (show n) ++ "*" ++ s
|
|
|
|
-- | A type for inside makeEquations:
|
|
data FlattenedExp
|
|
= Const Integer
|
|
-- ^ A constant
|
|
| Scale Integer (A.Expression, Int)
|
|
-- ^ A variable and coefficient. The first argument is the coefficient
|
|
-- The second part of the pair is for sub-indexing (or "priming") variables.
|
|
-- For example, replication is done by checking the replicated variable "i"
|
|
-- against a sub-indexed (with "1") version (denoted "i'"). The sub-index
|
|
-- is what differentiates i from i', given that they are technically the
|
|
-- same A.Variable
|
|
| Modulo Integer (Set.Set FlattenedExp) (Set.Set FlattenedExp)
|
|
-- ^ A modulo, with a coefficient\/scale and given top and bottom (in that order)
|
|
| Divide Integer (Set.Set FlattenedExp) (Set.Set FlattenedExp)
|
|
-- ^ An integer division, with a coefficient\/scale and the given top and bottom (in that order)
|
|
|
|
instance Eq FlattenedExp where
|
|
a == b = EQ == compare a b
|
|
|
|
-- | A straightforward comparison for FlattenedExp that compares while ignoring
|
|
-- the value of a const @(Const 3 == Const 5)@ and the value of a scale
|
|
-- @(Scale 1 (v,0)) == (Scale 3 (v,0))@, although note that @(Scale 1 (v,0)) \/= (Scale 1 (v,1))@.
|
|
instance Ord FlattenedExp where
|
|
compare (Const _) (Const _) = EQ
|
|
compare (Const _) _ = LT
|
|
compare _ (Const _) = GT
|
|
compare (Scale _ (lv,li)) (Scale _ (rv,ri)) = combineCompare [compare lv rv, compare li ri]
|
|
compare (Scale {}) _ = LT
|
|
compare _ (Scale {}) = GT
|
|
compare (Modulo _ ltop lbottom) (Modulo _ rtop rbottom)
|
|
= combineCompare [compare ltop rtop, compare lbottom rbottom]
|
|
compare (Modulo {}) _ = LT
|
|
compare _ (Modulo {}) = GT
|
|
compare (Divide _ ltop lbottom) (Divide _ rtop rbottom)
|
|
= combineCompare [compare ltop rtop, compare lbottom rbottom]
|
|
|
|
-- | Checks if an expression list contains only constants. Returns Just (the aggregate constant) if so,
|
|
-- otherwise returns Nothing.
|
|
onlyConst :: [FlattenedExp] -> Maybe Integer
|
|
onlyConst [] = Just 0
|
|
onlyConst ((Const n):es) = liftM2 (+) (return n) $ onlyConst es
|
|
onlyConst _ = Nothing
|
|
|
|
fmapFlattenedExp :: (A.Expression -> A.Expression) -> FlattenedExp -> FlattenedExp
|
|
fmapFlattenedExp f x@(Const _) = x
|
|
fmapFlattenedExp f (Scale n (e, i)) = Scale n (f e, i)
|
|
fmapFlattenedExp f (Modulo n top bottom)
|
|
= Modulo n (Set.map (fmapFlattenedExp f) top) (Set.map (fmapFlattenedExp f) bottom)
|
|
fmapFlattenedExp f (Divide n top bottom)
|
|
= Divide n (Set.map (fmapFlattenedExp f) top) (Set.map (fmapFlattenedExp f) bottom)
|
|
|
|
-- | A data type representing an array access. Each triple is (index, extra-equalities, extra-inequalities).
|
|
-- A Single item can be paired with every other access.
|
|
-- Each item of a Group cannot be paired with each other, but can be paired with each other access.
|
|
-- With a Replicated, each item in the left branch can be paired with each item in the right branch.
|
|
-- Each item in the left branch can be paired with each other, and each item in the left branch can
|
|
-- be paired with all other items.
|
|
data ArrayAccess label =
|
|
Group [(label, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]
|
|
| Replicated [ArrayAccess label] [ArrayAccess label]
|
|
|
|
-- | A simple data type for denoting whether an array access is a read or a write
|
|
data ArrayAccessType = AAWrite | AARead
|
|
|
|
-- | Transforms the ParItems (from the control-flow graph) into the more suitable ArrayAccess
|
|
-- data type used by this array usage checker.
|
|
parItemToArrayAccessM :: Monad m =>
|
|
( [((A.Name, A.Replicator), Bool)] ->
|
|
a ->
|
|
m [(label, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]
|
|
) ->
|
|
ParItems a ->
|
|
m ([ArrayAccess label], [A.Name])
|
|
parItemToArrayAccessM f (SeqItems xs)
|
|
-- Each sequential item is a group of one:
|
|
= do aas <- sequence [concatMapM (f []) xs >>* Group]
|
|
return (aas, [])
|
|
parItemToArrayAccessM f (ParItems ps)
|
|
= liftM (transformPair concat concat . unzip) $ mapM (parItemToArrayAccessM f) ps
|
|
parItemToArrayAccessM f (RepParItem rep p)
|
|
= do (normal, otherReps) <- parItemToArrayAccessM (\reps -> f ((rep,False):reps)) p
|
|
mirror <- liftM fst $ parItemToArrayAccessM (\reps -> f ((rep,True):reps)) p
|
|
return ([Replicated normal mirror], fst rep : otherReps)
|
|
|
|
-- | Turns a list of expressions (which may contain many constants, or duplicated variables)
|
|
-- into a set of expressions with at most one constant term, and at most one appearance
|
|
-- of a any variable, or distinct modulo\/division of variables.
|
|
-- If there is any problem (specifically, nested modulo or divisions) an error will be returned instead
|
|
makeExpSet :: forall m. MonadError String m => [FlattenedExp] -> m (Set.Set FlattenedExp)
|
|
makeExpSet = foldM makeExpSet' Set.empty
|
|
where
|
|
makeExpSet' :: Set.Set FlattenedExp -> FlattenedExp -> m (Set.Set FlattenedExp)
|
|
makeExpSet' accum (Const n) = return $ insert (addConst n) (Const n) accum
|
|
makeExpSet' accum (Scale n v) = return $ insert (addScale n v) (Scale n v) accum
|
|
makeExpSet' accum m@(Modulo {}) | Set.member m accum = throwError "Cannot have repeated REM items in an expression"
|
|
| otherwise = return $ Set.insert m accum
|
|
makeExpSet' accum d@(Divide {}) | Set.member d accum = throwError "Cannot have repeated (/) items in an expression"
|
|
| otherwise = return $ Set.insert d accum
|
|
|
|
insert :: (FlattenedExp -> Set.Set FlattenedExp -> Maybe (Set.Set FlattenedExp)) -> FlattenedExp -> Set.Set FlattenedExp -> Set.Set FlattenedExp
|
|
insert f e s = case Set.fold insert' (Set.empty,False) s of
|
|
(s',True) -> s'
|
|
_ -> Set.insert e s
|
|
where
|
|
insert' :: FlattenedExp -> (Set.Set FlattenedExp, Bool) -> (Set.Set FlattenedExp, Bool)
|
|
insert' e (s,b) = case f e s of
|
|
Just s' -> (s', True)
|
|
Nothing -> (Set.insert e s, False)
|
|
|
|
addConst :: Integer -> FlattenedExp -> Set.Set FlattenedExp -> Maybe (Set.Set FlattenedExp)
|
|
addConst x (Const n) s = Just $ Set.insert (Const (n + x)) s
|
|
addConst _ _ _ = Nothing
|
|
|
|
addScale :: Integer -> (A.Expression,Int) -> FlattenedExp -> Set.Set FlattenedExp -> Maybe (Set.Set FlattenedExp)
|
|
addScale x (lv,li) (Scale n (rv,ri)) s
|
|
| (EQ == compare lv rv) && (li == ri) = Just $ Set.insert (Scale (x + n) (rv,ri)) s
|
|
| otherwise = Nothing
|
|
addScale _ _ _ _ = Nothing
|
|
|
|
-- | A map from an item (a FlattenedExp, which may be a variable, or modulo\/divide item) to its coefficient in the problem.
|
|
type VarMap = Map.Map FlattenedExp CoeffIndex
|
|
|
|
-- | Background knowledge about a problem; either an equality or an inequality.
|
|
data BackgroundKnowledge
|
|
= Equal A.Expression A.Expression
|
|
| LessThanOrEqual A.Expression A.Expression
|
|
| RepBoundsIncl A.Variable A.Expression A.Expression
|
|
deriving (Typeable, Data)
|
|
|
|
instance Show BackgroundKnowledge where
|
|
show (Equal e e') = showOccam e ++ " = " ++ showOccam e'
|
|
show (LessThanOrEqual e e') = showOccam e ++ " <= " ++ showOccam e'
|
|
show (RepBoundsIncl v e e')
|
|
= showOccam e ++ " <= " ++ showOccam v ++ " <= " ++ showOccam e'
|
|
|
|
-- | The names relate to the equations given in my Omega Test presentation.
|
|
-- X is the top, Y is the bottom, A is the other var (x REM y = x + a)
|
|
data ModuloCase =
|
|
XZero
|
|
| XPos | XNeg -- these two are for constant divisor, all the ones below are for variable divisor
|
|
| XPosYPosAZero | XPosYPosANonZero
|
|
| XPosYNegAZero | XPosYNegANonZero
|
|
| XNegYPosAZero | XNegYPosANonZero
|
|
| XNegYNegAZero | XNegYNegANonZero
|
|
deriving (Show, Eq, Ord)
|
|
|
|
type BKM = StateT VarMap (ReaderT CompState (Either String))
|
|
|
|
-- | Transforms background knowledge into problems
|
|
-- TODO allow modulo in background knowledge
|
|
transformBK :: ([FlattenedExp] -> [FlattenedExp]) -> BackgroundKnowledge ->
|
|
BKM (EqualityProblem,InequalityProblem)
|
|
transformBK f (Equal eL eR) = do eL' <- makeSingleEq f eL "background knowledge"
|
|
eR' <- makeSingleEq f eR "background knowledge"
|
|
let e = addEq eL' (amap negate eR')
|
|
return ([e],[])
|
|
transformBK f (LessThanOrEqual eL eR)
|
|
= do eL' <- makeSingleEq f eL "background knowledge"
|
|
eR' <- makeSingleEq f eR "background knowledge"
|
|
-- eL <= eR implies eR - eL >= 0
|
|
let e = addEq (amap negate eL') eR'
|
|
return ([],[e])
|
|
transformBK f (RepBoundsIncl v low high)
|
|
= do eLow <- makeSingleEq f low "background knowledge, lower bound"
|
|
eHigh <- makeSingleEq f high "background knowledge, upper bound"
|
|
-- v <= eH implies eH - v >= 0
|
|
-- eL <= v implies v - eL >= 0
|
|
ev <- makeEquation v ([], id) (error "Irrelevant type") [Scale 1 (A.ExprVariable emptyMeta v, 0)]
|
|
>>= getSingleAccessItem ("Modulo or divide impossible")
|
|
ev' <- makeEquation v ([], id) (error "Irrelevant type") [Scale 1 (A.ExprVariable emptyMeta v, 1)]
|
|
>>= getSingleAccessItem ("Modulo or divide impossible")
|
|
return ([], [ addEq (amap negate ev) eHigh
|
|
, addEq (amap negate ev') eHigh
|
|
, addEq (amap negate eLow) ev
|
|
, addEq (amap negate eLow) ev'
|
|
])
|
|
|
|
transformBKList :: ([FlattenedExp] -> [FlattenedExp]) -> [BackgroundKnowledge] -> BKM (EqualityProblem,InequalityProblem)
|
|
transformBKList f bk = mapM (transformBK f) bk >>* foldl accumProblem ([],[])
|
|
|
|
-- | Turns a single expression into an equation-item. An error is given if the resulting
|
|
-- expression is anything complicated (for example, modulo or divide)
|
|
makeSingleEq :: ([FlattenedExp] -> [FlattenedExp]) -> A.Expression -> String -> BKM EqualityConstraintEquation
|
|
makeSingleEq f e desc = (lift (flatten e) >>* f) >>= makeEquation e ([{-TODO-}],
|
|
f) (error $ "Type is irrelevant for " ++ desc)
|
|
>>= getSingleAccessItem ("Modulo or Divide not allowed in " ++ desc
|
|
++ "(while processing: " ++ showOccam e ++ ")")
|
|
|
|
-- | A helper function for joining two problems
|
|
accumProblem :: (EqualityProblem,InequalityProblem) -> (EqualityProblem,InequalityProblem) -> (EqualityProblem,InequalityProblem)
|
|
accumProblem = concatPair
|
|
|
|
|
|
-- | Given a list of (written,read) expressions, an expression representing the upper array bound, returns either an error
|
|
-- (because the expressions can't be handled, typically) or a set of equalities, inequalities and mapping from
|
|
-- (unique, munged) variable name to variable-index in the equations.
|
|
--
|
|
-- The general strategy is as follows.
|
|
-- For every array index (here termed an "access"), we transform it into
|
|
-- the usual @[FlattenedExp]@ using the flatten function. Then we also transform
|
|
-- any access that is in the mirror-side of a Replicated item into its mirrored version
|
|
-- where each i is changed into i\'. This is done by using @vi=(variable "i",0)@
|
|
-- (in @Scale _ vi@) for the plain (normal) version, and @vi=(variable "i",1)@
|
|
-- for the prime (mirror) version.
|
|
--
|
|
-- Then the equations have bounds added. The rules are fairly simple; if
|
|
-- any of the transformed EqualityConstraintEquation (or related equalities or inequalities) representing an access
|
|
-- have a non-zero i (and\/or i\'), the bound for that variable is added.
|
|
-- So for example, an expression like i = i\' + 3 would have the bounds for
|
|
-- both i and i\' added (which would be near-identical, e.g. 1 <= i <= 6 and
|
|
-- 1 <= i\' <= 6). We have to check the equalities and inequalities because
|
|
-- when processing modulo, for the i REM y == 0 option, i will not appear in
|
|
-- the index itself (which will be 0) but will appear in the surrounding
|
|
-- constraints, and we still want to add the replication bounds.
|
|
--
|
|
-- The remainder of the work (correctly pairing equations) is done by
|
|
-- squareAndPair.
|
|
--
|
|
-- TODO probably want to take this into the PassM monad at some point, to use the Meta in the error message
|
|
makeEquations :: ParItems (BK, [A.Expression], [A.Expression]) -> A.Expression ->
|
|
ReaderT CompState (Either String) [(((A.Expression, [ModuloCase]), (A.Expression, [ModuloCase])), VarMap, (EqualityProblem, InequalityProblem))]
|
|
makeEquations accesses bound
|
|
= do ((v,h,repVarIndexes, allReps),s) <- (flip runStateT) Map.empty $
|
|
do ((accesses', allReps),repVars) <- flip runStateT [] $ parItemToArrayAccessM mkEq accesses
|
|
high <- makeSingleEq id bound "upper bound"
|
|
return (accesses', high, nub repVars, allReps)
|
|
lift $ squareAndPair (lookupBK allReps) (\(x,y,_) -> (x,y)) repVarIndexes s v (amap (const 0) h, addConstant (-1) h)
|
|
|
|
where
|
|
lookupBK :: [A.Name] -> (A.Expression, [ModuloCase], BK') -> Either String
|
|
[(EqualityProblem, InequalityProblem)]
|
|
lookupBK reps (e,_,bk) = liftM (filter (\x ->
|
|
(not $ null $ fst x) || (not $ null $ snd x))) $
|
|
mapM (foldl (liftM2 accumProblem) (return ([],[])) . map snd .
|
|
filter (\(v,_) -> v `elem` vs || v `elem` reps') . Map.toList) bk
|
|
where
|
|
reps' :: [Var]
|
|
reps' = map (Var . A.Variable emptyMeta) reps
|
|
vs :: [Var]
|
|
vs = map Var $ listify (const True :: A.Variable -> Bool) e
|
|
|
|
-- | A front-end to the setIndexVar' function
|
|
setIndexVar :: A.Variable -> Int -> [FlattenedExp] -> [FlattenedExp]
|
|
setIndexVar tv ti = map (setIndexVar' tv ti)
|
|
|
|
-- | Sets the sub-index of the specified variable throughout the expression
|
|
setIndexVar' :: A.Variable -> Int -> FlattenedExp -> FlattenedExp
|
|
setIndexVar' tv ti s@(Scale n (v,_))
|
|
| EQ == compare (A.ExprVariable emptyMeta tv) v = Scale n (v,ti)
|
|
| otherwise = s
|
|
setIndexVar' tv ti (Modulo n top bottom) = Modulo n top' bottom'
|
|
where
|
|
top' = Set.map (setIndexVar' tv ti) top
|
|
bottom' = Set.map (setIndexVar' tv ti) bottom
|
|
setIndexVar' tv ti (Divide n top bottom) = Divide n top' bottom'
|
|
where
|
|
top' = Set.map (setIndexVar' tv ti) top
|
|
bottom' = Set.map (setIndexVar' tv ti) bottom
|
|
setIndexVar' _ _ e = e
|
|
|
|
-- | Given a list of replicators (marked enabled\/disabled by a flag), the writes and reads,
|
|
-- turns them into a single list of accesses with all the relevant information. The writes and reads
|
|
-- can be grouped together because they are differentiated by the ArrayAccessType in the result
|
|
mkEq :: [((A.Name, A.Replicator), Bool)] ->
|
|
(BK, [A.Expression], [A.Expression]) ->
|
|
StateT [(CoeffIndex, CoeffIndex)]
|
|
BKM
|
|
[((A.Expression, [ModuloCase], BK'), ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]
|
|
mkEq reps (bk, ws, rs) = do repVarEqs <- mapM (liftF makeRepVarEq) reps
|
|
concatMapM (mkEq' repVarEqs) (ws' ++ rs')
|
|
where
|
|
ws' = zip (repeat AAWrite) ws
|
|
rs' = zip (repeat AARead) rs
|
|
|
|
makeRepVarEq :: ((A.Name, A.Replicator), Bool) -> BKM (A.Variable, EqualityConstraintEquation, EqualityConstraintEquation)
|
|
makeRepVarEq ((varName, A.For m from for _), _)
|
|
= do from' <- makeSingleEq id from "replication start"
|
|
upper <- makeSingleEq id (subExprsInt (addExprsInt for from) (makeConstant m 1)) "replication count"
|
|
return (A.Variable m varName, from', upper)
|
|
|
|
mkEq' :: [(A.Variable, EqualityConstraintEquation, EqualityConstraintEquation)] ->
|
|
(ArrayAccessType, A.Expression) ->
|
|
StateT [(CoeffIndex,CoeffIndex)]
|
|
BKM
|
|
[((A.Expression, [ModuloCase], BK'), ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]
|
|
mkEq' repVarEqs (aat, e)
|
|
= do f <- lift . lift $ flatten e
|
|
mirrorFunc <- liftM foldFuncs $ mapM mirrorFlaggedVar reps
|
|
g <- lift $ makeEquation e (bk, mirrorFunc) aat (mirrorFunc f)
|
|
case g of
|
|
Group g' -> return g'
|
|
_ -> throwError "Replicated group found unexpectedly"
|
|
|
|
-- | Turns all instances of the variable from the given replicator into their primed version in the given expression
|
|
mirrorFlaggedVar :: ((A.Name, A.Replicator),Bool) -> StateT [(CoeffIndex,CoeffIndex)] BKM ([FlattenedExp] -> [FlattenedExp])
|
|
mirrorFlaggedVar (_,False) = return id
|
|
mirrorFlaggedVar ((varName, A.For m from for _), True)
|
|
= do varIndexes <- lift $ seqPair (varIndex (Scale 1 (A.ExprVariable emptyMeta var,0)), varIndex (Scale 1 (A.ExprVariable emptyMeta var,1)))
|
|
modify (varIndexes :)
|
|
return $ setIndexVar var 1
|
|
where
|
|
var = A.Variable m varName
|
|
|
|
instance Die (ReaderT CompState (Either String)) where
|
|
dieReport (_, s) = throwError s
|
|
|
|
-- Note that in all these functions, the divisor should always be positive!
|
|
|
|
canonicalise :: forall m. (CSMR m, Die m) => A.Expression -> m A.Expression
|
|
canonicalise e@(A.FunctionCall m n es)
|
|
= do mOp <- functionOperator n
|
|
ts <- mapM astTypeOf es
|
|
case (mOp, fmap (\op -> A.nameName n == occamDefaultOperator op ts) mOp) of
|
|
(Just op, Just True) | op == "+" || op == "*"
|
|
-> liftM (foldl1 (\a b -> A.FunctionCall m n [a, b]) . sort) $ gatherTerms n e
|
|
_ -> mapM canonicalise es >>* A.FunctionCall m n
|
|
where
|
|
gatherTerms :: A.Name -> A.Expression -> m [A.Expression]
|
|
gatherTerms n (A.FunctionCall _ n' es) | n == n'
|
|
= concatMapM (gatherTerms n) es
|
|
gatherTerms _ e = canonicalise e >>* singleton
|
|
canonicalise e = return e
|
|
|
|
flatten :: A.Expression -> ReaderT CompState (Either String) [FlattenedExp]
|
|
flatten (A.Literal _ _ (A.IntLiteral _ n)) = return [Const (read n)]
|
|
flatten e@(A.FunctionCall m fn [lhs, rhs])
|
|
= do mOp <- builtInOperator fn
|
|
case mOp of
|
|
Just "+" -> combine' (flatten lhs) (flatten rhs)
|
|
Just "-" -> combine' (flatten lhs) (mapM (scale (-1)) =<< flatten rhs)
|
|
Just "*" -> multiplyOut' (flatten lhs) (flatten rhs)
|
|
Just "\\" -> liftM2L (Modulo 1) (flatten lhs) (flatten rhs)
|
|
Just "/" ->do rhs' <- flatten rhs
|
|
case onlyConst rhs' of
|
|
Just _ -> liftM2L (Divide 1) (flatten lhs) (return rhs')
|
|
-- Can't deal with variable divisors, leave expression as-is:
|
|
Nothing -> do e' <- canonicalise e
|
|
return [Scale 1 (e',0)]
|
|
_ -> do e' <- canonicalise e
|
|
return [Scale 1 (e',0)]
|
|
where
|
|
liftM2L :: MonadError String m => (Set.Set FlattenedExp -> Set.Set FlattenedExp -> c)
|
|
-> m [FlattenedExp] -> m [FlattenedExp] -> m [c]
|
|
liftM2L f x y = liftM singleton $ liftM2 f (x >>= makeExpSet) (y >>= makeExpSet)
|
|
|
|
multiplyOut' :: (Die m, CSMR m, MonadError String m ) => m [FlattenedExp] -> m [FlattenedExp] -> m [FlattenedExp]
|
|
multiplyOut' x y = join $ liftM2 multiplyOut x y
|
|
|
|
multiplyOut :: forall m. (Die m, CSMR m, MonadError String m) => [FlattenedExp] -> [FlattenedExp] -> m [FlattenedExp]
|
|
multiplyOut lhs rhs = mapM (uncurry mult) pairs
|
|
where
|
|
pairs = product2 (lhs,rhs)
|
|
|
|
mult :: FlattenedExp -> FlattenedExp -> m FlattenedExp
|
|
mult (Const x) e = scale x e
|
|
mult e (Const x) = scale x e
|
|
mult lhs rhs
|
|
= do lhs' <- backToEq lhs
|
|
rhs' <- backToEq rhs
|
|
e <- mulExprs lhs' rhs' >>= canonicalise
|
|
return $ (Scale 1 (e, 0))
|
|
|
|
backScale :: Integer -> A.Expression -> m A.Expression
|
|
backScale 1 e = return e
|
|
backScale n e = do t <- astTypeOf e
|
|
mulExprs (makeConstant' emptyMeta t n) e >>= canonicalise
|
|
|
|
backToEq :: FlattenedExp -> m A.Expression
|
|
backToEq (Const c) = return $ makeConstant emptyMeta (fromInteger c)
|
|
backToEq (Scale n (e,0)) = backScale n e
|
|
backToEq (Modulo n t b)
|
|
| Set.null t || Set.null b = throwError "Modulo had empty top or bottom"
|
|
| otherwise = do t' <- mapM backToEq $ Set.toList t
|
|
b' <- mapM backToEq $ Set.toList b
|
|
t'' <- foldM1 addExprs t'
|
|
b'' <- foldM1 addExprs b'
|
|
remExprs t'' b'' >>= backScale n
|
|
backToEq (Divide n t b)
|
|
| Set.null t || Set.null b = throwError "Divide had empty top or bottom"
|
|
| otherwise = do t' <- mapM backToEq $ Set.toList t
|
|
b' <- mapM backToEq $ Set.toList b
|
|
t'' <- foldM1 addExprs t'
|
|
b'' <- foldM1 addExprs b'
|
|
divExprs t'' b'' >>= backScale n
|
|
|
|
-- | Scales a flattened expression by the given integer scaling.
|
|
scale :: Monad m => Integer -> FlattenedExp -> m FlattenedExp
|
|
scale sc (Const n) = return $ Const (n * sc)
|
|
scale sc (Scale n v) = return $ Scale (n * sc) v
|
|
scale sc (Modulo n t b) = return $ Modulo (n * sc) t b
|
|
scale sc (Divide n t b) = return $ Divide (n * sc) t b
|
|
|
|
-- | An easy way of applying combine to two monadic returns
|
|
combine' :: Monad m => m [FlattenedExp] -> m [FlattenedExp] -> m [FlattenedExp]
|
|
combine' = liftM2 combine
|
|
|
|
-- | Combines (adds) two flattened expressions.
|
|
combine :: [FlattenedExp] -> [FlattenedExp] -> [FlattenedExp]
|
|
combine = (++)
|
|
flatten e = do e' <- canonicalise e
|
|
return [Scale 1 (e',0)]
|
|
|
|
-- | The "square" refers to making all equations the length of the longest
|
|
-- one, and the pair refers to pairing each in a list of array accesses (e.g.
|
|
-- [0, 5, i + 2]) into all possible pairings ([0 == 5, 0 == i + 2, 5 == i + 2])
|
|
--
|
|
-- There are two complications to this function.
|
|
--
|
|
-- Firstly, the array accesses are not actually given in a plain list, but
|
|
-- instead a list of lists. This is because for things like modulo, there are
|
|
-- groups of possible accesses that should not be paired against each other.
|
|
-- For example, you may have something like [0,x,-x] as the three possible
|
|
-- options for a modulo. You want to pair the accesses against other accesses
|
|
-- (e.g. y + 6), but not against each other. So the arguments are passed in
|
|
-- in groups: [[0,x,-x],[y + 6]] and groups are paired against each other,
|
|
-- but not against themselves. This all refers to the third argument to the
|
|
-- function. Each item is actually a triple of (item, equalities, inequalities)
|
|
-- because the modulo aspect adds additional constraints.
|
|
--
|
|
-- The other complication comes from replicated variables.
|
|
-- The first argument is a list of (plain,prime) coefficient indexes
|
|
-- that effectively labels the indexes related to replicated variables.
|
|
-- squareAndPair does two things with this information:
|
|
-- 1. It discards all equations that feature only the prime version of
|
|
-- a variable. You might have passed in the accesses as [[i],[i'],[3]].
|
|
-- (Altering the grouping would not be able to solve this particular problem)
|
|
-- The pairings generated would be [i == i', i == 3, i' == 3]. But the
|
|
-- last two are in effect identical. Therefore we drop the i' prime
|
|
-- version, because it has i' but not i. In contrast, the first item
|
|
-- (i == i') is retained because it features both i and i'.
|
|
-- 2. For every equation that features both i and i', it adds
|
|
-- the inequality "i <= i' - 1". Because all possible combinations of
|
|
-- accesses are examined, in the case of [i,i + 1,i', i' + 1], the pairing
|
|
-- will produce both "i = i' + 1" and "i + 1 = i'" so there is no need
|
|
-- to vary the inequality itself.
|
|
squareAndPair ::
|
|
(label -> Either String [(EqualityProblem, InequalityProblem)]) ->
|
|
(label -> labelStripped) ->
|
|
[(CoeffIndex, CoeffIndex)] ->
|
|
VarMap ->
|
|
[ArrayAccess label] ->
|
|
(EqualityConstraintEquation, EqualityConstraintEquation) ->
|
|
Either String [((labelStripped, labelStripped), VarMap, (EqualityProblem, InequalityProblem))]
|
|
squareAndPair lookupBK strip repVars s v lh
|
|
= concatMapM id
|
|
[let f ((bkEqA, bkIneqA), (bkEqB, bkIneqB))
|
|
= (transformPair strip strip labels,
|
|
s,
|
|
squareEquations (nub (bkEqA ++ bkEqB) ++ eq,
|
|
nub (bkIneqA ++ bkIneqB) ++ ineq ++ concat (applyAll (eq,ineq) (map addExtra repVars))))
|
|
bk = case liftM2 (curry product2) (lookupBK (fst labels)) (lookupBK (snd labels)) of
|
|
Right [] -> Right [(([],[]),([],[]))] -- No BK
|
|
xs -> xs
|
|
in bk >>* map f
|
|
| (labels, eq,ineq) <- pairEqsAndBounds v lh
|
|
,and (map (primeImpliesPlain (eq,ineq)) repVars)
|
|
]
|
|
where
|
|
itemPresent :: CoeffIndex -> [Array CoeffIndex Integer] -> Bool
|
|
itemPresent x = any (\a -> arrayLookupWithDefault 0 a x /= 0)
|
|
|
|
primeImpliesPlain :: (EqualityProblem,InequalityProblem) -> (CoeffIndex,CoeffIndex) -> Bool
|
|
primeImpliesPlain (eq,ineq) (plain,prime) =
|
|
if itemPresent prime (eq ++ ineq)
|
|
-- There are primes, check all the plains are present:
|
|
then itemPresent plain (eq ++ ineq)
|
|
-- No prime, therefore fine:
|
|
else True
|
|
|
|
addExtra :: (CoeffIndex, CoeffIndex) -> (EqualityProblem,InequalityProblem) -> InequalityProblem
|
|
addExtra (plain,prime) (eq, ineq)
|
|
-- prime >= plain + 1 (prime - plain - 1 >= 0)
|
|
= [mapToArray $ Map.fromList [(prime,1), (plain,-1), (0, -1)]]
|
|
|
|
getSingleAccessItem :: MonadError String m => String -> ArrayAccess label -> m EqualityConstraintEquation
|
|
getSingleAccessItem _ (Group [(_,_,(acc,_,_))]) = return acc
|
|
getSingleAccessItem err _ = throwError err
|
|
|
|
-- | Odd helper function for getting\/asserting the first item of a triple from a singleton list inside a monad transformer (!)
|
|
getSingleItem :: MonadError String m => String -> [(a,b,c)] -> m a
|
|
getSingleItem _ [(item,_,_)] = return item
|
|
getSingleItem err _ = throwError err
|
|
|
|
-- | Finds the index associated with a particular variable; either by finding an existing index
|
|
-- or allocating a new one.
|
|
varIndex :: FlattenedExp -> BKM Int
|
|
varIndex (Scale _ (e,vi))
|
|
= do st <- get
|
|
let (st',ind) = case Map.lookup (Scale 1 (e,vi)) st of
|
|
Just val -> (st,val)
|
|
Nothing -> let newId = (1 + (maximum $ 0 : Map.elems st)) in
|
|
(Map.insert (Scale 1 (e,vi)) newId st, newId)
|
|
put st'
|
|
return ind
|
|
varIndex mod@(Modulo _ top bottom)
|
|
= do st <- get
|
|
let (st',ind) = case Map.lookup mod st of
|
|
Just val -> (st,val)
|
|
Nothing -> let newId = (1 + (maximum $ 0 : Map.elems st)) in
|
|
(Map.insert mod newId st, newId)
|
|
put st'
|
|
return ind
|
|
varIndex div@(Divide _ top bottom)
|
|
= do st <- get
|
|
let (st',ind) = case Map.lookup div st of
|
|
Just val -> (st,val)
|
|
Nothing -> let newId = (1 + (maximum $ 0 : Map.elems st)) in
|
|
(Map.insert div newId st, newId)
|
|
put st'
|
|
return ind
|
|
|
|
-- | Pairs all possible combinations of the list of equations.
|
|
pairEqsAndBounds :: [ArrayAccess label] -> (EqualityConstraintEquation, EqualityConstraintEquation) -> [((label,label),EqualityProblem, InequalityProblem)]
|
|
pairEqsAndBounds items bounds = (concatMap (uncurry pairEqs) . allPairs) items ++ concatMap pairRep items
|
|
where
|
|
pairEqs :: ArrayAccess label
|
|
-> ArrayAccess label
|
|
-> [((label,label),EqualityProblem, InequalityProblem)]
|
|
pairEqs (Group accs) (Group accs') = mapMaybe (uncurry pairEqs'') $ product2 (accs,accs')
|
|
pairEqs (Replicated rA rB) lacc
|
|
= concatMap (pairEqs lacc) rA
|
|
pairEqs lacc (Replicated rA rB)
|
|
= concatMap (pairEqs lacc) rA
|
|
|
|
-- Used to pair the items of a single instance of PAR replication with each other
|
|
pairRep :: ArrayAccess label -> [((label,label),EqualityProblem, InequalityProblem)]
|
|
pairRep (Replicated rA rB) = concatMap (uncurry pairEqs) (product2 (rA,rB))
|
|
++ concatMap (uncurry pairEqs) (allPairs rA)
|
|
pairRep _ = []
|
|
|
|
pairEqs'' :: (label, ArrayAccessType,(EqualityConstraintEquation, EqualityProblem, InequalityProblem))
|
|
-> (label, ArrayAccessType,(EqualityConstraintEquation, EqualityProblem, InequalityProblem))
|
|
-> Maybe ((label,label), EqualityProblem, InequalityProblem)
|
|
pairEqs'' (lx,x,x') (ly,y,y') = case pairEqs' (x,x') (y,y') of
|
|
Just (eq,ineq) -> Just ((lx,ly),eq,ineq)
|
|
Nothing -> Nothing
|
|
|
|
pairEqs' :: (ArrayAccessType,(EqualityConstraintEquation, EqualityProblem, InequalityProblem))
|
|
-> (ArrayAccessType,(EqualityConstraintEquation, EqualityProblem, InequalityProblem))
|
|
-> Maybe (EqualityProblem, InequalityProblem)
|
|
pairEqs' (AARead,_) (AARead,_) = Nothing
|
|
pairEqs' (_,(ex,eqX,ineqX)) (_,(ey,eqY,ineqY)) = Just ([arrayZipWith' 0 (-) ex ey] ++ eqX ++ eqY, ineqX ++ ineqY ++ getIneqs bounds [ex,ey])
|
|
|
|
addEq :: EqualityConstraintEquation -> EqualityConstraintEquation -> EqualityConstraintEquation
|
|
addEq = arrayZipWith' 0 (+)
|
|
|
|
-- | Given a (low,high) bound (typically: array dimensions), and a list of equations ex,
|
|
-- forms the possible inequalities:
|
|
-- * ex >= low
|
|
-- * ex <= high
|
|
getIneqs :: (EqualityConstraintEquation, EqualityConstraintEquation) -> [EqualityConstraintEquation] -> [InequalityConstraintEquation]
|
|
getIneqs (low, high) = concatMap getLH
|
|
where
|
|
-- eq >= low => eq - low >= 0
|
|
-- eq <= high => high - eq >= 0
|
|
|
|
getLH :: EqualityConstraintEquation -> [InequalityConstraintEquation]
|
|
getLH eq = [eq `addEq` (amap negate low),high `addEq` amap negate eq]
|
|
|
|
justState :: Error e => StateT s (ReaderT r (Either e)) a -> StateT s (ReaderT r (Either e)) (Either e a)
|
|
justState m = do st <- get
|
|
r <- ask
|
|
let (x, st') = case runReaderT (runStateT m st) r of
|
|
Left err -> (Left err, st)
|
|
Right (x, st') -> (Right x, st')
|
|
put st'
|
|
return x
|
|
|
|
-- | Given an expression, forms equations (and accompanying additional equation-sets) and returns it
|
|
makeEquation :: label -> (BK, [FlattenedExp] -> [FlattenedExp]) -> ArrayAccessType -> [FlattenedExp]
|
|
-> BKM (ArrayAccess (label,[ModuloCase], BK'))
|
|
makeEquation l (bk, bkF) t summedItems
|
|
= do eqs <- process summedItems
|
|
bk' <- mapM (mapMapM (justState . transformBKList bkF)) bk
|
|
let eqs' = map (transformQuad id mapToArray (map mapToArray) (map mapToArray)) eqs :: [([ModuloCase], EqualityConstraintEquation, EqualityProblem, InequalityProblem)]
|
|
return $ Group [((l,c,bk'),t,(e0,e1,e2)) | (c,e0,e1,e2) <- eqs']
|
|
where
|
|
process :: [FlattenedExp] -> BKM [([ModuloCase], Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])]
|
|
process = foldM makeEquation' empty
|
|
|
|
makeEquation' :: [([ModuloCase], Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])] ->
|
|
FlattenedExp ->
|
|
BKM
|
|
[([ModuloCase], Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])]
|
|
makeEquation' m (Const n) = return $ add (0,n) m
|
|
makeEquation' m sc@(Scale n v) = varIndex sc >>* (\ind -> add (ind, n) m)
|
|
makeEquation' m mod@(Modulo n top bottom)
|
|
= do top' <- process (Set.toList top) >>* map (\(_,a,b,c) -> (a,b,c))
|
|
top'' <- getSingleItem "Modulo or divide not allowed in the numerator of Modulo" top'
|
|
bottom' <- process (Set.toList bottom) >>* map (\(_,a,b,c) -> (a,b,c))
|
|
modIndex <- varIndex mod
|
|
case onlyConst (Set.toList bottom) of
|
|
Just bottomConst ->
|
|
let add_x_plus_my = zipMap plus top'' . zipMap plus (Map.fromList [(modIndex, abs bottomConst)]) in
|
|
-- Adds n*(x + my)
|
|
let add_n_x_plus_my = zipMap plus (Map.map (*n) top'') . zipMap plus (Map.fromList [(modIndex, n * abs bottomConst)]) in
|
|
return $
|
|
-- The zero option (x = 0, x REM y = 0):
|
|
( map (transformQuad (++ [XZero]) id (++ [top'']) id) m)
|
|
++
|
|
-- The top-is-positive option:
|
|
( map (transformQuad (++ [XPos]) add_n_x_plus_my id (++
|
|
-- x >= 1
|
|
[zipMap plus (Map.fromList [(0,-1)]) top''
|
|
-- m <= 0
|
|
,Map.fromList [(modIndex,-1)]
|
|
-- x + my + 1 - |y| <= 0
|
|
,Map.map negate $ add_x_plus_my $ Map.fromList [(0,1 - abs bottomConst)]
|
|
-- x + my >= 0
|
|
,add_x_plus_my $ Map.empty])
|
|
) m) ++
|
|
-- The top-is-negative option:
|
|
( map (transformQuad (++ [XNeg]) add_n_x_plus_my id (++
|
|
-- x <= -1
|
|
[add' (0,-1) $ Map.map negate top''
|
|
-- m >= 0
|
|
,Map.fromList [(modIndex,1)]
|
|
-- x + my - 1 + |y| >= 0
|
|
,add_x_plus_my $ Map.fromList [(0,abs bottomConst - 1)]
|
|
-- x + my <= 0
|
|
,Map.map negate $ add_x_plus_my Map.empty])
|
|
) m)
|
|
_ ->
|
|
do bottom'' <- getSingleItem "Modulo or divide not allowed in the divisor of Modulo" bottom'
|
|
return $
|
|
-- The zero option (x = 0, x REM y = 0):
|
|
(map (transformQuad (++ [XZero]) id (++ [top'']) id) m)
|
|
-- The rest:
|
|
++ twinItems True True n (top'', modIndex) bottom''
|
|
++ twinItems True False n (top'', modIndex) bottom''
|
|
++ twinItems False True n (top'', modIndex) bottom''
|
|
++ twinItems False False n (top'', modIndex) bottom''
|
|
where
|
|
-- Each pair for modulo (variable divisor) depending on signs of x and y (in x REM y):
|
|
twinItems :: Bool -> Bool -> Integer -> (Map.Map Int Integer,Int) -> Map.Map Int Integer ->
|
|
[([ModuloCase], Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])]
|
|
twinItems xPos yPos n (top,modIndex) bottom
|
|
= (map (transformQuad (++ [findCase xPos yPos False]) (zipMap plus $ Map.map (*n) top) id
|
|
(++ [xEquation]
|
|
++ [xLowerBound False]
|
|
++ [xUpperBound False])) m)
|
|
++ (map (transformQuad (++ [findCase xPos yPos True]) (zipMap plus (Map.map (*n) top) . add' (modIndex, n)) id
|
|
(++ [xEquation]
|
|
++ [xLowerBound True]
|
|
++ [xUpperBound True]
|
|
-- We want to add the bounds for a and y as follows:
|
|
-- xPos yPos | Equation
|
|
-- T T | -y - a >= 0
|
|
-- T F | y - a >= 0
|
|
-- F T | a - y >= 0
|
|
-- F F | a + y >= 0
|
|
-- Therefore the sign of a is (not xPos), the sign of y is (not yPos)
|
|
++ [add' (modIndex,if xPos then -1 else 1) (signEq (not yPos) bottom)])) m)
|
|
where
|
|
-- x >= 1 or x <= -1 (rearranged: -1 + x >= 0 or -1 - x >= 0)
|
|
xEquation = add' (0,-1) (signEq xPos top)
|
|
-- We include (x [+ a] >= 0 or x [+ a] <= 0) even though they are redundant in some cases (addA = False):
|
|
xLowerBound addA = signEq xPos $ (if addA then add' (modIndex,1) else id) top
|
|
-- We want to add the bounds as follows:
|
|
-- xPos yPos | Equation
|
|
-- T T | y - 1 - x - a >= 0
|
|
-- T F | -y - 1 - x - a >= 0
|
|
-- F T | x + a - 1 + y >= 0
|
|
-- F F | x + a - y - 1 >= 0
|
|
-- Therefore the sign of y in the equation is yPos, the sign of x and a is (not xPos)
|
|
xUpperBound addA = add' (0,-1) $ zipMap plus (signEq (not xPos) ((if addA then add' (modIndex,1) else id) top)) (signEq yPos bottom)
|
|
|
|
signEq sign eq = if sign then eq else Map.map negate eq
|
|
|
|
findCase xPos yPos aNonZero = case (xPos, yPos, aNonZero) of
|
|
(True , True , True ) -> XPosYPosANonZero
|
|
(True , True , False) -> XPosYPosAZero
|
|
(True , False, True ) -> XPosYNegANonZero
|
|
(True , False, False) -> XPosYNegAZero
|
|
(False, True , True ) -> XNegYPosANonZero
|
|
(False, True , False) -> XNegYPosAZero
|
|
(False, False, True ) -> XNegYNegANonZero
|
|
(False, False, False) -> XNegYNegAZero
|
|
|
|
makeEquation' m div@(Divide n top bottom)
|
|
= do top' <- process (Set.toList top) >>* map (\(_,a,b,c) -> (a,b,c))
|
|
top'' <- getSingleItem "Modulo or Divide not allowed in the numerator of Divide" top'
|
|
bottom' <- process (Set.toList bottom) >>* map (\(_,a,b,c) -> (a,b,c))
|
|
divIndex <- varIndex div
|
|
case onlyConst (Set.toList bottom) of
|
|
Just bottomConst ->
|
|
let add_m :: Map.Map Int Integer -> Map.Map Int Integer
|
|
add_m = zipMap plus (Map.fromList [(divIndex,n)])
|
|
add_x_minus_my = zipMap plus top'' . zipMap plus (Map.fromList [(divIndex,-bottomConst)]) in
|
|
return $
|
|
-- The zero option (x = 0, x REM y = 0):
|
|
( map (transformQuad (++ [XZero]) id (++ [top'']) id) m)
|
|
++
|
|
-- The top-is-positive option:
|
|
( map (transformQuad (++ [XPos]) add_m id (++
|
|
-- x >= 1
|
|
[zipMap plus (Map.fromList [(0,-1)]) top''
|
|
-- m >= 0 if y positive
|
|
-- m <= 0 (i.e. -m >= 0) if y negative
|
|
,Map.fromList [(divIndex, signum bottomConst)]
|
|
-- x + my + 1 - y <= 0 if y positive
|
|
-- x + my - 1 - y >= 0 if y negative
|
|
,(if (bottomConst > 0) then Map.map negate else id) $ add_x_minus_my $ Map.fromList [(0,signum bottomConst - bottomConst)]
|
|
-- x + my >= 0 if y positive
|
|
-- x + my <= 0 if negative
|
|
,(if (bottomConst > 0) then id else Map.map negate) $ add_x_minus_my $ Map.empty])
|
|
) m) ++
|
|
-- The top-is-negative option:
|
|
( map (transformQuad (++ [XNeg]) add_m id (++
|
|
-- x <= -1
|
|
[add' (0,-1) $ Map.map negate top''
|
|
-- m <= 0 if y positive
|
|
-- m >= 0 if y negative
|
|
,Map.fromList [(divIndex, - signum bottomConst)]
|
|
-- x + my - 1 + y >= 0 if y positive
|
|
-- x + my + 1 + y <= 0 if y negative
|
|
,(if (bottomConst > 0) then id else Map.map negate) $ add_x_minus_my $ Map.fromList [(0,bottomConst - signum bottomConst)]
|
|
-- x + my <= 0 if y positive
|
|
-- x + my >= 0 if y negative
|
|
,(if (bottomConst > 0) then Map.map negate else id) $ add_x_minus_my Map.empty])
|
|
) m)
|
|
_ -> throwError "Variables in divisor not supported by usage checker"
|
|
|
|
empty :: [([ModuloCase],Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])]
|
|
empty = [([],Map.empty,[],[])]
|
|
|
|
plus :: Num n => Maybe n -> Maybe n -> Maybe n
|
|
plus x y = Just $ (fromMaybe 0 x) + (fromMaybe 0 y)
|
|
|
|
add' :: (Int,Integer) -> Map.Map Int Integer -> Map.Map Int Integer
|
|
add' (m,n) = Map.insertWith (+) m n
|
|
|
|
add :: (Int,Integer) -> [(z,Map.Map Int Integer,a,b)] -> [(z,Map.Map Int Integer,a,b)]
|
|
add (m,n) = map $ (\(a,b,c,d) -> (a,(Map.insertWith (+) m n) b,c,d))
|
|
|
|
-- | Converts a map to an array. Any missing elements in the middle of the bounds are given the value zero.
|
|
-- Could probably be moved to Utils
|
|
mapToArray :: (IArray a v, Num v, Num k, Ord k, Ix k) => Map.Map k v -> a k v
|
|
mapToArray m = accumArray (+) 0 (0, highest') . Map.assocs $ m
|
|
where
|
|
highest' = maximum $ 0 : Map.keys m
|
|
|
|
-- | Given a pair of equation sets, makes all the equations in the lists be the length
|
|
-- of the longest equation. All missing elements are of course given value zero.
|
|
squareEquations :: ([Array CoeffIndex Integer],[Array CoeffIndex Integer]) -> ([Array CoeffIndex Integer],[Array CoeffIndex Integer])
|
|
squareEquations (eqs,ineqs) = uncurry transformPair (mkPair $ map $ makeArraySize (0,highest) 0) (eqs,ineqs)
|
|
|
|
where
|
|
highest = maximum $ 0 : (concatMap indices $ eqs ++ ineqs)
|
|
|