fix some bugs; matrix synth tests working

- pointer selection should have type of ptr base
- fix more invalid vectors, created by to-X, and :
- add cl fns
This commit is contained in:
Stephen Chang 2016-12-02 12:25:51 -05:00
parent f1f43697f9
commit b5e8a7bceb
5 changed files with 385 additions and 248 deletions

View File

@ -56,7 +56,7 @@
(local-expand
e 'expression
(list #'ro:#%app #'ro:choose #'ro:synthesize #'ro:let #'ro:in-value
#'ro:assert #'ro:if)))
#'ro:assert #'ro:if #'ro:?)))
; (displayln (stx->datum e+))
e+)
(define (mk-ro:-id id) (format-id id "ro:~a" id))

View File

@ -1,5 +1,5 @@
#lang turnstile
(extends "rosette3.rkt" #:except ! #%app || && void = * + - #%datum if assert verify) ; typed rosette
(extends "rosette3.rkt" #:except ! #%app || && void = * + - / #%datum if assert verify) ; typed rosette
(require ;(prefix-in ro: (except-in rosette verify sqrt range print)) ; untyped
racket/stxparam
(prefix-in ro: (combine-in rosette rosette/lib/synthax))
@ -8,7 +8,10 @@
sdsl/synthcl/model/operators sdsl/synthcl/model/errors
sdsl/synthcl/model/memory sdsl/synthcl/model/runtime
sdsl/synthcl/model/work sdsl/synthcl/model/pointers
sdsl/synthcl/lang/queries))
sdsl/synthcl/lang/queries sdsl/synthcl/model/context
sdsl/synthcl/model/queue sdsl/synthcl/model/buffer
sdsl/synthcl/model/flags sdsl/synthcl/model/program
sdsl/synthcl/model/kernel))
(for-syntax (prefix-in cl: sdsl/synthcl/lang/util)))
(begin-for-syntax
@ -19,16 +22,31 @@
procedure kernel grammar #%datum if range for print
choose locally-scoped assert synth verify
int int2 int3 int4 int16 float float2 float3 float4 float16
bool void void* char* float* int* int16* int2*
bool void void* char* float* int* int2* int3* int4* int16*
cl_context cl_command_queue cl_program cl_kernel cl_mem
: ! ?: == + * - || &&
% << ; int ops
= += -= %= ; assignment ops
sizeof
sizeof clCreateProgramWithSource
(typed-out
;[with-output-to-string : (C→ (C→ Any) char*)]
[clCreateContext : (C→ cl_context)]
[clCreateCommandQueue : (C→ cl_context cl_command_queue)]
[clCreateBuffer : (C→ cl_context int int cl_mem)]
[clEnqueueReadBuffer : (C→ cl_command_queue cl_mem int int void* void)]
[clEnqueueWriteBuffer : (C→ cl_command_queue cl_mem int int void* void)]
[clEnqueueNDRangeKernel : (C→ cl_command_queue cl_kernel int int* int* int* void)]
[clCreateKernel : (C→ cl_program char* cl_kernel)]
[clSetKernelArg : (Ccase-> (C→ cl_kernel int cl_mem void)
(C→ cl_kernel int int void)
(C→ cl_kernel int float void))]
[get_global_id : (C→ int int)]
[CL_MEM_READ_ONLY : int]
[CL_MEM_WRITE_ONLY : int]
[/ : (Ccase-> (C→ int int int) (C→ float float float))]
[malloc : (C→ int void*)]
[get_work_dim : (C→ int)]
[!= : (Ccase-> (C→ CNum CNum CBool)
(C→ CNum CNum CNum CBool)
(C→ Num Num Bool)
@ -63,9 +81,8 @@
(define common-real-type
(case-lambda
[(t) (and (real-type? t) t)]
[(t1 t2) (cond [(real-type<=? t1 t2) t2]
[(real-type<=? t2 t1) t1]
[else #f])]
[(t1 t2) (or (and (real-type<=? t1 t2) t2)
(and (real-type<=? t2 t1) t1))]
[ts (common-real-type (car ts) (apply common-real-type (cdr ts)))]))
;; implements common-real-type from model/reals.rkt
@ -96,6 +113,8 @@
(type->str from) (type->str to)
#;(if (contract? to) (contract-name to) to))
expr subexpr)))
(define (mk-ptr id) (format-id id "~a*" id))
(define (mk-mk id) (format-id id "mk-~a" id))
(define (add-convert stx fn)
(set-stx-prop/preserved stx 'convert fn))
(define (get-convert stx)
@ -114,6 +133,8 @@
((current-type-eval)
(datum->syntax ctx
(string->symbol (car (regexp-match #px"[a-z]+" (type->str ty)))))))
(define (get-pointer-base ty [ctx #'here])
(datum->syntax ctx (string->symbol (string-trim (type->str ty) "*"))))
(define (vector-type? ty)
(ty->len ty)) ; TODO: check and not pointer-type?
(define (scalar-type? ty)
@ -146,113 +167,74 @@
(ro:define (mk-int v)
(ro:#%app cl:int v))
(ro:define (to-int16* v)
(cl:pointer-cast v cl:int16))
(ro:define (to-int2* v)
(cl:pointer-cast v cl:int2))
(ro:define (to-int* v)
(cl:pointer-cast v cl:int))
(ro:define (to-float* v)
(cl:pointer-cast v cl:float))
(define-named-type-alias bool
(add-convertm rosette3:Bool to-bool))
(define-named-type-alias int
(add-convertm rosette3:Int to-int))
(define-named-type-alias float
(add-convertm rosette3:Num to-float))
(define-named-type-alias char*
(add-convertm rosette3:CString (λ (x) x)))
(define-type-constructor Pointer #:arity = 1)
;(define-named-type-alias void rosette3:CUnit)
(define-base-types void cl_context cl_command_queue cl_program cl_kernel cl_mem)
(define-named-type-alias void* (add-convertm (Pointer void) (λ (x) x)))
(define-named-type-alias bool (add-convertm rosette3:Bool to-bool))
(define-named-type-alias int (add-convertm rosette3:Int to-int))
(define-named-type-alias int* (add-convertm (Pointer int) to-int*))
(define-named-type-alias float (add-convertm rosette3:Num to-float))
(define-named-type-alias float* (add-convertm (Pointer float) to-float*))
(define-named-type-alias char* (add-convertm rosette3:CString (λ (x) x)))
(define-syntax (define-int stx)
(syntax-parse stx
[(_ n)
#:with intn (format-id #'n "int~a" (syntax->datum #'n))
#:with to-intn (format-id #'n "to-~a" #'intn)
#:with mk-intn (format-id #'n "mk-~a" #'intn)
#:with cl-mk-intn (mk-cl #'intn)
#:with (x ...) (generate-temporaries
(build-list (syntax->datum #'n) (lambda (x) x)))
#:with (I ...) (stx-map (lambda _ #'rosette3:Int) #'(x ...))
#'(begin
(define-named-type-alias intn
(add-constructm
(add-convertm
(rosette3:CVector I ...)
to-intn)
mk-intn))
(ro:define (to-intn v)
(ro:cond
[(ro:list? v)
(ro:apply ro:vector-immutable
(ro:for/list ([i n]) (to-int (ro:list-ref v i))))]
[(ro:vector? v)
(ro:apply ro:vector-immutable
(ro:for/list ([i n]) (to-int (ro:vector-ref v i))))]
[else
(ro:apply ro:vector-immutable
(ro:make-list n (to-int v)))]))
(ro:define (mk-intn x ...)
(ro:#%app cl-mk-intn x ...)
#;(ro:#%app ro:vector-immutable (to-int x) ...))
)]))
(syntax-parse stx
[(_ n)
#:with intn (format-id #'n "int~a" (syntax->datum #'n))
#:with intn* (mk-ptr #'intn)
#:with to-intn (format-id #'n "to-~a" #'intn)
#:with mk-intn (mk-mk #'intn)
#:with to-intn* (mk-ptr #'to-intn)
#:with mk-intn* (mk-ptr #'mk-intn)
#:with cl-mk-intn (mk-cl #'intn)
#:with (x ...) (generate-temporaries (build-list (syntax->datum #'n) values))
#:with (I ...) (stx-map (lambda _ #'rosette3:Int) #'(x ...))
#'(begin
(define-named-type-alias intn
(add-constructm (add-convertm (rosette3:CVector I ...) to-intn) mk-intn))
(define-named-type-alias intn* (add-convertm (Pointer intn) to-intn*))
(ro:define (to-intn v)
(ro:cond
[(ro:list? v)
(ro:apply mk-intn (ro:for/list ([i n]) (ro:#%app to-int (ro:list-ref v i))))]
[(ro:vector? v)
(ro:apply mk-intn (ro:for/list ([i n]) (ro:#%app to-int (ro:vector-ref v i))))]
[else (ro:apply mk-intn (ro:make-list n (ro:#%app to-int v)))]))
(ro:define (to-intn* v) (cl:pointer-cast v cl-mk-intn))
(ro:define (mk-intn x ...) (ro:#%app cl-mk-intn x ...)))]))
(define-simple-macro (define-ints n ...) (begin (define-int n) ...))
(define-ints 2 3 4 16)
(define-syntax (define-float stx)
(syntax-parse stx
[(_ n)
#:with floatn (format-id #'n "float~a" (syntax->datum #'n))
#:with to-floatn (format-id #'n "to-~a" #'floatn)
#:with mk-floatn (format-id #'n "mk-~a" #'floatn)
#:with cl-mk-floatn (mk-cl #'floatn)
#:with (x ...) (generate-temporaries
(build-list (syntax->datum #'n) (lambda (x) x)))
#:with (I ...) (stx-map (lambda _ #'rosette3:Num) #'(x ...))
#'(begin
(define-named-type-alias floatn
(add-constructm
(add-convertm
(rosette3:CVector I ...)
to-floatn)
mk-floatn))
(ro:define (to-floatn v)
(ro:cond
[(ro:list? v)
(ro:apply ro:vector-immutable
(ro:for/list ([i n]) (to-float (ro:list-ref v i))))]
[(ro:vector? v)
(ro:apply ro:vector-immutable
(ro:for/list ([i n]) (to-float (ro:vector-ref v i))))]
[else
(ro:apply ro:vector-immutable
(ro:make-list n (to-float v)))]))
(ro:define (mk-floatn x ...)
(ro:#%app cl-mk-floatn x ...)
#;(ro:#%app ro:vector-immutable (to-float x) ...))
)]))
(syntax-parse stx
[(_ n)
#:with floatn (format-id #'n "float~a" (syntax->datum #'n))
#:with to-floatn (format-id #'n "to-~a" #'floatn)
#:with mk-floatn (mk-mk #'floatn)
#:with cl-mk-floatn (mk-cl #'floatn)
#:with (x ...) (generate-temporaries (build-list (syntax->datum #'n) values))
#:with (I ...) (stx-map (lambda _ #'rosette3:Num) #'(x ...))
#'(begin
(define-named-type-alias floatn
(add-constructm
(add-convertm (rosette3:CVector I ...) to-floatn) mk-floatn))
(ro:define (to-floatn v)
(ro:cond
[(ro:list? v)
(ro:apply mk-floatn (ro:for/list ([i n]) (ro:#%app to-float (ro:list-ref v i))))]
[(ro:vector? v)
(ro:apply mk-floatn (ro:for/list ([i n]) (ro:#%app to-float (ro:vector-ref v i))))]
[else (ro:apply mk-floatn (ro:make-list n (ro:#%app to-float v)))]))
(ro:define (mk-floatn x ...) (ro:#%app cl-mk-floatn x ...)))]))
(define-simple-macro (define-floats n ...) (begin (define-float n) ...))
(define-floats 2 3 4 16)
(define-type-constructor Pointer #:arity = 1)
;(define-named-type-alias void rosette3:CUnit)
(define-base-type void)
#;(begin-for-syntax
(define-syntax ~void*
(pattern-expander
(make-variable-like-transformer #'(~and t:type (~parse ~void #'t.norm))))))
(define-named-type-alias void*
(add-convertm (Pointer void) (λ (x) x)))
(define-named-type-alias int*
(add-convertm (Pointer int) to-int*))
(define-named-type-alias int16*
(add-convertm (Pointer int16) to-int16*))
(define-named-type-alias int2*
(add-convertm (Pointer int2) to-int2*))
(define-named-type-alias float*
(add-convertm (Pointer float) to-float*))
(define-typed-syntax synth-app
[(_ (ty:type) e) ; cast
[ e e- ty-e]
@ -282,13 +264,8 @@
[ ptr ptr- ty-ptr]
#:when (pointer-type? #'ty-ptr) #:with ~! #'dummy ; commit
[ sel sel- int]
#:do [(define split-ty (ty->len #'ty-ptr))]
#:when (and split-ty (= 3 (length split-ty)))
#:do [(define base-str (cadr split-ty))
(define len-str (caddr split-ty))]
#:with ty-out ((current-type-eval) (format-id #'h "~a~a" base-str len-str))
--------
[ (cl:pointer-ref ptr- sel-) ty-out]]
[ (cl:pointer-ref ptr- sel-) #,(get-pointer-base #'ty-ptr)]]
[(_ vec sel) ; applying vector to one arg is selector
[ vec vec- ty-vec]
#:when (vector-type? #'ty-vec)
@ -342,13 +319,13 @@
(define- f-
(lambda- (x ...)
(rosette3:let ([x (⊢m (ro:#%app conv x) ty)] ...)
(⊢m (let- () e ... (rosette3:ann e-body : ty-out)) ty-out)))))]])
(⊢m (ro:let () e ... (rosette3:ann e-body : ty-out)) ty-out))))
(provide- f))]])
(define-typed-syntax kernel
[(_ ty-out:type (f [ty:type x:id] ...) e ...)
#:fail-unless (void? #'ty-out.norm)
(format "expected void, given ~a" (type->str #'ty-out.norm))
--------
[ (procedure void (f [ty x] ...) e ...)]])
--- [ (procedure void (f [ty x] ...) e ...)]])
(define-typed-syntax grammar
[(_ ty-out:type (f [ty:type x:id] ...) e)
#:with f- (generate-temporary #'f)
@ -369,23 +346,21 @@
(define-typed-syntax if
[(_ test {then ...} {else ...})
--------
[ (ro:if (to-bool test)
[ (ro:if (ro:#%app to-bool test)
(ro:let () then ... (ro:void))
(ro:let () else ... (ro:void))) void]]
[(_ test {then ...})
--------
[ (if test {then ...} {})]])
--- [ (if test {then ...} {})]])
(define-typed-syntax (range e ...)
[ e e- int] ...
--------
[ (ro:#%app ro:in-range e- ...) int])
--- [ (ro:#%app ro:in-range e- ...) int])
(define-typed-syntax for
[(_ [((~literal :) ty:type var:id (~datum in) rangeExpr) ...] e ...)
[[var var- : ty.norm] ... [e e- ty-e] ...]
[(_ [((~literal :) ty:type x:id (~datum in) rangeExpr) ...] e ...)
--------
[ (ro:for* ([var- rangeExpr] ...)
e- ... (ro:void)) void]])
[ (ro:for* ([x rangeExpr] ...)
(rosette3:let ([x (⊢m x ty)] ...)
(⊢m (ro:let () e ... (ro:void)) void))) void]])
;; need to redefine #%datum because rosette3:#%datum is too precise
@ -433,10 +408,11 @@
(format "no pred for ~a" (type->str #'ty))
#:with (x- ...) (generate-temporaries #'(x ...))
#:with (x-- ...) (generate-temporaries #'(x ...))
#:with mk-ty (format-id #'here "mk-~a" #'ty)
--------
[ (begin-
(ro:define-symbolic* x-- pred [#,(string->number len-str)]) ...
(ro:define x- (ro:apply ro:vector-immutable x--)) ...
(ro:define x- (ro:apply mk-ty x--)) ...
(define-syntax- x
(make-rename-transformer (assign-type #'x- #'ty.norm))) ...)]]
[(_ ty:type [len] x:id ...) ; array of vector types
@ -457,7 +433,7 @@
[ (begin-
(ro:define-symbolic* x-- pred [len base-len]) ...
(ro:define x-
(ro:let ([*x (to-ty* (cl:malloc (ro:* len base-len)))])
(ro:let ([*x (ro:#%app to-ty* (cl:malloc (ro:* len base-len)))])
(ro:for ([i len][v x--])
(cl:pointer-set! *x i (ro:apply mk-ty v)))
*x)) ...
@ -506,12 +482,12 @@
#:with ty-base ((current-type-eval) (datum->syntax #'e (string->symbol out-base-str)))
#:with base-convert (get-convert #'ty-base)
-------
[ (convert
(ro:let ([a (convert e-)][b (convert e1-)][c (convert e2-)])
[ (ro:#%app convert
(ro:let ([a (ro:#%app convert e-)][b (ro:#%app convert e1-)][c (ro:#%app convert e2-)])
(ro:for/list ([idx #,(string->number out-len-str)])
(ro:if (ro:< (ro:vector-ref a idx) 0)
(base-convert (ro:vector-ref b idx))
(base-convert (ro:vector-ref c idx))))))
(ro:#%app base-convert (ro:vector-ref b idx))
(ro:#%app base-convert (ro:vector-ref c idx))))))
ty-out]]
[(_ ~! e e1 e2) ; should be scalar and real
[ e e- ty]
@ -535,10 +511,10 @@
[ x x- ty-x]
[ e e- ty-e]
#:fail-unless (cast-ok? #'ty-e #'ty-x stx)
(format "cannot cast ~a to ~a"
(type->str #'ty-e) (type->str #'ty-x))
(format "cannot cast ~a to ~a" (type->str #'ty-e) (type->str #'ty-x))
#:with conv (get-convert #'ty-x)
--------
[ (ro:set! x- (synth-app (ty-x) e-)) void]]
[ (ro:set! x- #,(if (syntax-e #'conv) #'(ro:#%app conv e-) #'e-)) void]]
;; selector can be list of numbers or up to wxyz for vectors of length <=4
[(_ [x:id sel] e)
[ x x- ty-x]
@ -561,11 +537,11 @@
[(_ e)
[ e e- bool]
--------
[ (cl:! e-) bool]]
[ (ro:#%app cl:! e-) bool]]
[(_ e) ; else try to coerce
[ e e- ty]
--------
[ (cl:! (synth-app (bool) e-)) bool]])
[ (ro:#%app cl:! (ro:#%app to-bool e-)) bool]])
;; TODO: this should produce int-vector result?
(define-typed-syntax ==
@ -576,7 +552,7 @@
#:when (real-type? #'ty2)
#:with ty-out ((current-join) #'ty1 #'ty2) ; only need this for the len
--------
[ (to-int (cl:== e1- e2-)) int]])
[ (ro:#%app to-int (cl:== e1- e2-)) int]])
(define-simple-macro (define-bool-ops o ...+) (ro:begin (define-bool-op o) ...))
(define-simple-macro (define-bool-op name)
@ -588,8 +564,7 @@
--------
[ (name- e1- e2-) bool]]
[(_ e1 e2) ; else try to coerce
--------
[ (name- (synth-app (bool) e1) (synth-app (bool) e2)) bool]]))
--- [ (name- (ro:#%app to-bool e1) (ro:#%app to-bool e2)) bool]]))
(define-simple-macro (define-real-ops o ...) (ro:begin (define-real-op o) ...))
(define-simple-macro (define-real-op name (~optional (~seq #:extra-check p?)
@ -597,32 +572,31 @@
#:with name- (mk-cl #'name)
#:with name= (format-id #'name "~a=" #'name) ; assignment form
(begin-
(define-typed-syntax (name e1 e2)
[ e1 e1- ty1]
[ e2 e2- ty2]
#:with ty-out (common-real-type #'ty1 #'ty2)
(define-typed-syntax (name e (... ...))
[ e e- ty] (... ...)
#:with ty-out (apply common-real-type (stx->list #'(ty (... ...))))
#:fail-unless (syntax-e #'ty-out)
(format "no common real type for operands; given ~a, ~a"
(type->str #'ty1) (type->str #'ty2))
#:when (p? #'ty-out #'ty1 #'ty2)
(format "no common real type for operands; given ~a"
(types->str #'(ty (... ...))))
#:when (p? #'ty-out #'(ty (... ...)))
#:with convert (get-convert #'ty-out)
#:with ty-base (get-base #'ty-out)
#:with base-convert (get-convert #'ty-base)
#:with (x (... ...)) (generate-temporaries #'(e (... ...)))
--------
[ #,(if (scalar-type? #'ty-out)
#'(convert (name- (convert e1-) (convert e2-)))
#'(convert (ro:let ([a (convert e1-)][b (convert e2-)])
(ro:for/list ([v1 a][v2 b])
(base-convert (name- v1 v2)))))) ty-out])
#'(ro:#%app convert (name- (convert e-) (... ...)))
#'(ro:#%app convert (ro:let ([x (ro:#%app convert e-)] (... ...))
(ro:for/list ([x x] (... ...))
(ro:#%app base-convert (name- x (... ...))))))) ty-out])
(define-typed-syntax (name= x e)
--------
[ (= x (name x e))])))
--- [ (= x (name x e))])))
(define-for-syntax (int? t given1 given2)
(define-for-syntax (int? t givens)
(or (typecheck/un? t #'int)
(raise-syntax-error #f
(format "no common integer type for operands; given ~a, ~a"
(type->str given1) (type->str given2)))))
(format "no common integer type for operands; given ~a"
(types->str givens)))))
(define-simple-macro (define-int-op o) (define-real-op o #:extra-check int?))
(define-simple-macro (define-int-ops o ...) (ro:begin (define-int-op o) ...))
@ -631,12 +605,10 @@
(define-int-ops % <<)
(define-typerule (sizeof t:type) >>
----------
[ #,(real-type-length #'t.norm) int])
--- [ #,(real-type-length #'t.norm) int])
(define-typerule (print e ...) >>
----------
[ (ro:begin (display e) ...) void])
--- [ (ro:begin (display e) ...) void])
(define-typed-syntax choose
[(ch e ...+)
@ -656,26 +628,35 @@
(define-for-syntax (decl->seq stx)
(syntax-parse stx
[((~datum :) type id (~datum in) rangeExpr)
(syntax/loc stx (id rangeExpr type))]
[((~datum :) type id)
(syntax/loc stx (id (ro:in-value (ro:let () (: type id) id)) type))]))
[((~datum :) ty:type id (~datum in) rangeExpr)
(syntax/loc stx (id rangeExpr ty.norm))]
[((~datum :) ty:type [len] id)
#:with tyout (mk-ptr #'ty)
(syntax/loc stx (id (ro:in-value (ro:let () (: ty [len] id) id)) tyout))]
[((~datum :) ty id)
(syntax/loc stx (id (ro:in-value (ro:let () (: ty id) id)) ty))]))
(define-typed-syntax (synth #:forall [decl ...] #:ensure e)
(define-typed-syntax synth
[(_ #:forall [decl ...] #:bitwidth bw #:ensure e)
#:with ([id seq ty] ...) (stx-map decl->seq #'(decl ...))
#:with (id- ...) (generate-temporaries #'(id ...))
#:with (typed-seq ...) #'((with-ctx ([id id- ty] ...) seq) ...)
#:with (tmp ...) (generate-temporaries #'(id ...))
--------
[ (ro:let ([id- 1] ...) ; dummy, enables simplifying stx template
[ (ro:let ([id- 1] ...) ; dummy ensuring id- bound, simplifies stx template
(ro:define-values (tmp ...)
(ro:for*/lists (tmp ...) ([id- typed-seq] ...) (ro:values id- ...)))
(ro:parameterize ([ro:term-cache (ro:hash-copy (ro:term-cache))])
(ro:parameterize ([ro:current-bitwidth bw]
[ro:term-cache (ro:hash-copy (ro:term-cache))])
(ro:print-forms
(ro:synthesize
#:forall (ro:append tmp ...)
#:guarantee (ro:for ([id- tmp] ...)
(with-ctx ([id id- ty] ...) e)))))) void])
(with-ctx ([id id- ty] ...) e)))))) void]]
[(_ #:forall [decl ...] #:ensure e)
--- [ (synth #:forall [decl ...] #:bitwidth 8 #:ensure e)]])
(define-typed-syntax verify
[(vfy #:forall [decl ...] #:ensure e)
@ -684,7 +665,8 @@
#:with (typed-seq ...) #'((with-ctx ([id id- ty] ...) seq) ...)
--------
[ (ro:let ([id- 1] ...) ; dummy, enables simplifying stx template
(ro:parameterize ([ro:term-cache (ro:hash-copy (ro:term-cache))])
(ro:parameterize ([ro:current-bitwidth 32]
[ro:term-cache (ro:hash-copy (ro:term-cache))])
(ro:for*/or ([id- typed-seq] ...)
(ro:define cex (with-ctx ([id id- ty] ...) (ro:verify e)))
(ro:and (ro:sat? cex)
@ -693,6 +675,12 @@
(printf "~a = ~a\n" i (ro:evaluate i- cex))))))) void]])
(define-typed-syntax (assert e)
#:with e- (expand/ro #'e)
--------
[ (ro:assert (to-bool e-)) void])
--- [ (ro:assert (ro:#%app to-bool #,(expand/ro #'e))) void])
(define- (/ x y)
(cond- [(zero?- y) 0]
[(integer?- x) (quotient- x y)]
[else (/- x y)]))
(define-typed-syntax (clCreateProgramWithSource ctx f)
--- [ (cl:clCreateProgramWithSource ctx f) cl_program])

View File

@ -0,0 +1,110 @@
#lang s-exp "../../../rosette/synthcl3.rkt"
; Matrix multiplication C = A * B, where A is an n x p matrix and B is a p x m matrix.
(kernel void (mmulScalarKernel [int* A] [int* B] [int* C] [int p] [int m])
(: int i j sum)
(= i (get_global_id 0))
(= j (get_global_id 1))
(= sum 0)
(for [(: int k in (range p))]
(+= sum (* [A (+ (* i p) k)] [B (+ (* k m) j)])))
(= [C (+ (* i m) j)] sum))
; Matrix multiplication C = A * B, where A is an n x p matrix and B is a p x m matrix.
(kernel void (mmulVectorKernel [int4* A] [int4* B] [int4* C] [int p] [int m])
(: int i j)
(: int4 sum0 sum1 sum2 sum3)
(= i (get_global_id 0))
(= j (get_global_id 1))
(= sum0 0)
(= sum1 0)
(= sum2 0)
(= sum3 0)
(for [(: int k in (range 0 p 4))]
(: int4 a0 a1 a2 a3)
(: int4 b0 b1 b2 b3)
(= a0 [A (indexA 0 i k p)])
(= a1 [A (indexA 1 i k p)])
(= a2 [A (indexA 2 i k p)])
(= a3 [A (indexA 3 i k p)])
(= b0 [B (indexB 0 k j m)])
(= b1 [B (indexB 1 k j m)])
(= b2 [B (indexB 2 k j m)])
(= b3 [B (indexB 3 k j m)])
(+= sum0 (computeSum a0 b0 b1 b2 b3))
(+= sum1 (computeSum a1 b0 b1 b2 b3))
(+= sum2 (computeSum a2 b0 b1 b2 b3))
(+= sum3 (computeSum a3 b0 b1 b2 b3)))
(= [C (indexC 0 i j m)] sum0)
(= [C (indexC 1 i j m)] sum1)
(= [C (indexC 2 i j m)] sum2)
(= [C (indexC 3 i j m)] sum3))
; Multiplies the 1x4 vector a by the 4x4 matrix with rows b0, b1, b2 and b3.
(procedure int4 (computeSum [int4 a] [int4 b0] [int4 b1] [int4 b2] [int4 b3])
(int4
(+ (* [a x] [b0 x]) (* [a y] [b1 x]) (* [a z] [b2 x]) (* [a w] [b3 x]))
(+ (* [a x] [b0 y]) (* [a y] [b1 y]) (* [a z] [b2 y]) (* [a w] [b3 y]))
(+ (* [a x] [b0 z]) (* [a y] [b1 z]) (* [a z] [b2 z]) (* [a w] [b3 z]))
(+ (* [a x] [b0 w]) (* [a y] [b1 w]) (* [a z] [b2 w]) (* [a w] [b3 w]))))
; Function for indexing into the matrix A.
(procedure int (indexA [int off] [int i] [int k] [int p])
; (+ (* (+ (* i 4) off) (/ p 4)) (/ k 4)))
(: int r c w)
(= r (+ (choose i (/ i 4) (* i 4)) (choose off 0)))
(= c (+ (choose k (/ k 4) (* k 4)) (choose off 0)))
(= w (+ (choose p (/ p 4) (* p 4)) (choose off 0)))
(+ (* r w) c))
; Function for indexing into the matrix B.
(procedure int (indexB [int off] [int k] [int j] [int m])
;(+ (* (+ k off) (/ m 4)) j))
(: int r c w)
(= r (+ (choose k (/ k 4) (* k 4)) (choose off 0)))
(= c (+ (choose j (/ j 4) (* j 4)) (choose off 0)))
(= w (+ (choose m (/ m 4) (* m 4)) (choose off 0)))
(+ (* r w) c))
; Function for indexing into the matrix C.
(procedure int (indexC [int off] [int i] [int j] [int m])
;(+ (* (+ (* i 4) off) (/ m 4)) j))
(: int r c w)
(= r (+ (choose i (/ i 4) (* i 4)) (choose off 0)))
(= c (+ (choose j (/ j 4) (* j 4)) (choose off 0)))
(= w (+ (choose m (/ m 4) (* m 4)) (choose off 0)))
(+ (* r w) c))
; Bad sketch and completions:
; Function for indexing into the matrix A.
;(procedure int (indexA [int off] [int i] [int k] [int p])
; (+ (* (+ (/ i 4) off) (/ p 4)) (/ k 4)))
;(: int r c w)
;(= r (+ (choose i (/ i 4)) (choose off 0)))
;(= c (+ (choose k (/ k 4)) (choose off 0)))
;(= w (+ (choose p (/ p 4)) (choose off 0)))
;(+ (* r w) c))
; Function for indexing into the matrix B.
;(procedure int (indexB [int off] [int k] [int j] [int m])
;(+ (* (+ k off) (/ m 4)) j))
;(: int r c w)
;(= r (+ (choose k (/ k 4)) (choose off 0)))
;(= c (+ (choose j (/ j 4)) (choose off 0)))
;(= w (+ (choose m (/ m 4)) (choose off 0)))
;(+ (* r w) c))
; Function for indexing into the matrix C.
;(procedure int (indexC [int off] [int i] [int j] [int m])
;(+ (* (+ (/ i 4) off) (/ m 4)) j))
;(: int r c w)
;(= r (+ (choose i (/ i 4)) (choose off 0)))
;(= c (+ (choose j (/ j 4)) (choose off 0)))
;(= w (+ (choose m (/ m 4)) (choose off 0)))
;(+ (* r w) c))

View File

@ -3,98 +3,104 @@
(prefix-in cl: sdsl/synthcl/lang/main)
(prefix-in ro: (rename-in rosette [#%app a])))
(: int2 [3] xs)
xs
(check-type xs : int2*)
(: int [4] xs2)
xs2
(check-type xs2 : int*)
;; ; The reference implementation for square matrix multiplication.
;; ; Multiplies two squre matrices A and B, where the dimension of A is
;; ; n x p and dimension of B is p x m. Both matrices are given as
;; ; flat arrays in row-major form. The output is the matrix C = A*B,
;; ; also given in row-major form.
;; (procedure int* (mmulSequential [int* A] [int* B] [int n] [int p] [int m])
;; (: int* C)
;; (= C ((int*) (malloc (* n m (sizeof int)))))
;; (for [(: int i in (range n))
;; (: int j in (range m))
;; (: int k in (range p))]
;; (+= [C (+ (* i m) j)] (* [A (+ (* i p) k)] [B (+ (* k m) j)])))
;; C)
; The reference implementation for square matrix multiplication.
; Multiplies two squre matrices A and B, where the dimension of A is
; n x p and dimension of B is p x m. Both matrices are given as
; flat arrays in row-major form. The output is the matrix C = A*B,
; also given in row-major form.
(procedure int* (mmulSequential [int* A] [int* B] [int n] [int p] [int m])
(: int* C)
(= C ((int*) (malloc (* n m (sizeof int)))))
(for [(: int i in (range n))
(: int j in (range m))
(: int k in (range p))]
(+= [C (+ (* i m) j)] (* [A (+ (* i p) k)] [B (+ (* k m) j)])))
C)
; A host implementation of matrix multiplication.
(procedure int* (mmulHost [char* kernelName] [int typeLen]
[int* A] [int* B] [int n] [int p] [int m])
;; (: cl_context context)
;; (: cl_command_queue command_queue)
;; (: cl_program program)
;; (: cl_kernel kernel)
;; (: cl_mem buffer_A buffer_B buffer_C)
(: int* C)
;; (: int[2] global)
;; (: int dimA dimB dimC)
;; (= [global 0] (/ n typeLen))
;; (= [global 1] (/ m typeLen))
;; (= dimA (* n p (sizeof int)))
;; (= dimB (* p m (sizeof int)))
;; (= dimC (* n m (sizeof int)))
;; (= C ((int*) (malloc dimC)))
;; (= context (clCreateContext))
;; (= command_queue (clCreateCommandQueue context))
;; (= buffer_A (clCreateBuffer context CL_MEM_READ_ONLY dimA))
;; (= buffer_B (clCreateBuffer context CL_MEM_READ_ONLY dimB))
;; (= buffer_C (clCreateBuffer context CL_MEM_WRITE_ONLY dimC))
;; (= program (clCreateProgramWithSource context "kernel.rkt"))
;; (clEnqueueWriteBuffer command_queue buffer_A 0 dimA A)
;; (clEnqueueWriteBuffer command_queue buffer_B 0 dimB B)
;; (= kernel (clCreateKernel program kernelName))
;; (clSetKernelArg kernel 0 buffer_A)
;; (clSetKernelArg kernel 1 buffer_B)
;; (clSetKernelArg kernel 2 buffer_C)
;; (clSetKernelArg kernel 3 p)
;; (clSetKernelArg kernel 4 m)
(: cl_context context)
(: cl_command_queue command_queue)
(: cl_program program)
(: cl_kernel kernel)
(: cl_mem buffer_A buffer_B buffer_C)
;; (clEnqueueNDRangeKernel command_queue kernel 2 NULL global NULL)
;; (clEnqueueReadBuffer command_queue buffer_C 0 dimC C)
(: int* C)
(: int[2] global)
(: int dimA dimB dimC)
(= [global 0] (/ n typeLen))
(= [global 1] (/ m typeLen))
(= dimA (* n p (sizeof int)))
(= dimB (* p m (sizeof int)))
(= dimC (* n m (sizeof int)))
(= C ((int*) (malloc dimC)))
(= context (clCreateContext))
(= command_queue (clCreateCommandQueue context))
(= buffer_A (clCreateBuffer context CL_MEM_READ_ONLY dimA))
(= buffer_B (clCreateBuffer context CL_MEM_READ_ONLY dimB))
(= buffer_C (clCreateBuffer context CL_MEM_WRITE_ONLY dimC))
(= program (clCreateProgramWithSource context "matrix-synth-kernel.rkt"))
(clEnqueueWriteBuffer command_queue buffer_A 0 dimA A)
(clEnqueueWriteBuffer command_queue buffer_B 0 dimB B)
(= kernel (clCreateKernel program kernelName))
(clSetKernelArg kernel 0 buffer_A)
(clSetKernelArg kernel 1 buffer_B)
(clSetKernelArg kernel 2 buffer_C)
(clSetKernelArg kernel 3 p)
(clSetKernelArg kernel 4 m)
(clEnqueueNDRangeKernel command_queue kernel 2 NULL global NULL)
(clEnqueueReadBuffer command_queue buffer_C 0 dimC C)
C)
;; ; A scalar parallel implementation of matrix multiplication.
;; (procedure int* (mmulScalar [int* A] [int* B] [int n] [int p] [int m])
;; (mmulHost "mmulScalarKernel" 1 A B n p m))
; A scalar parallel implementation of matrix multiplication.
(procedure int* (mmulScalar [int* A] [int* B] [int n] [int p] [int m])
(mmulHost "mmulScalarKernel" 1 A B n p m))
; A vector parallel implementation of matrix multiplication. The dimensions
; n and m must be evenly divisible by 4.
#;(procedure int* (mmulVector [int* A] [int* B] [int n] [int p] [int m])
(procedure int* (mmulVector [int* A] [int* B] [int n] [int p] [int m])
(mmulHost "mmulVectorKernel" 4 A B n p m))
;; ; Given two arrays of the same size, checks that they hold the same
;; ; values at each index.
;; (procedure void (check [int* actual] [int* expected] [int SIZE])
;; (assert (>= SIZE 0))
;; (for [(: int i in (range SIZE))]
;; (assert (== [actual i] [expected i]))))
; Given two arrays of the same size, checks that they hold the same
; values at each index.
(procedure void (check [int* actual] [int* expected] [int SIZE])
(assert (>= SIZE 0))
(for [(: int i in (range SIZE))]
(assert (== [actual i] [expected i]))))
;; (procedure void (synth_vector [int size])
;; (synth #:forall [(: int n in (range size (+ 1 size)))
;; (: int p in (range size (+ 1 size)))
;; (: int m in (range size (+ 1 size)))
;; (: int[(* n p)] A)
;; (: int[(* p m)] B)]
;; #:ensure (check (mmulVector A B n p m)
;; (mmulSequential A B n p m)
;; (* n m))))
(procedure void (synth_vector [int size])
(synth #:forall [(: int n in (range size (+ 1 size)))
(: int p in (range size (+ 1 size)))
(: int m in (range size (+ 1 size)))
(: int[(* n p)] A)
(: int[(* p m)] B)]
#:ensure (check (mmulVector A B n p m)
(mmulSequential A B n p m)
(* n m))))
;; ; (synth_vector 4) ; 20 sec
;; ; (synth_vector 8) ; 252 sec
(: int n)
(= n 4)
(: int p)
(= p 4)
(: int m)
(= m 4)
(: int[(* n p)] A)
(: int[(* p m)] B)
(check-type (mmulVector A B n p m) : int*)
(check-type (mmulSequential A B n p m) : int*)
(check-type
(with-output-to-string
(λ ()
(synth_vector 4))) ; 20 sec
: CString
-> "/home/stchang/NEU_Research/macrotypes/turnstile/examples/tests/rosette/rosette3/matrix-synth-kernel.rkt:57:0\n'(procedure\n int\n (indexA (int off) (int i) (int k) (int p))\n (: int r c w)\n (= r (+ (choose i (/ i 4) (* i 4)) off))\n (= c (+ (choose k (/ k 4) (* k 4)) 0))\n (= w (+ (/ p 4) 0))\n (+ (* r w) c))\n/home/stchang/NEU_Research/macrotypes/turnstile/examples/tests/rosette/rosette3/matrix-synth-kernel.rkt:66:0\n'(procedure\n int\n (indexB (int off) (int k) (int j) (int m))\n (: int r c w)\n (= r (+ (choose k (/ k 4) (* k 4)) off))\n (= c (+ (choose j (/ j 4) (* j 4)) 0))\n (= w (+ (/ m 4) 0))\n (+ (* r w) c))\n/home/stchang/NEU_Research/macrotypes/turnstile/examples/tests/rosette/rosette3/matrix-synth-kernel.rkt:75:0\n'(procedure\n int\n (indexC (int off) (int i) (int j) (int m))\n (: int r c w)\n (= r (+ (choose i (/ i 4) (* i 4)) off))\n (= c (+ (choose j (/ j 4) (* j 4)) 0))\n (= w (+ (/ m 4) 0))\n (+ (* r w) c))\n")
;(synth_vector 8) ; 252 sec

View File

@ -345,3 +345,36 @@
{(assert k)}))))
: CString
-> "counterexample found:\nt = 2\nk = 0\np = 4\n")
(: int2 [3] xs)
(check-type xs : int2*)
(: int [4] xs2)
(check-type xs2 : int*)
; basic matrix multiplying
(: int4 sum0 sum1 sum2 sum3)
(= sum0 0)
(= sum1 0)
(= sum2 0)
(= sum3 0)
(procedure int (computeSum1 [int4 a] [int4 b0] [int4 b1] [int4 b2] [int4 b3])
(+ (* [a x] [b0 x]) (* [a y] [b1 x]) (* [a z] [b2 x]) (* [a w] [b3 x])))
(procedure int (computeSum2 [int4 a] [int4 b0] [int4 b1] [int4 b2] [int4 b3])
(+ (* [a x] [b0 y]) (* [a y] [b1 y]) (* [a z] [b2 y]) (* [a w] [b3 y])))
(procedure int (computeSum3 [int4 a] [int4 b0] [int4 b1] [int4 b2] [int4 b3])
(+ (* [a x] [b0 z]) (* [a y] [b1 z]) (* [a z] [b2 z]) (* [a w] [b3 z])))
(procedure int (computeSum4 [int4 a] [int4 b0] [int4 b1] [int4 b2] [int4 b3])
(+ (* [a x] [b0 w]) (* [a y] [b1 w]) (* [a z] [b2 w]) (* [a w] [b3 w])))
(check-type (computeSum1 sum0 sum0 sum1 sum2 sum3) : int -> 0)
(check-type (computeSum2 sum0 sum0 sum1 sum2 sum3) : int -> 0)
(check-type (computeSum3 sum0 sum0 sum1 sum2 sum3) : int -> 0)
(check-type (computeSum4 sum0 sum0 sum1 sum2 sum3) : int -> 0)
(check-type (int4 (computeSum1 sum0 sum0 sum1 sum2 sum3)
(computeSum2 sum0 sum0 sum1 sum2 sum3)
(computeSum3 sum0 sum0 sum1 sum2 sum3)
(computeSum4 sum0 sum0 sum1 sum2 sum3))
: int4
-> (ro:a ro:vector-immutable 0 0 0 0))