Changed inferTypes to resolve operators to the right definition
This patch follows on from the previous change to the parser. When it spots a function-call, it looks for operators and treats them differently. It keeps a stack of operators in scope (csOperators in CompState), and when an operator is used, it searches the stack (with all old definitions masked out) for operator definitions to resolve to. The way it chooses which operator to use in the presence of overloadings (e.g. + on INT vs + on INT32) is simply to try them all. If one matches, it uses that. If none, or more than one match, it gives an error. This makes the code simple and seems logical, but I'm not totally confident if this is the required behaviour for resolving overloaded operators.
This commit is contained in:
parent
f7e114f2fd
commit
e1c18cc082
|
@ -149,6 +149,8 @@ data CompState = CompState {
|
|||
csAdditionalArgs :: Map String [A.Actual],
|
||||
csParProcs :: Set A.Name,
|
||||
csUnifyId :: Int,
|
||||
-- The string is the operator, the name is the munged function name
|
||||
csOperators :: [(String, A.Name, [A.Type])],
|
||||
csWarnings :: [WarningReport]
|
||||
}
|
||||
deriving (Data, Typeable, Show)
|
||||
|
@ -205,6 +207,7 @@ emptyState = CompState {
|
|||
csAdditionalArgs = Map.empty,
|
||||
csParProcs = Set.empty,
|
||||
csUnifyId = 0,
|
||||
csOperators = [],
|
||||
csWarnings = []
|
||||
}
|
||||
|
||||
|
|
|
@ -19,10 +19,14 @@ with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
-- | The occam typechecker.
|
||||
module OccamTypes (inferTypes, checkTypes, addDirections) where
|
||||
|
||||
import Control.Monad.Error
|
||||
import Control.Monad.Reader
|
||||
import Control.Monad.State
|
||||
import Data.Function (on)
|
||||
import Data.Generics
|
||||
import Data.List
|
||||
import qualified Data.Map as Map
|
||||
import Data.Maybe
|
||||
|
||||
import qualified AST as A
|
||||
import CompState
|
||||
|
@ -250,74 +254,6 @@ checkSubscript m s rawT
|
|||
A.SubscriptFor m _ e -> checkExpressionInt e
|
||||
_ -> ok
|
||||
|
||||
-- | Classes of operators.
|
||||
data OpClass = NumericOp | IntegerOp | ShiftOp | BooleanOp | ComparisonOp
|
||||
| ListOp
|
||||
|
||||
-- | Figure out the class of a monadic operator.
|
||||
classifyMOp :: A.MonadicOp -> OpClass
|
||||
classifyMOp A.MonadicSubtr = NumericOp
|
||||
classifyMOp A.MonadicMinus = NumericOp
|
||||
classifyMOp A.MonadicBitNot = IntegerOp
|
||||
classifyMOp A.MonadicNot = BooleanOp
|
||||
|
||||
-- | Figure out the class of a dyadic operator.
|
||||
classifyOp :: A.DyadicOp -> OpClass
|
||||
classifyOp A.Add = NumericOp
|
||||
classifyOp A.Subtr = NumericOp
|
||||
classifyOp A.Mul = NumericOp
|
||||
classifyOp A.Div = NumericOp
|
||||
classifyOp A.Rem = NumericOp
|
||||
classifyOp A.Plus = NumericOp
|
||||
classifyOp A.Minus = NumericOp
|
||||
classifyOp A.Times = NumericOp
|
||||
classifyOp A.BitAnd = IntegerOp
|
||||
classifyOp A.BitOr = IntegerOp
|
||||
classifyOp A.BitXor = IntegerOp
|
||||
classifyOp A.LeftShift = ShiftOp
|
||||
classifyOp A.RightShift = ShiftOp
|
||||
classifyOp A.And = BooleanOp
|
||||
classifyOp A.Or = BooleanOp
|
||||
classifyOp A.Eq = ComparisonOp
|
||||
classifyOp A.NotEq = ComparisonOp
|
||||
classifyOp A.Less = ComparisonOp
|
||||
classifyOp A.More = ComparisonOp
|
||||
classifyOp A.LessEq = ComparisonOp
|
||||
classifyOp A.MoreEq = ComparisonOp
|
||||
classifyOp A.After = ComparisonOp
|
||||
classifyOp A.Concat = ListOp
|
||||
|
||||
-- | Check a monadic operator.
|
||||
checkMonadicOp :: A.MonadicOp -> A.Expression -> PassM ()
|
||||
checkMonadicOp op e
|
||||
= do t <- astTypeOf e
|
||||
let m = findMeta e
|
||||
case classifyMOp op of
|
||||
NumericOp -> checkNumeric m t
|
||||
IntegerOp -> checkInteger m t
|
||||
BooleanOp -> checkType m A.Bool t
|
||||
|
||||
-- | Check a dyadic operator.
|
||||
checkDyadicOp :: A.DyadicOp -> A.Expression -> A.Expression -> PassM ()
|
||||
checkDyadicOp op l r
|
||||
= do lt <- astTypeOf l
|
||||
let lm = findMeta l
|
||||
rt <- astTypeOf r
|
||||
let rm = findMeta r
|
||||
case classifyOp op of
|
||||
NumericOp ->
|
||||
checkNumeric lm lt >> checkNumeric rm rt >> checkType rm lt rt
|
||||
IntegerOp ->
|
||||
checkInteger lm lt >> checkInteger rm rt >> checkType rm lt rt
|
||||
ShiftOp ->
|
||||
checkNumeric lm lt >> checkType rm A.Int rt
|
||||
BooleanOp ->
|
||||
checkType lm A.Bool lt >> checkType rm A.Bool rt
|
||||
ComparisonOp ->
|
||||
checkScalar lm lt >> checkScalar rm rt >> checkType rm lt rt
|
||||
ListOp ->
|
||||
checkList lm lt >> checkList rm rt >> checkType rm lt rt
|
||||
|
||||
-- | Check an abbreviation.
|
||||
-- Is the second abbrev mode a valid abbreviation of the first?
|
||||
checkAbbrev :: Meta -> A.AbbrevMode -> A.AbbrevMode -> PassM ()
|
||||
|
@ -367,7 +303,7 @@ checkActual (A.Formal newAM et _) a
|
|||
-- | Check a function exists.
|
||||
checkFunction :: Meta -> A.Name -> PassM ([A.Type], [A.Formal])
|
||||
checkFunction m n
|
||||
= do st <- specTypeOfName n
|
||||
= do st <- lookupNameOrError n (diePC m $ formatCode "Could not find function %" n) >>* A.ndSpecType
|
||||
case st of
|
||||
A.Function _ _ rs fs _ -> return (rs, fs)
|
||||
_ -> diePC m $ formatCode "% is not a function" n
|
||||
|
@ -742,27 +678,11 @@ inferTypes = occamOnlyPass "Infer types"
|
|||
|
||||
-- Expressions that aren't literals, but that modify the type
|
||||
-- context.
|
||||
A.Dyadic m op le re ->
|
||||
let -- Both types are the same.
|
||||
bothSame
|
||||
= do lt <- recurse le >>= astTypeOf
|
||||
rt <- recurse re >>= astTypeOf
|
||||
inTypeContext (Just $ betterType lt rt) $
|
||||
descend outer
|
||||
-- The RHS type is always A.Int.
|
||||
intOnRight
|
||||
= do le' <- recurse le
|
||||
re' <- inTypeContext (Just A.Int) $ recurse re
|
||||
return $ A.Dyadic m op le' re'
|
||||
in scrubMobile $ case classifyOp op of
|
||||
ComparisonOp -> noTypeContext $ bothSame
|
||||
ShiftOp -> intOnRight
|
||||
_ -> bothSame
|
||||
A.SizeExpr _ _ -> noTypeContext $ descend outer
|
||||
A.Conversion _ _ _ _ -> noTypeContext $ descend outer
|
||||
A.FunctionCall m n es ->
|
||||
do es' <- doFunctionCall m n es
|
||||
return $ A.FunctionCall m n es'
|
||||
do (n', es') <- doFunctionCall m (n, es)
|
||||
return $ A.FunctionCall m n' es'
|
||||
A.IntrinsicFunctionCall _ _ _ -> noTypeContext $ descend outer
|
||||
A.SubscriptedExpr m s e ->
|
||||
do ctx <- getTypeContext
|
||||
|
@ -789,19 +709,63 @@ inferTypes = occamOnlyPass "Infer types"
|
|||
-- Other expressions don't modify the type context.
|
||||
_ -> descend outer
|
||||
|
||||
doFunctionCall :: Meta -> A.Name -> Transform [A.Expression]
|
||||
doFunctionCall m n es
|
||||
= do (_, fs) <- checkFunction m n
|
||||
doActuals m n fs (error "Cannot direct channels passed to FUNCTIONs") es
|
||||
doFunctionCall :: Meta -> Transform (A.Name, [A.Expression])
|
||||
doFunctionCall m (n, es) = do
|
||||
if isOperator (A.nameName n)
|
||||
then
|
||||
-- for operators, resolve the function name, based on the type
|
||||
do let opDescrip = "\"" ++ (A.nameName n) ++ "\" "
|
||||
++ case length es of
|
||||
1 -> "unary"
|
||||
2 -> "binary"
|
||||
n -> show n ++ "-ary"
|
||||
|
||||
cs <- getCompState
|
||||
|
||||
-- The nubBy will ensure that only one definition remains for each
|
||||
-- set of type-arguments, and will keep the first definition in the
|
||||
-- list (which will be the most recent)
|
||||
possibles <- sequence
|
||||
[ (do es' <- sequence
|
||||
[do e' <- doActual m direct t e
|
||||
checkActual (A.Formal A.ValAbbrev t (A.Name m "x"))
|
||||
(A.ActualExpression e')
|
||||
return e'
|
||||
| (t, e) <- zip ts es]
|
||||
return $ Right ((opFuncName, es'), ts)
|
||||
) `catchError` (return . Left)
|
||||
| (raw, opFuncName, ts) <- nubBy ((==) `on` (\(op,_,ts) -> (op,ts))) $ csOperators cs
|
||||
-- Must be right operator:
|
||||
, raw == A.nameName n
|
||||
-- Must be right arity:
|
||||
, length ts == length es]
|
||||
case splitEither possibles of
|
||||
-- We want to be helpful and give the user an idea
|
||||
-- of what we thought the types were, but we must
|
||||
-- also be careful not to die while getting the
|
||||
-- types (and thus missing the real error!)
|
||||
(errs,[]) -> do tes <- sequence [astTypeOf e `catchError` (const $ return A.Infer) | e <- es]
|
||||
diePC m $ formatCode ("No matching " ++ opDescrip ++ " operator definition found for types: %"
|
||||
++ " errors were: " ++ show errs) tes
|
||||
(_, [poss]) -> return $ fst poss
|
||||
(_, posss) -> dieP m $ "Ambigious " ++ opDescrip ++ " operator, matches definitions: "
|
||||
++ show (map (transformPair (A.nameMeta . fst) showOccam) posss)
|
||||
else
|
||||
do (_, fs) <- checkFunction m n
|
||||
doActuals m n fs direct es >>* (,) n
|
||||
where
|
||||
direct = error "Cannot direct channels passed to FUNCTIONs"
|
||||
|
||||
doActuals :: Data a => Meta -> A.Name -> [A.Formal] -> (Meta -> A.Direction -> Transform a)
|
||||
-> Transform [a]
|
||||
doActuals m n fs applyDir as
|
||||
= do checkActualCount m n fs as
|
||||
sequence [case t of
|
||||
A.ChanEnd dir _ _ -> recurse a >>= applyDir m dir
|
||||
_ -> inTypeContext (Just t) $ recurse a
|
||||
| (A.Formal _ t _, a) <- zip fs as]
|
||||
sequence [doActual m applyDir t a | (A.Formal _ t _, a) <- zip fs as]
|
||||
|
||||
doActual :: Data a => Meta -> (Meta -> A.Direction -> Transform a) -> A.Type -> Transform a
|
||||
doActual m applyDir (A.ChanEnd dir _ _) a = recurse a >>= applyDir m dir
|
||||
doActual m _ t a = inTypeContext (Just t) $ recurse a
|
||||
|
||||
|
||||
doDimension :: Transform A.Dimension
|
||||
doDimension dim = inTypeContext (Just A.Int) $ descend dim
|
||||
|
@ -813,8 +777,8 @@ inferTypes = occamOnlyPass "Infer types"
|
|||
doExpressionList ts el
|
||||
= case el of
|
||||
A.FunctionCallList m n es ->
|
||||
do es' <- doFunctionCall m n es
|
||||
return $ A.FunctionCallList m n es'
|
||||
do (n', es') <- doFunctionCall m (n, es)
|
||||
return $ A.FunctionCallList m n' es'
|
||||
A.ExpressionList m es ->
|
||||
do es' <- sequence [inTypeContext (Just t) $ recurse e
|
||||
| (t, e) <- zip ts es]
|
||||
|
@ -853,7 +817,16 @@ inferTypes = occamOnlyPass "Infer types"
|
|||
= do st' <- runReaderT (doSpecType n st) body
|
||||
-- Update the definition of each name after we handle it.
|
||||
modifyName n (\nd -> nd { A.ndSpecType = st' })
|
||||
recurse body >>* A.Spec mspec (A.Specification m n st')
|
||||
let doBody = recurse body >>* A.Spec mspec (A.Specification m n st')
|
||||
mOp <- functionOperator n
|
||||
case (st, mOp) of
|
||||
(A.Function _ _ _ fs _, Just raw) -> do
|
||||
ts <- mapM astTypeOf fs
|
||||
modify $ \cs -> cs { csOperators = (raw, n, ts) : csOperators cs }
|
||||
x <- doBody
|
||||
modify $ \cs -> cs { csOperators = tail (csOperators cs)}
|
||||
return x
|
||||
_ -> doBody
|
||||
doStructured s = descend s
|
||||
|
||||
doSpecType :: Data a => A.Name -> A.SpecType -> ReaderT (A.Structured a) PassM A.SpecType
|
||||
|
@ -1290,8 +1263,6 @@ checkExpressions :: PassType
|
|||
checkExpressions = checkDepthM doExpression
|
||||
where
|
||||
doExpression :: Check A.Expression
|
||||
doExpression (A.Monadic _ op e) = checkMonadicOp op e
|
||||
doExpression (A.Dyadic _ op le re) = checkDyadicOp op le re
|
||||
doExpression (A.MostPos m t) = checkNumeric m t
|
||||
doExpression (A.MostNeg m t) = checkNumeric m t
|
||||
doExpression (A.SizeType m t) = checkSequence True m t
|
||||
|
|
Loading…
Reference in New Issue
Block a user