From 5c353f10b7862af5701fafdb8fddfb7272669e93 Mon Sep 17 00:00:00 2001 From: Neil Brown Date: Wed, 8 Apr 2009 15:44:35 +0000 Subject: [PATCH] A big patch to get ArrayUsageCheck compiling again, with the new operators --- checks/ArrayUsageCheck.hs | 155 +++++++++++++++++++++----------------- common/Types.hs | 13 +++- 2 files changed, 96 insertions(+), 72 deletions(-) diff --git a/checks/ArrayUsageCheck.hs b/checks/ArrayUsageCheck.hs index 3bf3718..c513445 100644 --- a/checks/ArrayUsageCheck.hs +++ b/checks/ArrayUsageCheck.hs @@ -32,6 +32,7 @@ module ArrayUsageCheck ( VarMap) where import Control.Monad.Error +import Control.Monad.Reader import Control.Monad.State import Data.Array.IArray import qualified Data.Foldable as F @@ -78,7 +79,7 @@ findRepSolutions reps bks -- which means they can overlap -- but only if there is also a solution to the -- replicator background knowledge, which is what this function is trying to -- determine. - = case makeEquations (addReps $ SeqItems [(bk, [makeConstant emptyMeta 0], []) + = getCompState >>= \cs -> case flip runReaderT cs $ makeEquations (addReps $ SeqItems [(bk, [makeConstant emptyMeta 0], []) | bk <- bks]) maxInt of Right problems -> do probs <- formatProblems [(vm, prob) | (_,vm,prob) <- problems] @@ -170,7 +171,8 @@ checkArrayUsage sharedAttr (m,p) A.Array (A.UnknownDimension:_) _ -> return $ makeConstant m $ fromInteger $ toInteger (maxBound :: Int32) -- It's not an array: _ -> dieP m $ "Cannot usage check array \"" ++ userArrName ++ "\"; found to be of type: " ++ show arrType - case makeEquations indexes arrLength of + cs <- getCompState + case runReaderT (makeEquations indexes arrLength) cs of Left err -> dieP m $ "Could not work with array indexes for array \"" ++ userArrName ++ "\": " ++ err Right [] -> return () -- No problems to work with Right problems -> do @@ -407,10 +409,10 @@ parItemToArrayAccessM f (RepParItem rep p) -- into a set of expressions with at most one constant term, and at most one appearance -- of a any variable, or distinct modulo\/division of variables. -- If there is any problem (specifically, nested modulo or divisions) an error will be returned instead -makeExpSet :: [FlattenedExp] -> Either String (Set.Set FlattenedExp) +makeExpSet :: forall m. MonadError String m => [FlattenedExp] -> m (Set.Set FlattenedExp) makeExpSet = foldM makeExpSet' Set.empty where - makeExpSet' :: Set.Set FlattenedExp -> FlattenedExp -> Either String (Set.Set FlattenedExp) + makeExpSet' :: Set.Set FlattenedExp -> FlattenedExp -> m (Set.Set FlattenedExp) makeExpSet' accum (Const n) = return $ insert (addConst n) (Const n) accum makeExpSet' accum (Scale n v) = return $ insert (addScale n v) (Scale n v) accum makeExpSet' accum m@(Modulo {}) | Set.member m accum = throwError "Cannot have repeated REM items in an expression" @@ -465,9 +467,12 @@ data ModuloCase = | XNegYNegAZero | XNegYNegANonZero deriving (Show, Eq, Ord) +type BKM = StateT VarMap (ReaderT CompState (Either String)) + -- | Transforms background knowledge into problems -- TODO allow modulo in background knowledge -transformBK :: ([FlattenedExp] -> [FlattenedExp]) -> BackgroundKnowledge -> StateT VarMap (Either String) (EqualityProblem,InequalityProblem) +transformBK :: ([FlattenedExp] -> [FlattenedExp]) -> BackgroundKnowledge -> + BKM (EqualityProblem,InequalityProblem) transformBK f (Equal eL eR) = do eL' <- makeSingleEq f eL "background knowledge" eR' <- makeSingleEq f eR "background knowledge" let e = addEq eL' (amap negate eR') @@ -493,12 +498,12 @@ transformBK f (RepBoundsIncl v low high) , addEq (amap negate eLow) ev' ]) -transformBKList :: ([FlattenedExp] -> [FlattenedExp]) -> [BackgroundKnowledge] -> StateT VarMap (Either String) (EqualityProblem,InequalityProblem) +transformBKList :: ([FlattenedExp] -> [FlattenedExp]) -> [BackgroundKnowledge] -> BKM (EqualityProblem,InequalityProblem) transformBKList f bk = mapM (transformBK f) bk >>* foldl accumProblem ([],[]) -- | Turns a single expression into an equation-item. An error is given if the resulting -- expression is anything complicated (for example, modulo or divide) -makeSingleEq :: ([FlattenedExp] -> [FlattenedExp]) -> A.Expression -> String -> StateT VarMap (Either String) EqualityConstraintEquation +makeSingleEq :: ([FlattenedExp] -> [FlattenedExp]) -> A.Expression -> String -> BKM EqualityConstraintEquation makeSingleEq f e desc = (lift (flatten e) >>* f) >>= makeEquation e ([{-TODO-}], f) (error $ "Type is irrelevant for " ++ desc) >>= getSingleAccessItem ("Modulo or Divide not allowed in " ++ desc @@ -536,13 +541,13 @@ accumProblem = concatPair -- -- TODO probably want to take this into the PassM monad at some point, to use the Meta in the error message makeEquations :: ParItems (BK, [A.Expression], [A.Expression]) -> A.Expression -> - Either String [(((A.Expression, [ModuloCase]), (A.Expression, [ModuloCase])), VarMap, (EqualityProblem, InequalityProblem))] + ReaderT CompState (Either String) [(((A.Expression, [ModuloCase]), (A.Expression, [ModuloCase])), VarMap, (EqualityProblem, InequalityProblem))] makeEquations accesses bound = do ((v,h,repVarIndexes, allReps),s) <- (flip runStateT) Map.empty $ do ((accesses', allReps),repVars) <- flip runStateT [] $ parItemToArrayAccessM mkEq accesses high <- makeSingleEq id bound "upper bound" return (accesses', high, nub repVars, allReps) - squareAndPair (lookupBK allReps) (\(x,y,_) -> (x,y)) repVarIndexes s v (amap (const 0) h, addConstant (-1) h) + lift $ squareAndPair (lookupBK allReps) (\(x,y,_) -> (x,y)) repVarIndexes s v (amap (const 0) h, addConstant (-1) h) where lookupBK :: [A.Name] -> (A.Expression, [ModuloCase], BK') -> Either String @@ -582,7 +587,7 @@ makeEquations accesses bound mkEq :: [((A.Name, A.Replicator), Bool)] -> (BK, [A.Expression], [A.Expression]) -> StateT [(CoeffIndex, CoeffIndex)] - (StateT VarMap (Either String)) + BKM [((A.Expression, [ModuloCase], BK'), ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))] mkEq reps (bk, ws, rs) = do repVarEqs <- mapM (liftF makeRepVarEq) reps concatMapM (mkEq' repVarEqs) (ws' ++ rs') @@ -590,16 +595,16 @@ makeEquations accesses bound ws' = zip (repeat AAWrite) ws rs' = zip (repeat AARead) rs - makeRepVarEq :: ((A.Name, A.Replicator), Bool) -> StateT VarMap (Either String) (A.Variable, EqualityConstraintEquation, EqualityConstraintEquation) + makeRepVarEq :: ((A.Name, A.Replicator), Bool) -> BKM (A.Variable, EqualityConstraintEquation, EqualityConstraintEquation) makeRepVarEq ((varName, A.For m from for _), _) = do from' <- makeSingleEq id from "replication start" - upper <- makeSingleEq id (A.Dyadic m A.Subtr (A.Dyadic m A.Add for from) (makeConstant m 1)) "replication count" + upper <- makeSingleEq id (subExprsInt (addExprsInt for from) (makeConstant m 1)) "replication count" return (A.Variable m varName, from', upper) mkEq' :: [(A.Variable, EqualityConstraintEquation, EqualityConstraintEquation)] -> (ArrayAccessType, A.Expression) -> StateT [(CoeffIndex,CoeffIndex)] - (StateT VarMap (Either String)) + BKM [((A.Expression, [ModuloCase], BK'), ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))] mkEq' repVarEqs (aat, e) = do f <- lift . lift $ flatten e @@ -610,7 +615,7 @@ makeEquations accesses bound _ -> throwError "Replicated group found unexpectedly" -- | Turns all instances of the variable from the given replicator into their primed version in the given expression - mirrorFlaggedVar :: ((A.Name, A.Replicator),Bool) -> StateT [(CoeffIndex,CoeffIndex)] (StateT VarMap (Either String)) ([FlattenedExp] -> [FlattenedExp]) + mirrorFlaggedVar :: ((A.Name, A.Replicator),Bool) -> StateT [(CoeffIndex,CoeffIndex)] BKM ([FlattenedExp] -> [FlattenedExp]) mirrorFlaggedVar (_,False) = return id mirrorFlaggedVar ((varName, A.For m from for _), True) = do varIndexes <- lift $ seqPair (varIndex (Scale 1 (A.ExprVariable emptyMeta var,0)), varIndex (Scale 1 (A.ExprVariable emptyMeta var,1))) @@ -619,93 +624,104 @@ makeEquations accesses bound where var = A.Variable m varName +instance Die (ReaderT CompState (Either String)) where + dieReport (_, s) = throwError s + -- Note that in all these functions, the divisor should always be positive! -canonicalise :: A.Expression -> A.Expression -canonicalise e@(A.Dyadic m op _ _) | op == A.Add || op == A.Mul - = foldl1 (A.Dyadic m op) $ sort $ gatherTerms op e +canonicalise :: forall m. (CSMR m, Die m) => A.Expression -> m A.Expression +canonicalise e@(A.FunctionCall m n es) + = do mOp <- functionOperator n + ts <- mapM astTypeOf es + case (mOp, fmap (\op -> A.nameName n == occamDefaultOperator op ts) mOp) of + (Just op, Just True) | op == "+" || op == "*" + -> liftM (foldl1 (\a b -> A.FunctionCall m n [a, b]) . sort) $ gatherTerms n e + _ -> mapM canonicalise es >>* A.FunctionCall m n where - gatherTerms :: A.DyadicOp -> A.Expression -> [A.Expression] - gatherTerms op (A.Dyadic _ op' lhs rhs) | op == op' - = gatherTerms op lhs ++ gatherTerms op rhs - gatherTerms _ e = [canonicalise e] -canonicalise (A.Dyadic m op lhs rhs) - = A.Dyadic m op (canonicalise lhs) (canonicalise rhs) -canonicalise (A.Monadic m op rhs) - = A.Monadic m op (canonicalise rhs) -canonicalise e = e + gatherTerms :: A.Name -> A.Expression -> m [A.Expression] + gatherTerms n (A.FunctionCall _ n' es) | n == n' + = concatMapM (gatherTerms n) es + gatherTerms _ e = canonicalise e >>* singleton +canonicalise e = return e -flatten :: A.Expression -> Either String [FlattenedExp] +flatten :: A.Expression -> ReaderT CompState (Either String) [FlattenedExp] flatten (A.Literal _ _ (A.IntLiteral _ n)) = return [Const (read n)] -flatten e@(A.Dyadic m op lhs rhs) - | op == A.Add = combine' (flatten lhs) (flatten rhs) - | op == A.Subtr = combine' (flatten lhs) (mapM (scale (-1)) =<< flatten rhs) - | op == A.Mul = multiplyOut' (flatten lhs) (flatten rhs) - | op == A.Rem = liftM2L (Modulo 1) (flatten lhs) (flatten rhs) - | op == A.Div = do rhs' <- flatten rhs +flatten e@(A.FunctionCall m fn [lhs, rhs]) + = do mOp <- builtInOperator fn + case mOp of + Just "+" -> combine' (flatten lhs) (flatten rhs) + Just "-" -> combine' (flatten lhs) (mapM (scale (-1)) =<< flatten rhs) + Just "*" -> multiplyOut' (flatten lhs) (flatten rhs) + Just "\\" -> liftM2L (Modulo 1) (flatten lhs) (flatten rhs) + Just "/" ->do rhs' <- flatten rhs case onlyConst rhs' of Just _ -> liftM2L (Divide 1) (flatten lhs) (return rhs') -- Can't deal with variable divisors, leave expression as-is: - Nothing -> return [Scale 1 (canonicalise e,0)] + Nothing -> do e' <- canonicalise e + return [Scale 1 (e',0)] + _ -> do e' <- canonicalise e + return [Scale 1 (e',0)] where --- liftM2L :: (Ord a, Ord b, Monad m) => (Set.Set a -> Set.Set b -> c) -> m [a] -> m [b] -> m [c] + liftM2L :: MonadError String m => (Set.Set FlattenedExp -> Set.Set FlattenedExp -> c) + -> m [FlattenedExp] -> m [FlattenedExp] -> m [c] liftM2L f x y = liftM singleton $ liftM2 f (x >>= makeExpSet) (y >>= makeExpSet) - multiplyOut' :: Either String [FlattenedExp] -> Either String [FlattenedExp] -> Either String [FlattenedExp] - multiplyOut' x y = do {x' <- x; y' <- y; multiplyOut x' y'} + multiplyOut' :: (Die m, CSMR m, MonadError String m ) => m [FlattenedExp] -> m [FlattenedExp] -> m [FlattenedExp] + multiplyOut' x y = join $ liftM2 multiplyOut x y - multiplyOut :: [FlattenedExp] -> [FlattenedExp] -> Either String [FlattenedExp] + multiplyOut :: forall m. (Die m, CSMR m, MonadError String m) => [FlattenedExp] -> [FlattenedExp] -> m [FlattenedExp] multiplyOut lhs rhs = mapM (uncurry mult) pairs where pairs = product2 (lhs,rhs) - mult :: FlattenedExp -> FlattenedExp -> Either String FlattenedExp + mult :: FlattenedExp -> FlattenedExp -> m FlattenedExp mult (Const x) e = scale x e mult e (Const x) = scale x e mult lhs rhs = do lhs' <- backToEq lhs rhs' <- backToEq rhs - return $ (Scale 1 (canonicalise $ A.Dyadic emptyMeta A.Mul lhs' rhs', 0)) + e <- mulExprs lhs' rhs' >>= canonicalise + return $ (Scale 1 (e, 0)) - backScale :: Integer -> A.Expression -> A.Expression - backScale 1 = id - backScale n = canonicalise . A.Dyadic emptyMeta A.Mul (makeConstant emptyMeta (fromInteger n)) + backScale :: Integer -> A.Expression -> m A.Expression + backScale 1 e = return e + backScale n e = do t <- astTypeOf e + mulExprs (makeConstant' emptyMeta t n) e >>= canonicalise - backToEq :: FlattenedExp -> Either String A.Expression + backToEq :: FlattenedExp -> m A.Expression backToEq (Const c) = return $ makeConstant emptyMeta (fromInteger c) - backToEq (Scale n (e,0)) = return $ backScale n e + backToEq (Scale n (e,0)) = backScale n e backToEq (Modulo n t b) | Set.null t || Set.null b = throwError "Modulo had empty top or bottom" | otherwise = do t' <- mapM backToEq $ Set.toList t b' <- mapM backToEq $ Set.toList b - return $ - (backScale n $ A.Dyadic emptyMeta A.Rem - (foldl1 (A.Dyadic emptyMeta A.Add) t') - (foldl1 (A.Dyadic emptyMeta A.Add) b')) + t'' <- foldM1 addExprs t' + b'' <- foldM1 addExprs b' + remExprs t'' b'' >>= backScale n backToEq (Divide n t b) | Set.null t || Set.null b = throwError "Divide had empty top or bottom" | otherwise = do t' <- mapM backToEq $ Set.toList t b' <- mapM backToEq $ Set.toList b - return $ - (backScale n $ A.Dyadic emptyMeta A.Div - (foldl1 (A.Dyadic emptyMeta A.Add) t') - (foldl1 (A.Dyadic emptyMeta A.Add) b')) + t'' <- foldM1 addExprs t' + b'' <- foldM1 addExprs b' + divExprs t'' b'' >>= backScale n -- | Scales a flattened expression by the given integer scaling. - scale :: Integer -> FlattenedExp -> Either String FlattenedExp + scale :: Monad m => Integer -> FlattenedExp -> m FlattenedExp scale sc (Const n) = return $ Const (n * sc) scale sc (Scale n v) = return $ Scale (n * sc) v scale sc (Modulo n t b) = return $ Modulo (n * sc) t b scale sc (Divide n t b) = return $ Divide (n * sc) t b -- | An easy way of applying combine to two monadic returns - combine' :: Either String [FlattenedExp] -> Either String [FlattenedExp] -> Either String [FlattenedExp] + combine' :: Monad m => m [FlattenedExp] -> m [FlattenedExp] -> m [FlattenedExp] combine' = liftM2 combine -- | Combines (adds) two flattened expressions. combine :: [FlattenedExp] -> [FlattenedExp] -> [FlattenedExp] combine = (++) -flatten e = return [Scale 1 (canonicalise e,0)] +flatten e = do e' <- canonicalise e + return [Scale 1 (e',0)] -- | The "square" refers to making all equations the length of the longest -- one, and the pair refers to pairing each in a list of array accesses (e.g. @@ -779,18 +795,18 @@ squareAndPair lookupBK strip repVars s v lh -- prime >= plain + 1 (prime - plain - 1 >= 0) = [mapToArray $ Map.fromList [(prime,1), (plain,-1), (0, -1)]] -getSingleAccessItem :: MonadTrans m => String -> ArrayAccess label -> m (Either String) EqualityConstraintEquation -getSingleAccessItem _ (Group [(_,_,(acc,_,_))]) = lift $ return acc -getSingleAccessItem err _ = lift $ throwError err +getSingleAccessItem :: MonadError String m => String -> ArrayAccess label -> m EqualityConstraintEquation +getSingleAccessItem _ (Group [(_,_,(acc,_,_))]) = return acc +getSingleAccessItem err _ = throwError err -- | Odd helper function for getting\/asserting the first item of a triple from a singleton list inside a monad transformer (!) -getSingleItem :: MonadTrans m => String -> [(a,b,c)] -> m (Either String) a -getSingleItem _ [(item,_,_)] = lift $ return item -getSingleItem err _ = lift $ throwError err +getSingleItem :: MonadError String m => String -> [(a,b,c)] -> m a +getSingleItem _ [(item,_,_)] = return item +getSingleItem err _ = throwError err -- | Finds the index associated with a particular variable; either by finding an existing index -- or allocating a new one. -varIndex :: FlattenedExp -> StateT (VarMap) (Either String) Int +varIndex :: FlattenedExp -> BKM Int varIndex (Scale _ (e,vi)) = do st <- get let (st',ind) = case Map.lookup (Scale 1 (e,vi)) st of @@ -864,29 +880,30 @@ getIneqs (low, high) = concatMap getLH getLH :: EqualityConstraintEquation -> [InequalityConstraintEquation] getLH eq = [eq `addEq` (amap negate low),high `addEq` amap negate eq] -justState :: Error e => StateT s (Either e) a -> StateT s (Either e) (Either e a) +justState :: Error e => StateT s (ReaderT r (Either e)) a -> StateT s (ReaderT r (Either e)) (Either e a) justState m = do st <- get - let (x, st') = case runStateT m st of + r <- ask + let (x, st') = case runReaderT (runStateT m st) r of Left err -> (Left err, st) Right (x, st') -> (Right x, st') put st' return x -- | Given an expression, forms equations (and accompanying additional equation-sets) and returns it -makeEquation :: label -> (BK, [FlattenedExp] -> [FlattenedExp]) -> ArrayAccessType -> [FlattenedExp] -> StateT VarMap (Either String) (ArrayAccess (label,[ModuloCase], - BK')) +makeEquation :: label -> (BK, [FlattenedExp] -> [FlattenedExp]) -> ArrayAccessType -> [FlattenedExp] + -> BKM (ArrayAccess (label,[ModuloCase], BK')) makeEquation l (bk, bkF) t summedItems = do eqs <- process summedItems bk' <- mapM (mapMapM (justState . transformBKList bkF)) bk let eqs' = map (transformQuad id mapToArray (map mapToArray) (map mapToArray)) eqs :: [([ModuloCase], EqualityConstraintEquation, EqualityProblem, InequalityProblem)] return $ Group [((l,c,bk'),t,(e0,e1,e2)) | (c,e0,e1,e2) <- eqs'] where - process :: [FlattenedExp] -> StateT VarMap (Either String) [([ModuloCase], Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])] + process :: [FlattenedExp] -> BKM [([ModuloCase], Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])] process = foldM makeEquation' empty makeEquation' :: [([ModuloCase], Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])] -> FlattenedExp -> - StateT (VarMap) (Either String) + BKM [([ModuloCase], Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])] makeEquation' m (Const n) = return $ add (0,n) m makeEquation' m sc@(Scale n v) = varIndex sc >>* (\ind -> add (ind, n) m) diff --git a/common/Types.hs b/common/Types.hs index f12b993..040e9f9 100644 --- a/common/Types.hs +++ b/common/Types.hs @@ -26,8 +26,8 @@ module Types , returnTypesOfFunction , BytesInResult(..), bytesInType, countReplicator, countStructured, computeStructured - , makeAbbrevAM, makeConstant, makeDimension, specificDimSize - , addOne, subOne, addExprs, subExprs, mulExprs, divExprs + , makeAbbrevAM, makeConstant, makeConstant', makeDimension, specificDimSize + , addOne, subOne, addExprs, subExprs, mulExprs, divExprs, remExprs , addOneInt, subOneInt, addExprsInt, subExprsInt, mulExprsInt, divExprsInt , addDimensions, applyDimension, removeFixedDimensions, trivialSubscriptType, subscriptType, unsubscriptType , applyDirection @@ -397,7 +397,10 @@ makeAbbrevAM am = am -- | Generate a constant expression from an integer -- for array sizes and the -- like. makeConstant :: Meta -> Int -> A.Expression -makeConstant m n = A.Literal m A.Int $ A.IntLiteral m (show n) +makeConstant m = makeConstant' m A.Int . toInteger + +makeConstant' :: Meta -> A.Type -> Integer -> A.Expression +makeConstant' m t n = A.Literal m t $ A.IntLiteral m (show n) -- | Generate a constant dimension from an integer. makeDimension :: Meta -> Int -> A.Dimension @@ -858,6 +861,10 @@ mulExprs = dyadicExpr "*" divExprs :: DyadicExprM divExprs = dyadicExpr "/" +-- | Divide two expressions. +remExprs :: DyadicExprM +remExprs = dyadicExpr "\\" + -- | Add two expressions. addExprsInt :: DyadicExpr addExprsInt = dyadicExpr' (A.Int,A.Int) "+"