Added some (fairly messy) code for taking a list of A.Expression and generating a list of equations

This commit is contained in:
Neil Brown 2007-12-14 23:15:39 +00:00
parent d674a2fdd0
commit 74490c005e
2 changed files with 102 additions and 0 deletions

View File

@ -18,9 +18,11 @@ with this program. If not, see <http://www.gnu.org/licenses/>.
module ArrayUsageCheck where
import Control.Monad.Error
import Control.Monad.State
import Data.Array.IArray
import Data.List
import qualified Data.Map as Map
import Data.Maybe
import qualified AST as A
@ -59,6 +61,94 @@ makeProblems indexLists = map checkEq zippedPairs
makeProblem1Dim :: [CoeffExpr] -> [Problem]
makeProblem1Dim ces = makeProblems [[c] | c <- ces]
data FlattenedExp = Const Integer | Scale Integer A.Variable deriving (Eq,Show)
-- TODO probably want to take this into the PassM monad at some point
makeEquations :: [A.Expression] -> A.Expression -> Either String (Map.Map String Int, (EqualityProblem, InequalityProblem))
makeEquations es high = makeEquations' >>* (\(s,v,lh) -> (s,(pairEqs v, getIneqs lh v)))
where
makeEquations' :: Either String (Map.Map String Int, [(Integer,EqualityConstraintEquation)], (EqualityConstraintEquation, EqualityConstraintEquation))
makeEquations' = do ((v,h),s) <- (flip runStateT) Map.empty $
do flattened <- lift (mapM flatten es)
eqs <- mapM makeEquation flattened
(1,high') <- (lift $ flatten high) >>= makeEquation
return (eqs,high')
return (s,v,(amap (const 0) h, h))
-- Takes an expression, and transforms it into an expression like:
-- (e_0 + e_1 + e_2) / d
-- where d is a constant (non-zero!) integer, and each e_k
-- is either a const, a var, const * var, or (const * var) % const [TODO].
-- If the expression cannot be transformed into such a format, an error is returned
flatten :: A.Expression -> Either String (Integer,[FlattenedExp])
flatten (A.Literal _ _ (A.IntLiteral _ n)) = return (1,[Const (read n)])
flatten (A.Dyadic m op lhs rhs) | op == A.Add = combine' (flatten lhs) (flatten rhs)
| op == A.Subtr = combine' (flatten lhs) (liftM (transformPair id (scale (-1))) $ flatten rhs)
-- TODO Mul and Div
| otherwise = throwError ("Unhandleable operator found in expression: " ++ show op)
flatten (A.ExprVariable _ v) = return (1,[Scale 1 v])
flatten other = throwError ("Unhandleable item found in expression: " ++ show other)
scale :: Integer -> [FlattenedExp] -> [FlattenedExp]
scale sc = map scale'
where
scale' (Const n) = Const (n * sc)
scale' (Scale n v) = Scale (n * sc) v
combine' x y = do {x' <- x; y' <- y; combine x' y'}
combine :: (Integer,[FlattenedExp]) -> (Integer,[FlattenedExp]) -> Either String (Integer,[FlattenedExp])
combine (nx, ex) (ny, ey) = return $ (nx * ny, scale ny ex ++ scale nx ey)
--TODO we need to handle lots more different expression types in future.
-- For now we just handle dyadic +,-
varIndex :: A.Variable -> StateT (Map.Map String Int) (Either String) Int
varIndex (A.Variable _ (A.Name _ _ varName))
= do st <- get
let (st',ind) = case Map.lookup varName st of
Just val -> (st,val)
Nothing -> let newId = (1 + (maximum $ 0 : Map.elems st)) in
(Map.insert varName newId st, newId)
put st'
return ind
-- Pairs all possible combinations
pairEqs :: [(Integer,EqualityConstraintEquation)] -> [EqualityConstraintEquation]
pairEqs = filter (any (/= 0) . elems) . map (uncurry pairEqs') . product2 . mkPair
where
pairEqs' (nx,ex) (ny,ey) = arrayZipWith (-) (amap (* ny) ex) (amap (* nx) ey)
getIneqs :: (EqualityConstraintEquation, EqualityConstraintEquation) -> [(Integer,EqualityConstraintEquation)] -> [InequalityConstraintEquation]
getIneqs (low, high) = concatMap getLH
where
-- eq / sc >= low => eq - (sc * low) >= 0
-- eq / sc <= high => (high * sc) - eq >= 0
getLH :: (Integer,EqualityConstraintEquation) -> [InequalityConstraintEquation]
getLH (sc, eq) = [eq `addEq` (scaleEq (-sc) low),(scaleEq sc high) `addEq` amap negate eq]
addEq = arrayZipWith (+)
scaleEq n = amap (* n)
makeEquation :: (Integer,[FlattenedExp]) -> StateT (Map.Map String Int) (Either String) (Integer,EqualityConstraintEquation)
makeEquation (divisor, summedItems)
= do eqs <- foldM makeEquation' Map.empty summedItems
max <- maxVar
return (divisor, mapToArray max eqs)
where
makeEquation' :: Map.Map Int Integer -> FlattenedExp -> StateT (Map.Map String Int) (Either String) (Map.Map Int Integer)
makeEquation' m (Const n) = return $ add (0,n) m
makeEquation' m (Scale n v) = varIndex v >>* (\ind -> add (ind, n) m)
add :: (Int,Integer) -> Map.Map Int Integer -> Map.Map Int Integer
add = uncurry (Map.insertWith (+))
maxVar = get >>* (maximum . (0 :) . Map.elems)
mapToArray :: (IArray a v, Num v, Num k, Ord k, Ix k) => k -> Map.Map k v -> a k v
mapToArray highest = (\arr -> accumArray (+) 0 (0, highest) arr) . Map.assocs
type CoeffIndex = Int
type EqualityConstraintEquation = Array CoeffIndex Integer
type EqualityProblem = [EqualityConstraintEquation]

View File

@ -378,8 +378,20 @@ testIndexes = TestList
,safeParTest 120 False (0,10) [i,i ++ con 1]
,safeParTest 140 True (0,10) [2 ** i, 2 ** i ++ con 1]
,TestCase $ assertStuff "testIndexes makeEq"
(Right (Map.empty,(uncurry makeConsistent) (doubleEq [con 0 === con 1],leq [con 0,con 0,con 7] &&& leq [con 0,con 1,con 7]))) $
makeEquations [intLiteral 0, intLiteral 1] (intLiteral 7)
,TestCase $ assertStuff "testIndexes makeEq 2"
(Right (Map.singleton "i" 1,(uncurry makeConsistent) (doubleEq [i === con 3],leq [con 0,con 3,con 7] &&& leq [con 0,i,con 7]))) $
makeEquations [exprVariable "i",intLiteral 3] (intLiteral 7)
]
where
doubleEq = concatMap (\(Eq e) -> [Eq e,Eq $ negateVars e])
assertStuff title x y = assertEqual title (munge x) (munge y)
where
munge = transformEither id (transformPair id (transformPair sort sort))
-- Given some indexes using "i", this function checks whether these can
-- ever overlap within the bounds given, and matches this against
-- the expected value; True for safe, False for unsafe.