add base/else grammar form; walsh synth tests passing

- add ?? and @
- add comparison ops
- use cl:/ instead of my own - fixed walsh synth scalar test
- fix unbound ids err in for due to let*-like bindings
This commit is contained in:
Stephen Chang 2016-12-07 16:08:13 -05:00
parent e0a2900c77
commit 21c77d7e61
6 changed files with 367 additions and 87 deletions

View File

@ -1,3 +1,34 @@
2016-12-07 --------------------
synthcl Walsh synth tests were not working and had trouble debugging, so
documenting my process here
1) getting the error:
; ?: literal data is not allowed;
; no #%datum syntax transformer is bound
; in: #f
in general, means a stx prop is expected but not present
In this case:
- convert fn undefined for cl_mem and other similar base types
2) getting unexpected unsat
helpful debugging technique 1:
- print exns swallowed by eval/assert (used by synthesize or verify)
In this case:
- mk-float* was undefined
- bool was not considered "real" but should be
- cmps not defined
helpful debugging technique 2:
- print asserts in ∃∀-solve, and compare to expected
- may have to set error-print-width higher (default 256)
In this case, comparing asserts showed that quotient was missing from my typed
synthcl. Using the / defined by synthcl (instead of mine) fixed the problem.
2016-11-18 --------------------
working on synthcl3 lang

View File

