diff --git a/common/EvalConstants.hs b/common/EvalConstants.hs index fb8bdec..8b43e05 100644 --- a/common/EvalConstants.hs +++ b/common/EvalConstants.hs @@ -17,7 +17,7 @@ with this program. If not, see . -} -- | Evaluate constant expressions. -module EvalConstants (constantFold, isConstantName) where +module EvalConstants (constantFold, maybeEvalIntExpression, isConstantName) where import Control.Monad.Error import Control.Monad.State @@ -34,6 +34,7 @@ import EvalLiterals import Metadata import ShowCode import Types +import Utils -- | Simplify an expression by constant folding, and also return whether it's a -- constant after that. @@ -45,6 +46,15 @@ constantFold e Right val -> (val, (Nothing, "already folded")) return (e', isConstant e', msg) +-- | Try to fold and evaluate an integer expression. +-- If it's not a constant, return 'Nothing'. +maybeEvalIntExpression :: (CSMR m, Die m) => A.Expression -> m (Maybe Int) +maybeEvalIntExpression e + = do (e', isConst, _) <- constantFold e + if isConst + then evalIntExpression e' >>* Just + else return Nothing + -- | Is a name defined as a constant expression? If so, return its definition. getConstantName :: (CSMR m, Die m) => A.Name -> m (Maybe A.Expression) getConstantName n diff --git a/common/TestUtils.hs b/common/TestUtils.hs index db65c26..446faf5 100644 --- a/common/TestUtils.hs +++ b/common/TestUtils.hs @@ -168,8 +168,16 @@ exprVariablePattern :: String -> Pattern exprVariablePattern e = tag2 A.ExprVariable DontCare $ variablePattern e -- | Creates an integer literal 'A.Expression' with the given integer. +integerLiteral :: A.Type -> Integer -> A.Expression +integerLiteral t n = A.Literal emptyMeta t $ A.IntLiteral emptyMeta (show n) + +-- | Creates an 'A.Int' literal with the given integer. intLiteral :: Integer -> A.Expression -intLiteral n = A.Literal emptyMeta A.Int $ A.IntLiteral emptyMeta (show n) +intLiteral n = integerLiteral A.Int n + +-- | Creates an 'A.Byte' literal with the given integer. +byteLiteral :: Integer -> A.Expression +byteLiteral n = integerLiteral A.Byte n -- | Creates a 'Pattern' to match an 'A.Expression' instance. -- @'assertPatternMatch' ('intLiteralPattern' x) ('intLiteral' x)@ will always succeed. @@ -276,6 +284,19 @@ simpleDefDecl n t = simpleDef n (A.Declaration emptyMeta t) simpleDefPattern :: String -> A.AbbrevMode -> Pattern -> Pattern simpleDefPattern n am sp = tag7 A.NameDef DontCare n n A.VariableName sp am A.Unplaced +-- | Define a @VAL IS@ constant. +defineConst :: String -> A.Type -> A.Expression -> State CompState () +defineConst s t e = defineName (simpleName s) $ + A.NameDef { + A.ndMeta = emptyMeta, + A.ndName = s, + A.ndOrigName = s, + A.ndNameType = A.VariableName, + A.ndType = A.IsExpr emptyMeta A.ValAbbrev t e, + A.ndAbbrevMode = A.ValAbbrev, + A.ndPlacement = A.Unplaced + } + --}}} --{{{ custom assertions diff --git a/frontends/OccamPasses.hs b/frontends/OccamPasses.hs index bae4e6a..d270e34 100644 --- a/frontends/OccamPasses.hs +++ b/frontends/OccamPasses.hs @@ -17,19 +17,23 @@ with this program. If not, see . -} -- | The occam-specific frontend passes. -module OccamPasses (occamPasses, foldConstants, checkConstants) where +module OccamPasses (occamPasses, foldConstants, checkConstants, + checkRetypes) where -import Control.Monad +import Control.Monad.State import Data.Generics +import System.IO import qualified AST as A import CompState +import Errors import EvalConstants import EvalLiterals import Metadata import Pass import qualified Properties as Prop import ShowCode +import Types -- | Occam-specific frontend passes. occamPasses :: [Pass] @@ -40,6 +44,9 @@ occamPasses = makePassesDep' ((== FrontendOccam) . csFrontend) , ("Check mandatory constants", checkConstants, [Prop.constantsFolded], [Prop.constantsChecked]) + , ("Check retyping", checkRetypes, + [], + [Prop.retypesChecked]) , ("Dummy occam pass", dummyOccamPass, [], Prop.agg_namesDone ++ [Prop.expressionTypesChecked, @@ -97,6 +104,49 @@ checkConstants = doGeneric `extM` doDimension `extM` doOption doGeneric o doOption o = doGeneric o +-- | Check that retyping is safe. +checkRetypes :: Data t => t -> PassM t +checkRetypes = everywhereASTM doSpecType + where + doSpecType :: A.SpecType -> PassM A.SpecType + doSpecType st@(A.Retypes m _ t v) + = do fromT <- typeOfVariable v + checkRetypes m fromT t + return st + doSpecType st@(A.RetypesExpr m _ t e) + = do fromT <- typeOfExpression e + checkRetypes m fromT t + return st + doSpecType st = return st + + checkRetypes :: Meta -> A.Type -> A.Type -> PassM () + checkRetypes m fromT toT + = do (fromBI, fromN) <- evalBytesInType fromT + (toBI, toN) <- evalBytesInType toT + case (fromBI, toBI, fromN, toN) of + (_, BIManyFree, _, _) -> + dieP m "Multiple free dimensions in retype destination type" + (BIJust _, BIJust _, Just a, Just b) -> + when (a /= b) $ + dieP m "Sizes do not match in retype" + (BIJust _, BIOneFree _ _, Just a, Just b) -> + when (not ((b <= a) && (a `mod` b == 0))) $ + dieP m "Sizes do not match in retype" + (BIOneFree _ _, BIJust _, Just a, Just b) -> + when (not ((a <= b) && (b `mod` a == 0))) $ + dieP m "Sizes do not match in retype" + -- Otherwise we must do a runtime check. + _ -> return () + + evalBytesInType :: A.Type -> PassM (BytesInResult, Maybe Int) + evalBytesInType t + = do bi <- bytesInType t + n <- case bi of + BIJust e -> maybeEvalIntExpression e + BIOneFree e _ -> maybeEvalIntExpression e + _ -> return Nothing + return (bi, n) + -- | A dummy pass for things that haven't been separated out into passes yet. dummyOccamPass :: Data t => t -> PassM t dummyOccamPass = return diff --git a/frontends/OccamPassesTest.hs b/frontends/OccamPassesTest.hs index 7ea5c0a..17e24b9 100644 --- a/frontends/OccamPassesTest.hs +++ b/frontends/OccamPassesTest.hs @@ -35,6 +35,17 @@ import TestUtils m :: Meta m = emptyMeta +-- | Initial state for the tests. +startState :: State CompState () +startState + = do defineConst "const" A.Int (intLiteral 2) + defineConst "someInt" A.Int (intLiteral 42) + defineConst "someByte" A.Byte (byteLiteral 24) + defineConst "someInts" (A.Array [A.UnknownDimension] A.Int) + undefined + defineConst "someBytes" (A.Array [A.UnknownDimension] A.Byte) + undefined + -- | Test 'OccamPasses.foldConstants'. testFoldConstants :: Test testFoldConstants = TestList @@ -81,21 +92,6 @@ testFoldConstants = TestList exp (OccamPasses.foldConstants orig) startState - startState :: State CompState () - startState = defineConst "const" A.Int two - - defineConst :: String -> A.Type -> A.Expression -> State CompState () - defineConst s t e = defineName (simpleName s) $ - A.NameDef { - A.ndMeta = m, - A.ndName = "const", - A.ndOrigName = "const", - A.ndNameType = A.VariableName, - A.ndType = A.IsExpr m A.ValAbbrev t e, - A.ndAbbrevMode = A.ValAbbrev, - A.ndPlacement = A.Unplaced - } - testSame :: Int -> A.Expression -> Test testSame n orig = test n orig orig @@ -158,8 +154,64 @@ testCheckConstants = TestList var = exprVariable "var" skip = A.Skip m +-- | Test 'OccamPasses.checkRetypes'. +testCheckRetypes :: Test +testCheckRetypes = TestList + [ + -- Definitely OK at compile time + testOK 0 $ retypesV A.Int intV + , testOK 1 $ retypesE A.Int intE + , testOK 2 $ retypesV A.Byte byteV + , testOK 3 $ retypesE A.Byte byteE + , testOK 4 $ retypesV known1 intV + , testOK 5 $ retypesV known2 intV + , testOK 6 $ retypesV both intV + , testOK 7 $ retypesV unknown1 intV + + -- Definitely wrong at compile time + , testFail 100 $ retypesV A.Byte intV + , testFail 101 $ retypesV A.Int byteV + , testFail 102 $ retypesV unknown2 intV + , testFail 103 $ retypesV unknown2 intsV + , testFail 104 $ retypesV A.Byte intsV + + -- Can't tell; need a runtime check + , testOK 200 $ retypesV unknown1 intsV + , testOK 201 $ retypesV A.Int intsV + , testOK 202 $ retypesV known2 intsV + , testOK 203 $ retypesV unknown1 bytesV + ] + where + testOK :: (Show a, Data a) => Int -> a -> Test + testOK n orig + = TestCase $ testPass ("testCheckRetypes" ++ show n) + orig (OccamPasses.checkRetypes orig) + startState + + testFail :: (Show a, Data a) => Int -> a -> Test + testFail n orig + = TestCase $ testPassShouldFail ("testCheckRetypes" ++ show n) + (OccamPasses.checkRetypes orig) + startState + + retypesV = A.Retypes m A.ValAbbrev + retypesE = A.RetypesExpr m A.ValAbbrev + + intV = variable "someInt" + intE = intLiteral 42 + byteV = variable "someByte" + byteE = byteLiteral 42 + intsV = variable "someInts" + bytesV = variable "someBytes" + known1 = A.Array [dimension 4] A.Byte + known2 = A.Array [dimension 2, dimension 2] A.Byte + both = A.Array [dimension 2, A.UnknownDimension] A.Byte + unknown1 = A.Array [A.UnknownDimension] A.Int + unknown2 = A.Array [A.UnknownDimension, A.UnknownDimension] A.Int + tests :: Test tests = TestLabel "OccamPassesTest" $ TestList [ testFoldConstants , testCheckConstants + , testCheckRetypes ] diff --git a/frontends/ParseOccam.hs b/frontends/ParseOccam.hs index e38c14c..81e154e 100644 --- a/frontends/ParseOccam.hs +++ b/frontends/ParseOccam.hs @@ -1428,7 +1428,6 @@ retypesAbbrev sColon eol origT <- typeOfVariable v - --checkRetypes m origT s return $ A.Specification m n $ A.Retypes m A.Abbrev s v <|> do m <- md (s, n) <- tryVVX channelSpecifier newChannelName retypesReshapes @@ -1436,7 +1435,6 @@ retypesAbbrev sColon eol origT <- typeOfVariable c - --checkRetypes m origT s return $ A.Specification m n $ A.Retypes m A.Abbrev s c <|> do m <- md (s, n) <- tryXVVX sVAL dataSpecifier newVariableName retypesReshapes @@ -1444,29 +1442,9 @@ retypesAbbrev sColon eol origT <- typeOfExpression e - --checkRetypes m origT s return $ A.Specification m n $ A.RetypesExpr m A.ValAbbrev s e "RETYPES/RESHAPES abbreviation" -{- --- | Check that a RETYPES\/RESHAPES is safe. -checkRetypes :: Meta -> A.Type -> A.Type -> OccParser () --- Retyping channels is always "safe". -checkRetypes _ (A.Chan {}) (A.Chan {}) = return () -checkRetypes m fromT toT - = do bf <- bytesInType fromT - bt <- bytesInType toT - case (bf, bt) of - (BIJust a, BIJust b) -> - when (a /= b) $ dieP m "size mismatch in RETYPES" - (BIJust a, BIOneFree b _) -> - when (not ((b <= a) && (a `mod` b == 0))) $ dieP m "size mismatch in RETYPES" - (_, BIManyFree) -> - dieP m "multiple free dimensions in RETYPES/RESHAPES type" - -- Otherwise we have to do a runtime check. - _ -> return () --} - dataSpecifier :: OccParser A.Type dataSpecifier = dataType diff --git a/pass/Properties.hs b/pass/Properties.hs index fa51c18..7653c2e 100644 --- a/pass/Properties.hs +++ b/pass/Properties.hs @@ -50,6 +50,7 @@ module Properties , processTypesChecked , rainParDeclarationsPulledUp , rangeTransformed + , retypesChecked , seqInputsFlattened , slicesSimplified , subscriptsPulledUp @@ -78,13 +79,30 @@ import Types import Utils agg_namesDone :: [Property] -agg_namesDone = [declarationsUnique, declarationTypesRecorded, inferredTypesRecorded, declaredNamesResolved] +agg_namesDone = + [ declarationTypesRecorded + , declarationsUnique + , declaredNamesResolved + , inferredTypesRecorded + ] agg_typesDone :: [Property] -agg_typesDone = [expressionTypesChecked, inferredTypesRecorded, processTypesChecked, typesResolvedInAST, typesResolvedInState, constantsFolded, constantsChecked] +agg_typesDone = + [ constantsChecked + , constantsFolded + , expressionTypesChecked + , inferredTypesRecorded + , processTypesChecked + , retypesChecked + , typesResolvedInAST + , typesResolvedInState + ] agg_functionsGone :: [Property] -agg_functionsGone = [functionCallsRemoved, functionsRemoved] +agg_functionsGone = + [ functionCallsRemoved + , functionsRemoved + ] -- Mark out all the checks I still need to implement: checkTODO :: Monad m => A.AST -> m () @@ -142,10 +160,13 @@ declarationsUnique = Property "declarationsUnique" $ checkDupes (n':ns) constantsChecked :: Property -constantsChecked = Property "constantsChecked" checkTODO +constantsChecked = Property "constantsChecked" nocheck constantsFolded :: Property -constantsFolded = Property "constantsFolded" checkTODO +constantsFolded = Property "constantsFolded" nocheck + +retypesChecked :: Property +retypesChecked = Property "retypesChecked" nocheck intLiteralsInBounds :: Property intLiteralsInBounds = Property "intLiteralsInBounds" $