A big patch to get ArrayUsageCheck compiling again, with the new operators

This commit is contained in:
Neil Brown 2009-04-08 15:44:35 +00:00
parent d2fb80c516
commit 5c353f10b7
2 changed files with 96 additions and 72 deletions

View File

@ -32,6 +32,7 @@ module ArrayUsageCheck (
VarMap) where VarMap) where
import Control.Monad.Error import Control.Monad.Error
import Control.Monad.Reader
import Control.Monad.State import Control.Monad.State
import Data.Array.IArray import Data.Array.IArray
import qualified Data.Foldable as F import qualified Data.Foldable as F
@ -78,7 +79,7 @@ findRepSolutions reps bks
-- which means they can overlap -- but only if there is also a solution to the -- which means they can overlap -- but only if there is also a solution to the
-- replicator background knowledge, which is what this function is trying to -- replicator background knowledge, which is what this function is trying to
-- determine. -- determine.
= case makeEquations (addReps $ SeqItems [(bk, [makeConstant emptyMeta 0], []) = getCompState >>= \cs -> case flip runReaderT cs $ makeEquations (addReps $ SeqItems [(bk, [makeConstant emptyMeta 0], [])
| bk <- bks]) maxInt of | bk <- bks]) maxInt of
Right problems -> do Right problems -> do
probs <- formatProblems [(vm, prob) | (_,vm,prob) <- problems] probs <- formatProblems [(vm, prob) | (_,vm,prob) <- problems]
@ -170,7 +171,8 @@ checkArrayUsage sharedAttr (m,p)
A.Array (A.UnknownDimension:_) _ -> return $ makeConstant m $ fromInteger $ toInteger (maxBound :: Int32) A.Array (A.UnknownDimension:_) _ -> return $ makeConstant m $ fromInteger $ toInteger (maxBound :: Int32)
-- It's not an array: -- It's not an array:
_ -> dieP m $ "Cannot usage check array \"" ++ userArrName ++ "\"; found to be of type: " ++ show arrType _ -> dieP m $ "Cannot usage check array \"" ++ userArrName ++ "\"; found to be of type: " ++ show arrType
case makeEquations indexes arrLength of cs <- getCompState
case runReaderT (makeEquations indexes arrLength) cs of
Left err -> dieP m $ "Could not work with array indexes for array \"" ++ userArrName ++ "\": " ++ err Left err -> dieP m $ "Could not work with array indexes for array \"" ++ userArrName ++ "\": " ++ err
Right [] -> return () -- No problems to work with Right [] -> return () -- No problems to work with
Right problems -> do Right problems -> do
@ -407,10 +409,10 @@ parItemToArrayAccessM f (RepParItem rep p)
-- into a set of expressions with at most one constant term, and at most one appearance -- into a set of expressions with at most one constant term, and at most one appearance
-- of a any variable, or distinct modulo\/division of variables. -- of a any variable, or distinct modulo\/division of variables.
-- If there is any problem (specifically, nested modulo or divisions) an error will be returned instead -- If there is any problem (specifically, nested modulo or divisions) an error will be returned instead
makeExpSet :: [FlattenedExp] -> Either String (Set.Set FlattenedExp) makeExpSet :: forall m. MonadError String m => [FlattenedExp] -> m (Set.Set FlattenedExp)
makeExpSet = foldM makeExpSet' Set.empty makeExpSet = foldM makeExpSet' Set.empty
where where
makeExpSet' :: Set.Set FlattenedExp -> FlattenedExp -> Either String (Set.Set FlattenedExp) makeExpSet' :: Set.Set FlattenedExp -> FlattenedExp -> m (Set.Set FlattenedExp)
makeExpSet' accum (Const n) = return $ insert (addConst n) (Const n) accum makeExpSet' accum (Const n) = return $ insert (addConst n) (Const n) accum
makeExpSet' accum (Scale n v) = return $ insert (addScale n v) (Scale n v) accum makeExpSet' accum (Scale n v) = return $ insert (addScale n v) (Scale n v) accum
makeExpSet' accum m@(Modulo {}) | Set.member m accum = throwError "Cannot have repeated REM items in an expression" makeExpSet' accum m@(Modulo {}) | Set.member m accum = throwError "Cannot have repeated REM items in an expression"
@ -465,9 +467,12 @@ data ModuloCase =
| XNegYNegAZero | XNegYNegANonZero | XNegYNegAZero | XNegYNegANonZero
deriving (Show, Eq, Ord) deriving (Show, Eq, Ord)
type BKM = StateT VarMap (ReaderT CompState (Either String))
-- | Transforms background knowledge into problems -- | Transforms background knowledge into problems
-- TODO allow modulo in background knowledge -- TODO allow modulo in background knowledge
transformBK :: ([FlattenedExp] -> [FlattenedExp]) -> BackgroundKnowledge -> StateT VarMap (Either String) (EqualityProblem,InequalityProblem) transformBK :: ([FlattenedExp] -> [FlattenedExp]) -> BackgroundKnowledge ->
BKM (EqualityProblem,InequalityProblem)
transformBK f (Equal eL eR) = do eL' <- makeSingleEq f eL "background knowledge" transformBK f (Equal eL eR) = do eL' <- makeSingleEq f eL "background knowledge"
eR' <- makeSingleEq f eR "background knowledge" eR' <- makeSingleEq f eR "background knowledge"
let e = addEq eL' (amap negate eR') let e = addEq eL' (amap negate eR')
@ -493,12 +498,12 @@ transformBK f (RepBoundsIncl v low high)
, addEq (amap negate eLow) ev' , addEq (amap negate eLow) ev'
]) ])
transformBKList :: ([FlattenedExp] -> [FlattenedExp]) -> [BackgroundKnowledge] -> StateT VarMap (Either String) (EqualityProblem,InequalityProblem) transformBKList :: ([FlattenedExp] -> [FlattenedExp]) -> [BackgroundKnowledge] -> BKM (EqualityProblem,InequalityProblem)
transformBKList f bk = mapM (transformBK f) bk >>* foldl accumProblem ([],[]) transformBKList f bk = mapM (transformBK f) bk >>* foldl accumProblem ([],[])
-- | Turns a single expression into an equation-item. An error is given if the resulting -- | Turns a single expression into an equation-item. An error is given if the resulting
-- expression is anything complicated (for example, modulo or divide) -- expression is anything complicated (for example, modulo or divide)
makeSingleEq :: ([FlattenedExp] -> [FlattenedExp]) -> A.Expression -> String -> StateT VarMap (Either String) EqualityConstraintEquation makeSingleEq :: ([FlattenedExp] -> [FlattenedExp]) -> A.Expression -> String -> BKM EqualityConstraintEquation
makeSingleEq f e desc = (lift (flatten e) >>* f) >>= makeEquation e ([{-TODO-}], makeSingleEq f e desc = (lift (flatten e) >>* f) >>= makeEquation e ([{-TODO-}],
f) (error $ "Type is irrelevant for " ++ desc) f) (error $ "Type is irrelevant for " ++ desc)
>>= getSingleAccessItem ("Modulo or Divide not allowed in " ++ desc >>= getSingleAccessItem ("Modulo or Divide not allowed in " ++ desc
@ -536,13 +541,13 @@ accumProblem = concatPair
-- --
-- TODO probably want to take this into the PassM monad at some point, to use the Meta in the error message -- TODO probably want to take this into the PassM monad at some point, to use the Meta in the error message
makeEquations :: ParItems (BK, [A.Expression], [A.Expression]) -> A.Expression -> makeEquations :: ParItems (BK, [A.Expression], [A.Expression]) -> A.Expression ->
Either String [(((A.Expression, [ModuloCase]), (A.Expression, [ModuloCase])), VarMap, (EqualityProblem, InequalityProblem))] ReaderT CompState (Either String) [(((A.Expression, [ModuloCase]), (A.Expression, [ModuloCase])), VarMap, (EqualityProblem, InequalityProblem))]
makeEquations accesses bound makeEquations accesses bound
= do ((v,h,repVarIndexes, allReps),s) <- (flip runStateT) Map.empty $ = do ((v,h,repVarIndexes, allReps),s) <- (flip runStateT) Map.empty $
do ((accesses', allReps),repVars) <- flip runStateT [] $ parItemToArrayAccessM mkEq accesses do ((accesses', allReps),repVars) <- flip runStateT [] $ parItemToArrayAccessM mkEq accesses
high <- makeSingleEq id bound "upper bound" high <- makeSingleEq id bound "upper bound"
return (accesses', high, nub repVars, allReps) return (accesses', high, nub repVars, allReps)
squareAndPair (lookupBK allReps) (\(x,y,_) -> (x,y)) repVarIndexes s v (amap (const 0) h, addConstant (-1) h) lift $ squareAndPair (lookupBK allReps) (\(x,y,_) -> (x,y)) repVarIndexes s v (amap (const 0) h, addConstant (-1) h)
where where
lookupBK :: [A.Name] -> (A.Expression, [ModuloCase], BK') -> Either String lookupBK :: [A.Name] -> (A.Expression, [ModuloCase], BK') -> Either String
@ -582,7 +587,7 @@ makeEquations accesses bound
mkEq :: [((A.Name, A.Replicator), Bool)] -> mkEq :: [((A.Name, A.Replicator), Bool)] ->
(BK, [A.Expression], [A.Expression]) -> (BK, [A.Expression], [A.Expression]) ->
StateT [(CoeffIndex, CoeffIndex)] StateT [(CoeffIndex, CoeffIndex)]
(StateT VarMap (Either String)) BKM
[((A.Expression, [ModuloCase], BK'), ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))] [((A.Expression, [ModuloCase], BK'), ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]
mkEq reps (bk, ws, rs) = do repVarEqs <- mapM (liftF makeRepVarEq) reps mkEq reps (bk, ws, rs) = do repVarEqs <- mapM (liftF makeRepVarEq) reps
concatMapM (mkEq' repVarEqs) (ws' ++ rs') concatMapM (mkEq' repVarEqs) (ws' ++ rs')
@ -590,16 +595,16 @@ makeEquations accesses bound
ws' = zip (repeat AAWrite) ws ws' = zip (repeat AAWrite) ws
rs' = zip (repeat AARead) rs rs' = zip (repeat AARead) rs
makeRepVarEq :: ((A.Name, A.Replicator), Bool) -> StateT VarMap (Either String) (A.Variable, EqualityConstraintEquation, EqualityConstraintEquation) makeRepVarEq :: ((A.Name, A.Replicator), Bool) -> BKM (A.Variable, EqualityConstraintEquation, EqualityConstraintEquation)
makeRepVarEq ((varName, A.For m from for _), _) makeRepVarEq ((varName, A.For m from for _), _)
= do from' <- makeSingleEq id from "replication start" = do from' <- makeSingleEq id from "replication start"
upper <- makeSingleEq id (A.Dyadic m A.Subtr (A.Dyadic m A.Add for from) (makeConstant m 1)) "replication count" upper <- makeSingleEq id (subExprsInt (addExprsInt for from) (makeConstant m 1)) "replication count"
return (A.Variable m varName, from', upper) return (A.Variable m varName, from', upper)
mkEq' :: [(A.Variable, EqualityConstraintEquation, EqualityConstraintEquation)] -> mkEq' :: [(A.Variable, EqualityConstraintEquation, EqualityConstraintEquation)] ->
(ArrayAccessType, A.Expression) -> (ArrayAccessType, A.Expression) ->
StateT [(CoeffIndex,CoeffIndex)] StateT [(CoeffIndex,CoeffIndex)]
(StateT VarMap (Either String)) BKM
[((A.Expression, [ModuloCase], BK'), ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))] [((A.Expression, [ModuloCase], BK'), ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))]
mkEq' repVarEqs (aat, e) mkEq' repVarEqs (aat, e)
= do f <- lift . lift $ flatten e = do f <- lift . lift $ flatten e
@ -610,7 +615,7 @@ makeEquations accesses bound
_ -> throwError "Replicated group found unexpectedly" _ -> throwError "Replicated group found unexpectedly"
-- | Turns all instances of the variable from the given replicator into their primed version in the given expression -- | Turns all instances of the variable from the given replicator into their primed version in the given expression
mirrorFlaggedVar :: ((A.Name, A.Replicator),Bool) -> StateT [(CoeffIndex,CoeffIndex)] (StateT VarMap (Either String)) ([FlattenedExp] -> [FlattenedExp]) mirrorFlaggedVar :: ((A.Name, A.Replicator),Bool) -> StateT [(CoeffIndex,CoeffIndex)] BKM ([FlattenedExp] -> [FlattenedExp])
mirrorFlaggedVar (_,False) = return id mirrorFlaggedVar (_,False) = return id
mirrorFlaggedVar ((varName, A.For m from for _), True) mirrorFlaggedVar ((varName, A.For m from for _), True)
= do varIndexes <- lift $ seqPair (varIndex (Scale 1 (A.ExprVariable emptyMeta var,0)), varIndex (Scale 1 (A.ExprVariable emptyMeta var,1))) = do varIndexes <- lift $ seqPair (varIndex (Scale 1 (A.ExprVariable emptyMeta var,0)), varIndex (Scale 1 (A.ExprVariable emptyMeta var,1)))
@ -619,93 +624,104 @@ makeEquations accesses bound
where where
var = A.Variable m varName var = A.Variable m varName
instance Die (ReaderT CompState (Either String)) where
dieReport (_, s) = throwError s
-- Note that in all these functions, the divisor should always be positive! -- Note that in all these functions, the divisor should always be positive!
canonicalise :: A.Expression -> A.Expression canonicalise :: forall m. (CSMR m, Die m) => A.Expression -> m A.Expression
canonicalise e@(A.Dyadic m op _ _) | op == A.Add || op == A.Mul canonicalise e@(A.FunctionCall m n es)
= foldl1 (A.Dyadic m op) $ sort $ gatherTerms op e = do mOp <- functionOperator n
ts <- mapM astTypeOf es
case (mOp, fmap (\op -> A.nameName n == occamDefaultOperator op ts) mOp) of
(Just op, Just True) | op == "+" || op == "*"
-> liftM (foldl1 (\a b -> A.FunctionCall m n [a, b]) . sort) $ gatherTerms n e
_ -> mapM canonicalise es >>* A.FunctionCall m n
where where
gatherTerms :: A.DyadicOp -> A.Expression -> [A.Expression] gatherTerms :: A.Name -> A.Expression -> m [A.Expression]
gatherTerms op (A.Dyadic _ op' lhs rhs) | op == op' gatherTerms n (A.FunctionCall _ n' es) | n == n'
= gatherTerms op lhs ++ gatherTerms op rhs = concatMapM (gatherTerms n) es
gatherTerms _ e = [canonicalise e] gatherTerms _ e = canonicalise e >>* singleton
canonicalise (A.Dyadic m op lhs rhs) canonicalise e = return e
= A.Dyadic m op (canonicalise lhs) (canonicalise rhs)
canonicalise (A.Monadic m op rhs)
= A.Monadic m op (canonicalise rhs)
canonicalise e = e
flatten :: A.Expression -> Either String [FlattenedExp] flatten :: A.Expression -> ReaderT CompState (Either String) [FlattenedExp]
flatten (A.Literal _ _ (A.IntLiteral _ n)) = return [Const (read n)] flatten (A.Literal _ _ (A.IntLiteral _ n)) = return [Const (read n)]
flatten e@(A.Dyadic m op lhs rhs) flatten e@(A.FunctionCall m fn [lhs, rhs])
| op == A.Add = combine' (flatten lhs) (flatten rhs) = do mOp <- builtInOperator fn
| op == A.Subtr = combine' (flatten lhs) (mapM (scale (-1)) =<< flatten rhs) case mOp of
| op == A.Mul = multiplyOut' (flatten lhs) (flatten rhs) Just "+" -> combine' (flatten lhs) (flatten rhs)
| op == A.Rem = liftM2L (Modulo 1) (flatten lhs) (flatten rhs) Just "-" -> combine' (flatten lhs) (mapM (scale (-1)) =<< flatten rhs)
| op == A.Div = do rhs' <- flatten rhs Just "*" -> multiplyOut' (flatten lhs) (flatten rhs)
Just "\\" -> liftM2L (Modulo 1) (flatten lhs) (flatten rhs)
Just "/" ->do rhs' <- flatten rhs
case onlyConst rhs' of case onlyConst rhs' of
Just _ -> liftM2L (Divide 1) (flatten lhs) (return rhs') Just _ -> liftM2L (Divide 1) (flatten lhs) (return rhs')
-- Can't deal with variable divisors, leave expression as-is: -- Can't deal with variable divisors, leave expression as-is:
Nothing -> return [Scale 1 (canonicalise e,0)] Nothing -> do e' <- canonicalise e
return [Scale 1 (e',0)]
_ -> do e' <- canonicalise e
return [Scale 1 (e',0)]
where where
-- liftM2L :: (Ord a, Ord b, Monad m) => (Set.Set a -> Set.Set b -> c) -> m [a] -> m [b] -> m [c] liftM2L :: MonadError String m => (Set.Set FlattenedExp -> Set.Set FlattenedExp -> c)
-> m [FlattenedExp] -> m [FlattenedExp] -> m [c]
liftM2L f x y = liftM singleton $ liftM2 f (x >>= makeExpSet) (y >>= makeExpSet) liftM2L f x y = liftM singleton $ liftM2 f (x >>= makeExpSet) (y >>= makeExpSet)
multiplyOut' :: Either String [FlattenedExp] -> Either String [FlattenedExp] -> Either String [FlattenedExp] multiplyOut' :: (Die m, CSMR m, MonadError String m ) => m [FlattenedExp] -> m [FlattenedExp] -> m [FlattenedExp]
multiplyOut' x y = do {x' <- x; y' <- y; multiplyOut x' y'} multiplyOut' x y = join $ liftM2 multiplyOut x y
multiplyOut :: [FlattenedExp] -> [FlattenedExp] -> Either String [FlattenedExp] multiplyOut :: forall m. (Die m, CSMR m, MonadError String m) => [FlattenedExp] -> [FlattenedExp] -> m [FlattenedExp]
multiplyOut lhs rhs = mapM (uncurry mult) pairs multiplyOut lhs rhs = mapM (uncurry mult) pairs
where where
pairs = product2 (lhs,rhs) pairs = product2 (lhs,rhs)
mult :: FlattenedExp -> FlattenedExp -> Either String FlattenedExp mult :: FlattenedExp -> FlattenedExp -> m FlattenedExp
mult (Const x) e = scale x e mult (Const x) e = scale x e
mult e (Const x) = scale x e mult e (Const x) = scale x e
mult lhs rhs mult lhs rhs
= do lhs' <- backToEq lhs = do lhs' <- backToEq lhs
rhs' <- backToEq rhs rhs' <- backToEq rhs
return $ (Scale 1 (canonicalise $ A.Dyadic emptyMeta A.Mul lhs' rhs', 0)) e <- mulExprs lhs' rhs' >>= canonicalise
return $ (Scale 1 (e, 0))
backScale :: Integer -> A.Expression -> A.Expression backScale :: Integer -> A.Expression -> m A.Expression
backScale 1 = id backScale 1 e = return e
backScale n = canonicalise . A.Dyadic emptyMeta A.Mul (makeConstant emptyMeta (fromInteger n)) backScale n e = do t <- astTypeOf e
mulExprs (makeConstant' emptyMeta t n) e >>= canonicalise
backToEq :: FlattenedExp -> Either String A.Expression backToEq :: FlattenedExp -> m A.Expression
backToEq (Const c) = return $ makeConstant emptyMeta (fromInteger c) backToEq (Const c) = return $ makeConstant emptyMeta (fromInteger c)
backToEq (Scale n (e,0)) = return $ backScale n e backToEq (Scale n (e,0)) = backScale n e
backToEq (Modulo n t b) backToEq (Modulo n t b)
| Set.null t || Set.null b = throwError "Modulo had empty top or bottom" | Set.null t || Set.null b = throwError "Modulo had empty top or bottom"
| otherwise = do t' <- mapM backToEq $ Set.toList t | otherwise = do t' <- mapM backToEq $ Set.toList t
b' <- mapM backToEq $ Set.toList b b' <- mapM backToEq $ Set.toList b
return $ t'' <- foldM1 addExprs t'
(backScale n $ A.Dyadic emptyMeta A.Rem b'' <- foldM1 addExprs b'
(foldl1 (A.Dyadic emptyMeta A.Add) t') remExprs t'' b'' >>= backScale n
(foldl1 (A.Dyadic emptyMeta A.Add) b'))
backToEq (Divide n t b) backToEq (Divide n t b)
| Set.null t || Set.null b = throwError "Divide had empty top or bottom" | Set.null t || Set.null b = throwError "Divide had empty top or bottom"
| otherwise = do t' <- mapM backToEq $ Set.toList t | otherwise = do t' <- mapM backToEq $ Set.toList t
b' <- mapM backToEq $ Set.toList b b' <- mapM backToEq $ Set.toList b
return $ t'' <- foldM1 addExprs t'
(backScale n $ A.Dyadic emptyMeta A.Div b'' <- foldM1 addExprs b'
(foldl1 (A.Dyadic emptyMeta A.Add) t') divExprs t'' b'' >>= backScale n
(foldl1 (A.Dyadic emptyMeta A.Add) b'))
-- | Scales a flattened expression by the given integer scaling. -- | Scales a flattened expression by the given integer scaling.
scale :: Integer -> FlattenedExp -> Either String FlattenedExp scale :: Monad m => Integer -> FlattenedExp -> m FlattenedExp
scale sc (Const n) = return $ Const (n * sc) scale sc (Const n) = return $ Const (n * sc)
scale sc (Scale n v) = return $ Scale (n * sc) v scale sc (Scale n v) = return $ Scale (n * sc) v
scale sc (Modulo n t b) = return $ Modulo (n * sc) t b scale sc (Modulo n t b) = return $ Modulo (n * sc) t b
scale sc (Divide n t b) = return $ Divide (n * sc) t b scale sc (Divide n t b) = return $ Divide (n * sc) t b
-- | An easy way of applying combine to two monadic returns -- | An easy way of applying combine to two monadic returns
combine' :: Either String [FlattenedExp] -> Either String [FlattenedExp] -> Either String [FlattenedExp] combine' :: Monad m => m [FlattenedExp] -> m [FlattenedExp] -> m [FlattenedExp]
combine' = liftM2 combine combine' = liftM2 combine
-- | Combines (adds) two flattened expressions. -- | Combines (adds) two flattened expressions.
combine :: [FlattenedExp] -> [FlattenedExp] -> [FlattenedExp] combine :: [FlattenedExp] -> [FlattenedExp] -> [FlattenedExp]
combine = (++) combine = (++)
flatten e = return [Scale 1 (canonicalise e,0)] flatten e = do e' <- canonicalise e
return [Scale 1 (e',0)]
-- | The "square" refers to making all equations the length of the longest -- | The "square" refers to making all equations the length of the longest
-- one, and the pair refers to pairing each in a list of array accesses (e.g. -- one, and the pair refers to pairing each in a list of array accesses (e.g.
@ -779,18 +795,18 @@ squareAndPair lookupBK strip repVars s v lh
-- prime >= plain + 1 (prime - plain - 1 >= 0) -- prime >= plain + 1 (prime - plain - 1 >= 0)
= [mapToArray $ Map.fromList [(prime,1), (plain,-1), (0, -1)]] = [mapToArray $ Map.fromList [(prime,1), (plain,-1), (0, -1)]]
getSingleAccessItem :: MonadTrans m => String -> ArrayAccess label -> m (Either String) EqualityConstraintEquation getSingleAccessItem :: MonadError String m => String -> ArrayAccess label -> m EqualityConstraintEquation
getSingleAccessItem _ (Group [(_,_,(acc,_,_))]) = lift $ return acc getSingleAccessItem _ (Group [(_,_,(acc,_,_))]) = return acc
getSingleAccessItem err _ = lift $ throwError err getSingleAccessItem err _ = throwError err
-- | Odd helper function for getting\/asserting the first item of a triple from a singleton list inside a monad transformer (!) -- | Odd helper function for getting\/asserting the first item of a triple from a singleton list inside a monad transformer (!)
getSingleItem :: MonadTrans m => String -> [(a,b,c)] -> m (Either String) a getSingleItem :: MonadError String m => String -> [(a,b,c)] -> m a
getSingleItem _ [(item,_,_)] = lift $ return item getSingleItem _ [(item,_,_)] = return item
getSingleItem err _ = lift $ throwError err getSingleItem err _ = throwError err
-- | Finds the index associated with a particular variable; either by finding an existing index -- | Finds the index associated with a particular variable; either by finding an existing index
-- or allocating a new one. -- or allocating a new one.
varIndex :: FlattenedExp -> StateT (VarMap) (Either String) Int varIndex :: FlattenedExp -> BKM Int
varIndex (Scale _ (e,vi)) varIndex (Scale _ (e,vi))
= do st <- get = do st <- get
let (st',ind) = case Map.lookup (Scale 1 (e,vi)) st of let (st',ind) = case Map.lookup (Scale 1 (e,vi)) st of
@ -864,29 +880,30 @@ getIneqs (low, high) = concatMap getLH
getLH :: EqualityConstraintEquation -> [InequalityConstraintEquation] getLH :: EqualityConstraintEquation -> [InequalityConstraintEquation]
getLH eq = [eq `addEq` (amap negate low),high `addEq` amap negate eq] getLH eq = [eq `addEq` (amap negate low),high `addEq` amap negate eq]
justState :: Error e => StateT s (Either e) a -> StateT s (Either e) (Either e a) justState :: Error e => StateT s (ReaderT r (Either e)) a -> StateT s (ReaderT r (Either e)) (Either e a)
justState m = do st <- get justState m = do st <- get
let (x, st') = case runStateT m st of r <- ask
let (x, st') = case runReaderT (runStateT m st) r of
Left err -> (Left err, st) Left err -> (Left err, st)
Right (x, st') -> (Right x, st') Right (x, st') -> (Right x, st')
put st' put st'
return x return x
-- | Given an expression, forms equations (and accompanying additional equation-sets) and returns it -- | Given an expression, forms equations (and accompanying additional equation-sets) and returns it
makeEquation :: label -> (BK, [FlattenedExp] -> [FlattenedExp]) -> ArrayAccessType -> [FlattenedExp] -> StateT VarMap (Either String) (ArrayAccess (label,[ModuloCase], makeEquation :: label -> (BK, [FlattenedExp] -> [FlattenedExp]) -> ArrayAccessType -> [FlattenedExp]
BK')) -> BKM (ArrayAccess (label,[ModuloCase], BK'))
makeEquation l (bk, bkF) t summedItems makeEquation l (bk, bkF) t summedItems
= do eqs <- process summedItems = do eqs <- process summedItems
bk' <- mapM (mapMapM (justState . transformBKList bkF)) bk bk' <- mapM (mapMapM (justState . transformBKList bkF)) bk
let eqs' = map (transformQuad id mapToArray (map mapToArray) (map mapToArray)) eqs :: [([ModuloCase], EqualityConstraintEquation, EqualityProblem, InequalityProblem)] let eqs' = map (transformQuad id mapToArray (map mapToArray) (map mapToArray)) eqs :: [([ModuloCase], EqualityConstraintEquation, EqualityProblem, InequalityProblem)]
return $ Group [((l,c,bk'),t,(e0,e1,e2)) | (c,e0,e1,e2) <- eqs'] return $ Group [((l,c,bk'),t,(e0,e1,e2)) | (c,e0,e1,e2) <- eqs']
where where
process :: [FlattenedExp] -> StateT VarMap (Either String) [([ModuloCase], Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])] process :: [FlattenedExp] -> BKM [([ModuloCase], Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])]
process = foldM makeEquation' empty process = foldM makeEquation' empty
makeEquation' :: [([ModuloCase], Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])] -> makeEquation' :: [([ModuloCase], Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])] ->
FlattenedExp -> FlattenedExp ->
StateT (VarMap) (Either String) BKM
[([ModuloCase], Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])] [([ModuloCase], Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])]
makeEquation' m (Const n) = return $ add (0,n) m makeEquation' m (Const n) = return $ add (0,n) m
makeEquation' m sc@(Scale n v) = varIndex sc >>* (\ind -> add (ind, n) m) makeEquation' m sc@(Scale n v) = varIndex sc >>* (\ind -> add (ind, n) m)

View File

@ -26,8 +26,8 @@ module Types
, returnTypesOfFunction , returnTypesOfFunction
, BytesInResult(..), bytesInType, countReplicator, countStructured, computeStructured , BytesInResult(..), bytesInType, countReplicator, countStructured, computeStructured
, makeAbbrevAM, makeConstant, makeDimension, specificDimSize , makeAbbrevAM, makeConstant, makeConstant', makeDimension, specificDimSize
, addOne, subOne, addExprs, subExprs, mulExprs, divExprs , addOne, subOne, addExprs, subExprs, mulExprs, divExprs, remExprs
, addOneInt, subOneInt, addExprsInt, subExprsInt, mulExprsInt, divExprsInt , addOneInt, subOneInt, addExprsInt, subExprsInt, mulExprsInt, divExprsInt
, addDimensions, applyDimension, removeFixedDimensions, trivialSubscriptType, subscriptType, unsubscriptType , addDimensions, applyDimension, removeFixedDimensions, trivialSubscriptType, subscriptType, unsubscriptType
, applyDirection , applyDirection
@ -397,7 +397,10 @@ makeAbbrevAM am = am
-- | Generate a constant expression from an integer -- for array sizes and the -- | Generate a constant expression from an integer -- for array sizes and the
-- like. -- like.
makeConstant :: Meta -> Int -> A.Expression makeConstant :: Meta -> Int -> A.Expression
makeConstant m n = A.Literal m A.Int $ A.IntLiteral m (show n) makeConstant m = makeConstant' m A.Int . toInteger
makeConstant' :: Meta -> A.Type -> Integer -> A.Expression
makeConstant' m t n = A.Literal m t $ A.IntLiteral m (show n)
-- | Generate a constant dimension from an integer. -- | Generate a constant dimension from an integer.
makeDimension :: Meta -> Int -> A.Dimension makeDimension :: Meta -> Int -> A.Dimension
@ -858,6 +861,10 @@ mulExprs = dyadicExpr "*"
divExprs :: DyadicExprM divExprs :: DyadicExprM
divExprs = dyadicExpr "/" divExprs = dyadicExpr "/"
-- | Divide two expressions.
remExprs :: DyadicExprM
remExprs = dyadicExpr "\\"
-- | Add two expressions. -- | Add two expressions.
addExprsInt :: DyadicExpr addExprsInt :: DyadicExpr
addExprsInt = dyadicExpr' (A.Int,A.Int) "+" addExprsInt = dyadicExpr' (A.Int,A.Int) "+"