support array decls; fix procedure to support local decls
This commit is contained in:
parent
5e0acd3f9b
commit
f1f43697f9
|
@ -19,7 +19,7 @@
|
|||
procedure kernel grammar #%datum if range for print
|
||||
choose locally-scoped assert synth verify
|
||||
int int2 int3 int4 int16 float float2 float3 float4 float16
|
||||
bool void void* char* float* int* int16*
|
||||
bool void void* char* float* int* int16* int2*
|
||||
: ! ?: == + * - || &&
|
||||
% << ; int ops
|
||||
= += -= %= ; assignment ops
|
||||
|
@ -48,10 +48,6 @@
|
|||
(define (pointer-type? t)
|
||||
(Pointer? t)
|
||||
#;(regexp-match #px"\\*$" (type->str t)))
|
||||
(define (type-base t)
|
||||
(datum->syntax t
|
||||
(string->symbol
|
||||
(car (regexp-match #px"[a-z]+" (type->str t))))))
|
||||
(define (real-type<=? t1 t2)
|
||||
(and (real-type? t1) (real-type? t2)
|
||||
(or ((current-type=?) t1 t2) ; need type= to distinguish reals/ints
|
||||
|
@ -59,7 +55,7 @@
|
|||
(and (typecheck/un? t1 #'int)
|
||||
(not (typecheck/un? t2 #'bool)))
|
||||
(and (typecheck/un? t1 #'float)
|
||||
(typecheck/un? (type-base t2) #'float)))))
|
||||
(typecheck/un? (get-base t2) #'float)))))
|
||||
|
||||
; Returns the common real type of the given types, as specified in
|
||||
; Ch. 6.2.6 of opencl-1.2 specification. If there is no common
|
||||
|
@ -112,14 +108,14 @@
|
|||
(regexp-match #px"([a-z]+)([0-9]+)" (type->str ty)))
|
||||
(define (real-type-length t)
|
||||
(define split-ty (ty->len t))
|
||||
(or (and split-ty (third split-ty)) 1))
|
||||
(string->number
|
||||
(or (and split-ty (third split-ty)) "1")))
|
||||
(define (get-base ty [ctx #'here])
|
||||
((current-type-eval)
|
||||
(datum->syntax ctx
|
||||
(string->symbol (car (regexp-match #px"[a-z]+" (type->str ty)))))))
|
||||
(define (vector-type? ty)
|
||||
;; TODO: and not pointer-type?
|
||||
(ty->len ty))
|
||||
(ty->len ty)) ; TODO: check and not pointer-type?
|
||||
(define (scalar-type? ty)
|
||||
(or (typecheck/un? ty #'bool)
|
||||
(and (real-type? ty) (not (vector-type? ty))))))
|
||||
|
@ -147,9 +143,13 @@
|
|||
[(ro:fixnum? v) (ro:exact->inexact v)]
|
||||
[(ro:flonum? v) v]
|
||||
[else (ro:type-cast ro:real? v)]))
|
||||
(ro:define (mk-int v)
|
||||
(ro:#%app cl:int v))
|
||||
|
||||
(ro:define (to-int16* v)
|
||||
(cl:pointer-cast v cl:int16))
|
||||
(ro:define (to-int2* v)
|
||||
(cl:pointer-cast v cl:int2))
|
||||
(ro:define (to-int* v)
|
||||
(cl:pointer-cast v cl:int))
|
||||
(ro:define (to-float* v)
|
||||
|
@ -161,7 +161,8 @@
|
|||
(add-convertm rosette3:Int to-int))
|
||||
(define-named-type-alias float
|
||||
(add-convertm rosette3:Num to-float))
|
||||
(define-named-type-alias char* rosette3:CString)
|
||||
(define-named-type-alias char*
|
||||
(add-convertm rosette3:CString (λ (x) x)))
|
||||
|
||||
(define-syntax (define-int stx)
|
||||
(syntax-parse stx
|
||||
|
@ -169,6 +170,7 @@
|
|||
#:with intn (format-id #'n "int~a" (syntax->datum #'n))
|
||||
#:with to-intn (format-id #'n "to-~a" #'intn)
|
||||
#:with mk-intn (format-id #'n "mk-~a" #'intn)
|
||||
#:with cl-mk-intn (mk-cl #'intn)
|
||||
#:with (x ...) (generate-temporaries
|
||||
(build-list (syntax->datum #'n) (lambda (x) x)))
|
||||
#:with (I ...) (stx-map (lambda _ #'rosette3:Int) #'(x ...))
|
||||
|
@ -191,7 +193,8 @@
|
|||
(ro:apply ro:vector-immutable
|
||||
(ro:make-list n (to-int v)))]))
|
||||
(ro:define (mk-intn x ...)
|
||||
(ro:#%app ro:vector-immutable (to-int x) ...))
|
||||
(ro:#%app cl-mk-intn x ...)
|
||||
#;(ro:#%app ro:vector-immutable (to-int x) ...))
|
||||
)]))
|
||||
(define-simple-macro (define-ints n ...) (begin (define-int n) ...))
|
||||
(define-ints 2 3 4 16)
|
||||
|
@ -202,6 +205,7 @@
|
|||
#:with floatn (format-id #'n "float~a" (syntax->datum #'n))
|
||||
#:with to-floatn (format-id #'n "to-~a" #'floatn)
|
||||
#:with mk-floatn (format-id #'n "mk-~a" #'floatn)
|
||||
#:with cl-mk-floatn (mk-cl #'floatn)
|
||||
#:with (x ...) (generate-temporaries
|
||||
(build-list (syntax->datum #'n) (lambda (x) x)))
|
||||
#:with (I ...) (stx-map (lambda _ #'rosette3:Num) #'(x ...))
|
||||
|
@ -224,7 +228,8 @@
|
|||
(ro:apply ro:vector-immutable
|
||||
(ro:make-list n (to-float v)))]))
|
||||
(ro:define (mk-floatn x ...)
|
||||
(ro:#%app ro:vector-immutable (to-float x) ...))
|
||||
(ro:#%app cl-mk-floatn x ...)
|
||||
#;(ro:#%app ro:vector-immutable (to-float x) ...))
|
||||
)]))
|
||||
(define-simple-macro (define-floats n ...) (begin (define-float n) ...))
|
||||
(define-floats 2 3 4 16)
|
||||
|
@ -243,6 +248,8 @@
|
|||
(add-convertm (Pointer int) to-int*))
|
||||
(define-named-type-alias int16*
|
||||
(add-convertm (Pointer int16) to-int16*))
|
||||
(define-named-type-alias int2*
|
||||
(add-convertm (Pointer int2) to-int2*))
|
||||
(define-named-type-alias float*
|
||||
(add-convertm (Pointer float) to-float*))
|
||||
|
||||
|
@ -325,11 +332,17 @@
|
|||
(format "expected void, given ~a" (type->str #'ty-out.norm))
|
||||
--------
|
||||
[≻ (rosette3:define (f [x : ty] ... -> void) (⊢m (ro:void) void))]]
|
||||
[(_ ty-out:type (f [ty:type x:id] ...) e ...+) ≫
|
||||
[(_ ty-out:type (f [ty:type x:id] ...) e ... e-body) ≫
|
||||
#:with (conv ...) (stx-map get-convert #'(ty.norm ...))
|
||||
#:with f- (add-orig (generate-temporary #'f) #'f)
|
||||
--------
|
||||
[≻ (rosette3:define (f [x : ty] ... -> ty-out)
|
||||
(rosette3:let ([x (⊢m (ro:#%app conv x) ty)] ...) e ...))]])
|
||||
[≻ (begin-
|
||||
(define-syntax- f
|
||||
(make-rename-transformer (⊢ f- : (C→ ty ... ty-out))))
|
||||
(define- f-
|
||||
(lambda- (x ...)
|
||||
(rosette3:let ([x (⊢m (ro:#%app conv x) ty)] ...)
|
||||
(⊢m (let- () e ... (rosette3:ann e-body : ty-out)) ty-out)))))]])
|
||||
(define-typed-syntax kernel
|
||||
[(_ ty-out:type (f [ty:type x:id] ...) e ...) ≫
|
||||
#:fail-unless (void? #'ty-out.norm)
|
||||
|
@ -407,6 +420,7 @@
|
|||
(ro:define x- (ro:#%datum . "")) ...)]]
|
||||
;; TODO: vector types need a better representation
|
||||
;; currently dissecting the identifier
|
||||
;; TODO: combine vector and scalar cases
|
||||
[(_ ty:type x:id ...) ≫
|
||||
#:when (real-type? #'ty.norm)
|
||||
#:do [(define split-ty (ty->len #'ty))]
|
||||
|
@ -425,6 +439,30 @@
|
|||
(ro:define x- (ro:apply ro:vector-immutable x--)) ...
|
||||
(define-syntax- x
|
||||
(make-rename-transformer (assign-type #'x- #'ty.norm))) ...)]]
|
||||
[(_ ty:type [len] x:id ...) ≫ ; array of vector types
|
||||
#:when (real-type? #'ty.norm)
|
||||
[⊢ len ≫ len- ⇐ int]
|
||||
#:with ty-base (get-base #'ty.norm)
|
||||
#:with base-len (datum->syntax #'ty (real-type-length #'ty.norm))
|
||||
#:with ty* (format-id #'ty "~a*" #'ty)
|
||||
#:with to-ty* (format-id #'here "to-~a" #'ty*)
|
||||
#:with pred (get-pred ((current-type-eval) #'ty-base))
|
||||
#:fail-unless (syntax-e #'pred)
|
||||
(format "no pred for ~a" (type->str #'ty))
|
||||
#:with (x- ...) (generate-temporaries #'(x ...))
|
||||
#:with (*x ...) (generate-temporaries #'(x ...))
|
||||
#:with (x-- ...) (generate-temporaries #'(x ...))
|
||||
#:with mk-ty (format-id #'here "mk-~a" #'ty)
|
||||
--------
|
||||
[≻ (begin-
|
||||
(ro:define-symbolic* x-- pred [len base-len]) ...
|
||||
(ro:define x-
|
||||
(ro:let ([*x (to-ty* (cl:malloc (ro:* len base-len)))])
|
||||
(ro:for ([i len][v x--])
|
||||
(cl:pointer-set! *x i (ro:apply mk-ty v)))
|
||||
*x)) ...
|
||||
(define-syntax- x
|
||||
(make-rename-transformer (assign-type #'x- #'ty*))) ...)]]
|
||||
;; real, scalar (ie non-vector) types
|
||||
[(_ ty:type x:id ...) ≫
|
||||
#:when (real-type? #'ty.norm)
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
#lang s-exp "../../../rosette/synthcl3.rkt"
|
||||
(require "../../rackunit-typechecking.rkt"
|
||||
(prefix-in cl: sdsl/synthcl/lang/main)
|
||||
(prefix-in ro: (rename-in rosette [#%app a])))
|
||||
|
||||
(: int2 [3] xs)
|
||||
xs
|
||||
(check-type xs : int2*)
|
||||
|
||||
(: int [4] xs2)
|
||||
xs2
|
||||
(check-type xs2 : int*)
|
||||
|
||||
|
||||
|
||||
;; ; The reference implementation for square matrix multiplication.
|
||||
;; ; Multiplies two squre matrices A and B, where the dimension of A is
|
||||
;; ; n x p and dimension of B is p x m. Both matrices are given as
|
||||
;; ; flat arrays in row-major form. The output is the matrix C = A*B,
|
||||
;; ; also given in row-major form.
|
||||
;; (procedure int* (mmulSequential [int* A] [int* B] [int n] [int p] [int m])
|
||||
;; (: int* C)
|
||||
;; (= C ((int*) (malloc (* n m (sizeof int)))))
|
||||
;; (for [(: int i in (range n))
|
||||
;; (: int j in (range m))
|
||||
;; (: int k in (range p))]
|
||||
;; (+= [C (+ (* i m) j)] (* [A (+ (* i p) k)] [B (+ (* k m) j)])))
|
||||
;; C)
|
||||
|
||||
; A host implementation of matrix multiplication.
|
||||
(procedure int* (mmulHost [char* kernelName] [int typeLen]
|
||||
[int* A] [int* B] [int n] [int p] [int m])
|
||||
;; (: cl_context context)
|
||||
;; (: cl_command_queue command_queue)
|
||||
;; (: cl_program program)
|
||||
;; (: cl_kernel kernel)
|
||||
;; (: cl_mem buffer_A buffer_B buffer_C)
|
||||
(: int* C)
|
||||
;; (: int[2] global)
|
||||
;; (: int dimA dimB dimC)
|
||||
|
||||
;; (= [global 0] (/ n typeLen))
|
||||
;; (= [global 1] (/ m typeLen))
|
||||
;; (= dimA (* n p (sizeof int)))
|
||||
;; (= dimB (* p m (sizeof int)))
|
||||
;; (= dimC (* n m (sizeof int)))
|
||||
|
||||
;; (= C ((int*) (malloc dimC)))
|
||||
|
||||
;; (= context (clCreateContext))
|
||||
|
||||
;; (= command_queue (clCreateCommandQueue context))
|
||||
|
||||
;; (= buffer_A (clCreateBuffer context CL_MEM_READ_ONLY dimA))
|
||||
;; (= buffer_B (clCreateBuffer context CL_MEM_READ_ONLY dimB))
|
||||
;; (= buffer_C (clCreateBuffer context CL_MEM_WRITE_ONLY dimC))
|
||||
|
||||
;; (= program (clCreateProgramWithSource context "kernel.rkt"))
|
||||
|
||||
;; (clEnqueueWriteBuffer command_queue buffer_A 0 dimA A)
|
||||
;; (clEnqueueWriteBuffer command_queue buffer_B 0 dimB B)
|
||||
|
||||
;; (= kernel (clCreateKernel program kernelName))
|
||||
;; (clSetKernelArg kernel 0 buffer_A)
|
||||
;; (clSetKernelArg kernel 1 buffer_B)
|
||||
;; (clSetKernelArg kernel 2 buffer_C)
|
||||
;; (clSetKernelArg kernel 3 p)
|
||||
;; (clSetKernelArg kernel 4 m)
|
||||
|
||||
;; (clEnqueueNDRangeKernel command_queue kernel 2 NULL global NULL)
|
||||
;; (clEnqueueReadBuffer command_queue buffer_C 0 dimC C)
|
||||
C)
|
||||
;; ; A scalar parallel implementation of matrix multiplication.
|
||||
;; (procedure int* (mmulScalar [int* A] [int* B] [int n] [int p] [int m])
|
||||
;; (mmulHost "mmulScalarKernel" 1 A B n p m))
|
||||
|
||||
; A vector parallel implementation of matrix multiplication. The dimensions
|
||||
; n and m must be evenly divisible by 4.
|
||||
#;(procedure int* (mmulVector [int* A] [int* B] [int n] [int p] [int m])
|
||||
(mmulHost "mmulVectorKernel" 4 A B n p m))
|
||||
|
||||
;; ; Given two arrays of the same size, checks that they hold the same
|
||||
;; ; values at each index.
|
||||
;; (procedure void (check [int* actual] [int* expected] [int SIZE])
|
||||
;; (assert (>= SIZE 0))
|
||||
;; (for [(: int i in (range SIZE))]
|
||||
;; (assert (== [actual i] [expected i]))))
|
||||
|
||||
;; (procedure void (synth_vector [int size])
|
||||
;; (synth #:forall [(: int n in (range size (+ 1 size)))
|
||||
;; (: int p in (range size (+ 1 size)))
|
||||
;; (: int m in (range size (+ 1 size)))
|
||||
;; (: int[(* n p)] A)
|
||||
;; (: int[(* p m)] B)]
|
||||
;; #:ensure (check (mmulVector A B n p m)
|
||||
;; (mmulSequential A B n p m)
|
||||
;; (* n m))))
|
||||
|
||||
;; ; (synth_vector 4) ; 20 sec
|
||||
;; ; (synth_vector 8) ; 252 sec
|
Loading…
Reference in New Issue
Block a user