Move Retypes checking from the occam parser into a pass.

This also fixes a bug in the original algorithm: it used to let you retype
[]INT to BYTE.
This commit is contained in:
Adam Sampson 2008-03-19 19:38:56 +00:00
parent b8caf7c3b6
commit e08aac59d3
6 changed files with 178 additions and 46 deletions

View File

@ -17,7 +17,7 @@ with this program. If not, see <http://www.gnu.org/licenses/>.
-} -}
-- | Evaluate constant expressions. -- | Evaluate constant expressions.
module EvalConstants (constantFold, isConstantName) where module EvalConstants (constantFold, maybeEvalIntExpression, isConstantName) where
import Control.Monad.Error import Control.Monad.Error
import Control.Monad.State import Control.Monad.State
@ -34,6 +34,7 @@ import EvalLiterals
import Metadata import Metadata
import ShowCode import ShowCode
import Types import Types
import Utils
-- | Simplify an expression by constant folding, and also return whether it's a -- | Simplify an expression by constant folding, and also return whether it's a
-- constant after that. -- constant after that.
@ -45,6 +46,15 @@ constantFold e
Right val -> (val, (Nothing, "already folded")) Right val -> (val, (Nothing, "already folded"))
return (e', isConstant e', msg) return (e', isConstant e', msg)
-- | Try to fold and evaluate an integer expression.
-- If it's not a constant, return 'Nothing'.
maybeEvalIntExpression :: (CSMR m, Die m) => A.Expression -> m (Maybe Int)
maybeEvalIntExpression e
= do (e', isConst, _) <- constantFold e
if isConst
then evalIntExpression e' >>* Just
else return Nothing
-- | Is a name defined as a constant expression? If so, return its definition. -- | Is a name defined as a constant expression? If so, return its definition.
getConstantName :: (CSMR m, Die m) => A.Name -> m (Maybe A.Expression) getConstantName :: (CSMR m, Die m) => A.Name -> m (Maybe A.Expression)
getConstantName n getConstantName n

View File

@ -168,8 +168,16 @@ exprVariablePattern :: String -> Pattern
exprVariablePattern e = tag2 A.ExprVariable DontCare $ variablePattern e exprVariablePattern e = tag2 A.ExprVariable DontCare $ variablePattern e
-- | Creates an integer literal 'A.Expression' with the given integer. -- | Creates an integer literal 'A.Expression' with the given integer.
integerLiteral :: A.Type -> Integer -> A.Expression
integerLiteral t n = A.Literal emptyMeta t $ A.IntLiteral emptyMeta (show n)
-- | Creates an 'A.Int' literal with the given integer.
intLiteral :: Integer -> A.Expression intLiteral :: Integer -> A.Expression
intLiteral n = A.Literal emptyMeta A.Int $ A.IntLiteral emptyMeta (show n) intLiteral n = integerLiteral A.Int n
-- | Creates an 'A.Byte' literal with the given integer.
byteLiteral :: Integer -> A.Expression
byteLiteral n = integerLiteral A.Byte n
-- | Creates a 'Pattern' to match an 'A.Expression' instance. -- | Creates a 'Pattern' to match an 'A.Expression' instance.
-- @'assertPatternMatch' ('intLiteralPattern' x) ('intLiteral' x)@ will always succeed. -- @'assertPatternMatch' ('intLiteralPattern' x) ('intLiteral' x)@ will always succeed.
@ -276,6 +284,19 @@ simpleDefDecl n t = simpleDef n (A.Declaration emptyMeta t)
simpleDefPattern :: String -> A.AbbrevMode -> Pattern -> Pattern simpleDefPattern :: String -> A.AbbrevMode -> Pattern -> Pattern
simpleDefPattern n am sp = tag7 A.NameDef DontCare n n A.VariableName sp am A.Unplaced simpleDefPattern n am sp = tag7 A.NameDef DontCare n n A.VariableName sp am A.Unplaced
-- | Define a @VAL IS@ constant.
defineConst :: String -> A.Type -> A.Expression -> State CompState ()
defineConst s t e = defineName (simpleName s) $
A.NameDef {
A.ndMeta = emptyMeta,
A.ndName = s,
A.ndOrigName = s,
A.ndNameType = A.VariableName,
A.ndType = A.IsExpr emptyMeta A.ValAbbrev t e,
A.ndAbbrevMode = A.ValAbbrev,
A.ndPlacement = A.Unplaced
}
--}}} --}}}
--{{{ custom assertions --{{{ custom assertions

View File

@ -17,19 +17,23 @@ with this program. If not, see <http://www.gnu.org/licenses/>.
-} -}
-- | The occam-specific frontend passes. -- | The occam-specific frontend passes.
module OccamPasses (occamPasses, foldConstants, checkConstants) where module OccamPasses (occamPasses, foldConstants, checkConstants,
checkRetypes) where
import Control.Monad import Control.Monad.State
import Data.Generics import Data.Generics
import System.IO
import qualified AST as A import qualified AST as A
import CompState import CompState
import Errors
import EvalConstants import EvalConstants
import EvalLiterals import EvalLiterals
import Metadata import Metadata
import Pass import Pass
import qualified Properties as Prop import qualified Properties as Prop
import ShowCode import ShowCode
import Types
-- | Occam-specific frontend passes. -- | Occam-specific frontend passes.
occamPasses :: [Pass] occamPasses :: [Pass]
@ -40,6 +44,9 @@ occamPasses = makePassesDep' ((== FrontendOccam) . csFrontend)
, ("Check mandatory constants", checkConstants, , ("Check mandatory constants", checkConstants,
[Prop.constantsFolded], [Prop.constantsFolded],
[Prop.constantsChecked]) [Prop.constantsChecked])
, ("Check retyping", checkRetypes,
[],
[Prop.retypesChecked])
, ("Dummy occam pass", dummyOccamPass, , ("Dummy occam pass", dummyOccamPass,
[], [],
Prop.agg_namesDone ++ [Prop.expressionTypesChecked, Prop.agg_namesDone ++ [Prop.expressionTypesChecked,
@ -97,6 +104,49 @@ checkConstants = doGeneric `extM` doDimension `extM` doOption
doGeneric o doGeneric o
doOption o = doGeneric o doOption o = doGeneric o
-- | Check that retyping is safe.
checkRetypes :: Data t => t -> PassM t
checkRetypes = everywhereASTM doSpecType
where
doSpecType :: A.SpecType -> PassM A.SpecType
doSpecType st@(A.Retypes m _ t v)
= do fromT <- typeOfVariable v
checkRetypes m fromT t
return st
doSpecType st@(A.RetypesExpr m _ t e)
= do fromT <- typeOfExpression e
checkRetypes m fromT t
return st
doSpecType st = return st
checkRetypes :: Meta -> A.Type -> A.Type -> PassM ()
checkRetypes m fromT toT
= do (fromBI, fromN) <- evalBytesInType fromT
(toBI, toN) <- evalBytesInType toT
case (fromBI, toBI, fromN, toN) of
(_, BIManyFree, _, _) ->
dieP m "Multiple free dimensions in retype destination type"
(BIJust _, BIJust _, Just a, Just b) ->
when (a /= b) $
dieP m "Sizes do not match in retype"
(BIJust _, BIOneFree _ _, Just a, Just b) ->
when (not ((b <= a) && (a `mod` b == 0))) $
dieP m "Sizes do not match in retype"
(BIOneFree _ _, BIJust _, Just a, Just b) ->
when (not ((a <= b) && (b `mod` a == 0))) $
dieP m "Sizes do not match in retype"
-- Otherwise we must do a runtime check.
_ -> return ()
evalBytesInType :: A.Type -> PassM (BytesInResult, Maybe Int)
evalBytesInType t
= do bi <- bytesInType t
n <- case bi of
BIJust e -> maybeEvalIntExpression e
BIOneFree e _ -> maybeEvalIntExpression e
_ -> return Nothing
return (bi, n)
-- | A dummy pass for things that haven't been separated out into passes yet. -- | A dummy pass for things that haven't been separated out into passes yet.
dummyOccamPass :: Data t => t -> PassM t dummyOccamPass :: Data t => t -> PassM t
dummyOccamPass = return dummyOccamPass = return

View File

@ -35,6 +35,17 @@ import TestUtils
m :: Meta m :: Meta
m = emptyMeta m = emptyMeta
-- | Initial state for the tests.
startState :: State CompState ()
startState
= do defineConst "const" A.Int (intLiteral 2)
defineConst "someInt" A.Int (intLiteral 42)
defineConst "someByte" A.Byte (byteLiteral 24)
defineConst "someInts" (A.Array [A.UnknownDimension] A.Int)
undefined
defineConst "someBytes" (A.Array [A.UnknownDimension] A.Byte)
undefined
-- | Test 'OccamPasses.foldConstants'. -- | Test 'OccamPasses.foldConstants'.
testFoldConstants :: Test testFoldConstants :: Test
testFoldConstants = TestList testFoldConstants = TestList
@ -81,21 +92,6 @@ testFoldConstants = TestList
exp (OccamPasses.foldConstants orig) exp (OccamPasses.foldConstants orig)
startState startState
startState :: State CompState ()
startState = defineConst "const" A.Int two
defineConst :: String -> A.Type -> A.Expression -> State CompState ()
defineConst s t e = defineName (simpleName s) $
A.NameDef {
A.ndMeta = m,
A.ndName = "const",
A.ndOrigName = "const",
A.ndNameType = A.VariableName,
A.ndType = A.IsExpr m A.ValAbbrev t e,
A.ndAbbrevMode = A.ValAbbrev,
A.ndPlacement = A.Unplaced
}
testSame :: Int -> A.Expression -> Test testSame :: Int -> A.Expression -> Test
testSame n orig = test n orig orig testSame n orig = test n orig orig
@ -158,8 +154,64 @@ testCheckConstants = TestList
var = exprVariable "var" var = exprVariable "var"
skip = A.Skip m skip = A.Skip m
-- | Test 'OccamPasses.checkRetypes'.
testCheckRetypes :: Test
testCheckRetypes = TestList
[
-- Definitely OK at compile time
testOK 0 $ retypesV A.Int intV
, testOK 1 $ retypesE A.Int intE
, testOK 2 $ retypesV A.Byte byteV
, testOK 3 $ retypesE A.Byte byteE
, testOK 4 $ retypesV known1 intV
, testOK 5 $ retypesV known2 intV
, testOK 6 $ retypesV both intV
, testOK 7 $ retypesV unknown1 intV
-- Definitely wrong at compile time
, testFail 100 $ retypesV A.Byte intV
, testFail 101 $ retypesV A.Int byteV
, testFail 102 $ retypesV unknown2 intV
, testFail 103 $ retypesV unknown2 intsV
, testFail 104 $ retypesV A.Byte intsV
-- Can't tell; need a runtime check
, testOK 200 $ retypesV unknown1 intsV
, testOK 201 $ retypesV A.Int intsV
, testOK 202 $ retypesV known2 intsV
, testOK 203 $ retypesV unknown1 bytesV
]
where
testOK :: (Show a, Data a) => Int -> a -> Test
testOK n orig
= TestCase $ testPass ("testCheckRetypes" ++ show n)
orig (OccamPasses.checkRetypes orig)
startState
testFail :: (Show a, Data a) => Int -> a -> Test
testFail n orig
= TestCase $ testPassShouldFail ("testCheckRetypes" ++ show n)
(OccamPasses.checkRetypes orig)
startState
retypesV = A.Retypes m A.ValAbbrev
retypesE = A.RetypesExpr m A.ValAbbrev
intV = variable "someInt"
intE = intLiteral 42
byteV = variable "someByte"
byteE = byteLiteral 42
intsV = variable "someInts"
bytesV = variable "someBytes"
known1 = A.Array [dimension 4] A.Byte
known2 = A.Array [dimension 2, dimension 2] A.Byte
both = A.Array [dimension 2, A.UnknownDimension] A.Byte
unknown1 = A.Array [A.UnknownDimension] A.Int
unknown2 = A.Array [A.UnknownDimension, A.UnknownDimension] A.Int
tests :: Test tests :: Test
tests = TestLabel "OccamPassesTest" $ TestList tests = TestLabel "OccamPassesTest" $ TestList
[ testFoldConstants [ testFoldConstants
, testCheckConstants , testCheckConstants
, testCheckRetypes
] ]

View File

@ -1428,7 +1428,6 @@ retypesAbbrev
sColon sColon
eol eol
origT <- typeOfVariable v origT <- typeOfVariable v
--checkRetypes m origT s
return $ A.Specification m n $ A.Retypes m A.Abbrev s v return $ A.Specification m n $ A.Retypes m A.Abbrev s v
<|> do m <- md <|> do m <- md
(s, n) <- tryVVX channelSpecifier newChannelName retypesReshapes (s, n) <- tryVVX channelSpecifier newChannelName retypesReshapes
@ -1436,7 +1435,6 @@ retypesAbbrev
sColon sColon
eol eol
origT <- typeOfVariable c origT <- typeOfVariable c
--checkRetypes m origT s
return $ A.Specification m n $ A.Retypes m A.Abbrev s c return $ A.Specification m n $ A.Retypes m A.Abbrev s c
<|> do m <- md <|> do m <- md
(s, n) <- tryXVVX sVAL dataSpecifier newVariableName retypesReshapes (s, n) <- tryXVVX sVAL dataSpecifier newVariableName retypesReshapes
@ -1444,29 +1442,9 @@ retypesAbbrev
sColon sColon
eol eol
origT <- typeOfExpression e origT <- typeOfExpression e
--checkRetypes m origT s
return $ A.Specification m n $ A.RetypesExpr m A.ValAbbrev s e return $ A.Specification m n $ A.RetypesExpr m A.ValAbbrev s e
<?> "RETYPES/RESHAPES abbreviation" <?> "RETYPES/RESHAPES abbreviation"
{-
-- | Check that a RETYPES\/RESHAPES is safe.
checkRetypes :: Meta -> A.Type -> A.Type -> OccParser ()
-- Retyping channels is always "safe".
checkRetypes _ (A.Chan {}) (A.Chan {}) = return ()
checkRetypes m fromT toT
= do bf <- bytesInType fromT
bt <- bytesInType toT
case (bf, bt) of
(BIJust a, BIJust b) ->
when (a /= b) $ dieP m "size mismatch in RETYPES"
(BIJust a, BIOneFree b _) ->
when (not ((b <= a) && (a `mod` b == 0))) $ dieP m "size mismatch in RETYPES"
(_, BIManyFree) ->
dieP m "multiple free dimensions in RETYPES/RESHAPES type"
-- Otherwise we have to do a runtime check.
_ -> return ()
-}
dataSpecifier :: OccParser A.Type dataSpecifier :: OccParser A.Type
dataSpecifier dataSpecifier
= dataType = dataType

View File

@ -50,6 +50,7 @@ module Properties
, processTypesChecked , processTypesChecked
, rainParDeclarationsPulledUp , rainParDeclarationsPulledUp
, rangeTransformed , rangeTransformed
, retypesChecked
, seqInputsFlattened , seqInputsFlattened
, slicesSimplified , slicesSimplified
, subscriptsPulledUp , subscriptsPulledUp
@ -78,13 +79,30 @@ import Types
import Utils import Utils
agg_namesDone :: [Property] agg_namesDone :: [Property]
agg_namesDone = [declarationsUnique, declarationTypesRecorded, inferredTypesRecorded, declaredNamesResolved] agg_namesDone =
[ declarationTypesRecorded
, declarationsUnique
, declaredNamesResolved
, inferredTypesRecorded
]
agg_typesDone :: [Property] agg_typesDone :: [Property]
agg_typesDone = [expressionTypesChecked, inferredTypesRecorded, processTypesChecked, typesResolvedInAST, typesResolvedInState, constantsFolded, constantsChecked] agg_typesDone =
[ constantsChecked
, constantsFolded
, expressionTypesChecked
, inferredTypesRecorded
, processTypesChecked
, retypesChecked
, typesResolvedInAST
, typesResolvedInState
]
agg_functionsGone :: [Property] agg_functionsGone :: [Property]
agg_functionsGone = [functionCallsRemoved, functionsRemoved] agg_functionsGone =
[ functionCallsRemoved
, functionsRemoved
]
-- Mark out all the checks I still need to implement: -- Mark out all the checks I still need to implement:
checkTODO :: Monad m => A.AST -> m () checkTODO :: Monad m => A.AST -> m ()
@ -142,10 +160,13 @@ declarationsUnique = Property "declarationsUnique" $
checkDupes (n':ns) checkDupes (n':ns)
constantsChecked :: Property constantsChecked :: Property
constantsChecked = Property "constantsChecked" checkTODO constantsChecked = Property "constantsChecked" nocheck
constantsFolded :: Property constantsFolded :: Property
constantsFolded = Property "constantsFolded" checkTODO constantsFolded = Property "constantsFolded" nocheck
retypesChecked :: Property
retypesChecked = Property "retypesChecked" nocheck
intLiteralsInBounds :: Property intLiteralsInBounds :: Property
intLiteralsInBounds = Property "intLiteralsInBounds" $ intLiteralsInBounds = Property "intLiteralsInBounds" $