factor out solve function in infer.rkt

This commit is contained in:
AlexKnauth 2016-06-13 10:29:03 -04:00
parent 34f969efba
commit c4ab4510ee

View File

@ -29,6 +29,35 @@
(define-primop not : ( Bool Bool))
(define-primop abs : ( Int Int))
(begin-for-syntax
;; solve : (Stx-Listof Id) (Stx-Listof Stx) (Stx-Listof Type-Stx)
;; -> (List Constraints (Listof (Stx-List Stx Type-Stx)))
;; Solves for the Xs by inferring the type of each arg and unifying it against
;; each corresponding expected-τ (which could have free Xs in them).
;; It returns list of 2 values if successful, else throws a type error
;; - the constraints for substituting the types
;; - a list containing of all the arguments paired with their types
(define (solve Xs args expected-τs)
(let-values
([(cs e+τs)
(for/fold ([cs #'()] [e+τs #'()])
([e_arg (syntax->list args)]
[τ_inX (syntax->list expected-τs)])
(define/with-syntax τs_solved (stx-map (λ (y) (lookup y cs)) Xs))
(cond
[(andmap syntax-e (syntax->list #'τs_solved)) ; all tyvars X have mapping
; TODO: substs is not properly transferring #%type property
; (stx-map displayln #'τs_solved)
(define e+τ (infer+erase #`(add-expected #,e_arg #,(substs #'τs_solved Xs τ_inX))))
; (displayln e+τ)
(values cs (cons e+τ e+τs))]
[else
(define/with-syntax [e τ] (infer+erase e_arg))
; (displayln #'(e τ))
(define cs* (add-constraints Xs cs #`([#,τ_inX τ])))
(values cs* (cons #'[e τ] e+τs))]))])
(list cs (reverse (stx->list e+τs))))))
(define-typed-syntax define
[(_ x:id e)
#:with (e- τ) (infer+erase #'e)
@ -146,25 +175,7 @@
(string-join (map ~a (syntax->datum #'(e_arg ...))) ", ")))
; #:with ([e_arg- τ_arg] ...) #'(infers+erase #'(e_arg ...))
#:with (cs ([e_arg- τ_arg] ...))
(let-values ([(cs e+τs)
(for/fold ([cs #'()] [e+τs #'()])
([e_arg (syntax->list #'(e_arg ...))]
[τ_inX (syntax->list #'(τ_inX ...))])
(define/with-syntax τs_solved (stx-map (λ (y) (lookup y cs)) #'(X ...)))
(cond
[(andmap syntax-e (syntax->list #'τs_solved)) ; all tyvars X have mapping
; TODO: substs is not properly transferring #%type property
; (stx-map displayln #'τs_solved)
(define e+τ (infer+erase #`(add-expected #,e_arg #,(substs #'τs_solved #'(X ...) τ_inX))))
; (displayln e+τ)
(values cs (cons e+τ e+τs))]
[else
(define/with-syntax [e τ] (infer+erase e_arg))
; (displayln #'(e τ))
(define cs* (add-constraints #'(X ...) cs #`([#,τ_inX τ])))
(values cs* (cons #'[e τ] e+τs))]))])
(define/with-syntax e+τs/stx e+τs)
(list cs (reverse (syntax->list #'e+τs/stx))))
(solve #'(X ...) #'(e_arg ...) #'(τ_inX ...))
#:with env (stx-flatten (filter (λ (x) x) (stx-map get-env #'(e_arg- ...))))
#:with (τ_in ... τ_out) (inst-types/cs #'(X ...) #'cs #'(τ_inX ... τ_outX))
; some code duplication