diff --git a/transformations/ArrayUsageCheck.hs b/transformations/ArrayUsageCheck.hs index ab3b628..7200d94 100644 --- a/transformations/ArrayUsageCheck.hs +++ b/transformations/ArrayUsageCheck.hs @@ -87,8 +87,8 @@ checkArrayUsage tree = (mapM_ checkPar $ listify (const True) tree) >> return tr showFlattenedExp :: FlattenedExp -> PassM String showFlattenedExp (Const n) = return $ show n - showFlattenedExp (Scale n (A.Variable _ vn)) - = do vn' <- getRealName vn + showFlattenedExp (Scale n ((A.Variable _ vn),vi)) + = do vn' <- getRealName vn >>* (\s -> if vi == 0 then s else s ++ replicate vi '\'' ) case n of 1 -> return vn' -1 -> return $ "-" ++ vn' @@ -110,7 +110,7 @@ checkArrayUsage tree = (mapM_ checkPar $ listify (const True) tree) >> return tr -- | A type for inside makeEquations: data FlattenedExp = Const Integer - | Scale Integer A.Variable + | Scale Integer (A.Variable, Int) | Modulo (Set.Set FlattenedExp) (Set.Set FlattenedExp) | Divide (Set.Set FlattenedExp) (Set.Set FlattenedExp) @@ -121,7 +121,7 @@ instance Ord FlattenedExp where compare (Const _) (Const _) = EQ compare (Const _) _ = LT compare _ (Const _) = GT - compare (Scale _ lv) (Scale _ rv) = customVarCompare lv rv + compare (Scale _ (lv,li)) (Scale _ (rv,ri)) = combineCompare [customVarCompare lv rv, compare li ri] compare (Scale {}) _ = LT compare _ (Scale {}) = GT compare (Modulo ltop lbottom) (Modulo rtop rbottom) @@ -166,9 +166,10 @@ makeExpSet = foldM makeExpSet' Set.empty addConst x (Const n) s = Just $ Set.insert (Const (n + x)) s addConst _ _ _ = Nothing - addScale :: Integer -> A.Variable -> FlattenedExp -> Set.Set FlattenedExp -> Maybe (Set.Set FlattenedExp) - addScale x lv (Scale n rv) s | EQ == customVarCompare lv rv = Just $ Set.insert (Scale (x + n) rv) s - | otherwise = Nothing + addScale :: Integer -> (A.Variable,Int) -> FlattenedExp -> Set.Set FlattenedExp -> Maybe (Set.Set FlattenedExp) + addScale x (lv,li) (Scale n (rv,ri)) s + | (EQ == customVarCompare lv rv) && (li == ri) = Just $ Set.insert (Scale (x + n) (rv,ri)) s + | otherwise = Nothing addScale _ _ _ _ = Nothing type VarMap = Map.Map FlattenedExp Int @@ -209,7 +210,7 @@ makeEquations es high = makeEquations' >>* (\(s,v,lh) -> [(s,squareEquations eqI | op == A.Rem = liftM2L Modulo (flatten lhs) (flatten rhs) | op == A.Div = liftM2L Divide (flatten lhs) (flatten rhs) | otherwise = throwError ("Unhandleable operator found in expression: " ++ show op) - flatten (A.ExprVariable _ v) = return [Scale 1 v] + flatten (A.ExprVariable _ v) = return [Scale 1 (v,0)] flatten other = throwError ("Unhandleable item found in expression: " ++ show other) -- liftM2L :: (Ord a, Ord b, Monad m) => (Set.Set a -> Set.Set b -> c) -> m [a] -> m [b] -> m [c] @@ -252,12 +253,12 @@ makeEquations es high = makeEquations' >>* (\(s,v,lh) -> [(s,squareEquations eqI -- | Finds the index associated with a particular variable; either by finding an existing index -- or allocating a new one. varIndex :: FlattenedExp -> StateT (VarMap) (Either String) Int - varIndex (Scale _ var@(A.Variable _ (A.Name _ _ varName))) + varIndex (Scale _ (var@(A.Variable _ (A.Name _ _ varName)),vi)) = do st <- get - let (st',ind) = case Map.lookup (Scale 1 var) st of + let (st',ind) = case Map.lookup (Scale 1 (var,vi)) st of Just val -> (st,val) Nothing -> let newId = (1 + (maximum $ 0 : Map.elems st)) in - (Map.insert (Scale 1 var) newId st, newId) + (Map.insert (Scale 1 (var,vi)) newId st, newId) put st' return ind varIndex mod@(Modulo top bottom) diff --git a/transformations/ArrayUsageCheckTest.hs b/transformations/ArrayUsageCheckTest.hs index 224928f..8d6167c 100644 --- a/transformations/ArrayUsageCheckTest.hs +++ b/transformations/ArrayUsageCheckTest.hs @@ -338,17 +338,17 @@ testMakeEquations = TestList pairLatterTwo (a,b,c) = (a,(b,c)) - i_mapping = Map.singleton (Scale 1 $ variable "i") 1 - ij_mapping = Map.fromList [(Scale 1 $ variable "i",1),(Scale 1 $ variable "j",2)] - i_mod_mapping n = Map.fromList [(Scale 1 $ variable "i",1),(Modulo (Set.singleton $ Scale 1 $ variable "i") (Set.singleton $ Const n),2)] - i_mod_j_mapping = Map.fromList [(Scale 1 $ variable "i",1),(Scale 1 $ variable "j",2), - (Modulo (Set.singleton $ Scale 1 $ variable "i") (Set.singleton $ Scale 1 $ variable "j"),3)] - _3i_2j_mod_mapping n = Map.fromList [(Scale 1 $ variable "i",1),(Scale 1 $ variable "j",2), - (Modulo (Set.fromList [(Scale 3 $ variable "i"),(Scale (-2) $ variable "j")]) (Set.singleton $ Const n),3)] + i_mapping = Map.singleton (Scale 1 $ (variable "i",0)) 1 + ij_mapping = Map.fromList [(Scale 1 $ (variable "i",0),1),(Scale 1 $ (variable "j",0),2)] + i_mod_mapping n = Map.fromList [(Scale 1 $ (variable "i",0),1),(Modulo (Set.singleton $ Scale 1 $ (variable "i",0)) (Set.singleton $ Const n),2)] + i_mod_j_mapping = Map.fromList [(Scale 1 $ (variable "i",0),1),(Scale 1 $ (variable "j",0),2), + (Modulo (Set.singleton $ Scale 1 $ (variable "i",0)) (Set.singleton $ Scale 1 $ (variable "j",0)),3)] + _3i_2j_mod_mapping n = Map.fromList [(Scale 1 $ (variable "i",0),1),(Scale 1 $ (variable "j",0),2), + (Modulo (Set.fromList [(Scale 3 $ (variable "i",0)),(Scale (-2) $ (variable "j",0))]) (Set.singleton $ Const n),3)] -- i REM m, i + 1 REM n - i_ip1_mod_mapping m n = Map.fromList [(Scale 1 $ variable "i",1) - ,(Modulo (Set.singleton $ Scale 1 $ variable "i") (Set.singleton $ Const m),2) - ,(Modulo (Set.fromList [Scale 1 $ variable "i", Const 1]) (Set.singleton $ Const n),3) + i_ip1_mod_mapping m n = Map.fromList [(Scale 1 $ (variable "i",0),1) + ,(Modulo (Set.singleton $ Scale 1 $ (variable "i",0)) (Set.singleton $ Const m),2) + ,(Modulo (Set.fromList [Scale 1 $ (variable "i",0), Const 1]) (Set.singleton $ Const n),3) ] -- Helper functions for i REM 2 vs (i + 1) REM 4. Each one is a pair of equalities, inequalities