diff --git a/turnstile/examples/rosette/synthcl3.rkt b/turnstile/examples/rosette/synthcl3.rkt index 97101d5..0e8e853 100644 --- a/turnstile/examples/rosette/synthcl3.rkt +++ b/turnstile/examples/rosette/synthcl3.rkt @@ -19,7 +19,7 @@ procedure kernel grammar #%datum if range for print choose locally-scoped assert synth verify int int2 int3 int4 int16 float float2 float3 float4 float16 - bool void void* char* float* int* int16* + bool void void* char* float* int* int16* int2* : ! ?: == + * - || && % << ; int ops = += -= %= ; assignment ops @@ -48,10 +48,6 @@ (define (pointer-type? t) (Pointer? t) #;(regexp-match #px"\\*$" (type->str t))) - (define (type-base t) - (datum->syntax t - (string->symbol - (car (regexp-match #px"[a-z]+" (type->str t)))))) (define (real-type<=? t1 t2) (and (real-type? t1) (real-type? t2) (or ((current-type=?) t1 t2) ; need type= to distinguish reals/ints @@ -59,7 +55,7 @@ (and (typecheck/un? t1 #'int) (not (typecheck/un? t2 #'bool))) (and (typecheck/un? t1 #'float) - (typecheck/un? (type-base t2) #'float))))) + (typecheck/un? (get-base t2) #'float))))) ; Returns the common real type of the given types, as specified in ; Ch. 6.2.6 of opencl-1.2 specification. If there is no common @@ -112,14 +108,14 @@ (regexp-match #px"([a-z]+)([0-9]+)" (type->str ty))) (define (real-type-length t) (define split-ty (ty->len t)) - (or (and split-ty (third split-ty)) 1)) + (string->number + (or (and split-ty (third split-ty)) "1"))) (define (get-base ty [ctx #'here]) ((current-type-eval) (datum->syntax ctx (string->symbol (car (regexp-match #px"[a-z]+" (type->str ty))))))) (define (vector-type? ty) - ;; TODO: and not pointer-type? - (ty->len ty)) + (ty->len ty)) ; TODO: check and not pointer-type? (define (scalar-type? ty) (or (typecheck/un? ty #'bool) (and (real-type? ty) (not (vector-type? ty)))))) @@ -147,9 +143,13 @@ [(ro:fixnum? v) (ro:exact->inexact v)] [(ro:flonum? v) v] [else (ro:type-cast ro:real? v)])) +(ro:define (mk-int v) + (ro:#%app cl:int v)) (ro:define (to-int16* v) (cl:pointer-cast v cl:int16)) +(ro:define (to-int2* v) + (cl:pointer-cast v cl:int2)) (ro:define (to-int* v) (cl:pointer-cast v cl:int)) (ro:define (to-float* v) @@ -161,7 +161,8 @@ (add-convertm rosette3:Int to-int)) (define-named-type-alias float (add-convertm rosette3:Num to-float)) -(define-named-type-alias char* rosette3:CString) +(define-named-type-alias char* + (add-convertm rosette3:CString (λ (x) x))) (define-syntax (define-int stx) (syntax-parse stx @@ -169,6 +170,7 @@ #:with intn (format-id #'n "int~a" (syntax->datum #'n)) #:with to-intn (format-id #'n "to-~a" #'intn) #:with mk-intn (format-id #'n "mk-~a" #'intn) + #:with cl-mk-intn (mk-cl #'intn) #:with (x ...) (generate-temporaries (build-list (syntax->datum #'n) (lambda (x) x))) #:with (I ...) (stx-map (lambda _ #'rosette3:Int) #'(x ...)) @@ -191,7 +193,8 @@ (ro:apply ro:vector-immutable (ro:make-list n (to-int v)))])) (ro:define (mk-intn x ...) - (ro:#%app ro:vector-immutable (to-int x) ...)) + (ro:#%app cl-mk-intn x ...) + #;(ro:#%app ro:vector-immutable (to-int x) ...)) )])) (define-simple-macro (define-ints n ...) (begin (define-int n) ...)) (define-ints 2 3 4 16) @@ -202,6 +205,7 @@ #:with floatn (format-id #'n "float~a" (syntax->datum #'n)) #:with to-floatn (format-id #'n "to-~a" #'floatn) #:with mk-floatn (format-id #'n "mk-~a" #'floatn) + #:with cl-mk-floatn (mk-cl #'floatn) #:with (x ...) (generate-temporaries (build-list (syntax->datum #'n) (lambda (x) x))) #:with (I ...) (stx-map (lambda _ #'rosette3:Num) #'(x ...)) @@ -224,7 +228,8 @@ (ro:apply ro:vector-immutable (ro:make-list n (to-float v)))])) (ro:define (mk-floatn x ...) - (ro:#%app ro:vector-immutable (to-float x) ...)) + (ro:#%app cl-mk-floatn x ...) + #;(ro:#%app ro:vector-immutable (to-float x) ...)) )])) (define-simple-macro (define-floats n ...) (begin (define-float n) ...)) (define-floats 2 3 4 16) @@ -243,6 +248,8 @@ (add-convertm (Pointer int) to-int*)) (define-named-type-alias int16* (add-convertm (Pointer int16) to-int16*)) +(define-named-type-alias int2* + (add-convertm (Pointer int2) to-int2*)) (define-named-type-alias float* (add-convertm (Pointer float) to-float*)) @@ -325,11 +332,17 @@ (format "expected void, given ~a" (type->str #'ty-out.norm)) -------- [≻ (rosette3:define (f [x : ty] ... -> void) (⊢m (ro:void) void))]] - [(_ ty-out:type (f [ty:type x:id] ...) e ...+) ≫ + [(_ ty-out:type (f [ty:type x:id] ...) e ... e-body) ≫ #:with (conv ...) (stx-map get-convert #'(ty.norm ...)) + #:with f- (add-orig (generate-temporary #'f) #'f) -------- - [≻ (rosette3:define (f [x : ty] ... -> ty-out) - (rosette3:let ([x (⊢m (ro:#%app conv x) ty)] ...) e ...))]]) + [≻ (begin- + (define-syntax- f + (make-rename-transformer (⊢ f- : (C→ ty ... ty-out)))) + (define- f- + (lambda- (x ...) + (rosette3:let ([x (⊢m (ro:#%app conv x) ty)] ...) + (⊢m (let- () e ... (rosette3:ann e-body : ty-out)) ty-out)))))]]) (define-typed-syntax kernel [(_ ty-out:type (f [ty:type x:id] ...) e ...) ≫ #:fail-unless (void? #'ty-out.norm) @@ -407,6 +420,7 @@ (ro:define x- (ro:#%datum . "")) ...)]] ;; TODO: vector types need a better representation ;; currently dissecting the identifier + ;; TODO: combine vector and scalar cases [(_ ty:type x:id ...) ≫ #:when (real-type? #'ty.norm) #:do [(define split-ty (ty->len #'ty))] @@ -425,6 +439,30 @@ (ro:define x- (ro:apply ro:vector-immutable x--)) ... (define-syntax- x (make-rename-transformer (assign-type #'x- #'ty.norm))) ...)]] + [(_ ty:type [len] x:id ...) ≫ ; array of vector types + #:when (real-type? #'ty.norm) + [⊢ len ≫ len- ⇐ int] + #:with ty-base (get-base #'ty.norm) + #:with base-len (datum->syntax #'ty (real-type-length #'ty.norm)) + #:with ty* (format-id #'ty "~a*" #'ty) + #:with to-ty* (format-id #'here "to-~a" #'ty*) + #:with pred (get-pred ((current-type-eval) #'ty-base)) + #:fail-unless (syntax-e #'pred) + (format "no pred for ~a" (type->str #'ty)) + #:with (x- ...) (generate-temporaries #'(x ...)) + #:with (*x ...) (generate-temporaries #'(x ...)) + #:with (x-- ...) (generate-temporaries #'(x ...)) + #:with mk-ty (format-id #'here "mk-~a" #'ty) + -------- + [≻ (begin- + (ro:define-symbolic* x-- pred [len base-len]) ... + (ro:define x- + (ro:let ([*x (to-ty* (cl:malloc (ro:* len base-len)))]) + (ro:for ([i len][v x--]) + (cl:pointer-set! *x i (ro:apply mk-ty v))) + *x)) ... + (define-syntax- x + (make-rename-transformer (assign-type #'x- #'ty*))) ...)]] ;; real, scalar (ie non-vector) types [(_ ty:type x:id ...) ≫ #:when (real-type? #'ty.norm) diff --git a/turnstile/examples/tests/rosette/rosette3/synthcl3-matrix-tests.rkt b/turnstile/examples/tests/rosette/rosette3/synthcl3-matrix-tests.rkt new file mode 100644 index 0000000..9d5761c --- /dev/null +++ b/turnstile/examples/tests/rosette/rosette3/synthcl3-matrix-tests.rkt @@ -0,0 +1,100 @@ +#lang s-exp "../../../rosette/synthcl3.rkt" +(require "../../rackunit-typechecking.rkt" + (prefix-in cl: sdsl/synthcl/lang/main) + (prefix-in ro: (rename-in rosette [#%app a]))) + +(: int2 [3] xs) +xs +(check-type xs : int2*) + +(: int [4] xs2) +xs2 +(check-type xs2 : int*) + + + +;; ; The reference implementation for square matrix multiplication. +;; ; Multiplies two squre matrices A and B, where the dimension of A is +;; ; n x p and dimension of B is p x m. Both matrices are given as +;; ; flat arrays in row-major form. The output is the matrix C = A*B, +;; ; also given in row-major form. +;; (procedure int* (mmulSequential [int* A] [int* B] [int n] [int p] [int m]) +;; (: int* C) +;; (= C ((int*) (malloc (* n m (sizeof int))))) +;; (for [(: int i in (range n)) +;; (: int j in (range m)) +;; (: int k in (range p))] +;; (+= [C (+ (* i m) j)] (* [A (+ (* i p) k)] [B (+ (* k m) j)]))) +;; C) + +; A host implementation of matrix multiplication. +(procedure int* (mmulHost [char* kernelName] [int typeLen] + [int* A] [int* B] [int n] [int p] [int m]) + ;; (: cl_context context) + ;; (: cl_command_queue command_queue) + ;; (: cl_program program) + ;; (: cl_kernel kernel) + ;; (: cl_mem buffer_A buffer_B buffer_C) + (: int* C) + ;; (: int[2] global) + ;; (: int dimA dimB dimC) + + ;; (= [global 0] (/ n typeLen)) + ;; (= [global 1] (/ m typeLen)) + ;; (= dimA (* n p (sizeof int))) + ;; (= dimB (* p m (sizeof int))) + ;; (= dimC (* n m (sizeof int))) + + ;; (= C ((int*) (malloc dimC))) + + ;; (= context (clCreateContext)) + + ;; (= command_queue (clCreateCommandQueue context)) + + ;; (= buffer_A (clCreateBuffer context CL_MEM_READ_ONLY dimA)) + ;; (= buffer_B (clCreateBuffer context CL_MEM_READ_ONLY dimB)) + ;; (= buffer_C (clCreateBuffer context CL_MEM_WRITE_ONLY dimC)) + + ;; (= program (clCreateProgramWithSource context "kernel.rkt")) + + ;; (clEnqueueWriteBuffer command_queue buffer_A 0 dimA A) + ;; (clEnqueueWriteBuffer command_queue buffer_B 0 dimB B) + + ;; (= kernel (clCreateKernel program kernelName)) + ;; (clSetKernelArg kernel 0 buffer_A) + ;; (clSetKernelArg kernel 1 buffer_B) + ;; (clSetKernelArg kernel 2 buffer_C) + ;; (clSetKernelArg kernel 3 p) + ;; (clSetKernelArg kernel 4 m) + + ;; (clEnqueueNDRangeKernel command_queue kernel 2 NULL global NULL) + ;; (clEnqueueReadBuffer command_queue buffer_C 0 dimC C) + C) +;; ; A scalar parallel implementation of matrix multiplication. +;; (procedure int* (mmulScalar [int* A] [int* B] [int n] [int p] [int m]) +;; (mmulHost "mmulScalarKernel" 1 A B n p m)) + +; A vector parallel implementation of matrix multiplication. The dimensions +; n and m must be evenly divisible by 4. +#;(procedure int* (mmulVector [int* A] [int* B] [int n] [int p] [int m]) + (mmulHost "mmulVectorKernel" 4 A B n p m)) + +;; ; Given two arrays of the same size, checks that they hold the same +;; ; values at each index. +;; (procedure void (check [int* actual] [int* expected] [int SIZE]) +;; (assert (>= SIZE 0)) +;; (for [(: int i in (range SIZE))] +;; (assert (== [actual i] [expected i])))) + +;; (procedure void (synth_vector [int size]) +;; (synth #:forall [(: int n in (range size (+ 1 size))) +;; (: int p in (range size (+ 1 size))) +;; (: int m in (range size (+ 1 size))) +;; (: int[(* n p)] A) +;; (: int[(* p m)] B)] +;; #:ensure (check (mmulVector A B n p m) +;; (mmulSequential A B n p m) +;; (* n m)))) + +;; ; (synth_vector 4) ; 20 sec +;; ; (synth_vector 8) ; 252 sec