add base/else grammar form; walsh synth tests passing
- add ?? and @ - add comparison ops - use cl:/ instead of my own - fixed walsh synth scalar test - fix unbound ids err in for due to let*-like bindings
This commit is contained in:
parent
e0a2900c77
commit
21c77d7e61
|
@ -1,3 +1,34 @@
|
|||
2016-12-07 --------------------
|
||||
synthcl Walsh synth tests were not working and had trouble debugging, so
|
||||
documenting my process here
|
||||
|
||||
1) getting the error:
|
||||
; ?: literal data is not allowed;
|
||||
; no #%datum syntax transformer is bound
|
||||
; in: #f
|
||||
|
||||
in general, means a stx prop is expected but not present
|
||||
|
||||
In this case:
|
||||
- convert fn undefined for cl_mem and other similar base types
|
||||
|
||||
|
||||
2) getting unexpected unsat
|
||||
helpful debugging technique 1:
|
||||
- print exns swallowed by eval/assert (used by synthesize or verify)
|
||||
|
||||
In this case:
|
||||
- mk-float* was undefined
|
||||
- bool was not considered "real" but should be
|
||||
- cmps not defined
|
||||
|
||||
helpful debugging technique 2:
|
||||
- print asserts in ∃∀-solve, and compare to expected
|
||||
- may have to set error-print-width higher (default 256)
|
||||
|
||||
In this case, comparing asserts showed that quotient was missing from my typed
|
||||
synthcl. Using the / defined by synthcl (instead of mine) fixed the problem.
|
||||
|
||||
2016-11-18 --------------------
|
||||
|
||||
working on synthcl3 lang
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
#lang turnstile
|
||||
(extends "rosette3.rkt" #:except ! #%app || && void = * + - / #%datum if assert verify) ; typed rosette
|
||||
(extends "rosette3.rkt" #:except ! #%app || && void = * + - / #%datum if assert verify < <= > >=) ; typed rosette
|
||||
(require ;(prefix-in ro: (except-in rosette verify sqrt range print)) ; untyped
|
||||
racket/stxparam
|
||||
(prefix-in ro: (combine-in rosette rosette/lib/synthax))
|
||||
(prefix-in cl: (combine-in (except-in sdsl/synthcl/model/operators /)
|
||||
(prefix-in cl: (combine-in sdsl/synthcl/model/operators
|
||||
sdsl/synthcl/lang/forms sdsl/synthcl/model/reals
|
||||
sdsl/synthcl/model/errors sdsl/synthcl/model/kernel
|
||||
sdsl/synthcl/model/memory sdsl/synthcl/model/runtime
|
||||
|
@ -19,13 +19,14 @@
|
|||
|
||||
(provide (rename-out [synth-app #%app])
|
||||
procedure kernel grammar #%datum if range for print
|
||||
choose locally-scoped assert synth verify
|
||||
choose ?? @ locally-scoped assert synth verify
|
||||
int int2 int3 int4 int16 float float2 float3 float4 float16
|
||||
bool void void* char* float* int* int2* int3* int4* int16*
|
||||
bool void void* char*
|
||||
int* int2* int3* int4* int16* float* float2* float3* float4* float16*
|
||||
cl_context cl_command_queue cl_program cl_kernel cl_mem
|
||||
: ! ?: == + * / - || &&
|
||||
% << ; int ops
|
||||
= += -= *= /= %= ; assignment ops
|
||||
% << $ & > >= < <= ; int ops
|
||||
= += -= *= /= %= $= &= ; assignment ops
|
||||
sizeof clCreateProgramWithSource
|
||||
(typed-out
|
||||
[clCreateContext : (C→ cl_context)]
|
||||
|
@ -42,6 +43,7 @@
|
|||
[get_global_id : (C→ int int)]
|
||||
[CL_MEM_READ_ONLY : int]
|
||||
[CL_MEM_WRITE_ONLY : int]
|
||||
[CL_MEM_READ_WRITE : int]
|
||||
[malloc : (C→ int void*)]
|
||||
[get_work_dim : (C→ int)]
|
||||
[!= : (Ccase-> (C→ CNum CNum CBool)
|
||||
|
@ -57,7 +59,7 @@
|
|||
(typecheck? ((current-type-eval) t1)
|
||||
((current-type-eval) t2)))
|
||||
(define (real-type? t)
|
||||
(and (not (typecheck/un? t #'bool))
|
||||
(and #;(not (typecheck/un? t #'bool))
|
||||
(not (typecheck/un? t #'char*))
|
||||
(not (pointer-type? t))))
|
||||
(define (pointer-type? t)
|
||||
|
@ -113,16 +115,12 @@
|
|||
(define (mk-ptr id) (format-id id "~a*" id))
|
||||
(define (mk-mk id) (format-id id "mk-~a" id))
|
||||
(define (mk-to id) (format-id id "to-~a" id))
|
||||
(define (add-convert stx fn)
|
||||
(set-stx-prop/preserved stx 'convert fn))
|
||||
(define (add-construct stx fn) (set-stx-prop/preserved stx 'construct fn))
|
||||
(define (add-convert stx fn) (set-stx-prop/preserved stx 'convert fn))
|
||||
(define (get-construct stx) (syntax-property stx 'construct))
|
||||
(define (get-convert stx)
|
||||
(syntax-property stx 'convert))
|
||||
(define (add-construct stx fn)
|
||||
(set-stx-prop/preserved stx 'construct fn))
|
||||
(define (get-construct stx)
|
||||
(syntax-property stx 'construct))
|
||||
(define (ty->len ty)
|
||||
(regexp-match #px"([a-z]+)([0-9]+)" (type->str ty)))
|
||||
(let ([conv (syntax-property stx 'convert)]) (or conv #'(λ (x) x))))
|
||||
(define (ty->len ty) (regexp-match #px"([a-z]+)([0-9]+)" (type->str ty)))
|
||||
(define (real-type-length t)
|
||||
(define split-ty (ty->len t))
|
||||
(string->number
|
||||
|
@ -133,42 +131,32 @@
|
|||
(string->symbol (car (regexp-match #px"[a-z]+" (type->str ty)))))))
|
||||
(define (get-pointer-base ty [ctx #'here])
|
||||
(datum->syntax ctx (string->symbol (string-trim (type->str ty) "*"))))
|
||||
(define (vector-type? ty)
|
||||
(ty->len ty)) ; TODO: check and not pointer-type?
|
||||
(define (vector-type? ty) (ty->len ty)) ; TODO: check and not pointer-type?
|
||||
(define (scalar-type? ty)
|
||||
(or (typecheck/un? ty #'bool)
|
||||
(and (real-type? ty) (not (vector-type? ty))))))
|
||||
|
||||
(define-syntax-parser add-convertm
|
||||
[(_ stx fn) (add-convert #'stx #'fn)])
|
||||
(define-syntax-parser add-constructm
|
||||
[(_ stx fn) (add-construct #'stx #'fn)])
|
||||
(define-syntax-parser add-convertm [(_ stx fn) (add-convert #'stx #'fn)])
|
||||
(define-syntax-parser add-constructm [(_ stx fn) (add-construct #'stx #'fn)])
|
||||
|
||||
;; TODO: reuse impls in model/reals.rkt ?
|
||||
|
||||
(ro:define (to-bool v)
|
||||
(ro:cond
|
||||
[(ro:boolean? v) v]
|
||||
[(ro:number? v) (ro:! (ro:= 0 v))]
|
||||
[else (cl:raise-conversion-error v "bool")]))
|
||||
(ro:define (to-bool v) ; TODO: reuse impls in model/reals.rkt ?
|
||||
(ro:cond [(ro:boolean? v) v]
|
||||
[(ro:number? v) (ro:! (ro:= 0 v))]
|
||||
[else (cl:raise-conversion-error v "bool")]))
|
||||
(ro:define (to-int v)
|
||||
(ro:cond [(ro:boolean? v) (ro:if v 1 0)]
|
||||
[(ro:fixnum? v) v]
|
||||
[(ro:flonum? v) (ro:exact-truncate v)]
|
||||
[else (ro:real->integer v)]))
|
||||
(ro:define (to-float v)
|
||||
(ro:cond
|
||||
[(ro:boolean? v) (ro:if v 1.0 0.0)]
|
||||
[(ro:fixnum? v) (ro:exact->inexact v)]
|
||||
[(ro:flonum? v) v]
|
||||
[else (ro:type-cast ro:real? v)]))
|
||||
(ro:define (mk-int v)
|
||||
(ro:#%app cl:int v))
|
||||
|
||||
(ro:define (to-int* v)
|
||||
(cl:pointer-cast v cl:int))
|
||||
(ro:define (to-float* v)
|
||||
(cl:pointer-cast v cl:float))
|
||||
(ro:cond [(ro:boolean? v) (ro:if v 1.0 0.0)]
|
||||
[(ro:fixnum? v) (ro:exact->inexact v)]
|
||||
[(ro:flonum? v) v]
|
||||
[else (ro:type-cast ro:real? v)]))
|
||||
(ro:define (mk-int v) (ro:#%app cl:int v))
|
||||
(ro:define (mk-float v) (ro:#%app cl:float v))
|
||||
(ro:define (to-int* v) (cl:pointer-cast v cl:int))
|
||||
(ro:define (to-float* v) (cl:pointer-cast v cl:float))
|
||||
|
||||
(define-type-constructor Pointer #:arity = 1)
|
||||
;(define-named-type-alias void rosette3:CUnit)
|
||||
|
@ -213,8 +201,11 @@
|
|||
(syntax-parse stx
|
||||
[(_ n)
|
||||
#:with floatn (format-id #'n "float~a" (syntax->datum #'n))
|
||||
#:with floatn* (mk-ptr #'floatn)
|
||||
#:with to-floatn (format-id #'n "to-~a" #'floatn)
|
||||
#:with mk-floatn (mk-mk #'floatn)
|
||||
#:with to-floatn* (mk-ptr #'to-floatn)
|
||||
#:with mk-floatn* (mk-ptr #'mk-floatn)
|
||||
#:with cl-mk-floatn (mk-cl #'floatn)
|
||||
#:with (x ...) (generate-temporaries (build-list (syntax->datum #'n) values))
|
||||
#:with (I ...) (stx-map (lambda _ #'rosette3:Num) #'(x ...))
|
||||
|
@ -222,6 +213,7 @@
|
|||
(define-named-type-alias floatn
|
||||
(add-constructm
|
||||
(add-convertm (rosette3:CVector I ...) to-floatn) mk-floatn))
|
||||
(define-named-type-alias floatn* (add-convertm (Pointer floatn) to-floatn*))
|
||||
(ro:define (to-floatn v)
|
||||
(ro:cond
|
||||
[(ro:list? v)
|
||||
|
@ -229,6 +221,7 @@
|
|||
[(ro:vector? v)
|
||||
(ro:apply mk-floatn (ro:for/list ([i n]) (to-float (ro:vector-ref v i))))]
|
||||
[else (ro:apply mk-floatn (ro:make-list n (to-float v)))]))
|
||||
(ro:define (to-floatn* v) (cl:pointer-cast v cl-mk-floatn))
|
||||
(ro:define (mk-floatn x ...) (ro:#%app cl-mk-floatn x ...)))]))
|
||||
(define-simple-macro (define-floats n ...) (begin (define-float n) ...))
|
||||
(define-floats 2 3 4 16)
|
||||
|
@ -302,7 +295,7 @@
|
|||
|
||||
;; top-level fns --------------------------------------------------
|
||||
(define-typed-syntax procedure
|
||||
[(~and (_ ty-out:type (f [ty:type x:id] ...)) ~!) ≫
|
||||
[(~and (_ ty-out:type (f [ty:type x:id] ...)) ~!) ≫ ; empty body
|
||||
#:fail-unless (void? #'ty-out.norm)
|
||||
(format "expected void, given ~a" (type->str #'ty-out.norm))
|
||||
--------
|
||||
|
@ -312,8 +305,7 @@
|
|||
#:with f- (add-orig (generate-temporary #'f) #'f)
|
||||
--------
|
||||
[≻ (begin-
|
||||
(define-syntax- f
|
||||
(make-rename-transformer (⊢ f- : (C→ ty ... ty-out))))
|
||||
(define-syntax- f (make-rename-transformer (⊢ f- : (C→ ty ... ty-out))))
|
||||
(define- f-
|
||||
(lambda- (x ...)
|
||||
(rosette3:let ([x (⊢m (ro:#%app conv x) ty)] ...)
|
||||
|
@ -325,6 +317,22 @@
|
|||
(format "expected void, given ~a" (type->str #'ty-out.norm))
|
||||
--- [≻ (procedure void (f [ty x] ...) e ...)]])
|
||||
(define-typed-syntax grammar
|
||||
[(_ ty-out:type (f [ty:type x:id] ... [ty-depth k]) #:base be #:else ee) ≫
|
||||
#:with f- (generate-temporary #'f)
|
||||
#:with (a ...) (generate-temporaries #'(x ...))
|
||||
--------
|
||||
[≻ (ro:begin
|
||||
(ro:define-synthax (f- x ... k) #:base (rosette3:ann be : ty-out)
|
||||
#:else (rosette3:ann ee : ty-out))
|
||||
(define-typed-syntax f
|
||||
[(ff a ... j) ≫
|
||||
[⊢ a ≫ _ ⇐ ty] ...
|
||||
[⊢ j ≫ _ ⇐ ty-depth]
|
||||
;; j will be eval'ed, so strip its context
|
||||
#:with j- (assign-type (datum->syntax #'H (stx->datum #'j)) #'int)
|
||||
#:with f-- (replace-stx-loc #'f- #'ff)
|
||||
-----------
|
||||
[⊢ (f-- a ... j-) ⇒ ty-out]]))]]
|
||||
[(_ ty-out:type (f [ty:type x:id] ...) e) ≫
|
||||
#:with f- (generate-temporary #'f)
|
||||
--------
|
||||
|
@ -342,40 +350,36 @@
|
|||
|
||||
;; for and if statement --------------------------------------------------
|
||||
(define-typed-syntax if
|
||||
[(_ test {then ...} {else ...}) ≫
|
||||
[(_ e-test {e1 ...} {e2 ...}) ≫
|
||||
--------
|
||||
[⊢ (ro:if (to-bool test)
|
||||
(ro:let () then ... (ro:void))
|
||||
(ro:let () else ... (ro:void))) ⇒ void]]
|
||||
[(_ test {then ...}) ≫
|
||||
--- [≻ (if test {then ...} {})]])
|
||||
[⊢ (ro:if (to-bool e-test)
|
||||
(ro:let () e1 ... (ro:void))
|
||||
(ro:let () e2 ... (ro:void))) ⇒ void]]
|
||||
[(_ e-test es) ≫ --- [≻ (if e-test es {})]])
|
||||
|
||||
(define-typed-syntax (range e ...) ≫
|
||||
[⊢ e ≫ e- ⇐ int] ...
|
||||
--- [⊢ (ro:#%app ro:in-range e- ...) ⇒ int])
|
||||
(define-typed-syntax for
|
||||
[(_ [((~literal :) ty:type x:id (~datum in) rangeExpr) ...] e ...) ≫
|
||||
#:with (x- ...) (generate-temporaries #'(x ...))
|
||||
#:with (typed-seq ...) #'((with-ctx ([x x- ty] ...) rangeExpr) ...)
|
||||
--------
|
||||
[⊢ (ro:for* ([x rangeExpr] ...)
|
||||
(rosette3:let ([x (⊢m x ty)] ...)
|
||||
(⊢m (ro:let () e ... (ro:void)) void))) ⇒ void]])
|
||||
[⊢ (ro:let ([x- 1] ...) ; dummy ensuring id- bound, simplifies stx template
|
||||
(ro:for* ([x- typed-seq] ...)
|
||||
(with-ctx ([x x- ty] ...)
|
||||
(⊢m (ro:let () e ... (ro:void)) void)))) ⇒ void]])
|
||||
|
||||
|
||||
;; need to redefine #%datum because rosette3:#%datum is too precise
|
||||
(define-typed-syntax #%datum
|
||||
[(_ . b:boolean) ≫
|
||||
--------
|
||||
[⊢ (ro:#%datum . b) ⇒ bool]]
|
||||
[(_ . n:integer) ≫
|
||||
--------
|
||||
[⊢ (ro:#%datum . n) ⇒ int]]
|
||||
[(#%datum . n:number) ≫
|
||||
[(_ . b:boolean) ≫ --- [⊢ (ro:#%datum . b) ⇒ bool]]
|
||||
[(_ . s:str) ≫ --- [⊢ (ro:#%datum . s) ⇒ char*]]
|
||||
[(_ . n:integer) ≫ --- [⊢ (ro:#%datum . n) ⇒ int]]
|
||||
[(#%datum . n:number) ≫
|
||||
#:when (real? (syntax-e #'n))
|
||||
--------
|
||||
[⊢ (ro:#%datum . n) ⇒ float]]
|
||||
[(_ . s:str) ≫
|
||||
--------
|
||||
[⊢ (ro:#%datum . s) ⇒ char*]]
|
||||
[(_ . x) ≫
|
||||
--------
|
||||
[_ #:error (type-error #:src #'x #:msg "Unsupported literal: ~v" #'x)]])
|
||||
|
@ -460,12 +464,6 @@
|
|||
|
||||
;; ?: --------------------------------------------------
|
||||
(define-typed-syntax ?:
|
||||
[(_ e e1 e2) ≫
|
||||
[⊢ e ≫ e- ⇐ bool]
|
||||
[⊢ e1 ≫ e1- ⇒ ty1]
|
||||
[⊢ e2 ≫ e2- ⇒ ty2]
|
||||
-------
|
||||
[⊢ (cl:?: e- e1- e2-) ⇒ (⊔ τ1 τ2)]]
|
||||
[(_ e e1 e2) ≫
|
||||
[⊢ e ≫ e- ⇒ ty] ; vector type
|
||||
#:do [(define split-ty (ty->len #'ty))]
|
||||
|
@ -502,8 +500,7 @@
|
|||
(synth-app (ty-out) e1-)
|
||||
(synth-app (ty-out) e2-)) ⇒ ty-out]])
|
||||
|
||||
;; = --------------------------------------------------
|
||||
;; assignment
|
||||
;; = (assignment) --------------------------------------------------
|
||||
(define-typed-syntax =
|
||||
[(_ x:id e) ≫
|
||||
[⊢ x ≫ x- ⇒ ty-x]
|
||||
|
@ -540,16 +537,21 @@
|
|||
--------
|
||||
[⊢ (ro:#%app cl:! (to-bool e-)) ⇒ bool]])
|
||||
|
||||
;; TODO: this should produce int-vector result?
|
||||
(define-typed-syntax ==
|
||||
[(_ e1 e2) ≫
|
||||
[⊢ e1 ≫ e1- ⇒ ty1]
|
||||
[⊢ e2 ≫ e2- ⇒ ty2]
|
||||
#:when (real-type? #'ty1)
|
||||
#:when (real-type? #'ty2)
|
||||
#:with ty-out ((current-join) #'ty1 #'ty2) ; only need this for the len
|
||||
--------
|
||||
[⊢ (to-int (cl:== e1- e2-)) ⇒ int]])
|
||||
;; TODO: comparison ops need to support vec types (and result)
|
||||
(define-simple-macro (mk-cmp cmp-op)
|
||||
(define-typed-syntax cmp-op
|
||||
[(o e1 e2) ≫
|
||||
[⊢ e1 ≫ e1- ⇒ ty1]
|
||||
[⊢ e2 ≫ e2- ⇒ ty2]
|
||||
#:when (real-type? #'ty1)
|
||||
#:when (real-type? #'ty2)
|
||||
#:with ty-out ((current-join) #'ty1 #'ty2) ; only need this for the len
|
||||
#:with conv (get-convert #'ty-out)
|
||||
#:with o- (mk-cl #'o)
|
||||
--------
|
||||
[⊢ (to-int (o- (conv e1-) (conv e2-))) ⇒ int]]))
|
||||
(define-simple-macro (mk-cmps o ...) (begin- (mk-cmp o) ...))
|
||||
(mk-cmps == < <= > >=)
|
||||
|
||||
(define-simple-macro (define-bool-ops o ...+) (ro:begin (define-bool-op o) ...))
|
||||
(define-simple-macro (define-bool-op name)
|
||||
|
@ -563,7 +565,7 @@
|
|||
[(_ e1 e2) ≫ ; else try to coerce
|
||||
--- [⊢ (name- (to-bool e1) (to-bool e2)) ⇒ bool]]))
|
||||
|
||||
(define- (cl:/ x y)
|
||||
#;(define- (cl:/ x y)
|
||||
(cond- [(zero?- y) 0]
|
||||
[(integer?- x) (quotient- x y)]
|
||||
[else (/- x y)]))
|
||||
|
@ -604,7 +606,7 @@
|
|||
|
||||
(define-bool-ops || &&)
|
||||
(define-real-ops + * - /)
|
||||
(define-int-ops % <<)
|
||||
(define-int-ops % << $ &)
|
||||
|
||||
(define-typerule (sizeof t:type) >> ---[⊢ #,(real-type-length #'t.norm) ⇒ int])
|
||||
(define-typerule (print e ...) >> ---[⊢ (ro:begin (display e) ...) ⇒ void])
|
||||
|
@ -621,9 +623,30 @@
|
|||
--------
|
||||
[⊢ (ch/disarmed e/disarmed ...) ⇒ #,(stx-car #'(ty ...))]])
|
||||
|
||||
(define-typed-syntax (locally-scoped e ...) ≫
|
||||
--------
|
||||
[≻ (rosette3:let () e ...)])
|
||||
(define-typed-syntax ??
|
||||
[(qq) ≫
|
||||
#:with ??/progsrc (replace-stx-loc #'cl:?? #'qq)
|
||||
--------
|
||||
[⊢ (??/progsrc) ⇒ int]]
|
||||
[(qq ty:type) ≫
|
||||
#:with ??/progsrc (replace-stx-loc #'cl:?? #'qq)
|
||||
#:with t (datum->syntax #'here (string->symbol (type->str #'ty.norm)))
|
||||
#:with cl-t (mk-cl #'t)
|
||||
;; #:with ty-base (get-base #'ty.norm)
|
||||
;; #:with pred (get-pred ((current-type-eval) #'ty-base))
|
||||
--------
|
||||
[⊢ (??/progsrc cl-t) ⇒ ty]])
|
||||
|
||||
(define-typed-syntax (@ x:id) ≫
|
||||
[⊢ x ≫ x- ⇒ ty+] ;; TODO: check ty = real, non-ptr type
|
||||
#:with ty (datum->syntax #'x (string->symbol (type->str #'ty+)))
|
||||
#:with cl-ty (mk-cl #'ty)
|
||||
---------
|
||||
[⊢ (cl:address-of x- cl-ty) ⇒ #,(mk-ptr #'ty)])
|
||||
|
||||
(define-typed-syntax locally-scoped
|
||||
[(_ e ...) ⇐ ty ≫ --- [⊢ (ro:let () e ...)]]
|
||||
[(_ e ...) ≫ --- [≻ (⊢m (ro:let () e ...) void)]])
|
||||
|
||||
(define-for-syntax (decl->seq stx)
|
||||
(syntax-parse stx
|
||||
|
@ -646,6 +669,7 @@
|
|||
(ro:define-values (tmp ...)
|
||||
(ro:for*/lists (tmp ...) ([id- typed-seq] ...) (ro:values id- ...)))
|
||||
(ro:parameterize ([ro:current-bitwidth bw] ; matrix mult unsat w/o this
|
||||
[ro:current-oracle (ro:oracle (ro:current-oracle))]
|
||||
[ro:term-cache (ro:hash-copy (ro:term-cache))])
|
||||
(ro:print-forms
|
||||
(ro:synthesize
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
(do-tests "synthcl3-tests.rkt" "SynthCL"
|
||||
"synthcl3-matrix-synth-tests.rkt" "SynthCL Matrix Mult: synth"
|
||||
"synthcl3-matrix-verify-tests.rkt" "SynthCL Matrix Mult: verify"
|
||||
"synthcl3-matrix-verify-buggy-tests.rkt" "SynthCL buggy Matrix Mult: verify")
|
||||
"synthcl3-matrix-verify-buggy-tests.rkt" "SynthCL buggy Matrix Mult: verify"
|
||||
"synthcl3-walsh-synth-tests.rkt" "SynthCL Walsh Transform: synth")
|
||||
(do-tests "bv-ref-tests.rkt" "BV SDSL - Hacker's Delight synthesis")
|
||||
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
#lang racket/base
|
||||
|
||||
(require macrotypes/examples/tests/do-tests)
|
||||
|
||||
(do-tests
|
||||
"synthcl3-tests.rkt" "SynthCL"
|
||||
"synthcl3-matrix-synth-tests.rkt" "SynthCL Matrix Mult: synth"
|
||||
"synthcl3-matrix-verify-tests.rkt" "SynthCL Matrix Mult: verify"
|
||||
"synthcl3-matrix-verify-buggy-tests.rkt" "SynthCL buggy Matrix Mult: verify"
|
||||
"synthcl3-walsh-synth-tests.rkt" "SynthCL Walsh Transform: synth")
|
||||
|
||||
|
|
@ -0,0 +1,146 @@
|
|||
#lang s-exp "../../../rosette/synthcl3.rkt"
|
||||
(require "../../rackunit-typechecking.rkt")
|
||||
; Compute the number of steps for the algorithm,
|
||||
; assuming that v is a power of 2. See the log2
|
||||
; algorithm from http://graphics.stanford.edu/~seander/bithacks.html#IntegerLog
|
||||
(procedure int (steps [int v])
|
||||
(: int r)
|
||||
(= r 0)
|
||||
($= r (<< (!= 0 (& v #xAAAAAAAA)) 0))
|
||||
($= r (<< (!= 0 (& v #xCCCCCCCC)) 1))
|
||||
($= r (<< (!= 0 (& v #xF0F0F0F0)) 2))
|
||||
($= r (<< (!= 0 (& v #xFF00FF00)) 3))
|
||||
($= r (<< (!= 0 (& v #xFFFF0000)) 4))
|
||||
r)
|
||||
|
||||
; Reference implementation for Fast Walsh Transform. This implementation
|
||||
; requires the length of the input array to be a power of 2, and it modifies
|
||||
; the input array in place.
|
||||
(procedure float* (fwt [float* tArray] [int length])
|
||||
(for [(: int i in (range 0 (steps length)))]
|
||||
(: int step)
|
||||
(= step (<< 1 i))
|
||||
(for [(: int group in (range 0 step))
|
||||
(: int pair in (range group length (<< step 1)))]
|
||||
(: int match)
|
||||
(: float t1 t2)
|
||||
(= match (+ pair step))
|
||||
(= t1 [tArray pair])
|
||||
(= t2 [tArray match])
|
||||
(= [tArray pair] (+ t1 t2))
|
||||
(= [tArray match] (- t1 t2))))
|
||||
tArray)
|
||||
|
||||
; Scalar host for Fast Walsh Transform. This implementation
|
||||
; requires the length of the input array to be a power of 2. The
|
||||
; input array is not modified; the output is a new array that holds
|
||||
; the result of the transform.
|
||||
(procedure float* (fwtScalarHost [float* input] [int length])
|
||||
(: cl_context context)
|
||||
(: cl_command_queue command_queue)
|
||||
(: cl_program program)
|
||||
(: cl_kernel kernel)
|
||||
(: cl_mem tBuffer)
|
||||
(: float* tArray)
|
||||
(: int dim global)
|
||||
|
||||
(= dim (* length (sizeof float)))
|
||||
(= global (/ length 2))
|
||||
|
||||
(= tArray ((float*) (malloc dim)))
|
||||
|
||||
(= context (clCreateContext))
|
||||
|
||||
(= command_queue (clCreateCommandQueue context))
|
||||
|
||||
(= tBuffer (clCreateBuffer context CL_MEM_READ_WRITE dim))
|
||||
(= program (clCreateProgramWithSource context "walsh-synth-kernel.rkt"))
|
||||
|
||||
(clEnqueueWriteBuffer command_queue tBuffer 0 dim input)
|
||||
|
||||
(= kernel (clCreateKernel program "fwtKernelSketch"))
|
||||
(clSetKernelArg kernel 0 tBuffer)
|
||||
|
||||
(for [(: int i in (range 0 (steps length)))]
|
||||
(: int step)
|
||||
(= step (<< 1 i))
|
||||
(clSetKernelArg kernel 1 step)
|
||||
(clEnqueueNDRangeKernel command_queue kernel 1 NULL (@ global) NULL))
|
||||
|
||||
(clEnqueueReadBuffer command_queue tBuffer 0 dim tArray)
|
||||
tArray)
|
||||
|
||||
; Vectorized host for Fast Walsh Transform. This implementation
|
||||
; requires the length of the input array to be a power of 2. The
|
||||
; input array is not modified; the output is a new array that holds
|
||||
; the result of the transform.
|
||||
(procedure float* (fwtVectorHost [float* input] [int length])
|
||||
(: cl_context context)
|
||||
(: cl_command_queue command_queue)
|
||||
(: cl_program program)
|
||||
(: cl_mem tBuffer)
|
||||
(: float* tArray)
|
||||
(: int dim global n)
|
||||
|
||||
(= dim (* length (sizeof float)))
|
||||
(= global (/ length 2))
|
||||
|
||||
(= tArray ((float*) (malloc dim)))
|
||||
|
||||
(= context (clCreateContext))
|
||||
|
||||
(= command_queue (clCreateCommandQueue context))
|
||||
|
||||
(= tBuffer (clCreateBuffer context CL_MEM_READ_WRITE dim))
|
||||
(= program (clCreateProgramWithSource context "walsh-synth-kernel.rkt"))
|
||||
|
||||
(clEnqueueWriteBuffer command_queue tBuffer 0 dim input)
|
||||
|
||||
(= n (steps length))
|
||||
|
||||
(runKernel command_queue (clCreateKernel program "fwtKernel") tBuffer global 0 (?: (< n 2) n 2))
|
||||
(if (> n 2)
|
||||
{ (/= global 4)
|
||||
(runKernel command_queue (clCreateKernel program "fwtKernel4Sketch") tBuffer global 2 n) })
|
||||
|
||||
(clEnqueueReadBuffer command_queue tBuffer 0 dim tArray)
|
||||
tArray)
|
||||
|
||||
(procedure void (runKernel [cl_command_queue command_queue] [cl_kernel kernel] [cl_mem tBuffer]
|
||||
[int global] [int start] [int end])
|
||||
(clSetKernelArg kernel 0 tBuffer)
|
||||
(for [(: int i in (range start end))]
|
||||
(: int step)
|
||||
(= step (<< 1 i))
|
||||
(clSetKernelArg kernel 1 step)
|
||||
(clEnqueueNDRangeKernel command_queue kernel 1 NULL (@ global) NULL)))
|
||||
|
||||
; Given two arrays of the same size, checks that they hold the same
|
||||
; values at each index.
|
||||
(procedure void (check [int* actual] [int* expected] [int SIZE])
|
||||
(assert (>= SIZE 0))
|
||||
(for [(: int i in (range SIZE))]
|
||||
(assert (== [actual i] [expected i]))))
|
||||
|
||||
(procedure void (synth_scalar) ; ~7 sec
|
||||
(synth #:forall [(: int length in (range 8 9))
|
||||
(: float[length] tArray)]
|
||||
#:ensure (check (fwtScalarHost tArray length)
|
||||
(fwt tArray length)
|
||||
length)))
|
||||
|
||||
(procedure void (synth_vector) ; < 1 sec
|
||||
(synth #:forall [(: int length in (range 8 9))
|
||||
(: float[length] tArray)]
|
||||
#:ensure (check (fwtVectorHost tArray length)
|
||||
(fwt tArray length)
|
||||
length)))
|
||||
|
||||
(check-type
|
||||
(with-output-to-string (λ () (synth_scalar)))
|
||||
: CString
|
||||
-> "/home/stchang/NEU_Research/macrotypes/turnstile/examples/tests/rosette/rosette3/walsh-synth-kernel.rkt:3:0\n'(kernel\n void\n (fwtKernelSketch (float* tArray) (int step))\n (: int tid group pair match)\n (: float t1 t2)\n (= tid (get_global_id 0))\n (=\n group\n (rosette3:ann\n (locally-scoped\n (: int left right)\n (= left (rosette3:ann tid : int))\n (= right (rosette3:ann step : int))\n (% left right))\n :\n int))\n (=\n pair\n (+\n (*\n (<< step 1)\n (rosette3:ann\n (locally-scoped\n (: int left right)\n (= left (rosette3:ann tid : int))\n (= right (rosette3:ann step : int))\n (/ left right))\n :\n int))\n group))\n (= match (+ pair step))\n (= t1 (tArray pair))\n (= t2 (tArray match))\n (= (tArray pair) (+ t1 t2))\n (= (tArray match) (- t1 t2)))\n")
|
||||
(check-type
|
||||
(with-output-to-string (λ () (synth_vector)))
|
||||
: CString
|
||||
-> "/home/stchang/NEU_Research/macrotypes/turnstile/examples/tests/rosette/rosette3/walsh-synth-kernel.rkt:15:0\n'(kernel\n void\n (fwtKernel4Sketch (float4* tArray) (int step))\n (: int tid group pair match)\n (: float4 t1 t2)\n (= tid (get_global_id 0))\n (= step (/ step 4))\n (= group (% tid step))\n (= pair (+ (* (<< step 1) (/ tid step)) group))\n (= match (+ pair step))\n (= t1 (tArray pair))\n (= t2 (tArray match))\n (= (tArray pair) (+ t1 t2))\n (= (tArray match) (- t1 t2)))\n")
|
|
@ -0,0 +1,66 @@
|
|||
#lang s-exp "../../../rosette/synthcl3.rkt"
|
||||
|
||||
(kernel void (fwtKernelSketch [float* tArray] [int step])
|
||||
(: int tid group pair match)
|
||||
(: float t1 t2)
|
||||
(= tid (get_global_id 0))
|
||||
(= group (idx tid step 1)) ; (% tid step)
|
||||
(= pair (+ (* (<< step 1) (idx tid step 1)) group)) ; (/ tid step)
|
||||
(= match (+ pair step))
|
||||
(= t1 [tArray pair])
|
||||
(= t2 [tArray match])
|
||||
(= [tArray pair] (+ t1 t2))
|
||||
(= [tArray match] (- t1 t2)))
|
||||
|
||||
(kernel void (fwtKernel4Sketch [float4* tArray] [int step])
|
||||
(: int tid group pair match)
|
||||
(: float4 t1 t2)
|
||||
(= tid (get_global_id 0))
|
||||
(= step [choose step (/ step 4) (* step 4) (% step 4)]) ; (/ step 4)
|
||||
(= group (% tid step))
|
||||
(= pair (+ (* (<< step 1) (/ tid step)) group))
|
||||
(= match (+ pair step))
|
||||
(= t1 [tArray pair])
|
||||
(= t2 [tArray match])
|
||||
(= [tArray pair] (+ t1 t2))
|
||||
(= [tArray match] (- t1 t2)))
|
||||
|
||||
|
||||
(grammar int (idx [int tid] [int step] [int depth])
|
||||
#:base (choose tid step (?? int))
|
||||
#:else (locally-scoped
|
||||
(: int left right)
|
||||
(= left (idx tid step (- depth 1)))
|
||||
(= right (idx tid step (- depth 1)))
|
||||
[choose left
|
||||
(+ left right)
|
||||
(- left right)
|
||||
(/ left right)
|
||||
(* left right)
|
||||
(% left right)]))
|
||||
|
||||
(kernel void (fwtKernel [float* tArray] [int step])
|
||||
(: int tid group pair match)
|
||||
(: float t1 t2)
|
||||
(= tid (get_global_id 0))
|
||||
(= group (% tid step))
|
||||
(= pair (+ (* (<< step 1) (/ tid step)) group))
|
||||
(= match (+ pair step))
|
||||
(= t1 [tArray pair])
|
||||
(= t2 [tArray match])
|
||||
(= [tArray pair] (+ t1 t2))
|
||||
(= [tArray match] (- t1 t2)))
|
||||
|
||||
(kernel void (fwtKernel4 [float4* tArray] [int step])
|
||||
(: int tid group pair match)
|
||||
(: float4 t1 t2)
|
||||
(= tid (get_global_id 0))
|
||||
(= step (/ step 4))
|
||||
(= group (% tid step))
|
||||
(= pair (+ (* (<< step 1) (/ tid step)) group))
|
||||
(= match (+ pair step))
|
||||
(= t1 [tArray pair])
|
||||
(= t2 [tArray match])
|
||||
(= [tArray pair] (+ t1 t2))
|
||||
(= [tArray match] (- t1 t2)))
|
||||
|
Loading…
Reference in New Issue
Block a user