Added some (fairly messy) code for taking a list of A.Expression and generating a list of equations

This commit is contained in:
Neil Brown 2007-12-14 23:15:39 +00:00
parent d674a2fdd0
commit 74490c005e
2 changed files with 102 additions and 0 deletions

View File

@ -18,9 +18,11 @@ with this program. If not, see <http://www.gnu.org/licenses/>.
module ArrayUsageCheck where module ArrayUsageCheck where
import Control.Monad.Error
import Control.Monad.State import Control.Monad.State
import Data.Array.IArray import Data.Array.IArray
import Data.List import Data.List
import qualified Data.Map as Map
import Data.Maybe import Data.Maybe
import qualified AST as A import qualified AST as A
@ -59,6 +61,94 @@ makeProblems indexLists = map checkEq zippedPairs
makeProblem1Dim :: [CoeffExpr] -> [Problem] makeProblem1Dim :: [CoeffExpr] -> [Problem]
makeProblem1Dim ces = makeProblems [[c] | c <- ces] makeProblem1Dim ces = makeProblems [[c] | c <- ces]
data FlattenedExp = Const Integer | Scale Integer A.Variable deriving (Eq,Show)
-- TODO probably want to take this into the PassM monad at some point
makeEquations :: [A.Expression] -> A.Expression -> Either String (Map.Map String Int, (EqualityProblem, InequalityProblem))
makeEquations es high = makeEquations' >>* (\(s,v,lh) -> (s,(pairEqs v, getIneqs lh v)))
where
makeEquations' :: Either String (Map.Map String Int, [(Integer,EqualityConstraintEquation)], (EqualityConstraintEquation, EqualityConstraintEquation))
makeEquations' = do ((v,h),s) <- (flip runStateT) Map.empty $
do flattened <- lift (mapM flatten es)
eqs <- mapM makeEquation flattened
(1,high') <- (lift $ flatten high) >>= makeEquation
return (eqs,high')
return (s,v,(amap (const 0) h, h))
-- Takes an expression, and transforms it into an expression like:
-- (e_0 + e_1 + e_2) / d
-- where d is a constant (non-zero!) integer, and each e_k
-- is either a const, a var, const * var, or (const * var) % const [TODO].
-- If the expression cannot be transformed into such a format, an error is returned
flatten :: A.Expression -> Either String (Integer,[FlattenedExp])
flatten (A.Literal _ _ (A.IntLiteral _ n)) = return (1,[Const (read n)])
flatten (A.Dyadic m op lhs rhs) | op == A.Add = combine' (flatten lhs) (flatten rhs)
| op == A.Subtr = combine' (flatten lhs) (liftM (transformPair id (scale (-1))) $ flatten rhs)
-- TODO Mul and Div
| otherwise = throwError ("Unhandleable operator found in expression: " ++ show op)
flatten (A.ExprVariable _ v) = return (1,[Scale 1 v])
flatten other = throwError ("Unhandleable item found in expression: " ++ show other)
scale :: Integer -> [FlattenedExp] -> [FlattenedExp]
scale sc = map scale'
where
scale' (Const n) = Const (n * sc)
scale' (Scale n v) = Scale (n * sc) v
combine' x y = do {x' <- x; y' <- y; combine x' y'}
combine :: (Integer,[FlattenedExp]) -> (Integer,[FlattenedExp]) -> Either String (Integer,[FlattenedExp])
combine (nx, ex) (ny, ey) = return $ (nx * ny, scale ny ex ++ scale nx ey)
--TODO we need to handle lots more different expression types in future.
-- For now we just handle dyadic +,-
varIndex :: A.Variable -> StateT (Map.Map String Int) (Either String) Int
varIndex (A.Variable _ (A.Name _ _ varName))
= do st <- get
let (st',ind) = case Map.lookup varName st of
Just val -> (st,val)
Nothing -> let newId = (1 + (maximum $ 0 : Map.elems st)) in
(Map.insert varName newId st, newId)
put st'
return ind
-- Pairs all possible combinations
pairEqs :: [(Integer,EqualityConstraintEquation)] -> [EqualityConstraintEquation]
pairEqs = filter (any (/= 0) . elems) . map (uncurry pairEqs') . product2 . mkPair
where
pairEqs' (nx,ex) (ny,ey) = arrayZipWith (-) (amap (* ny) ex) (amap (* nx) ey)
getIneqs :: (EqualityConstraintEquation, EqualityConstraintEquation) -> [(Integer,EqualityConstraintEquation)] -> [InequalityConstraintEquation]
getIneqs (low, high) = concatMap getLH
where
-- eq / sc >= low => eq - (sc * low) >= 0
-- eq / sc <= high => (high * sc) - eq >= 0
getLH :: (Integer,EqualityConstraintEquation) -> [InequalityConstraintEquation]
getLH (sc, eq) = [eq `addEq` (scaleEq (-sc) low),(scaleEq sc high) `addEq` amap negate eq]
addEq = arrayZipWith (+)
scaleEq n = amap (* n)
makeEquation :: (Integer,[FlattenedExp]) -> StateT (Map.Map String Int) (Either String) (Integer,EqualityConstraintEquation)
makeEquation (divisor, summedItems)
= do eqs <- foldM makeEquation' Map.empty summedItems
max <- maxVar
return (divisor, mapToArray max eqs)
where
makeEquation' :: Map.Map Int Integer -> FlattenedExp -> StateT (Map.Map String Int) (Either String) (Map.Map Int Integer)
makeEquation' m (Const n) = return $ add (0,n) m
makeEquation' m (Scale n v) = varIndex v >>* (\ind -> add (ind, n) m)
add :: (Int,Integer) -> Map.Map Int Integer -> Map.Map Int Integer
add = uncurry (Map.insertWith (+))
maxVar = get >>* (maximum . (0 :) . Map.elems)
mapToArray :: (IArray a v, Num v, Num k, Ord k, Ix k) => k -> Map.Map k v -> a k v
mapToArray highest = (\arr -> accumArray (+) 0 (0, highest) arr) . Map.assocs
type CoeffIndex = Int type CoeffIndex = Int
type EqualityConstraintEquation = Array CoeffIndex Integer type EqualityConstraintEquation = Array CoeffIndex Integer
type EqualityProblem = [EqualityConstraintEquation] type EqualityProblem = [EqualityConstraintEquation]

View File

@ -378,8 +378,20 @@ testIndexes = TestList
,safeParTest 120 False (0,10) [i,i ++ con 1] ,safeParTest 120 False (0,10) [i,i ++ con 1]
,safeParTest 140 True (0,10) [2 ** i, 2 ** i ++ con 1] ,safeParTest 140 True (0,10) [2 ** i, 2 ** i ++ con 1]
,TestCase $ assertStuff "testIndexes makeEq"
(Right (Map.empty,(uncurry makeConsistent) (doubleEq [con 0 === con 1],leq [con 0,con 0,con 7] &&& leq [con 0,con 1,con 7]))) $
makeEquations [intLiteral 0, intLiteral 1] (intLiteral 7)
,TestCase $ assertStuff "testIndexes makeEq 2"
(Right (Map.singleton "i" 1,(uncurry makeConsistent) (doubleEq [i === con 3],leq [con 0,con 3,con 7] &&& leq [con 0,i,con 7]))) $
makeEquations [exprVariable "i",intLiteral 3] (intLiteral 7)
] ]
where where
doubleEq = concatMap (\(Eq e) -> [Eq e,Eq $ negateVars e])
assertStuff title x y = assertEqual title (munge x) (munge y)
where
munge = transformEither id (transformPair id (transformPair sort sort))
-- Given some indexes using "i", this function checks whether these can -- Given some indexes using "i", this function checks whether these can
-- ever overlap within the bounds given, and matches this against -- ever overlap within the bounds given, and matches this against
-- the expected value; True for safe, False for unsafe. -- the expected value; True for safe, False for unsafe.