From 759bc06a5305980c6bdcf3b02fda6c756a179b0b Mon Sep 17 00:00:00 2001 From: Neil Brown Date: Fri, 17 Apr 2009 18:01:17 +0000 Subject: [PATCH] Fixed the type inference for operators The previous system followed Jim Moores' description, which turns out to not be how KRoC does it, and different enough that lots of existing occam code was failing. The new system makes two changes. First, if there is only one suitable operator, given the types (taking Infer to be a wild-card) then it uses that operator, and can usually settle any remaining Infers. Second, it uses the type context to narrow down the operators to only those that have a matching return type. These changes seem enough to compile a lot of the trickier types. --- data/CompState.hs | 5 +- frontends/OccamInferTypes.hs | 114 ++++++++++++++++++++++++++--------- frontends/ParseOccam.hs | 2 +- 3 files changed, 90 insertions(+), 31 deletions(-) diff --git a/data/CompState.hs b/data/CompState.hs index 945ae6d..d057bc4 100644 --- a/data/CompState.hs +++ b/data/CompState.hs @@ -160,8 +160,9 @@ data CompState = CompState { csAdditionalArgs :: Map String [A.Actual], csParProcs :: Map A.Name ParOrFork, csUnifyId :: Int, - -- The string is the operator, the name is the munged function name - csOperators :: [(String, A.Name, [A.Type])], + -- The string is the operator, the name is the munged function name, the single + -- type is the return type + csOperators :: [(String, A.Name, A.Type, [A.Type])], csWarnings :: [WarningReport] } deriving (Data, Typeable, Show) diff --git a/frontends/OccamInferTypes.hs b/frontends/OccamInferTypes.hs index 3a319b5..4bad805 100644 --- a/frontends/OccamInferTypes.hs +++ b/frontends/OccamInferTypes.hs @@ -248,41 +248,94 @@ inferTypes = occamOnlyPass "Infer types" 2 -> "binary" n -> show n ++ "-ary" - es' <- noTypeContext $ mapM recurse es - tes <- sequence [underlyingTypeOf m e `catchError` (const $ return A.Infer) | e <- es'] - cs <- getCompState resolvedOps <- sequence [ do ts' <- mapM (underlyingType m) ts - return (op, n, ts') - | (op, n, ts) <- csOperators cs + rt' <- underlyingType m rt + return (op, n, rt', ts') + | (op, n, rt, ts) <- csOperators cs ] - - -- The nubBy will ensure that only one definition remains for each - -- set of type-arguments, and will keep the first definition in the - -- list (which will be the most recent) - possibles <- return - [ ((opFuncName, es'), ts) - | (raw, opFuncName, ts) <- nubBy opsMatch resolvedOps - -- Must be right operator: - , raw == A.nameName n - -- Must be right arity: - , length ts == length es - -- Must have right types: - , ts `typesEqForOp` tes - ] - case possibles of - [] -> diePC m $ formatCode "No matching % operator definition found for types: %" opDescrip tes - [poss] -> return $ fst poss + -- The nubBy will ensure that only one definition remains for each + -- set of type-arguments, and will keep the first definition in the + -- list (which will be the most recent) + >>* nubBy opsMatch + + ctx <- (getTypeContext >>* fromMaybe A.Infer) >>= underlyingType m + let findOperatorTs :: [Maybe A.Type] -> [(A.Name, [A.Type])] + findOperatorTs mts + = [ (opFuncName, ts) + | (raw, opFuncName, rt, ts) <- resolvedOps + -- Must be right operator: + , raw == A.nameName n + -- Must be right arity: + , length ts == length mts + -- Must have right argument types: + , and $ zipWith (maybe True) (map typeEqForOp ts) mts + -- Must have right return type: + , ctx == A.Infer || ctx == rt + ] + + pickTypes :: [Maybe A.Type] -> [Maybe A.Type] + pickTypes mts = case findOperatorTs mts of + -- Exactly one match; use it: + [(_, ts)] -> map Just ts + -- No match, or multiple matches, no change: + _ -> mts + + -- We have to catch errors here because if it's an operator inside + -- an operator, the type-getting will fail because we haven't resolved + -- the inner operator yet. + origTs <- sequence [astTypeOf e `catchError` (const $ return A.Infer) | e <- es] + + es' <- mapM (uncurry inTypeContext) + $ zip (pickTypes [case t of + A.Infer -> Nothing + _ -> Just t + | t <- origTs]) + (map recurse es) +{- + es' <- if all (== A.Infer) origTs + && length (findOperatorTs [Nothing | _t <- origTs]) /= 1 + -- No operand has a definite type, and we can't determine it from + -- the return type, so we'll just use what recursing + -- with no type context brings back: + then noTypeContext $ mapM recurse es + -- We do have at least one definite (non-Infer) type, or there + -- is only one operator to pick: + else case origTs of + -- For binary operators, we'll use the definite type to find + -- an operator and if there's only one possibility, we'll use + -- that as the type contexts: + [A.Infer, t] -> mapM (uncurry inTypeContext) + $ zip (pickTypes [Nothing, Just t]) (map recurse es) + [t, A.Infer] -> mapM (uncurry inTypeContext) + $ zip (pickTypes [Just t, Nothing]) (map recurse es) + -- For unary operators, we have a definite type, so don't do + -- anything different; + [_] -> noTypeContext $ mapM recurse es + -- If we have two definite types for binary operators, we again + -- don't do anything different + [_,_] -> noTypeContext $ mapM recurse es + -- Anything else must be some different arity operator: + _ -> dieP m "Operator with strange arity" +-} + tes <- sequence [underlyingTypeOf m e `catchError` (const $ return A.Infer) | e <- es'] + + case findOperatorTs $ map Just tes of + [] -> case ctx of + A.Infer -> diePC m $ formatCode "No matching % operator definition found for types: %" opDescrip tes + _ -> diePC m $ formatCode ("No matching % operator definition found for types: %" + ++ " that returns type: % (original types: %)") opDescrip tes ctx origTs + [poss] -> return (fst poss, es') posss -> dieP m $ "Ambigious " ++ opDescrip ++ " operator, matches definitions: " - ++ show (map (transformPair (A.nameMeta . fst) showOccam) posss) + ++ show (map (transformPair A.nameMeta showOccam) posss) else do (_, fs) <- checkFunction m n doActuals m n fs (direct, const return) es >>* (,) n where direct = error "Cannot direct channels passed to FUNCTIONs" - opsMatch (opA, _, tsA) (opB, _, tsB) = (opA == opB) && (tsA `typesEqForOp` tsB) + opsMatch (opA, _, _, tsA) (opB, _, _, tsB) = (opA == opB) && (tsA `typesEqForOp` tsB) typesEqForOp :: [A.Type] -> [A.Type] -> Bool typesEqForOp tsA tsB = (length tsA == length tsB) && (and $ zipWith typeEqForOp tsA tsB) @@ -500,8 +553,11 @@ inferTypes = occamOnlyPass "Infer types" case mOp of Just raw -> do ts <- mapM astTypeOf fs + rt <- case ts' of + [t] -> return t + _ -> dieP m "Operator must have exactly one return" let before, after :: PassM () - before = modify $ \cs -> cs { csOperators = (raw, n, ts) : csOperators cs } + before = modify $ \cs -> cs { csOperators = (raw, n, rt, ts) : csOperators cs } after = modify $ \cs -> cs { csOperators = tail (csOperators cs)} return (func ,\m -> do before @@ -509,9 +565,11 @@ inferTypes = occamOnlyPass "Infer types" after return x) _ -> return func >>* addId - A.Retypes m am t v -> lift $ inTypeContext (Just t) $ - (recurse v >>= derefVariableIfNeeded (Just t)) >>* - (addId . A.Retypes m am t) + A.Retypes m am t v -> + do t' <- lift $ recurse t + lift $ inTypeContext (Just t') $ + (recurse v >>= derefVariableIfNeeded (Just t')) >>* + (addId . A.Retypes m am t') A.RetypesExpr _ _ _ _ -> lift $ noTypeContext $ descend st >>* addId -- For PROCs that take any channels without direction, -- we must determine if we can infer a specific direction diff --git a/frontends/ParseOccam.hs b/frontends/ParseOccam.hs index 0b0e235..4eeb04e 100644 --- a/frontends/ParseOccam.hs +++ b/frontends/ParseOccam.hs @@ -712,7 +712,7 @@ stringLiteral :: OccParser (A.Type, A.LiteralRepr) stringLiteral = do m <- md cs <- stringCont <|> stringLit - let aes = A.Several m [A.Only m $ A.Literal m' A.Infer c + let aes = A.Several m [A.Only m $ A.Literal m' A.Byte c | c@(A.ByteLiteral m' _) <- cs] return (A.Array [A.UnknownDimension] A.Byte, A.ArrayListLiteral m aes) "string literal"