tock-mirror/frontends/OccamInferTypes.hs

865 lines
39 KiB
Haskell

{-
Tock: a compiler for parallel languages
Copyright (C) 2008, 2009 University of Kent
This program is free software; you can redistribute it and/or modify it
under the terms of the GNU General Public License as published by the
Free Software Foundation, either version 2 of the License, or (at your
option) any later version.
This program is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
General Public License for more details.
You should have received a copy of the GNU General Public License along
with this program. If not, see <http://www.gnu.org/licenses/>.
-}
-- | The occam typechecker.
module OccamInferTypes (inferTypes, addDirections) where
import Control.Monad.Error
import Control.Monad.Reader
import Control.Monad.State
import Data.Generics (Data)
import Data.IORef
import Data.List
import Data.Maybe
import qualified Data.Traversable as T
import qualified AST as A
import CompState
import Errors
import Intrinsics
import Metadata
import OccamCheckTypes
import Pass
import qualified Properties as Prop
import ShowCode
import Traversal
import Types
import Utils
-- | Pick the more specific of a pair of types.
betterType :: A.Type -> A.Type -> A.Type
betterType A.Infer t = t
betterType t A.Infer = t
betterType t@(A.UserDataType _) _ = t
betterType _ t@(A.UserDataType _) = t
betterType t1@(A.Array ds1 et1) t2@(A.Array ds2 et2)
| length ds1 == length ds2
= A.Array (zipWith betterDim ds1 ds2) $ betterType et1 et2
| length ds1 < length ds2 = t1
where
betterDim A.UnknownDimension d@(A.Dimension _) = d
-- All other cases (both unknown, right is unknown, both known), use left:
betterDim d _ = d
betterType t _ = t
--}}}
--{{{ type context management
-- | Run an operation in a given type context.
inTypeContext :: Maybe A.Type -> PassM a -> PassM a
inTypeContext ctx body
= do pushTypeContext (case ctx of
Just A.Infer -> Nothing
_ -> ctx)
v <- body
popTypeContext
return v
-- | Run an operation in the type context 'Nothing'.
noTypeContext :: PassM a -> PassM a
noTypeContext = inTypeContext Nothing
-- | Run an operation in the type context that results from subscripting
-- the current type context.
-- If the current type context is 'Nothing', the resulting one will be too.
inSubscriptedContext :: Meta -> PassM a -> PassM a
inSubscriptedContext m body
= do ctx <- getTypeContext
subCtx <- case ctx of
Just t@(A.Array _ _) ->
trivialSubscriptType m t >>* Just
Just t -> diePC m $ formatCode "Attempting to subscript non-array type %" t
Nothing -> return Nothing
inTypeContext subCtx body
--}}}
addDirections :: PassOn2 A.Process A.Alternative
addDirections = occamOnlyPass "Add direction specifiers to inputs and outputs"
[] []
(applyBottomUpM2 doProcess doAlternative)
where
doProcess :: Transform A.Process
doProcess (A.Output m v os)
= do v' <- makeEnd m A.DirOutput v
return $ A.Output m v' os
doProcess (A.OutputCase m v n os)
= do v' <- makeEnd m A.DirOutput v
return $ A.OutputCase m v' n os
doProcess (A.Input m v im@(A.InputSimple {}))
= do v' <- makeEnd m A.DirInput v
return $ A.Input m v' im
doProcess (A.Input m v im@(A.InputCase {}))
= do v' <- makeEnd m A.DirInput v
return $ A.Input m v' im
doProcess p = return p
doAlternative :: Transform A.Alternative
doAlternative (A.Alternative m pre v im p)
= do v' <- case im of
A.InputSimple {} -> makeEnd m A.DirInput v
A.InputCase {} -> makeEnd m A.DirInput v
_ -> return v
return $ A.Alternative m pre v' im p
doAlternative a = return a
makeEnd :: Meta -> A.Direction -> Transform A.Variable
makeEnd m dir v
= case v of
A.SubscriptedVariable _ _ innerV
-> do t <- astTypeOf innerV
case t of
A.ChanDataType {} -> return v
_ -> makeEnd'
_ -> makeEnd'
where
makeEnd' :: PassM A.Variable
makeEnd'
= do t <- astTypeOf v
case t of
A.ChanEnd {} -> return v
A.Chan {} -> return $ A.DirectedVariable m dir v
A.Array _ (A.ChanEnd {}) -> return v
A.Array _ (A.Chan {}) -> return $ A.DirectedVariable m dir v
-- If unsure (e.g. Infer), just shove a direction on it to be sure:
_ -> return $ A.DirectedVariable m dir v
scrubMobile :: PassM a -> PassM a
scrubMobile m
= do ctx <- getTypeContext
case ctx of
(Just (A.Mobile t)) -> inTypeContext (Just t) m
_ -> m
inferAllocMobile :: Meta -> A.Type -> A.Expression -> PassM A.Expression
inferAllocMobile m (A.Mobile {}) e
= do t <- astTypeOf e >>= underlyingType m
case t of
A.Mobile {} -> return e
_ -> return $ A.AllocMobile m (A.Mobile t) (Just e)
inferAllocMobile _ _ e = return e
--{{{ inferTypes
-- I can't put this in the where clause of inferTypes, so it has to be out
-- here. It should be the type of ops inside the inferTypes function below.
type InferTypeOps
= ExtOpMSP BaseOp
`ExtOpMP` A.Expression
`ExtOpMP` A.Dimension
`ExtOpMP` A.Subscript
`ExtOpMP` A.Replicator
`ExtOpMP` A.Alternative
`ExtOpMP` A.Process
`ExtOpMP` A.Variable
`ExtOpMP` A.Variant
-- | Infer types.
inferTypes :: Pass A.AST
inferTypes = occamOnlyPass "Infer types"
[]
[Prop.inferredTypesRecorded]
recurse
where
ops :: InferTypeOps
ops = baseOp
`extOpMS` (ops, doStructured)
`extOpM` doExpression
`extOpM` doDimension
`extOpM` doSubscript
`extOpM` doReplicator
`extOpM` doAlternative
`extOpM` doProcess
`extOpM` doVariable
`extOpM` doVariant
recurse :: RecurseM PassM InferTypeOps
recurse = makeRecurseM ops
descend :: DescendM PassM InferTypeOps
descend = makeDescendM ops
doExpression :: Transform A.Expression
doExpression outer
= case outer of
-- Literals are what we're really looking for here.
A.Literal m t lr ->
do t' <- recurse t
scrubMobile $ do
ctx <- getTypeContext
let wantT = case (ctx, t') of
-- No type specified on the literal,
-- but there's a context, so use that.
(Just ct, A.Infer) -> ct
-- Use the explicit type of the literal, or the
-- default.
_ -> t'
(realT, realLR) <- doLiteral (wantT, lr)
return $ A.Literal m realT realLR
-- Expressions that aren't literals, but that modify the type
-- context.
A.SizeExpr _ _ -> noTypeContext $ descend outer
A.Conversion _ _ _ _ -> noTypeContext $ descend outer
A.FunctionCall m n es ->
do (n', es') <- doFunctionCall m (n, es)
return $ A.FunctionCall m n' es'
A.IntrinsicFunctionCall _ _ _ -> noTypeContext $ descend outer
A.SubscriptedExpr m s e ->
do ctx <- getTypeContext
e' <- inTypeContext (ctx >>= unsubscriptType s) $ recurse e
t <- astTypeOf e'
s' <- recurse s >>= fixSubscript t
return $ A.SubscriptedExpr m s' e'
A.BytesInExpr _ _ -> noTypeContext $ descend outer
-- FIXME: ExprConstr
-- FIXME: AllocMobile
A.ExprVariable m v ->
do ctx <- getTypeContext
v' <- recurse v
derefVariableIfNeeded ctx v' >>* A.ExprVariable m
-- Other expressions don't modify the type context.
_ -> descend outer
doFunctionCall :: Meta -> Transform (A.Name, [A.Expression])
doFunctionCall m (n, es) = do
if isOperator (A.nameName n)
then
-- for operators, resolve the function name, based on the type
do let opDescrip = "\"" ++ (A.nameName n) ++ "\" "
++ case length es of
1 -> "unary"
2 -> "binary"
n -> show n ++ "-ary"
es' <- noTypeContext $ mapM recurse es
tes <- sequence [underlyingTypeOf m e `catchError` (const $ return A.Infer) | e <- es']
cs <- getCompState
resolvedOps <- sequence [ do ts' <- mapM (underlyingType m) ts
return (op, n, ts')
| (op, n, ts) <- csOperators cs
]
-- The nubBy will ensure that only one definition remains for each
-- set of type-arguments, and will keep the first definition in the
-- list (which will be the most recent)
possibles <- return
[ ((opFuncName, es'), ts)
| (raw, opFuncName, ts) <- nubBy opsMatch resolvedOps
-- Must be right operator:
, raw == A.nameName n
-- Must be right arity:
, length ts == length es
-- Must have right types:
, ts `typesEqForOp` tes
]
case possibles of
[] -> diePC m $ formatCode "No matching % operator definition found for types: %" opDescrip tes
[poss] -> return $ fst poss
posss -> dieP m $ "Ambigious " ++ opDescrip ++ " operator, matches definitions: "
++ show (map (transformPair (A.nameMeta . fst) showOccam) posss)
else
do (_, fs) <- checkFunction m n
doActuals m n fs (direct, const return) es >>* (,) n
where
direct = error "Cannot direct channels passed to FUNCTIONs"
opsMatch (opA, _, tsA) (opB, _, tsB) = (opA == opB) && (tsA `typesEqForOp` tsB)
typesEqForOp :: [A.Type] -> [A.Type] -> Bool
typesEqForOp tsA tsB = (length tsA == length tsB) && (and $ zipWith typeEqForOp tsA tsB)
typeEqForOp :: A.Type -> A.Type -> Bool
typeEqForOp (A.Array ds t) (A.Array ds' t')
= (length ds == length ds') && typeEqForOp t t'
typeEqForOp t t' = t == t'
doActuals :: (PolyplateM a InferTypeOps () PassM, Data a) => Meta -> A.Name -> [A.Formal] ->
(Meta -> A.Direction -> Transform a, A.Type -> Transform a) -> Transform [a]
doActuals m n fs applyDir_Deref as
= do checkActualCount m n fs as
sequence [doActual m applyDir_Deref t a | (A.Formal _ t _, a) <- zip fs as]
-- First function directs, second function dereferences if needed
doActual :: (PolyplateM a InferTypeOps () PassM, Data a) =>
Meta -> (Meta -> A.Direction -> Transform a, A.Type -> Transform a) -> A.Type -> Transform a
doActual m (applyDir, _) (A.ChanEnd dir _ _) a = recurse a >>= applyDir m dir
doActual m (_, deref) t a = inTypeContext (Just t) $ recurse a >>= deref t
doDimension :: Transform A.Dimension
doDimension dim = inTypeContext (Just A.Int) $ descend dim
doSubscript :: Transform A.Subscript
doSubscript s = inTypeContext (Just A.Int) $ descend s
doExpressionList :: [A.Type] -> Transform A.ExpressionList
doExpressionList ts el
= case el of
A.FunctionCallList m n es ->
do (n', es') <- doFunctionCall m (n, es)
return $ A.FunctionCallList m n' es'
A.ExpressionList m es ->
do es' <- sequence [inTypeContext (Just t) $ recurse e
| (t, e) <- zip ts es]
es'' <- mapM (uncurry $ inferAllocMobile m) $ zip ts es'
return $ A.ExpressionList m es''
A.AllocChannelBundle {} -> return el
doReplicator :: Transform A.Replicator
doReplicator rep
= case rep of
A.For _ _ _ _ -> inTypeContext (Just A.Int) $ descend rep
A.ForEach _ _ -> noTypeContext $ descend rep
doAlternative :: Transform A.Alternative
doAlternative (A.Alternative m pre v im p)
= do pre' <- inTypeContext (Just A.Bool) $ recurse pre
v' <- recurse v >>= derefVariableIfNeeded Nothing
im' <- doInputMode v' im
p' <- recurse p
return $ A.Alternative m pre' v' im' p'
doAlternative (A.AlternativeSkip m pre p)
= do pre' <- inTypeContext (Just A.Bool) $ recurse pre
p' <- recurse p
return $ A.AlternativeSkip m pre' p'
doInputMode :: A.Variable -> Transform A.InputMode
doInputMode v (A.InputSimple m iis)
= do ts <- protocolItems m v >>* either id (const [])
iis' <- sequence [doInputItem t ii
| (t, ii) <- zip ts iis]
return $ A.InputSimple m iis'
doInputMode v (A.InputCase m sv)
= do ct <- astTypeOf v
inTypeContext (Just ct) (recurse sv) >>* A.InputCase m
doInputMode _ (A.InputTimerRead m ii)
= doInputItem A.Int ii >>* A.InputTimerRead m
doInputMode _ im = inTypeContext (Just A.Int) $ descend im
doInputItem :: A.Type -> Transform A.InputItem
doInputItem t (A.InVariable m v)
= (inTypeContext (Just t) (recurse v)
>>= derefVariableIfNeeded (Just t)
) >>* A.InVariable m
doInputItem t (A.InCounted m cv av)
= do cv' <- inTypeContext (Just A.Int) (recurse cv)
>>= derefVariableIfNeeded (Just A.Int)
av' <- inTypeContext (Just t) (recurse av)
>>= derefVariableIfNeeded (Just t)
return $ A.InCounted m cv' av'
doVariant :: Transform A.Variant
doVariant (A.Variant m n iis p)
= do ctx <- getTypeContext
ets <- case ctx of
Just x -> protocolItems m x
Nothing -> dieP m "Could not deduce protocol"
case ets of
Left {} -> dieP m "Simple protocol expected during input CASE"
Right ps -> case lookup n ps of
Nothing -> diePC m $ formatCode "Name % is not part of protocol %"
n (fromJust ctx)
Just ts -> do iis' <- sequence [doInputItem t ii
| (t, ii) <- zip ts iis]
p' <- recurse p
return $ A.Variant m n iis' p'
doStructured :: ( PolyplateM (A.Structured t) InferTypeOps () PassM
, PolyplateM (A.Structured t) () InferTypeOps PassM
, Data t) => Transform (A.Structured t)
doStructured (A.Spec mspec s@(A.Specification m n st) body)
= do (st', wrap) <- runReaderT (doSpecType n st) body
-- Update the definition of each name after we handle it.
modifyName n (\nd -> nd { A.ndSpecType = st' })
wrap (recurse body) >>* A.Spec mspec (A.Specification m n st')
doStructured s = descend s
-- The second parameter is a modifier (wrapper) for the descent into the body
doSpecType :: ( PolyplateM (A.Structured t) InferTypeOps () PassM
, PolyplateM (A.Structured t) () InferTypeOps PassM
, Data t) => A.Name -> A.SpecType -> ReaderT (A.Structured t) PassM
(A.SpecType, PassM (A.Structured a) -> PassM (A.Structured a))
doSpecType n st
= case st of
A.Place _ _ -> lift $ inTypeContext (Just A.Int) $ descend st >>* addId
A.Is m am t (A.ActualVariable v) ->
do am' <- lift $ recurse am
t' <- lift $ recurse t
v' <- lift $ inTypeContext (Just t') $ recurse v
>>= derefVariableIfNeeded (Just t')
vt <- lift $ astTypeOf v'
(t'', v'') <- case (t', vt) of
(A.Infer, A.Chan attr innerT) ->
do dirs <- ask >>= (lift . findDir n)
case nub dirs of
[dir] ->
do let tEnd = A.ChanEnd dir (dirAttr dir attr) innerT
return (tEnd, A.DirectedVariable m dir v')
_ -> return (vt, v') -- no direction, or two
(A.Infer, _) -> return (vt, v')
(A.ChanEnd dir _ _, _) -> do v'' <- lift $ makeEnd m dir v'
return (t', v'')
(A.Array _ (A.ChanEnd dir _ _), _) ->
do v'' <- lift $ makeEnd m dir v'
return (t', v'')
(A.Chan cattr cinnerT, A.ChanEnd dir _ einnerT)
-> do cinnerT' <- lift $ recurse cinnerT
einnerT' <- lift $ recurse einnerT
if cinnerT' /= einnerT'
then lift $ diePC m $ formatCode "Inner types of channels do not match in type inference: % %" cinnerT' einnerT'
else return (vt, v')
(A.Chan attr innerT, A.Chan {}) ->
do dirs <- ask >>= (lift . findDir n)
case nub dirs of
[dir] ->
do let tEnd = A.ChanEnd dir (dirAttr dir attr) innerT
return (tEnd, A.DirectedVariable m dir v')
_ -> return (t', v') -- no direction, or two
_ -> return (t', v')
return $ addId $ A.Is m am' t'' $ A.ActualVariable v''
A.Is m am t (A.ActualExpression e) -> lift $
do am' <- recurse am
t' <- recurse t
e' <- inTypeContext (Just t') $ recurse e
t'' <- case t' of
A.Infer -> astTypeOf e'
A.Array ds _ | A.UnknownDimension `elem` ds -> astTypeOf e'
_ -> return t'
return $ addId $ A.Is m am' t'' (A.ActualExpression e')
A.Is m am t (A.ActualClaim v) -> lift $
do am' <- recurse am
t' <- recurse t
v' <- inTypeContext (Just t') $ recurse v
t'' <- case t' of
A.Infer -> astTypeOf (A.ActualClaim v')
_ -> return t'
am'' <- case t'' of
-- CLAIMed channel bundles are ValAbbrev, as they may
-- not be altered:
A.ChanDataType {} ->
do modifyName n $ \nd -> nd { A.ndAbbrevMode = A.ValAbbrev }
return A.ValAbbrev
-- CLAIMed normal channels are as before:
_ -> return am'
return $ addId $ A.Is m am'' t'' (A.ActualClaim v')
A.Is m am t (A.ActualChannelArray vs) ->
-- No expressions in this -- but we may need to infer the type
-- of the variable if it's something like "cs IS [c]:".
do t' <- lift $ recurse t
vs' <- lift $ mapM recurse vs >>= case t' of
A.Infer -> return
A.Array _ (A.Chan {}) -> return
A.Array _ (A.ChanEnd dir _ _) -> mapM (makeEnd m dir)
_ -> const $ dieP m "Cannot coerce non-channels into channels"
let dim = makeDimension m $ length vs'
t'' <- lift $ case (t', vs') of
(A.Infer, (v:_)) ->
do elemT <- astTypeOf v
return $ addDimensions [dim] elemT
(A.Infer, []) ->
dieP m "Cannot infer type of empty channel array"
_ -> return $ applyDimension dim t'
(t''', f) <- case t'' of
A.Array ds (A.Chan attr innerT) -> do
dirs <- ask >>= (lift . findDir n)
case nub dirs of
[dir] -> return (A.Array ds $ A.ChanEnd dir (dirAttr dir attr) innerT
,A.DirectedVariable m dir)
_ -> return (t'', id)
_ -> return (t'', id)
return $ addId $ A.Is m am t''' $ A.ActualChannelArray $ map f vs'
A.Function m sm ts fs mbody -> lift $
do sm' <- recurse sm
ts' <- recurse ts
fs' <- recurse fs
sel' <- case mbody of
Just (Left sel) -> doFuncDef ts sel >>* (Just . Left)
_ -> return mbody
mOp <- functionOperator n
let func = A.Function m sm' ts' fs' sel'
case mOp of
Just raw -> do
ts <- mapM astTypeOf fs
let before, after :: PassM ()
before = modify $ \cs -> cs { csOperators = (raw, n, ts) : csOperators cs }
after = modify $ \cs -> cs { csOperators = tail (csOperators cs)}
return (func
,\m -> do before
x <- m
after
return x)
_ -> return func >>* addId
A.Retypes m am t v -> lift $ inTypeContext (Just t) $
(recurse v >>= derefVariableIfNeeded (Just t)) >>*
(addId . A.Retypes m am t)
A.RetypesExpr _ _ _ _ -> lift $ noTypeContext $ descend st >>* addId
-- For PROCs that take any channels without direction,
-- we must determine if we can infer a specific direction
-- for that channel
A.Proc m sm fs body -> lift $
do body' <- recurse body
fs' <- mapM (processFormal body') fs
return $ addId $ A.Proc m sm fs' body'
where
processFormal body f@(A.Formal am t n)
= do t' <- recurse t
case t' of
A.Chan attr innerT ->
do dirs <- findDir n body
case nub dirs of
[dir] ->
do let t' = A.ChanEnd dir (dirAttr dir attr) innerT
f' = A.Formal am t' n
modifyName n (\nd -> nd {A.ndSpecType =
A.Declaration m t'})
return f'
_ -> return $ A.Formal am t' n -- no direction, or two
_ -> do modifyName n (\nd -> nd {A.ndSpecType =
A.Declaration m t'})
return $ A.Formal am t' n
_ -> lift $ descend st >>* addId
where
addId :: a -> (a, b -> b)
addId a = (a, id)
-- | This is a bit ugly: walk down a Structured to find the single
-- ExpressionList that must be in there.
-- (This can go away once we represent all functions in the new Process
-- form.)
doFuncDef :: [A.Type] -> Transform (A.Structured A.ExpressionList)
doFuncDef ts (A.Spec m (A.Specification m' n st) s)
= do (st', wrap) <- runReaderT (doSpecType n st) s
modifyName n (\nd -> nd { A.ndSpecType = st' })
s' <- wrap $ doFuncDef ts s
return $ A.Spec m (A.Specification m' n st') s'
doFuncDef ts (A.ProcThen m p s)
= do p' <- recurse p
s' <- doFuncDef ts s
return $ A.ProcThen m p' s'
doFuncDef ts (A.Only m el)
= do el' <- doExpressionList ts el
return $ A.Only m el'
-- findDir only really needs to descend operating on Variables
-- But since this is called by doStructured, that would require doStructured
-- to have an extra constraint that the Structured supports descent into
-- Variables. But that constraint, in turn, is not satisfied when we build
-- our ops using extOpMS. Rather than fix all the constraints, I've decided
-- to adopt a slightly sneaky approach, and build a set of ops for findDir
-- with the same type as the one for infer types (thus the constraints
-- don't change), but where everything apart from the Variable operation
-- is a call to descend.
--
-- Also, to fit with the normal ops, we must do so in the PassM monad.
-- Normally we would do this pass in a StateT monad, but to slip inside
-- PassM, I've used an IORef instead.
findDir :: ( PolyplateM a InferTypeOps () PassM
, PolyplateM a () InferTypeOps PassM
) => A.Name -> a -> PassM [A.Direction]
findDir n x
= do r <- liftIO $ newIORef []
makeRecurseM (makeOps r) x
liftIO $ readIORef r
where
makeOps :: IORef [A.Direction] -> InferTypeOps
makeOps r = ops
where
ops :: InferTypeOps
ops = baseOp
`extOpMS` (ops, descend)
`extOpM` descend
`extOpM` descend
`extOpM` descend
`extOpM` descend
`extOpM` descend
`extOpM` descend
`extOpM` (doVariable r)
`extOpM` descend
descend :: DescendM PassM InferTypeOps
descend = makeDescendM ops
-- This will cover everything, since we will have inferred the direction
-- specifiers before applying this function.
doVariable :: IORef [A.Direction] -> A.Variable -> PassM A.Variable
doVariable r v@(A.DirectedVariable _ dir (A.Variable _ n')) | n == n'
= liftIO $ modifyIORef r (dir:) >> return v
doVariable r v@(A.DirectedVariable _ dir
(A.SubscriptedVariable _ _ (A.Variable _ n'))) | n == n'
= liftIO $ modifyIORef r (dir:) >> return v
doVariable r v = makeDescendM (makeOps r) v
doProcess :: Transform A.Process
doProcess p
= case p of
A.Assign m vs el ->
-- We do not dereference variables on the LHS of an assignment,
-- instead we promote the things on the RHS to allocations if
-- needed. After all, if the user does something like:
-- xs := "flibble"
-- where xs is a mobile array, we definitely want to allocate
-- the RHS, rather than dereference the possibly undefined LHS.
do vs' <- noTypeContext $ recurse vs
ts <- mapM astTypeOf vs'
el' <- doExpressionList ts el
return $ A.Assign m vs' el'
-- We don't dereference any of the channel variables, the backend can
-- handle that.
A.Output m v ois ->
do v' <- recurse v
-- At this point we must resolve the "c ! x" ambiguity:
-- we definitely know what c is, and we must know what x is
-- before trying to infer its type.
tagged <- isTagged v'
if tagged
-- Tagged protocol -- convert (wrong) variable to tag.
then case ois of
((A.OutExpression _ (A.ExprVariable _ (A.Variable _ wrong))):ois) ->
do tag <- nameToUnscoped wrong
ois' <- doOutputItems m v' (Just tag) ois
return $ A.OutputCase m v' tag ois'
_ -> diePC m $ formatCode "This channel carries a variant protocol; expected a list starting with a tag, but found %" ois
-- Regular protocol -- proceed as before.
else do ois' <- doOutputItems m v' Nothing ois
return $ A.Output m v' ois'
A.OutputCase m v tag ois ->
do v' <- recurse v
ois' <- doOutputItems m v' (Just tag) ois
return $ A.OutputCase m v' tag ois'
A.If _ _ -> inTypeContext (Just A.Bool) $ descend p
A.Case m e so ->
do e' <- recurse e
t <- astTypeOf e'
so' <- inTypeContext (Just t) $ recurse so
return $ A.Case m e' so'
A.While _ _ _ -> inTypeContext (Just A.Bool) $ descend p
A.Processor _ _ _ -> inTypeContext (Just A.Int) $ descend p
A.ProcCall m n as ->
do fs <- checkProc m n
as' <- doActuals m n fs
(\m dir (A.ActualVariable v) -> liftM A.ActualVariable $ makeEnd m dir v
,\t a -> case a of
A.ActualVariable v -> derefVariableIfNeeded (Just t) v >>* A.ActualVariable
_ -> return a
) as
return $ A.ProcCall m n as'
p@(A.IntrinsicProcCall m n as) ->
case lookup n intrinsicProcs of
Nothing -> descend p -- Will fail type-checking anyway
Just params -> sequence [inTypeContext (Just t) $
case a of
A.ActualVariable v ->
(recurse v >>= derefVariableIfNeeded (Just t)) >>* A.ActualVariable
_ -> descend a
| (a, (_,t,_)) <- zip as params] >>* A.IntrinsicProcCall m n
A.Input m v im@(A.InputSimple {})
-> do v' <- recurse v
im' <- doInputMode v' im
return $ A.Input m v' im'
A.Input m v im@(A.InputCase {})
-> do v' <- recurse v
im' <- doInputMode v' im
return $ A.Input m v' im'
_ -> descend p
where
-- | Does a channel carry a tagged protocol?
isTagged :: A.Variable -> PassM Bool
isTagged c
= do protoT <- checkChannel A.DirOutput c
case protoT of
A.UserProtocol n ->
do st <- specTypeOfName n
case st of
A.ProtocolCase _ _ -> return True
_ -> return False
_ -> return False
doOutputItems :: Meta -> A.Variable -> Maybe A.Name
-> Transform [A.OutputItem]
doOutputItems m v tag ois
= do chanT <- checkChannel A.DirOutput v
ts <- protocolTypes m chanT tag
sequence [doOutputItem t oi | (t, oi) <- zip ts ois]
doOutputItem :: A.Type -> Transform A.OutputItem
doOutputItem (A.Counted ct at) (A.OutCounted m ce ae)
= do ce' <- inTypeContext (Just ct) $ recurse ce
ae' <- inTypeContext (Just at) $ recurse ae
return $ A.OutCounted m ce' ae'
doOutputItem A.Any o = noTypeContext $ recurse o
doOutputItem t (A.OutExpression m e)
= inTypeContext (Just t) (recurse e >>= inferAllocMobile m t)
>>* A.OutExpression m
doVariable :: Transform A.Variable
doVariable (A.SubscriptedVariable m s v)
= do v' <- noTypeContext (recurse v) >>= derefVariableIfNeeded Nothing
t <- astTypeOf v'
s' <- recurse s >>= fixSubscript t
return $ A.SubscriptedVariable m s' v'
doVariable v = descend v
derefVariableIfNeeded :: Maybe (A.Type) -> A.Variable -> PassM A.Variable
derefVariableIfNeeded ctxOrig v
= do ctx <- (T.sequence . fmap (resolveUserType (findMeta v))) ctxOrig
underT <- astTypeOf v >>= resolveUserType (findMeta v)
case (ctx, underT) of
(Just (A.Mobile {}), A.Mobile {}) -> return v
(_, A.Mobile {}) -> return $ A.DerefVariable (findMeta v) v
_ -> return v
-- | Resolve the @v[s]@ ambiguity: this takes the type that @v@ is, and
-- returns the correct 'Subscript'.
fixSubscript :: A.Type -> A.Subscript -> PassM A.Subscript
fixSubscript t s@(A.Subscript m _ (A.ExprVariable _ (A.Variable _ wrong)))
= do underT <- resolveUserType m t
case underT of
A.Record _ ->
do n <- nameToUnscoped wrong
return $ A.SubscriptField m n
A.ChanDataType {} ->
do n <- nameToUnscoped wrong
return $ A.SubscriptField m n
_ -> return s
fixSubscript _ s = return s
-- | Given a name that should really have been a tag, make it one.
nameToUnscoped :: A.Name -> PassM A.Name
nameToUnscoped n@(A.Name m _)
= do nd <- lookupName n
findUnscopedName (A.Name m (A.ndOrigName nd))
-- | Process a 'LiteralRepr', taking the type it's meant to represent or
-- 'Infer', and returning the type it really is.
doLiteral :: Transform (A.Type, A.LiteralRepr)
doLiteral (wantT, lr)
= case lr of
A.ArrayListLiteral m aes ->
do (t, aes') <-
doArrayElem wantT aes
lr' <- case aes' of
A.Several _ ss -> buildTable t ss
_ -> return $ A.ArrayListLiteral m aes'
return (t, lr')
_ ->
do lr' <- descend lr
(defT, isT) <-
case lr' of
A.RealLiteral _ _ -> return (A.Real32, isRealType)
A.IntLiteral _ _ -> return (A.Int, isIntegerType)
A.HexLiteral _ _ -> return (A.Int, isIntegerType)
A.ByteLiteral _ _ -> return (A.Byte, isIntegerType)
_ -> dieP m $ "Unexpected LiteralRepr: " ++ show lr'
underT <- resolveUserType m wantT
case (wantT, isT underT) of
(A.Infer, _) -> return (defT, lr')
(_, True) -> return (wantT, lr')
(_, False) -> diePC m $ formatCode "Literal of default type % is not valid for type %" defT wantT
where
m = findMeta lr
doArrayElem :: A.Type -> A.Structured A.Expression -> PassM (A.Type, A.Structured A.Expression)
doArrayElem wantT (A.Spec m spec body)
-- A replicator: strip off a subscript and keep going
= do underT <- resolveUserType m wantT
subT <- trivialSubscriptType m underT
dim <- case underT of
A.Array (dim:_) _ -> return dim
A.Infer -> return A.UnknownDimension
_ -> diePC m $ formatCode "Unexpected type in array constructor: %" underT
(t, body') <- doArrayElem subT body
specAndBody' <- doStructured $ A.Spec m spec body'
return (applyDimension dim wantT, specAndBody')
-- A table: this could be an array or a record.
doArrayElem wantT (A.Several m aes)
= do underT <- resolveUserType m wantT
case underT of
A.Array _ _ ->
do subT <- trivialSubscriptType m underT
(elemT, aes') <- doElems subT aes
let dim = makeDimension m (length aes)
return (addDimensions [dim] elemT, A.Several m aes')
A.Record _ ->
do nts <- recordFields m underT
aes <- sequence [doArrayElem t ae >>* snd
| ((_, t), ae) <- zip nts aes]
return (wantT, A.Several m aes)
-- If we don't know, assume it's an array.
A.Infer ->
do (elemT, aes') <- doElems A.Infer aes
when (elemT == A.Infer) $
dieP m "Cannot infer type of (empty?) array"
let dims = [makeDimension m (length aes)]
return (addDimensions dims elemT,
A.Several m aes')
_ -> diePC m $ formatCode "Table literal is not valid for type %" wantT
where
doElems :: A.Type -> [A.Structured A.Expression] -> PassM (A.Type, [A.Structured A.Expression])
doElems t aes
= do ts <- mapM (\ae -> doArrayElem t ae >>* fst) aes
let bestT = foldl betterType t ts
aes' <- mapM (\ae -> doArrayElem bestT ae >>* snd) aes
return (bestT, aes')
-- An expression: descend into it with the right context.
doArrayElem wantT (A.Only m e)
= do e' <- inTypeContext (Just wantT) $ doExpression e
t <- astTypeOf e'
checkType (findMeta e') wantT t
return (t, A.Only m e')
-- | Turn a raw table literal into the appropriate combination of
-- arrays and records.
buildTable :: A.Type -> [A.Structured A.Expression] -> PassM A.LiteralRepr
buildTable t aes
= do underT <- resolveUserType m t
case underT of
A.Array _ _ ->
do elemT <- trivialSubscriptType m t
aes' <- mapM (buildElem elemT) aes
return $ A.ArrayListLiteral m $ A.Several m aes'
A.Record _ ->
do nts <- recordFields m underT
aes' <- sequence [buildExpr elemT ae
| ((_, elemT), ae) <- zip nts aes]
return $ A.RecordLiteral m aes'
where
buildExpr :: A.Type -> A.Structured A.Expression -> PassM A.Expression
buildExpr t (A.Several _ aes)
= do lr <- buildTable t aes
return $ A.Literal m t lr
buildExpr _ (A.Only _ e) = return e
buildElem :: A.Type -> A.Structured A.Expression -> PassM (A.Structured A.Expression)
buildElem t ae
= do underT <- resolveUserType m t
case (underT, ae) of
(A.Array _ _, A.Several _ aes) ->
do A.ArrayListLiteral _ aes' <- buildTable t aes
return aes'
(A.Record _, A.Several {}) ->
do e <- buildExpr t ae
return $ A.Only m e
(_, A.Only {}) -> return ae
--}}}