diff --git a/backends/BackendPassesTest.hs b/backends/BackendPassesTest.hs index 21d9ea8..b3bab67 100644 --- a/backends/BackendPassesTest.hs +++ b/backends/BackendPassesTest.hs @@ -22,6 +22,7 @@ with this program. If not, see . module BackendPassesTest (tests) where import Control.Monad.State +import Data.Generics import qualified Data.Map as Map import Test.HUnit hiding (State) @@ -33,6 +34,7 @@ import Pattern import TagAST import TestUtils import TreeUtils +import Utils m :: Meta m = emptyMeta @@ -154,6 +156,14 @@ testDeclareSizes = TestList ,testFoo 11 $ isChanArrFoo 2 ,testFoo 12 $ isChanArrFoo 3 + ,testRecordFoo 20 [] + ,testRecordFoo 21 [A.Int] + ,testRecordFoo 22 [A.Array [A.Dimension 3] A.Int] + ,testRecordFoo 23 [A.Array (map A.Dimension [3,4,5,6]) A.Int] + + ,testRecordFoo 24 [A.Int, A.Array [A.Dimension 3] A.Int] + ,testRecordFoo 25 [A.Byte, A.Int, A.Array [A.Dimension 3] A.Int, A.Array (map A.Dimension [3,4,5,6]) A.Int, A.Array (map A.Dimension [1,2]) A.Int] + {- ,testFooDecl 10 [Nothing] ,testFooDecl 11 [Just 4, Nothing] @@ -163,7 +173,6 @@ testDeclareSizes = TestList ,testFooDecl 15 [Just 4, Nothing, Just 5, Nothing, Nothing] -} --TODO test that arrays that are abbreviations (Is and IsExpr) also get _sizes arrays, and that they are initialised correctly - --TODO test records getting sizes arrays --TODO test reshapes/retypes abbreviations ] where @@ -178,27 +187,52 @@ testDeclareSizes = TestList isChanArrFoo :: Int -> (A.SpecType, A.SpecType, State CompState ()) isChanArrFoo n = (A.IsChannelArray emptyMeta (A.Array [A.Dimension n] $ A.Chan A.DirUnknown (A.ChanAttributes False False) A.Byte) (replicate n $ variable "c") - ,valFoo [n], return ()) + ,valSize [n], return ()) + + testRecordFoo :: Int -> [A.Type] -> Test + -- Give fields arbitrary names (for testing), then check that all ones that are array types + -- do get _sizes array (concat of array name, field name and _sizes) + testRecordFoo n ts = test n + (declRecord fields $ flip (foldr declSizeItems) (reverse fields) term) + (declRecord fields term) (return ()) (sequence_ . flip applyAll (map checkSizeItems fields)) + where + fields = (zip ["x_" ++ show n | n <- [(0::Integer)..]] ts) + + declRecord :: Data a => [(String, A.Type)] -> A.Structured a -> A.Structured a + declRecord fields = A.Spec emptyMeta (A.Specification emptyMeta (simpleName "foo") fooSpec) + where + fooSpec = A.RecordType emptyMeta False (map (\(n,t) -> (simpleName n, t)) fields) + + declSizeItems :: Data a => (String, A.Type) -> A.Structured a -> A.Structured a + declSizeItems (n, A.Array ds _) = A.Spec emptyMeta (A.Specification emptyMeta (simpleName $ "foo" ++ n) $ + valSize $ map (\(A.Dimension n) -> n) ds) + declSizeItems _ = id + + checkSizeItems :: (String, A.Type) -> CompState -> Assertion + checkSizeItems (n, A.Array ds _) = checkSizes ("foo" ++ n) (valSize $ map (\(A.Dimension n) -> n) ds) + checkSizeItems _ = const (return ()) declFoo :: [Int] -> (A.SpecType, A.SpecType, State CompState ()) - declFoo ns = (A.Declaration emptyMeta t Nothing, valFoo ns, return ()) + declFoo ns = (A.Declaration emptyMeta t Nothing, valSize ns, return ()) where t = A.Array (map A.Dimension ns) A.Byte - valFoo :: [Int] -> A.SpecType - valFoo ds = A.IsExpr emptyMeta A.ValAbbrev (A.Array [A.Dimension $ length ds] A.Int) $ makeSizesLiteral ds + valSize :: [Int] -> A.SpecType + valSize ds = A.IsExpr emptyMeta A.ValAbbrev (A.Array [A.Dimension $ length ds] A.Int) $ makeSizesLiteral ds makeSizesLiteral :: [Int] -> A.Expression makeSizesLiteral xs = A.Literal emptyMeta (A.Array [A.Dimension $ length xs] A.Int) $ A.ArrayLiteral emptyMeta $ map (A.ArrayElemExpr . A.Literal emptyMeta A.Int . A.IntLiteral emptyMeta . show) xs - checkFooSizes :: A.SpecType -> CompState -> Assertion - checkFooSizes spec cs - = do nd <- case Map.lookup "foo_sizes" (csNames cs) of + checkFooSizes = checkSizes "foo_sizes" + + checkSizes :: String -> A.SpecType -> CompState -> Assertion + checkSizes n spec cs + = do nd <- case Map.lookup n (csNames cs) of Just nd -> return nd - Nothing -> assertFailure "Could not find foo_sizes" >> return undefined - assertEqual "ndName" "foo_sizes" (A.ndName nd) - assertEqual "ndOrigName" "foo_sizes" (A.ndOrigName nd) + Nothing -> assertFailure ("Could not find " ++ n) >> return undefined + assertEqual "ndName" n (A.ndName nd) + assertEqual "ndOrigName" n (A.ndOrigName nd) assertEqual "ndType" spec (A.ndType nd) assertEqual "ndAbbrevMode" A.ValAbbrev (A.ndAbbrevMode nd)