support array decls; fix procedure to support local decls

This commit is contained in:
Stephen Chang 2016-11-30 16:52:14 -05:00
parent 5e0acd3f9b
commit f1f43697f9
2 changed files with 153 additions and 15 deletions

View File

@ -19,7 +19,7 @@
procedure kernel grammar #%datum if range for print
choose locally-scoped assert synth verify
int int2 int3 int4 int16 float float2 float3 float4 float16
bool void void* char* float* int* int16*
bool void void* char* float* int* int16* int2*
: ! ?: == + * - || &&
% << ; int ops
= += -= %= ; assignment ops
@ -48,10 +48,6 @@
(define (pointer-type? t)
(Pointer? t)
#;(regexp-match #px"\\*$" (type->str t)))
(define (type-base t)
(datum->syntax t
(string->symbol
(car (regexp-match #px"[a-z]+" (type->str t))))))
(define (real-type<=? t1 t2)
(and (real-type? t1) (real-type? t2)
(or ((current-type=?) t1 t2) ; need type= to distinguish reals/ints
@ -59,7 +55,7 @@
(and (typecheck/un? t1 #'int)
(not (typecheck/un? t2 #'bool)))
(and (typecheck/un? t1 #'float)
(typecheck/un? (type-base t2) #'float)))))
(typecheck/un? (get-base t2) #'float)))))
; Returns the common real type of the given types, as specified in
; Ch. 6.2.6 of opencl-1.2 specification. If there is no common
@ -112,14 +108,14 @@
(regexp-match #px"([a-z]+)([0-9]+)" (type->str ty)))
(define (real-type-length t)
(define split-ty (ty->len t))
(or (and split-ty (third split-ty)) 1))
(string->number
(or (and split-ty (third split-ty)) "1")))
(define (get-base ty [ctx #'here])
((current-type-eval)
(datum->syntax ctx
(string->symbol (car (regexp-match #px"[a-z]+" (type->str ty)))))))
(define (vector-type? ty)
;; TODO: and not pointer-type?
(ty->len ty))
(ty->len ty)) ; TODO: check and not pointer-type?
(define (scalar-type? ty)
(or (typecheck/un? ty #'bool)
(and (real-type? ty) (not (vector-type? ty))))))
@ -147,9 +143,13 @@
[(ro:fixnum? v) (ro:exact->inexact v)]
[(ro:flonum? v) v]
[else (ro:type-cast ro:real? v)]))
(ro:define (mk-int v)
(ro:#%app cl:int v))
(ro:define (to-int16* v)
(cl:pointer-cast v cl:int16))
(ro:define (to-int2* v)
(cl:pointer-cast v cl:int2))
(ro:define (to-int* v)
(cl:pointer-cast v cl:int))
(ro:define (to-float* v)
@ -161,7 +161,8 @@
(add-convertm rosette3:Int to-int))
(define-named-type-alias float
(add-convertm rosette3:Num to-float))
(define-named-type-alias char* rosette3:CString)
(define-named-type-alias char*
(add-convertm rosette3:CString (λ (x) x)))
(define-syntax (define-int stx)
(syntax-parse stx
@ -169,6 +170,7 @@
#:with intn (format-id #'n "int~a" (syntax->datum #'n))
#:with to-intn (format-id #'n "to-~a" #'intn)
#:with mk-intn (format-id #'n "mk-~a" #'intn)
#:with cl-mk-intn (mk-cl #'intn)
#:with (x ...) (generate-temporaries
(build-list (syntax->datum #'n) (lambda (x) x)))
#:with (I ...) (stx-map (lambda _ #'rosette3:Int) #'(x ...))
@ -191,7 +193,8 @@
(ro:apply ro:vector-immutable
(ro:make-list n (to-int v)))]))
(ro:define (mk-intn x ...)
(ro:#%app ro:vector-immutable (to-int x) ...))
(ro:#%app cl-mk-intn x ...)
#;(ro:#%app ro:vector-immutable (to-int x) ...))
)]))
(define-simple-macro (define-ints n ...) (begin (define-int n) ...))
(define-ints 2 3 4 16)
@ -202,6 +205,7 @@
#:with floatn (format-id #'n "float~a" (syntax->datum #'n))
#:with to-floatn (format-id #'n "to-~a" #'floatn)
#:with mk-floatn (format-id #'n "mk-~a" #'floatn)
#:with cl-mk-floatn (mk-cl #'floatn)
#:with (x ...) (generate-temporaries
(build-list (syntax->datum #'n) (lambda (x) x)))
#:with (I ...) (stx-map (lambda _ #'rosette3:Num) #'(x ...))
@ -224,7 +228,8 @@
(ro:apply ro:vector-immutable
(ro:make-list n (to-float v)))]))
(ro:define (mk-floatn x ...)
(ro:#%app ro:vector-immutable (to-float x) ...))
(ro:#%app cl-mk-floatn x ...)
#;(ro:#%app ro:vector-immutable (to-float x) ...))
)]))
(define-simple-macro (define-floats n ...) (begin (define-float n) ...))
(define-floats 2 3 4 16)
@ -243,6 +248,8 @@
(add-convertm (Pointer int) to-int*))
(define-named-type-alias int16*
(add-convertm (Pointer int16) to-int16*))
(define-named-type-alias int2*
(add-convertm (Pointer int2) to-int2*))
(define-named-type-alias float*
(add-convertm (Pointer float) to-float*))
@ -325,11 +332,17 @@
(format "expected void, given ~a" (type->str #'ty-out.norm))
--------
[ (rosette3:define (f [x : ty] ... -> void) (⊢m (ro:void) void))]]
[(_ ty-out:type (f [ty:type x:id] ...) e ...+)
[(_ ty-out:type (f [ty:type x:id] ...) e ... e-body)
#:with (conv ...) (stx-map get-convert #'(ty.norm ...))
#:with f- (add-orig (generate-temporary #'f) #'f)
--------
[ (rosette3:define (f [x : ty] ... -> ty-out)
(rosette3:let ([x (⊢m (ro:#%app conv x) ty)] ...) e ...))]])
[ (begin-
(define-syntax- f
(make-rename-transformer ( f- : (C→ ty ... ty-out))))
(define- f-
(lambda- (x ...)
(rosette3:let ([x (⊢m (ro:#%app conv x) ty)] ...)
(⊢m (let- () e ... (rosette3:ann e-body : ty-out)) ty-out)))))]])
(define-typed-syntax kernel
[(_ ty-out:type (f [ty:type x:id] ...) e ...)
#:fail-unless (void? #'ty-out.norm)
@ -407,6 +420,7 @@
(ro:define x- (ro:#%datum . "")) ...)]]
;; TODO: vector types need a better representation
;; currently dissecting the identifier
;; TODO: combine vector and scalar cases
[(_ ty:type x:id ...)
#:when (real-type? #'ty.norm)
#:do [(define split-ty (ty->len #'ty))]
@ -425,6 +439,30 @@
(ro:define x- (ro:apply ro:vector-immutable x--)) ...
(define-syntax- x
(make-rename-transformer (assign-type #'x- #'ty.norm))) ...)]]
[(_ ty:type [len] x:id ...) ; array of vector types
#:when (real-type? #'ty.norm)
[ len len- int]
#:with ty-base (get-base #'ty.norm)
#:with base-len (datum->syntax #'ty (real-type-length #'ty.norm))
#:with ty* (format-id #'ty "~a*" #'ty)
#:with to-ty* (format-id #'here "to-~a" #'ty*)
#:with pred (get-pred ((current-type-eval) #'ty-base))
#:fail-unless (syntax-e #'pred)
(format "no pred for ~a" (type->str #'ty))
#:with (x- ...) (generate-temporaries #'(x ...))
#:with (*x ...) (generate-temporaries #'(x ...))
#:with (x-- ...) (generate-temporaries #'(x ...))
#:with mk-ty (format-id #'here "mk-~a" #'ty)
--------
[ (begin-
(ro:define-symbolic* x-- pred [len base-len]) ...
(ro:define x-
(ro:let ([*x (to-ty* (cl:malloc (ro:* len base-len)))])
(ro:for ([i len][v x--])
(cl:pointer-set! *x i (ro:apply mk-ty v)))
*x)) ...
(define-syntax- x
(make-rename-transformer (assign-type #'x- #'ty*))) ...)]]
;; real, scalar (ie non-vector) types
[(_ ty:type x:id ...)
#:when (real-type? #'ty.norm)

View File

@ -0,0 +1,100 @@
#lang s-exp "../../../rosette/synthcl3.rkt"
(require "../../rackunit-typechecking.rkt"
(prefix-in cl: sdsl/synthcl/lang/main)
(prefix-in ro: (rename-in rosette [#%app a])))
(: int2 [3] xs)
xs
(check-type xs : int2*)
(: int [4] xs2)
xs2
(check-type xs2 : int*)
;; ; The reference implementation for square matrix multiplication.
;; ; Multiplies two squre matrices A and B, where the dimension of A is
;; ; n x p and dimension of B is p x m. Both matrices are given as
;; ; flat arrays in row-major form. The output is the matrix C = A*B,
;; ; also given in row-major form.
;; (procedure int* (mmulSequential [int* A] [int* B] [int n] [int p] [int m])
;; (: int* C)
;; (= C ((int*) (malloc (* n m (sizeof int)))))
;; (for [(: int i in (range n))
;; (: int j in (range m))
;; (: int k in (range p))]
;; (+= [C (+ (* i m) j)] (* [A (+ (* i p) k)] [B (+ (* k m) j)])))
;; C)
; A host implementation of matrix multiplication.
(procedure int* (mmulHost [char* kernelName] [int typeLen]
[int* A] [int* B] [int n] [int p] [int m])
;; (: cl_context context)
;; (: cl_command_queue command_queue)
;; (: cl_program program)
;; (: cl_kernel kernel)
;; (: cl_mem buffer_A buffer_B buffer_C)
(: int* C)
;; (: int[2] global)
;; (: int dimA dimB dimC)
;; (= [global 0] (/ n typeLen))
;; (= [global 1] (/ m typeLen))
;; (= dimA (* n p (sizeof int)))
;; (= dimB (* p m (sizeof int)))
;; (= dimC (* n m (sizeof int)))
;; (= C ((int*) (malloc dimC)))
;; (= context (clCreateContext))
;; (= command_queue (clCreateCommandQueue context))
;; (= buffer_A (clCreateBuffer context CL_MEM_READ_ONLY dimA))
;; (= buffer_B (clCreateBuffer context CL_MEM_READ_ONLY dimB))
;; (= buffer_C (clCreateBuffer context CL_MEM_WRITE_ONLY dimC))
;; (= program (clCreateProgramWithSource context "kernel.rkt"))
;; (clEnqueueWriteBuffer command_queue buffer_A 0 dimA A)
;; (clEnqueueWriteBuffer command_queue buffer_B 0 dimB B)
;; (= kernel (clCreateKernel program kernelName))
;; (clSetKernelArg kernel 0 buffer_A)
;; (clSetKernelArg kernel 1 buffer_B)
;; (clSetKernelArg kernel 2 buffer_C)
;; (clSetKernelArg kernel 3 p)
;; (clSetKernelArg kernel 4 m)
;; (clEnqueueNDRangeKernel command_queue kernel 2 NULL global NULL)
;; (clEnqueueReadBuffer command_queue buffer_C 0 dimC C)
C)
;; ; A scalar parallel implementation of matrix multiplication.
;; (procedure int* (mmulScalar [int* A] [int* B] [int n] [int p] [int m])
;; (mmulHost "mmulScalarKernel" 1 A B n p m))
; A vector parallel implementation of matrix multiplication. The dimensions
; n and m must be evenly divisible by 4.
#;(procedure int* (mmulVector [int* A] [int* B] [int n] [int p] [int m])
(mmulHost "mmulVectorKernel" 4 A B n p m))
;; ; Given two arrays of the same size, checks that they hold the same
;; ; values at each index.
;; (procedure void (check [int* actual] [int* expected] [int SIZE])
;; (assert (>= SIZE 0))
;; (for [(: int i in (range SIZE))]
;; (assert (== [actual i] [expected i]))))
;; (procedure void (synth_vector [int size])
;; (synth #:forall [(: int n in (range size (+ 1 size)))
;; (: int p in (range size (+ 1 size)))
;; (: int m in (range size (+ 1 size)))
;; (: int[(* n p)] A)
;; (: int[(* p m)] B)]
;; #:ensure (check (mmulVector A B n p m)
;; (mmulSequential A B n p m)
;; (* n m))))
;; ; (synth_vector 4) ; 20 sec
;; ; (synth_vector 8) ; 252 sec