diff --git a/data/CompState.hs b/data/CompState.hs index 945ae6d..d057bc4 100644 --- a/data/CompState.hs +++ b/data/CompState.hs @@ -160,8 +160,9 @@ data CompState = CompState { csAdditionalArgs :: Map String [A.Actual], csParProcs :: Map A.Name ParOrFork, csUnifyId :: Int, - -- The string is the operator, the name is the munged function name - csOperators :: [(String, A.Name, [A.Type])], + -- The string is the operator, the name is the munged function name, the single + -- type is the return type + csOperators :: [(String, A.Name, A.Type, [A.Type])], csWarnings :: [WarningReport] } deriving (Data, Typeable, Show) diff --git a/frontends/OccamInferTypes.hs b/frontends/OccamInferTypes.hs index 3a319b5..4bad805 100644 --- a/frontends/OccamInferTypes.hs +++ b/frontends/OccamInferTypes.hs @@ -248,41 +248,94 @@ inferTypes = occamOnlyPass "Infer types" 2 -> "binary" n -> show n ++ "-ary" - es' <- noTypeContext $ mapM recurse es - tes <- sequence [underlyingTypeOf m e `catchError` (const $ return A.Infer) | e <- es'] - cs <- getCompState resolvedOps <- sequence [ do ts' <- mapM (underlyingType m) ts - return (op, n, ts') - | (op, n, ts) <- csOperators cs + rt' <- underlyingType m rt + return (op, n, rt', ts') + | (op, n, rt, ts) <- csOperators cs ] - - -- The nubBy will ensure that only one definition remains for each - -- set of type-arguments, and will keep the first definition in the - -- list (which will be the most recent) - possibles <- return - [ ((opFuncName, es'), ts) - | (raw, opFuncName, ts) <- nubBy opsMatch resolvedOps - -- Must be right operator: - , raw == A.nameName n - -- Must be right arity: - , length ts == length es - -- Must have right types: - , ts `typesEqForOp` tes - ] - case possibles of - [] -> diePC m $ formatCode "No matching % operator definition found for types: %" opDescrip tes - [poss] -> return $ fst poss + -- The nubBy will ensure that only one definition remains for each + -- set of type-arguments, and will keep the first definition in the + -- list (which will be the most recent) + >>* nubBy opsMatch + + ctx <- (getTypeContext >>* fromMaybe A.Infer) >>= underlyingType m + let findOperatorTs :: [Maybe A.Type] -> [(A.Name, [A.Type])] + findOperatorTs mts + = [ (opFuncName, ts) + | (raw, opFuncName, rt, ts) <- resolvedOps + -- Must be right operator: + , raw == A.nameName n + -- Must be right arity: + , length ts == length mts + -- Must have right argument types: + , and $ zipWith (maybe True) (map typeEqForOp ts) mts + -- Must have right return type: + , ctx == A.Infer || ctx == rt + ] + + pickTypes :: [Maybe A.Type] -> [Maybe A.Type] + pickTypes mts = case findOperatorTs mts of + -- Exactly one match; use it: + [(_, ts)] -> map Just ts + -- No match, or multiple matches, no change: + _ -> mts + + -- We have to catch errors here because if it's an operator inside + -- an operator, the type-getting will fail because we haven't resolved + -- the inner operator yet. + origTs <- sequence [astTypeOf e `catchError` (const $ return A.Infer) | e <- es] + + es' <- mapM (uncurry inTypeContext) + $ zip (pickTypes [case t of + A.Infer -> Nothing + _ -> Just t + | t <- origTs]) + (map recurse es) +{- + es' <- if all (== A.Infer) origTs + && length (findOperatorTs [Nothing | _t <- origTs]) /= 1 + -- No operand has a definite type, and we can't determine it from + -- the return type, so we'll just use what recursing + -- with no type context brings back: + then noTypeContext $ mapM recurse es + -- We do have at least one definite (non-Infer) type, or there + -- is only one operator to pick: + else case origTs of + -- For binary operators, we'll use the definite type to find + -- an operator and if there's only one possibility, we'll use + -- that as the type contexts: + [A.Infer, t] -> mapM (uncurry inTypeContext) + $ zip (pickTypes [Nothing, Just t]) (map recurse es) + [t, A.Infer] -> mapM (uncurry inTypeContext) + $ zip (pickTypes [Just t, Nothing]) (map recurse es) + -- For unary operators, we have a definite type, so don't do + -- anything different; + [_] -> noTypeContext $ mapM recurse es + -- If we have two definite types for binary operators, we again + -- don't do anything different + [_,_] -> noTypeContext $ mapM recurse es + -- Anything else must be some different arity operator: + _ -> dieP m "Operator with strange arity" +-} + tes <- sequence [underlyingTypeOf m e `catchError` (const $ return A.Infer) | e <- es'] + + case findOperatorTs $ map Just tes of + [] -> case ctx of + A.Infer -> diePC m $ formatCode "No matching % operator definition found for types: %" opDescrip tes + _ -> diePC m $ formatCode ("No matching % operator definition found for types: %" + ++ " that returns type: % (original types: %)") opDescrip tes ctx origTs + [poss] -> return (fst poss, es') posss -> dieP m $ "Ambigious " ++ opDescrip ++ " operator, matches definitions: " - ++ show (map (transformPair (A.nameMeta . fst) showOccam) posss) + ++ show (map (transformPair A.nameMeta showOccam) posss) else do (_, fs) <- checkFunction m n doActuals m n fs (direct, const return) es >>* (,) n where direct = error "Cannot direct channels passed to FUNCTIONs" - opsMatch (opA, _, tsA) (opB, _, tsB) = (opA == opB) && (tsA `typesEqForOp` tsB) + opsMatch (opA, _, _, tsA) (opB, _, _, tsB) = (opA == opB) && (tsA `typesEqForOp` tsB) typesEqForOp :: [A.Type] -> [A.Type] -> Bool typesEqForOp tsA tsB = (length tsA == length tsB) && (and $ zipWith typeEqForOp tsA tsB) @@ -500,8 +553,11 @@ inferTypes = occamOnlyPass "Infer types" case mOp of Just raw -> do ts <- mapM astTypeOf fs + rt <- case ts' of + [t] -> return t + _ -> dieP m "Operator must have exactly one return" let before, after :: PassM () - before = modify $ \cs -> cs { csOperators = (raw, n, ts) : csOperators cs } + before = modify $ \cs -> cs { csOperators = (raw, n, rt, ts) : csOperators cs } after = modify $ \cs -> cs { csOperators = tail (csOperators cs)} return (func ,\m -> do before @@ -509,9 +565,11 @@ inferTypes = occamOnlyPass "Infer types" after return x) _ -> return func >>* addId - A.Retypes m am t v -> lift $ inTypeContext (Just t) $ - (recurse v >>= derefVariableIfNeeded (Just t)) >>* - (addId . A.Retypes m am t) + A.Retypes m am t v -> + do t' <- lift $ recurse t + lift $ inTypeContext (Just t') $ + (recurse v >>= derefVariableIfNeeded (Just t')) >>* + (addId . A.Retypes m am t') A.RetypesExpr _ _ _ _ -> lift $ noTypeContext $ descend st >>* addId -- For PROCs that take any channels without direction, -- we must determine if we can infer a specific direction diff --git a/frontends/ParseOccam.hs b/frontends/ParseOccam.hs index 0b0e235..4eeb04e 100644 --- a/frontends/ParseOccam.hs +++ b/frontends/ParseOccam.hs @@ -712,7 +712,7 @@ stringLiteral :: OccParser (A.Type, A.LiteralRepr) stringLiteral = do m <- md cs <- stringCont <|> stringLit - let aes = A.Several m [A.Only m $ A.Literal m' A.Infer c + let aes = A.Several m [A.Only m $ A.Literal m' A.Byte c | c@(A.ByteLiteral m' _) <- cs] return (A.Array [A.UnknownDimension] A.Byte, A.ArrayListLiteral m aes) "string literal"