242 lines
10 KiB
Haskell
242 lines
10 KiB
Haskell
{-
|
|
Tock: a compiler for parallel languages
|
|
Copyright (C) 2008 University of Kent
|
|
|
|
This program is free software; you can redistribute it and/or modify it
|
|
under the terms of the GNU General Public License as published by the
|
|
Free Software Foundation, either version 2 of the License, or (at your
|
|
option) any later version.
|
|
|
|
This program is distributed in the hope that it will be useful, but
|
|
WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
|
General Public License for more details.
|
|
|
|
You should have received a copy of the GNU General Public License along
|
|
with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
-}
|
|
|
|
module TypeUnification where
|
|
|
|
import Control.Monad
|
|
import Control.Monad.State
|
|
import Control.Monad.Trans
|
|
import Data.Generics
|
|
import qualified Data.Map as Map
|
|
import Data.Maybe
|
|
import Data.IORef
|
|
|
|
import qualified AST as A
|
|
import Errors
|
|
import Metadata
|
|
import Pass
|
|
import ShowCode
|
|
import UnifyType
|
|
import Utils
|
|
|
|
foldCon :: ([A.Type] -> A.Type) -> [Either String A.Type] -> Either String A.Type
|
|
foldCon con es = case splitEither es of
|
|
([],ts) -> Right $ con ts
|
|
((e:_),_) -> Left e
|
|
|
|
-- Much of the code in this module is taken from or based on Tim Sheard's Haskell
|
|
-- listing of a simple type unification algorithm at the beginning of his
|
|
-- paper "Generic Unification via Two-Level Types and Parameterized Modules Functional
|
|
-- Pearl (2001)", citeseer: http://citeseer.ist.psu.edu/451401.html
|
|
-- This in turn was taken from Luca Cardelli's "Basic Polymorphic Type Checking"
|
|
|
|
-- | Given a map from keys to non-unified types and a list of pairs of types to unify,
|
|
-- gives back the resulting map from keys to actual types.
|
|
unifyRainTypes :: forall k. (Ord k, Show k) => (Map.Map k (TypeExp A.Type)) -> [(k, k)] ->
|
|
PassM (Map.Map k A.Type)
|
|
unifyRainTypes m' prs
|
|
= do mapM_ (\(x,y) -> unifyType (lookupStartType x m') (lookupStartType y m')) prs
|
|
stToMap m'
|
|
where
|
|
lookupStartType :: k -> Map.Map k (TypeExp A.Type) -> TypeExp A.Type
|
|
lookupStartType s m = case Map.lookup s m of
|
|
Just x -> x
|
|
Nothing -> error $ "Could not find type for variable in map before unification: "
|
|
++ show s
|
|
|
|
-- | Given a map containing simplified types (no mutvar/pointers or numlits
|
|
-- remaining, just actual types), turns them into the actual type representations
|
|
stToMap :: Map.Map k (TypeExp A.Type) -> PassM (Map.Map k A.Type)
|
|
stToMap m = do m' <- mapMapWithKeyM (\k v -> prune Nothing v >>= liftIO . read k) m
|
|
let (mapOfErrs, mapOfRes) = Map.mapEitherWithKey (const id) m'
|
|
case Map.elems mapOfErrs of
|
|
((m,e):_) -> dieP m e
|
|
[] -> return mapOfRes
|
|
where
|
|
read :: k -> TypeExp A.Type -> IO (Either (Meta, String) A.Type)
|
|
read k (OperType m _ con vals)
|
|
= do vals' <- mapM (read k) vals
|
|
case foldCon con (map (either (Left . snd) Right) vals') of
|
|
Left e -> return $ Left (m, e)
|
|
Right x -> return $ Right x
|
|
read k (MutVar m v) = readIORef v >>= \(_,t) -> case t of
|
|
Nothing -> return $ Left (m, "Type error in unification, "
|
|
++ "ambigious type remains for: " ++ show k)
|
|
Just t' -> read k t'
|
|
read k (NumLit m v) = readIORef v >>= \x -> case x of
|
|
Left _ -> return $ Left (m, "Type error in unification, "
|
|
++ "ambigious type remains for numeric literal: " ++ show k)
|
|
Right t -> return $ Right t
|
|
|
|
fromTypeExp :: TypeExp A.Type -> PassM A.Type
|
|
fromTypeExp x = fromTypeExp' =<< prune Nothing x
|
|
where
|
|
fromTypeExp' :: TypeExp A.Type -> PassM A.Type
|
|
fromTypeExp' (MutVar m _) = dieP m "Unresolved type"
|
|
fromTypeExp' (GenVar m _) = dieP m "Template vars not yet supported"
|
|
fromTypeExp' (NumLit m v) = liftIO (readIORef v) >>= \x -> case x of
|
|
Left (n:_) -> dieP m $ "Ambigiously typed numeric literal: " ++ show n
|
|
Right t -> return t
|
|
fromTypeExp' (OperType _ _ f ts) = mapM fromTypeExp ts >>* f
|
|
|
|
-- For debugging:
|
|
showInErr :: TypeExp A.Type -> PassM String
|
|
showInErr (MutVar {}) = return "MutVar"
|
|
showInErr (GenVar {}) = return "GenVar"
|
|
showInErr (NumLit {}) = return "NumLit"
|
|
showInErr t@(OperType {}) = showCode =<< fromTypeExp t
|
|
|
|
giveErr :: Meta -> String -> TypeExp A.Type -> TypeExp A.Type -> PassM a
|
|
giveErr m msg tx ty
|
|
= do x <- showInErr tx
|
|
y <- showInErr ty
|
|
dieP m $ msg ++ x ++ " and " ++ y
|
|
|
|
-- | Merges two lots of attributes into a union of the requirements
|
|
mergeAttr :: A.TypeRequirements -> A.TypeRequirements -> A.TypeRequirements
|
|
mergeAttr (A.TypeRequirements p) (A.TypeRequirements p') = A.TypeRequirements (p || p')
|
|
|
|
-- | Checks the attributes match a non-mutvar variable
|
|
checkAttrMatch :: A.TypeRequirements -> TypeExp A.Type -> PassM ()
|
|
checkAttrMatch (A.TypeRequirements False) _ = return () -- no need to check
|
|
checkAttrMatch (A.TypeRequirements True) (NumLit m _)
|
|
= dieP m "Numeric literal can never be poisonable"
|
|
checkAttrMatch (A.TypeRequirements True) t@(OperType m str _ _)
|
|
= case str of
|
|
"?" -> return ()
|
|
"!" -> return ()
|
|
_ -> do err <- showInErr t
|
|
dieP m $ "Type cannot be poisoned: " ++ err
|
|
|
|
-- | Reduces chains of MutVars down to just a single pointer.
|
|
prune :: Maybe A.TypeRequirements -> TypeExp A.Type -> PassM (TypeExp A.Type)
|
|
prune attr mv@(MutVar m r)
|
|
= do (attr', x) <- liftIO $ readIORef r
|
|
let merged = maybe attr' (mergeAttr attr') attr
|
|
case x of
|
|
Nothing -> do liftIO $ writeIORef r (merged, Nothing)
|
|
return mv
|
|
Just t2 ->
|
|
do t' <- prune (Just merged) t2
|
|
liftIO $ writeIORef r (merged, Just t')
|
|
return t'
|
|
prune Nothing t = return t
|
|
prune (Just attr) t = checkAttrMatch attr t >> return t
|
|
|
|
-- | Checks if the given pointer occurs in the given type, returning True if so.
|
|
-- Used to stop the type checker performing infinite loops around a type-cycle.
|
|
occursInType :: Ptr A.Type -> TypeExp A.Type -> PassM Bool
|
|
occursInType r t =
|
|
do t' <- prune Nothing t
|
|
case t' of
|
|
MutVar _ r2 -> return $ r == r2
|
|
GenVar _ n -> return False
|
|
OperType _ _ _ ts -> mapM (occursInType r) ts >>* or
|
|
|
|
-- | Unifies two types, giving an error if it's not possible
|
|
unifyType :: TypeExp A.Type -> TypeExp A.Type -> PassM ()
|
|
unifyType te1 te2
|
|
= do t1' <- prune Nothing te1
|
|
t2' <- prune Nothing te2
|
|
case (t1',t2') of
|
|
(MutVar _ r1, MutVar _ r2) ->
|
|
if r1 == r2
|
|
then return ()
|
|
else do attr <- liftIO $ readIORef r1 >>* fst
|
|
attr' <- liftIO $ readIORef r2 >>* fst
|
|
liftIO $ writeIORef r1 (mergeAttr attr attr', Just t2')
|
|
(MutVar m r1, _) ->
|
|
do b <- occursInType r1 t2'
|
|
if b
|
|
then dieP m "Infinitely recursive type formed"
|
|
else do attr <- liftIO $ readIORef r1 >>* fst
|
|
liftIO $ writeIORef r1 (attr, Just t2')
|
|
(_,MutVar {}) -> unifyType t2' t1'
|
|
(GenVar m x,GenVar _ y) ->
|
|
if x == y then return () else dieP m $ "different template variables"
|
|
++ " cannot be assumed to be equal"
|
|
(OperType m1 n1 _ ts1,OperType m2 n2 _ ts2) ->
|
|
if n1 == n2
|
|
then unifyArgs ts1 ts2
|
|
else giveErr m1 "Type cannot be matched: " t1' t2'
|
|
(NumLit m1 vns1, NumLit m2 vns2) ->
|
|
do nst1 <- liftIO $ readIORef vns1
|
|
nst2 <- liftIO $ readIORef vns2
|
|
case (nst1, nst2) of
|
|
(Right t1, Right t2) ->
|
|
if t1 /= t2
|
|
then dieP m1 "Numeric literals bound to different types"
|
|
else return ()
|
|
(Left ns1, Left ns2) ->
|
|
do liftIO $ writeIORef vns1 $ Left (ns1 ++ ns2)
|
|
liftIO $ writeIORef vns2 $ Left (ns2 ++ ns1)
|
|
(Right {}, Left {}) -> unifyType t2' t1'
|
|
(Left ns1, Right t2) ->
|
|
if all (willFit t2) (map snd ns1)
|
|
then liftIO $ writeIORef vns1 (Right t2)
|
|
else dieP m1 "Numeric literals will not fit in concrete type"
|
|
(OperType {}, NumLit {}) -> unifyType t2' t1'
|
|
(NumLit m1 vns1, OperType m2 n2 f ts2) ->
|
|
do nst1 <- liftIO $ readIORef vns1
|
|
case nst1 of
|
|
Right t ->
|
|
if null ts2 && t == f []
|
|
then return ()
|
|
else dieP m1 $ "numeric literal cannot be unified"
|
|
++ " with two different types"
|
|
Left ns ->
|
|
if null ts2
|
|
then if all (willFit $ f []) (map snd ns)
|
|
then liftIO $ writeIORef vns1 $ Right (f [])
|
|
else dieP m1 "Numeric literals will not fit in concrete type"
|
|
else dieP m1 $ "Numeric literal cannot be unified"
|
|
++ " with non-numeric type"
|
|
(t,_) -> dieP (findMeta t) "different types"
|
|
where
|
|
unifyArgs (x:xs) (y:ys) = unifyType x y >> unifyArgs xs ys
|
|
unifyArgs [] [] = return ()
|
|
unifyArgs xs ys = dieP (findMeta (xs,ys)) "different lengths"
|
|
|
|
instantiate :: Typeable a => [TypeExp a] -> TypeExp a -> TypeExp a
|
|
instantiate ts x = case x of
|
|
MutVar _ _ -> x
|
|
OperType m nm f xs -> OperType m nm f (map (instantiate ts) xs)
|
|
GenVar _ n -> ts !! n
|
|
|
|
-- | Checks if the given number will fit in the given type
|
|
willFit :: A.Type -> Integer -> Bool
|
|
willFit t n = case bounds t of
|
|
Just (l,h) -> l <= n && n <= h
|
|
_ -> False
|
|
where
|
|
unsigned, signed :: Int -> Maybe (Integer, Integer)
|
|
signed n = Just (negate $ 2 ^ (n - 1), (2 ^ (n - 1)) - 1)
|
|
unsigned n = Just (0, (2 ^ n) - 1)
|
|
|
|
bounds :: A.Type -> Maybe (Integer, Integer)
|
|
bounds A.Int8 = signed 8
|
|
bounds A.Int16 = signed 16
|
|
bounds A.Int32 = signed 32
|
|
bounds A.Int64 = signed 64
|
|
bounds A.Byte = unsigned 8
|
|
bounds A.UInt16 = unsigned 16
|
|
bounds A.UInt32 = unsigned 32
|
|
bounds A.UInt64 = unsigned 64
|
|
bounds _ = Nothing
|
|
|