From f0646f300bfa0aec352be8a045554b2f22897b95 Mon Sep 17 00:00:00 2001 From: Stephen Chang Date: Thu, 8 Dec 2016 16:04:40 -0500 Subject: [PATCH] more synthcl3 code cleanup --- macrotypes/typecheck.rkt | 11 +- turnstile/examples/rosette/synthcl3.rkt | 220 +++++++++--------------- 2 files changed, 90 insertions(+), 141 deletions(-) diff --git a/macrotypes/typecheck.rkt b/macrotypes/typecheck.rkt index e4b0012..a2b57d4 100644 --- a/macrotypes/typecheck.rkt +++ b/macrotypes/typecheck.rkt @@ -798,9 +798,14 @@ (define (mk-tyvar X) (attach X 'tyvar #t)) (define (tyvar? X) (syntax-property X 'tyvar)) - (define type-pat "[A-Za-z]+") - - ;; TODO: remove this? only benefit is single control point for current-promote + (define type-pat "[A-Za-z]+")) + +(define-syntax (⊢m stx) + (syntax-parse stx #:datum-literals (:) + [(_ e : τ) (assign-type #`e #`τ)] + [(_ e τ) (assign-type #`e #`τ)])) + +(begin-for-syntax ;; - infers type of e ;; - checks that type of e matches the specified type ;; - erases types in e diff --git a/turnstile/examples/rosette/synthcl3.rkt b/turnstile/examples/rosette/synthcl3.rkt index 81fa7da..34e6359 100644 --- a/turnstile/examples/rosette/synthcl3.rkt +++ b/turnstile/examples/rosette/synthcl3.rkt @@ -73,7 +73,7 @@ (format "no implicit conversion from ~a to ~a" (type->str from) (type->str to)) expr subexpr))) (define (mk-ptr id) (format-id id "~a*" id)) - (define (mk-mk id) (format-id id "mk-~a" id)) + (define (mk-mk id [ctx id]) (format-id ctx "mk-~a" id)) (define (mk-to id) (format-id id "to-~a" id)) (define (add-construct stx fn) (set-stx-prop/preserved stx 'construct fn)) (define (add-convert stx fn) (set-stx-prop/preserved stx 'convert fn)) @@ -85,12 +85,13 @@ (define split-ty (ty->len t)) (string->number (or (and split-ty (third split-ty)) "1"))) + (define (uneval ty #:ctx [ctx #'here] #:str-fn [str-fn (λ (x) x)]) + (datum->syntax ctx (string->symbol (str-fn (type->str ty))))) (define (get-base/un ty [ctx #'here]) ; returns unexpanded base type - (datum->syntax ctx - (string->symbol (car (regexp-match #px"[a-z]+" (type->str ty)))))) + (uneval ty #:ctx ctx #:str-fn (λ (s) (car (regexp-match #px"[a-z]+" s))))) (define (get-base ty [ctx #'here]) ((current-type-eval) (get-base/un ty ctx))) (define (get-pointer-base ty [ctx #'here]) ; returns unexpanded ptr base - (datum->syntax ctx (string->symbol (string-trim (type->str ty) "*")))) + (uneval ty #:ctx ctx #:str-fn (λ (s) (string-trim s "*")))) (define (vector-type? ty) (define tstr (type->str ty)) (ormap (λ (x) (string=? x tstr)) '("int2" "int3" "int4" "int16" "float2" "float3" "float4" "float16"))) @@ -157,21 +158,15 @@ (define-typed-syntax synth-app [(_ (ty:type) e) ≫ ; cast [⊢ e ≫ e- ⇒ ty-e] - #:fail-unless (cast-ok? #'ty-e #'ty.norm #'e) - (format "cannot cast ~a to ~a" - (type->str #'ty-e) (type->str #'ty.norm)) + #:when (cast-ok? #'ty-e #'ty.norm #'e) ; raises exn #:with convert (get-convert #'ty.norm) - #:fail-unless (syntax-e #'convert) - (format "cannot cast ~a to ~a: conversion fn not found" - (type->str #'ty-e) (type->str #'ty.norm)) -------- [⊢ (ro:#%app convert e-) ⇒ ty.norm]] [(_ ty:type e ...) ≫ ; construct [⊢ e ≫ e- ⇒ ty-e] ... #:with construct (get-construct #'ty.norm) #:fail-unless (syntax-e #'construct) - (format "no constructor found for ~a type" - (type->str #'ty.norm)) + (format "no constructor found for ~a type" (type->str #'ty.norm)) -------- [⊢ (ro:#%app construct e- ...) ⇒ ty.norm]] [(_ p _) ≫ ; applying ptr to one arg is selector @@ -194,15 +189,13 @@ #:do [(define base-str (cadr split-ty)) (define len-str (caddr split-ty))] #:do [(define sels (length (stx->list #'selector)))] - #:with e-out (if (= sels 1) - #'(ro:vector-ref vec- (ro:car 'selector)) - #'(for/list ([idx 'selector]) - (ro:vector-ref vec- idx))) + #:with e-out (if (= sels 1) #'(ro:vector-ref vec- (ro:car 'selector)) + #'(for/list ([idx 'selector]) + (ro:vector-ref vec- idx))) #:with ty-out ((current-type-eval) - (if (= sels 1) - (format-id #'here "~a" base-str) - (format-id #'here "~a~a" - base-str (length (stx->list #'selector))))) + (if (= sels 1) (format-id #'here "~a" base-str) + (format-id #'here "~a~a" + base-str (length (stx->list #'selector))))) #:with convert (get-convert #'ty-out) -------- [⊢ (ro:#%app convert e-out) ⇒ ty-out]] @@ -212,14 +205,7 @@ #:when (stx-andmap cast-ok? #'(ty-e ...) #'(ty-in ...)) -------- [⊢ (ro:#%app f- e- ...) ⇒ ty-out]] - [(_ . es) ≫ - -------- - [≻ (rosette3:#%app . es)]]) - -(define-syntax (⊢m stx) - (syntax-parse stx #:datum-literals (:) - [(_ e : τ) (assign-type #`e #`τ)] - [(_ e τ) (assign-type #`e #`τ)])) + [(_ . es) ≫ --- [≻ (rosette3:#%app . es)]]) ;; top-level fns -------------------------------------------------- (define-typed-syntax procedure @@ -239,11 +225,12 @@ (rosette3:let ([x (⊢m (ro:#%app conv x) ty)] ...) (⊢m (ro:let () e ... (rosette3:ann e-body : ty-out)) ty-out)))) (provide- f))]]) -(define-typed-syntax kernel - [(_ ty-out:type (f [ty:type x:id] ...) e ...) ≫ - #:fail-unless (void? #'ty-out.norm) - (format "expected void, given ~a" (type->str #'ty-out.norm)) - --- [≻ (procedure void (f [ty x] ...) e ...)]]) + +(define-typed-syntax (kernel ty-out:type (f [ty:type x:id] ...) e ...) ≫ + #:fail-unless (void? #'ty-out.norm) + (format "expected void, given ~a" (type->str #'ty-out.norm)) + --- [≻ (procedure void (f [ty x] ...) e ...)]) + (define-typed-syntax grammar [(_ ty-out:type (f [ty:type x:id] ... [ty-depth k]) #:base be #:else ee) ≫ #:with f- (generate-temporary #'f) @@ -255,8 +242,7 @@ (define-typed-syntax f [(ff a ... j) ≫ [⊢ a ≫ _ ⇐ ty] ... - [⊢ j ≫ _ ⇐ ty-depth] - ;; j will be eval'ed, so strip its context + [⊢ j ≫ _ ⇐ ty-depth] ; j will be eval'ed, so strip its context #:with j- (assign-type (datum->syntax #'H (stx->datum #'j)) #'int) #:with f-- (replace-stx-loc #'f- #'ff) ----------- @@ -281,7 +267,7 @@ [(_ e-test {e1 ...} {e2 ...}) ≫ -------- [⊢ (ro:if (to-bool e-test) - (ro:let () e1 ... (ro:void)) + (ro:let () e1 ... (ro:void)) (ro:let () e2 ... (ro:void))) ⇒ void]] [(_ e-test es) ≫ --- [≻ (if e-test es {})]]) @@ -298,53 +284,44 @@ (with-ctx ([x x- ty] ...) (⊢m (ro:let () e ... (ro:void)) void)))) ⇒ void]]) - -;; need to redefine #%datum because rosette3:#%datum is too precise -(define-typed-syntax #%datum +(define-typed-syntax #%datum ; redefine bc rosette3:#%datum is too precise [(_ . b:boolean) ≫ --- [⊢ (ro:#%datum . b) ⇒ bool]] [(_ . s:str) ≫ --- [⊢ (ro:#%datum . s) ⇒ char*]] [(_ . n:integer) ≫ --- [⊢ (ro:#%datum . n) ⇒ int]] - [(#%datum . n:number) ≫ - #:when (real? (syntax-e #'n)) + [(_ . n:number) ≫ + #:when (real? (syntax-e #'n)) -------- [⊢ (ro:#%datum . n) ⇒ float]] [(_ . x) ≫ -------- [_ #:error (type-error #:src #'x #:msg "Unsupported literal: ~v" #'x)]]) - -;; : -------------------------------------------------- +;; : (var declaration) -------------------------------------------------- (define-typed-syntax : [(_ ty:type x:id ...) ≫ ; special String case #:when (rosette3:CString? #'ty.norm) #:with (x- ...) (generate-temporaries #'(x ...)) -------- - [≻ (begin- - (define-syntax- x - (make-rename-transformer (assign-type #'x- #'ty.norm))) ... - (ro:define x- (ro:#%datum . "")) ...)]] - ;; TODO: vector types need a better representation - ;; currently dissecting the identifier - ;; TODO: combine vector and scalar cases + [≻ (begin- (define-syntax- x + (make-rename-transformer (assign-type #'x- #'ty.norm))) ... + (ro:define x- (ro:#%datum . "")) ...)]] [(_ ty:type x:id ...) ≫ #:when (real-type? #'ty.norm) #:do [(define split-ty (ty->len #'ty))] #:when (and split-ty (= 3 (length split-ty))) - #:do [(define base-str (cadr split-ty)) + #:do [(define base-str (cadr split-ty)) (define len-str (caddr split-ty))] #:with ty-base (datum->syntax #'ty (string->symbol base-str)) #:with pred (get-pred ((current-type-eval) #'ty-base)) - #:fail-unless (syntax-e #'pred) - (format "no pred for ~a" (type->str #'ty)) + #:fail-unless (syntax-e #'pred) (format "no pred for ~a" (type->str #'ty)) #:with (x- ...) (generate-temporaries #'(x ...)) #:with (x-- ...) (generate-temporaries #'(x ...)) - #:with mk-ty (format-id #'here "mk-~a" #'ty) + #:with mk-ty (mk-mk #'ty #'here) -------- - [≻ (begin- - (ro:define-symbolic* x-- pred [#,(string->number len-str)]) ... - (ro:define x- (ro:apply mk-ty x--)) ... - (define-syntax- x - (make-rename-transformer (assign-type #'x- #'ty.norm))) ...)]] + [≻ (begin- (ro:define-symbolic* x-- pred [#,(string->number len-str)]) ... + (ro:define x- (ro:apply mk-ty x--)) ... + (define-syntax- x + (make-rename-transformer (assign-type #'x- #'ty.norm))) ...)]] [(_ ty:type [len] x:id ...) ≫ ; array of vector types #:when (real-type? #'ty.norm) [⊢ len ≫ len- ⇐ int] @@ -353,42 +330,37 @@ #:with ty* (format-id #'ty "~a*" #'ty) #:with to-ty* (format-id #'here "to-~a" #'ty*) #:with pred (get-pred ((current-type-eval) #'ty-base)) - #:fail-unless (syntax-e #'pred) - (format "no pred for ~a" (type->str #'ty)) + #:fail-unless (syntax-e #'pred) (format "no pred for ~a" (type->str #'ty)) #:with (x- ...) (generate-temporaries #'(x ...)) #:with (*x ...) (generate-temporaries #'(x ...)) #:with (x-- ...) (generate-temporaries #'(x ...)) #:with mk-ty (format-id #'here "mk-~a" #'ty) -------- - [≻ (begin- - (ro:define-symbolic* x-- pred [len base-len]) ... - (ro:define x- - (ro:let ([*x (to-ty* (cl:malloc (ro:* len base-len)))]) - (ro:for ([i len][v x--]) + [≻ (begin- (ro:define-symbolic* x-- pred [len base-len]) ... + (ro:define x- + (ro:let ([*x (to-ty* (cl:malloc (ro:* len base-len)))]) + (ro:for ([i len][v x--]) (cl:pointer-set! *x i (ro:apply mk-ty v))) *x)) ... - (define-syntax- x - (make-rename-transformer (assign-type #'x- #'ty*))) ...)]] + (define-syntax- x + (make-rename-transformer (assign-type #'x- #'ty*))) ...)]] ;; real, scalar (ie non-vector) types [(_ ty:type x:id ...) ≫ #:when (real-type? #'ty.norm) #:with pred (get-pred #'ty.norm) - #:fail-unless (syntax-e #'pred) - (format "no pred for ~a" (type->str #'ty)) + #:fail-unless (syntax-e #'pred) (format "no pred for ~a" (type->str #'ty)) #:with (x- ...) (generate-temporaries #'(x ...)) -------- - [≻ (begin- - (define-syntax- x - (make-rename-transformer (assign-type #'x- #'ty.norm))) ... - (ro:define-symbolic* x- pred) ...)]] + [≻ (begin- (define-syntax- x + (make-rename-transformer (assign-type #'x- #'ty.norm))) ... + (ro:define-symbolic* x- pred) ...)]] ;; else init to NULLs [(_ ty:type x:id ...) ≫ #:with (x- ...) (generate-temporaries #'(x ...)) -------- - [≻ (begin- - (define-syntax- x - (make-rename-transformer (assign-type #'x- #'ty.norm))) ... - (ro:define x- cl:NULL) ...)]]) + [≻ (begin- (define-syntax- x + (make-rename-transformer (assign-type #'x- #'ty.norm))) ... + (ro:define x- cl:NULL) ...)]]) ;; ?: -------------------------------------------------- (define-typed-syntax ?: @@ -406,20 +378,16 @@ #:with ty-base ((current-type-eval) (datum->syntax #'e (string->symbol out-base-str))) #:with base-convert (get-convert #'ty-base) ------- - [⊢ (convert - (ro:let ([a (convert e-)][b (convert e1-)][c (convert e2-)]) - (ro:for/list ([idx #,(string->number out-len-str)]) - (ro:if (ro:< (ro:vector-ref a idx) 0) - (base-convert (ro:vector-ref b idx)) - (base-convert (ro:vector-ref c idx)))))) - ⇒ ty-out]] + [⊢ (convert (ro:let ([a (convert e-)][b (convert e1-)][c (convert e2-)]) + (ro:for/list ([idx #,(string->number out-len-str)]) + (ro:if (ro:< (ro:vector-ref a idx) 0) + (base-convert (ro:vector-ref b idx)) + (base-convert (ro:vector-ref c idx)))))) ⇒ ty-out]] [(_ ~! e e1 e2) ≫ ; should be scalar and real [⊢ e ≫ e- ⇒ ty] - #:fail-unless (real-type? #'ty) - (format "not a real type: ~s has type ~a" - (syntax->datum #'e) (type->str #'ty)) - #:fail-unless (cast-ok? #'ty #'bool #'e) - (format "cannot cast ~a to bool" (type->str #'ty)) + #:fail-unless (real-type? #'ty) (format "not a real type: ~s has type ~a" + (syntax->datum #'e) (type->str #'ty)) + #:when (cast-ok? #'ty #'bool #'e) [⊢ e1 ≫ e1- ⇒ ty1] [⊢ e2 ≫ e2- ⇒ ty2] #:with ty-out ((current-join) #'ty1 #'ty2) @@ -465,19 +433,15 @@ -------- [⊢ (ro:#%app cl:! (to-bool e-)) ⇒ bool]]) -;; TODO: comparison ops need to support vec types (and result) -(define-simple-macro (mk-cmp cmp-op) +;TODO: cmps should produce vec int result with same length as comm-real-ty +(define-simple-macro (mk-cmp cmp-op) (define-typed-syntax cmp-op [(o e1 e2) ≫ [⊢ e1 ≫ e1- ⇒ ty1] [⊢ e2 ≫ e2- ⇒ ty2] - #:when (real-type? #'ty1) - #:when (real-type? #'ty2) - #:with ty-out ((current-join) #'ty1 #'ty2) ; only need this for the len - #:with conv (get-convert #'ty-out) - #:with o- (mk-cl #'o) + #:with conv (get-convert ((current-join) #'ty1 #'ty2)) -------- - [⊢ (to-int (o- (conv e1-) (conv e2-))) ⇒ int]])) + [⊢ (to-int (#,(mk-cl #'o) (conv e1-) (conv e2-))) ⇒ int]])) (define-simple-macro (mk-cmps o ...) (begin- (mk-cmp o) ...)) (mk-cmps == < <= > >= !=) @@ -490,17 +454,11 @@ [⊢ e2 ≫ e2- ⇐ bool] -------- [⊢ (name- e1- e2-) ⇒ bool]] - [(_ e1 e2) ≫ ; else try to coerce - --- [⊢ (name- (to-bool e1) (to-bool e2)) ⇒ bool]])) - -#;(define- (cl:/ x y) - (cond- [(zero?- y) 0] - [(integer?- x) (quotient- x y)] - [else (/- x y)])) + [(_ e1 e2) ≫ --- [⊢ (name- (to-bool e1) (to-bool e2)) ⇒ bool]])) ; coerce (define-simple-macro (define-real-ops o ...) (ro:begin (define-real-op o) ...)) (define-simple-macro (define-real-op name (~optional (~seq #:extra-check p?) - #:defaults ([p? #'(λ _ #t)]))) + #:defaults ([p? #'(λ _ #t)]))) #:with name- (mk-cl #'name) #:with name= (format-id #'name "~a=" #'name) ; assignment form (begin- @@ -521,14 +479,13 @@ #'(convert (ro:let ([x (convert e-)] (... ...)) (ro:for/list ([x x] (... ...)) (base-convert (name- x (... ...))))))) ⇒ ty-out]) - (define-typed-syntax (name= x e) ≫ - --- [≻ (= x (name x e))]))) + (define-typed-syntax (name= x e) ≫ --- [≻ (= x (name x e))]))) (define-for-syntax (int? t givens) (or (typecheck/un? t #'int) - (raise-syntax-error #f - (format "no common integer type for operands; given ~a" - (types->str givens))))) + (raise-syntax-error #f + (format "no common integer type for operands; given ~a" + (types->str givens))))) (define-simple-macro (define-int-op o) (define-real-op o #:extra-check int?)) (define-simple-macro (define-int-ops o ...) (ro:begin (define-int-op o) ...)) @@ -536,41 +493,34 @@ (define-real-ops + * - / sqrt) (define-int-ops % << $ &) -(define-typerule (sizeof t:type) >> ---[⊢ #,(real-type-length #'t.norm) ⇒ int]) -(define-typerule (print e ...) >> ---[⊢ (ro:begin (display e) ...) ⇒ void]) +(define-typerule (sizeof t:type) ≫ --- [⊢ #,(real-type-length #'t.norm) ⇒ int]) +(define-typerule (print e ...) ≫ --- [⊢ (ro:begin (display e) ...) ⇒ void]) +(define-typerule (assert e) ≫ --- [⊢ (ro:assert (to-bool e)) ⇒ void]) +(define-typerule (clCreateProgramWithSource ctx f) ≫ + --- [⊢ (cl:clCreateProgramWithSource ctx f) ⇒ cl_program]) (define-typed-syntax choose [(ch e ...+) ≫ #:with (e- ...) (stx-map expand/ro #'(e ...)) #:with (ty ...) (stx-map typeof #'(e- ...)) #:when (same-types? #'(ty ...)) - #:with (e/disarmed ...) (stx-map replace-stx-loc #'(e- ...) #'(e ...)) - ;; the #'choose identifier itself must have the location of its use - ;; see define-synthax implementation, specifically syntax/source in utils - #:with ch/disarmed (replace-stx-loc #'ro:choose #'ch) + #:with ch- (replace-stx-loc #'ro:choose #'ch) -------- - [⊢ (ch/disarmed e/disarmed ...) ⇒ #,(stx-car #'(ty ...))]]) + [⊢ (ch- e- ...) ⇒ #,(stx-car #'(ty ...))]]) (define-typed-syntax ?? - [(qq) ≫ - #:with ??/progsrc (replace-stx-loc #'cl:?? #'qq) - -------- - [⊢ (??/progsrc) ⇒ int]] [(qq ty:type) ≫ - #:with ??/progsrc (replace-stx-loc #'cl:?? #'qq) - #:with t (datum->syntax #'here (string->symbol (type->str #'ty.norm))) - #:with cl-t (mk-cl #'t) - ;; #:with ty-base (get-base #'ty.norm) - ;; #:with pred (get-pred ((current-type-eval) #'ty-base)) + #:with qq- (replace-stx-loc #'cl:?? #'qq) + #:with cl-t (mk-cl (uneval #'ty.norm)) -------- - [⊢ (??/progsrc cl-t) ⇒ ty]]) + [⊢ (qq- cl-t) ⇒ ty.norm]] + [(qq) ≫ --- [⊢ (#,(replace-stx-loc #'cl:?? #'qq)) ⇒ int]]) (define-typed-syntax (@ x:id) ≫ [⊢ x ≫ x- ⇒ ty+] ;; TODO: check ty = real, non-ptr type - #:with ty (datum->syntax #'x (string->symbol (type->str #'ty+))) - #:with cl-ty (mk-cl #'ty) + #:with ty (uneval #'ty+) --------- - [⊢ (cl:address-of x- cl-ty) ⇒ #,(mk-ptr #'ty)]) + [⊢ (cl:address-of x- #,(mk-cl #'ty)) ⇒ #,(mk-ptr #'ty)]) (define-typed-syntax locally-scoped [(_ e ...) ⇐ ty ≫ --- [⊢ (ro:let () e ...)]] @@ -603,7 +553,7 @@ (ro:synthesize #:forall (ro:append tmp ...) #:guarantee (ro:for ([id- tmp] ...) - (with-ctx ([id id- ty] ...) e)))))) ⇒ void]] + (with-ctx ([id id- ty] ...) e)))))) ⇒ void]] [(_ #:forall [decl ...] #:ensure e) ≫ --- [≻ (synth #:forall [decl ...] #:bitwidth 8 #:ensure e)]]) @@ -624,9 +574,3 @@ (printf "~a = ~a\n" i (ro:evaluate i- cex))) cex)) (begin (displayln "no counterexample found") (ro:unsat))))) ⇒ void]]) - -(define-typed-syntax (assert e) ≫ - --- [⊢ (ro:assert (to-bool #,(expand/ro #'e))) ⇒ void]) - -(define-typed-syntax (clCreateProgramWithSource ctx f) ≫ - --- [⊢ (cl:clCreateProgramWithSource ctx f) ⇒ cl_program])