diff --git a/checks/ArrayUsageCheck.hs b/checks/ArrayUsageCheck.hs index 7ef466f..2fdf7e2 100644 --- a/checks/ArrayUsageCheck.hs +++ b/checks/ArrayUsageCheck.hs @@ -16,7 +16,7 @@ You should have received a copy of the GNU General Public License along with this program. If not, see . -} -module ArrayUsageCheck (BackgroundKnowledge(..), BK, checkArrayUsage, FlattenedExp(..), makeEquations, makeExpSet, ModuloCase(..), onlyConst, showFlattenedExp, VarMap) where +module ArrayUsageCheck (BackgroundKnowledge(..), BK, checkArrayUsage, FlattenedExp(..), makeEquations, makeExpSet, ModuloCase(..), onlyConst, showFlattenedExp, VarMap, canonicalise, fmapFlattenedExp) where import Control.Monad.Error import Control.Monad.State @@ -235,6 +235,14 @@ onlyConst [] = Just 0 onlyConst ((Const n):es) = liftM2 (+) (return n) $ onlyConst es onlyConst _ = Nothing +fmapFlattenedExp :: (A.Expression -> A.Expression) -> FlattenedExp -> FlattenedExp +fmapFlattenedExp f x@(Const _) = x +fmapFlattenedExp f (Scale n (e, i)) = Scale n (f e, i) +fmapFlattenedExp f (Modulo n top bottom) + = Modulo n (Set.map (fmapFlattenedExp f) top) (Set.map (fmapFlattenedExp f) bottom) +fmapFlattenedExp f (Divide n top bottom) + = Divide n (Set.map (fmapFlattenedExp f) top) (Set.map (fmapFlattenedExp f) bottom) + -- | A data type representing an array access. Each triple is (index, extra-equalities, extra-inequalities). -- A Single item can be paired with every other access. -- Each item of a Group cannot be paired with each other, but can be paired with each other access. @@ -468,7 +476,21 @@ makeEquations accesses bound var = A.Variable m varName -- Note that in all these functions, the divisor should always be positive! - + +canonicalise :: A.Expression -> A.Expression +canonicalise e@(A.Dyadic m op _ _) | op == A.Add || op == A.Mul + = foldl1 (A.Dyadic m op) $ sort $ gatherTerms op e + where + gatherTerms :: A.DyadicOp -> A.Expression -> [A.Expression] + gatherTerms op (A.Dyadic _ op' lhs rhs) | op == op' + = gatherTerms op lhs ++ gatherTerms op rhs + gatherTerms _ e = [canonicalise e] +canonicalise (A.Dyadic m op lhs rhs) + = A.Dyadic m op (canonicalise lhs) (canonicalise rhs) +canonicalise (A.Monadic m op rhs) + = A.Monadic m op (canonicalise rhs) +canonicalise e = e + flatten :: A.Expression -> Either String [FlattenedExp] flatten (A.Literal _ _ (A.IntLiteral _ n)) = return [Const (read n)] flatten e@(A.Dyadic m op lhs rhs) @@ -480,7 +502,7 @@ flatten e@(A.Dyadic m op lhs rhs) case onlyConst rhs' of Just _ -> liftM2L (Divide 1) (flatten lhs) (return rhs') -- Can't deal with variable divisors, leave expression as-is: - Nothing -> return [Scale 1 (e,0)] + Nothing -> return [Scale 1 (canonicalise e,0)] where -- liftM2L :: (Ord a, Ord b, Monad m) => (Set.Set a -> Set.Set b -> c) -> m [a] -> m [b] -> m [c] liftM2L f x y = liftM singleton $ liftM2 f (x >>= makeExpSet) (y >>= makeExpSet) @@ -499,12 +521,11 @@ flatten e@(A.Dyadic m op lhs rhs) mult lhs rhs = do lhs' <- backToEq lhs rhs' <- backToEq rhs - let (onLeft, onRight) = if lhs' <= rhs' then (lhs', rhs') else (rhs', lhs') - return $ (Scale 1 (A.Dyadic emptyMeta A.Mul onLeft onRight, 0)) + return $ (Scale 1 (canonicalise $ A.Dyadic emptyMeta A.Mul lhs' rhs', 0)) backScale :: Integer -> A.Expression -> A.Expression backScale 1 = id - backScale n = A.Dyadic emptyMeta A.Mul (makeConstant emptyMeta (fromInteger n)) + backScale n = canonicalise . A.Dyadic emptyMeta A.Mul (makeConstant emptyMeta (fromInteger n)) backToEq :: FlattenedExp -> Either String A.Expression backToEq (Const c) = return $ makeConstant emptyMeta (fromInteger c) @@ -540,7 +561,7 @@ flatten e@(A.Dyadic m op lhs rhs) -- | Combines (adds) two flattened expressions. combine :: [FlattenedExp] -> [FlattenedExp] -> [FlattenedExp] combine = (++) -flatten e = return [Scale 1 (e,0)] +flatten e = return [Scale 1 (canonicalise e,0)] -- | The "square" refers to making all equations the length of the longest -- one, and the pair refers to pairing each in a list of array accesses (e.g. diff --git a/checks/ArrayUsageCheckTest.hs b/checks/ArrayUsageCheckTest.hs index 4f21451..af38a00 100644 --- a/checks/ArrayUsageCheckTest.hs +++ b/checks/ArrayUsageCheckTest.hs @@ -869,8 +869,11 @@ testIndexes = TestList generateMapping :: TestMonad m r => String -> VarMap -> VarMap -> m [(CoeffIndex,CoeffIndex)] generateMapping msg m0 m1 - = do testEqual ("Keys in variable mapping " ++ msg) (Map.keys m0) (Map.keys m1) - return $ Map.elems $ zipMap mergeMaybe m0 m1 + = do testEqual ("Keys in variable mapping " ++ msg) (Map.keys m0') (Map.keys m1') + return $ Map.elems $ zipMap mergeMaybe m0' m1' + where + m0' = Map.mapKeys (fmapFlattenedExp canonicalise) m0 + m1' = Map.mapKeys (fmapFlattenedExp canonicalise) m1 -- | Given a forward mapping list, translates equations across translateEquations :: forall m r. TestMonad m r =>