Fixed the type inference for operators

The previous system followed Jim Moores' description, which turns out to not be how KRoC does it, and different enough that lots of existing occam code was failing.

The new system makes two changes.  First, if there is only one suitable operator, given the types (taking Infer to be a wild-card) then it uses that operator, and can usually settle any remaining Infers.  Second, it uses the type context to narrow down the operators to only those that have a matching return type.  These changes seem enough to compile a lot of the trickier types.
This commit is contained in:
Neil Brown 2009-04-17 18:01:17 +00:00
parent 4add133d5b
commit 759bc06a53
3 changed files with 90 additions and 31 deletions

View File

@ -160,8 +160,9 @@ data CompState = CompState {
csAdditionalArgs :: Map String [A.Actual],
csParProcs :: Map A.Name ParOrFork,
csUnifyId :: Int,
-- The string is the operator, the name is the munged function name
csOperators :: [(String, A.Name, [A.Type])],
-- The string is the operator, the name is the munged function name, the single
-- type is the return type
csOperators :: [(String, A.Name, A.Type, [A.Type])],
csWarnings :: [WarningReport]
}
deriving (Data, Typeable, Show)

View File

@ -248,41 +248,94 @@ inferTypes = occamOnlyPass "Infer types"
2 -> "binary"
n -> show n ++ "-ary"
es' <- noTypeContext $ mapM recurse es
tes <- sequence [underlyingTypeOf m e `catchError` (const $ return A.Infer) | e <- es']
cs <- getCompState
resolvedOps <- sequence [ do ts' <- mapM (underlyingType m) ts
return (op, n, ts')
| (op, n, ts) <- csOperators cs
rt' <- underlyingType m rt
return (op, n, rt', ts')
| (op, n, rt, ts) <- csOperators cs
]
-- The nubBy will ensure that only one definition remains for each
-- set of type-arguments, and will keep the first definition in the
-- list (which will be the most recent)
possibles <- return
[ ((opFuncName, es'), ts)
| (raw, opFuncName, ts) <- nubBy opsMatch resolvedOps
-- Must be right operator:
, raw == A.nameName n
-- Must be right arity:
, length ts == length es
-- Must have right types:
, ts `typesEqForOp` tes
]
case possibles of
[] -> diePC m $ formatCode "No matching % operator definition found for types: %" opDescrip tes
[poss] -> return $ fst poss
-- The nubBy will ensure that only one definition remains for each
-- set of type-arguments, and will keep the first definition in the
-- list (which will be the most recent)
>>* nubBy opsMatch
ctx <- (getTypeContext >>* fromMaybe A.Infer) >>= underlyingType m
let findOperatorTs :: [Maybe A.Type] -> [(A.Name, [A.Type])]
findOperatorTs mts
= [ (opFuncName, ts)
| (raw, opFuncName, rt, ts) <- resolvedOps
-- Must be right operator:
, raw == A.nameName n
-- Must be right arity:
, length ts == length mts
-- Must have right argument types:
, and $ zipWith (maybe True) (map typeEqForOp ts) mts
-- Must have right return type:
, ctx == A.Infer || ctx == rt
]
pickTypes :: [Maybe A.Type] -> [Maybe A.Type]
pickTypes mts = case findOperatorTs mts of
-- Exactly one match; use it:
[(_, ts)] -> map Just ts
-- No match, or multiple matches, no change:
_ -> mts
-- We have to catch errors here because if it's an operator inside
-- an operator, the type-getting will fail because we haven't resolved
-- the inner operator yet.
origTs <- sequence [astTypeOf e `catchError` (const $ return A.Infer) | e <- es]
es' <- mapM (uncurry inTypeContext)
$ zip (pickTypes [case t of
A.Infer -> Nothing
_ -> Just t
| t <- origTs])
(map recurse es)
{-
es' <- if all (== A.Infer) origTs
&& length (findOperatorTs [Nothing | _t <- origTs]) /= 1
-- No operand has a definite type, and we can't determine it from
-- the return type, so we'll just use what recursing
-- with no type context brings back:
then noTypeContext $ mapM recurse es
-- We do have at least one definite (non-Infer) type, or there
-- is only one operator to pick:
else case origTs of
-- For binary operators, we'll use the definite type to find
-- an operator and if there's only one possibility, we'll use
-- that as the type contexts:
[A.Infer, t] -> mapM (uncurry inTypeContext)
$ zip (pickTypes [Nothing, Just t]) (map recurse es)
[t, A.Infer] -> mapM (uncurry inTypeContext)
$ zip (pickTypes [Just t, Nothing]) (map recurse es)
-- For unary operators, we have a definite type, so don't do
-- anything different;
[_] -> noTypeContext $ mapM recurse es
-- If we have two definite types for binary operators, we again
-- don't do anything different
[_,_] -> noTypeContext $ mapM recurse es
-- Anything else must be some different arity operator:
_ -> dieP m "Operator with strange arity"
-}
tes <- sequence [underlyingTypeOf m e `catchError` (const $ return A.Infer) | e <- es']
case findOperatorTs $ map Just tes of
[] -> case ctx of
A.Infer -> diePC m $ formatCode "No matching % operator definition found for types: %" opDescrip tes
_ -> diePC m $ formatCode ("No matching % operator definition found for types: %"
++ " that returns type: % (original types: %)") opDescrip tes ctx origTs
[poss] -> return (fst poss, es')
posss -> dieP m $ "Ambigious " ++ opDescrip ++ " operator, matches definitions: "
++ show (map (transformPair (A.nameMeta . fst) showOccam) posss)
++ show (map (transformPair A.nameMeta showOccam) posss)
else
do (_, fs) <- checkFunction m n
doActuals m n fs (direct, const return) es >>* (,) n
where
direct = error "Cannot direct channels passed to FUNCTIONs"
opsMatch (opA, _, tsA) (opB, _, tsB) = (opA == opB) && (tsA `typesEqForOp` tsB)
opsMatch (opA, _, _, tsA) (opB, _, _, tsB) = (opA == opB) && (tsA `typesEqForOp` tsB)
typesEqForOp :: [A.Type] -> [A.Type] -> Bool
typesEqForOp tsA tsB = (length tsA == length tsB) && (and $ zipWith typeEqForOp tsA tsB)
@ -500,8 +553,11 @@ inferTypes = occamOnlyPass "Infer types"
case mOp of
Just raw -> do
ts <- mapM astTypeOf fs
rt <- case ts' of
[t] -> return t
_ -> dieP m "Operator must have exactly one return"
let before, after :: PassM ()
before = modify $ \cs -> cs { csOperators = (raw, n, ts) : csOperators cs }
before = modify $ \cs -> cs { csOperators = (raw, n, rt, ts) : csOperators cs }
after = modify $ \cs -> cs { csOperators = tail (csOperators cs)}
return (func
,\m -> do before
@ -509,9 +565,11 @@ inferTypes = occamOnlyPass "Infer types"
after
return x)
_ -> return func >>* addId
A.Retypes m am t v -> lift $ inTypeContext (Just t) $
(recurse v >>= derefVariableIfNeeded (Just t)) >>*
(addId . A.Retypes m am t)
A.Retypes m am t v ->
do t' <- lift $ recurse t
lift $ inTypeContext (Just t') $
(recurse v >>= derefVariableIfNeeded (Just t')) >>*
(addId . A.Retypes m am t')
A.RetypesExpr _ _ _ _ -> lift $ noTypeContext $ descend st >>* addId
-- For PROCs that take any channels without direction,
-- we must determine if we can infer a specific direction

View File

@ -712,7 +712,7 @@ stringLiteral :: OccParser (A.Type, A.LiteralRepr)
stringLiteral
= do m <- md
cs <- stringCont <|> stringLit
let aes = A.Several m [A.Only m $ A.Literal m' A.Infer c
let aes = A.Several m [A.Only m $ A.Literal m' A.Byte c
| c@(A.ByteLiteral m' _) <- cs]
return (A.Array [A.UnknownDimension] A.Byte, A.ArrayListLiteral m aes)
<?> "string literal"