diff --git a/frontends/RainTypes.hs b/frontends/RainTypes.hs index 33b098f..1e1a726 100644 --- a/frontends/RainTypes.hs +++ b/frontends/RainTypes.hs @@ -89,9 +89,41 @@ annnotateIntLiteralTypes = applyDepthM doExpression -- | Annotates all list literals and list ranges with their type annotateListLiteralTypes :: Data t => t -> PassM t -annotateListLiteralTypes = return - --- | A pass that finds all the 'A.ProcCall' and 'A.FunctionCall' in the AST, and checks that the actual parameters are valid inputs, given the 'A.Formal' parameters in the process's type +annotateListLiteralTypes = everywhereASTM doExpression + where + doExpression :: A.Expression -> PassM A.Expression + doExpression (A.Literal m _ (A.ListLiteral m' es)) + = do ts <- mapM typeOfExpression es + sharedT <- case leastGeneralSharedTypeRain ts of + Just t -> return t + Nothing -> diePC m' + $ formatCode + "Can't determine a common type for the list literal from: %" + ts + es' <- mapM (coerceIfNecessary sharedT) (zip ts es) + return $ A.Literal m (A.List sharedT) $ A.ListLiteral m' es' + doExpression (A.ExprConstr m (A.RangeConstr m' t b e)) + = do bt <- typeOfExpression b + et <- typeOfExpression e + sharedT <- case leastGeneralSharedTypeRain [bt, et] of + Just t -> return t + Nothing -> diePC m' + $ formatCode + "Can't determine a common type for the range from: % %" + bt et + b' <- coerceIfNecessary sharedT (bt, b) + e' <- coerceIfNecessary sharedT (et, e) + return $ A.ExprConstr m $ A.RangeConstr m' (A.List sharedT) b' e' + doExpression e = return e + + coerceIfNecessary :: A.Type -> (A.Type, A.Expression) -> PassM A.Expression + coerceIfNecessary to (from, e) + | to == from = return e + | otherwise = coerceType " in list literal" to from e + +-- | A pass that finds all the 'A.ProcCall' and 'A.FunctionCall' in the +-- AST, and checks that the actual parameters are valid inputs, given +-- the 'A.Formal' parameters in the process's type matchParamPass :: Data t => t -> PassM t matchParamPass = everywhereM ((mkM matchParamPassProc) `extM` matchParamPassFunc) where