From d20a238961fb251cddcf4f275d4afe97edfeb78c Mon Sep 17 00:00:00 2001 From: Stephen Chang Date: Tue, 15 Nov 2016 15:00:39 -0500 Subject: [PATCH] synthcl3: add macros to abstract definition of common forms and types --- turnstile/examples/rosette/synthcl3.rkt | 333 +++++++++--------------- 1 file changed, 119 insertions(+), 214 deletions(-) diff --git a/turnstile/examples/rosette/synthcl3.rkt b/turnstile/examples/rosette/synthcl3.rkt index be0f1c7..2b402eb 100644 --- a/turnstile/examples/rosette/synthcl3.rkt +++ b/turnstile/examples/rosette/synthcl3.rkt @@ -3,7 +3,6 @@ (require ;(prefix-in ro: (except-in rosette verify sqrt range print)) ; untyped racket/stxparam (prefix-in ro: rosette) -; define-symbolic* #%datum define if ! = number? boolean? cond ||)) (prefix-in cl: sdsl/synthcl/lang/forms) (prefix-in cl: sdsl/synthcl/model/reals) (prefix-in cl: sdsl/synthcl/model/operators) @@ -18,41 +17,24 @@ (define (mk-cl id) (format-id id "cl:~a" id)) (current-host-lang mk-cl)) -;(define-base-types) - -(provide (rename-out - [synth-app #%app] -; [rosette3:Bool bool] ; symbolic -; [rosette3:Int int] ; symbolic -; [rosette3:Num float] ; symbolic - #;[rosette3:CString char*]) ; always concrete - procedure kernel - #%datum if range for - bool int int2 int3 int4 float float3 int16 void - void* char* int* int16* - : ! ?: == + % - ;; assignment ops - = += %= +(provide (rename-out [synth-app #%app]) + procedure kernel #%datum if range for + int int2 int3 int4 int16 float float2 float3 float4 float16 + bool void void* char* int* int16* + : ! ?: == + % || && + = += %= ; assignment ops (typed-out - ;; need the concrete cases for literals; - ;; alternative is to redefine #%datum to give literals symbolic type [malloc : (C→ int void*)] [get_work_dim : (C→ int)] - #;[% : (Ccase-> (C→ CInt CInt CInt) - (C→ Int Int Int))] + [!= : (Ccase-> (C→ CNum CNum CBool) (C→ CNum CNum CNum CBool) (C→ Num Num Bool) (C→ Num Num Num Bool))] - [NULL : void*] - #;[== : (Ccase-> (C→ CNum CNum CBool) - (C→ CNum CNum CNum CBool) - (C→ Num Num Bool) - (C→ Num Num Num Bool))])) + [NULL : void*])) (begin-for-syntax ;; TODO: use equality type relation instead of subtype ;; - require reimplementing many more things, eg #%datum, +, etc -; (current-typecheck-relation (current-type=?)) ;; typecheck unexpanded types (define (typecheck/un? t1 t2) (typecheck? ((current-type-eval) t1) @@ -107,10 +89,7 @@ (format "no implicit conversion from ~a to ~a" (type->str from) (type->str to) #;(if (contract? to) (contract-name to) to)) - expr subexpr)) - #;(or (typecheck/un? from to) ; from == to - (and (real-type? from) - (typecheck/un? to #'bool)))) + expr subexpr))) (define (add-convert stx fn) (set-stx-prop/preserved stx 'convert fn)) (define (get-convert stx) @@ -137,6 +116,7 @@ [(_ stx fn) (add-construct #'stx #'fn)]) ;; TODO: reuse impls in model/reals.rkt ? + (ro:define (to-bool v) (ro:cond [(ro:boolean? v) v] @@ -153,65 +133,7 @@ [(ro:fixnum? v) (ro:exact->inexact v)] [(ro:flonum? v) v] [else (ro:type-cast ro:real? v)])) -(ro:define (to-float3 v) - (ro:cond - [(ro:list? v) - (ro:apply ro:vector-immutable - (ro:for/list ([i 3]) (to-float (ro:list-ref v i))))] - [(ro:vector? v) - (ro:apply ro:vector-immutable - (ro:for/list ([i 3]) (to-float (ro:vector-ref v i))))] - [else (ro:apply ro:vector-immutable (ro:make-list 3 (to-float v)))])) -(ro:define (to-int2 v) - (ro:cond - [(ro:list? v) - (ro:apply ro:vector-immutable - (ro:for/list ([i 2]) (to-int (ro:list-ref v i))))] - [(ro:vector? v) - (ro:apply ro:vector-immutable - (ro:for/list ([i 2]) (to-int (ro:vector-ref v i))))] - [else - (ro:apply ro:vector-immutable - (ro:make-list 2 (to-int v)))])) -(ro:define (to-int3 v) - (ro:cond - [(ro:list? v) - (ro:apply ro:vector-immutable - (ro:for/list ([i 3]) (to-int (ro:list-ref v i))))] - [(ro:vector? v) - (ro:apply ro:vector-immutable - (ro:for/list ([i 3]) (to-int (ro:vector-ref v i))))] - [else - (ro:apply ro:vector-immutable - (ro:make-list 3 (to-int v)))])) -(ro:define (to-int4 v) - (ro:cond - [(ro:list? v) - (ro:apply ro:vector-immutable - (ro:for/list ([i 4]) (to-int (ro:list-ref v i))))] - [(ro:vector? v) - (ro:apply ro:vector-immutable - (ro:for/list ([i 4]) (to-int (ro:vector-ref v i))))] - [else - (ro:apply ro:vector-immutable - (ro:make-list 4 (to-int v)))])) -(ro:define (mk-int2 x y) - (ro:#%app ro:vector-immutable (to-int x) (to-int y))) -(ro:define (mk-int3 x y z) - (ro:#%app ro:vector-immutable (to-int x) (to-int y) (to-int z))) -(ro:define (mk-int4 w x y z) - (ro:#%app ro:vector-immutable (to-int w) (to-int x) (to-int y) (to-int z))) -(ro:define (to-int16 v) - (ro:cond - [(ro:list? v) - (ro:apply ro:vector-immutable - (ro:for/list ([i 16]) (to-int (ro:list-ref v i))))] - [(ro:vector? v) - (ro:apply ro:vector-immutable - (ro:for/list ([i 16]) (to-int (ro:vector-ref v i))))] - [else - (ro:apply ro:vector-immutable - (ro:make-list 16 (to-int v)))])) + (ro:define (to-int16* v) (cl:pointer-cast v cl:int16)) @@ -223,35 +145,73 @@ (add-convertm rosette3:Num to-float)) (define-named-type-alias char* rosette3:CString) -(define-named-type-alias float3 - (add-convertm - (rosette3:CVector rosette3:Num rosette3:Num rosette3:Num) - to-float3)) -(define-named-type-alias int2 - (add-constructm - (add-convertm - (rosette3:CVector rosette3:Int rosette3:Int) - to-int2) - mk-int2)) -(define-named-type-alias int3 - (add-constructm - (add-convertm - (rosette3:CVector rosette3:Int rosette3:Int rosette3:Int) - to-int3) - mk-int3)) -(define-named-type-alias int4 - (add-constructm - (add-convertm - (rosette3:CVector rosette3:Int rosette3:Int rosette3:Int rosette3:Int) - to-int4) - mk-int4)) -(define-named-type-alias int16 - (add-convertm - (rosette3:CVector rosette3:Int rosette3:Int rosette3:Int rosette3:Int - rosette3:Int rosette3:Int rosette3:Int rosette3:Int - rosette3:Int rosette3:Int rosette3:Int rosette3:Int - rosette3:Int rosette3:Int rosette3:Int rosette3:Int) - to-int16)) +(define-syntax (define-int stx) + (syntax-parse stx + [(_ n) + #:with intn (format-id #'n "int~a" (syntax->datum #'n)) + #:with to-intn (format-id #'n "to-~a" #'intn) + #:with mk-intn (format-id #'n "mk-~a" #'intn) + #:with (x ...) (generate-temporaries + (build-list (syntax->datum #'n) (lambda (x) x))) + #:with (I ...) (stx-map (lambda _ #'rosette3:Int) #'(x ...)) + #'(begin + (define-named-type-alias intn + (add-constructm + (add-convertm + (rosette3:CVector I ...) + to-intn) + mk-intn)) + (ro:define (to-intn v) + (ro:cond + [(ro:list? v) + (ro:apply ro:vector-immutable + (ro:for/list ([i n]) (to-int (ro:list-ref v i))))] + [(ro:vector? v) + (ro:apply ro:vector-immutable + (ro:for/list ([i n]) (to-int (ro:vector-ref v i))))] + [else + (ro:apply ro:vector-immutable + (ro:make-list n (to-int v)))])) + (ro:define (mk-intn x ...) + (ro:#%app ro:vector-immutable (to-int x) ...)) + )])) +(define-simple-macro (define-ints n ...) (begin (define-int n) ...)) +(define-ints 2 3 4 16) + +(define-syntax (define-float stx) + (syntax-parse stx + [(_ n) + #:with floatn (format-id #'n "float~a" (syntax->datum #'n)) + #:with to-floatn (format-id #'n "to-~a" #'floatn) + #:with mk-floatn (format-id #'n "mk-~a" #'floatn) + #:with (x ...) (generate-temporaries + (build-list (syntax->datum #'n) (lambda (x) x))) + #:with (I ...) (stx-map (lambda _ #'rosette3:Num) #'(x ...)) + #'(begin + (define-named-type-alias floatn + (add-constructm + (add-convertm + (rosette3:CVector I ...) + to-floatn) + mk-floatn)) + (ro:define (to-floatn v) + (ro:cond + [(ro:list? v) + (ro:apply ro:vector-immutable + (ro:for/list ([i n]) (to-float (ro:list-ref v i))))] + [(ro:vector? v) + (ro:apply ro:vector-immutable + (ro:for/list ([i n]) (to-float (ro:vector-ref v i))))] + [else + (ro:apply ro:vector-immutable + (ro:make-list n (to-float v)))])) + (ro:define (mk-floatn x ...) + (ro:#%app ro:vector-immutable (to-float x) ...)) + )])) +(define-simple-macro (define-floats n ...) (begin (define-float n) ...)) +(define-floats 2 3 4 16) + + (define-type-constructor Pointer #:arity = 1) (define-named-type-alias void rosette3:CUnit) (define-named-type-alias void* (Pointer rosette3:CUnit)) @@ -274,9 +234,6 @@ [⊢ (ro:#%app convert e-) ⇒ ty.norm]] [(_ ty:type e ...) ≫ ; construct [⊢ e ≫ e- ⇒ ty-e] ... - ;; #:fail-unless (cast-ok? #'ty-e #'ty.norm #'e) - ;; (format "cannot cast ~a to ~a" - ;; (type->str #'ty-e) (type->str #'ty.norm)) #:with construct (get-construct #'ty.norm) #:fail-unless (syntax-e #'construct) (format "no constructor found for ~a type" @@ -294,7 +251,6 @@ #:with ty-out ((current-type-eval) (format-id #'here "~a~a" base-str len-str)) -; #:with convert (get-convert #'ty-out) -------- [⊢ e-out ⇒ ty-out]] [(_ vec sel) ≫ ; applying vector to one arg is selector @@ -342,20 +298,18 @@ -------- [≻ #,(if (typecheck/un? #'ty-out #'void) #'(rosette3:define (f [x col ty] ... arr ty-out) - ;; TODO: this is deviating from rosette's impl - ;; but I think it's a bug in rosette - ;; otherwise it's unsound -; (⊢ (ro:set! x (ro:a conv x)) void) ... - (⊢m (ro:let ([x (ro:#%app conv x)] ...) - e ... - (rosette3:#%app rosette3:void)) - ty-out)) + ;; TODO: this is deviating from rosette's impl + ;; (to use let instead of set!) + ;; but I think it's a bug in rosette, otherwise it's unsound +; (⊢ (ro:set! x (ro:a conv x)) void) ... + (⊢m (ro:let ([x (ro:#%app conv x)] ...) + e ... (rosette3:#%app rosette3:void)) + ty-out)) #'(rosette3:define (f [x col ty] ... arr ty-out) -; (⊢ (ro:set! x (ro:a conv x)) void) ... - (⊢m (ro:let ([x (ro:#%app conv x)] ...) - (rosette3:#%app rosette3:void) - e ...) - ty-out)))]]) +; (⊢ (ro:set! x (ro:a conv x)) void) ... + (⊢m (ro:let ([x (ro:#%app conv x)] ...) + (rosette3:#%app rosette3:void) e ...) + ty-out)))]]) (define-typed-syntax kernel [(_ ty-out:type (f [ty:type x:id] ...) e ...) ≫ -------- @@ -372,14 +326,12 @@ -------- [≻ (if test {then ...} {})]]) -;(define-syntax-parameter range (syntax-rules ())) (define-typed-syntax (range e ...) ≫ [⊢ e ≫ e- ⇐ int] ... -------- [⊢ (ro:#%app ro:in-range e- ...) ⇒ int]) (define-typed-syntax for [(_ [((~literal :) ty:type var:id (~datum in) rangeExpr) ...] e ...) ≫ -; [⊢ rangeExpr ≫ rangeExpr- ⇒ _] ... [[var ≫ var- : ty.norm] ... ⊢ [e ≫ e- ⇒ ty-e] ...] -------- [⊢ (ro:for* ([var- rangeExpr] ...) @@ -389,14 +341,9 @@ ;; need to redefine #%datum because rosette3:#%datum is too precise (define-typed-syntax #%datum [(_ . b:boolean) ≫ -; #:with ty_out (if (syntax-e #'b) #'True #'False) -------- [⊢ (ro:#%datum . b) ⇒ bool]] [(_ . n:integer) ≫ - ;; #:with ty_out (let ([m (syntax-e #'n)]) - ;; (cond [(zero? m) #'Zero] - ;; [(> m 0) #'PosInt] - ;; [else #'NegInt])) -------- [⊢ (ro:#%datum . n) ⇒ int]] [(#%datum . n:number) ≫ @@ -565,9 +512,7 @@ [⊢ (name- (synth-app (bool) e1-) (synth-app (bool) e2-)) ⇒ bool]])])) (define-simple-macro (define-coercing-bool-binops o ...+) - (ro:begin - (provide o ...) - (define-coercing-bool-binop o) ...)) + (ro:begin (define-coercing-bool-binop o) ...)) (define-coercing-bool-binops || &&) @@ -582,72 +527,32 @@ -------- [⊢ (to-int (cl:== e1- e2-)) ⇒ int]]) -(define-typed-syntax + - [(_ e1 e2) ≫ - [⊢ e1 ≫ e1- ⇒ ty1] - [⊢ e2 ≫ e2- ⇒ ty2] - ;; #:when (real-type? #'ty1) - ;; #:when (real-type? #'ty2) - #:with ty-out (common-real-type #'ty1 #'ty2) - #:with convert (get-convert #'ty-out) - #:with ty-base (get-base #'ty-out) - #:with base-convert (get-convert #'ty-base) - -------- - [⊢ #,(if (scalar-type? #'ty-out) - #'(convert (cl:+ (synth-app (ty-out) e1-) - (synth-app (ty-out) e2-))) - #'(convert - (ro:let ([a (convert e1-)][b (convert e2-)]) - (ro:for/list ([v1 a][v2 b]) - (base-convert (cl:+ v1 v2)))))) ⇒ ty-out]]) - -(define-typed-syntax += - [(_ x e) ≫ - -------- - [≻ (= x (+ x e))]]) - -(define-typed-syntax % - [(_ e1 e2) ≫ - [⊢ e1 ≫ e1- ⇒ ty1] - [⊢ e2 ≫ e2- ⇒ ty2] - ;; #:when (real-type? #'ty1) - ;; #:when (real-type? #'ty2) - #:with ty-out (common-real-type #'ty1 #'ty2) - #:with convert (get-convert #'ty-out) - #:with ty-base (get-base #'ty-out) - #:with base-convert (get-convert #'ty-base) - -------- - [⊢ #,(if (scalar-type? #'ty-out) - #'(convert (cl:% (synth-app (ty-out) e1-) - (synth-app (ty-out) e2-))) - #'(convert - (ro:let ([a (convert e1-)][b (convert e2-)]) - (ro:for/list ([v1 a][v2 b]) - (base-convert (cl:% v1 v2)))))) ⇒ ty-out]]) - -(define-typed-syntax %= - [(_ x e) ≫ - -------- - [≻ (= x (% x e))]]) -#;(define-typed-syntax %= - [(_ x:id e) ≫ - [⊢ x ≫ x- ⇒ ty-x] - [⊢ e ≫ e- ⇒ ty-e] - #:fail-unless (cast-ok? #'ty-e #'ty-x stx) - (format "cannot cast ~a to ~a" - (type->str #'ty-e) (type->str #'ty-x)) - -------- - [⊢ (ro:set! x- (% x- (synth-app (ty-x) e-))) ⇒ void]]) - -#;(define-typed-syntax && - [(_ e1 e2) ≫ - [⊢ e1 ≫ e1- ⇐ bool] - [⊢ e2 ≫ e2- ⇐ bool] - -------- - [⊢ (cl:&& e1- e2-) ⇒ bool]] - ;; else try to coerce - [(_ e1 e2) ≫ - [⊢ e1 ≫ e1- ⇒ ty1] - [⊢ e2 ≫ e2- ⇒ ty2] - -------- - [⊢ (cl:&& (synth-app (bool) e1-) (synth-app (bool) e2-)) ⇒ bool]]) +(define-syntax (define-coercing-real-binop stx) + (syntax-parse stx + [(_ name) + #:with name- (mk-cl #'name) + #:with name= (format-id #'name "~a=" #'name) ; assignment form + #'(begin- + (define-typed-syntax name + [(_ e1 e2) ≫ + [⊢ e1 ≫ e1- ⇒ ty1] + [⊢ e2 ≫ e2- ⇒ ty2] + #:with ty-out (common-real-type #'ty1 #'ty2) + #:with convert (get-convert #'ty-out) + #:with ty-base (get-base #'ty-out) + #:with base-convert (get-convert #'ty-base) + -------- + [⊢ #,(if (scalar-type? #'ty-out) + #'(convert (name- (synth-app (ty-out) e1-) + (synth-app (ty-out) e2-))) + #'(convert + (ro:let ([a (convert e1-)][b (convert e2-)]) + (ro:for/list ([v1 a][v2 b]) + (base-convert (name- v1 v2)))))) ⇒ ty-out]]) + (define-typed-syntax name= + [(_ x e) ≫ + -------- + [≻ (= x (name x e))]]))])) +(define-simple-macro (define-real-binops o ...) + (ro:begin (define-coercing-real-binop o) ...)) +(define-real-binops + %)