diff --git a/turnstile/examples/rosette/rosette3.rkt b/turnstile/examples/rosette/rosette3.rkt index 1b4d77f..63b917d 100644 --- a/turnstile/examples/rosette/rosette3.rkt +++ b/turnstile/examples/rosette/rosette3.rkt @@ -56,7 +56,7 @@ (local-expand e 'expression (list #'ro:#%app #'ro:choose #'ro:synthesize #'ro:let #'ro:in-value - #'ro:assert #'ro:if))) + #'ro:assert #'ro:if #'ro:?))) ; (displayln (stx->datum e+)) e+) (define (mk-ro:-id id) (format-id id "ro:~a" id)) diff --git a/turnstile/examples/rosette/synthcl3.rkt b/turnstile/examples/rosette/synthcl3.rkt index 0e8e853..8fbb13f 100644 --- a/turnstile/examples/rosette/synthcl3.rkt +++ b/turnstile/examples/rosette/synthcl3.rkt @@ -1,5 +1,5 @@ #lang turnstile -(extends "rosette3.rkt" #:except ! #%app || && void = * + - #%datum if assert verify) ; typed rosette +(extends "rosette3.rkt" #:except ! #%app || && void = * + - / #%datum if assert verify) ; typed rosette (require ;(prefix-in ro: (except-in rosette verify sqrt range print)) ; untyped racket/stxparam (prefix-in ro: (combine-in rosette rosette/lib/synthax)) @@ -8,7 +8,10 @@ sdsl/synthcl/model/operators sdsl/synthcl/model/errors sdsl/synthcl/model/memory sdsl/synthcl/model/runtime sdsl/synthcl/model/work sdsl/synthcl/model/pointers - sdsl/synthcl/lang/queries)) + sdsl/synthcl/lang/queries sdsl/synthcl/model/context + sdsl/synthcl/model/queue sdsl/synthcl/model/buffer + sdsl/synthcl/model/flags sdsl/synthcl/model/program + sdsl/synthcl/model/kernel)) (for-syntax (prefix-in cl: sdsl/synthcl/lang/util))) (begin-for-syntax @@ -19,16 +22,31 @@ procedure kernel grammar #%datum if range for print choose locally-scoped assert synth verify int int2 int3 int4 int16 float float2 float3 float4 float16 - bool void void* char* float* int* int16* int2* + bool void void* char* float* int* int2* int3* int4* int16* + cl_context cl_command_queue cl_program cl_kernel cl_mem : ! ?: == + * - || && % << ; int ops = += -= %= ; assignment ops - sizeof + sizeof clCreateProgramWithSource (typed-out ;[with-output-to-string : (C→ (C→ Any) char*)] + [clCreateContext : (C→ cl_context)] + [clCreateCommandQueue : (C→ cl_context cl_command_queue)] + [clCreateBuffer : (C→ cl_context int int cl_mem)] + [clEnqueueReadBuffer : (C→ cl_command_queue cl_mem int int void* void)] + + [clEnqueueWriteBuffer : (C→ cl_command_queue cl_mem int int void* void)] + [clEnqueueNDRangeKernel : (C→ cl_command_queue cl_kernel int int* int* int* void)] + [clCreateKernel : (C→ cl_program char* cl_kernel)] + [clSetKernelArg : (Ccase-> (C→ cl_kernel int cl_mem void) + (C→ cl_kernel int int void) + (C→ cl_kernel int float void))] + [get_global_id : (C→ int int)] + [CL_MEM_READ_ONLY : int] + [CL_MEM_WRITE_ONLY : int] + [/ : (Ccase-> (C→ int int int) (C→ float float float))] [malloc : (C→ int void*)] [get_work_dim : (C→ int)] - [!= : (Ccase-> (C→ CNum CNum CBool) (C→ CNum CNum CNum CBool) (C→ Num Num Bool) @@ -63,9 +81,8 @@ (define common-real-type (case-lambda [(t) (and (real-type? t) t)] - [(t1 t2) (cond [(real-type<=? t1 t2) t2] - [(real-type<=? t2 t1) t1] - [else #f])] + [(t1 t2) (or (and (real-type<=? t1 t2) t2) + (and (real-type<=? t2 t1) t1))] [ts (common-real-type (car ts) (apply common-real-type (cdr ts)))])) ;; implements common-real-type from model/reals.rkt @@ -96,6 +113,8 @@ (type->str from) (type->str to) #;(if (contract? to) (contract-name to) to)) expr subexpr))) + (define (mk-ptr id) (format-id id "~a*" id)) + (define (mk-mk id) (format-id id "mk-~a" id)) (define (add-convert stx fn) (set-stx-prop/preserved stx 'convert fn)) (define (get-convert stx) @@ -114,6 +133,8 @@ ((current-type-eval) (datum->syntax ctx (string->symbol (car (regexp-match #px"[a-z]+" (type->str ty))))))) + (define (get-pointer-base ty [ctx #'here]) + (datum->syntax ctx (string->symbol (string-trim (type->str ty) "*")))) (define (vector-type? ty) (ty->len ty)) ; TODO: check and not pointer-type? (define (scalar-type? ty) @@ -146,113 +167,74 @@ (ro:define (mk-int v) (ro:#%app cl:int v)) -(ro:define (to-int16* v) - (cl:pointer-cast v cl:int16)) -(ro:define (to-int2* v) - (cl:pointer-cast v cl:int2)) (ro:define (to-int* v) (cl:pointer-cast v cl:int)) (ro:define (to-float* v) (cl:pointer-cast v cl:float)) -(define-named-type-alias bool - (add-convertm rosette3:Bool to-bool)) -(define-named-type-alias int - (add-convertm rosette3:Int to-int)) -(define-named-type-alias float - (add-convertm rosette3:Num to-float)) -(define-named-type-alias char* - (add-convertm rosette3:CString (λ (x) x))) +(define-type-constructor Pointer #:arity = 1) +;(define-named-type-alias void rosette3:CUnit) +(define-base-types void cl_context cl_command_queue cl_program cl_kernel cl_mem) +(define-named-type-alias void* (add-convertm (Pointer void) (λ (x) x))) +(define-named-type-alias bool (add-convertm rosette3:Bool to-bool)) +(define-named-type-alias int (add-convertm rosette3:Int to-int)) +(define-named-type-alias int* (add-convertm (Pointer int) to-int*)) +(define-named-type-alias float (add-convertm rosette3:Num to-float)) +(define-named-type-alias float* (add-convertm (Pointer float) to-float*)) +(define-named-type-alias char* (add-convertm rosette3:CString (λ (x) x))) (define-syntax (define-int stx) - (syntax-parse stx - [(_ n) - #:with intn (format-id #'n "int~a" (syntax->datum #'n)) - #:with to-intn (format-id #'n "to-~a" #'intn) - #:with mk-intn (format-id #'n "mk-~a" #'intn) - #:with cl-mk-intn (mk-cl #'intn) - #:with (x ...) (generate-temporaries - (build-list (syntax->datum #'n) (lambda (x) x))) - #:with (I ...) (stx-map (lambda _ #'rosette3:Int) #'(x ...)) - #'(begin - (define-named-type-alias intn - (add-constructm - (add-convertm - (rosette3:CVector I ...) - to-intn) - mk-intn)) - (ro:define (to-intn v) - (ro:cond - [(ro:list? v) - (ro:apply ro:vector-immutable - (ro:for/list ([i n]) (to-int (ro:list-ref v i))))] - [(ro:vector? v) - (ro:apply ro:vector-immutable - (ro:for/list ([i n]) (to-int (ro:vector-ref v i))))] - [else - (ro:apply ro:vector-immutable - (ro:make-list n (to-int v)))])) - (ro:define (mk-intn x ...) - (ro:#%app cl-mk-intn x ...) - #;(ro:#%app ro:vector-immutable (to-int x) ...)) - )])) + (syntax-parse stx + [(_ n) + #:with intn (format-id #'n "int~a" (syntax->datum #'n)) + #:with intn* (mk-ptr #'intn) + #:with to-intn (format-id #'n "to-~a" #'intn) + #:with mk-intn (mk-mk #'intn) + #:with to-intn* (mk-ptr #'to-intn) + #:with mk-intn* (mk-ptr #'mk-intn) + #:with cl-mk-intn (mk-cl #'intn) + #:with (x ...) (generate-temporaries (build-list (syntax->datum #'n) values)) + #:with (I ...) (stx-map (lambda _ #'rosette3:Int) #'(x ...)) + #'(begin + (define-named-type-alias intn + (add-constructm (add-convertm (rosette3:CVector I ...) to-intn) mk-intn)) + (define-named-type-alias intn* (add-convertm (Pointer intn) to-intn*)) + (ro:define (to-intn v) + (ro:cond + [(ro:list? v) + (ro:apply mk-intn (ro:for/list ([i n]) (ro:#%app to-int (ro:list-ref v i))))] + [(ro:vector? v) + (ro:apply mk-intn (ro:for/list ([i n]) (ro:#%app to-int (ro:vector-ref v i))))] + [else (ro:apply mk-intn (ro:make-list n (ro:#%app to-int v)))])) + (ro:define (to-intn* v) (cl:pointer-cast v cl-mk-intn)) + (ro:define (mk-intn x ...) (ro:#%app cl-mk-intn x ...)))])) (define-simple-macro (define-ints n ...) (begin (define-int n) ...)) (define-ints 2 3 4 16) (define-syntax (define-float stx) - (syntax-parse stx - [(_ n) - #:with floatn (format-id #'n "float~a" (syntax->datum #'n)) - #:with to-floatn (format-id #'n "to-~a" #'floatn) - #:with mk-floatn (format-id #'n "mk-~a" #'floatn) - #:with cl-mk-floatn (mk-cl #'floatn) - #:with (x ...) (generate-temporaries - (build-list (syntax->datum #'n) (lambda (x) x))) - #:with (I ...) (stx-map (lambda _ #'rosette3:Num) #'(x ...)) - #'(begin - (define-named-type-alias floatn - (add-constructm - (add-convertm - (rosette3:CVector I ...) - to-floatn) - mk-floatn)) - (ro:define (to-floatn v) - (ro:cond - [(ro:list? v) - (ro:apply ro:vector-immutable - (ro:for/list ([i n]) (to-float (ro:list-ref v i))))] - [(ro:vector? v) - (ro:apply ro:vector-immutable - (ro:for/list ([i n]) (to-float (ro:vector-ref v i))))] - [else - (ro:apply ro:vector-immutable - (ro:make-list n (to-float v)))])) - (ro:define (mk-floatn x ...) - (ro:#%app cl-mk-floatn x ...) - #;(ro:#%app ro:vector-immutable (to-float x) ...)) - )])) + (syntax-parse stx + [(_ n) + #:with floatn (format-id #'n "float~a" (syntax->datum #'n)) + #:with to-floatn (format-id #'n "to-~a" #'floatn) + #:with mk-floatn (mk-mk #'floatn) + #:with cl-mk-floatn (mk-cl #'floatn) + #:with (x ...) (generate-temporaries (build-list (syntax->datum #'n) values)) + #:with (I ...) (stx-map (lambda _ #'rosette3:Num) #'(x ...)) + #'(begin + (define-named-type-alias floatn + (add-constructm + (add-convertm (rosette3:CVector I ...) to-floatn) mk-floatn)) + (ro:define (to-floatn v) + (ro:cond + [(ro:list? v) + (ro:apply mk-floatn (ro:for/list ([i n]) (ro:#%app to-float (ro:list-ref v i))))] + [(ro:vector? v) + (ro:apply mk-floatn (ro:for/list ([i n]) (ro:#%app to-float (ro:vector-ref v i))))] + [else (ro:apply mk-floatn (ro:make-list n (ro:#%app to-float v)))])) + (ro:define (mk-floatn x ...) (ro:#%app cl-mk-floatn x ...)))])) (define-simple-macro (define-floats n ...) (begin (define-float n) ...)) (define-floats 2 3 4 16) - -(define-type-constructor Pointer #:arity = 1) -;(define-named-type-alias void rosette3:CUnit) -(define-base-type void) -#;(begin-for-syntax - (define-syntax ~void* - (pattern-expander - (make-variable-like-transformer #'(~and t:type (~parse ~void #'t.norm)))))) -(define-named-type-alias void* - (add-convertm (Pointer void) (λ (x) x))) -(define-named-type-alias int* - (add-convertm (Pointer int) to-int*)) -(define-named-type-alias int16* - (add-convertm (Pointer int16) to-int16*)) -(define-named-type-alias int2* - (add-convertm (Pointer int2) to-int2*)) -(define-named-type-alias float* - (add-convertm (Pointer float) to-float*)) - (define-typed-syntax synth-app [(_ (ty:type) e) ≫ ; cast [⊢ e ≫ e- ⇒ ty-e] @@ -282,13 +264,8 @@ [⊢ ptr ≫ ptr- ⇒ ty-ptr] #:when (pointer-type? #'ty-ptr) #:with ~! #'dummy ; commit [⊢ sel ≫ sel- ⇐ int] - #:do [(define split-ty (ty->len #'ty-ptr))] - #:when (and split-ty (= 3 (length split-ty))) - #:do [(define base-str (cadr split-ty)) - (define len-str (caddr split-ty))] - #:with ty-out ((current-type-eval) (format-id #'h "~a~a" base-str len-str)) -------- - [⊢ (cl:pointer-ref ptr- sel-) ⇒ ty-out]] + [⊢ (cl:pointer-ref ptr- sel-) ⇒ #,(get-pointer-base #'ty-ptr)]] [(_ vec sel) ≫ ; applying vector to one arg is selector [⊢ vec ≫ vec- ⇒ ty-vec] #:when (vector-type? #'ty-vec) @@ -342,13 +319,13 @@ (define- f- (lambda- (x ...) (rosette3:let ([x (⊢m (ro:#%app conv x) ty)] ...) - (⊢m (let- () e ... (rosette3:ann e-body : ty-out)) ty-out)))))]]) + (⊢m (ro:let () e ... (rosette3:ann e-body : ty-out)) ty-out)))) + (provide- f))]]) (define-typed-syntax kernel [(_ ty-out:type (f [ty:type x:id] ...) e ...) ≫ #:fail-unless (void? #'ty-out.norm) (format "expected void, given ~a" (type->str #'ty-out.norm)) - -------- - [≻ (procedure void (f [ty x] ...) e ...)]]) + --- [≻ (procedure void (f [ty x] ...) e ...)]]) (define-typed-syntax grammar [(_ ty-out:type (f [ty:type x:id] ...) e) ≫ #:with f- (generate-temporary #'f) @@ -369,23 +346,21 @@ (define-typed-syntax if [(_ test {then ...} {else ...}) ≫ -------- - [⊢ (ro:if (to-bool test) + [⊢ (ro:if (ro:#%app to-bool test) (ro:let () then ... (ro:void)) (ro:let () else ... (ro:void))) ⇒ void]] [(_ test {then ...}) ≫ - -------- - [≻ (if test {then ...} {})]]) + --- [≻ (if test {then ...} {})]]) (define-typed-syntax (range e ...) ≫ [⊢ e ≫ e- ⇐ int] ... - -------- - [⊢ (ro:#%app ro:in-range e- ...) ⇒ int]) + --- [⊢ (ro:#%app ro:in-range e- ...) ⇒ int]) (define-typed-syntax for - [(_ [((~literal :) ty:type var:id (~datum in) rangeExpr) ...] e ...) ≫ - [[var ≫ var- : ty.norm] ... ⊢ [e ≫ e- ⇒ ty-e] ...] + [(_ [((~literal :) ty:type x:id (~datum in) rangeExpr) ...] e ...) ≫ -------- - [⊢ (ro:for* ([var- rangeExpr] ...) - e- ... (ro:void)) ⇒ void]]) + [⊢ (ro:for* ([x rangeExpr] ...) + (rosette3:let ([x (⊢m x ty)] ...) + (⊢m (ro:let () e ... (ro:void)) void))) ⇒ void]]) ;; need to redefine #%datum because rosette3:#%datum is too precise @@ -433,10 +408,11 @@ (format "no pred for ~a" (type->str #'ty)) #:with (x- ...) (generate-temporaries #'(x ...)) #:with (x-- ...) (generate-temporaries #'(x ...)) + #:with mk-ty (format-id #'here "mk-~a" #'ty) -------- [≻ (begin- (ro:define-symbolic* x-- pred [#,(string->number len-str)]) ... - (ro:define x- (ro:apply ro:vector-immutable x--)) ... + (ro:define x- (ro:apply mk-ty x--)) ... (define-syntax- x (make-rename-transformer (assign-type #'x- #'ty.norm))) ...)]] [(_ ty:type [len] x:id ...) ≫ ; array of vector types @@ -457,7 +433,7 @@ [≻ (begin- (ro:define-symbolic* x-- pred [len base-len]) ... (ro:define x- - (ro:let ([*x (to-ty* (cl:malloc (ro:* len base-len)))]) + (ro:let ([*x (ro:#%app to-ty* (cl:malloc (ro:* len base-len)))]) (ro:for ([i len][v x--]) (cl:pointer-set! *x i (ro:apply mk-ty v))) *x)) ... @@ -506,12 +482,12 @@ #:with ty-base ((current-type-eval) (datum->syntax #'e (string->symbol out-base-str))) #:with base-convert (get-convert #'ty-base) ------- - [⊢ (convert - (ro:let ([a (convert e-)][b (convert e1-)][c (convert e2-)]) + [⊢ (ro:#%app convert + (ro:let ([a (ro:#%app convert e-)][b (ro:#%app convert e1-)][c (ro:#%app convert e2-)]) (ro:for/list ([idx #,(string->number out-len-str)]) (ro:if (ro:< (ro:vector-ref a idx) 0) - (base-convert (ro:vector-ref b idx)) - (base-convert (ro:vector-ref c idx)))))) + (ro:#%app base-convert (ro:vector-ref b idx)) + (ro:#%app base-convert (ro:vector-ref c idx)))))) ⇒ ty-out]] [(_ ~! e e1 e2) ≫ ; should be scalar and real [⊢ e ≫ e- ⇒ ty] @@ -535,10 +511,10 @@ [⊢ x ≫ x- ⇒ ty-x] [⊢ e ≫ e- ⇒ ty-e] #:fail-unless (cast-ok? #'ty-e #'ty-x stx) - (format "cannot cast ~a to ~a" - (type->str #'ty-e) (type->str #'ty-x)) + (format "cannot cast ~a to ~a" (type->str #'ty-e) (type->str #'ty-x)) + #:with conv (get-convert #'ty-x) -------- - [⊢ (ro:set! x- (synth-app (ty-x) e-)) ⇒ void]] + [⊢ (ro:set! x- #,(if (syntax-e #'conv) #'(ro:#%app conv e-) #'e-)) ⇒ void]] ;; selector can be list of numbers or up to wxyz for vectors of length <=4 [(_ [x:id sel] e) ≫ [⊢ x ≫ x- ⇒ ty-x] @@ -561,11 +537,11 @@ [(_ e) ≫ [⊢ e ≫ e- ⇐ bool] -------- - [⊢ (cl:! e-) ⇒ bool]] + [⊢ (ro:#%app cl:! e-) ⇒ bool]] [(_ e) ≫ ; else try to coerce [⊢ e ≫ e- ⇒ ty] -------- - [⊢ (cl:! (synth-app (bool) e-)) ⇒ bool]]) + [⊢ (ro:#%app cl:! (ro:#%app to-bool e-)) ⇒ bool]]) ;; TODO: this should produce int-vector result? (define-typed-syntax == @@ -576,7 +552,7 @@ #:when (real-type? #'ty2) #:with ty-out ((current-join) #'ty1 #'ty2) ; only need this for the len -------- - [⊢ (to-int (cl:== e1- e2-)) ⇒ int]]) + [⊢ (ro:#%app to-int (cl:== e1- e2-)) ⇒ int]]) (define-simple-macro (define-bool-ops o ...+) (ro:begin (define-bool-op o) ...)) (define-simple-macro (define-bool-op name) @@ -588,8 +564,7 @@ -------- [⊢ (name- e1- e2-) ⇒ bool]] [(_ e1 e2) ≫ ; else try to coerce - -------- - [⊢ (name- (synth-app (bool) e1) (synth-app (bool) e2)) ⇒ bool]])) + --- [⊢ (name- (ro:#%app to-bool e1) (ro:#%app to-bool e2)) ⇒ bool]])) (define-simple-macro (define-real-ops o ...) (ro:begin (define-real-op o) ...)) (define-simple-macro (define-real-op name (~optional (~seq #:extra-check p?) @@ -597,32 +572,31 @@ #:with name- (mk-cl #'name) #:with name= (format-id #'name "~a=" #'name) ; assignment form (begin- - (define-typed-syntax (name e1 e2) ≫ - [⊢ e1 ≫ e1- ⇒ ty1] - [⊢ e2 ≫ e2- ⇒ ty2] - #:with ty-out (common-real-type #'ty1 #'ty2) + (define-typed-syntax (name e (... ...)) ≫ + [⊢ e ≫ e- ⇒ ty] (... ...) + #:with ty-out (apply common-real-type (stx->list #'(ty (... ...)))) #:fail-unless (syntax-e #'ty-out) - (format "no common real type for operands; given ~a, ~a" - (type->str #'ty1) (type->str #'ty2)) - #:when (p? #'ty-out #'ty1 #'ty2) + (format "no common real type for operands; given ~a" + (types->str #'(ty (... ...)))) + #:when (p? #'ty-out #'(ty (... ...))) #:with convert (get-convert #'ty-out) #:with ty-base (get-base #'ty-out) #:with base-convert (get-convert #'ty-base) + #:with (x (... ...)) (generate-temporaries #'(e (... ...))) -------- [⊢ #,(if (scalar-type? #'ty-out) - #'(convert (name- (convert e1-) (convert e2-))) - #'(convert (ro:let ([a (convert e1-)][b (convert e2-)]) - (ro:for/list ([v1 a][v2 b]) - (base-convert (name- v1 v2)))))) ⇒ ty-out]) + #'(ro:#%app convert (name- (convert e-) (... ...))) + #'(ro:#%app convert (ro:let ([x (ro:#%app convert e-)] (... ...)) + (ro:for/list ([x x] (... ...)) + (ro:#%app base-convert (name- x (... ...))))))) ⇒ ty-out]) (define-typed-syntax (name= x e) ≫ - -------- - [≻ (= x (name x e))]))) + --- [≻ (= x (name x e))]))) -(define-for-syntax (int? t given1 given2) +(define-for-syntax (int? t givens) (or (typecheck/un? t #'int) (raise-syntax-error #f - (format "no common integer type for operands; given ~a, ~a" - (type->str given1) (type->str given2))))) + (format "no common integer type for operands; given ~a" + (types->str givens))))) (define-simple-macro (define-int-op o) (define-real-op o #:extra-check int?)) (define-simple-macro (define-int-ops o ...) (ro:begin (define-int-op o) ...)) @@ -631,12 +605,10 @@ (define-int-ops % <<) (define-typerule (sizeof t:type) >> - ---------- - [⊢ #,(real-type-length #'t.norm) ⇒ int]) + --- [⊢ #,(real-type-length #'t.norm) ⇒ int]) (define-typerule (print e ...) >> - ---------- - [⊢ (ro:begin (display e) ...) ⇒ void]) + --- [⊢ (ro:begin (display e) ...) ⇒ void]) (define-typed-syntax choose [(ch e ...+) ≫ @@ -656,26 +628,35 @@ (define-for-syntax (decl->seq stx) (syntax-parse stx - [((~datum :) type id (~datum in) rangeExpr) - (syntax/loc stx (id rangeExpr type))] - [((~datum :) type id) - (syntax/loc stx (id (ro:in-value (ro:let () (: type id) id)) type))])) + [((~datum :) ty:type id (~datum in) rangeExpr) + (syntax/loc stx (id rangeExpr ty.norm))] + [((~datum :) ty:type [len] id) + #:with tyout (mk-ptr #'ty) + (syntax/loc stx (id (ro:in-value (ro:let () (: ty [len] id) id)) tyout))] + [((~datum :) ty id) + (syntax/loc stx (id (ro:in-value (ro:let () (: ty id) id)) ty))])) -(define-typed-syntax (synth #:forall [decl ...] #:ensure e) ≫ +(define-typed-syntax synth + [(_ #:forall [decl ...] #:bitwidth bw #:ensure e) ≫ #:with ([id seq ty] ...) (stx-map decl->seq #'(decl ...)) #:with (id- ...) (generate-temporaries #'(id ...)) #:with (typed-seq ...) #'((with-ctx ([id id- ty] ...) seq) ...) #:with (tmp ...) (generate-temporaries #'(id ...)) -------- - [⊢ (ro:let ([id- 1] ...) ; dummy, enables simplifying stx template + [⊢ (ro:let ([id- 1] ...) ; dummy ensuring id- bound, simplifies stx template (ro:define-values (tmp ...) (ro:for*/lists (tmp ...) ([id- typed-seq] ...) (ro:values id- ...))) - (ro:parameterize ([ro:term-cache (ro:hash-copy (ro:term-cache))]) + (ro:parameterize ([ro:current-bitwidth bw] + [ro:term-cache (ro:hash-copy (ro:term-cache))]) (ro:print-forms (ro:synthesize #:forall (ro:append tmp ...) #:guarantee (ro:for ([id- tmp] ...) - (with-ctx ([id id- ty] ...) e)))))) ⇒ void]) + (with-ctx ([id id- ty] ...) e)))))) ⇒ void]] + [(_ #:forall [decl ...] #:ensure e) ≫ + --- [≻ (synth #:forall [decl ...] #:bitwidth 8 #:ensure e)]]) + + (define-typed-syntax verify [(vfy #:forall [decl ...] #:ensure e) ≫ @@ -684,7 +665,8 @@ #:with (typed-seq ...) #'((with-ctx ([id id- ty] ...) seq) ...) -------- [⊢ (ro:let ([id- 1] ...) ; dummy, enables simplifying stx template - (ro:parameterize ([ro:term-cache (ro:hash-copy (ro:term-cache))]) + (ro:parameterize ([ro:current-bitwidth 32] + [ro:term-cache (ro:hash-copy (ro:term-cache))]) (ro:for*/or ([id- typed-seq] ...) (ro:define cex (with-ctx ([id id- ty] ...) (ro:verify e))) (ro:and (ro:sat? cex) @@ -693,6 +675,12 @@ (printf "~a = ~a\n" i (ro:evaluate i- cex))))))) ⇒ void]]) (define-typed-syntax (assert e) ≫ - #:with e- (expand/ro #'e) - -------- - [⊢ (ro:assert (to-bool e-)) ⇒ void]) + --- [⊢ (ro:assert (ro:#%app to-bool #,(expand/ro #'e))) ⇒ void]) + +(define- (/ x y) + (cond- [(zero?- y) 0] + [(integer?- x) (quotient- x y)] + [else (/- x y)])) + +(define-typed-syntax (clCreateProgramWithSource ctx f) ≫ + --- [⊢ (cl:clCreateProgramWithSource ctx f) ⇒ cl_program]) diff --git a/turnstile/examples/tests/rosette/rosette3/matrix-synth-kernel.rkt b/turnstile/examples/tests/rosette/rosette3/matrix-synth-kernel.rkt new file mode 100644 index 0000000..8f5bb6b --- /dev/null +++ b/turnstile/examples/tests/rosette/rosette3/matrix-synth-kernel.rkt @@ -0,0 +1,110 @@ +#lang s-exp "../../../rosette/synthcl3.rkt" + +; Matrix multiplication C = A * B, where A is an n x p matrix and B is a p x m matrix. +(kernel void (mmulScalarKernel [int* A] [int* B] [int* C] [int p] [int m]) + (: int i j sum) + (= i (get_global_id 0)) + (= j (get_global_id 1)) + (= sum 0) + (for [(: int k in (range p))] + (+= sum (* [A (+ (* i p) k)] [B (+ (* k m) j)]))) + (= [C (+ (* i m) j)] sum)) + +; Matrix multiplication C = A * B, where A is an n x p matrix and B is a p x m matrix. +(kernel void (mmulVectorKernel [int4* A] [int4* B] [int4* C] [int p] [int m]) + (: int i j) + (: int4 sum0 sum1 sum2 sum3) + + (= i (get_global_id 0)) + (= j (get_global_id 1)) + (= sum0 0) + (= sum1 0) + (= sum2 0) + (= sum3 0) + + (for [(: int k in (range 0 p 4))] + (: int4 a0 a1 a2 a3) + (: int4 b0 b1 b2 b3) + + (= a0 [A (indexA 0 i k p)]) + (= a1 [A (indexA 1 i k p)]) + (= a2 [A (indexA 2 i k p)]) + (= a3 [A (indexA 3 i k p)]) + + (= b0 [B (indexB 0 k j m)]) + (= b1 [B (indexB 1 k j m)]) + (= b2 [B (indexB 2 k j m)]) + (= b3 [B (indexB 3 k j m)]) + + (+= sum0 (computeSum a0 b0 b1 b2 b3)) + (+= sum1 (computeSum a1 b0 b1 b2 b3)) + (+= sum2 (computeSum a2 b0 b1 b2 b3)) + (+= sum3 (computeSum a3 b0 b1 b2 b3))) + (= [C (indexC 0 i j m)] sum0) + (= [C (indexC 1 i j m)] sum1) + (= [C (indexC 2 i j m)] sum2) + (= [C (indexC 3 i j m)] sum3)) + +; Multiplies the 1x4 vector a by the 4x4 matrix with rows b0, b1, b2 and b3. +(procedure int4 (computeSum [int4 a] [int4 b0] [int4 b1] [int4 b2] [int4 b3]) + (int4 + (+ (* [a x] [b0 x]) (* [a y] [b1 x]) (* [a z] [b2 x]) (* [a w] [b3 x])) + (+ (* [a x] [b0 y]) (* [a y] [b1 y]) (* [a z] [b2 y]) (* [a w] [b3 y])) + (+ (* [a x] [b0 z]) (* [a y] [b1 z]) (* [a z] [b2 z]) (* [a w] [b3 z])) + (+ (* [a x] [b0 w]) (* [a y] [b1 w]) (* [a z] [b2 w]) (* [a w] [b3 w])))) + +; Function for indexing into the matrix A. +(procedure int (indexA [int off] [int i] [int k] [int p]) + ; (+ (* (+ (* i 4) off) (/ p 4)) (/ k 4))) + (: int r c w) + (= r (+ (choose i (/ i 4) (* i 4)) (choose off 0))) + (= c (+ (choose k (/ k 4) (* k 4)) (choose off 0))) + (= w (+ (choose p (/ p 4) (* p 4)) (choose off 0))) + (+ (* r w) c)) + +; Function for indexing into the matrix B. +(procedure int (indexB [int off] [int k] [int j] [int m]) + ;(+ (* (+ k off) (/ m 4)) j)) + (: int r c w) + (= r (+ (choose k (/ k 4) (* k 4)) (choose off 0))) + (= c (+ (choose j (/ j 4) (* j 4)) (choose off 0))) + (= w (+ (choose m (/ m 4) (* m 4)) (choose off 0))) + (+ (* r w) c)) + +; Function for indexing into the matrix C. +(procedure int (indexC [int off] [int i] [int j] [int m]) + ;(+ (* (+ (* i 4) off) (/ m 4)) j)) + (: int r c w) + (= r (+ (choose i (/ i 4) (* i 4)) (choose off 0))) + (= c (+ (choose j (/ j 4) (* j 4)) (choose off 0))) + (= w (+ (choose m (/ m 4) (* m 4)) (choose off 0))) + (+ (* r w) c)) + + +; Bad sketch and completions: +; Function for indexing into the matrix A. +;(procedure int (indexA [int off] [int i] [int k] [int p]) + ; (+ (* (+ (/ i 4) off) (/ p 4)) (/ k 4))) + ;(: int r c w) + ;(= r (+ (choose i (/ i 4)) (choose off 0))) + ;(= c (+ (choose k (/ k 4)) (choose off 0))) + ;(= w (+ (choose p (/ p 4)) (choose off 0))) + ;(+ (* r w) c)) + +; Function for indexing into the matrix B. +;(procedure int (indexB [int off] [int k] [int j] [int m]) + ;(+ (* (+ k off) (/ m 4)) j)) + ;(: int r c w) + ;(= r (+ (choose k (/ k 4)) (choose off 0))) + ;(= c (+ (choose j (/ j 4)) (choose off 0))) + ;(= w (+ (choose m (/ m 4)) (choose off 0))) + ;(+ (* r w) c)) + +; Function for indexing into the matrix C. +;(procedure int (indexC [int off] [int i] [int j] [int m]) + ;(+ (* (+ (/ i 4) off) (/ m 4)) j)) + ;(: int r c w) + ;(= r (+ (choose i (/ i 4)) (choose off 0))) + ;(= c (+ (choose j (/ j 4)) (choose off 0))) + ;(= w (+ (choose m (/ m 4)) (choose off 0))) + ;(+ (* r w) c)) diff --git a/turnstile/examples/tests/rosette/rosette3/synthcl3-matrix-tests.rkt b/turnstile/examples/tests/rosette/rosette3/synthcl3-matrix-tests.rkt index 9d5761c..795aefe 100644 --- a/turnstile/examples/tests/rosette/rosette3/synthcl3-matrix-tests.rkt +++ b/turnstile/examples/tests/rosette/rosette3/synthcl3-matrix-tests.rkt @@ -3,98 +3,104 @@ (prefix-in cl: sdsl/synthcl/lang/main) (prefix-in ro: (rename-in rosette [#%app a]))) -(: int2 [3] xs) -xs -(check-type xs : int2*) - -(: int [4] xs2) -xs2 -(check-type xs2 : int*) - - - -;; ; The reference implementation for square matrix multiplication. -;; ; Multiplies two squre matrices A and B, where the dimension of A is -;; ; n x p and dimension of B is p x m. Both matrices are given as -;; ; flat arrays in row-major form. The output is the matrix C = A*B, -;; ; also given in row-major form. -;; (procedure int* (mmulSequential [int* A] [int* B] [int n] [int p] [int m]) -;; (: int* C) -;; (= C ((int*) (malloc (* n m (sizeof int))))) -;; (for [(: int i in (range n)) -;; (: int j in (range m)) -;; (: int k in (range p))] -;; (+= [C (+ (* i m) j)] (* [A (+ (* i p) k)] [B (+ (* k m) j)]))) -;; C) +; The reference implementation for square matrix multiplication. +; Multiplies two squre matrices A and B, where the dimension of A is +; n x p and dimension of B is p x m. Both matrices are given as +; flat arrays in row-major form. The output is the matrix C = A*B, +; also given in row-major form. +(procedure int* (mmulSequential [int* A] [int* B] [int n] [int p] [int m]) + (: int* C) + (= C ((int*) (malloc (* n m (sizeof int))))) + (for [(: int i in (range n)) + (: int j in (range m)) + (: int k in (range p))] + (+= [C (+ (* i m) j)] (* [A (+ (* i p) k)] [B (+ (* k m) j)]))) + C) ; A host implementation of matrix multiplication. (procedure int* (mmulHost [char* kernelName] [int typeLen] [int* A] [int* B] [int n] [int p] [int m]) - ;; (: cl_context context) - ;; (: cl_command_queue command_queue) - ;; (: cl_program program) - ;; (: cl_kernel kernel) - ;; (: cl_mem buffer_A buffer_B buffer_C) - (: int* C) - ;; (: int[2] global) - ;; (: int dimA dimB dimC) - - ;; (= [global 0] (/ n typeLen)) - ;; (= [global 1] (/ m typeLen)) - ;; (= dimA (* n p (sizeof int))) - ;; (= dimB (* p m (sizeof int))) - ;; (= dimC (* n m (sizeof int))) - - ;; (= C ((int*) (malloc dimC))) - - ;; (= context (clCreateContext)) - - ;; (= command_queue (clCreateCommandQueue context)) - - ;; (= buffer_A (clCreateBuffer context CL_MEM_READ_ONLY dimA)) - ;; (= buffer_B (clCreateBuffer context CL_MEM_READ_ONLY dimB)) - ;; (= buffer_C (clCreateBuffer context CL_MEM_WRITE_ONLY dimC)) - - ;; (= program (clCreateProgramWithSource context "kernel.rkt")) - - ;; (clEnqueueWriteBuffer command_queue buffer_A 0 dimA A) - ;; (clEnqueueWriteBuffer command_queue buffer_B 0 dimB B) - - ;; (= kernel (clCreateKernel program kernelName)) - ;; (clSetKernelArg kernel 0 buffer_A) - ;; (clSetKernelArg kernel 1 buffer_B) - ;; (clSetKernelArg kernel 2 buffer_C) - ;; (clSetKernelArg kernel 3 p) - ;; (clSetKernelArg kernel 4 m) + (: cl_context context) + (: cl_command_queue command_queue) + (: cl_program program) + (: cl_kernel kernel) + (: cl_mem buffer_A buffer_B buffer_C) - ;; (clEnqueueNDRangeKernel command_queue kernel 2 NULL global NULL) - ;; (clEnqueueReadBuffer command_queue buffer_C 0 dimC C) + (: int* C) + (: int[2] global) + (: int dimA dimB dimC) + (= [global 0] (/ n typeLen)) + (= [global 1] (/ m typeLen)) + (= dimA (* n p (sizeof int))) + (= dimB (* p m (sizeof int))) + (= dimC (* n m (sizeof int))) + + (= C ((int*) (malloc dimC))) + + (= context (clCreateContext)) + + (= command_queue (clCreateCommandQueue context)) + + (= buffer_A (clCreateBuffer context CL_MEM_READ_ONLY dimA)) + (= buffer_B (clCreateBuffer context CL_MEM_READ_ONLY dimB)) + (= buffer_C (clCreateBuffer context CL_MEM_WRITE_ONLY dimC)) + + (= program (clCreateProgramWithSource context "matrix-synth-kernel.rkt")) + + (clEnqueueWriteBuffer command_queue buffer_A 0 dimA A) + (clEnqueueWriteBuffer command_queue buffer_B 0 dimB B) + + (= kernel (clCreateKernel program kernelName)) + (clSetKernelArg kernel 0 buffer_A) + (clSetKernelArg kernel 1 buffer_B) + (clSetKernelArg kernel 2 buffer_C) + (clSetKernelArg kernel 3 p) + (clSetKernelArg kernel 4 m) + + (clEnqueueNDRangeKernel command_queue kernel 2 NULL global NULL) + (clEnqueueReadBuffer command_queue buffer_C 0 dimC C) C) -;; ; A scalar parallel implementation of matrix multiplication. -;; (procedure int* (mmulScalar [int* A] [int* B] [int n] [int p] [int m]) -;; (mmulHost "mmulScalarKernel" 1 A B n p m)) +; A scalar parallel implementation of matrix multiplication. +(procedure int* (mmulScalar [int* A] [int* B] [int n] [int p] [int m]) + (mmulHost "mmulScalarKernel" 1 A B n p m)) ; A vector parallel implementation of matrix multiplication. The dimensions ; n and m must be evenly divisible by 4. -#;(procedure int* (mmulVector [int* A] [int* B] [int n] [int p] [int m]) +(procedure int* (mmulVector [int* A] [int* B] [int n] [int p] [int m]) (mmulHost "mmulVectorKernel" 4 A B n p m)) -;; ; Given two arrays of the same size, checks that they hold the same -;; ; values at each index. -;; (procedure void (check [int* actual] [int* expected] [int SIZE]) -;; (assert (>= SIZE 0)) -;; (for [(: int i in (range SIZE))] -;; (assert (== [actual i] [expected i])))) +; Given two arrays of the same size, checks that they hold the same +; values at each index. +(procedure void (check [int* actual] [int* expected] [int SIZE]) + (assert (>= SIZE 0)) + (for [(: int i in (range SIZE))] + (assert (== [actual i] [expected i])))) -;; (procedure void (synth_vector [int size]) -;; (synth #:forall [(: int n in (range size (+ 1 size))) -;; (: int p in (range size (+ 1 size))) -;; (: int m in (range size (+ 1 size))) -;; (: int[(* n p)] A) -;; (: int[(* p m)] B)] -;; #:ensure (check (mmulVector A B n p m) -;; (mmulSequential A B n p m) -;; (* n m)))) +(procedure void (synth_vector [int size]) + (synth #:forall [(: int n in (range size (+ 1 size))) + (: int p in (range size (+ 1 size))) + (: int m in (range size (+ 1 size))) + (: int[(* n p)] A) + (: int[(* p m)] B)] + #:ensure (check (mmulVector A B n p m) + (mmulSequential A B n p m) + (* n m)))) -;; ; (synth_vector 4) ; 20 sec -;; ; (synth_vector 8) ; 252 sec +(: int n) +(= n 4) +(: int p) +(= p 4) +(: int m) +(= m 4) +(: int[(* n p)] A) +(: int[(* p m)] B) +(check-type (mmulVector A B n p m) : int*) +(check-type (mmulSequential A B n p m) : int*) + +(check-type + (with-output-to-string + (λ () + (synth_vector 4))) ; 20 sec + : CString + -> "/home/stchang/NEU_Research/macrotypes/turnstile/examples/tests/rosette/rosette3/matrix-synth-kernel.rkt:57:0\n'(procedure\n int\n (indexA (int off) (int i) (int k) (int p))\n (: int r c w)\n (= r (+ (choose i (/ i 4) (* i 4)) off))\n (= c (+ (choose k (/ k 4) (* k 4)) 0))\n (= w (+ (/ p 4) 0))\n (+ (* r w) c))\n/home/stchang/NEU_Research/macrotypes/turnstile/examples/tests/rosette/rosette3/matrix-synth-kernel.rkt:66:0\n'(procedure\n int\n (indexB (int off) (int k) (int j) (int m))\n (: int r c w)\n (= r (+ (choose k (/ k 4) (* k 4)) off))\n (= c (+ (choose j (/ j 4) (* j 4)) 0))\n (= w (+ (/ m 4) 0))\n (+ (* r w) c))\n/home/stchang/NEU_Research/macrotypes/turnstile/examples/tests/rosette/rosette3/matrix-synth-kernel.rkt:75:0\n'(procedure\n int\n (indexC (int off) (int i) (int j) (int m))\n (: int r c w)\n (= r (+ (choose i (/ i 4) (* i 4)) off))\n (= c (+ (choose j (/ j 4) (* j 4)) 0))\n (= w (+ (/ m 4) 0))\n (+ (* r w) c))\n") +;(synth_vector 8) ; 252 sec diff --git a/turnstile/examples/tests/rosette/rosette3/synthcl3-tests.rkt b/turnstile/examples/tests/rosette/rosette3/synthcl3-tests.rkt index a8037d2..4d8f261 100644 --- a/turnstile/examples/tests/rosette/rosette3/synthcl3-tests.rkt +++ b/turnstile/examples/tests/rosette/rosette3/synthcl3-tests.rkt @@ -345,3 +345,36 @@ {(assert k)})))) : CString -> "counterexample found:\nt = 2\nk = 0\np = 4\n") + +(: int2 [3] xs) +(check-type xs : int2*) + +(: int [4] xs2) +(check-type xs2 : int*) + +; basic matrix multiplying +(: int4 sum0 sum1 sum2 sum3) +(= sum0 0) +(= sum1 0) +(= sum2 0) +(= sum3 0) + +(procedure int (computeSum1 [int4 a] [int4 b0] [int4 b1] [int4 b2] [int4 b3]) + (+ (* [a x] [b0 x]) (* [a y] [b1 x]) (* [a z] [b2 x]) (* [a w] [b3 x]))) +(procedure int (computeSum2 [int4 a] [int4 b0] [int4 b1] [int4 b2] [int4 b3]) + (+ (* [a x] [b0 y]) (* [a y] [b1 y]) (* [a z] [b2 y]) (* [a w] [b3 y]))) +(procedure int (computeSum3 [int4 a] [int4 b0] [int4 b1] [int4 b2] [int4 b3]) + (+ (* [a x] [b0 z]) (* [a y] [b1 z]) (* [a z] [b2 z]) (* [a w] [b3 z]))) +(procedure int (computeSum4 [int4 a] [int4 b0] [int4 b1] [int4 b2] [int4 b3]) + (+ (* [a x] [b0 w]) (* [a y] [b1 w]) (* [a z] [b2 w]) (* [a w] [b3 w]))) + +(check-type (computeSum1 sum0 sum0 sum1 sum2 sum3) : int -> 0) +(check-type (computeSum2 sum0 sum0 sum1 sum2 sum3) : int -> 0) +(check-type (computeSum3 sum0 sum0 sum1 sum2 sum3) : int -> 0) +(check-type (computeSum4 sum0 sum0 sum1 sum2 sum3) : int -> 0) +(check-type (int4 (computeSum1 sum0 sum0 sum1 sum2 sum3) + (computeSum2 sum0 sum0 sum1 sum2 sum3) + (computeSum3 sum0 sum0 sum1 sum2 sum3) + (computeSum4 sum0 sum0 sum1 sum2 sum3)) + : int4 + -> (ro:a ro:vector-immutable 0 0 0 0))