diff --git a/frontends/RainTypesTest.hs b/frontends/RainTypesTest.hs index f01eb57..29008ac 100644 --- a/frontends/RainTypesTest.hs +++ b/frontends/RainTypesTest.hs @@ -483,20 +483,30 @@ checkExpressionTest = TestList testUnify :: Test testUnify = TestList - [test [] [] $ Just [] - ,test' [("a",A.Int)] [] - ,test' [("a",A.Int)] [("a","a")] - ,test [("a", A.Int), ("b", A.Infer)] [("a","b")] $ - Just [("a", A.Int), ("b", A.Int)] + [pass [] [] [] + ,pass' [("a",A.Int)] [] + ,pass' [("a",A.Int)] [("a","a")] + ,pass [("a", A.Int), ("b", A.Infer)] [("a","b")] + [("a", A.Int), ("b", A.Int)] + ,pass [("a", A.List A.Int), ("b", A.List A.Infer)] [("a","b")] + [("a", A.List A.Int), ("b", A.List A.Int)] + ,fail [("a", A.Int), ("b", A.List A.Infer)] [("a","b")] + ,fail [("a", A.Infer)] [] + ,fail [("a", A.Infer), ("b", A.Infer)] [("a","b")] ] where - test :: [(String, A.Type)] -> [(String, String)] -> Maybe [(String, A.Type)] + pass :: [(String, A.Type)] -> [(String, String)] -> [(String, A.Type)] -> Test - test im u om = TestCase $ assertEqual "testUnify" (fmap Map.fromList om) $ Just $ unifyRainTypes (Map.fromList - im) u + pass im u om = TestCase $ assertEqual "testUnify" (Right $ Map.fromList om) + $ unifyRainTypes (Map.fromList im) u - test' :: [(String, A.Type)] -> [(String, String)] -> Test - test' x y = test x y (Just x) + fail :: [(String, A.Type)] -> [(String, String)] -> Test + fail im u = TestCase $ case unifyRainTypes (Map.fromList im) u of + Left _ -> return () + Right om -> assertEqual "testUnify" Nothing $ Just om + + pass' :: [(String, A.Type)] -> [(String, String)] -> Test + pass' x y = pass x y x tests :: Test diff --git a/frontends/TypeUnification.hs b/frontends/TypeUnification.hs index 7f269c4..c985dd5 100644 --- a/frontends/TypeUnification.hs +++ b/frontends/TypeUnification.hs @@ -18,21 +18,31 @@ with this program. If not, see . module TypeUnification where +import Control.Monad import Control.Monad.ST import Data.Generics import qualified Data.Map as Map +import Data.Maybe import Data.STRef import qualified AST as A import Utils +foldCon :: Constr -> [Either String A.Type] -> Either String A.Type +foldCon con [] = Right $ fromConstr con +foldCon con [Left e] = Left e +foldCon con [Right t] = Right $ fromConstrB (fromJust $ cast t) con +foldCon con _ = Left "foldCon: too many arguments given" + + -- Much of the code in this module is taken from or based on Tim Sheard's Haskell -- listing of a simple type unification algorithm at the beginning of his -- paper "Generic Unification via Two-Level Types and Parameterized Modules Functional -- Pearl (2001)", citeseer: http://citeseer.ist.psu.edu/451401.html -- This in turn was taken from Luca Cardelli's "Basic Polymorphic Type Checking" -unifyRainTypes :: Map.Map String A.Type -> [(String, String)] -> Map.Map String A.Type +unifyRainTypes :: Map.Map String A.Type -> [(String, String)] -> Either String + (Map.Map String A.Type) unifyRainTypes m prs = runST $ do m' <- mapToST m mapM_ (\(x,y) -> unifyType (lookupStartType x m') (lookupStartType y @@ -49,14 +59,22 @@ unifyRainTypes m prs mapToST :: Map.Map String A.Type -> ST s (Map.Map String (TypeExp s A.Type)) mapToST = mapMapM typeToTypeExp - stToMap :: Map.Map String (TypeExp s A.Type) -> ST s (Map.Map String - A.Type) - stToMap = mapMapM (read <.< prune) + stToMap :: Map.Map String (TypeExp s A.Type) -> ST s (Either String (Map.Map String + A.Type)) + stToMap m = do m' <- mapMapM (read <.< prune) m + let (mapOfErrs, mapOfRes) = Map.mapEitherWithKey (const id) m' + case Map.elems mapOfErrs of + (e:_) -> return $ Left e + [] -> return $ Right mapOfRes where - read :: TypeExp s A.Type -> ST s A.Type + read :: TypeExp s A.Type -> ST s (Either String A.Type) read (OperType con vals) = do vals' <- mapM read vals - return $ fromConstr con -- vals' - read x = error $ "Type error in unification, found: " ++ show x + return $ foldCon con vals' + read (MutVar v) = readSTRef v >>= \t -> case t of + Nothing -> return $ Left $ "Type error in unification, found non-unified type" + Just t' -> read t' + read x = return $ Left $ "Type error in unification, found: " ++ show x + ++ " in: " ++ show m ttte :: Data b => b -> A.Type -> ST s (TypeExp s A.Type) ttte c t = typeToTypeExp t >>= \t' -> return $ OperType (toConstr c) [t'] @@ -88,7 +106,7 @@ data TypeExp s a instance Show (TypeExp s a) where show (MutVar {}) = "MutVar" show (GenVar {}) = "GenVar" - show (OperType {}) = "OperType" + show (OperType _ ts) = "OperType " ++ show ts prune :: TypeExp s a -> ST s (TypeExp s a) prune t = @@ -113,33 +131,33 @@ occursInType r t = do bs <- mapM (occursInType r) ts return (or bs) -unifyType :: TypeExp s a -> TypeExp s a -> ST s () +unifyType :: TypeExp s a -> TypeExp s a -> ST s (Either String ()) unifyType t1 t2 = do t1' <- prune t1 t2' <- prune t2 case (t1',t2') of (MutVar r1, MutVar r2) -> if r1 == r2 - then return () - else writeSTRef r1 (Just t2') + then return $ Right () + else liftM Right $ writeSTRef r1 (Just t2') (MutVar r1, _) -> do b <- occursInType r1 t2' if b - then error "occurs in" - else writeSTRef r1 (Just t2') + then return $ Left "occurs in" + else liftM Right $ writeSTRef r1 (Just t2') (_,MutVar _) -> unifyType t2' t1' (GenVar n,GenVar m) -> - if n == m then return () else error "different genvars" + if n == m then return $ Right () else return $ Left "different genvars" (OperType n1 ts1,OperType n2 ts2) -> if n1 == n2 then unifyArgs ts1 ts2 - else error "different constructors" - (_,_) -> error "different types" + else return $ Left "different constructors" + (_,_) -> return $ Left "different types" where unifyArgs (x:xs) (y:ys) = do unifyType x y unifyArgs xs ys - unifyArgs [] [] = return () - unifyArgs _ _ = error "different lengths" + unifyArgs [] [] = return $ Right () + unifyArgs _ _ = return $ Left "different lengths" instantiate :: [TypeExp s a] -> TypeExp s a -> TypeExp s a instantiate ts x = case x of