diff --git a/data/CompState.hs b/data/CompState.hs index 58c27c3..f55a0ac 100644 --- a/data/CompState.hs +++ b/data/CompState.hs @@ -149,6 +149,8 @@ data CompState = CompState { csAdditionalArgs :: Map String [A.Actual], csParProcs :: Set A.Name, csUnifyId :: Int, + -- The string is the operator, the name is the munged function name + csOperators :: [(String, A.Name, [A.Type])], csWarnings :: [WarningReport] } deriving (Data, Typeable, Show) @@ -205,6 +207,7 @@ emptyState = CompState { csAdditionalArgs = Map.empty, csParProcs = Set.empty, csUnifyId = 0, + csOperators = [], csWarnings = [] } diff --git a/frontends/OccamTypes.hs b/frontends/OccamTypes.hs index bf1f724..bfdb505 100644 --- a/frontends/OccamTypes.hs +++ b/frontends/OccamTypes.hs @@ -19,10 +19,14 @@ with this program. If not, see . -- | The occam typechecker. module OccamTypes (inferTypes, checkTypes, addDirections) where +import Control.Monad.Error import Control.Monad.Reader import Control.Monad.State +import Data.Function (on) import Data.Generics import Data.List +import qualified Data.Map as Map +import Data.Maybe import qualified AST as A import CompState @@ -250,74 +254,6 @@ checkSubscript m s rawT A.SubscriptFor m _ e -> checkExpressionInt e _ -> ok --- | Classes of operators. -data OpClass = NumericOp | IntegerOp | ShiftOp | BooleanOp | ComparisonOp - | ListOp - --- | Figure out the class of a monadic operator. -classifyMOp :: A.MonadicOp -> OpClass -classifyMOp A.MonadicSubtr = NumericOp -classifyMOp A.MonadicMinus = NumericOp -classifyMOp A.MonadicBitNot = IntegerOp -classifyMOp A.MonadicNot = BooleanOp - --- | Figure out the class of a dyadic operator. -classifyOp :: A.DyadicOp -> OpClass -classifyOp A.Add = NumericOp -classifyOp A.Subtr = NumericOp -classifyOp A.Mul = NumericOp -classifyOp A.Div = NumericOp -classifyOp A.Rem = NumericOp -classifyOp A.Plus = NumericOp -classifyOp A.Minus = NumericOp -classifyOp A.Times = NumericOp -classifyOp A.BitAnd = IntegerOp -classifyOp A.BitOr = IntegerOp -classifyOp A.BitXor = IntegerOp -classifyOp A.LeftShift = ShiftOp -classifyOp A.RightShift = ShiftOp -classifyOp A.And = BooleanOp -classifyOp A.Or = BooleanOp -classifyOp A.Eq = ComparisonOp -classifyOp A.NotEq = ComparisonOp -classifyOp A.Less = ComparisonOp -classifyOp A.More = ComparisonOp -classifyOp A.LessEq = ComparisonOp -classifyOp A.MoreEq = ComparisonOp -classifyOp A.After = ComparisonOp -classifyOp A.Concat = ListOp - --- | Check a monadic operator. -checkMonadicOp :: A.MonadicOp -> A.Expression -> PassM () -checkMonadicOp op e - = do t <- astTypeOf e - let m = findMeta e - case classifyMOp op of - NumericOp -> checkNumeric m t - IntegerOp -> checkInteger m t - BooleanOp -> checkType m A.Bool t - --- | Check a dyadic operator. -checkDyadicOp :: A.DyadicOp -> A.Expression -> A.Expression -> PassM () -checkDyadicOp op l r - = do lt <- astTypeOf l - let lm = findMeta l - rt <- astTypeOf r - let rm = findMeta r - case classifyOp op of - NumericOp -> - checkNumeric lm lt >> checkNumeric rm rt >> checkType rm lt rt - IntegerOp -> - checkInteger lm lt >> checkInteger rm rt >> checkType rm lt rt - ShiftOp -> - checkNumeric lm lt >> checkType rm A.Int rt - BooleanOp -> - checkType lm A.Bool lt >> checkType rm A.Bool rt - ComparisonOp -> - checkScalar lm lt >> checkScalar rm rt >> checkType rm lt rt - ListOp -> - checkList lm lt >> checkList rm rt >> checkType rm lt rt - -- | Check an abbreviation. -- Is the second abbrev mode a valid abbreviation of the first? checkAbbrev :: Meta -> A.AbbrevMode -> A.AbbrevMode -> PassM () @@ -367,7 +303,7 @@ checkActual (A.Formal newAM et _) a -- | Check a function exists. checkFunction :: Meta -> A.Name -> PassM ([A.Type], [A.Formal]) checkFunction m n - = do st <- specTypeOfName n + = do st <- lookupNameOrError n (diePC m $ formatCode "Could not find function %" n) >>* A.ndSpecType case st of A.Function _ _ rs fs _ -> return (rs, fs) _ -> diePC m $ formatCode "% is not a function" n @@ -742,27 +678,11 @@ inferTypes = occamOnlyPass "Infer types" -- Expressions that aren't literals, but that modify the type -- context. - A.Dyadic m op le re -> - let -- Both types are the same. - bothSame - = do lt <- recurse le >>= astTypeOf - rt <- recurse re >>= astTypeOf - inTypeContext (Just $ betterType lt rt) $ - descend outer - -- The RHS type is always A.Int. - intOnRight - = do le' <- recurse le - re' <- inTypeContext (Just A.Int) $ recurse re - return $ A.Dyadic m op le' re' - in scrubMobile $ case classifyOp op of - ComparisonOp -> noTypeContext $ bothSame - ShiftOp -> intOnRight - _ -> bothSame A.SizeExpr _ _ -> noTypeContext $ descend outer A.Conversion _ _ _ _ -> noTypeContext $ descend outer A.FunctionCall m n es -> - do es' <- doFunctionCall m n es - return $ A.FunctionCall m n es' + do (n', es') <- doFunctionCall m (n, es) + return $ A.FunctionCall m n' es' A.IntrinsicFunctionCall _ _ _ -> noTypeContext $ descend outer A.SubscriptedExpr m s e -> do ctx <- getTypeContext @@ -789,19 +709,63 @@ inferTypes = occamOnlyPass "Infer types" -- Other expressions don't modify the type context. _ -> descend outer - doFunctionCall :: Meta -> A.Name -> Transform [A.Expression] - doFunctionCall m n es - = do (_, fs) <- checkFunction m n - doActuals m n fs (error "Cannot direct channels passed to FUNCTIONs") es + doFunctionCall :: Meta -> Transform (A.Name, [A.Expression]) + doFunctionCall m (n, es) = do + if isOperator (A.nameName n) + then + -- for operators, resolve the function name, based on the type + do let opDescrip = "\"" ++ (A.nameName n) ++ "\" " + ++ case length es of + 1 -> "unary" + 2 -> "binary" + n -> show n ++ "-ary" + + cs <- getCompState + + -- The nubBy will ensure that only one definition remains for each + -- set of type-arguments, and will keep the first definition in the + -- list (which will be the most recent) + possibles <- sequence + [ (do es' <- sequence + [do e' <- doActual m direct t e + checkActual (A.Formal A.ValAbbrev t (A.Name m "x")) + (A.ActualExpression e') + return e' + | (t, e) <- zip ts es] + return $ Right ((opFuncName, es'), ts) + ) `catchError` (return . Left) + | (raw, opFuncName, ts) <- nubBy ((==) `on` (\(op,_,ts) -> (op,ts))) $ csOperators cs + -- Must be right operator: + , raw == A.nameName n + -- Must be right arity: + , length ts == length es] + case splitEither possibles of + -- We want to be helpful and give the user an idea + -- of what we thought the types were, but we must + -- also be careful not to die while getting the + -- types (and thus missing the real error!) + (errs,[]) -> do tes <- sequence [astTypeOf e `catchError` (const $ return A.Infer) | e <- es] + diePC m $ formatCode ("No matching " ++ opDescrip ++ " operator definition found for types: %" + ++ " errors were: " ++ show errs) tes + (_, [poss]) -> return $ fst poss + (_, posss) -> dieP m $ "Ambigious " ++ opDescrip ++ " operator, matches definitions: " + ++ show (map (transformPair (A.nameMeta . fst) showOccam) posss) + else + do (_, fs) <- checkFunction m n + doActuals m n fs direct es >>* (,) n + where + direct = error "Cannot direct channels passed to FUNCTIONs" doActuals :: Data a => Meta -> A.Name -> [A.Formal] -> (Meta -> A.Direction -> Transform a) -> Transform [a] doActuals m n fs applyDir as = do checkActualCount m n fs as - sequence [case t of - A.ChanEnd dir _ _ -> recurse a >>= applyDir m dir - _ -> inTypeContext (Just t) $ recurse a - | (A.Formal _ t _, a) <- zip fs as] + sequence [doActual m applyDir t a | (A.Formal _ t _, a) <- zip fs as] + + doActual :: Data a => Meta -> (Meta -> A.Direction -> Transform a) -> A.Type -> Transform a + doActual m applyDir (A.ChanEnd dir _ _) a = recurse a >>= applyDir m dir + doActual m _ t a = inTypeContext (Just t) $ recurse a + doDimension :: Transform A.Dimension doDimension dim = inTypeContext (Just A.Int) $ descend dim @@ -813,8 +777,8 @@ inferTypes = occamOnlyPass "Infer types" doExpressionList ts el = case el of A.FunctionCallList m n es -> - do es' <- doFunctionCall m n es - return $ A.FunctionCallList m n es' + do (n', es') <- doFunctionCall m (n, es) + return $ A.FunctionCallList m n' es' A.ExpressionList m es -> do es' <- sequence [inTypeContext (Just t) $ recurse e | (t, e) <- zip ts es] @@ -853,7 +817,16 @@ inferTypes = occamOnlyPass "Infer types" = do st' <- runReaderT (doSpecType n st) body -- Update the definition of each name after we handle it. modifyName n (\nd -> nd { A.ndSpecType = st' }) - recurse body >>* A.Spec mspec (A.Specification m n st') + let doBody = recurse body >>* A.Spec mspec (A.Specification m n st') + mOp <- functionOperator n + case (st, mOp) of + (A.Function _ _ _ fs _, Just raw) -> do + ts <- mapM astTypeOf fs + modify $ \cs -> cs { csOperators = (raw, n, ts) : csOperators cs } + x <- doBody + modify $ \cs -> cs { csOperators = tail (csOperators cs)} + return x + _ -> doBody doStructured s = descend s doSpecType :: Data a => A.Name -> A.SpecType -> ReaderT (A.Structured a) PassM A.SpecType @@ -1290,8 +1263,6 @@ checkExpressions :: PassType checkExpressions = checkDepthM doExpression where doExpression :: Check A.Expression - doExpression (A.Monadic _ op e) = checkMonadicOp op e - doExpression (A.Dyadic _ op le re) = checkDyadicOp op le re doExpression (A.MostPos m t) = checkNumeric m t doExpression (A.MostNeg m t) = checkNumeric m t doExpression (A.SizeType m t) = checkSequence True m t