From e1c18cc082e01cf1fabc2951f638aa0e61de322e Mon Sep 17 00:00:00 2001 From: Neil Brown Date: Sun, 5 Apr 2009 23:01:26 +0000 Subject: [PATCH] Changed inferTypes to resolve operators to the right definition This patch follows on from the previous change to the parser. When it spots a function-call, it looks for operators and treats them differently. It keeps a stack of operators in scope (csOperators in CompState), and when an operator is used, it searches the stack (with all old definitions masked out) for operator definitions to resolve to. The way it chooses which operator to use in the presence of overloadings (e.g. + on INT vs + on INT32) is simply to try them all. If one matches, it uses that. If none, or more than one match, it gives an error. This makes the code simple and seems logical, but I'm not totally confident if this is the required behaviour for resolving overloaded operators. --- data/CompState.hs | 3 + frontends/OccamTypes.hs | 171 +++++++++++++++++----------------------- 2 files changed, 74 insertions(+), 100 deletions(-) diff --git a/data/CompState.hs b/data/CompState.hs index 58c27c3..f55a0ac 100644 --- a/data/CompState.hs +++ b/data/CompState.hs @@ -149,6 +149,8 @@ data CompState = CompState { csAdditionalArgs :: Map String [A.Actual], csParProcs :: Set A.Name, csUnifyId :: Int, + -- The string is the operator, the name is the munged function name + csOperators :: [(String, A.Name, [A.Type])], csWarnings :: [WarningReport] } deriving (Data, Typeable, Show) @@ -205,6 +207,7 @@ emptyState = CompState { csAdditionalArgs = Map.empty, csParProcs = Set.empty, csUnifyId = 0, + csOperators = [], csWarnings = [] } diff --git a/frontends/OccamTypes.hs b/frontends/OccamTypes.hs index bf1f724..bfdb505 100644 --- a/frontends/OccamTypes.hs +++ b/frontends/OccamTypes.hs @@ -19,10 +19,14 @@ with this program. If not, see . -- | The occam typechecker. module OccamTypes (inferTypes, checkTypes, addDirections) where +import Control.Monad.Error import Control.Monad.Reader import Control.Monad.State +import Data.Function (on) import Data.Generics import Data.List +import qualified Data.Map as Map +import Data.Maybe import qualified AST as A import CompState @@ -250,74 +254,6 @@ checkSubscript m s rawT A.SubscriptFor m _ e -> checkExpressionInt e _ -> ok --- | Classes of operators. -data OpClass = NumericOp | IntegerOp | ShiftOp | BooleanOp | ComparisonOp - | ListOp - --- | Figure out the class of a monadic operator. -classifyMOp :: A.MonadicOp -> OpClass -classifyMOp A.MonadicSubtr = NumericOp -classifyMOp A.MonadicMinus = NumericOp -classifyMOp A.MonadicBitNot = IntegerOp -classifyMOp A.MonadicNot = BooleanOp - --- | Figure out the class of a dyadic operator. -classifyOp :: A.DyadicOp -> OpClass -classifyOp A.Add = NumericOp -classifyOp A.Subtr = NumericOp -classifyOp A.Mul = NumericOp -classifyOp A.Div = NumericOp -classifyOp A.Rem = NumericOp -classifyOp A.Plus = NumericOp -classifyOp A.Minus = NumericOp -classifyOp A.Times = NumericOp -classifyOp A.BitAnd = IntegerOp -classifyOp A.BitOr = IntegerOp -classifyOp A.BitXor = IntegerOp -classifyOp A.LeftShift = ShiftOp -classifyOp A.RightShift = ShiftOp -classifyOp A.And = BooleanOp -classifyOp A.Or = BooleanOp -classifyOp A.Eq = ComparisonOp -classifyOp A.NotEq = ComparisonOp -classifyOp A.Less = ComparisonOp -classifyOp A.More = ComparisonOp -classifyOp A.LessEq = ComparisonOp -classifyOp A.MoreEq = ComparisonOp -classifyOp A.After = ComparisonOp -classifyOp A.Concat = ListOp - --- | Check a monadic operator. -checkMonadicOp :: A.MonadicOp -> A.Expression -> PassM () -checkMonadicOp op e - = do t <- astTypeOf e - let m = findMeta e - case classifyMOp op of - NumericOp -> checkNumeric m t - IntegerOp -> checkInteger m t - BooleanOp -> checkType m A.Bool t - --- | Check a dyadic operator. -checkDyadicOp :: A.DyadicOp -> A.Expression -> A.Expression -> PassM () -checkDyadicOp op l r - = do lt <- astTypeOf l - let lm = findMeta l - rt <- astTypeOf r - let rm = findMeta r - case classifyOp op of - NumericOp -> - checkNumeric lm lt >> checkNumeric rm rt >> checkType rm lt rt - IntegerOp -> - checkInteger lm lt >> checkInteger rm rt >> checkType rm lt rt - ShiftOp -> - checkNumeric lm lt >> checkType rm A.Int rt - BooleanOp -> - checkType lm A.Bool lt >> checkType rm A.Bool rt - ComparisonOp -> - checkScalar lm lt >> checkScalar rm rt >> checkType rm lt rt - ListOp -> - checkList lm lt >> checkList rm rt >> checkType rm lt rt - -- | Check an abbreviation. -- Is the second abbrev mode a valid abbreviation of the first? checkAbbrev :: Meta -> A.AbbrevMode -> A.AbbrevMode -> PassM () @@ -367,7 +303,7 @@ checkActual (A.Formal newAM et _) a -- | Check a function exists. checkFunction :: Meta -> A.Name -> PassM ([A.Type], [A.Formal]) checkFunction m n - = do st <- specTypeOfName n + = do st <- lookupNameOrError n (diePC m $ formatCode "Could not find function %" n) >>* A.ndSpecType case st of A.Function _ _ rs fs _ -> return (rs, fs) _ -> diePC m $ formatCode "% is not a function" n @@ -742,27 +678,11 @@ inferTypes = occamOnlyPass "Infer types" -- Expressions that aren't literals, but that modify the type -- context. - A.Dyadic m op le re -> - let -- Both types are the same. - bothSame - = do lt <- recurse le >>= astTypeOf - rt <- recurse re >>= astTypeOf - inTypeContext (Just $ betterType lt rt) $ - descend outer - -- The RHS type is always A.Int. - intOnRight - = do le' <- recurse le - re' <- inTypeContext (Just A.Int) $ recurse re - return $ A.Dyadic m op le' re' - in scrubMobile $ case classifyOp op of - ComparisonOp -> noTypeContext $ bothSame - ShiftOp -> intOnRight - _ -> bothSame A.SizeExpr _ _ -> noTypeContext $ descend outer A.Conversion _ _ _ _ -> noTypeContext $ descend outer A.FunctionCall m n es -> - do es' <- doFunctionCall m n es - return $ A.FunctionCall m n es' + do (n', es') <- doFunctionCall m (n, es) + return $ A.FunctionCall m n' es' A.IntrinsicFunctionCall _ _ _ -> noTypeContext $ descend outer A.SubscriptedExpr m s e -> do ctx <- getTypeContext @@ -789,19 +709,63 @@ inferTypes = occamOnlyPass "Infer types" -- Other expressions don't modify the type context. _ -> descend outer - doFunctionCall :: Meta -> A.Name -> Transform [A.Expression] - doFunctionCall m n es - = do (_, fs) <- checkFunction m n - doActuals m n fs (error "Cannot direct channels passed to FUNCTIONs") es + doFunctionCall :: Meta -> Transform (A.Name, [A.Expression]) + doFunctionCall m (n, es) = do + if isOperator (A.nameName n) + then + -- for operators, resolve the function name, based on the type + do let opDescrip = "\"" ++ (A.nameName n) ++ "\" " + ++ case length es of + 1 -> "unary" + 2 -> "binary" + n -> show n ++ "-ary" + + cs <- getCompState + + -- The nubBy will ensure that only one definition remains for each + -- set of type-arguments, and will keep the first definition in the + -- list (which will be the most recent) + possibles <- sequence + [ (do es' <- sequence + [do e' <- doActual m direct t e + checkActual (A.Formal A.ValAbbrev t (A.Name m "x")) + (A.ActualExpression e') + return e' + | (t, e) <- zip ts es] + return $ Right ((opFuncName, es'), ts) + ) `catchError` (return . Left) + | (raw, opFuncName, ts) <- nubBy ((==) `on` (\(op,_,ts) -> (op,ts))) $ csOperators cs + -- Must be right operator: + , raw == A.nameName n + -- Must be right arity: + , length ts == length es] + case splitEither possibles of + -- We want to be helpful and give the user an idea + -- of what we thought the types were, but we must + -- also be careful not to die while getting the + -- types (and thus missing the real error!) + (errs,[]) -> do tes <- sequence [astTypeOf e `catchError` (const $ return A.Infer) | e <- es] + diePC m $ formatCode ("No matching " ++ opDescrip ++ " operator definition found for types: %" + ++ " errors were: " ++ show errs) tes + (_, [poss]) -> return $ fst poss + (_, posss) -> dieP m $ "Ambigious " ++ opDescrip ++ " operator, matches definitions: " + ++ show (map (transformPair (A.nameMeta . fst) showOccam) posss) + else + do (_, fs) <- checkFunction m n + doActuals m n fs direct es >>* (,) n + where + direct = error "Cannot direct channels passed to FUNCTIONs" doActuals :: Data a => Meta -> A.Name -> [A.Formal] -> (Meta -> A.Direction -> Transform a) -> Transform [a] doActuals m n fs applyDir as = do checkActualCount m n fs as - sequence [case t of - A.ChanEnd dir _ _ -> recurse a >>= applyDir m dir - _ -> inTypeContext (Just t) $ recurse a - | (A.Formal _ t _, a) <- zip fs as] + sequence [doActual m applyDir t a | (A.Formal _ t _, a) <- zip fs as] + + doActual :: Data a => Meta -> (Meta -> A.Direction -> Transform a) -> A.Type -> Transform a + doActual m applyDir (A.ChanEnd dir _ _) a = recurse a >>= applyDir m dir + doActual m _ t a = inTypeContext (Just t) $ recurse a + doDimension :: Transform A.Dimension doDimension dim = inTypeContext (Just A.Int) $ descend dim @@ -813,8 +777,8 @@ inferTypes = occamOnlyPass "Infer types" doExpressionList ts el = case el of A.FunctionCallList m n es -> - do es' <- doFunctionCall m n es - return $ A.FunctionCallList m n es' + do (n', es') <- doFunctionCall m (n, es) + return $ A.FunctionCallList m n' es' A.ExpressionList m es -> do es' <- sequence [inTypeContext (Just t) $ recurse e | (t, e) <- zip ts es] @@ -853,7 +817,16 @@ inferTypes = occamOnlyPass "Infer types" = do st' <- runReaderT (doSpecType n st) body -- Update the definition of each name after we handle it. modifyName n (\nd -> nd { A.ndSpecType = st' }) - recurse body >>* A.Spec mspec (A.Specification m n st') + let doBody = recurse body >>* A.Spec mspec (A.Specification m n st') + mOp <- functionOperator n + case (st, mOp) of + (A.Function _ _ _ fs _, Just raw) -> do + ts <- mapM astTypeOf fs + modify $ \cs -> cs { csOperators = (raw, n, ts) : csOperators cs } + x <- doBody + modify $ \cs -> cs { csOperators = tail (csOperators cs)} + return x + _ -> doBody doStructured s = descend s doSpecType :: Data a => A.Name -> A.SpecType -> ReaderT (A.Structured a) PassM A.SpecType @@ -1290,8 +1263,6 @@ checkExpressions :: PassType checkExpressions = checkDepthM doExpression where doExpression :: Check A.Expression - doExpression (A.Monadic _ op e) = checkMonadicOp op e - doExpression (A.Dyadic _ op le re) = checkDyadicOp op le re doExpression (A.MostPos m t) = checkNumeric m t doExpression (A.MostNeg m t) = checkNumeric m t doExpression (A.SizeType m t) = checkSequence True m t