Added more tests for type unification, fixed some bugs and added cleaner error handling

This commit is contained in:
Neil Brown 2008-05-14 09:52:16 +00:00
parent e3fa3df623
commit 9f1d65f4a0
2 changed files with 56 additions and 28 deletions

View File

@ -483,20 +483,30 @@ checkExpressionTest = TestList
testUnify :: Test testUnify :: Test
testUnify = TestList testUnify = TestList
[test [] [] $ Just [] [pass [] [] []
,test' [("a",A.Int)] [] ,pass' [("a",A.Int)] []
,test' [("a",A.Int)] [("a","a")] ,pass' [("a",A.Int)] [("a","a")]
,test [("a", A.Int), ("b", A.Infer)] [("a","b")] $ ,pass [("a", A.Int), ("b", A.Infer)] [("a","b")]
Just [("a", A.Int), ("b", A.Int)] [("a", A.Int), ("b", A.Int)]
,pass [("a", A.List A.Int), ("b", A.List A.Infer)] [("a","b")]
[("a", A.List A.Int), ("b", A.List A.Int)]
,fail [("a", A.Int), ("b", A.List A.Infer)] [("a","b")]
,fail [("a", A.Infer)] []
,fail [("a", A.Infer), ("b", A.Infer)] [("a","b")]
] ]
where where
test :: [(String, A.Type)] -> [(String, String)] -> Maybe [(String, A.Type)] pass :: [(String, A.Type)] -> [(String, String)] -> [(String, A.Type)]
-> Test -> Test
test im u om = TestCase $ assertEqual "testUnify" (fmap Map.fromList om) $ Just $ unifyRainTypes (Map.fromList pass im u om = TestCase $ assertEqual "testUnify" (Right $ Map.fromList om)
im) u $ unifyRainTypes (Map.fromList im) u
test' :: [(String, A.Type)] -> [(String, String)] -> Test fail :: [(String, A.Type)] -> [(String, String)] -> Test
test' x y = test x y (Just x) fail im u = TestCase $ case unifyRainTypes (Map.fromList im) u of
Left _ -> return ()
Right om -> assertEqual "testUnify" Nothing $ Just om
pass' :: [(String, A.Type)] -> [(String, String)] -> Test
pass' x y = pass x y x
tests :: Test tests :: Test

View File

