diff --git a/common/EvalConstants.hs b/common/EvalConstants.hs
index fb8bdec..8b43e05 100644
--- a/common/EvalConstants.hs
+++ b/common/EvalConstants.hs
@@ -17,7 +17,7 @@ with this program. If not, see .
-}
-- | Evaluate constant expressions.
-module EvalConstants (constantFold, isConstantName) where
+module EvalConstants (constantFold, maybeEvalIntExpression, isConstantName) where
import Control.Monad.Error
import Control.Monad.State
@@ -34,6 +34,7 @@ import EvalLiterals
import Metadata
import ShowCode
import Types
+import Utils
-- | Simplify an expression by constant folding, and also return whether it's a
-- constant after that.
@@ -45,6 +46,15 @@ constantFold e
Right val -> (val, (Nothing, "already folded"))
return (e', isConstant e', msg)
+-- | Try to fold and evaluate an integer expression.
+-- If it's not a constant, return 'Nothing'.
+maybeEvalIntExpression :: (CSMR m, Die m) => A.Expression -> m (Maybe Int)
+maybeEvalIntExpression e
+ = do (e', isConst, _) <- constantFold e
+ if isConst
+ then evalIntExpression e' >>* Just
+ else return Nothing
+
-- | Is a name defined as a constant expression? If so, return its definition.
getConstantName :: (CSMR m, Die m) => A.Name -> m (Maybe A.Expression)
getConstantName n
diff --git a/common/TestUtils.hs b/common/TestUtils.hs
index db65c26..446faf5 100644
--- a/common/TestUtils.hs
+++ b/common/TestUtils.hs
@@ -168,8 +168,16 @@ exprVariablePattern :: String -> Pattern
exprVariablePattern e = tag2 A.ExprVariable DontCare $ variablePattern e
-- | Creates an integer literal 'A.Expression' with the given integer.
+integerLiteral :: A.Type -> Integer -> A.Expression
+integerLiteral t n = A.Literal emptyMeta t $ A.IntLiteral emptyMeta (show n)
+
+-- | Creates an 'A.Int' literal with the given integer.
intLiteral :: Integer -> A.Expression
-intLiteral n = A.Literal emptyMeta A.Int $ A.IntLiteral emptyMeta (show n)
+intLiteral n = integerLiteral A.Int n
+
+-- | Creates an 'A.Byte' literal with the given integer.
+byteLiteral :: Integer -> A.Expression
+byteLiteral n = integerLiteral A.Byte n
-- | Creates a 'Pattern' to match an 'A.Expression' instance.
-- @'assertPatternMatch' ('intLiteralPattern' x) ('intLiteral' x)@ will always succeed.
@@ -276,6 +284,19 @@ simpleDefDecl n t = simpleDef n (A.Declaration emptyMeta t)
simpleDefPattern :: String -> A.AbbrevMode -> Pattern -> Pattern
simpleDefPattern n am sp = tag7 A.NameDef DontCare n n A.VariableName sp am A.Unplaced
+-- | Define a @VAL IS@ constant.
+defineConst :: String -> A.Type -> A.Expression -> State CompState ()
+defineConst s t e = defineName (simpleName s) $
+ A.NameDef {
+ A.ndMeta = emptyMeta,
+ A.ndName = s,
+ A.ndOrigName = s,
+ A.ndNameType = A.VariableName,
+ A.ndType = A.IsExpr emptyMeta A.ValAbbrev t e,
+ A.ndAbbrevMode = A.ValAbbrev,
+ A.ndPlacement = A.Unplaced
+ }
+
--}}}
--{{{ custom assertions
diff --git a/frontends/OccamPasses.hs b/frontends/OccamPasses.hs
index bae4e6a..d270e34 100644
--- a/frontends/OccamPasses.hs
+++ b/frontends/OccamPasses.hs
@@ -17,19 +17,23 @@ with this program. If not, see .
-}
-- | The occam-specific frontend passes.
-module OccamPasses (occamPasses, foldConstants, checkConstants) where
+module OccamPasses (occamPasses, foldConstants, checkConstants,
+ checkRetypes) where
-import Control.Monad
+import Control.Monad.State
import Data.Generics
+import System.IO
import qualified AST as A
import CompState
+import Errors
import EvalConstants
import EvalLiterals
import Metadata
import Pass
import qualified Properties as Prop
import ShowCode
+import Types
-- | Occam-specific frontend passes.
occamPasses :: [Pass]
@@ -40,6 +44,9 @@ occamPasses = makePassesDep' ((== FrontendOccam) . csFrontend)
, ("Check mandatory constants", checkConstants,
[Prop.constantsFolded],
[Prop.constantsChecked])
+ , ("Check retyping", checkRetypes,
+ [],
+ [Prop.retypesChecked])
, ("Dummy occam pass", dummyOccamPass,
[],
Prop.agg_namesDone ++ [Prop.expressionTypesChecked,
@@ -97,6 +104,49 @@ checkConstants = doGeneric `extM` doDimension `extM` doOption
doGeneric o
doOption o = doGeneric o
+-- | Check that retyping is safe.
+checkRetypes :: Data t => t -> PassM t
+checkRetypes = everywhereASTM doSpecType
+ where
+ doSpecType :: A.SpecType -> PassM A.SpecType
+ doSpecType st@(A.Retypes m _ t v)
+ = do fromT <- typeOfVariable v
+ checkRetypes m fromT t
+ return st
+ doSpecType st@(A.RetypesExpr m _ t e)
+ = do fromT <- typeOfExpression e
+ checkRetypes m fromT t
+ return st
+ doSpecType st = return st
+
+ checkRetypes :: Meta -> A.Type -> A.Type -> PassM ()
+ checkRetypes m fromT toT
+ = do (fromBI, fromN) <- evalBytesInType fromT
+ (toBI, toN) <- evalBytesInType toT
+ case (fromBI, toBI, fromN, toN) of
+ (_, BIManyFree, _, _) ->
+ dieP m "Multiple free dimensions in retype destination type"
+ (BIJust _, BIJust _, Just a, Just b) ->
+ when (a /= b) $
+ dieP m "Sizes do not match in retype"
+ (BIJust _, BIOneFree _ _, Just a, Just b) ->
+ when (not ((b <= a) && (a `mod` b == 0))) $
+ dieP m "Sizes do not match in retype"
+ (BIOneFree _ _, BIJust _, Just a, Just b) ->
+ when (not ((a <= b) && (b `mod` a == 0))) $
+ dieP m "Sizes do not match in retype"
+ -- Otherwise we must do a runtime check.
+ _ -> return ()
+
+ evalBytesInType :: A.Type -> PassM (BytesInResult, Maybe Int)
+ evalBytesInType t
+ = do bi <- bytesInType t
+ n <- case bi of
+ BIJust e -> maybeEvalIntExpression e
+ BIOneFree e _ -> maybeEvalIntExpression e
+ _ -> return Nothing
+ return (bi, n)
+
-- | A dummy pass for things that haven't been separated out into passes yet.
dummyOccamPass :: Data t => t -> PassM t
dummyOccamPass = return
diff --git a/frontends/OccamPassesTest.hs b/frontends/OccamPassesTest.hs
index 7ea5c0a..17e24b9 100644
--- a/frontends/OccamPassesTest.hs
+++ b/frontends/OccamPassesTest.hs
@@ -35,6 +35,17 @@ import TestUtils
m :: Meta
m = emptyMeta
+-- | Initial state for the tests.
+startState :: State CompState ()
+startState
+ = do defineConst "const" A.Int (intLiteral 2)
+ defineConst "someInt" A.Int (intLiteral 42)
+ defineConst "someByte" A.Byte (byteLiteral 24)
+ defineConst "someInts" (A.Array [A.UnknownDimension] A.Int)
+ undefined
+ defineConst "someBytes" (A.Array [A.UnknownDimension] A.Byte)
+ undefined
+
-- | Test 'OccamPasses.foldConstants'.
testFoldConstants :: Test
testFoldConstants = TestList
@@ -81,21 +92,6 @@ testFoldConstants = TestList
exp (OccamPasses.foldConstants orig)
startState
- startState :: State CompState ()
- startState = defineConst "const" A.Int two
-
- defineConst :: String -> A.Type -> A.Expression -> State CompState ()
- defineConst s t e = defineName (simpleName s) $
- A.NameDef {
- A.ndMeta = m,
- A.ndName = "const",
- A.ndOrigName = "const",
- A.ndNameType = A.VariableName,
- A.ndType = A.IsExpr m A.ValAbbrev t e,
- A.ndAbbrevMode = A.ValAbbrev,
- A.ndPlacement = A.Unplaced
- }
-
testSame :: Int -> A.Expression -> Test
testSame n orig = test n orig orig
@@ -158,8 +154,64 @@ testCheckConstants = TestList
var = exprVariable "var"
skip = A.Skip m
+-- | Test 'OccamPasses.checkRetypes'.
+testCheckRetypes :: Test
+testCheckRetypes = TestList
+ [
+ -- Definitely OK at compile time
+ testOK 0 $ retypesV A.Int intV
+ , testOK 1 $ retypesE A.Int intE
+ , testOK 2 $ retypesV A.Byte byteV
+ , testOK 3 $ retypesE A.Byte byteE
+ , testOK 4 $ retypesV known1 intV
+ , testOK 5 $ retypesV known2 intV
+ , testOK 6 $ retypesV both intV
+ , testOK 7 $ retypesV unknown1 intV
+
+ -- Definitely wrong at compile time
+ , testFail 100 $ retypesV A.Byte intV
+ , testFail 101 $ retypesV A.Int byteV
+ , testFail 102 $ retypesV unknown2 intV
+ , testFail 103 $ retypesV unknown2 intsV
+ , testFail 104 $ retypesV A.Byte intsV
+
+ -- Can't tell; need a runtime check
+ , testOK 200 $ retypesV unknown1 intsV
+ , testOK 201 $ retypesV A.Int intsV
+ , testOK 202 $ retypesV known2 intsV
+ , testOK 203 $ retypesV unknown1 bytesV
+ ]
+ where
+ testOK :: (Show a, Data a) => Int -> a -> Test
+ testOK n orig
+ = TestCase $ testPass ("testCheckRetypes" ++ show n)
+ orig (OccamPasses.checkRetypes orig)
+ startState
+
+ testFail :: (Show a, Data a) => Int -> a -> Test
+ testFail n orig
+ = TestCase $ testPassShouldFail ("testCheckRetypes" ++ show n)
+ (OccamPasses.checkRetypes orig)
+ startState
+
+ retypesV = A.Retypes m A.ValAbbrev
+ retypesE = A.RetypesExpr m A.ValAbbrev
+
+ intV = variable "someInt"
+ intE = intLiteral 42
+ byteV = variable "someByte"
+ byteE = byteLiteral 42
+ intsV = variable "someInts"
+ bytesV = variable "someBytes"
+ known1 = A.Array [dimension 4] A.Byte
+ known2 = A.Array [dimension 2, dimension 2] A.Byte
+ both = A.Array [dimension 2, A.UnknownDimension] A.Byte
+ unknown1 = A.Array [A.UnknownDimension] A.Int
+ unknown2 = A.Array [A.UnknownDimension, A.UnknownDimension] A.Int
+
tests :: Test
tests = TestLabel "OccamPassesTest" $ TestList
[ testFoldConstants
, testCheckConstants
+ , testCheckRetypes
]
diff --git a/frontends/ParseOccam.hs b/frontends/ParseOccam.hs
index e38c14c..81e154e 100644
--- a/frontends/ParseOccam.hs
+++ b/frontends/ParseOccam.hs
@@ -1428,7 +1428,6 @@ retypesAbbrev
sColon
eol
origT <- typeOfVariable v
- --checkRetypes m origT s
return $ A.Specification m n $ A.Retypes m A.Abbrev s v
<|> do m <- md
(s, n) <- tryVVX channelSpecifier newChannelName retypesReshapes
@@ -1436,7 +1435,6 @@ retypesAbbrev
sColon
eol
origT <- typeOfVariable c
- --checkRetypes m origT s
return $ A.Specification m n $ A.Retypes m A.Abbrev s c
<|> do m <- md
(s, n) <- tryXVVX sVAL dataSpecifier newVariableName retypesReshapes
@@ -1444,29 +1442,9 @@ retypesAbbrev
sColon
eol
origT <- typeOfExpression e
- --checkRetypes m origT s
return $ A.Specification m n $ A.RetypesExpr m A.ValAbbrev s e
> "RETYPES/RESHAPES abbreviation"
-{-
--- | Check that a RETYPES\/RESHAPES is safe.
-checkRetypes :: Meta -> A.Type -> A.Type -> OccParser ()
--- Retyping channels is always "safe".
-checkRetypes _ (A.Chan {}) (A.Chan {}) = return ()
-checkRetypes m fromT toT
- = do bf <- bytesInType fromT
- bt <- bytesInType toT
- case (bf, bt) of
- (BIJust a, BIJust b) ->
- when (a /= b) $ dieP m "size mismatch in RETYPES"
- (BIJust a, BIOneFree b _) ->
- when (not ((b <= a) && (a `mod` b == 0))) $ dieP m "size mismatch in RETYPES"
- (_, BIManyFree) ->
- dieP m "multiple free dimensions in RETYPES/RESHAPES type"
- -- Otherwise we have to do a runtime check.
- _ -> return ()
--}
-
dataSpecifier :: OccParser A.Type
dataSpecifier
= dataType
diff --git a/pass/Properties.hs b/pass/Properties.hs
index fa51c18..7653c2e 100644
--- a/pass/Properties.hs
+++ b/pass/Properties.hs
@@ -50,6 +50,7 @@ module Properties
, processTypesChecked
, rainParDeclarationsPulledUp
, rangeTransformed
+ , retypesChecked
, seqInputsFlattened
, slicesSimplified
, subscriptsPulledUp
@@ -78,13 +79,30 @@ import Types
import Utils
agg_namesDone :: [Property]
-agg_namesDone = [declarationsUnique, declarationTypesRecorded, inferredTypesRecorded, declaredNamesResolved]
+agg_namesDone =
+ [ declarationTypesRecorded
+ , declarationsUnique
+ , declaredNamesResolved
+ , inferredTypesRecorded
+ ]
agg_typesDone :: [Property]
-agg_typesDone = [expressionTypesChecked, inferredTypesRecorded, processTypesChecked, typesResolvedInAST, typesResolvedInState, constantsFolded, constantsChecked]
+agg_typesDone =
+ [ constantsChecked
+ , constantsFolded
+ , expressionTypesChecked
+ , inferredTypesRecorded
+ , processTypesChecked
+ , retypesChecked
+ , typesResolvedInAST
+ , typesResolvedInState
+ ]
agg_functionsGone :: [Property]
-agg_functionsGone = [functionCallsRemoved, functionsRemoved]
+agg_functionsGone =
+ [ functionCallsRemoved
+ , functionsRemoved
+ ]
-- Mark out all the checks I still need to implement:
checkTODO :: Monad m => A.AST -> m ()
@@ -142,10 +160,13 @@ declarationsUnique = Property "declarationsUnique" $
checkDupes (n':ns)
constantsChecked :: Property
-constantsChecked = Property "constantsChecked" checkTODO
+constantsChecked = Property "constantsChecked" nocheck
constantsFolded :: Property
-constantsFolded = Property "constantsFolded" checkTODO
+constantsFolded = Property "constantsFolded" nocheck
+
+retypesChecked :: Property
+retypesChecked = Property "retypesChecked" nocheck
intLiteralsInBounds :: Property
intLiteralsInBounds = Property "intLiteralsInBounds" $