Move Retypes checking from the occam parser into a pass.
This also fixes a bug in the original algorithm: it used to let you retype []INT to BYTE.
This commit is contained in:
parent
b8caf7c3b6
commit
e08aac59d3
|
@ -17,7 +17,7 @@ with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
-}
|
-}
|
||||||
|
|
||||||
-- | Evaluate constant expressions.
|
-- | Evaluate constant expressions.
|
||||||
module EvalConstants (constantFold, isConstantName) where
|
module EvalConstants (constantFold, maybeEvalIntExpression, isConstantName) where
|
||||||
|
|
||||||
import Control.Monad.Error
|
import Control.Monad.Error
|
||||||
import Control.Monad.State
|
import Control.Monad.State
|
||||||
|
@ -34,6 +34,7 @@ import EvalLiterals
|
||||||
import Metadata
|
import Metadata
|
||||||
import ShowCode
|
import ShowCode
|
||||||
import Types
|
import Types
|
||||||
|
import Utils
|
||||||
|
|
||||||
-- | Simplify an expression by constant folding, and also return whether it's a
|
-- | Simplify an expression by constant folding, and also return whether it's a
|
||||||
-- constant after that.
|
-- constant after that.
|
||||||
|
@ -45,6 +46,15 @@ constantFold e
|
||||||
Right val -> (val, (Nothing, "already folded"))
|
Right val -> (val, (Nothing, "already folded"))
|
||||||
return (e', isConstant e', msg)
|
return (e', isConstant e', msg)
|
||||||
|
|
||||||
|
-- | Try to fold and evaluate an integer expression.
|
||||||
|
-- If it's not a constant, return 'Nothing'.
|
||||||
|
maybeEvalIntExpression :: (CSMR m, Die m) => A.Expression -> m (Maybe Int)
|
||||||
|
maybeEvalIntExpression e
|
||||||
|
= do (e', isConst, _) <- constantFold e
|
||||||
|
if isConst
|
||||||
|
then evalIntExpression e' >>* Just
|
||||||
|
else return Nothing
|
||||||
|
|
||||||
-- | Is a name defined as a constant expression? If so, return its definition.
|
-- | Is a name defined as a constant expression? If so, return its definition.
|
||||||
getConstantName :: (CSMR m, Die m) => A.Name -> m (Maybe A.Expression)
|
getConstantName :: (CSMR m, Die m) => A.Name -> m (Maybe A.Expression)
|
||||||
getConstantName n
|
getConstantName n
|
||||||
|
|
|
@ -168,8 +168,16 @@ exprVariablePattern :: String -> Pattern
|
||||||
exprVariablePattern e = tag2 A.ExprVariable DontCare $ variablePattern e
|
exprVariablePattern e = tag2 A.ExprVariable DontCare $ variablePattern e
|
||||||
|
|
||||||
-- | Creates an integer literal 'A.Expression' with the given integer.
|
-- | Creates an integer literal 'A.Expression' with the given integer.
|
||||||
|
integerLiteral :: A.Type -> Integer -> A.Expression
|
||||||
|
integerLiteral t n = A.Literal emptyMeta t $ A.IntLiteral emptyMeta (show n)
|
||||||
|
|
||||||
|
-- | Creates an 'A.Int' literal with the given integer.
|
||||||
intLiteral :: Integer -> A.Expression
|
intLiteral :: Integer -> A.Expression
|
||||||
intLiteral n = A.Literal emptyMeta A.Int $ A.IntLiteral emptyMeta (show n)
|
intLiteral n = integerLiteral A.Int n
|
||||||
|
|
||||||
|
-- | Creates an 'A.Byte' literal with the given integer.
|
||||||
|
byteLiteral :: Integer -> A.Expression
|
||||||
|
byteLiteral n = integerLiteral A.Byte n
|
||||||
|
|
||||||
-- | Creates a 'Pattern' to match an 'A.Expression' instance.
|
-- | Creates a 'Pattern' to match an 'A.Expression' instance.
|
||||||
-- @'assertPatternMatch' ('intLiteralPattern' x) ('intLiteral' x)@ will always succeed.
|
-- @'assertPatternMatch' ('intLiteralPattern' x) ('intLiteral' x)@ will always succeed.
|
||||||
|
@ -276,6 +284,19 @@ simpleDefDecl n t = simpleDef n (A.Declaration emptyMeta t)
|
||||||
simpleDefPattern :: String -> A.AbbrevMode -> Pattern -> Pattern
|
simpleDefPattern :: String -> A.AbbrevMode -> Pattern -> Pattern
|
||||||
simpleDefPattern n am sp = tag7 A.NameDef DontCare n n A.VariableName sp am A.Unplaced
|
simpleDefPattern n am sp = tag7 A.NameDef DontCare n n A.VariableName sp am A.Unplaced
|
||||||
|
|
||||||
|
-- | Define a @VAL IS@ constant.
|
||||||
|
defineConst :: String -> A.Type -> A.Expression -> State CompState ()
|
||||||
|
defineConst s t e = defineName (simpleName s) $
|
||||||
|
A.NameDef {
|
||||||
|
A.ndMeta = emptyMeta,
|
||||||
|
A.ndName = s,
|
||||||
|
A.ndOrigName = s,
|
||||||
|
A.ndNameType = A.VariableName,
|
||||||
|
A.ndType = A.IsExpr emptyMeta A.ValAbbrev t e,
|
||||||
|
A.ndAbbrevMode = A.ValAbbrev,
|
||||||
|
A.ndPlacement = A.Unplaced
|
||||||
|
}
|
||||||
|
|
||||||
--}}}
|
--}}}
|
||||||
--{{{ custom assertions
|
--{{{ custom assertions
|
||||||
|
|
||||||
|
|
|
@ -17,19 +17,23 @@ with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
-}
|
-}
|
||||||
|
|
||||||
-- | The occam-specific frontend passes.
|
-- | The occam-specific frontend passes.
|
||||||
module OccamPasses (occamPasses, foldConstants, checkConstants) where
|
module OccamPasses (occamPasses, foldConstants, checkConstants,
|
||||||
|
checkRetypes) where
|
||||||
|
|
||||||
import Control.Monad
|
import Control.Monad.State
|
||||||
import Data.Generics
|
import Data.Generics
|
||||||
|
import System.IO
|
||||||
|
|
||||||
import qualified AST as A
|
import qualified AST as A
|
||||||
import CompState
|
import CompState
|
||||||
|
import Errors
|
||||||
import EvalConstants
|
import EvalConstants
|
||||||
import EvalLiterals
|
import EvalLiterals
|
||||||
import Metadata
|
import Metadata
|
||||||
import Pass
|
import Pass
|
||||||
import qualified Properties as Prop
|
import qualified Properties as Prop
|
||||||
import ShowCode
|
import ShowCode
|
||||||
|
import Types
|
||||||
|
|
||||||
-- | Occam-specific frontend passes.
|
-- | Occam-specific frontend passes.
|
||||||
occamPasses :: [Pass]
|
occamPasses :: [Pass]
|
||||||
|
@ -40,6 +44,9 @@ occamPasses = makePassesDep' ((== FrontendOccam) . csFrontend)
|
||||||
, ("Check mandatory constants", checkConstants,
|
, ("Check mandatory constants", checkConstants,
|
||||||
[Prop.constantsFolded],
|
[Prop.constantsFolded],
|
||||||
[Prop.constantsChecked])
|
[Prop.constantsChecked])
|
||||||
|
, ("Check retyping", checkRetypes,
|
||||||
|
[],
|
||||||
|
[Prop.retypesChecked])
|
||||||
, ("Dummy occam pass", dummyOccamPass,
|
, ("Dummy occam pass", dummyOccamPass,
|
||||||
[],
|
[],
|
||||||
Prop.agg_namesDone ++ [Prop.expressionTypesChecked,
|
Prop.agg_namesDone ++ [Prop.expressionTypesChecked,
|
||||||
|
@ -97,6 +104,49 @@ checkConstants = doGeneric `extM` doDimension `extM` doOption
|
||||||
doGeneric o
|
doGeneric o
|
||||||
doOption o = doGeneric o
|
doOption o = doGeneric o
|
||||||
|
|
||||||
|
-- | Check that retyping is safe.
|
||||||
|
checkRetypes :: Data t => t -> PassM t
|
||||||
|
checkRetypes = everywhereASTM doSpecType
|
||||||
|
where
|
||||||
|
doSpecType :: A.SpecType -> PassM A.SpecType
|
||||||
|
doSpecType st@(A.Retypes m _ t v)
|
||||||
|
= do fromT <- typeOfVariable v
|
||||||
|
checkRetypes m fromT t
|
||||||
|
return st
|
||||||
|
doSpecType st@(A.RetypesExpr m _ t e)
|
||||||
|
= do fromT <- typeOfExpression e
|
||||||
|
checkRetypes m fromT t
|
||||||
|
return st
|
||||||
|
doSpecType st = return st
|
||||||
|
|
||||||
|
checkRetypes :: Meta -> A.Type -> A.Type -> PassM ()
|
||||||
|
checkRetypes m fromT toT
|
||||||
|
= do (fromBI, fromN) <- evalBytesInType fromT
|
||||||
|
(toBI, toN) <- evalBytesInType toT
|
||||||
|
case (fromBI, toBI, fromN, toN) of
|
||||||
|
(_, BIManyFree, _, _) ->
|
||||||
|
dieP m "Multiple free dimensions in retype destination type"
|
||||||
|
(BIJust _, BIJust _, Just a, Just b) ->
|
||||||
|
when (a /= b) $
|
||||||
|
dieP m "Sizes do not match in retype"
|
||||||
|
(BIJust _, BIOneFree _ _, Just a, Just b) ->
|
||||||
|
when (not ((b <= a) && (a `mod` b == 0))) $
|
||||||
|
dieP m "Sizes do not match in retype"
|
||||||
|
(BIOneFree _ _, BIJust _, Just a, Just b) ->
|
||||||
|
when (not ((a <= b) && (b `mod` a == 0))) $
|
||||||
|
dieP m "Sizes do not match in retype"
|
||||||
|
-- Otherwise we must do a runtime check.
|
||||||
|
_ -> return ()
|
||||||
|
|
||||||
|
evalBytesInType :: A.Type -> PassM (BytesInResult, Maybe Int)
|
||||||
|
evalBytesInType t
|
||||||
|
= do bi <- bytesInType t
|
||||||
|
n <- case bi of
|
||||||
|
BIJust e -> maybeEvalIntExpression e
|
||||||
|
BIOneFree e _ -> maybeEvalIntExpression e
|
||||||
|
_ -> return Nothing
|
||||||
|
return (bi, n)
|
||||||
|
|
||||||
-- | A dummy pass for things that haven't been separated out into passes yet.
|
-- | A dummy pass for things that haven't been separated out into passes yet.
|
||||||
dummyOccamPass :: Data t => t -> PassM t
|
dummyOccamPass :: Data t => t -> PassM t
|
||||||
dummyOccamPass = return
|
dummyOccamPass = return
|
||||||
|
|
|
@ -35,6 +35,17 @@ import TestUtils
|
||||||
m :: Meta
|
m :: Meta
|
||||||
m = emptyMeta
|
m = emptyMeta
|
||||||
|
|
||||||
|
-- | Initial state for the tests.
|
||||||
|
startState :: State CompState ()
|
||||||
|
startState
|
||||||
|
= do defineConst "const" A.Int (intLiteral 2)
|
||||||
|
defineConst "someInt" A.Int (intLiteral 42)
|
||||||
|
defineConst "someByte" A.Byte (byteLiteral 24)
|
||||||
|
defineConst "someInts" (A.Array [A.UnknownDimension] A.Int)
|
||||||
|
undefined
|
||||||
|
defineConst "someBytes" (A.Array [A.UnknownDimension] A.Byte)
|
||||||
|
undefined
|
||||||
|
|
||||||
-- | Test 'OccamPasses.foldConstants'.
|
-- | Test 'OccamPasses.foldConstants'.
|
||||||
testFoldConstants :: Test
|
testFoldConstants :: Test
|
||||||
testFoldConstants = TestList
|
testFoldConstants = TestList
|
||||||
|
@ -81,21 +92,6 @@ testFoldConstants = TestList
|
||||||
exp (OccamPasses.foldConstants orig)
|
exp (OccamPasses.foldConstants orig)
|
||||||
startState
|
startState
|
||||||
|
|
||||||
startState :: State CompState ()
|
|
||||||
startState = defineConst "const" A.Int two
|
|
||||||
|
|
||||||
defineConst :: String -> A.Type -> A.Expression -> State CompState ()
|
|
||||||
defineConst s t e = defineName (simpleName s) $
|
|
||||||
A.NameDef {
|
|
||||||
A.ndMeta = m,
|
|
||||||
A.ndName = "const",
|
|
||||||
A.ndOrigName = "const",
|
|
||||||
A.ndNameType = A.VariableName,
|
|
||||||
A.ndType = A.IsExpr m A.ValAbbrev t e,
|
|
||||||
A.ndAbbrevMode = A.ValAbbrev,
|
|
||||||
A.ndPlacement = A.Unplaced
|
|
||||||
}
|
|
||||||
|
|
||||||
testSame :: Int -> A.Expression -> Test
|
testSame :: Int -> A.Expression -> Test
|
||||||
testSame n orig = test n orig orig
|
testSame n orig = test n orig orig
|
||||||
|
|
||||||
|
@ -158,8 +154,64 @@ testCheckConstants = TestList
|
||||||
var = exprVariable "var"
|
var = exprVariable "var"
|
||||||
skip = A.Skip m
|
skip = A.Skip m
|
||||||
|
|
||||||
|
-- | Test 'OccamPasses.checkRetypes'.
|
||||||
|
testCheckRetypes :: Test
|
||||||
|
testCheckRetypes = TestList
|
||||||
|
[
|
||||||
|
-- Definitely OK at compile time
|
||||||
|
testOK 0 $ retypesV A.Int intV
|
||||||
|
, testOK 1 $ retypesE A.Int intE
|
||||||
|
, testOK 2 $ retypesV A.Byte byteV
|
||||||
|
, testOK 3 $ retypesE A.Byte byteE
|
||||||
|
, testOK 4 $ retypesV known1 intV
|
||||||
|
, testOK 5 $ retypesV known2 intV
|
||||||
|
, testOK 6 $ retypesV both intV
|
||||||
|
, testOK 7 $ retypesV unknown1 intV
|
||||||
|
|
||||||
|
-- Definitely wrong at compile time
|
||||||
|
, testFail 100 $ retypesV A.Byte intV
|
||||||
|
, testFail 101 $ retypesV A.Int byteV
|
||||||
|
, testFail 102 $ retypesV unknown2 intV
|
||||||
|
, testFail 103 $ retypesV unknown2 intsV
|
||||||
|
, testFail 104 $ retypesV A.Byte intsV
|
||||||
|
|
||||||
|
-- Can't tell; need a runtime check
|
||||||
|
, testOK 200 $ retypesV unknown1 intsV
|
||||||
|
, testOK 201 $ retypesV A.Int intsV
|
||||||
|
, testOK 202 $ retypesV known2 intsV
|
||||||
|
, testOK 203 $ retypesV unknown1 bytesV
|
||||||
|
]
|
||||||
|
where
|
||||||
|
testOK :: (Show a, Data a) => Int -> a -> Test
|
||||||
|
testOK n orig
|
||||||
|
= TestCase $ testPass ("testCheckRetypes" ++ show n)
|
||||||
|
orig (OccamPasses.checkRetypes orig)
|
||||||
|
startState
|
||||||
|
|
||||||
|
testFail :: (Show a, Data a) => Int -> a -> Test
|
||||||
|
testFail n orig
|
||||||
|
= TestCase $ testPassShouldFail ("testCheckRetypes" ++ show n)
|
||||||
|
(OccamPasses.checkRetypes orig)
|
||||||
|
startState
|
||||||
|
|
||||||
|
retypesV = A.Retypes m A.ValAbbrev
|
||||||
|
retypesE = A.RetypesExpr m A.ValAbbrev
|
||||||
|
|
||||||
|
intV = variable "someInt"
|
||||||
|
intE = intLiteral 42
|
||||||
|
byteV = variable "someByte"
|
||||||
|
byteE = byteLiteral 42
|
||||||
|
intsV = variable "someInts"
|
||||||
|
bytesV = variable "someBytes"
|
||||||
|
known1 = A.Array [dimension 4] A.Byte
|
||||||
|
known2 = A.Array [dimension 2, dimension 2] A.Byte
|
||||||
|
both = A.Array [dimension 2, A.UnknownDimension] A.Byte
|
||||||
|
unknown1 = A.Array [A.UnknownDimension] A.Int
|
||||||
|
unknown2 = A.Array [A.UnknownDimension, A.UnknownDimension] A.Int
|
||||||
|
|
||||||
tests :: Test
|
tests :: Test
|
||||||
tests = TestLabel "OccamPassesTest" $ TestList
|
tests = TestLabel "OccamPassesTest" $ TestList
|
||||||
[ testFoldConstants
|
[ testFoldConstants
|
||||||
, testCheckConstants
|
, testCheckConstants
|
||||||
|
, testCheckRetypes
|
||||||
]
|
]
|
||||||
|
|
|
@ -1428,7 +1428,6 @@ retypesAbbrev
|
||||||
sColon
|
sColon
|
||||||
eol
|
eol
|
||||||
origT <- typeOfVariable v
|
origT <- typeOfVariable v
|
||||||
--checkRetypes m origT s
|
|
||||||
return $ A.Specification m n $ A.Retypes m A.Abbrev s v
|
return $ A.Specification m n $ A.Retypes m A.Abbrev s v
|
||||||
<|> do m <- md
|
<|> do m <- md
|
||||||
(s, n) <- tryVVX channelSpecifier newChannelName retypesReshapes
|
(s, n) <- tryVVX channelSpecifier newChannelName retypesReshapes
|
||||||
|
@ -1436,7 +1435,6 @@ retypesAbbrev
|
||||||
sColon
|
sColon
|
||||||
eol
|
eol
|
||||||
origT <- typeOfVariable c
|
origT <- typeOfVariable c
|
||||||
--checkRetypes m origT s
|
|
||||||
return $ A.Specification m n $ A.Retypes m A.Abbrev s c
|
return $ A.Specification m n $ A.Retypes m A.Abbrev s c
|
||||||
<|> do m <- md
|
<|> do m <- md
|
||||||
(s, n) <- tryXVVX sVAL dataSpecifier newVariableName retypesReshapes
|
(s, n) <- tryXVVX sVAL dataSpecifier newVariableName retypesReshapes
|
||||||
|
@ -1444,29 +1442,9 @@ retypesAbbrev
|
||||||
sColon
|
sColon
|
||||||
eol
|
eol
|
||||||
origT <- typeOfExpression e
|
origT <- typeOfExpression e
|
||||||
--checkRetypes m origT s
|
|
||||||
return $ A.Specification m n $ A.RetypesExpr m A.ValAbbrev s e
|
return $ A.Specification m n $ A.RetypesExpr m A.ValAbbrev s e
|
||||||
<?> "RETYPES/RESHAPES abbreviation"
|
<?> "RETYPES/RESHAPES abbreviation"
|
||||||
|
|
||||||
{-
|
|
||||||
-- | Check that a RETYPES\/RESHAPES is safe.
|
|
||||||
checkRetypes :: Meta -> A.Type -> A.Type -> OccParser ()
|
|
||||||
-- Retyping channels is always "safe".
|
|
||||||
checkRetypes _ (A.Chan {}) (A.Chan {}) = return ()
|
|
||||||
checkRetypes m fromT toT
|
|
||||||
= do bf <- bytesInType fromT
|
|
||||||
bt <- bytesInType toT
|
|
||||||
case (bf, bt) of
|
|
||||||
(BIJust a, BIJust b) ->
|
|
||||||
when (a /= b) $ dieP m "size mismatch in RETYPES"
|
|
||||||
(BIJust a, BIOneFree b _) ->
|
|
||||||
when (not ((b <= a) && (a `mod` b == 0))) $ dieP m "size mismatch in RETYPES"
|
|
||||||
(_, BIManyFree) ->
|
|
||||||
dieP m "multiple free dimensions in RETYPES/RESHAPES type"
|
|
||||||
-- Otherwise we have to do a runtime check.
|
|
||||||
_ -> return ()
|
|
||||||
-}
|
|
||||||
|
|
||||||
dataSpecifier :: OccParser A.Type
|
dataSpecifier :: OccParser A.Type
|
||||||
dataSpecifier
|
dataSpecifier
|
||||||
= dataType
|
= dataType
|
||||||
|
|
|
@ -50,6 +50,7 @@ module Properties
|
||||||
, processTypesChecked
|
, processTypesChecked
|
||||||
, rainParDeclarationsPulledUp
|
, rainParDeclarationsPulledUp
|
||||||
, rangeTransformed
|
, rangeTransformed
|
||||||
|
, retypesChecked
|
||||||
, seqInputsFlattened
|
, seqInputsFlattened
|
||||||
, slicesSimplified
|
, slicesSimplified
|
||||||
, subscriptsPulledUp
|
, subscriptsPulledUp
|
||||||
|
@ -78,13 +79,30 @@ import Types
|
||||||
import Utils
|
import Utils
|
||||||
|
|
||||||
agg_namesDone :: [Property]
|
agg_namesDone :: [Property]
|
||||||
agg_namesDone = [declarationsUnique, declarationTypesRecorded, inferredTypesRecorded, declaredNamesResolved]
|
agg_namesDone =
|
||||||
|
[ declarationTypesRecorded
|
||||||
|
, declarationsUnique
|
||||||
|
, declaredNamesResolved
|
||||||
|
, inferredTypesRecorded
|
||||||
|
]
|
||||||
|
|
||||||
agg_typesDone :: [Property]
|
agg_typesDone :: [Property]
|
||||||
agg_typesDone = [expressionTypesChecked, inferredTypesRecorded, processTypesChecked, typesResolvedInAST, typesResolvedInState, constantsFolded, constantsChecked]
|
agg_typesDone =
|
||||||
|
[ constantsChecked
|
||||||
|
, constantsFolded
|
||||||
|
, expressionTypesChecked
|
||||||
|
, inferredTypesRecorded
|
||||||
|
, processTypesChecked
|
||||||
|
, retypesChecked
|
||||||
|
, typesResolvedInAST
|
||||||
|
, typesResolvedInState
|
||||||
|
]
|
||||||
|
|
||||||
agg_functionsGone :: [Property]
|
agg_functionsGone :: [Property]
|
||||||
agg_functionsGone = [functionCallsRemoved, functionsRemoved]
|
agg_functionsGone =
|
||||||
|
[ functionCallsRemoved
|
||||||
|
, functionsRemoved
|
||||||
|
]
|
||||||
|
|
||||||
-- Mark out all the checks I still need to implement:
|
-- Mark out all the checks I still need to implement:
|
||||||
checkTODO :: Monad m => A.AST -> m ()
|
checkTODO :: Monad m => A.AST -> m ()
|
||||||
|
@ -142,10 +160,13 @@ declarationsUnique = Property "declarationsUnique" $
|
||||||
checkDupes (n':ns)
|
checkDupes (n':ns)
|
||||||
|
|
||||||
constantsChecked :: Property
|
constantsChecked :: Property
|
||||||
constantsChecked = Property "constantsChecked" checkTODO
|
constantsChecked = Property "constantsChecked" nocheck
|
||||||
|
|
||||||
constantsFolded :: Property
|
constantsFolded :: Property
|
||||||
constantsFolded = Property "constantsFolded" checkTODO
|
constantsFolded = Property "constantsFolded" nocheck
|
||||||
|
|
||||||
|
retypesChecked :: Property
|
||||||
|
retypesChecked = Property "retypesChecked" nocheck
|
||||||
|
|
||||||
intLiteralsInBounds :: Property
|
intLiteralsInBounds :: Property
|
||||||
intLiteralsInBounds = Property "intLiteralsInBounds" $
|
intLiteralsInBounds = Property "intLiteralsInBounds" $
|
||||||
|
|
Loading…
Reference in New Issue
Block a user