diff --git a/frontends/OccamTypes.hs b/frontends/OccamTypes.hs index 6b41d23..c39414e 100644 --- a/frontends/OccamTypes.hs +++ b/frontends/OccamTypes.hs @@ -172,6 +172,23 @@ checkExpressionInt e = checkExpressionType A.Int e checkExpressionBool :: Check A.Expression checkExpressionBool e = checkExpressionType A.Bool e +-- | Pick the more specific of a pair of types. +betterType :: A.Type -> A.Type -> A.Type +betterType t1 t2 + = case betterType' t1 t2 of + Left () -> t1 + Right () -> t2 + where + betterType' :: A.Type -> A.Type -> Either () () + betterType' A.Infer t = Right () + betterType' t A.Infer = Left () + betterType' t@(A.UserDataType _) _ = Left () + betterType' _ t@(A.UserDataType _) = Right () + betterType' t1@(A.Array ds1 et1) t2@(A.Array ds2 et2) + | length ds1 == length ds2 = betterType' et1 et2 + | length ds1 < length ds2 = Left () + betterType' t _ = Left () + --}}} --{{{ more complex checks @@ -624,25 +641,21 @@ inferTypes = applyExplicitM9 doExpression doDimension doSubscript -- Expressions that aren't literals, but that modify the type -- context. A.Dyadic m op le re -> - case classifyOp op of - -- No info about the LHS; infer the RHS type from the LHS. - ComparisonOp -> - do le' <- noTypeContext $ inferTypes le - t <- typeOfExpression le' - re' <- inTypeContext (Just t) $ inferTypes re - return $ A.Dyadic m op le' re' - -- The RHS type is always A.Int. - ShiftOp -> - do le' <- inferTypes le - re' <- inTypeContext (Just A.Int) $ inferTypes re - return $ A.Dyadic m op le' re' - -- Otherwise infer the LHS from the current context, - -- then the RHS from that. - _ -> - do le' <- inferTypes le - t <- typeOfExpression le' - re' <- inTypeContext (Just t) $ inferTypes re - return $ A.Dyadic m op le' re' + let -- Both types are the same. + bothSame + = do lt <- inferTypes le >>= typeOfExpression + rt <- inferTypes re >>= typeOfExpression + inTypeContext (Just $ betterType lt rt) $ + descend outer + -- The RHS type is always A.Int. + intOnRight + = do le' <- inferTypes le + re' <- inTypeContext (Just A.Int) $ inferTypes re + return $ A.Dyadic m op le' re' + in case classifyOp op of + ComparisonOp -> noTypeContext $ bothSame + ShiftOp -> intOnRight + _ -> bothSame A.SizeExpr _ _ -> noTypeContext $ descend outer A.Conversion _ _ _ _ -> noTypeContext $ descend outer A.FunctionCall m n es -> @@ -910,15 +923,12 @@ inferTypes = applyExplicitM9 doExpression doDimension doSubscript A.ArrayElemArray aes') _ -> diePC m $ formatCode "Table literal is not valid for type %" wantT where - -- | When walking along an array literal, use the type of the - -- first element as the default for the rest. doElems :: A.Type -> [A.ArrayElem] -> PassM (A.Type, [A.ArrayElem]) - doElems t [] = return (t, []) - doElems t (ae:aes) - = do (t', ae') <- doArrayElem t ae - aes' <- sequence [doArrayElem t' ae >>* snd - | ae <- aes] - return (t', ae':aes') + doElems t aes + = do ts <- mapM (\ae -> doArrayElem t ae >>* fst) aes + let bestT = foldl betterType t ts + aes' <- mapM (\ae -> doArrayElem bestT ae >>* snd) aes + return (bestT, aes') -- An expression: descend into it with the right context. doArrayElem wantT (A.ArrayElemExpr e) = do e' <- inTypeContext (Just wantT) $ doExpression descend e