Added support for the canonicalisation of simple Expressions

This seems to have fixed a couple more normal tests, and goes towards fixing the randomly failing quickcheck test (but that test is not fixed, yet)
This commit is contained in:
Neil Brown 2009-01-18 00:20:14 +00:00
parent 536d0b19a6
commit 949c88bb75
2 changed files with 33 additions and 9 deletions

View File

@ -16,7 +16,7 @@ You should have received a copy of the GNU General Public License along
with this program. If not, see <http://www.gnu.org/licenses/>.
-}
module ArrayUsageCheck (BackgroundKnowledge(..), BK, checkArrayUsage, FlattenedExp(..), makeEquations, makeExpSet, ModuloCase(..), onlyConst, showFlattenedExp, VarMap) where
module ArrayUsageCheck (BackgroundKnowledge(..), BK, checkArrayUsage, FlattenedExp(..), makeEquations, makeExpSet, ModuloCase(..), onlyConst, showFlattenedExp, VarMap, canonicalise, fmapFlattenedExp) where
import Control.Monad.Error
import Control.Monad.State
@ -235,6 +235,14 @@ onlyConst [] = Just 0
onlyConst ((Const n):es) = liftM2 (+) (return n) $ onlyConst es
onlyConst _ = Nothing
fmapFlattenedExp :: (A.Expression -> A.Expression) -> FlattenedExp -> FlattenedExp
fmapFlattenedExp f x@(Const _) = x
fmapFlattenedExp f (Scale n (e, i)) = Scale n (f e, i)
fmapFlattenedExp f (Modulo n top bottom)
= Modulo n (Set.map (fmapFlattenedExp f) top) (Set.map (fmapFlattenedExp f) bottom)
fmapFlattenedExp f (Divide n top bottom)
= Divide n (Set.map (fmapFlattenedExp f) top) (Set.map (fmapFlattenedExp f) bottom)
-- | A data type representing an array access. Each triple is (index, extra-equalities, extra-inequalities).
-- A Single item can be paired with every other access.
-- Each item of a Group cannot be paired with each other, but can be paired with each other access.
@ -468,7 +476,21 @@ makeEquations accesses bound
var = A.Variable m varName
-- Note that in all these functions, the divisor should always be positive!
canonicalise :: A.Expression -> A.Expression
canonicalise e@(A.Dyadic m op _ _) | op == A.Add || op == A.Mul
= foldl1 (A.Dyadic m op) $ sort $ gatherTerms op e
where
gatherTerms :: A.DyadicOp -> A.Expression -> [A.Expression]
gatherTerms op (A.Dyadic _ op' lhs rhs) | op == op'
= gatherTerms op lhs ++ gatherTerms op rhs
gatherTerms _ e = [canonicalise e]
canonicalise (A.Dyadic m op lhs rhs)
= A.Dyadic m op (canonicalise lhs) (canonicalise rhs)
canonicalise (A.Monadic m op rhs)
= A.Monadic m op (canonicalise rhs)
canonicalise e = e
flatten :: A.Expression -> Either String [FlattenedExp]
flatten (A.Literal _ _ (A.IntLiteral _ n)) = return [Const (read n)]
flatten e@(A.Dyadic m op lhs rhs)
@ -480,7 +502,7 @@ flatten e@(A.Dyadic m op lhs rhs)
case onlyConst rhs' of
Just _ -> liftM2L (Divide 1) (flatten lhs) (return rhs')
-- Can't deal with variable divisors, leave expression as-is:
Nothing -> return [Scale 1 (e,0)]
Nothing -> return [Scale 1 (canonicalise e,0)]
where
-- liftM2L :: (Ord a, Ord b, Monad m) => (Set.Set a -> Set.Set b -> c) -> m [a] -> m [b] -> m [c]
liftM2L f x y = liftM singleton $ liftM2 f (x >>= makeExpSet) (y >>= makeExpSet)
@ -499,12 +521,11 @@ flatten e@(A.Dyadic m op lhs rhs)
mult lhs rhs
= do lhs' <- backToEq lhs
rhs' <- backToEq rhs
let (onLeft, onRight) = if lhs' <= rhs' then (lhs', rhs') else (rhs', lhs')
return $ (Scale 1 (A.Dyadic emptyMeta A.Mul onLeft onRight, 0))
return $ (Scale 1 (canonicalise $ A.Dyadic emptyMeta A.Mul lhs' rhs', 0))
backScale :: Integer -> A.Expression -> A.Expression
backScale 1 = id
backScale n = A.Dyadic emptyMeta A.Mul (makeConstant emptyMeta (fromInteger n))
backScale n = canonicalise . A.Dyadic emptyMeta A.Mul (makeConstant emptyMeta (fromInteger n))
backToEq :: FlattenedExp -> Either String A.Expression
backToEq (Const c) = return $ makeConstant emptyMeta (fromInteger c)
@ -540,7 +561,7 @@ flatten e@(A.Dyadic m op lhs rhs)
-- | Combines (adds) two flattened expressions.
combine :: [FlattenedExp] -> [FlattenedExp] -> [FlattenedExp]
combine = (++)
flatten e = return [Scale 1 (e,0)]
flatten e = return [Scale 1 (canonicalise e,0)]
-- | The "square" refers to making all equations the length of the longest
-- one, and the pair refers to pairing each in a list of array accesses (e.g.

View File

@ -869,8 +869,11 @@ testIndexes = TestList
generateMapping :: TestMonad m r => String -> VarMap -> VarMap -> m [(CoeffIndex,CoeffIndex)]
generateMapping msg m0 m1
= do testEqual ("Keys in variable mapping " ++ msg) (Map.keys m0) (Map.keys m1)
return $ Map.elems $ zipMap mergeMaybe m0 m1
= do testEqual ("Keys in variable mapping " ++ msg) (Map.keys m0') (Map.keys m1')
return $ Map.elems $ zipMap mergeMaybe m0' m1'
where
m0' = Map.mapKeys (fmapFlattenedExp canonicalise) m0
m1' = Map.mapKeys (fmapFlattenedExp canonicalise) m1
-- | Given a forward mapping list, translates equations across
translateEquations :: forall m r. TestMonad m r =>