Added some (fairly messy) code for taking a list of A.Expression and generating a list of equations
This commit is contained in:
parent
d674a2fdd0
commit
74490c005e
|
@ -18,9 +18,11 @@ with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
module ArrayUsageCheck where
|
||||
|
||||
import Control.Monad.Error
|
||||
import Control.Monad.State
|
||||
import Data.Array.IArray
|
||||
import Data.List
|
||||
import qualified Data.Map as Map
|
||||
import Data.Maybe
|
||||
|
||||
import qualified AST as A
|
||||
|
@ -59,6 +61,94 @@ makeProblems indexLists = map checkEq zippedPairs
|
|||
makeProblem1Dim :: [CoeffExpr] -> [Problem]
|
||||
makeProblem1Dim ces = makeProblems [[c] | c <- ces]
|
||||
|
||||
data FlattenedExp = Const Integer | Scale Integer A.Variable deriving (Eq,Show)
|
||||
|
||||
-- TODO probably want to take this into the PassM monad at some point
|
||||
makeEquations :: [A.Expression] -> A.Expression -> Either String (Map.Map String Int, (EqualityProblem, InequalityProblem))
|
||||
makeEquations es high = makeEquations' >>* (\(s,v,lh) -> (s,(pairEqs v, getIneqs lh v)))
|
||||
where
|
||||
makeEquations' :: Either String (Map.Map String Int, [(Integer,EqualityConstraintEquation)], (EqualityConstraintEquation, EqualityConstraintEquation))
|
||||
makeEquations' = do ((v,h),s) <- (flip runStateT) Map.empty $
|
||||
do flattened <- lift (mapM flatten es)
|
||||
eqs <- mapM makeEquation flattened
|
||||
(1,high') <- (lift $ flatten high) >>= makeEquation
|
||||
return (eqs,high')
|
||||
return (s,v,(amap (const 0) h, h))
|
||||
|
||||
|
||||
-- Takes an expression, and transforms it into an expression like:
|
||||
-- (e_0 + e_1 + e_2) / d
|
||||
-- where d is a constant (non-zero!) integer, and each e_k
|
||||
-- is either a const, a var, const * var, or (const * var) % const [TODO].
|
||||
-- If the expression cannot be transformed into such a format, an error is returned
|
||||
flatten :: A.Expression -> Either String (Integer,[FlattenedExp])
|
||||
flatten (A.Literal _ _ (A.IntLiteral _ n)) = return (1,[Const (read n)])
|
||||
flatten (A.Dyadic m op lhs rhs) | op == A.Add = combine' (flatten lhs) (flatten rhs)
|
||||
| op == A.Subtr = combine' (flatten lhs) (liftM (transformPair id (scale (-1))) $ flatten rhs)
|
||||
-- TODO Mul and Div
|
||||
| otherwise = throwError ("Unhandleable operator found in expression: " ++ show op)
|
||||
flatten (A.ExprVariable _ v) = return (1,[Scale 1 v])
|
||||
flatten other = throwError ("Unhandleable item found in expression: " ++ show other)
|
||||
|
||||
scale :: Integer -> [FlattenedExp] -> [FlattenedExp]
|
||||
scale sc = map scale'
|
||||
where
|
||||
scale' (Const n) = Const (n * sc)
|
||||
scale' (Scale n v) = Scale (n * sc) v
|
||||
|
||||
combine' x y = do {x' <- x; y' <- y; combine x' y'}
|
||||
combine :: (Integer,[FlattenedExp]) -> (Integer,[FlattenedExp]) -> Either String (Integer,[FlattenedExp])
|
||||
combine (nx, ex) (ny, ey) = return $ (nx * ny, scale ny ex ++ scale nx ey)
|
||||
|
||||
--TODO we need to handle lots more different expression types in future.
|
||||
-- For now we just handle dyadic +,-
|
||||
|
||||
varIndex :: A.Variable -> StateT (Map.Map String Int) (Either String) Int
|
||||
varIndex (A.Variable _ (A.Name _ _ varName))
|
||||
= do st <- get
|
||||
let (st',ind) = case Map.lookup varName st of
|
||||
Just val -> (st,val)
|
||||
Nothing -> let newId = (1 + (maximum $ 0 : Map.elems st)) in
|
||||
(Map.insert varName newId st, newId)
|
||||
put st'
|
||||
return ind
|
||||
|
||||
-- Pairs all possible combinations
|
||||
pairEqs :: [(Integer,EqualityConstraintEquation)] -> [EqualityConstraintEquation]
|
||||
pairEqs = filter (any (/= 0) . elems) . map (uncurry pairEqs') . product2 . mkPair
|
||||
where
|
||||
pairEqs' (nx,ex) (ny,ey) = arrayZipWith (-) (amap (* ny) ex) (amap (* nx) ey)
|
||||
|
||||
getIneqs :: (EqualityConstraintEquation, EqualityConstraintEquation) -> [(Integer,EqualityConstraintEquation)] -> [InequalityConstraintEquation]
|
||||
getIneqs (low, high) = concatMap getLH
|
||||
where
|
||||
-- eq / sc >= low => eq - (sc * low) >= 0
|
||||
-- eq / sc <= high => (high * sc) - eq >= 0
|
||||
|
||||
getLH :: (Integer,EqualityConstraintEquation) -> [InequalityConstraintEquation]
|
||||
getLH (sc, eq) = [eq `addEq` (scaleEq (-sc) low),(scaleEq sc high) `addEq` amap negate eq]
|
||||
|
||||
addEq = arrayZipWith (+)
|
||||
scaleEq n = amap (* n)
|
||||
|
||||
makeEquation :: (Integer,[FlattenedExp]) -> StateT (Map.Map String Int) (Either String) (Integer,EqualityConstraintEquation)
|
||||
makeEquation (divisor, summedItems)
|
||||
= do eqs <- foldM makeEquation' Map.empty summedItems
|
||||
max <- maxVar
|
||||
return (divisor, mapToArray max eqs)
|
||||
where
|
||||
makeEquation' :: Map.Map Int Integer -> FlattenedExp -> StateT (Map.Map String Int) (Either String) (Map.Map Int Integer)
|
||||
makeEquation' m (Const n) = return $ add (0,n) m
|
||||
makeEquation' m (Scale n v) = varIndex v >>* (\ind -> add (ind, n) m)
|
||||
|
||||
add :: (Int,Integer) -> Map.Map Int Integer -> Map.Map Int Integer
|
||||
add = uncurry (Map.insertWith (+))
|
||||
|
||||
maxVar = get >>* (maximum . (0 :) . Map.elems)
|
||||
|
||||
mapToArray :: (IArray a v, Num v, Num k, Ord k, Ix k) => k -> Map.Map k v -> a k v
|
||||
mapToArray highest = (\arr -> accumArray (+) 0 (0, highest) arr) . Map.assocs
|
||||
|
||||
type CoeffIndex = Int
|
||||
type EqualityConstraintEquation = Array CoeffIndex Integer
|
||||
type EqualityProblem = [EqualityConstraintEquation]
|
||||
|
|
|
@ -378,8 +378,20 @@ testIndexes = TestList
|
|||
,safeParTest 120 False (0,10) [i,i ++ con 1]
|
||||
,safeParTest 140 True (0,10) [2 ** i, 2 ** i ++ con 1]
|
||||
|
||||
|
||||
,TestCase $ assertStuff "testIndexes makeEq"
|
||||
(Right (Map.empty,(uncurry makeConsistent) (doubleEq [con 0 === con 1],leq [con 0,con 0,con 7] &&& leq [con 0,con 1,con 7]))) $
|
||||
makeEquations [intLiteral 0, intLiteral 1] (intLiteral 7)
|
||||
,TestCase $ assertStuff "testIndexes makeEq 2"
|
||||
(Right (Map.singleton "i" 1,(uncurry makeConsistent) (doubleEq [i === con 3],leq [con 0,con 3,con 7] &&& leq [con 0,i,con 7]))) $
|
||||
makeEquations [exprVariable "i",intLiteral 3] (intLiteral 7)
|
||||
]
|
||||
where
|
||||
doubleEq = concatMap (\(Eq e) -> [Eq e,Eq $ negateVars e])
|
||||
assertStuff title x y = assertEqual title (munge x) (munge y)
|
||||
where
|
||||
munge = transformEither id (transformPair id (transformPair sort sort))
|
||||
|
||||
-- Given some indexes using "i", this function checks whether these can
|
||||
-- ever overlap within the bounds given, and matches this against
|
||||
-- the expected value; True for safe, False for unsafe.
|
||||
|
|
Loading…
Reference in New Issue
Block a user