A big patch to get ArrayUsageCheck compiling again, with the new operators

This commit is contained in:
Neil Brown 2009-04-08 15:44:35 +00:00
parent d2fb80c516
commit 5c353f10b7
2 changed files with 96 additions and 72 deletions

View File

@ -32,6 +32,7 @@ module ArrayUsageCheck (
VarMap) where
import Control.Monad.Error
import Control.Monad.Reader
import Control.Monad.State
import Data.Array.IArray
import qualified Data.Foldable as F
@ -78,7 +79,7 @@ findRepSolutions reps bks
-- which means they can overlap -- but only if there is also a solution to the
-- replicator background knowledge, which is what this function is trying to
-- determine.
= case makeEquations (addReps $ SeqItems [(bk, [makeConstant emptyMeta 0], [])
= getCompState >>= \cs -> case flip runReaderT cs $ makeEquations (addReps $ SeqItems [(bk, [makeConstant emptyMeta 0], [])
| bk <- bks]) maxInt of
Right problems -> do
probs <- formatProblems [(vm, prob) | (_,vm,prob) <- problems]
@ -170,7 +171,8 @@ checkArrayUsage sharedAttr (m,p)
A.Array (A.UnknownDimension:_) _ -> return $ makeConstant m $ fromInteger $ toInteger (maxBound :: Int32)
-- It's not an array:
_ -> dieP m $ "Cannot usage check array \"" ++ userArrName ++ "\"; found to be of type: " ++ show arrType
case makeEquations indexes arrLength of
cs <- getCompState
case runReaderT (makeEquations indexes arrLength) cs of
Left err -> dieP m $ "Could not work with array indexes for array \"" ++ userArrName ++ "\": " ++ err
Right [] -> return () -- No problems to work with
Right problems -> do
@ -407,10 +409,10 @@ parItemToArrayAccessM f (RepParItem rep p)
-- into a set of expressions with at most one constant term, and at most one appearance
-- of a any variable, or distinct modulo\/division of variables.
-- If there is any problem (specifically, nested modulo or divisions) an error will be returned instead
makeExpSet :: [FlattenedExp] -> Either String (Set.Set FlattenedExp)
makeExpSet :: forall m. MonadError String m => [FlattenedExp] -> m (Set.Set FlattenedExp)
makeExpSet = foldM makeExpSet' Set.empty
where
makeExpSet' :: Set.Set FlattenedExp -> FlattenedExp -> Either String (Set.Set FlattenedExp)
makeExpSet' :: Set.Set FlattenedExp -> FlattenedExp -> m (Set.Set FlattenedExp)
makeExpSet' accum (Const n) = return $ insert (addConst n) (Const n) accum
makeExpSet' accum (Scale n v) = return $ insert (addScale n v) (Scale n v) accum
makeExpSet' accum m@(Modulo {}) | Set.member m accum = throwError "Cannot have repeated REM items in an expression"
@ -465,9 +467,12 @@ data ModuloCase =
| XNegYNegAZero | XNegYNegANonZero
deriving (Show, Eq, Ord)
type BKM = StateT VarMap (ReaderT CompState (Either String))
-- | Transforms background knowledge into problems
-- TODO allow modulo in background knowledge
transformBK :: ([FlattenedExp] -> [FlattenedExp]) -> BackgroundKnowledge -> StateT VarMap (Either String) (EqualityProblem,InequalityProblem)
transformBK :: ([FlattenedExp] -> [FlattenedExp]) -> BackgroundKnowledge ->
BKM (EqualityProblem,InequalityProblem)
transformBK f (Equal eL eR) = do eL' <- makeSingleEq f eL "background knowledge"
eR' <- makeSingleEq f eR "background knowledge"
let e = addEq eL' (amap negate eR')
@ -493,12 +498,12 @@ transformBK f (RepBoundsIncl v low high)
, addEq (amap negate eLow) ev'
])
transformBKList :: ([FlattenedExp] -> [FlattenedExp]) -> [BackgroundKnowledge] -> StateT VarMap (Either String) (EqualityProblem,InequalityProblem)
transformBKList :: ([FlattenedExp] -> [FlattenedExp]) -> [BackgroundKnowledge] -> BKM (EqualityProblem,InequalityProblem)
transformBKList f bk = mapM (transformBK f) bk >>* foldl accumProblem ([],[])
-- | Turns a single expression into an equation-item. An error is given if the resulting
-- expression is anything complicated (for example, modulo or divide)
makeSingleEq :: ([FlattenedExp] -> [FlattenedExp]) -> A.Expression -> String -> StateT VarMap (Either String) EqualityConstraintEquation
makeSingleEq :: ([FlattenedExp] -> [FlattenedExp]) -> A.Expression -> String -> BKM EqualityConstraintEquation
makeSingleEq f e desc = (lift (flatten e) >>* f) >>= makeEquation e ([{-TODO-}],
f) (error $ "Type is irrelevant for " ++ desc)
>>= getSingleAccessItem ("Modulo or Divide not allowed in " ++ desc
@ -536,13 +541,13 @@ accumProblem = concatPair
--
-- TODO probably want to take this into the PassM monad at some point, to use the Meta in the error message
makeEquations :: ParItems (BK, [A.Expression], [A.Expression]) -> A.Expression ->
Either String [(((A.Expression, [ModuloCase]), (A.Expression, [ModuloCase])), VarMap, (EqualityProblem, InequalityProblem))]
ReaderT CompState (Either String) [(((A.Expression, [ModuloCase]), (A.Expression, [ModuloCase])), VarMap, (EqualityProblem, InequalityProblem))]
makeEquations accesses bound
= do ((v,h,repVarIndexes, allReps),s) <- (flip runStateT) Map.empty $
do ((accesses', allReps),repVars) <- flip runStateT [] $ parItemToArrayAccessM mkEq accesses
high <- makeSingleEq id bound "upper bound"
return (accesses', high, nub repVars, allReps)
squareAndPair (lookupBK allReps) (\(x,y,_) -> (x,y)) repVarIndexes s v (amap (const 0) h, addConstant (-1) h)
lift $ squareAndPair (lookupBK allReps) (\(x,y,_) -> (x,y)) repVarIndexes s v (amap (const 0) h, addConstant (-1) h)
where
lookupBK :: [A.Name] -> (A.Expression, [ModuloCase], BK') -> Either String
@ -582,7 +587,7 @@ makeEquations accesses bound
mkEq :: [((A.Name, A.Replicator), Bool)] ->
(BK, [A.Expression], [A.Expression]) ->
StateT [(CoeffIndex, CoeffIndex)]
(StateT VarMap (Either String))
BKM
[((A.Expression, [ModuloCase], BK'), ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]
mkEq reps (bk, ws, rs) = do repVarEqs <- mapM (liftF makeRepVarEq) reps
concatMapM (mkEq' repVarEqs) (ws' ++ rs')
@ -590,16 +595,16 @@ makeEquations accesses bound
ws' = zip (repeat AAWrite) ws
rs' = zip (repeat AARead) rs
makeRepVarEq :: ((A.Name, A.Replicator), Bool) -> StateT VarMap (Either String) (A.Variable, EqualityConstraintEquation, EqualityConstraintEquation)
makeRepVarEq :: ((A.Name, A.Replicator), Bool) -> BKM (A.Variable, EqualityConstraintEquation, EqualityConstraintEquation)
makeRepVarEq ((varName, A.For m from for _), _)
= do from' <- makeSingleEq id from "replication start"
upper <- makeSingleEq id (A.Dyadic m A.Subtr (A.Dyadic m A.Add for from) (makeConstant m 1)) "replication count"
upper <- makeSingleEq id (subExprsInt (addExprsInt for from) (makeConstant m 1)) "replication count"
return (A.Variable m varName, from', upper)
mkEq' :: [(A.Variable, EqualityConstraintEquation, EqualityConstraintEquation)] ->
(ArrayAccessType, A.Expression) ->
StateT [(CoeffIndex,CoeffIndex)]
(StateT VarMap (Either String))
BKM
[((A.Expression, [ModuloCase], BK'), ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]
mkEq' repVarEqs (aat, e)
= do f <- lift . lift $ flatten e
@ -610,7 +615,7 @@ makeEquations accesses bound
_ -> throwError "Replicated group found unexpectedly"
-- | Turns all instances of the variable from the given replicator into their primed version in the given expression
mirrorFlaggedVar :: ((A.Name, A.Replicator),Bool) -> StateT [(CoeffIndex,CoeffIndex)] (StateT VarMap (Either String)) ([FlattenedExp] -> [FlattenedExp])
mirrorFlaggedVar :: ((A.Name, A.Replicator),Bool) -> StateT [(CoeffIndex,CoeffIndex)] BKM ([FlattenedExp] -> [FlattenedExp])
mirrorFlaggedVar (_,False) = return id
mirrorFlaggedVar ((varName, A.For m from for _), True)
= do varIndexes <- lift $ seqPair (varIndex (Scale 1 (A.ExprVariable emptyMeta var,0)), varIndex (Scale 1 (A.ExprVariable emptyMeta var,1)))
@ -619,93 +624,104 @@ makeEquations accesses bound
where
var = A.Variable m varName
instance Die (ReaderT CompState (Either String)) where
dieReport (_, s) = throwError s
-- Note that in all these functions, the divisor should always be positive!
canonicalise :: A.Expression -> A.Expression
canonicalise e@(A.Dyadic m op _ _) | op == A.Add || op == A.Mul
= foldl1 (A.Dyadic m op) $ sort $ gatherTerms op e
canonicalise :: forall m. (CSMR m, Die m) => A.Expression -> m A.Expression
canonicalise e@(A.FunctionCall m n es)
= do mOp <- functionOperator n
ts <- mapM astTypeOf es
case (mOp, fmap (\op -> A.nameName n == occamDefaultOperator op ts) mOp) of
(Just op, Just True) | op == "+" || op == "*"
-> liftM (foldl1 (\a b -> A.FunctionCall m n [a, b]) . sort) $ gatherTerms n e
_ -> mapM canonicalise es >>* A.FunctionCall m n
where
gatherTerms :: A.DyadicOp -> A.Expression -> [A.Expression]
gatherTerms op (A.Dyadic _ op' lhs rhs) | op == op'
= gatherTerms op lhs ++ gatherTerms op rhs
gatherTerms _ e = [canonicalise e]
canonicalise (A.Dyadic m op lhs rhs)
= A.Dyadic m op (canonicalise lhs) (canonicalise rhs)
canonicalise (A.Monadic m op rhs)
= A.Monadic m op (canonicalise rhs)
canonicalise e = e
gatherTerms :: A.Name -> A.Expression -> m [A.Expression]
gatherTerms n (A.FunctionCall _ n' es) | n == n'
= concatMapM (gatherTerms n) es
gatherTerms _ e = canonicalise e >>* singleton
canonicalise e = return e
flatten :: A.Expression -> Either String [FlattenedExp]
flatten :: A.Expression -> ReaderT CompState (Either String) [FlattenedExp]
flatten (A.Literal _ _ (A.IntLiteral _ n)) = return [Const (read n)]
flatten e@(A.Dyadic m op lhs rhs)
| op == A.Add = combine' (flatten lhs) (flatten rhs)
| op == A.Subtr = combine' (flatten lhs) (mapM (scale (-1)) =<< flatten rhs)
| op == A.Mul = multiplyOut' (flatten lhs) (flatten rhs)
| op == A.Rem = liftM2L (Modulo 1) (flatten lhs) (flatten rhs)
| op == A.Div = do rhs' <- flatten rhs
flatten e@(A.FunctionCall m fn [lhs, rhs])
= do mOp <- builtInOperator fn
case mOp of
Just "+" -> combine' (flatten lhs) (flatten rhs)
Just "-" -> combine' (flatten lhs) (mapM (scale (-1)) =<< flatten rhs)
Just "*" -> multiplyOut' (flatten lhs) (flatten rhs)
Just "\\" -> liftM2L (Modulo 1) (flatten lhs) (flatten rhs)
Just "/" ->do rhs' <- flatten rhs
case onlyConst rhs' of
Just _ -> liftM2L (Divide 1) (flatten lhs) (return rhs')
-- Can't deal with variable divisors, leave expression as-is:
Nothing -> return [Scale 1 (canonicalise e,0)]
Nothing -> do e' <- canonicalise e
return [Scale 1 (e',0)]
_ -> do e' <- canonicalise e
return [Scale 1 (e',0)]
where
-- liftM2L :: (Ord a, Ord b, Monad m) => (Set.Set a -> Set.Set b -> c) -> m [a] -> m [b] -> m [c]
liftM2L :: MonadError String m => (Set.Set FlattenedExp -> Set.Set FlattenedExp -> c)
-> m [FlattenedExp] -> m [FlattenedExp] -> m [c]
liftM2L f x y = liftM singleton $ liftM2 f (x >>= makeExpSet) (y >>= makeExpSet)
multiplyOut' :: Either String [FlattenedExp] -> Either String [FlattenedExp] -> Either String [FlattenedExp]
multiplyOut' x y = do {x' <- x; y' <- y; multiplyOut x' y'}
multiplyOut' :: (Die m, CSMR m, MonadError String m ) => m [FlattenedExp] -> m [FlattenedExp] -> m [FlattenedExp]
multiplyOut' x y = join $ liftM2 multiplyOut x y
multiplyOut :: [FlattenedExp] -> [FlattenedExp] -> Either String [FlattenedExp]
multiplyOut :: forall m. (Die m, CSMR m, MonadError String m) => [FlattenedExp] -> [FlattenedExp] -> m [FlattenedExp]
multiplyOut lhs rhs = mapM (uncurry mult) pairs
where
pairs = product2 (lhs,rhs)
mult :: FlattenedExp -> FlattenedExp -> Either String FlattenedExp
mult :: FlattenedExp -> FlattenedExp -> m FlattenedExp
mult (Const x) e = scale x e
mult e (Const x) = scale x e
mult lhs rhs
= do lhs' <- backToEq lhs
rhs' <- backToEq rhs
return $ (Scale 1 (canonicalise $ A.Dyadic emptyMeta A.Mul lhs' rhs', 0))
e <- mulExprs lhs' rhs' >>= canonicalise
return $ (Scale 1 (e, 0))
backScale :: Integer -> A.Expression -> A.Expression
backScale 1 = id
backScale n = canonicalise . A.Dyadic emptyMeta A.Mul (makeConstant emptyMeta (fromInteger n))
backScale :: Integer -> A.Expression -> m A.Expression
backScale 1 e = return e
backScale n e = do t <- astTypeOf e
mulExprs (makeConstant' emptyMeta t n) e >>= canonicalise
backToEq :: FlattenedExp -> Either String A.Expression
backToEq :: FlattenedExp -> m A.Expression
backToEq (Const c) = return $ makeConstant emptyMeta (fromInteger c)
backToEq (Scale n (e,0)) = return $ backScale n e
backToEq (Scale n (e,0)) = backScale n e
backToEq (Modulo n t b)
| Set.null t || Set.null b = throwError "Modulo had empty top or bottom"
| otherwise = do t' <- mapM backToEq $ Set.toList t
b' <- mapM backToEq $ Set.toList b
return $
(backScale n $ A.Dyadic emptyMeta A.Rem
(foldl1 (A.Dyadic emptyMeta A.Add) t')
(foldl1 (A.Dyadic emptyMeta A.Add) b'))
t'' <- foldM1 addExprs t'
b'' <- foldM1 addExprs b'
remExprs t'' b'' >>= backScale n
backToEq (Divide n t b)
| Set.null t || Set.null b = throwError "Divide had empty top or bottom"
| otherwise = do t' <- mapM backToEq $ Set.toList t
b' <- mapM backToEq $ Set.toList b
return $
(backScale n $ A.Dyadic emptyMeta A.Div
(foldl1 (A.Dyadic emptyMeta A.Add) t')
(foldl1 (A.Dyadic emptyMeta A.Add) b'))
t'' <- foldM1 addExprs t'
b'' <- foldM1 addExprs b'
divExprs t'' b'' >>= backScale n
-- | Scales a flattened expression by the given integer scaling.
scale :: Integer -> FlattenedExp -> Either String FlattenedExp
scale :: Monad m => Integer -> FlattenedExp -> m FlattenedExp
scale sc (Const n) = return $ Const (n * sc)
scale sc (Scale n v) = return $ Scale (n * sc) v
scale sc (Modulo n t b) = return $ Modulo (n * sc) t b
scale sc (Divide n t b) = return $ Divide (n * sc) t b
-- | An easy way of applying combine to two monadic returns
combine' :: Either String [FlattenedExp] -> Either String [FlattenedExp] -> Either String [FlattenedExp]
combine' :: Monad m => m [FlattenedExp] -> m [FlattenedExp] -> m [FlattenedExp]
combine' = liftM2 combine
-- | Combines (adds) two flattened expressions.
combine :: [FlattenedExp] -> [FlattenedExp] -> [FlattenedExp]
combine = (++)
flatten e = return [Scale 1 (canonicalise e,0)]
flatten e = do e' <- canonicalise e
return [Scale 1 (e',0)]
-- | The "square" refers to making all equations the length of the longest
-- one, and the pair refers to pairing each in a list of array accesses (e.g.
@ -779,18 +795,18 @@ squareAndPair lookupBK strip repVars s v lh
-- prime >= plain + 1 (prime - plain - 1 >= 0)
= [mapToArray $ Map.fromList [(prime,1), (plain,-1), (0, -1)]]
getSingleAccessItem :: MonadTrans m => String -> ArrayAccess label -> m (Either String) EqualityConstraintEquation
getSingleAccessItem _ (Group [(_,_,(acc,_,_))]) = lift $ return acc
getSingleAccessItem err _ = lift $ throwError err
getSingleAccessItem :: MonadError String m => String -> ArrayAccess label -> m EqualityConstraintEquation
getSingleAccessItem _ (Group [(_,_,(acc,_,_))]) = return acc
getSingleAccessItem err _ = throwError err
-- | Odd helper function for getting\/asserting the first item of a triple from a singleton list inside a monad transformer (!)
getSingleItem :: MonadTrans m => String -> [(a,b,c)] -> m (Either String) a
getSingleItem _ [(item,_,_)] = lift $ return item
getSingleItem err _ = lift $ throwError err
getSingleItem :: MonadError String m => String -> [(a,b,c)] -> m a
getSingleItem _ [(item,_,_)] = return item
getSingleItem err _ = throwError err
-- | Finds the index associated with a particular variable; either by finding an existing index
-- or allocating a new one.
varIndex :: FlattenedExp -> StateT (VarMap) (Either String) Int
varIndex :: FlattenedExp -> BKM Int
varIndex (Scale _ (e,vi))
= do st <- get
let (st',ind) = case Map.lookup (Scale 1 (e,vi)) st of
@ -864,29 +880,30 @@ getIneqs (low, high) = concatMap getLH
getLH :: EqualityConstraintEquation -> [InequalityConstraintEquation]
getLH eq = [eq `addEq` (amap negate low),high `addEq` amap negate eq]
justState :: Error e => StateT s (Either e) a -> StateT s (Either e) (Either e a)
justState :: Error e => StateT s (ReaderT r (Either e)) a -> StateT s (ReaderT r (Either e)) (Either e a)
justState m = do st <- get
let (x, st') = case runStateT m st of
r <- ask
let (x, st') = case runReaderT (runStateT m st) r of
Left err -> (Left err, st)
Right (x, st') -> (Right x, st')
put st'
return x
-- | Given an expression, forms equations (and accompanying additional equation-sets) and returns it
makeEquation :: label -> (BK, [FlattenedExp] -> [FlattenedExp]) -> ArrayAccessType -> [FlattenedExp] -> StateT VarMap (Either String) (ArrayAccess (label,[ModuloCase],
BK'))
makeEquation :: label -> (BK, [FlattenedExp] -> [FlattenedExp]) -> ArrayAccessType -> [FlattenedExp]
-> BKM (ArrayAccess (label,[ModuloCase], BK'))
makeEquation l (bk, bkF) t summedItems
= do eqs <- process summedItems
bk' <- mapM (mapMapM (justState . transformBKList bkF)) bk
let eqs' = map (transformQuad id mapToArray (map mapToArray) (map mapToArray)) eqs :: [([ModuloCase], EqualityConstraintEquation, EqualityProblem, InequalityProblem)]
return $ Group [((l,c,bk'),t,(e0,e1,e2)) | (c,e0,e1,e2) <- eqs']
where
process :: [FlattenedExp] -> StateT VarMap (Either String) [([ModuloCase], Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])]
process :: [FlattenedExp] -> BKM [([ModuloCase], Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])]
process = foldM makeEquation' empty
makeEquation' :: [([ModuloCase], Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])] ->
FlattenedExp ->
StateT (VarMap) (Either String)
BKM
[([ModuloCase], Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])]
makeEquation' m (Const n) = return $ add (0,n) m
makeEquation' m sc@(Scale n v) = varIndex sc >>* (\ind -> add (ind, n) m)

View File

@ -26,8 +26,8 @@ module Types
, returnTypesOfFunction
, BytesInResult(..), bytesInType, countReplicator, countStructured, computeStructured
, makeAbbrevAM, makeConstant, makeDimension, specificDimSize
, addOne, subOne, addExprs, subExprs, mulExprs, divExprs
, makeAbbrevAM, makeConstant, makeConstant', makeDimension, specificDimSize
, addOne, subOne, addExprs, subExprs, mulExprs, divExprs, remExprs
, addOneInt, subOneInt, addExprsInt, subExprsInt, mulExprsInt, divExprsInt
, addDimensions, applyDimension, removeFixedDimensions, trivialSubscriptType, subscriptType, unsubscriptType
, applyDirection
@ -397,7 +397,10 @@ makeAbbrevAM am = am
-- | Generate a constant expression from an integer -- for array sizes and the
-- like.
makeConstant :: Meta -> Int -> A.Expression
makeConstant m n = A.Literal m A.Int $ A.IntLiteral m (show n)
makeConstant m = makeConstant' m A.Int . toInteger
makeConstant' :: Meta -> A.Type -> Integer -> A.Expression
makeConstant' m t n = A.Literal m t $ A.IntLiteral m (show n)
-- | Generate a constant dimension from an integer.
makeDimension :: Meta -> Int -> A.Dimension
@ -858,6 +861,10 @@ mulExprs = dyadicExpr "*"
divExprs :: DyadicExprM
divExprs = dyadicExpr "/"
-- | Divide two expressions.
remExprs :: DyadicExprM
remExprs = dyadicExpr "\\"
-- | Add two expressions.
addExprsInt :: DyadicExpr
addExprsInt = dyadicExpr' (A.Int,A.Int) "+"