@ -18,21 +18,31 @@ with this program. If not, see <http://www.gnu.org/licenses/>.
module TypeUnification where module TypeUnification where
import Control.Monad
import Control.Monad.ST import Control.Monad.ST
import Data.Generics import Data.Generics
import qualified Data.Map as Map import qualified Data.Map as Map
import Data.Maybe
import Data.STRef import Data.STRef
import qualified AST as A import qualified AST as A
import Utils import Utils
foldCon :: Constr -> [Either String A.Type] -> Either String A.Type
foldCon con [] = Right $ fromConstr con
foldCon con [Left e] = Left e
foldCon con [Right t] = Right $ fromConstrB (fromJust $ cast t) con
foldCon con _ = Left "foldCon: too many arguments given"
-- Much of the code in this module is taken from or based on Tim Sheard's Haskell -- Much of the code in this module is taken from or based on Tim Sheard's Haskell
-- listing of a simple type unification algorithm at the beginning of his -- listing of a simple type unification algorithm at the beginning of his
-- paper "Generic Unification via Two-Level Types and Parameterized Modules Functional -- paper "Generic Unification via Two-Level Types and Parameterized Modules Functional
-- Pearl (2001)", citeseer: http://citeseer.ist.psu.edu/451401.html -- Pearl (2001)", citeseer: http://citeseer.ist.psu.edu/451401.html
-- This in turn was taken from Luca Cardelli's "Basic Polymorphic Type Checking" -- This in turn was taken from Luca Cardelli's "Basic Polymorphic Type Checking"
unifyRainTypes :: Map.Map String A.Type -> [(String, String)] -> Map.Map String A.Type unifyRainTypes :: Map.Map String A.Type -> [(String, String)] -> Either String
(Map.Map String A.Type)
unifyRainTypes m prs unifyRainTypes m prs
= runST $ do m' <- mapToST m = runST $ do m' <- mapToST m
mapM_ (\(x,y) -> unifyType (lookupStartType x m') (lookupStartType y mapM_ (\(x,y) -> unifyType (lookupStartType x m') (lookupStartType y
@ -49,14 +59,22 @@ unifyRainTypes m prs
mapToST :: Map.Map String A.Type -> ST s (Map.Map String (TypeExp s A.Type)) mapToST :: Map.Map String A.Type -> ST s (Map.Map String (TypeExp s A.Type))
mapToST = mapMapM typeToTypeExp mapToST = mapMapM typeToTypeExp
stToMap :: Map.Map String (TypeExp s A.Type) -> ST s (Map.Map String stToMap :: Map.Map String (TypeExp s A.Type) -> ST s (Either String (Map.Map String
A.Type) A.Type))
stToMap = mapMapM (read <.< prune) stToMap m = do m' <- mapMapM (read <.< prune) m
let (mapOfErrs, mapOfRes) = Map.mapEitherWithKey (const id) m'
case Map.elems mapOfErrs of
(e:_) -> return $ Left e
[] -> return $ Right mapOfRes
where where
read :: TypeExp s A.Type -> ST s A.Type read :: TypeExp s A.Type -> ST s (Either String A.Type)
read (OperType con vals) = do vals' <- mapM read vals read (OperType con vals) = do vals' <- mapM read vals
return $ fromConstr con -- vals' return $ foldCon con vals'
read x = error $ "Type error in unification, found: " ++ show x read (MutVar v) = readSTRef v >>= \t -> case t of
Nothing -> return $ Left $ "Type error in unification, found non-unified type"
Just t' -> read t'
read x = return $ Left $ "Type error in unification, found: " ++ show x
++ " in: " ++ show m
ttte :: Data b => b -> A.Type -> ST s (TypeExp s A.Type) ttte :: Data b => b -> A.Type -> ST s (TypeExp s A.Type)
ttte c t = typeToTypeExp t >>= \t' -> return $ OperType (toConstr c) [t'] ttte c t = typeToTypeExp t >>= \t' -> return $ OperType (toConstr c) [t']
@ -88,7 +106,7 @@ data TypeExp s a
instance Show (TypeExp s a) where instance Show (TypeExp s a) where
show (MutVar {}) = "MutVar" show (MutVar {}) = "MutVar"
show (GenVar {}) = "GenVar" show (GenVar {}) = "GenVar"
show (OperType {}) = "OperType" show (OperType _ ts) = "OperType " ++ show ts
prune :: TypeExp s a -> ST s (TypeExp s a) prune :: TypeExp s a -> ST s (TypeExp s a)
prune t = prune t =
@ -113,33 +131,33 @@ occursInType r t =
do bs <- mapM (occursInType r) ts do bs <- mapM (occursInType r) ts
return (or bs) return (or bs)
unifyType :: TypeExp s a -> TypeExp s a -> ST s () unifyType :: TypeExp s a -> TypeExp s a -> ST s (Either String ())
unifyType t1 t2 unifyType t1 t2
= do t1' <- prune t1 = do t1' <- prune t1
t2' <- prune t2 t2' <- prune t2
case (t1',t2') of case (t1',t2') of
(MutVar r1, MutVar r2) -> (MutVar r1, MutVar r2) ->
if r1 == r2 if r1 == r2
then return () then return $ Right ()
else writeSTRef r1 (Just t2') else liftM Right $ writeSTRef r1 (Just t2')
(MutVar r1, _) -> (MutVar r1, _) ->
do b <- occursInType r1 t2' do b <- occursInType r1 t2'
if b if b
then error "occurs in" then return $ Left "occurs in"
else writeSTRef r1 (Just t2') else liftM Right $ writeSTRef r1 (Just t2')
(_,MutVar _) -> unifyType t2' t1' (_,MutVar _) -> unifyType t2' t1'
(GenVar n,GenVar m) -> (GenVar n,GenVar m) ->
if n == m then return () else error "different genvars" if n == m then return $ Right () else return $ Left "different genvars"
(OperType n1 ts1,OperType n2 ts2) -> (OperType n1 ts1,OperType n2 ts2) ->
if n1 == n2 if n1 == n2
then unifyArgs ts1 ts2 then unifyArgs ts1 ts2
else error "different constructors" else return $ Left "different constructors"
(_,_) -> error "different types" (_,_) -> return $ Left "different types"
where where
unifyArgs (x:xs) (y:ys) = do unifyType x y unifyArgs (x:xs) (y:ys) = do unifyType x y
unifyArgs xs ys unifyArgs xs ys
unifyArgs [] [] = return () unifyArgs [] [] = return $ Right ()
unifyArgs _ _ = error "different lengths" unifyArgs _ _ = return $ Left "different lengths"
instantiate :: [TypeExp s a] -> TypeExp s a -> TypeExp s a instantiate :: [TypeExp s a] -> TypeExp s a -> TypeExp s a
instantiate ts x = case x of instantiate ts x = case x of