Added support for the canonicalisation of simple Expressions

This seems to have fixed a couple more normal tests, and goes towards fixing the randomly failing quickcheck test (but that test is not fixed, yet)
This commit is contained in:
Neil Brown 2009-01-18 00:20:14 +00:00
parent 536d0b19a6
commit 949c88bb75
2 changed files with 33 additions and 9 deletions

View File

@ -16,7 +16,7 @@ You should have received a copy of the GNU General Public License along
with this program. If not, see <http://www.gnu.org/licenses/>. with this program. If not, see <http://www.gnu.org/licenses/>.
-} -}
module ArrayUsageCheck (BackgroundKnowledge(..), BK, checkArrayUsage, FlattenedExp(..), makeEquations, makeExpSet, ModuloCase(..), onlyConst, showFlattenedExp, VarMap) where module ArrayUsageCheck (BackgroundKnowledge(..), BK, checkArrayUsage, FlattenedExp(..), makeEquations, makeExpSet, ModuloCase(..), onlyConst, showFlattenedExp, VarMap, canonicalise, fmapFlattenedExp) where
import Control.Monad.Error import Control.Monad.Error
import Control.Monad.State import Control.Monad.State
@ -235,6 +235,14 @@ onlyConst [] = Just 0
onlyConst ((Const n):es) = liftM2 (+) (return n) $ onlyConst es onlyConst ((Const n):es) = liftM2 (+) (return n) $ onlyConst es
onlyConst _ = Nothing onlyConst _ = Nothing
fmapFlattenedExp :: (A.Expression -> A.Expression) -> FlattenedExp -> FlattenedExp
fmapFlattenedExp f x@(Const _) = x
fmapFlattenedExp f (Scale n (e, i)) = Scale n (f e, i)
fmapFlattenedExp f (Modulo n top bottom)
= Modulo n (Set.map (fmapFlattenedExp f) top) (Set.map (fmapFlattenedExp f) bottom)
fmapFlattenedExp f (Divide n top bottom)
= Divide n (Set.map (fmapFlattenedExp f) top) (Set.map (fmapFlattenedExp f) bottom)
-- | A data type representing an array access. Each triple is (index, extra-equalities, extra-inequalities). -- | A data type representing an array access. Each triple is (index, extra-equalities, extra-inequalities).
-- A Single item can be paired with every other access. -- A Single item can be paired with every other access.
-- Each item of a Group cannot be paired with each other, but can be paired with each other access. -- Each item of a Group cannot be paired with each other, but can be paired with each other access.
@ -469,6 +477,20 @@ makeEquations accesses bound
-- Note that in all these functions, the divisor should always be positive! -- Note that in all these functions, the divisor should always be positive!
canonicalise :: A.Expression -> A.Expression
canonicalise e@(A.Dyadic m op _ _) | op == A.Add || op == A.Mul
= foldl1 (A.Dyadic m op) $ sort $ gatherTerms op e
where
gatherTerms :: A.DyadicOp -> A.Expression -> [A.Expression]
gatherTerms op (A.Dyadic _ op' lhs rhs) | op == op'
= gatherTerms op lhs ++ gatherTerms op rhs
gatherTerms _ e = [canonicalise e]
canonicalise (A.Dyadic m op lhs rhs)
= A.Dyadic m op (canonicalise lhs) (canonicalise rhs)
canonicalise (A.Monadic m op rhs)
= A.Monadic m op (canonicalise rhs)
canonicalise e = e
flatten :: A.Expression -> Either String [FlattenedExp] flatten :: A.Expression -> Either String [FlattenedExp]
flatten (A.Literal _ _ (A.IntLiteral _ n)) = return [Const (read n)] flatten (A.Literal _ _ (A.IntLiteral _ n)) = return [Const (read n)]
flatten e@(A.Dyadic m op lhs rhs) flatten e@(A.Dyadic m op lhs rhs)
@ -480,7 +502,7 @@ flatten e@(A.Dyadic m op lhs rhs)
case onlyConst rhs' of case onlyConst rhs' of
Just _ -> liftM2L (Divide 1) (flatten lhs) (return rhs') Just _ -> liftM2L (Divide 1) (flatten lhs) (return rhs')
-- Can't deal with variable divisors, leave expression as-is: -- Can't deal with variable divisors, leave expression as-is:
Nothing -> return [Scale 1 (e,0)] Nothing -> return [Scale 1 (canonicalise e,0)]
where where
-- liftM2L :: (Ord a, Ord b, Monad m) => (Set.Set a -> Set.Set b -> c) -> m [a] -> m [b] -> m [c] -- liftM2L :: (Ord a, Ord b, Monad m) => (Set.Set a -> Set.Set b -> c) -> m [a] -> m [b] -> m [c]
liftM2L f x y = liftM singleton $ liftM2 f (x >>= makeExpSet) (y >>= makeExpSet) liftM2L f x y = liftM singleton $ liftM2 f (x >>= makeExpSet) (y >>= makeExpSet)
@ -499,12 +521,11 @@ flatten e@(A.Dyadic m op lhs rhs)
mult lhs rhs mult lhs rhs
= do lhs' <- backToEq lhs = do lhs' <- backToEq lhs
rhs' <- backToEq rhs rhs' <- backToEq rhs
let (onLeft, onRight) = if lhs' <= rhs' then (lhs', rhs') else (rhs', lhs') return $ (Scale 1 (canonicalise $ A.Dyadic emptyMeta A.Mul lhs' rhs', 0))
return $ (Scale 1 (A.Dyadic emptyMeta A.Mul onLeft onRight, 0))
backScale :: Integer -> A.Expression -> A.Expression backScale :: Integer -> A.Expression -> A.Expression
backScale 1 = id backScale 1 = id
backScale n = A.Dyadic emptyMeta A.Mul (makeConstant emptyMeta (fromInteger n)) backScale n = canonicalise . A.Dyadic emptyMeta A.Mul (makeConstant emptyMeta (fromInteger n))
backToEq :: FlattenedExp -> Either String A.Expression backToEq :: FlattenedExp -> Either String A.Expression
backToEq (Const c) = return $ makeConstant emptyMeta (fromInteger c) backToEq (Const c) = return $ makeConstant emptyMeta (fromInteger c)
@ -540,7 +561,7 @@ flatten e@(A.Dyadic m op lhs rhs)
-- | Combines (adds) two flattened expressions. -- | Combines (adds) two flattened expressions.
combine :: [FlattenedExp] -> [FlattenedExp] -> [FlattenedExp] combine :: [FlattenedExp] -> [FlattenedExp] -> [FlattenedExp]
combine = (++) combine = (++)
flatten e = return [Scale 1 (e,0)] flatten e = return [Scale 1 (canonicalise e,0)]
-- | The "square" refers to making all equations the length of the longest -- | The "square" refers to making all equations the length of the longest
-- one, and the pair refers to pairing each in a list of array accesses (e.g. -- one, and the pair refers to pairing each in a list of array accesses (e.g.

View File

@ -869,8 +869,11 @@ testIndexes = TestList
generateMapping :: TestMonad m r => String -> VarMap -> VarMap -> m [(CoeffIndex,CoeffIndex)] generateMapping :: TestMonad m r => String -> VarMap -> VarMap -> m [(CoeffIndex,CoeffIndex)]
generateMapping msg m0 m1 generateMapping msg m0 m1
= do testEqual ("Keys in variable mapping " ++ msg) (Map.keys m0) (Map.keys m1) = do testEqual ("Keys in variable mapping " ++ msg) (Map.keys m0') (Map.keys m1')
return $ Map.elems $ zipMap mergeMaybe m0 m1 return $ Map.elems $ zipMap mergeMaybe m0' m1'
where
m0' = Map.mapKeys (fmapFlattenedExp canonicalise) m0
m1' = Map.mapKeys (fmapFlattenedExp canonicalise) m1
-- | Given a forward mapping list, translates equations across -- | Given a forward mapping list, translates equations across
translateEquations :: forall m r. TestMonad m r => translateEquations :: forall m r. TestMonad m r =>