From 74490c005ee76e7b2c2989e5ba4ca6ae4af6e385 Mon Sep 17 00:00:00 2001 From: Neil Brown Date: Fri, 14 Dec 2007 23:15:39 +0000 Subject: [PATCH] Added some (fairly messy) code for taking a list of A.Expression and generating a list of equations --- transformations/ArrayUsageCheck.hs | 90 +++++++++++++++++++++++++++ transformations/RainUsageCheckTest.hs | 12 ++++ 2 files changed, 102 insertions(+) diff --git a/transformations/ArrayUsageCheck.hs b/transformations/ArrayUsageCheck.hs index 32b5971..e16f6c2 100644 --- a/transformations/ArrayUsageCheck.hs +++ b/transformations/ArrayUsageCheck.hs @@ -18,9 +18,11 @@ with this program. If not, see . module ArrayUsageCheck where +import Control.Monad.Error import Control.Monad.State import Data.Array.IArray import Data.List +import qualified Data.Map as Map import Data.Maybe import qualified AST as A @@ -59,6 +61,94 @@ makeProblems indexLists = map checkEq zippedPairs makeProblem1Dim :: [CoeffExpr] -> [Problem] makeProblem1Dim ces = makeProblems [[c] | c <- ces] +data FlattenedExp = Const Integer | Scale Integer A.Variable deriving (Eq,Show) + +-- TODO probably want to take this into the PassM monad at some point +makeEquations :: [A.Expression] -> A.Expression -> Either String (Map.Map String Int, (EqualityProblem, InequalityProblem)) +makeEquations es high = makeEquations' >>* (\(s,v,lh) -> (s,(pairEqs v, getIneqs lh v))) + where + makeEquations' :: Either String (Map.Map String Int, [(Integer,EqualityConstraintEquation)], (EqualityConstraintEquation, EqualityConstraintEquation)) + makeEquations' = do ((v,h),s) <- (flip runStateT) Map.empty $ + do flattened <- lift (mapM flatten es) + eqs <- mapM makeEquation flattened + (1,high') <- (lift $ flatten high) >>= makeEquation + return (eqs,high') + return (s,v,(amap (const 0) h, h)) + + + -- Takes an expression, and transforms it into an expression like: + -- (e_0 + e_1 + e_2) / d + -- where d is a constant (non-zero!) integer, and each e_k + -- is either a const, a var, const * var, or (const * var) % const [TODO]. + -- If the expression cannot be transformed into such a format, an error is returned + flatten :: A.Expression -> Either String (Integer,[FlattenedExp]) + flatten (A.Literal _ _ (A.IntLiteral _ n)) = return (1,[Const (read n)]) + flatten (A.Dyadic m op lhs rhs) | op == A.Add = combine' (flatten lhs) (flatten rhs) + | op == A.Subtr = combine' (flatten lhs) (liftM (transformPair id (scale (-1))) $ flatten rhs) + -- TODO Mul and Div + | otherwise = throwError ("Unhandleable operator found in expression: " ++ show op) + flatten (A.ExprVariable _ v) = return (1,[Scale 1 v]) + flatten other = throwError ("Unhandleable item found in expression: " ++ show other) + + scale :: Integer -> [FlattenedExp] -> [FlattenedExp] + scale sc = map scale' + where + scale' (Const n) = Const (n * sc) + scale' (Scale n v) = Scale (n * sc) v + + combine' x y = do {x' <- x; y' <- y; combine x' y'} + combine :: (Integer,[FlattenedExp]) -> (Integer,[FlattenedExp]) -> Either String (Integer,[FlattenedExp]) + combine (nx, ex) (ny, ey) = return $ (nx * ny, scale ny ex ++ scale nx ey) + + --TODO we need to handle lots more different expression types in future. + -- For now we just handle dyadic +,- + + varIndex :: A.Variable -> StateT (Map.Map String Int) (Either String) Int + varIndex (A.Variable _ (A.Name _ _ varName)) + = do st <- get + let (st',ind) = case Map.lookup varName st of + Just val -> (st,val) + Nothing -> let newId = (1 + (maximum $ 0 : Map.elems st)) in + (Map.insert varName newId st, newId) + put st' + return ind + + -- Pairs all possible combinations + pairEqs :: [(Integer,EqualityConstraintEquation)] -> [EqualityConstraintEquation] + pairEqs = filter (any (/= 0) . elems) . map (uncurry pairEqs') . product2 . mkPair + where + pairEqs' (nx,ex) (ny,ey) = arrayZipWith (-) (amap (* ny) ex) (amap (* nx) ey) + + getIneqs :: (EqualityConstraintEquation, EqualityConstraintEquation) -> [(Integer,EqualityConstraintEquation)] -> [InequalityConstraintEquation] + getIneqs (low, high) = concatMap getLH + where + -- eq / sc >= low => eq - (sc * low) >= 0 + -- eq / sc <= high => (high * sc) - eq >= 0 + + getLH :: (Integer,EqualityConstraintEquation) -> [InequalityConstraintEquation] + getLH (sc, eq) = [eq `addEq` (scaleEq (-sc) low),(scaleEq sc high) `addEq` amap negate eq] + + addEq = arrayZipWith (+) + scaleEq n = amap (* n) + + makeEquation :: (Integer,[FlattenedExp]) -> StateT (Map.Map String Int) (Either String) (Integer,EqualityConstraintEquation) + makeEquation (divisor, summedItems) + = do eqs <- foldM makeEquation' Map.empty summedItems + max <- maxVar + return (divisor, mapToArray max eqs) + where + makeEquation' :: Map.Map Int Integer -> FlattenedExp -> StateT (Map.Map String Int) (Either String) (Map.Map Int Integer) + makeEquation' m (Const n) = return $ add (0,n) m + makeEquation' m (Scale n v) = varIndex v >>* (\ind -> add (ind, n) m) + + add :: (Int,Integer) -> Map.Map Int Integer -> Map.Map Int Integer + add = uncurry (Map.insertWith (+)) + + maxVar = get >>* (maximum . (0 :) . Map.elems) + + mapToArray :: (IArray a v, Num v, Num k, Ord k, Ix k) => k -> Map.Map k v -> a k v + mapToArray highest = (\arr -> accumArray (+) 0 (0, highest) arr) . Map.assocs + type CoeffIndex = Int type EqualityConstraintEquation = Array CoeffIndex Integer type EqualityProblem = [EqualityConstraintEquation] diff --git a/transformations/RainUsageCheckTest.hs b/transformations/RainUsageCheckTest.hs index 6289285..c2ea8c8 100644 --- a/transformations/RainUsageCheckTest.hs +++ b/transformations/RainUsageCheckTest.hs @@ -378,8 +378,20 @@ testIndexes = TestList ,safeParTest 120 False (0,10) [i,i ++ con 1] ,safeParTest 140 True (0,10) [2 ** i, 2 ** i ++ con 1] + + ,TestCase $ assertStuff "testIndexes makeEq" + (Right (Map.empty,(uncurry makeConsistent) (doubleEq [con 0 === con 1],leq [con 0,con 0,con 7] &&& leq [con 0,con 1,con 7]))) $ + makeEquations [intLiteral 0, intLiteral 1] (intLiteral 7) + ,TestCase $ assertStuff "testIndexes makeEq 2" + (Right (Map.singleton "i" 1,(uncurry makeConsistent) (doubleEq [i === con 3],leq [con 0,con 3,con 7] &&& leq [con 0,i,con 7]))) $ + makeEquations [exprVariable "i",intLiteral 3] (intLiteral 7) ] where + doubleEq = concatMap (\(Eq e) -> [Eq e,Eq $ negateVars e]) + assertStuff title x y = assertEqual title (munge x) (munge y) + where + munge = transformEither id (transformPair id (transformPair sort sort)) + -- Given some indexes using "i", this function checks whether these can -- ever overlap within the bounds given, and matches this against -- the expected value; True for safe, False for unsafe.