{- Tock: a compiler for parallel languages Copyright (C) 2007 University of Kent This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 2 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . -} module RainTypes (constantFoldPass,performTypeUnification,recordInfNameTypes) where import Control.Monad.State import Data.Generics import qualified Data.Map as Map import Data.Maybe import Data.IORef import qualified AST as A import CompState import Errors import EvalConstants import Metadata import Pass import ShowCode import Traversal import Types import TypeUnification import UnifyType import Utils lookupMapElseMutVar :: UnifyIndex -> PassM (TypeExp A.Type) lookupMapElseMutVar k = do st <- get let m = csUnifyLookup st case Map.lookup k m of Just v -> return v Nothing -> do r <- liftIO $ newIORef Nothing let v = MutVar r m' = Map.insert k v m put st {csUnifyLookup = m'} return v ttte :: String -> (A.Type -> A.Type) -> A.Type -> PassM (TypeExp A.Type) ttte c f t = typeToTypeExp t >>= \t' -> return $ OperType c (\[x] -> f x) [t'] -- Transforms the given type into a typeexp, such that the only inner types -- left will be the primitive types (integer types, float types, bool, time). Arrays -- (which would require unification of dimensions and such) are not supported, -- neither are records. -- User data types should not be present in the input. typeToTypeExp :: A.Type -> PassM (TypeExp A.Type) typeToTypeExp (A.List t) = ttte "[]" A.List t typeToTypeExp (A.Chan A.DirInput at t) = ttte "?" (A.Chan A.DirInput at) t typeToTypeExp (A.Chan A.DirOutput at t) = ttte "!" (A.Chan A.DirOutput at) t typeToTypeExp (A.Chan A.DirUnknown at t) = ttte "channel" (A.Chan A.DirUnknown at) t typeToTypeExp (A.Mobile t) = ttte "MOBILE" A.Mobile t typeToTypeExp (A.UnknownVarType en) = case en of Left n -> lookupMapElseMutVar (UnifyIndex (A.nameMeta n, Right n)) Right (m, i) -> lookupMapElseMutVar (UnifyIndex (m, Left i)) typeToTypeExp (A.UnknownNumLitType m id n) = do r <- liftIO . newIORef $ Left [n] let v = NumLit r st <- get let mp = csUnifyLookup st put st {csUnifyLookup = Map.insert (UnifyIndex (m,Left id)) v mp} return v typeToTypeExp t = return $ OperType (show t) (const t) [] markUnify :: (Typed a, Typed b) => a -> b -> PassM () markUnify x y = do tx <- astTypeOf x ty <- astTypeOf y tex <- typeToTypeExp tx tey <- typeToTypeExp ty modify $ \st -> st {csUnifyPairs = (tex,tey) : csUnifyPairs st} performTypeUnification :: Data t => t -> PassM t performTypeUnification x = do -- First, we copy the known types into the unify map: st <- get ul <- shift $ csNames st put st {csUnifyPairs = [], csUnifyLookup = ul} -- Then we markup all the types in the tree: x' <- markConditionalTypes <.< markParamPass <.< markAssignmentTypes <.< markCommTypes <.< markReplicators <.< markExpressionTypes $ x -- Then, we do the unification: prs <- get >>* csUnifyPairs res <- liftIO $ mapM (uncurry unifyType) prs mapM (diePC emptyMeta) (fst $ splitEither res) -- Now put the types back in a map, and replace them through the tree: l <- get >>* csUnifyLookup ts <- mapMapWithKeyM (\(UnifyIndex(m,_)) v -> fromTypeExp m v) l get >>= substituteUnknownTypes ts >>= put substituteUnknownTypes ts x' where shift :: Map.Map String A.NameDef -> PassM (Map.Map UnifyIndex UnifyValue) shift = liftM (Map.fromList . catMaybes) . mapM shift' . Map.toList where shift' :: (String, A.NameDef) -> PassM (Maybe (UnifyIndex, UnifyValue)) shift' (rawName, d) = do mt <- typeOfSpec (A.ndType d) case mt of Nothing -> return Nothing Just t -> do te <- typeToTypeExp t return $ Just (UnifyIndex (A.ndMeta d, Right name), te) where name = A.Name {A.nameName = rawName, A.nameMeta = A.ndMeta d, A.nameType = A.ndNameType d} substituteUnknownTypes :: Data t => Map.Map UnifyIndex A.Type -> t -> PassM t substituteUnknownTypes mt = applyDepthM sub where sub :: A.Type -> PassM A.Type sub (A.UnknownVarType (Left n)) = lookup $ UnifyIndex (A.nameMeta n, Right n) sub (A.UnknownNumLitType m i _) = lookup $ UnifyIndex (m, Left i) sub t = return t lookup :: UnifyIndex -> PassM A.Type lookup u@(UnifyIndex(m,_)) = case Map.lookup u mt of Just t -> return t Nothing -> dieP m "Could not deduce type" -- | A pass that records inferred types. Currently the only place where types are inferred is in seqeach\/pareach loops. recordInfNameTypes :: Data t => t -> PassM t recordInfNameTypes = everywhereM (mkM recordInfNameTypes') where recordInfNameTypes' :: A.Replicator -> PassM A.Replicator recordInfNameTypes' input@(A.ForEach m n e) = do let innerT = A.UnknownVarType $ Left n defineName n A.NameDef {A.ndMeta = m, A.ndName = A.nameName n, A.ndOrigName = A.nameName n, A.ndNameType = A.VariableName, A.ndType = (A.Declaration m innerT), A.ndAbbrevMode = A.Abbrev, A.ndPlacement = A.Unplaced} return input recordInfNameTypes' r = return r markReplicators :: Data t => t -> PassM t markReplicators = checkDepthM mark where mark :: Check A.Replicator mark (A.ForEach _m n e) = astTypeOf n >>= \t -> markUnify (A.List t) e -- | Folds all constants. constantFoldPass :: Data t => t -> PassM t constantFoldPass = applyDepthM doExpression where doExpression :: A.Expression -> PassM A.Expression doExpression = (liftM (\(x,_,_) -> x)) . constantFold -- | A pass that finds all the 'A.ProcCall' and 'A.FunctionCall' in the -- AST, and checks that the actual parameters are valid inputs, given -- the 'A.Formal' parameters in the process's type markParamPass :: Data t => t -> PassM t markParamPass = checkDepthM2 matchParamPassProc matchParamPassFunc where --Picks out the parameters of a process call, checks the number is correct, and maps doParam over them matchParamPassProc :: Check A.Process matchParamPassProc (A.ProcCall m n actualParams) = do def <- lookupNameOrError n $ dieP m ("Process name is unknown: \"" ++ (show $ A.nameName n) ++ "\"") case A.ndType def of A.Proc _ _ expectedParams _ -> if (length expectedParams) == (length actualParams) then mapM_ (uncurry markUnify) (zip expectedParams actualParams) else dieP m $ "Wrong number of parameters given to process call; expected: " ++ show (length expectedParams) ++ " but found: " ++ show (length actualParams) _ -> dieP m $ "You cannot run things that are not processes, such as: \"" ++ (show $ A.nameName n) ++ "\"" matchParamPassProc _ = return () --Picks out the parameters of a function call, checks the number is correct, and maps doExpParam over them matchParamPassFunc :: Check A.Expression matchParamPassFunc (A.FunctionCall m n actualParams) = do def <- lookupNameOrError n $ dieP m ("Function name is unknown: \"" ++ (show $ A.nameName n) ++ "\"") case A.ndType def of A.Function _ _ _ expectedParams _ -> if (length expectedParams) == (length actualParams) then mapM_ (uncurry markUnify) (zip expectedParams actualParams) else dieP m $ "Wrong number of parameters given to function call; expected: " ++ show (length expectedParams) ++ " but found: " ++ show (length actualParams) _ -> dieP m $ "Attempt to make a function call with something" ++ " that is not a function: \"" ++ A.nameName n ++ "\"; is actually: " ++ showConstr (toConstr $ A.ndType def) matchParamPassFunc _ = return () -- | Checks the types in expressions markExpressionTypes :: Data t => t -> PassM t markExpressionTypes = checkDepthM checkExpression where -- TODO also check in a later pass that the op is valid checkExpression :: Check A.Expression checkExpression (A.Dyadic _ _ lhs rhs) = markUnify lhs rhs checkExpression _ = return () -- | Checks the types in assignments markAssignmentTypes :: Data t => t -> PassM t markAssignmentTypes = checkDepthM checkAssignment where checkAssignment :: Check A.Process checkAssignment (A.Assign m [v] (A.ExpressionList _ [e])) = do am <- abbrevModeOfVariable v when (am == A.ValAbbrev) $ diePC m $ formatCode "Cannot assign to a constant variable: %" v markUnify v e checkAssignment (A.Assign m _ _) = dieInternal (Just m,"Rain checker found occam-style assignment") checkAssignment st = return () -- | Checks the types in if and while conditionals markConditionalTypes :: Data t => t -> PassM t markConditionalTypes = checkDepthM2 checkWhile checkIf where checkWhile :: Check A.Process checkWhile w@(A.While m exp _) = markUnify exp A.Bool checkWhile _ = return () checkIf :: Check A.Choice checkIf c@(A.Choice m exp _) = markUnify exp A.Bool -- | Checks the types in inputs and outputs, including inputs in alts markCommTypes :: Data t => t -> PassM t markCommTypes = checkDepthM2 checkInputOutput checkAltInput where checkInput :: A.Variable -> A.Variable -> Meta -> a -> PassM () checkInput chanVar destVar m p = astTypeOf destVar >>= markUnify chanVar . A.Chan A.DirInput (A.ChanAttributes False False) checkWait :: Check A.InputMode checkWait (A.InputTimerFor m exp) = markUnify A.Time exp checkWait (A.InputTimerAfter m exp) = markUnify A.Time exp checkWait (A.InputTimerRead m (A.InVariable _ v)) = markUnify A.Time v checkWait _ = return () checkInputOutput :: Check A.Process checkInputOutput p@(A.Input m chanVar (A.InputSimple _ [A.InVariable _ destVar])) = checkInput chanVar destVar m p checkInputOutput (A.Input _ _ im@(A.InputTimerFor {})) = checkWait im checkInputOutput (A.Input _ _ im@(A.InputTimerAfter {})) = checkWait im checkInputOutput (A.Input _ _ im@(A.InputTimerRead {})) = checkWait im checkInputOutput p@(A.Output m chanVar [A.OutExpression m' srcExp]) = astTypeOf srcExp >>= markUnify chanVar . A.Chan A.DirOutput (A.ChanAttributes False False) checkInputOutput _ = return () checkAltInput :: Check A.Alternative checkAltInput a@(A.Alternative m _ chanVar (A.InputSimple _ [A.InVariable _ destVar]) body) = checkInput chanVar destVar m a checkAltInput (A.Alternative m _ _ im@(A.InputTimerFor {}) _) = checkWait im checkAltInput (A.Alternative m _ _ im@(A.InputTimerAfter {}) _) = checkWait im checkAltInput _ = return ()