diff --git a/frontends/OccamTypes.hs b/frontends/OccamTypes.hs index 0e143a2..cf02d45 100644 --- a/frontends/OccamTypes.hs +++ b/frontends/OccamTypes.hs @@ -181,10 +181,11 @@ checkRecordField m t n when (not $ n `elem` validNames) $ diePC m $ formatCode "Invalid field name % in record type %" n t --- | Check that a subscript is being applied to an appropriate type. -checkSubscriptType :: Meta -> A.Subscript -> A.Type -> PassM () -checkSubscriptType m s rawT - = do t <- underlyingType m rawT +-- | Check a subscript. +checkSubscript :: Meta -> A.Subscript -> A.Type -> PassM () +checkSubscript m s rawT + = do -- Check the type of the thing being subscripted. + t <- underlyingType m rawT case s of -- A record subscript. A.SubscriptField m n -> @@ -194,6 +195,15 @@ checkSubscriptType m s rawT -- An array slice. _ -> checkArray m t + -- Check the subscript itself. + case s of + A.Subscript m _ e -> checkExpressionInt m e + A.SubscriptFromFor m e f -> + checkExpressionInt m e >> checkExpressionInt m f + A.SubscriptFrom m e -> checkExpressionInt m e + A.SubscriptFor m e -> checkExpressionInt m e + _ -> ok + -- | Classes of operators. data OpClass = NumericOp | IntegerOp | ShiftOp | BooleanOp | ComparisonOp | ListOp @@ -328,58 +338,20 @@ checkWritable v -- inside the AST, but it doesn't really make sense to split it up. checkTypes :: Data t => t -> PassM t checkTypes t = - checkSubscripts t >>= - checkLiterals >>= - checkVariables >>= + checkVariables t >>= checkExpressions >>= checkInputItems >>= checkOutputItems >>= checkReplicators >>= checkChoices -checkSubscripts :: Data t => t -> PassM t -checkSubscripts = checkDepthM doSubscript - where - doSubscript :: A.Subscript -> PassM () - doSubscript (A.Subscript m _ e) = checkExpressionInt m e - doSubscript (A.SubscriptFromFor m e f) - = checkExpressionInt m e >> checkExpressionInt m f - doSubscript (A.SubscriptFrom m e) = checkExpressionInt m e - doSubscript (A.SubscriptFor m e) = checkExpressionInt m e - doSubscript _ = ok - -checkLiterals :: Data t => t -> PassM t -checkLiterals = checkDepthM doExpression - where - doExpression :: A.Expression -> PassM () - doExpression (A.Literal m t lr) = doLiteralRepr t lr - doExpression _ = ok - - doLiteralRepr :: A.Type -> A.LiteralRepr -> PassM () - doLiteralRepr t (A.ArrayLiteral m aes) - = doArrayElem m t (A.ArrayElemArray aes) - doLiteralRepr t (A.RecordLiteral m es) - = do rfs <- underlyingType m t >>= recordFields m - when (length es /= length rfs) $ - dieP m $ "Record literal has wrong number of fields: found " ++ (show $ length es) ++ ", expected " ++ (show $ length rfs) - sequence_ [checkExpressionType (findMeta fe) ft fe - | ((_, ft), fe) <- zip rfs es] - doLiteralRepr _ _ = ok - - doArrayElem :: Meta -> A.Type -> A.ArrayElem -> PassM () - doArrayElem m t (A.ArrayElemArray aes) - = do checkArraySize m t (length aes) - t' <- subscriptType (A.Subscript m A.NoCheck undefined) t - sequence_ $ map (doArrayElem m t') aes - doArrayElem _ t (A.ArrayElemExpr e) = checkExpressionType (findMeta e) t e - checkVariables :: Data t => t -> PassM t checkVariables = checkDepthM doVariable where doVariable :: A.Variable -> PassM () doVariable (A.SubscriptedVariable m s v) = do t <- typeOfVariable v - checkSubscriptType m s t + checkSubscript m s t doVariable (A.DirectedVariable m _ v) = do t <- typeOfVariable v >>= underlyingType m case t of @@ -410,19 +382,38 @@ checkExpressions = checkDepthM doExpression doExpression (A.Conversion m _ t e) = do et <- typeOfExpression e checkScalar m t >> checkScalar (findMeta e) et + doExpression (A.Literal m t lr) = doLiteralRepr t lr doExpression (A.FunctionCall m n es) = checkFunctionCall m n es True doExpression (A.IntrinsicFunctionCall m s es) = checkIntrinsicFunctionCall m s es True doExpression (A.SubscriptedExpr m s e) = do t <- typeOfExpression e - checkSubscriptType m s t + checkSubscript m s t doExpression (A.OffsetOf m rawT n) = do t <- underlyingType m rawT checkRecordField m t n doExpression (A.AllocMobile m t me) = checkAllocMobile m t me doExpression _ = ok + doLiteralRepr :: A.Type -> A.LiteralRepr -> PassM () + doLiteralRepr t (A.ArrayLiteral m aes) + = doArrayElem m t (A.ArrayElemArray aes) + doLiteralRepr t (A.RecordLiteral m es) + = do rfs <- underlyingType m t >>= recordFields m + when (length es /= length rfs) $ + dieP m $ "Record literal has wrong number of fields: found " ++ (show $ length es) ++ ", expected " ++ (show $ length rfs) + sequence_ [checkExpressionType (findMeta fe) ft fe + | ((_, ft), fe) <- zip rfs es] + doLiteralRepr _ _ = ok + + doArrayElem :: Meta -> A.Type -> A.ArrayElem -> PassM () + doArrayElem m t (A.ArrayElemArray aes) + = do checkArraySize m t (length aes) + t' <- subscriptType (A.Subscript m A.NoCheck undefined) t + sequence_ $ map (doArrayElem m t') aes + doArrayElem _ t (A.ArrayElemExpr e) = checkExpressionType (findMeta e) t e + checkInputItems :: Data t => t -> PassM t checkInputItems = checkDepthM doInputItem where diff --git a/frontends/OccamTypesTest.hs b/frontends/OccamTypesTest.hs index 7e7df4b..7b7ec1f 100644 --- a/frontends/OccamTypesTest.hs +++ b/frontends/OccamTypesTest.hs @@ -68,14 +68,14 @@ testOccamTypes :: Test testOccamTypes = TestList [ -- Subscript expressions - testOK 0 $ A.Subscript m A.NoCheck intE - , testFail 1 $ A.Subscript m A.NoCheck byteE - , testOK 2 $ A.SubscriptFromFor m intE intE - , testFail 3 $ A.SubscriptFromFor m byteE byteE - , testOK 4 $ A.SubscriptFrom m intE - , testFail 5 $ A.SubscriptFrom m byteE - , testOK 6 $ A.SubscriptFor m intE - , testFail 7 $ A.SubscriptFor m byteE + testOK 0 $ subex $ A.Subscript m A.NoCheck intE + , testFail 1 $ subex $ A.Subscript m A.NoCheck byteE + , testOK 2 $ subex $ A.SubscriptFromFor m intE intE + , testFail 3 $ subex $ A.SubscriptFromFor m byteE byteE + , testOK 4 $ subex $ A.SubscriptFrom m intE + , testFail 5 $ subex $ A.SubscriptFrom m byteE + , testOK 6 $ subex $ A.SubscriptFor m intE + , testFail 7 $ subex $ A.SubscriptFor m byteE -- Trivial literals , testOK 20 $ intE @@ -232,6 +232,7 @@ testOccamTypes = TestList (OccamTypes.checkTypes orig) startState + subex sub = A.SubscriptedExpr m sub twoIntsE intV = variable "varInt" intE = intLiteral 42 realV = variable "varReal"