Moved the flatten function to the top level of the ArrayUsageCheck module

This commit is contained in:
Neil Brown 2008-01-16 15:09:26 +00:00
parent 5f2158531b
commit 6e28d3e3db

View File

@ -175,45 +175,24 @@ makeExpSet = foldM makeExpSet' Set.empty
type VarMap = Map.Map FlattenedExp Int
-- | Given a list of expressions, an expression representing the upper array bound, returns either an error
-- (because the expressions can't be handled, typically) or a set of equalities, inequalities and mapping from
-- (unique, munged) variable name to variable-index in the equations.
-- TODO probably want to take this into the PassM monad at some point, to use the Meta in the error message
makeEquations :: [A.Expression] -> A.Expression -> Either String [(VarMap, (EqualityProblem, InequalityProblem))]
makeEquations es high = makeEquations' >>* (\(s,v,lh) -> [(s,squareEquations eqIneq) | eqIneq <- pairEqsAndBounds v lh])
where
-- | The body of makeEquations; returns the variable mapping, the list of (nx,ex) pairs and a pair
-- representing the upper and lower bounds of the array (inclusive).
makeEquations' :: Either String (VarMap, [[(EqualityConstraintEquation,EqualityProblem,InequalityProblem)]], (EqualityConstraintEquation, EqualityConstraintEquation))
makeEquations' = do ((v,h),s) <- (flip runStateT) Map.empty $
do flattened <- lift (mapM flatten es)
eqs <- mapM makeEquation flattened
high' <- (lift $ flatten high) >>= makeEquation
high'' <- case high' of
[(h,_,_)] -> return h
_ -> throwError "Multiple possible upper bounds not supported"
return (eqs,high'')
return (s,v,(amap (const 0) h, addConstant (-1) h))
-- Note that in all these functions, the divisor should always be positive!
-- Takes an expression, and transforms it into an expression like:
-- (e_0 + e_1 + e_2) / d
-- where d is a constant (non-zero!) integer, and each e_k
-- is either a const, a var, const * var, or (const * var) % const [TODO].
-- If the expression cannot be transformed into such a format, an error is returned
flatten :: A.Expression -> Either String [FlattenedExp]
flatten (A.Literal _ _ (A.IntLiteral _ n)) = return [Const (read n)]
flatten (A.Dyadic m op lhs rhs) | op == A.Add = combine' (flatten lhs) (flatten rhs)
| op == A.Subtr = combine' (flatten lhs) (liftM (scale (-1)) $ flatten rhs)
| op == A.Mul = multiplyOut' (flatten lhs) (flatten rhs)
| op == A.Rem = liftM2L Modulo (flatten lhs) (flatten rhs)
| op == A.Div = liftM2L Divide (flatten lhs) (flatten rhs)
| otherwise = throwError ("Unhandleable operator found in expression: " ++ show op)
flatten (A.ExprVariable _ v) = return [Scale 1 (v,0)]
flatten other = throwError ("Unhandleable item found in expression: " ++ show other)
-- Note that in all these functions, the divisor should always be positive!
-- Takes an expression, and transforms it into an expression like:
-- (e_0 + e_1 + e_2) / d
-- where d is a constant (non-zero!) integer, and each e_k
-- is either a const, a var, const * var, or (const * var) % const [TODO].
-- If the expression cannot be transformed into such a format, an error is returned
flatten :: A.Expression -> Either String [FlattenedExp]
flatten (A.Literal _ _ (A.IntLiteral _ n)) = return [Const (read n)]
flatten (A.ExprVariable _ v) = return [Scale 1 (v,0)]
flatten (A.Dyadic m op lhs rhs) | op == A.Add = combine' (flatten lhs) (flatten rhs)
| op == A.Subtr = combine' (flatten lhs) (liftM (scale (-1)) $ flatten rhs)
| op == A.Mul = multiplyOut' (flatten lhs) (flatten rhs)
| op == A.Rem = liftM2L Modulo (flatten lhs) (flatten rhs)
| op == A.Div = liftM2L Divide (flatten lhs) (flatten rhs)
| otherwise = throwError ("Unhandleable operator found in expression: " ++ show op)
where
-- liftM2L :: (Ord a, Ord b, Monad m) => (Set.Set a -> Set.Set b -> c) -> m [a] -> m [b] -> m [c]
liftM2L f x y = liftM (:[]) $ liftM2 f (x >>= makeExpSet) (y >>= makeExpSet)
@ -249,6 +228,31 @@ makeEquations es high = makeEquations' >>* (\(s,v,lh) -> [(s,squareEquations eqI
-- | Combines (adds) two flattened expressions.
combine :: [FlattenedExp] -> [FlattenedExp] -> [FlattenedExp]
combine = (++)
flatten other = throwError ("Unhandleable item found in expression: " ++ show other)
-- | Given a list of expressions, an expression representing the upper array bound, returns either an error
-- (because the expressions can't be handled, typically) or a set of equalities, inequalities and mapping from
-- (unique, munged) variable name to variable-index in the equations.
-- TODO probably want to take this into the PassM monad at some point, to use the Meta in the error message
-- TODO allow "background knowledge" in the form of other equalities and inequalities
makeEquations :: [A.Expression] -> A.Expression -> Either String [(VarMap, (EqualityProblem, InequalityProblem))]
makeEquations es high = makeEquations' >>* (\(s,v,lh) -> [(s,squareEquations eqIneq) | eqIneq <- pairEqsAndBounds v lh])
where
-- | The body of makeEquations; returns the variable mapping, the list of (nx,ex) pairs and a pair
-- representing the upper and lower bounds of the array (inclusive).
makeEquations' :: Either String (VarMap, [[(EqualityConstraintEquation,EqualityProblem,InequalityProblem)]], (EqualityConstraintEquation, EqualityConstraintEquation))
makeEquations' = do ((v,h),s) <- (flip runStateT) Map.empty $
do flattened <- lift (mapM flatten es)
eqs <- mapM makeEquation flattened
high' <- (lift $ flatten high) >>= makeEquation
high'' <- case high' of
[(h,_,_)] -> return h
_ -> throwError "Multiple possible upper bounds not supported"
return (eqs,high'')
return (s,v,(amap (const 0) h, addConstant (-1) h))
-- | Finds the index associated with a particular variable; either by finding an existing index