@ -1,9 +1,9 @@
#lang turnstile
(extends "rosette3.rkt" #:except ! #%app || && void = * + - / #%datum if assert verify) ; typed rosette
(extends "rosette3.rkt" #:except ! #%app || && void = * + - / #%datum if assert verify < <= > >=) ; typed rosette
(require ;(prefix-in ro: (except-in rosette verify sqrt range print)) ; untyped
racket/stxparam
(prefix-in ro: (combine-in rosette rosette/lib/synthax))
(prefix-in cl: (combine-in (except-in sdsl/synthcl/model/operators /)
(prefix-in cl: (combine-in sdsl/synthcl/model/operators
sdsl/synthcl/lang/forms sdsl/synthcl/model/reals
sdsl/synthcl/model/errors sdsl/synthcl/model/kernel
sdsl/synthcl/model/memory sdsl/synthcl/model/runtime
@ -19,13 +19,14 @@
(provide (rename-out [synth-app #%app])
procedure kernel grammar #%datum if range for print
choose locally-scoped assert synth verify
choose ?? @ locally-scoped assert synth verify
int int2 int3 int4 int16 float float2 float3 float4 float16
bool void void* char* float* int* int2* int3* int4* int16*
bool void void* char*
int* int2* int3* int4* int16* float* float2* float3* float4* float16*
cl_context cl_command_queue cl_program cl_kernel cl_mem
: ! ?: == + * / - || &&
% << ; int ops
= += -= *= /= %= ; assignment ops
% << $ & > >= < <= ; int ops
= += -= *= /= %= $= &= ; assignment ops
sizeof clCreateProgramWithSource
(typed-out
[clCreateContext : (C→ cl_context)]
@ -42,6 +43,7 @@
[get_global_id : (C→ int int)]
[CL_MEM_READ_ONLY : int]
[CL_MEM_WRITE_ONLY : int]
[CL_MEM_READ_WRITE : int]
[malloc : (C→ int void*)]
[get_work_dim : (C→ int)]
[!= : (Ccase-> (C→ CNum CNum CBool)
@ -57,7 +59,7 @@
(typecheck? ((current-type-eval) t1)
((current-type-eval) t2)))
(define (real-type? t)
(and (not (typecheck/un? t #'bool))
(and #;(not (typecheck/un? t #'bool))
(not (typecheck/un? t #'char*))
(not (pointer-type? t))))
(define (pointer-type? t)
@ -113,16 +115,12 @@
(define (mk-ptr id) (format-id id "~a*" id))
(define (mk-mk id) (format-id id "mk-~a" id))
(define (mk-to id) (format-id id "to-~a" id))
(define (add-convert stx fn)
(set-stx-prop/preserved stx 'convert fn))
(define (add-construct stx fn) (set-stx-prop/preserved stx 'construct fn))
(define (add-convert stx fn) (set-stx-prop/preserved stx 'convert fn))
(define (get-construct stx) (syntax-property stx 'construct))
(define (get-convert stx)
(syntax-property stx 'convert))
(define (add-construct stx fn)
(set-stx-prop/preserved stx 'construct fn))
(define (get-construct stx)
(syntax-property stx 'construct))
(define (ty->len ty)
(regexp-match #px"([a-z]+)([0-9]+)" (type->str ty)))
(let ([conv (syntax-property stx 'convert)]) (or conv #'(λ (x) x))))
(define (ty->len ty) (regexp-match #px"([a-z]+)([0-9]+)" (type->str ty)))
(define (real-type-length t)
(define split-ty (ty->len t))
(string->number
@ -133,42 +131,32 @@
(string->symbol (car (regexp-match #px"[a-z]+" (type->str ty)))))))
(define (get-pointer-base ty [ctx #'here])
(datum->syntax ctx (string->symbol (string-trim (type->str ty) "*"))))
(define (vector-type? ty)
(ty->len ty)) ; TODO: check and not pointer-type?
(define (vector-type? ty) (ty->len ty)) ; TODO: check and not pointer-type?
(define (scalar-type? ty)
(or (typecheck/un? ty #'bool)
(and (real-type? ty) (not (vector-type? ty))))))
(define-syntax-parser add-convertm
[(_ stx fn) (add-convert #'stx #'fn)])
(define-syntax-parser add-constructm
[(_ stx fn) (add-construct #'stx #'fn)])
(define-syntax-parser add-convertm [(_ stx fn) (add-convert #'stx #'fn)])
(define-syntax-parser add-constructm [(_ stx fn) (add-construct #'stx #'fn)])
;; TODO: reuse impls in model/reals.rkt ?
(ro:define (to-bool v)
(ro:cond
[(ro:boolean? v) v]
[(ro:number? v) (ro:! (ro:= 0 v))]
[else (cl:raise-conversion-error v "bool")]))
(ro:define (to-bool v) ; TODO: reuse impls in model/reals.rkt ?
(ro:cond [(ro:boolean? v) v]
[(ro:number? v) (ro:! (ro:= 0 v))]
[else (cl:raise-conversion-error v "bool")]))
(ro:define (to-int v)
(ro:cond [(ro:boolean? v) (ro:if v 1 0)]
[(ro:fixnum? v) v]
[(ro:flonum? v) (ro:exact-truncate v)]
[else (ro:real->integer v)]))
(ro:define (to-float v)
(ro:cond
[(ro:boolean? v) (ro:if v 1.0 0.0)]
[(ro:fixnum? v) (ro:exact->inexact v)]
[(ro:flonum? v) v]
[else (ro:type-cast ro:real? v)]))
(ro:define (mk-int v)
(ro:#%app cl:int v))
(ro:define (to-int* v)
(cl:pointer-cast v cl:int))
(ro:define (to-float* v)
(cl:pointer-cast v cl:float))
(ro:cond [(ro:boolean? v) (ro:if v 1.0 0.0)]
[(ro:fixnum? v) (ro:exact->inexact v)]
[(ro:flonum? v) v]
[else (ro:type-cast ro:real? v)]))
(ro:define (mk-int v) (ro:#%app cl:int v))
(ro:define (mk-float v) (ro:#%app cl:float v))
(ro:define (to-int* v) (cl:pointer-cast v cl:int))
(ro:define (to-float* v) (cl:pointer-cast v cl:float))
(define-type-constructor Pointer #:arity = 1)
;(define-named-type-alias void rosette3:CUnit)
@ -213,8 +201,11 @@
(syntax-parse stx
[(_ n)
#:with floatn (format-id #'n "float~a" (syntax->datum #'n))
#:with floatn* (mk-ptr #'floatn)
#:with to-floatn (format-id #'n "to-~a" #'floatn)
#:with mk-floatn (mk-mk #'floatn)
#:with to-floatn* (mk-ptr #'to-floatn)
#:with mk-floatn* (mk-ptr #'mk-floatn)
#:with cl-mk-floatn (mk-cl #'floatn)
#:with (x ...) (generate-temporaries (build-list (syntax->datum #'n) values))
#:with (I ...) (stx-map (lambda _ #'rosette3:Num) #'(x ...))
@ -222,6 +213,7 @@
(define-named-type-alias floatn
(add-constructm
(add-convertm (rosette3:CVector I ...) to-floatn) mk-floatn))
(define-named-type-alias floatn* (add-convertm (Pointer floatn) to-floatn*))
(ro:define (to-floatn v)
(ro:cond
[(ro:list? v)
@ -229,6 +221,7 @@
[(ro:vector? v)
(ro:apply mk-floatn (ro:for/list ([i n]) (to-float (ro:vector-ref v i))))]
[else (ro:apply mk-floatn (ro:make-list n (to-float v)))]))
(ro:define (to-floatn* v) (cl:pointer-cast v cl-mk-floatn))
(ro:define (mk-floatn x ...) (ro:#%app cl-mk-floatn x ...)))]))
(define-simple-macro (define-floats n ...) (begin (define-float n) ...))
(define-floats 2 3 4 16)
@ -302,7 +295,7 @@
;; top-level fns --------------------------------------------------
(define-typed-syntax procedure
[(~and (_ ty-out:type (f [ty:type x:id] ...)) ~!)
[(~and (_ ty-out:type (f [ty:type x:id] ...)) ~!) ; empty body
#:fail-unless (void? #'ty-out.norm)
(format "expected void, given ~a" (type->str #'ty-out.norm))
--------
@ -312,8 +305,7 @@
#:with f- (add-orig (generate-temporary #'f) #'f)
--------
[ (begin-
(define-syntax- f
(make-rename-transformer ( f- : (C→ ty ... ty-out))))
(define-syntax- f (make-rename-transformer ( f- : (C→ ty ... ty-out))))
(define- f-
(lambda- (x ...)
(rosette3:let ([x (⊢m (ro:#%app conv x) ty)] ...)
@ -325,6 +317,22 @@
(format "expected void, given ~a" (type->str #'ty-out.norm))
--- [ (procedure void (f [ty x] ...) e ...)]])
(define-typed-syntax grammar
[(_ ty-out:type (f [ty:type x:id] ... [ty-depth k]) #:base be #:else ee)
#:with f- (generate-temporary #'f)
#:with (a ...) (generate-temporaries #'(x ...))
--------
[ (ro:begin
(ro:define-synthax (f- x ... k) #:base (rosette3:ann be : ty-out)
#:else (rosette3:ann ee : ty-out))
(define-typed-syntax f
[(ff a ... j)
[ a _ ty] ...
[ j _ ty-depth]
;; j will be eval'ed, so strip its context
#:with j- (assign-type (datum->syntax #'H (stx->datum #'j)) #'int)
#:with f-- (replace-stx-loc #'f- #'ff)
-----------
[ (f-- a ... j-) ty-out]]))]]
[(_ ty-out:type (f [ty:type x:id] ...) e)
#:with f- (generate-temporary #'f)
--------
@ -342,40 +350,36 @@
;; for and if statement --------------------------------------------------
(define-typed-syntax if
[(_ test {then ...} {else ...})
[(_ e-test {e1 ...} {e2 ...})
--------
[ (ro:if (to-bool test)
(ro:let () then ... (ro:void))
(ro:let () else ... (ro:void))) void]]
[(_ test {then ...})
--- [ (if test {then ...} {})]])
[ (ro:if (to-bool e-test)
(ro:let () e1 ... (ro:void))
(ro:let () e2 ... (ro:void))) void]]
[(_ e-test es) --- [ (if e-test es {})]])
(define-typed-syntax (range e ...)
[ e e- int] ...
--- [ (ro:#%app ro:in-range e- ...) int])
(define-typed-syntax for
[(_ [((~literal :) ty:type x:id (~datum in) rangeExpr) ...] e ...)
#:with (x- ...) (generate-temporaries #'(x ...))
#:with (typed-seq ...) #'((with-ctx ([x x- ty] ...) rangeExpr) ...)
--------
[ (ro:for* ([x rangeExpr] ...)
(rosette3:let ([x (⊢m x ty)] ...)
(⊢m (ro:let () e ... (ro:void)) void))) void]])
[ (ro:let ([x- 1] ...) ; dummy ensuring id- bound, simplifies stx template
(ro:for* ([x- typed-seq] ...)
(with-ctx ([x x- ty] ...)
(⊢m (ro:let () e ... (ro:void)) void)))) void]])
;; need to redefine #%datum because rosette3:#%datum is too precise
(define-typed-syntax #%datum
[(_ . b:boolean)
--------
[ (ro:#%datum . b) bool]]
[(_ . n:integer)
--------
[ (ro:#%datum . n) int]]
[(#%datum . n:number)
[(_ . b:boolean) --- [ (ro:#%datum . b) bool]]
[(_ . s:str) --- [ (ro:#%datum . s) char*]]
[(_ . n:integer) --- [ (ro:#%datum . n) int]]
[(#%datum . n:number)
#:when (real? (syntax-e #'n))
--------
[ (ro:#%datum . n) float]]
[(_ . s:str)
--------
[ (ro:#%datum . s) char*]]
[(_ . x)
--------
[_ #:error (type-error #:src #'x #:msg "Unsupported literal: ~v" #'x)]])
@ -460,12 +464,6 @@
;; ?: --------------------------------------------------
(define-typed-syntax ?:
[(_ e e1 e2)
[ e e- bool]
[ e1 e1- ty1]
[ e2 e2- ty2]
-------
[ (cl:?: e- e1- e2-) ( τ1 τ2)]]
[(_ e e1 e2)
[ e e- ty] ; vector type
#:do [(define split-ty (ty->len #'ty))]
@ -502,8 +500,7 @@
(synth-app (ty-out) e1-)
(synth-app (ty-out) e2-)) ty-out]])
;; = --------------------------------------------------
;; assignment
;; = (assignment) --------------------------------------------------
(define-typed-syntax =
[(_ x:id e)
[ x x- ty-x]
@ -540,16 +537,21 @@
--------
[ (ro:#%app cl:! (to-bool e-)) bool]])
;; TODO: this should produce int-vector result?
(define-typed-syntax ==
[(_ e1 e2)
[ e1 e1- ty1]
[ e2 e2- ty2]
#:when (real-type? #'ty1)
#:when (real-type? #'ty2)
#:with ty-out ((current-join) #'ty1 #'ty2) ; only need this for the len
--------
[ (to-int (cl:== e1- e2-)) int]])
;; TODO: comparison ops need to support vec types (and result)
(define-simple-macro (mk-cmp cmp-op)
(define-typed-syntax cmp-op
[(o e1 e2)
[ e1 e1- ty1]
[ e2 e2- ty2]
#:when (real-type? #'ty1)
#:when (real-type? #'ty2)
#:with ty-out ((current-join) #'ty1 #'ty2) ; only need this for the len
#:with conv (get-convert #'ty-out)
#:with o- (mk-cl #'o)
--------
[ (to-int (o- (conv e1-) (conv e2-))) int]]))
(define-simple-macro (mk-cmps o ...) (begin- (mk-cmp o) ...))
(mk-cmps == < <= > >=)
(define-simple-macro (define-bool-ops o ...+) (ro:begin (define-bool-op o) ...))
(define-simple-macro (define-bool-op name)
@ -563,7 +565,7 @@
[(_ e1 e2) ; else try to coerce
--- [ (name- (to-bool e1) (to-bool e2)) bool]]))
(define- (cl:/ x y)
#;(define- (cl:/ x y)
(cond- [(zero?- y) 0]
[(integer?- x) (quotient- x y)]
[else (/- x y)]))
@ -604,7 +606,7 @@
(define-bool-ops || &&)
(define-real-ops + * - /)
(define-int-ops % <<)
(define-int-ops % << $ &)
(define-typerule (sizeof t:type) >> ---[ #,(real-type-length #'t.norm) int])
(define-typerule (print e ...) >> ---[ (ro:begin (display e) ...) void])
@ -621,9 +623,30 @@
--------
[ (ch/disarmed e/disarmed ...) #,(stx-car #'(ty ...))]])
(define-typed-syntax (locally-scoped e ...)
--------
[ (rosette3:let () e ...)])
(define-typed-syntax ??
[(qq)
#:with ??/progsrc (replace-stx-loc #'cl:?? #'qq)
--------
[ (??/progsrc) int]]
[(qq ty:type)
#:with ??/progsrc (replace-stx-loc #'cl:?? #'qq)
#:with t (datum->syntax #'here (string->symbol (type->str #'ty.norm)))
#:with cl-t (mk-cl #'t)
;; #:with ty-base (get-base #'ty.norm)
;; #:with pred (get-pred ((current-type-eval) #'ty-base))
--------
[ (??/progsrc cl-t) ty]])
(define-typed-syntax (@ x:id)
[ x x- ty+] ;; TODO: check ty = real, non-ptr type
#:with ty (datum->syntax #'x (string->symbol (type->str #'ty+)))
#:with cl-ty (mk-cl #'ty)
---------
[ (cl:address-of x- cl-ty) #,(mk-ptr #'ty)])
(define-typed-syntax locally-scoped
[(_ e ...) ty --- [ (ro:let () e ...)]]
[(_ e ...) --- [ (⊢m (ro:let () e ...) void)]])
(define-for-syntax (decl->seq stx)
(syntax-parse stx
@ -646,6 +669,7 @@
(ro:define-values (tmp ...)
(ro:for*/lists (tmp ...) ([id- typed-seq] ...) (ro:values id- ...)))
(ro:parameterize ([ro:current-bitwidth bw] ; matrix mult unsat w/o this
[ro:current-oracle (ro:oracle (ro:current-oracle))]
[ro:term-cache (ro:hash-copy (ro:term-cache))])
(ro:print-forms
(ro:synthesize

View File

@ -24,6 +24,7 @@
(do-tests "synthcl3-tests.rkt" "SynthCL"
"synthcl3-matrix-synth-tests.rkt" "SynthCL Matrix Mult: synth"
"synthcl3-matrix-verify-tests.rkt" "SynthCL Matrix Mult: verify"
"synthcl3-matrix-verify-buggy-tests.rkt" "SynthCL buggy Matrix Mult: verify")
"synthcl3-matrix-verify-buggy-tests.rkt" "SynthCL buggy Matrix Mult: verify"
"synthcl3-walsh-synth-tests.rkt" "SynthCL Walsh Transform: synth")
(do-tests "bv-ref-tests.rkt" "BV SDSL - Hacker's Delight synthesis")

View File

@ -0,0 +1,12 @@
#lang racket/base
(require macrotypes/examples/tests/do-tests)
(do-tests
"synthcl3-tests.rkt" "SynthCL"
"synthcl3-matrix-synth-tests.rkt" "SynthCL Matrix Mult: synth"
"synthcl3-matrix-verify-tests.rkt" "SynthCL Matrix Mult: verify"
"synthcl3-matrix-verify-buggy-tests.rkt" "SynthCL buggy Matrix Mult: verify"
"synthcl3-walsh-synth-tests.rkt" "SynthCL Walsh Transform: synth")

View File

@ -0,0 +1,146 @@
#lang s-exp "../../../rosette/synthcl3.rkt"
(require "../../rackunit-typechecking.rkt")
; Compute the number of steps for the algorithm,
; assuming that v is a power of 2. See the log2
; algorithm from http://graphics.stanford.edu/~seander/bithacks.html#IntegerLog
(procedure int (steps [int v])
(: int r)
(= r 0)
($= r (<< (!= 0 (& v #xAAAAAAAA)) 0))
($= r (<< (!= 0 (& v #xCCCCCCCC)) 1))
($= r (<< (!= 0 (& v #xF0F0F0F0)) 2))
($= r (<< (!= 0 (& v #xFF00FF00)) 3))
($= r (<< (!= 0 (& v #xFFFF0000)) 4))
r)
; Reference implementation for Fast Walsh Transform. This implementation
; requires the length of the input array to be a power of 2, and it modifies
; the input array in place.
(procedure float* (fwt [float* tArray] [int length])
(for [(: int i in (range 0 (steps length)))]
(: int step)
(= step (<< 1 i))
(for [(: int group in (range 0 step))
(: int pair in (range group length (<< step 1)))]
(: int match)
(: float t1 t2)
(= match (+ pair step))
(= t1 [tArray pair])
(= t2 [tArray match])
(= [tArray pair] (+ t1 t2))
(= [tArray match] (- t1 t2))))
tArray)
; Scalar host for Fast Walsh Transform. This implementation
; requires the length of the input array to be a power of 2. The
; input array is not modified; the output is a new array that holds
; the result of the transform.
(procedure float* (fwtScalarHost [float* input] [int length])
(: cl_context context)
(: cl_command_queue command_queue)
(: cl_program program)
(: cl_kernel kernel)
(: cl_mem tBuffer)
(: float* tArray)
(: int dim global)
(= dim (* length (sizeof float)))
(= global (/ length 2))
(= tArray ((float*) (malloc dim)))
(= context (clCreateContext))
(= command_queue (clCreateCommandQueue context))
(= tBuffer (clCreateBuffer context CL_MEM_READ_WRITE dim))
(= program (clCreateProgramWithSource context "walsh-synth-kernel.rkt"))
(clEnqueueWriteBuffer command_queue tBuffer 0 dim input)
(= kernel (clCreateKernel program "fwtKernelSketch"))
(clSetKernelArg kernel 0 tBuffer)
(for [(: int i in (range 0 (steps length)))]
(: int step)
(= step (<< 1 i))
(clSetKernelArg kernel 1 step)
(clEnqueueNDRangeKernel command_queue kernel 1 NULL (@ global) NULL))
(clEnqueueReadBuffer command_queue tBuffer 0 dim tArray)
tArray)
; Vectorized host for Fast Walsh Transform. This implementation
; requires the length of the input array to be a power of 2. The
; input array is not modified; the output is a new array that holds
; the result of the transform.
(procedure float* (fwtVectorHost [float* input] [int length])
(: cl_context context)
(: cl_command_queue command_queue)
(: cl_program program)
(: cl_mem tBuffer)
(: float* tArray)
(: int dim global n)
(= dim (* length (sizeof float)))
(= global (/ length 2))
(= tArray ((float*) (malloc dim)))
(= context (clCreateContext))
(= command_queue (clCreateCommandQueue context))
(= tBuffer (clCreateBuffer context CL_MEM_READ_WRITE dim))
(= program (clCreateProgramWithSource context "walsh-synth-kernel.rkt"))
(clEnqueueWriteBuffer command_queue tBuffer 0 dim input)
(= n (steps length))
(runKernel command_queue (clCreateKernel program "fwtKernel") tBuffer global 0 (?: (< n 2) n 2))
(if (> n 2)
{ (/= global 4)
(runKernel command_queue (clCreateKernel program "fwtKernel4Sketch") tBuffer global 2 n) })
(clEnqueueReadBuffer command_queue tBuffer 0 dim tArray)
tArray)
(procedure void (runKernel [cl_command_queue command_queue] [cl_kernel kernel] [cl_mem tBuffer]
[int global] [int start] [int end])
(clSetKernelArg kernel 0 tBuffer)
(for [(: int i in (range start end))]
(: int step)
(= step (<< 1 i))
(clSetKernelArg kernel 1 step)
(clEnqueueNDRangeKernel command_queue kernel 1 NULL (@ global) NULL)))
; Given two arrays of the same size, checks that they hold the same
; values at each index.
(procedure void (check [int* actual] [int* expected] [int SIZE])
(assert (>= SIZE 0))
(for [(: int i in (range SIZE))]
(assert (== [actual i] [expected i]))))
(procedure void (synth_scalar) ; ~7 sec
(synth #:forall [(: int length in (range 8 9))
(: float[length] tArray)]
#:ensure (check (fwtScalarHost tArray length)
(fwt tArray length)
length)))
(procedure void (synth_vector) ; < 1 sec
(synth #:forall [(: int length in (range 8 9))
(: float[length] tArray)]
#:ensure (check (fwtVectorHost tArray length)
(fwt tArray length)
length)))
(check-type
(with-output-to-string (λ () (synth_scalar)))
: CString
-> "/home/stchang/NEU_Research/macrotypes/turnstile/examples/tests/rosette/rosette3/walsh-synth-kernel.rkt:3:0\n'(kernel\n void\n (fwtKernelSketch (float* tArray) (int step))\n (: int tid group pair match)\n (: float t1 t2)\n (= tid (get_global_id 0))\n (=\n group\n (rosette3:ann\n (locally-scoped\n (: int left right)\n (= left (rosette3:ann tid : int))\n (= right (rosette3:ann step : int))\n (% left right))\n :\n int))\n (=\n pair\n (+\n (*\n (<< step 1)\n (rosette3:ann\n (locally-scoped\n (: int left right)\n (= left (rosette3:ann tid : int))\n (= right (rosette3:ann step : int))\n (/ left right))\n :\n int))\n group))\n (= match (+ pair step))\n (= t1 (tArray pair))\n (= t2 (tArray match))\n (= (tArray pair) (+ t1 t2))\n (= (tArray match) (- t1 t2)))\n")
(check-type
(with-output-to-string (λ () (synth_vector)))
: CString
-> "/home/stchang/NEU_Research/macrotypes/turnstile/examples/tests/rosette/rosette3/walsh-synth-kernel.rkt:15:0\n'(kernel\n void\n (fwtKernel4Sketch (float4* tArray) (int step))\n (: int tid group pair match)\n (: float4 t1 t2)\n (= tid (get_global_id 0))\n (= step (/ step 4))\n (= group (% tid step))\n (= pair (+ (* (<< step 1) (/ tid step)) group))\n (= match (+ pair step))\n (= t1 (tArray pair))\n (= t2 (tArray match))\n (= (tArray pair) (+ t1 t2))\n (= (tArray match) (- t1 t2)))\n")

View File

@ -0,0 +1,66 @@
#lang s-exp "../../../rosette/synthcl3.rkt"
(kernel void (fwtKernelSketch [float* tArray] [int step])
(: int tid group pair match)
(: float t1 t2)
(= tid (get_global_id 0))
(= group (idx tid step 1)) ; (% tid step)
(= pair (+ (* (<< step 1) (idx tid step 1)) group)) ; (/ tid step)
(= match (+ pair step))
(= t1 [tArray pair])
(= t2 [tArray match])
(= [tArray pair] (+ t1 t2))
(= [tArray match] (- t1 t2)))
(kernel void (fwtKernel4Sketch [float4* tArray] [int step])
(: int tid group pair match)
(: float4 t1 t2)
(= tid (get_global_id 0))
(= step [choose step (/ step 4) (* step 4) (% step 4)]) ; (/ step 4)
(= group (% tid step))
(= pair (+ (* (<< step 1) (/ tid step)) group))
(= match (+ pair step))
(= t1 [tArray pair])
(= t2 [tArray match])
(= [tArray pair] (+ t1 t2))
(= [tArray match] (- t1 t2)))
(grammar int (idx [int tid] [int step] [int depth])
#:base (choose tid step (?? int))
#:else (locally-scoped
(: int left right)
(= left (idx tid step (- depth 1)))
(= right (idx tid step (- depth 1)))
[choose left
(+ left right)
(- left right)
(/ left right)
(* left right)
(% left right)]))
(kernel void (fwtKernel [float* tArray] [int step])
(: int tid group pair match)
(: float t1 t2)
(= tid (get_global_id 0))
(= group (% tid step))
(= pair (+ (* (<< step 1) (/ tid step)) group))
(= match (+ pair step))
(= t1 [tArray pair])
(= t2 [tArray match])
(= [tArray pair] (+ t1 t2))
(= [tArray match] (- t1 t2)))
(kernel void (fwtKernel4 [float4* tArray] [int step])
(: int tid group pair match)
(: float4 t1 t2)
(= tid (get_global_id 0))
(= step (/ step 4))
(= group (% tid step))
(= pair (+ (* (<< step 1) (/ tid step)) group))
(= match (+ pair step))
(= t1 [tArray pair])
(= t2 [tArray match])
(= [tArray pair] (+ t1 t2))
(= [tArray match] (- t1 t2)))