From 21c77d7e61d0da2dfa264f618ad7d8a2e666dde4 Mon Sep 17 00:00:00 2001 From: Stephen Chang Date: Wed, 7 Dec 2016 16:08:13 -0500 Subject: [PATCH] add base/else grammar form; walsh synth tests passing - add ?? and @ - add comparison ops - use cl:/ instead of my own - fixed walsh synth scalar test - fix unbound ids err in for due to let*-like bindings --- turnstile/examples/rosette/rosette-notes.txt | 31 +++ turnstile/examples/rosette/synthcl3.rkt | 196 ++++++++++-------- .../rosette3/run-all-rosette-tests-script.rkt | 3 +- .../rosette3/run-all-synthcl-tests.rkt | 12 ++ .../rosette3/synthcl3-walsh-synth-tests.rkt | 146 +++++++++++++ .../rosette/rosette3/walsh-synth-kernel.rkt | 66 ++++++ 6 files changed, 367 insertions(+), 87 deletions(-) create mode 100644 turnstile/examples/tests/rosette/rosette3/run-all-synthcl-tests.rkt create mode 100644 turnstile/examples/tests/rosette/rosette3/synthcl3-walsh-synth-tests.rkt create mode 100644 turnstile/examples/tests/rosette/rosette3/walsh-synth-kernel.rkt diff --git a/turnstile/examples/rosette/rosette-notes.txt b/turnstile/examples/rosette/rosette-notes.txt index 88b7c8e..ad47dc6 100644 --- a/turnstile/examples/rosette/rosette-notes.txt +++ b/turnstile/examples/rosette/rosette-notes.txt @@ -1,3 +1,34 @@ +2016-12-07 -------------------- +synthcl Walsh synth tests were not working and had trouble debugging, so +documenting my process here + +1) getting the error: +; ?: literal data is not allowed; +; no #%datum syntax transformer is bound +; in: #f + +in general, means a stx prop is expected but not present + +In this case: +- convert fn undefined for cl_mem and other similar base types + + +2) getting unexpected unsat +helpful debugging technique 1: +- print exns swallowed by eval/assert (used by synthesize or verify) + +In this case: +- mk-float* was undefined +- bool was not considered "real" but should be +- cmps not defined + +helpful debugging technique 2: +- print asserts in ∃∀-solve, and compare to expected +- may have to set error-print-width higher (default 256) + +In this case, comparing asserts showed that quotient was missing from my typed +synthcl. Using the / defined by synthcl (instead of mine) fixed the problem. + 2016-11-18 -------------------- working on synthcl3 lang diff --git a/turnstile/examples/rosette/synthcl3.rkt b/turnstile/examples/rosette/synthcl3.rkt index 738c6a8..42fbabe 100644 --- a/turnstile/examples/rosette/synthcl3.rkt +++ b/turnstile/examples/rosette/synthcl3.rkt @@ -1,9 +1,9 @@ #lang turnstile -(extends "rosette3.rkt" #:except ! #%app || && void = * + - / #%datum if assert verify) ; typed rosette +(extends "rosette3.rkt" #:except ! #%app || && void = * + - / #%datum if assert verify < <= > >=) ; typed rosette (require ;(prefix-in ro: (except-in rosette verify sqrt range print)) ; untyped racket/stxparam (prefix-in ro: (combine-in rosette rosette/lib/synthax)) - (prefix-in cl: (combine-in (except-in sdsl/synthcl/model/operators /) + (prefix-in cl: (combine-in sdsl/synthcl/model/operators sdsl/synthcl/lang/forms sdsl/synthcl/model/reals sdsl/synthcl/model/errors sdsl/synthcl/model/kernel sdsl/synthcl/model/memory sdsl/synthcl/model/runtime @@ -19,13 +19,14 @@ (provide (rename-out [synth-app #%app]) procedure kernel grammar #%datum if range for print - choose locally-scoped assert synth verify + choose ?? @ locally-scoped assert synth verify int int2 int3 int4 int16 float float2 float3 float4 float16 - bool void void* char* float* int* int2* int3* int4* int16* + bool void void* char* + int* int2* int3* int4* int16* float* float2* float3* float4* float16* cl_context cl_command_queue cl_program cl_kernel cl_mem : ! ?: == + * / - || && - % << ; int ops - = += -= *= /= %= ; assignment ops + % << $ & > >= < <= ; int ops + = += -= *= /= %= $= &= ; assignment ops sizeof clCreateProgramWithSource (typed-out [clCreateContext : (C→ cl_context)] @@ -42,6 +43,7 @@ [get_global_id : (C→ int int)] [CL_MEM_READ_ONLY : int] [CL_MEM_WRITE_ONLY : int] + [CL_MEM_READ_WRITE : int] [malloc : (C→ int void*)] [get_work_dim : (C→ int)] [!= : (Ccase-> (C→ CNum CNum CBool) @@ -57,7 +59,7 @@ (typecheck? ((current-type-eval) t1) ((current-type-eval) t2))) (define (real-type? t) - (and (not (typecheck/un? t #'bool)) + (and #;(not (typecheck/un? t #'bool)) (not (typecheck/un? t #'char*)) (not (pointer-type? t)))) (define (pointer-type? t) @@ -113,16 +115,12 @@ (define (mk-ptr id) (format-id id "~a*" id)) (define (mk-mk id) (format-id id "mk-~a" id)) (define (mk-to id) (format-id id "to-~a" id)) - (define (add-convert stx fn) - (set-stx-prop/preserved stx 'convert fn)) + (define (add-construct stx fn) (set-stx-prop/preserved stx 'construct fn)) + (define (add-convert stx fn) (set-stx-prop/preserved stx 'convert fn)) + (define (get-construct stx) (syntax-property stx 'construct)) (define (get-convert stx) - (syntax-property stx 'convert)) - (define (add-construct stx fn) - (set-stx-prop/preserved stx 'construct fn)) - (define (get-construct stx) - (syntax-property stx 'construct)) - (define (ty->len ty) - (regexp-match #px"([a-z]+)([0-9]+)" (type->str ty))) + (let ([conv (syntax-property stx 'convert)]) (or conv #'(λ (x) x)))) + (define (ty->len ty) (regexp-match #px"([a-z]+)([0-9]+)" (type->str ty))) (define (real-type-length t) (define split-ty (ty->len t)) (string->number @@ -133,42 +131,32 @@ (string->symbol (car (regexp-match #px"[a-z]+" (type->str ty))))))) (define (get-pointer-base ty [ctx #'here]) (datum->syntax ctx (string->symbol (string-trim (type->str ty) "*")))) - (define (vector-type? ty) - (ty->len ty)) ; TODO: check and not pointer-type? + (define (vector-type? ty) (ty->len ty)) ; TODO: check and not pointer-type? (define (scalar-type? ty) (or (typecheck/un? ty #'bool) (and (real-type? ty) (not (vector-type? ty)))))) -(define-syntax-parser add-convertm - [(_ stx fn) (add-convert #'stx #'fn)]) -(define-syntax-parser add-constructm - [(_ stx fn) (add-construct #'stx #'fn)]) +(define-syntax-parser add-convertm [(_ stx fn) (add-convert #'stx #'fn)]) +(define-syntax-parser add-constructm [(_ stx fn) (add-construct #'stx #'fn)]) -;; TODO: reuse impls in model/reals.rkt ? - -(ro:define (to-bool v) - (ro:cond - [(ro:boolean? v) v] - [(ro:number? v) (ro:! (ro:= 0 v))] - [else (cl:raise-conversion-error v "bool")])) +(ro:define (to-bool v) ; TODO: reuse impls in model/reals.rkt ? + (ro:cond [(ro:boolean? v) v] + [(ro:number? v) (ro:! (ro:= 0 v))] + [else (cl:raise-conversion-error v "bool")])) (ro:define (to-int v) (ro:cond [(ro:boolean? v) (ro:if v 1 0)] [(ro:fixnum? v) v] [(ro:flonum? v) (ro:exact-truncate v)] [else (ro:real->integer v)])) (ro:define (to-float v) - (ro:cond - [(ro:boolean? v) (ro:if v 1.0 0.0)] - [(ro:fixnum? v) (ro:exact->inexact v)] - [(ro:flonum? v) v] - [else (ro:type-cast ro:real? v)])) -(ro:define (mk-int v) - (ro:#%app cl:int v)) - -(ro:define (to-int* v) - (cl:pointer-cast v cl:int)) -(ro:define (to-float* v) - (cl:pointer-cast v cl:float)) + (ro:cond [(ro:boolean? v) (ro:if v 1.0 0.0)] + [(ro:fixnum? v) (ro:exact->inexact v)] + [(ro:flonum? v) v] + [else (ro:type-cast ro:real? v)])) +(ro:define (mk-int v) (ro:#%app cl:int v)) +(ro:define (mk-float v) (ro:#%app cl:float v)) +(ro:define (to-int* v) (cl:pointer-cast v cl:int)) +(ro:define (to-float* v) (cl:pointer-cast v cl:float)) (define-type-constructor Pointer #:arity = 1) ;(define-named-type-alias void rosette3:CUnit) @@ -213,8 +201,11 @@ (syntax-parse stx [(_ n) #:with floatn (format-id #'n "float~a" (syntax->datum #'n)) + #:with floatn* (mk-ptr #'floatn) #:with to-floatn (format-id #'n "to-~a" #'floatn) #:with mk-floatn (mk-mk #'floatn) + #:with to-floatn* (mk-ptr #'to-floatn) + #:with mk-floatn* (mk-ptr #'mk-floatn) #:with cl-mk-floatn (mk-cl #'floatn) #:with (x ...) (generate-temporaries (build-list (syntax->datum #'n) values)) #:with (I ...) (stx-map (lambda _ #'rosette3:Num) #'(x ...)) @@ -222,6 +213,7 @@ (define-named-type-alias floatn (add-constructm (add-convertm (rosette3:CVector I ...) to-floatn) mk-floatn)) + (define-named-type-alias floatn* (add-convertm (Pointer floatn) to-floatn*)) (ro:define (to-floatn v) (ro:cond [(ro:list? v) @@ -229,6 +221,7 @@ [(ro:vector? v) (ro:apply mk-floatn (ro:for/list ([i n]) (to-float (ro:vector-ref v i))))] [else (ro:apply mk-floatn (ro:make-list n (to-float v)))])) + (ro:define (to-floatn* v) (cl:pointer-cast v cl-mk-floatn)) (ro:define (mk-floatn x ...) (ro:#%app cl-mk-floatn x ...)))])) (define-simple-macro (define-floats n ...) (begin (define-float n) ...)) (define-floats 2 3 4 16) @@ -302,7 +295,7 @@ ;; top-level fns -------------------------------------------------- (define-typed-syntax procedure - [(~and (_ ty-out:type (f [ty:type x:id] ...)) ~!) ≫ + [(~and (_ ty-out:type (f [ty:type x:id] ...)) ~!) ≫ ; empty body #:fail-unless (void? #'ty-out.norm) (format "expected void, given ~a" (type->str #'ty-out.norm)) -------- @@ -312,8 +305,7 @@ #:with f- (add-orig (generate-temporary #'f) #'f) -------- [≻ (begin- - (define-syntax- f - (make-rename-transformer (⊢ f- : (C→ ty ... ty-out)))) + (define-syntax- f (make-rename-transformer (⊢ f- : (C→ ty ... ty-out)))) (define- f- (lambda- (x ...) (rosette3:let ([x (⊢m (ro:#%app conv x) ty)] ...) @@ -325,6 +317,22 @@ (format "expected void, given ~a" (type->str #'ty-out.norm)) --- [≻ (procedure void (f [ty x] ...) e ...)]]) (define-typed-syntax grammar + [(_ ty-out:type (f [ty:type x:id] ... [ty-depth k]) #:base be #:else ee) ≫ + #:with f- (generate-temporary #'f) + #:with (a ...) (generate-temporaries #'(x ...)) + -------- + [≻ (ro:begin + (ro:define-synthax (f- x ... k) #:base (rosette3:ann be : ty-out) + #:else (rosette3:ann ee : ty-out)) + (define-typed-syntax f + [(ff a ... j) ≫ + [⊢ a ≫ _ ⇐ ty] ... + [⊢ j ≫ _ ⇐ ty-depth] + ;; j will be eval'ed, so strip its context + #:with j- (assign-type (datum->syntax #'H (stx->datum #'j)) #'int) + #:with f-- (replace-stx-loc #'f- #'ff) + ----------- + [⊢ (f-- a ... j-) ⇒ ty-out]]))]] [(_ ty-out:type (f [ty:type x:id] ...) e) ≫ #:with f- (generate-temporary #'f) -------- @@ -342,40 +350,36 @@ ;; for and if statement -------------------------------------------------- (define-typed-syntax if - [(_ test {then ...} {else ...}) ≫ + [(_ e-test {e1 ...} {e2 ...}) ≫ -------- - [⊢ (ro:if (to-bool test) - (ro:let () then ... (ro:void)) - (ro:let () else ... (ro:void))) ⇒ void]] - [(_ test {then ...}) ≫ - --- [≻ (if test {then ...} {})]]) + [⊢ (ro:if (to-bool e-test) + (ro:let () e1 ... (ro:void)) + (ro:let () e2 ... (ro:void))) ⇒ void]] + [(_ e-test es) ≫ --- [≻ (if e-test es {})]]) (define-typed-syntax (range e ...) ≫ [⊢ e ≫ e- ⇐ int] ... --- [⊢ (ro:#%app ro:in-range e- ...) ⇒ int]) (define-typed-syntax for [(_ [((~literal :) ty:type x:id (~datum in) rangeExpr) ...] e ...) ≫ + #:with (x- ...) (generate-temporaries #'(x ...)) + #:with (typed-seq ...) #'((with-ctx ([x x- ty] ...) rangeExpr) ...) -------- - [⊢ (ro:for* ([x rangeExpr] ...) - (rosette3:let ([x (⊢m x ty)] ...) - (⊢m (ro:let () e ... (ro:void)) void))) ⇒ void]]) + [⊢ (ro:let ([x- 1] ...) ; dummy ensuring id- bound, simplifies stx template + (ro:for* ([x- typed-seq] ...) + (with-ctx ([x x- ty] ...) + (⊢m (ro:let () e ... (ro:void)) void)))) ⇒ void]]) ;; need to redefine #%datum because rosette3:#%datum is too precise (define-typed-syntax #%datum - [(_ . b:boolean) ≫ - -------- - [⊢ (ro:#%datum . b) ⇒ bool]] - [(_ . n:integer) ≫ - -------- - [⊢ (ro:#%datum . n) ⇒ int]] - [(#%datum . n:number) ≫ + [(_ . b:boolean) ≫ --- [⊢ (ro:#%datum . b) ⇒ bool]] + [(_ . s:str) ≫ --- [⊢ (ro:#%datum . s) ⇒ char*]] + [(_ . n:integer) ≫ --- [⊢ (ro:#%datum . n) ⇒ int]] + [(#%datum . n:number) ≫ #:when (real? (syntax-e #'n)) -------- [⊢ (ro:#%datum . n) ⇒ float]] - [(_ . s:str) ≫ - -------- - [⊢ (ro:#%datum . s) ⇒ char*]] [(_ . x) ≫ -------- [_ #:error (type-error #:src #'x #:msg "Unsupported literal: ~v" #'x)]]) @@ -460,12 +464,6 @@ ;; ?: -------------------------------------------------- (define-typed-syntax ?: - [(_ e e1 e2) ≫ - [⊢ e ≫ e- ⇐ bool] - [⊢ e1 ≫ e1- ⇒ ty1] - [⊢ e2 ≫ e2- ⇒ ty2] - ------- - [⊢ (cl:?: e- e1- e2-) ⇒ (⊔ τ1 τ2)]] [(_ e e1 e2) ≫ [⊢ e ≫ e- ⇒ ty] ; vector type #:do [(define split-ty (ty->len #'ty))] @@ -502,8 +500,7 @@ (synth-app (ty-out) e1-) (synth-app (ty-out) e2-)) ⇒ ty-out]]) -;; = -------------------------------------------------- -;; assignment +;; = (assignment) -------------------------------------------------- (define-typed-syntax = [(_ x:id e) ≫ [⊢ x ≫ x- ⇒ ty-x] @@ -540,16 +537,21 @@ -------- [⊢ (ro:#%app cl:! (to-bool e-)) ⇒ bool]]) -;; TODO: this should produce int-vector result? -(define-typed-syntax == - [(_ e1 e2) ≫ - [⊢ e1 ≫ e1- ⇒ ty1] - [⊢ e2 ≫ e2- ⇒ ty2] - #:when (real-type? #'ty1) - #:when (real-type? #'ty2) - #:with ty-out ((current-join) #'ty1 #'ty2) ; only need this for the len - -------- - [⊢ (to-int (cl:== e1- e2-)) ⇒ int]]) +;; TODO: comparison ops need to support vec types (and result) +(define-simple-macro (mk-cmp cmp-op) + (define-typed-syntax cmp-op + [(o e1 e2) ≫ + [⊢ e1 ≫ e1- ⇒ ty1] + [⊢ e2 ≫ e2- ⇒ ty2] + #:when (real-type? #'ty1) + #:when (real-type? #'ty2) + #:with ty-out ((current-join) #'ty1 #'ty2) ; only need this for the len + #:with conv (get-convert #'ty-out) + #:with o- (mk-cl #'o) + -------- + [⊢ (to-int (o- (conv e1-) (conv e2-))) ⇒ int]])) +(define-simple-macro (mk-cmps o ...) (begin- (mk-cmp o) ...)) +(mk-cmps == < <= > >=) (define-simple-macro (define-bool-ops o ...+) (ro:begin (define-bool-op o) ...)) (define-simple-macro (define-bool-op name) @@ -563,7 +565,7 @@ [(_ e1 e2) ≫ ; else try to coerce --- [⊢ (name- (to-bool e1) (to-bool e2)) ⇒ bool]])) -(define- (cl:/ x y) +#;(define- (cl:/ x y) (cond- [(zero?- y) 0] [(integer?- x) (quotient- x y)] [else (/- x y)])) @@ -604,7 +606,7 @@ (define-bool-ops || &&) (define-real-ops + * - /) -(define-int-ops % <<) +(define-int-ops % << $ &) (define-typerule (sizeof t:type) >> ---[⊢ #,(real-type-length #'t.norm) ⇒ int]) (define-typerule (print e ...) >> ---[⊢ (ro:begin (display e) ...) ⇒ void]) @@ -621,9 +623,30 @@ -------- [⊢ (ch/disarmed e/disarmed ...) ⇒ #,(stx-car #'(ty ...))]]) -(define-typed-syntax (locally-scoped e ...) ≫ - -------- - [≻ (rosette3:let () e ...)]) +(define-typed-syntax ?? + [(qq) ≫ + #:with ??/progsrc (replace-stx-loc #'cl:?? #'qq) + -------- + [⊢ (??/progsrc) ⇒ int]] + [(qq ty:type) ≫ + #:with ??/progsrc (replace-stx-loc #'cl:?? #'qq) + #:with t (datum->syntax #'here (string->symbol (type->str #'ty.norm))) + #:with cl-t (mk-cl #'t) + ;; #:with ty-base (get-base #'ty.norm) + ;; #:with pred (get-pred ((current-type-eval) #'ty-base)) + -------- + [⊢ (??/progsrc cl-t) ⇒ ty]]) + +(define-typed-syntax (@ x:id) ≫ + [⊢ x ≫ x- ⇒ ty+] ;; TODO: check ty = real, non-ptr type + #:with ty (datum->syntax #'x (string->symbol (type->str #'ty+))) + #:with cl-ty (mk-cl #'ty) + --------- + [⊢ (cl:address-of x- cl-ty) ⇒ #,(mk-ptr #'ty)]) + +(define-typed-syntax locally-scoped + [(_ e ...) ⇐ ty ≫ --- [⊢ (ro:let () e ...)]] + [(_ e ...) ≫ --- [≻ (⊢m (ro:let () e ...) void)]]) (define-for-syntax (decl->seq stx) (syntax-parse stx @@ -646,6 +669,7 @@ (ro:define-values (tmp ...) (ro:for*/lists (tmp ...) ([id- typed-seq] ...) (ro:values id- ...))) (ro:parameterize ([ro:current-bitwidth bw] ; matrix mult unsat w/o this + [ro:current-oracle (ro:oracle (ro:current-oracle))] [ro:term-cache (ro:hash-copy (ro:term-cache))]) (ro:print-forms (ro:synthesize diff --git a/turnstile/examples/tests/rosette/rosette3/run-all-rosette-tests-script.rkt b/turnstile/examples/tests/rosette/rosette3/run-all-rosette-tests-script.rkt index cd69a2d..324c3a9 100644 --- a/turnstile/examples/tests/rosette/rosette3/run-all-rosette-tests-script.rkt +++ b/turnstile/examples/tests/rosette/rosette3/run-all-rosette-tests-script.rkt @@ -24,6 +24,7 @@ (do-tests "synthcl3-tests.rkt" "SynthCL" "synthcl3-matrix-synth-tests.rkt" "SynthCL Matrix Mult: synth" "synthcl3-matrix-verify-tests.rkt" "SynthCL Matrix Mult: verify" - "synthcl3-matrix-verify-buggy-tests.rkt" "SynthCL buggy Matrix Mult: verify") + "synthcl3-matrix-verify-buggy-tests.rkt" "SynthCL buggy Matrix Mult: verify" + "synthcl3-walsh-synth-tests.rkt" "SynthCL Walsh Transform: synth") (do-tests "bv-ref-tests.rkt" "BV SDSL - Hacker's Delight synthesis") diff --git a/turnstile/examples/tests/rosette/rosette3/run-all-synthcl-tests.rkt b/turnstile/examples/tests/rosette/rosette3/run-all-synthcl-tests.rkt new file mode 100644 index 0000000..54713d7 --- /dev/null +++ b/turnstile/examples/tests/rosette/rosette3/run-all-synthcl-tests.rkt @@ -0,0 +1,12 @@ +#lang racket/base + +(require macrotypes/examples/tests/do-tests) + +(do-tests + "synthcl3-tests.rkt" "SynthCL" + "synthcl3-matrix-synth-tests.rkt" "SynthCL Matrix Mult: synth" + "synthcl3-matrix-verify-tests.rkt" "SynthCL Matrix Mult: verify" + "synthcl3-matrix-verify-buggy-tests.rkt" "SynthCL buggy Matrix Mult: verify" + "synthcl3-walsh-synth-tests.rkt" "SynthCL Walsh Transform: synth") + + diff --git a/turnstile/examples/tests/rosette/rosette3/synthcl3-walsh-synth-tests.rkt b/turnstile/examples/tests/rosette/rosette3/synthcl3-walsh-synth-tests.rkt new file mode 100644 index 0000000..cc0a48c --- /dev/null +++ b/turnstile/examples/tests/rosette/rosette3/synthcl3-walsh-synth-tests.rkt @@ -0,0 +1,146 @@ +#lang s-exp "../../../rosette/synthcl3.rkt" +(require "../../rackunit-typechecking.rkt") +; Compute the number of steps for the algorithm, +; assuming that v is a power of 2. See the log2 +; algorithm from http://graphics.stanford.edu/~seander/bithacks.html#IntegerLog +(procedure int (steps [int v]) + (: int r) + (= r 0) + ($= r (<< (!= 0 (& v #xAAAAAAAA)) 0)) + ($= r (<< (!= 0 (& v #xCCCCCCCC)) 1)) + ($= r (<< (!= 0 (& v #xF0F0F0F0)) 2)) + ($= r (<< (!= 0 (& v #xFF00FF00)) 3)) + ($= r (<< (!= 0 (& v #xFFFF0000)) 4)) + r) + +; Reference implementation for Fast Walsh Transform. This implementation +; requires the length of the input array to be a power of 2, and it modifies +; the input array in place. +(procedure float* (fwt [float* tArray] [int length]) + (for [(: int i in (range 0 (steps length)))] + (: int step) + (= step (<< 1 i)) + (for [(: int group in (range 0 step)) + (: int pair in (range group length (<< step 1)))] + (: int match) + (: float t1 t2) + (= match (+ pair step)) + (= t1 [tArray pair]) + (= t2 [tArray match]) + (= [tArray pair] (+ t1 t2)) + (= [tArray match] (- t1 t2)))) + tArray) + +; Scalar host for Fast Walsh Transform. This implementation +; requires the length of the input array to be a power of 2. The +; input array is not modified; the output is a new array that holds +; the result of the transform. +(procedure float* (fwtScalarHost [float* input] [int length]) + (: cl_context context) + (: cl_command_queue command_queue) + (: cl_program program) + (: cl_kernel kernel) + (: cl_mem tBuffer) + (: float* tArray) + (: int dim global) + + (= dim (* length (sizeof float))) + (= global (/ length 2)) + + (= tArray ((float*) (malloc dim))) + + (= context (clCreateContext)) + + (= command_queue (clCreateCommandQueue context)) + + (= tBuffer (clCreateBuffer context CL_MEM_READ_WRITE dim)) + (= program (clCreateProgramWithSource context "walsh-synth-kernel.rkt")) + + (clEnqueueWriteBuffer command_queue tBuffer 0 dim input) + + (= kernel (clCreateKernel program "fwtKernelSketch")) + (clSetKernelArg kernel 0 tBuffer) + + (for [(: int i in (range 0 (steps length)))] + (: int step) + (= step (<< 1 i)) + (clSetKernelArg kernel 1 step) + (clEnqueueNDRangeKernel command_queue kernel 1 NULL (@ global) NULL)) + + (clEnqueueReadBuffer command_queue tBuffer 0 dim tArray) + tArray) + +; Vectorized host for Fast Walsh Transform. This implementation +; requires the length of the input array to be a power of 2. The +; input array is not modified; the output is a new array that holds +; the result of the transform. +(procedure float* (fwtVectorHost [float* input] [int length]) + (: cl_context context) + (: cl_command_queue command_queue) + (: cl_program program) + (: cl_mem tBuffer) + (: float* tArray) + (: int dim global n) + + (= dim (* length (sizeof float))) + (= global (/ length 2)) + + (= tArray ((float*) (malloc dim))) + + (= context (clCreateContext)) + + (= command_queue (clCreateCommandQueue context)) + + (= tBuffer (clCreateBuffer context CL_MEM_READ_WRITE dim)) + (= program (clCreateProgramWithSource context "walsh-synth-kernel.rkt")) + + (clEnqueueWriteBuffer command_queue tBuffer 0 dim input) + + (= n (steps length)) + + (runKernel command_queue (clCreateKernel program "fwtKernel") tBuffer global 0 (?: (< n 2) n 2)) + (if (> n 2) + { (/= global 4) + (runKernel command_queue (clCreateKernel program "fwtKernel4Sketch") tBuffer global 2 n) }) + + (clEnqueueReadBuffer command_queue tBuffer 0 dim tArray) + tArray) + +(procedure void (runKernel [cl_command_queue command_queue] [cl_kernel kernel] [cl_mem tBuffer] + [int global] [int start] [int end]) + (clSetKernelArg kernel 0 tBuffer) + (for [(: int i in (range start end))] + (: int step) + (= step (<< 1 i)) + (clSetKernelArg kernel 1 step) + (clEnqueueNDRangeKernel command_queue kernel 1 NULL (@ global) NULL))) + +; Given two arrays of the same size, checks that they hold the same +; values at each index. +(procedure void (check [int* actual] [int* expected] [int SIZE]) + (assert (>= SIZE 0)) + (for [(: int i in (range SIZE))] + (assert (== [actual i] [expected i])))) + +(procedure void (synth_scalar) ; ~7 sec + (synth #:forall [(: int length in (range 8 9)) + (: float[length] tArray)] + #:ensure (check (fwtScalarHost tArray length) + (fwt tArray length) + length))) + +(procedure void (synth_vector) ; < 1 sec + (synth #:forall [(: int length in (range 8 9)) + (: float[length] tArray)] + #:ensure (check (fwtVectorHost tArray length) + (fwt tArray length) + length))) + +(check-type + (with-output-to-string (λ () (synth_scalar))) + : CString + -> "/home/stchang/NEU_Research/macrotypes/turnstile/examples/tests/rosette/rosette3/walsh-synth-kernel.rkt:3:0\n'(kernel\n void\n (fwtKernelSketch (float* tArray) (int step))\n (: int tid group pair match)\n (: float t1 t2)\n (= tid (get_global_id 0))\n (=\n group\n (rosette3:ann\n (locally-scoped\n (: int left right)\n (= left (rosette3:ann tid : int))\n (= right (rosette3:ann step : int))\n (% left right))\n :\n int))\n (=\n pair\n (+\n (*\n (<< step 1)\n (rosette3:ann\n (locally-scoped\n (: int left right)\n (= left (rosette3:ann tid : int))\n (= right (rosette3:ann step : int))\n (/ left right))\n :\n int))\n group))\n (= match (+ pair step))\n (= t1 (tArray pair))\n (= t2 (tArray match))\n (= (tArray pair) (+ t1 t2))\n (= (tArray match) (- t1 t2)))\n") +(check-type + (with-output-to-string (λ () (synth_vector))) + : CString + -> "/home/stchang/NEU_Research/macrotypes/turnstile/examples/tests/rosette/rosette3/walsh-synth-kernel.rkt:15:0\n'(kernel\n void\n (fwtKernel4Sketch (float4* tArray) (int step))\n (: int tid group pair match)\n (: float4 t1 t2)\n (= tid (get_global_id 0))\n (= step (/ step 4))\n (= group (% tid step))\n (= pair (+ (* (<< step 1) (/ tid step)) group))\n (= match (+ pair step))\n (= t1 (tArray pair))\n (= t2 (tArray match))\n (= (tArray pair) (+ t1 t2))\n (= (tArray match) (- t1 t2)))\n") diff --git a/turnstile/examples/tests/rosette/rosette3/walsh-synth-kernel.rkt b/turnstile/examples/tests/rosette/rosette3/walsh-synth-kernel.rkt new file mode 100644 index 0000000..9bfcc03 --- /dev/null +++ b/turnstile/examples/tests/rosette/rosette3/walsh-synth-kernel.rkt @@ -0,0 +1,66 @@ +#lang s-exp "../../../rosette/synthcl3.rkt" + +(kernel void (fwtKernelSketch [float* tArray] [int step]) + (: int tid group pair match) + (: float t1 t2) + (= tid (get_global_id 0)) + (= group (idx tid step 1)) ; (% tid step) + (= pair (+ (* (<< step 1) (idx tid step 1)) group)) ; (/ tid step) + (= match (+ pair step)) + (= t1 [tArray pair]) + (= t2 [tArray match]) + (= [tArray pair] (+ t1 t2)) + (= [tArray match] (- t1 t2))) + +(kernel void (fwtKernel4Sketch [float4* tArray] [int step]) + (: int tid group pair match) + (: float4 t1 t2) + (= tid (get_global_id 0)) + (= step [choose step (/ step 4) (* step 4) (% step 4)]) ; (/ step 4) + (= group (% tid step)) + (= pair (+ (* (<< step 1) (/ tid step)) group)) + (= match (+ pair step)) + (= t1 [tArray pair]) + (= t2 [tArray match]) + (= [tArray pair] (+ t1 t2)) + (= [tArray match] (- t1 t2))) + + +(grammar int (idx [int tid] [int step] [int depth]) + #:base (choose tid step (?? int)) + #:else (locally-scoped + (: int left right) + (= left (idx tid step (- depth 1))) + (= right (idx tid step (- depth 1))) + [choose left + (+ left right) + (- left right) + (/ left right) + (* left right) + (% left right)])) + +(kernel void (fwtKernel [float* tArray] [int step]) + (: int tid group pair match) + (: float t1 t2) + (= tid (get_global_id 0)) + (= group (% tid step)) + (= pair (+ (* (<< step 1) (/ tid step)) group)) + (= match (+ pair step)) + (= t1 [tArray pair]) + (= t2 [tArray match]) + (= [tArray pair] (+ t1 t2)) + (= [tArray match] (- t1 t2))) + +(kernel void (fwtKernel4 [float4* tArray] [int step]) + (: int tid group pair match) + (: float4 t1 t2) + (= tid (get_global_id 0)) + (= step (/ step 4)) + (= group (% tid step)) + (= pair (+ (* (<< step 1) (/ tid step)) group)) + (= match (+ pair step)) + (= t1 [tArray pair]) + (= t2 [tArray match]) + (= [tArray pair] (+ t1 t2)) + (= [tArray match] (- t1 t2))) +