From 337d3396415c87a0e5dd229dff54407a58e73f17 Mon Sep 17 00:00:00 2001 From: Neil Brown Date: Mon, 17 Dec 2007 11:43:18 +0000 Subject: [PATCH] Added comments to the makeEquations function --- transformations/ArrayUsageCheck.hs | 48 ++++++++++++++++++++++-------- 1 file changed, 35 insertions(+), 13 deletions(-) diff --git a/transformations/ArrayUsageCheck.hs b/transformations/ArrayUsageCheck.hs index 7b7ca7d..1448e9b 100644 --- a/transformations/ArrayUsageCheck.hs +++ b/transformations/ArrayUsageCheck.hs @@ -35,7 +35,6 @@ import Pass import Types import Utils - -- TODO we should probably calculate this from the CFG checkArrayUsage :: Data a => a -> PassM a checkArrayUsage tree = (mapM_ checkPar $ listify (const True) tree) >> return tree @@ -80,12 +79,18 @@ checkArrayUsage tree = (mapM_ checkPar $ listify (const True) tree) >> return tr getRealName s = lookupName (A.Name undefined undefined s) >>* A.ndOrigName +-- | A type for inside makeEquations: data FlattenedExp = Const Integer | Scale Integer A.Variable deriving (Eq,Show) --- TODO probably want to take this into the PassM monad at some point +-- | Given a list of expressions, an expression representing the upper array bound, returns either an error +-- (because the expressions can't be handled, typically) or a set of equalities, inequalities and mapping from +-- (unique, munged) variable name to variable-index in the equations. +-- TODO probably want to take this into the PassM monad at some point, to use the Meta in the error message makeEquations :: [A.Expression] -> A.Expression -> Either String (Map.Map String Int, (EqualityProblem, InequalityProblem)) makeEquations es high = makeEquations' >>* (\(s,v,lh) -> (s,squareEquations (pairEqs v, getIneqs lh v))) where + -- | The body of makeEquations; returns the variable mapping, the list of (nx,ex) pairs and a pair + -- representing the upper and lower bounds of the array (inclusive). makeEquations' :: Either String (Map.Map String Int, [(Integer,EqualityConstraintEquation)], (EqualityConstraintEquation, EqualityConstraintEquation)) makeEquations' = do ((v,h),s) <- (flip runStateT) Map.empty $ do flattened <- lift (mapM flatten es) @@ -94,6 +99,7 @@ makeEquations es high = makeEquations' >>* (\(s,v,lh) -> (s,squareEquations (pai return (eqs,high') return (s,v,(amap (const 0) h, h)) + -- Note that in all these functions, the divisor should always be positive! -- Takes an expression, and transforms it into an expression like: -- (e_0 + e_1 + e_2) / d @@ -108,20 +114,29 @@ makeEquations es high = makeEquations' >>* (\(s,v,lh) -> (s,squareEquations (pai | otherwise = throwError ("Unhandleable operator found in expression: " ++ show op) flatten (A.ExprVariable _ v) = return (1,[Scale 1 v]) flatten other = throwError ("Unhandleable item found in expression: " ++ show other) + + --TODO we need to handle lots more different expression types in future. + -- | Scales a flattened expression by the given integer scaling. scale :: Integer -> [FlattenedExp] -> [FlattenedExp] scale sc = map scale' where scale' (Const n) = Const (n * sc) scale' (Scale n v) = Scale (n * sc) v - + + -- | An easy way of applying combine to two monadic returns + combine' :: Either String (Integer,[FlattenedExp]) -> Either String (Integer,[FlattenedExp]) -> Either String (Integer,[FlattenedExp]) combine' x y = do {x' <- x; y' <- y; combine x' y'} + + -- | Combines (adds) two flattened expressions with a divisor. + -- Given (nx,ex) and (ny,ey), representing ex/nx and ey/ny, this becomes + -- ((ny*ex)+(nx*ey)/nx*ny (i.e. standard mathematics!). combine :: (Integer,[FlattenedExp]) -> (Integer,[FlattenedExp]) -> Either String (Integer,[FlattenedExp]) combine (nx, ex) (ny, ey) = return $ (nx * ny, scale ny ex ++ scale nx ey) - --TODO we need to handle lots more different expression types in future. - -- For now we just handle dyadic +,- + -- | Finds the index associated with a particular variable; either by finding an existing index + -- or allocating a new one. varIndex :: A.Variable -> StateT (Map.Map String Int) (Either String) Int varIndex (A.Variable _ (A.Name _ _ varName)) = do st <- get @@ -132,12 +147,17 @@ makeEquations es high = makeEquations' >>* (\(s,v,lh) -> (s,squareEquations (pai put st' return ind - -- Pairs all possible combinations + -- | Pairs all possible combinations of the list of divided equations. That is for all pairs + -- in the list ((nx,ex),(ny,ey)) (representing ex/nx and ey/ny), forms the equation ny*ex = nx*ey pairEqs :: [(Integer,EqualityConstraintEquation)] -> [EqualityConstraintEquation] pairEqs = filter (any (/= 0) . elems) . map (uncurry pairEqs') . allPairs where pairEqs' (nx,ex) (ny,ey) = arrayZipWith' 0 (-) (amap (* ny) ex) (amap (* nx) ey) + -- | Given a (low,high) bound (typically: array dimensions), and a list of equations (nx,ex) representing (ex/nx), + -- forms the possible inequalities: + -- * ex/nx >= low (=> ex >= low * nx) + -- * ex/nx <= high (=> ex <= high * nx) getIneqs :: (EqualityConstraintEquation, EqualityConstraintEquation) -> [(Integer,EqualityConstraintEquation)] -> [InequalityConstraintEquation] getIneqs (low, high) = concatMap getLH where @@ -148,12 +168,12 @@ makeEquations es high = makeEquations' >>* (\(s,v,lh) -> (s,squareEquations (pai getLH (sc, eq) = [eq `addEq` (scaleEq (-sc) low),(scaleEq sc high) `addEq` amap negate eq] addEq = arrayZipWith' 0 (+) - + + -- | Given a pair (nx,ex) representing ex/nx, forms an equation (e) from the latter part, and returns (nx,e) makeEquation :: (Integer,[FlattenedExp]) -> StateT (Map.Map String Int) (Either String) (Integer,EqualityConstraintEquation) makeEquation (divisor, summedItems) = do eqs <- foldM makeEquation' Map.empty summedItems - max <- maxVar - return (divisor, mapToArray max eqs) + return (divisor, mapToArray eqs) where makeEquation' :: Map.Map Int Integer -> FlattenedExp -> StateT (Map.Map String Int) (Either String) (Map.Map Int Integer) makeEquation' m (Const n) = return $ add (0,n) m @@ -161,14 +181,16 @@ makeEquations es high = makeEquations' >>* (\(s,v,lh) -> (s,squareEquations (pai add :: (Int,Integer) -> Map.Map Int Integer -> Map.Map Int Integer add = uncurry (Map.insertWith (+)) - - maxVar = get >>* (maximum . (0 :) . Map.elems) - mapToArray :: (IArray a v, Num v, Num k, Ord k, Ix k) => k -> Map.Map k v -> a k v - mapToArray highest m = accumArray (+) 0 (0, highest') . Map.assocs $ m + -- | Converts a map to an array. Any missing elements in the middle of the bounds are given the value zero. + -- Could probably be moved to Utils + mapToArray :: (IArray a v, Num v, Num k, Ord k, Ix k) => Map.Map k v -> a k v + mapToArray m = accumArray (+) 0 (0, highest') . Map.assocs $ m where highest' = maximum $ Map.keys m + -- | Given a pair of equation sets, makes all the equations in the lists be the length + -- of the longest equation. All missing elements are of course given value zero. squareEquations :: ([Array CoeffIndex Integer],[Array CoeffIndex Integer]) -> ([Array CoeffIndex Integer],[Array CoeffIndex Integer]) squareEquations (eqs,ineqs) = uncurry transformPair (mkPair $ map $ makeSize (0,highest) 0) (eqs,ineqs)