Move Retypes checking from the occam parser into a pass.
This also fixes a bug in the original algorithm: it used to let you retype []INT to BYTE.
This commit is contained in:
parent
b8caf7c3b6
commit
e08aac59d3
|
@ -17,7 +17,7 @@ with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
-}
|
||||
|
||||
-- | Evaluate constant expressions.
|
||||
module EvalConstants (constantFold, isConstantName) where
|
||||
module EvalConstants (constantFold, maybeEvalIntExpression, isConstantName) where
|
||||
|
||||
import Control.Monad.Error
|
||||
import Control.Monad.State
|
||||
|
@ -34,6 +34,7 @@ import EvalLiterals
|
|||
import Metadata
|
||||
import ShowCode
|
||||
import Types
|
||||
import Utils
|
||||
|
||||
-- | Simplify an expression by constant folding, and also return whether it's a
|
||||
-- constant after that.
|
||||
|
@ -45,6 +46,15 @@ constantFold e
|
|||
Right val -> (val, (Nothing, "already folded"))
|
||||
return (e', isConstant e', msg)
|
||||
|
||||
-- | Try to fold and evaluate an integer expression.
|
||||
-- If it's not a constant, return 'Nothing'.
|
||||
maybeEvalIntExpression :: (CSMR m, Die m) => A.Expression -> m (Maybe Int)
|
||||
maybeEvalIntExpression e
|
||||
= do (e', isConst, _) <- constantFold e
|
||||
if isConst
|
||||
then evalIntExpression e' >>* Just
|
||||
else return Nothing
|
||||
|
||||
-- | Is a name defined as a constant expression? If so, return its definition.
|
||||
getConstantName :: (CSMR m, Die m) => A.Name -> m (Maybe A.Expression)
|
||||
getConstantName n
|
||||
|
|
|
@ -168,8 +168,16 @@ exprVariablePattern :: String -> Pattern
|
|||
exprVariablePattern e = tag2 A.ExprVariable DontCare $ variablePattern e
|
||||
|
||||
-- | Creates an integer literal 'A.Expression' with the given integer.
|
||||
integerLiteral :: A.Type -> Integer -> A.Expression
|
||||
integerLiteral t n = A.Literal emptyMeta t $ A.IntLiteral emptyMeta (show n)
|
||||
|
||||
-- | Creates an 'A.Int' literal with the given integer.
|
||||
intLiteral :: Integer -> A.Expression
|
||||
intLiteral n = A.Literal emptyMeta A.Int $ A.IntLiteral emptyMeta (show n)
|
||||
intLiteral n = integerLiteral A.Int n
|
||||
|
||||
-- | Creates an 'A.Byte' literal with the given integer.
|
||||
byteLiteral :: Integer -> A.Expression
|
||||
byteLiteral n = integerLiteral A.Byte n
|
||||
|
||||
-- | Creates a 'Pattern' to match an 'A.Expression' instance.
|
||||
-- @'assertPatternMatch' ('intLiteralPattern' x) ('intLiteral' x)@ will always succeed.
|
||||
|
@ -276,6 +284,19 @@ simpleDefDecl n t = simpleDef n (A.Declaration emptyMeta t)
|
|||
simpleDefPattern :: String -> A.AbbrevMode -> Pattern -> Pattern
|
||||
simpleDefPattern n am sp = tag7 A.NameDef DontCare n n A.VariableName sp am A.Unplaced
|
||||
|
||||
-- | Define a @VAL IS@ constant.
|
||||
defineConst :: String -> A.Type -> A.Expression -> State CompState ()
|
||||
defineConst s t e = defineName (simpleName s) $
|
||||
A.NameDef {
|
||||
A.ndMeta = emptyMeta,
|
||||
A.ndName = s,
|
||||
A.ndOrigName = s,
|
||||
A.ndNameType = A.VariableName,
|
||||
A.ndType = A.IsExpr emptyMeta A.ValAbbrev t e,
|
||||
A.ndAbbrevMode = A.ValAbbrev,
|
||||
A.ndPlacement = A.Unplaced
|
||||
}
|
||||
|
||||
--}}}
|
||||
--{{{ custom assertions
|
||||
|
||||
|
|
|
@ -17,19 +17,23 @@ with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
-}
|
||||
|
||||
-- | The occam-specific frontend passes.
|
||||
module OccamPasses (occamPasses, foldConstants, checkConstants) where
|
||||
module OccamPasses (occamPasses, foldConstants, checkConstants,
|
||||
checkRetypes) where
|
||||
|
||||
import Control.Monad
|
||||
import Control.Monad.State
|
||||
import Data.Generics
|
||||
import System.IO
|
||||
|
||||
import qualified AST as A
|
||||
import CompState
|
||||
import Errors
|
||||
import EvalConstants
|
||||
import EvalLiterals
|
||||
import Metadata
|
||||
import Pass
|
||||
import qualified Properties as Prop
|
||||
import ShowCode
|
||||
import Types
|
||||
|
||||
-- | Occam-specific frontend passes.
|
||||
occamPasses :: [Pass]
|
||||
|
@ -40,6 +44,9 @@ occamPasses = makePassesDep' ((== FrontendOccam) . csFrontend)
|
|||
, ("Check mandatory constants", checkConstants,
|
||||
[Prop.constantsFolded],
|
||||
[Prop.constantsChecked])
|
||||
, ("Check retyping", checkRetypes,
|
||||
[],
|
||||
[Prop.retypesChecked])
|
||||
, ("Dummy occam pass", dummyOccamPass,
|
||||
[],
|
||||
Prop.agg_namesDone ++ [Prop.expressionTypesChecked,
|
||||
|
@ -97,6 +104,49 @@ checkConstants = doGeneric `extM` doDimension `extM` doOption
|
|||
doGeneric o
|
||||
doOption o = doGeneric o
|
||||
|
||||
-- | Check that retyping is safe.
|
||||
checkRetypes :: Data t => t -> PassM t
|
||||
checkRetypes = everywhereASTM doSpecType
|
||||
where
|
||||
doSpecType :: A.SpecType -> PassM A.SpecType
|
||||
doSpecType st@(A.Retypes m _ t v)
|
||||
= do fromT <- typeOfVariable v
|
||||
checkRetypes m fromT t
|
||||
return st
|
||||
doSpecType st@(A.RetypesExpr m _ t e)
|
||||
= do fromT <- typeOfExpression e
|
||||
checkRetypes m fromT t
|
||||
return st
|
||||
doSpecType st = return st
|
||||
|
||||
checkRetypes :: Meta -> A.Type -> A.Type -> PassM ()
|
||||
checkRetypes m fromT toT
|
||||
= do (fromBI, fromN) <- evalBytesInType fromT
|
||||
(toBI, toN) <- evalBytesInType toT
|
||||
case (fromBI, toBI, fromN, toN) of
|
||||
(_, BIManyFree, _, _) ->
|
||||
dieP m "Multiple free dimensions in retype destination type"
|
||||
(BIJust _, BIJust _, Just a, Just b) ->
|
||||
when (a /= b) $
|
||||
dieP m "Sizes do not match in retype"
|
||||
(BIJust _, BIOneFree _ _, Just a, Just b) ->
|
||||
when (not ((b <= a) && (a `mod` b == 0))) $
|
||||
dieP m "Sizes do not match in retype"
|
||||
(BIOneFree _ _, BIJust _, Just a, Just b) ->
|
||||
when (not ((a <= b) && (b `mod` a == 0))) $
|
||||
dieP m "Sizes do not match in retype"
|
||||
-- Otherwise we must do a runtime check.
|
||||
_ -> return ()
|
||||
|
||||
evalBytesInType :: A.Type -> PassM (BytesInResult, Maybe Int)
|
||||
evalBytesInType t
|
||||
= do bi <- bytesInType t
|
||||
n <- case bi of
|
||||
BIJust e -> maybeEvalIntExpression e
|
||||
BIOneFree e _ -> maybeEvalIntExpression e
|
||||
_ -> return Nothing
|
||||
return (bi, n)
|
||||
|
||||
-- | A dummy pass for things that haven't been separated out into passes yet.
|
||||
dummyOccamPass :: Data t => t -> PassM t
|
||||
dummyOccamPass = return
|
||||
|
|
|
@ -35,6 +35,17 @@ import TestUtils
|
|||
m :: Meta
|
||||
m = emptyMeta
|
||||
|
||||
-- | Initial state for the tests.
|
||||
startState :: State CompState ()
|
||||
startState
|
||||
= do defineConst "const" A.Int (intLiteral 2)
|
||||
defineConst "someInt" A.Int (intLiteral 42)
|
||||
defineConst "someByte" A.Byte (byteLiteral 24)
|
||||
defineConst "someInts" (A.Array [A.UnknownDimension] A.Int)
|
||||
undefined
|
||||
defineConst "someBytes" (A.Array [A.UnknownDimension] A.Byte)
|
||||
undefined
|
||||
|
||||
-- | Test 'OccamPasses.foldConstants'.
|
||||
testFoldConstants :: Test
|
||||
testFoldConstants = TestList
|
||||
|
@ -81,21 +92,6 @@ testFoldConstants = TestList
|
|||
exp (OccamPasses.foldConstants orig)
|
||||
startState
|
||||
|
||||
startState :: State CompState ()
|
||||
startState = defineConst "const" A.Int two
|
||||
|
||||
defineConst :: String -> A.Type -> A.Expression -> State CompState ()
|
||||
defineConst s t e = defineName (simpleName s) $
|
||||
A.NameDef {
|
||||
A.ndMeta = m,
|
||||
A.ndName = "const",
|
||||
A.ndOrigName = "const",
|
||||
A.ndNameType = A.VariableName,
|
||||
A.ndType = A.IsExpr m A.ValAbbrev t e,
|
||||
A.ndAbbrevMode = A.ValAbbrev,
|
||||
A.ndPlacement = A.Unplaced
|
||||
}
|
||||
|
||||
testSame :: Int -> A.Expression -> Test
|
||||
testSame n orig = test n orig orig
|
||||
|
||||
|
@ -158,8 +154,64 @@ testCheckConstants = TestList
|
|||
var = exprVariable "var"
|
||||
skip = A.Skip m
|
||||
|
||||
-- | Test 'OccamPasses.checkRetypes'.
|
||||
testCheckRetypes :: Test
|
||||
testCheckRetypes = TestList
|
||||
[
|
||||
-- Definitely OK at compile time
|
||||
testOK 0 $ retypesV A.Int intV
|
||||
, testOK 1 $ retypesE A.Int intE
|
||||
, testOK 2 $ retypesV A.Byte byteV
|
||||
, testOK 3 $ retypesE A.Byte byteE
|
||||
, testOK 4 $ retypesV known1 intV
|
||||
, testOK 5 $ retypesV known2 intV
|
||||
, testOK 6 $ retypesV both intV
|
||||
, testOK 7 $ retypesV unknown1 intV
|
||||
|
||||
-- Definitely wrong at compile time
|
||||
, testFail 100 $ retypesV A.Byte intV
|
||||
, testFail 101 $ retypesV A.Int byteV
|
||||
, testFail 102 $ retypesV unknown2 intV
|
||||
, testFail 103 $ retypesV unknown2 intsV
|
||||
, testFail 104 $ retypesV A.Byte intsV
|
||||
|
||||
-- Can't tell; need a runtime check
|
||||
, testOK 200 $ retypesV unknown1 intsV
|
||||
, testOK 201 $ retypesV A.Int intsV
|
||||
, testOK 202 $ retypesV known2 intsV
|
||||
, testOK 203 $ retypesV unknown1 bytesV
|
||||
]
|
||||
where
|
||||
testOK :: (Show a, Data a) => Int -> a -> Test
|
||||
testOK n orig
|
||||
= TestCase $ testPass ("testCheckRetypes" ++ show n)
|
||||
orig (OccamPasses.checkRetypes orig)
|
||||
startState
|
||||
|
||||
testFail :: (Show a, Data a) => Int -> a -> Test
|
||||
testFail n orig
|
||||
= TestCase $ testPassShouldFail ("testCheckRetypes" ++ show n)
|
||||
(OccamPasses.checkRetypes orig)
|
||||
startState
|
||||
|
||||
retypesV = A.Retypes m A.ValAbbrev
|
||||
retypesE = A.RetypesExpr m A.ValAbbrev
|
||||
|
||||
intV = variable "someInt"
|
||||
intE = intLiteral 42
|
||||
byteV = variable "someByte"
|
||||
byteE = byteLiteral 42
|
||||
intsV = variable "someInts"
|
||||
bytesV = variable "someBytes"
|
||||
known1 = A.Array [dimension 4] A.Byte
|
||||
known2 = A.Array [dimension 2, dimension 2] A.Byte
|
||||
both = A.Array [dimension 2, A.UnknownDimension] A.Byte
|
||||
unknown1 = A.Array [A.UnknownDimension] A.Int
|
||||
unknown2 = A.Array [A.UnknownDimension, A.UnknownDimension] A.Int
|
||||
|
||||
tests :: Test
|
||||
tests = TestLabel "OccamPassesTest" $ TestList
|
||||
[ testFoldConstants
|
||||
, testCheckConstants
|
||||
, testCheckRetypes
|
||||
]
|
||||
|
|
|
@ -1428,7 +1428,6 @@ retypesAbbrev
|
|||
sColon
|
||||
eol
|
||||
origT <- typeOfVariable v
|
||||
--checkRetypes m origT s
|
||||
return $ A.Specification m n $ A.Retypes m A.Abbrev s v
|
||||
<|> do m <- md
|
||||
(s, n) <- tryVVX channelSpecifier newChannelName retypesReshapes
|
||||
|
@ -1436,7 +1435,6 @@ retypesAbbrev
|
|||
sColon
|
||||
eol
|
||||
origT <- typeOfVariable c
|
||||
--checkRetypes m origT s
|
||||
return $ A.Specification m n $ A.Retypes m A.Abbrev s c
|
||||
<|> do m <- md
|
||||
(s, n) <- tryXVVX sVAL dataSpecifier newVariableName retypesReshapes
|
||||
|
@ -1444,29 +1442,9 @@ retypesAbbrev
|
|||
sColon
|
||||
eol
|
||||
origT <- typeOfExpression e
|
||||
--checkRetypes m origT s
|
||||
return $ A.Specification m n $ A.RetypesExpr m A.ValAbbrev s e
|
||||
<?> "RETYPES/RESHAPES abbreviation"
|
||||
|
||||
{-
|
||||
-- | Check that a RETYPES\/RESHAPES is safe.
|
||||
checkRetypes :: Meta -> A.Type -> A.Type -> OccParser ()
|
||||
-- Retyping channels is always "safe".
|
||||
checkRetypes _ (A.Chan {}) (A.Chan {}) = return ()
|
||||
checkRetypes m fromT toT
|
||||
= do bf <- bytesInType fromT
|
||||
bt <- bytesInType toT
|
||||
case (bf, bt) of
|
||||
(BIJust a, BIJust b) ->
|
||||
when (a /= b) $ dieP m "size mismatch in RETYPES"
|
||||
(BIJust a, BIOneFree b _) ->
|
||||
when (not ((b <= a) && (a `mod` b == 0))) $ dieP m "size mismatch in RETYPES"
|
||||
(_, BIManyFree) ->
|
||||
dieP m "multiple free dimensions in RETYPES/RESHAPES type"
|
||||
-- Otherwise we have to do a runtime check.
|
||||
_ -> return ()
|
||||
-}
|
||||
|
||||
dataSpecifier :: OccParser A.Type
|
||||
dataSpecifier
|
||||
= dataType
|
||||
|
|
|
@ -50,6 +50,7 @@ module Properties
|
|||
, processTypesChecked
|
||||
, rainParDeclarationsPulledUp
|
||||
, rangeTransformed
|
||||
, retypesChecked
|
||||
, seqInputsFlattened
|
||||
, slicesSimplified
|
||||
, subscriptsPulledUp
|
||||
|
@ -78,13 +79,30 @@ import Types
|
|||
import Utils
|
||||
|
||||
agg_namesDone :: [Property]
|
||||
agg_namesDone = [declarationsUnique, declarationTypesRecorded, inferredTypesRecorded, declaredNamesResolved]
|
||||
agg_namesDone =
|
||||
[ declarationTypesRecorded
|
||||
, declarationsUnique
|
||||
, declaredNamesResolved
|
||||
, inferredTypesRecorded
|
||||
]
|
||||
|
||||
agg_typesDone :: [Property]
|
||||
agg_typesDone = [expressionTypesChecked, inferredTypesRecorded, processTypesChecked, typesResolvedInAST, typesResolvedInState, constantsFolded, constantsChecked]
|
||||
agg_typesDone =
|
||||
[ constantsChecked
|
||||
, constantsFolded
|
||||
, expressionTypesChecked
|
||||
, inferredTypesRecorded
|
||||
, processTypesChecked
|
||||
, retypesChecked
|
||||
, typesResolvedInAST
|
||||
, typesResolvedInState
|
||||
]
|
||||
|
||||
agg_functionsGone :: [Property]
|
||||
agg_functionsGone = [functionCallsRemoved, functionsRemoved]
|
||||
agg_functionsGone =
|
||||
[ functionCallsRemoved
|
||||
, functionsRemoved
|
||||
]
|
||||
|
||||
-- Mark out all the checks I still need to implement:
|
||||
checkTODO :: Monad m => A.AST -> m ()
|
||||
|
@ -142,10 +160,13 @@ declarationsUnique = Property "declarationsUnique" $
|
|||
checkDupes (n':ns)
|
||||
|
||||
constantsChecked :: Property
|
||||
constantsChecked = Property "constantsChecked" checkTODO
|
||||
constantsChecked = Property "constantsChecked" nocheck
|
||||
|
||||
constantsFolded :: Property
|
||||
constantsFolded = Property "constantsFolded" checkTODO
|
||||
constantsFolded = Property "constantsFolded" nocheck
|
||||
|
||||
retypesChecked :: Property
|
||||
retypesChecked = Property "retypesChecked" nocheck
|
||||
|
||||
intLiteralsInBounds :: Property
|
||||
intLiteralsInBounds = Property "intLiteralsInBounds" $
|
||||
|
|
Loading…
Reference in New Issue
Block a user