Added comments to the makeEquations function

This commit is contained in:
Neil Brown 2007-12-17 11:43:18 +00:00
parent 218a1bd22c
commit 337d339641

View File

@ -35,7 +35,6 @@ import Pass
import Types
import Utils
-- TODO we should probably calculate this from the CFG
checkArrayUsage :: Data a => a -> PassM a
checkArrayUsage tree = (mapM_ checkPar $ listify (const True) tree) >> return tree
@ -80,12 +79,18 @@ checkArrayUsage tree = (mapM_ checkPar $ listify (const True) tree) >> return tr
getRealName s = lookupName (A.Name undefined undefined s) >>* A.ndOrigName
-- | A type for inside makeEquations:
data FlattenedExp = Const Integer | Scale Integer A.Variable deriving (Eq,Show)
-- TODO probably want to take this into the PassM monad at some point
-- | Given a list of expressions, an expression representing the upper array bound, returns either an error
-- (because the expressions can't be handled, typically) or a set of equalities, inequalities and mapping from
-- (unique, munged) variable name to variable-index in the equations.
-- TODO probably want to take this into the PassM monad at some point, to use the Meta in the error message
makeEquations :: [A.Expression] -> A.Expression -> Either String (Map.Map String Int, (EqualityProblem, InequalityProblem))
makeEquations es high = makeEquations' >>* (\(s,v,lh) -> (s,squareEquations (pairEqs v, getIneqs lh v)))
where
-- | The body of makeEquations; returns the variable mapping, the list of (nx,ex) pairs and a pair
-- representing the upper and lower bounds of the array (inclusive).
makeEquations' :: Either String (Map.Map String Int, [(Integer,EqualityConstraintEquation)], (EqualityConstraintEquation, EqualityConstraintEquation))
makeEquations' = do ((v,h),s) <- (flip runStateT) Map.empty $
do flattened <- lift (mapM flatten es)
@ -94,6 +99,7 @@ makeEquations es high = makeEquations' >>* (\(s,v,lh) -> (s,squareEquations (pai
return (eqs,high')
return (s,v,(amap (const 0) h, h))
-- Note that in all these functions, the divisor should always be positive!
-- Takes an expression, and transforms it into an expression like:
-- (e_0 + e_1 + e_2) / d
@ -108,20 +114,29 @@ makeEquations es high = makeEquations' >>* (\(s,v,lh) -> (s,squareEquations (pai
| otherwise = throwError ("Unhandleable operator found in expression: " ++ show op)
flatten (A.ExprVariable _ v) = return (1,[Scale 1 v])
flatten other = throwError ("Unhandleable item found in expression: " ++ show other)
--TODO we need to handle lots more different expression types in future.
-- | Scales a flattened expression by the given integer scaling.
scale :: Integer -> [FlattenedExp] -> [FlattenedExp]
scale sc = map scale'
where
scale' (Const n) = Const (n * sc)
scale' (Scale n v) = Scale (n * sc) v
-- | An easy way of applying combine to two monadic returns
combine' :: Either String (Integer,[FlattenedExp]) -> Either String (Integer,[FlattenedExp]) -> Either String (Integer,[FlattenedExp])
combine' x y = do {x' <- x; y' <- y; combine x' y'}
-- | Combines (adds) two flattened expressions with a divisor.
-- Given (nx,ex) and (ny,ey), representing ex/nx and ey/ny, this becomes
-- ((ny*ex)+(nx*ey)/nx*ny (i.e. standard mathematics!).
combine :: (Integer,[FlattenedExp]) -> (Integer,[FlattenedExp]) -> Either String (Integer,[FlattenedExp])
combine (nx, ex) (ny, ey) = return $ (nx * ny, scale ny ex ++ scale nx ey)
--TODO we need to handle lots more different expression types in future.
-- For now we just handle dyadic +,-
-- | Finds the index associated with a particular variable; either by finding an existing index
-- or allocating a new one.
varIndex :: A.Variable -> StateT (Map.Map String Int) (Either String) Int
varIndex (A.Variable _ (A.Name _ _ varName))
= do st <- get
@ -132,12 +147,17 @@ makeEquations es high = makeEquations' >>* (\(s,v,lh) -> (s,squareEquations (pai
put st'
return ind
-- Pairs all possible combinations
-- | Pairs all possible combinations of the list of divided equations. That is for all pairs
-- in the list ((nx,ex),(ny,ey)) (representing ex/nx and ey/ny), forms the equation ny*ex = nx*ey
pairEqs :: [(Integer,EqualityConstraintEquation)] -> [EqualityConstraintEquation]
pairEqs = filter (any (/= 0) . elems) . map (uncurry pairEqs') . allPairs
where
pairEqs' (nx,ex) (ny,ey) = arrayZipWith' 0 (-) (amap (* ny) ex) (amap (* nx) ey)
-- | Given a (low,high) bound (typically: array dimensions), and a list of equations (nx,ex) representing (ex/nx),
-- forms the possible inequalities:
-- * ex/nx >= low (=> ex >= low * nx)
-- * ex/nx <= high (=> ex <= high * nx)
getIneqs :: (EqualityConstraintEquation, EqualityConstraintEquation) -> [(Integer,EqualityConstraintEquation)] -> [InequalityConstraintEquation]
getIneqs (low, high) = concatMap getLH
where
@ -148,12 +168,12 @@ makeEquations es high = makeEquations' >>* (\(s,v,lh) -> (s,squareEquations (pai
getLH (sc, eq) = [eq `addEq` (scaleEq (-sc) low),(scaleEq sc high) `addEq` amap negate eq]
addEq = arrayZipWith' 0 (+)
-- | Given a pair (nx,ex) representing ex/nx, forms an equation (e) from the latter part, and returns (nx,e)
makeEquation :: (Integer,[FlattenedExp]) -> StateT (Map.Map String Int) (Either String) (Integer,EqualityConstraintEquation)
makeEquation (divisor, summedItems)
= do eqs <- foldM makeEquation' Map.empty summedItems
max <- maxVar
return (divisor, mapToArray max eqs)
return (divisor, mapToArray eqs)
where
makeEquation' :: Map.Map Int Integer -> FlattenedExp -> StateT (Map.Map String Int) (Either String) (Map.Map Int Integer)
makeEquation' m (Const n) = return $ add (0,n) m
@ -161,14 +181,16 @@ makeEquations es high = makeEquations' >>* (\(s,v,lh) -> (s,squareEquations (pai
add :: (Int,Integer) -> Map.Map Int Integer -> Map.Map Int Integer
add = uncurry (Map.insertWith (+))
maxVar = get >>* (maximum . (0 :) . Map.elems)
mapToArray :: (IArray a v, Num v, Num k, Ord k, Ix k) => k -> Map.Map k v -> a k v
mapToArray highest m = accumArray (+) 0 (0, highest') . Map.assocs $ m
-- | Converts a map to an array. Any missing elements in the middle of the bounds are given the value zero.
-- Could probably be moved to Utils
mapToArray :: (IArray a v, Num v, Num k, Ord k, Ix k) => Map.Map k v -> a k v
mapToArray m = accumArray (+) 0 (0, highest') . Map.assocs $ m
where
highest' = maximum $ Map.keys m
-- | Given a pair of equation sets, makes all the equations in the lists be the length
-- of the longest equation. All missing elements are of course given value zero.
squareEquations :: ([Array CoeffIndex Integer],[Array CoeffIndex Integer]) -> ([Array CoeffIndex Integer],[Array CoeffIndex Integer])
squareEquations (eqs,ineqs) = uncurry transformPair (mkPair $ map $ makeSize (0,highest) 0) (eqs,ineqs)