diff --git a/turnstile/examples/rosette/rosette3.rkt b/turnstile/examples/rosette/rosette3.rkt index 63b917d..12aa27a 100644 --- a/turnstile/examples/rosette/rosette3.rkt +++ b/turnstile/examples/rosette/rosette3.rkt @@ -56,7 +56,7 @@ (local-expand e 'expression (list #'ro:#%app #'ro:choose #'ro:synthesize #'ro:let #'ro:in-value - #'ro:assert #'ro:if #'ro:?))) + #'ro:assert #'ro:if #'ro:? #'ro:verify))) ; (displayln (stx->datum e+)) e+) (define (mk-ro:-id id) (format-id id "ro:~a" id)) diff --git a/turnstile/examples/rosette/synthcl3.rkt b/turnstile/examples/rosette/synthcl3.rkt index 193fdad..738c6a8 100644 --- a/turnstile/examples/rosette/synthcl3.rkt +++ b/turnstile/examples/rosette/synthcl3.rkt @@ -3,15 +3,14 @@ (require ;(prefix-in ro: (except-in rosette verify sqrt range print)) ; untyped racket/stxparam (prefix-in ro: (combine-in rosette rosette/lib/synthax)) - (prefix-in cl: (combine-in + (prefix-in cl: (combine-in (except-in sdsl/synthcl/model/operators /) sdsl/synthcl/lang/forms sdsl/synthcl/model/reals - sdsl/synthcl/model/operators sdsl/synthcl/model/errors + sdsl/synthcl/model/errors sdsl/synthcl/model/kernel sdsl/synthcl/model/memory sdsl/synthcl/model/runtime sdsl/synthcl/model/work sdsl/synthcl/model/pointers sdsl/synthcl/lang/queries sdsl/synthcl/model/context sdsl/synthcl/model/queue sdsl/synthcl/model/buffer - sdsl/synthcl/model/flags sdsl/synthcl/model/program - sdsl/synthcl/model/kernel)) + sdsl/synthcl/model/flags sdsl/synthcl/model/program)) (for-syntax (prefix-in cl: sdsl/synthcl/lang/util))) (begin-for-syntax @@ -24,12 +23,11 @@ int int2 int3 int4 int16 float float2 float3 float4 float16 bool void void* char* float* int* int2* int3* int4* int16* cl_context cl_command_queue cl_program cl_kernel cl_mem - : ! ?: == + * - || && + : ! ?: == + * / - || && % << ; int ops - = += -= %= ; assignment ops + = += -= *= /= %= ; assignment ops sizeof clCreateProgramWithSource (typed-out - ;[with-output-to-string : (C→ (C→ Any) char*)] [clCreateContext : (C→ cl_context)] [clCreateCommandQueue : (C→ cl_context cl_command_queue)] [clCreateBuffer : (C→ cl_context int int cl_mem)] @@ -44,7 +42,6 @@ [get_global_id : (C→ int int)] [CL_MEM_READ_ONLY : int] [CL_MEM_WRITE_ONLY : int] - [/ : (Ccase-> (C→ int int int) (C→ float float float))] [malloc : (C→ int void*)] [get_work_dim : (C→ int)] [!= : (Ccase-> (C→ CNum CNum CBool) @@ -115,6 +112,7 @@ expr subexpr))) (define (mk-ptr id) (format-id id "~a*" id)) (define (mk-mk id) (format-id id "mk-~a" id)) + (define (mk-to id) (format-id id "to-~a" id)) (define (add-convert stx fn) (set-stx-prop/preserved stx 'convert fn)) (define (get-convert stx) @@ -202,10 +200,10 @@ (ro:define (to-intn v) (ro:cond [(ro:list? v) - (ro:apply mk-intn (ro:for/list ([i n]) (ro:#%app to-int (ro:list-ref v i))))] + (ro:apply mk-intn (ro:for/list ([i n]) (to-int (ro:list-ref v i))))] [(ro:vector? v) - (ro:apply mk-intn (ro:for/list ([i n]) (ro:#%app to-int (ro:vector-ref v i))))] - [else (ro:apply mk-intn (ro:make-list n (ro:#%app to-int v)))])) + (ro:apply mk-intn (ro:for/list ([i n]) (to-int (ro:vector-ref v i))))] + [else (ro:apply mk-intn (ro:make-list n (to-int v)))])) (ro:define (to-intn* v) (cl:pointer-cast v cl-mk-intn)) (ro:define (mk-intn x ...) (ro:#%app cl-mk-intn x ...)))])) (define-simple-macro (define-ints n ...) (begin (define-int n) ...)) @@ -227,10 +225,10 @@ (ro:define (to-floatn v) (ro:cond [(ro:list? v) - (ro:apply mk-floatn (ro:for/list ([i n]) (ro:#%app to-float (ro:list-ref v i))))] + (ro:apply mk-floatn (ro:for/list ([i n]) (to-float (ro:list-ref v i))))] [(ro:vector? v) - (ro:apply mk-floatn (ro:for/list ([i n]) (ro:#%app to-float (ro:vector-ref v i))))] - [else (ro:apply mk-floatn (ro:make-list n (ro:#%app to-float v)))])) + (ro:apply mk-floatn (ro:for/list ([i n]) (to-float (ro:vector-ref v i))))] + [else (ro:apply mk-floatn (ro:make-list n (to-float v)))])) (ro:define (mk-floatn x ...) (ro:#%app cl-mk-floatn x ...)))])) (define-simple-macro (define-floats n ...) (begin (define-float n) ...)) (define-floats 2 3 4 16) @@ -346,7 +344,7 @@ (define-typed-syntax if [(_ test {then ...} {else ...}) ≫ -------- - [⊢ (ro:if (ro:#%app to-bool test) + [⊢ (ro:if (to-bool test) (ro:let () then ... (ro:void)) (ro:let () else ... (ro:void))) ⇒ void]] [(_ test {then ...}) ≫ @@ -433,7 +431,7 @@ [≻ (begin- (ro:define-symbolic* x-- pred [len base-len]) ... (ro:define x- - (ro:let ([*x (ro:#%app to-ty* (cl:malloc (ro:* len base-len)))]) + (ro:let ([*x (to-ty* (cl:malloc (ro:* len base-len)))]) (ro:for ([i len][v x--]) (cl:pointer-set! *x i (ro:apply mk-ty v))) *x)) ... @@ -482,12 +480,12 @@ #:with ty-base ((current-type-eval) (datum->syntax #'e (string->symbol out-base-str))) #:with base-convert (get-convert #'ty-base) ------- - [⊢ (ro:#%app convert - (ro:let ([a (ro:#%app convert e-)][b (ro:#%app convert e1-)][c (ro:#%app convert e2-)]) + [⊢ (convert + (ro:let ([a (convert e-)][b (convert e1-)][c (convert e2-)]) (ro:for/list ([idx #,(string->number out-len-str)]) (ro:if (ro:< (ro:vector-ref a idx) 0) - (ro:#%app base-convert (ro:vector-ref b idx)) - (ro:#%app base-convert (ro:vector-ref c idx)))))) + (base-convert (ro:vector-ref b idx)) + (base-convert (ro:vector-ref c idx)))))) ⇒ ty-out]] [(_ ~! e e1 e2) ≫ ; should be scalar and real [⊢ e ≫ e- ⇒ ty] @@ -514,22 +512,21 @@ (format "cannot cast ~a to ~a" (type->str #'ty-e) (type->str #'ty-x)) #:with conv (get-convert #'ty-x) -------- - [⊢ (ro:set! x- #,(if (syntax-e #'conv) #'(ro:#%app conv e-) #'e-)) ⇒ void]] + [⊢ (ro:set! x- #,(if (syntax-e #'conv) #'(conv e-) #'e-)) ⇒ void]] ;; selector can be list of numbers or up to wxyz for vectors of length <=4 [(_ [x:id sel] e) ≫ [⊢ x ≫ x- ⇒ ty-x] [⊢ e ≫ e- ⇒ ty-e] #:with out-e (if (pointer-type? #'ty-x) - #'(ro:begin - (cl:pointer-set! x- sel e-) - x-) + (with-syntax ([conv (mk-to (get-pointer-base #'ty-x))]) + #'(ro:begin (cl:pointer-set! x- sel (conv e-)) x-)) (with-syntax ([selector (cl:parse-selector #f #'sel stx)]) - #`(ro:let ([out (ro:vector-copy x-)]) - #,(if (= 1 (length (stx->list #'selector))) + #`(ro:let ([out (ro:vector-copy x-)]) + #,(if (= 1 (length (stx->list #'selector))) #`(ro:vector-set! out (car 'selector) e-) #'(ro:for ([idx 'selector] [v e-]) (ro:vector-set! out idx v))) - out))) + out))) ; TODO: need mk-ty here? -------- [⊢ (ro:set! x- out-e) ⇒ void]]) @@ -541,7 +538,7 @@ [(_ e) ≫ ; else try to coerce [⊢ e ≫ e- ⇒ ty] -------- - [⊢ (ro:#%app cl:! (ro:#%app to-bool e-)) ⇒ bool]]) + [⊢ (ro:#%app cl:! (to-bool e-)) ⇒ bool]]) ;; TODO: this should produce int-vector result? (define-typed-syntax == @@ -552,7 +549,7 @@ #:when (real-type? #'ty2) #:with ty-out ((current-join) #'ty1 #'ty2) ; only need this for the len -------- - [⊢ (ro:#%app to-int (cl:== e1- e2-)) ⇒ int]]) + [⊢ (to-int (cl:== e1- e2-)) ⇒ int]]) (define-simple-macro (define-bool-ops o ...+) (ro:begin (define-bool-op o) ...)) (define-simple-macro (define-bool-op name) @@ -564,7 +561,12 @@ -------- [⊢ (name- e1- e2-) ⇒ bool]] [(_ e1 e2) ≫ ; else try to coerce - --- [⊢ (name- (ro:#%app to-bool e1) (ro:#%app to-bool e2)) ⇒ bool]])) + --- [⊢ (name- (to-bool e1) (to-bool e2)) ⇒ bool]])) + +(define- (cl:/ x y) + (cond- [(zero?- y) 0] + [(integer?- x) (quotient- x y)] + [else (/- x y)])) (define-simple-macro (define-real-ops o ...) (ro:begin (define-real-op o) ...)) (define-simple-macro (define-real-op name (~optional (~seq #:extra-check p?) @@ -585,10 +587,10 @@ #:with (x (... ...)) (generate-temporaries #'(e (... ...))) -------- [⊢ #,(if (scalar-type? #'ty-out) - #'(ro:#%app convert (name- (convert e-) (... ...))) - #'(ro:#%app convert (ro:let ([x (ro:#%app convert e-)] (... ...)) + #'(convert (name- (convert e-) (... ...))) + #'(convert (ro:let ([x (convert e-)] (... ...)) (ro:for/list ([x x] (... ...)) - (ro:#%app base-convert (name- x (... ...))))))) ⇒ ty-out]) + (base-convert (name- x (... ...))))))) ⇒ ty-out]) (define-typed-syntax (name= x e) ≫ --- [≻ (= x (name x e))]))) @@ -601,14 +603,11 @@ (define-simple-macro (define-int-ops o ...) (ro:begin (define-int-op o) ...)) (define-bool-ops || &&) -(define-real-ops + * -) +(define-real-ops + * - /) (define-int-ops % <<) -(define-typerule (sizeof t:type) >> - --- [⊢ #,(real-type-length #'t.norm) ⇒ int]) - -(define-typerule (print e ...) >> - --- [⊢ (ro:begin (display e) ...) ⇒ void]) +(define-typerule (sizeof t:type) >> ---[⊢ #,(real-type-length #'t.norm) ⇒ int]) +(define-typerule (print e ...) >> ---[⊢ (ro:begin (display e) ...) ⇒ void]) (define-typed-syntax choose [(ch e ...+) ≫ @@ -656,10 +655,8 @@ [(_ #:forall [decl ...] #:ensure e) ≫ --- [≻ (synth #:forall [decl ...] #:bitwidth 8 #:ensure e)]]) - - (define-typed-syntax verify - [(vfy #:forall [decl ...] #:ensure e) ≫ + [(_ #:forall [decl ...] #:ensure e) ≫ #:with ([id seq ty] ...) (stx-map decl->seq #'(decl ...)) #:with (id- ...) (generate-temporaries #'(id ...)) #:with (typed-seq ...) #'((with-ctx ([id id- ty] ...) seq) ...) @@ -667,20 +664,17 @@ [⊢ (ro:let ([id- 1] ...) ; dummy, enables simplifying stx template (ro:parameterize ([ro:current-bitwidth 32] [ro:term-cache (ro:hash-copy (ro:term-cache))]) - (ro:for*/or ([id- typed-seq] ...) - (ro:define cex (with-ctx ([id id- ty] ...) (ro:verify e))) - (ro:and (ro:sat? cex) - (displayln "counterexample found:") - (ro:for ([i '(id ...)] [i- (ro:list id- ...)]) - (printf "~a = ~a\n" i (ro:evaluate i- cex))))))) ⇒ void]]) + (ro:or (ro:for*/or ([id- typed-seq] ...) + (ro:define cex (with-ctx ([id id- ty] ...) (ro:verify e))) + (ro:and (ro:sat? cex) + (displayln "counterexample found:") + (ro:for ([i '(id ...)] [i- (ro:list id- ...)]) + (printf "~a = ~a\n" i (ro:evaluate i- cex))) + cex)) + (begin (displayln "no counterexample found") (ro:unsat))))) ⇒ void]]) (define-typed-syntax (assert e) ≫ - --- [⊢ (ro:assert (ro:#%app to-bool #,(expand/ro #'e))) ⇒ void]) - -(define- (/ x y) - (cond- [(zero?- y) 0] - [(integer?- x) (quotient- x y)] - [else (/- x y)])) + --- [⊢ (ro:assert (to-bool #,(expand/ro #'e))) ⇒ void]) (define-typed-syntax (clCreateProgramWithSource ctx f) ≫ --- [⊢ (cl:clCreateProgramWithSource ctx f) ⇒ cl_program]) diff --git a/turnstile/examples/tests/rosette/rosette3/matrix-verify-kernel.rkt b/turnstile/examples/tests/rosette/rosette3/matrix-verify-kernel.rkt new file mode 100644 index 0000000..9f0dfb3 --- /dev/null +++ b/turnstile/examples/tests/rosette/rosette3/matrix-verify-kernel.rkt @@ -0,0 +1,127 @@ +#lang s-exp "../../../rosette/synthcl3.rkt" + +; Matrix multiplication C = A * B, where A is an n x p matrix and B is a p x m matrix. +(kernel void (mmulScalarKernel [int* A] [int* B] [int* C] [int p] [int m]) + (: int i j sum) + (= i (get_global_id 0)) + (= j (get_global_id 1)) + (= sum 0) + (for [(: int k in (range p))] + (+= sum (* [A (+ (* i p) k)] [B (+ (* k m) j)]))) + (= [C (+ (* i m) j)] sum)) + +;;--------------- Vectorized kernel ---------------;; + +; Matrix multiplication C = A * B, where A is an n x p matrix and B is a p x m matrix. +(kernel void (mmulVectorKernel [int4* A] [int4* B] [int4* C] [int p] [int m]) + (: int i j) + (: int4 sum0 sum1 sum2 sum3) + + (= i (get_global_id 0)) + (= j (get_global_id 1)) + (= sum0 0) + (= sum1 0) + (= sum2 0) + (= sum3 0) + + (for [(: int k in (range 0 p 4))] + (: int4 a0 a1 a2 a3) + (: int4 b0 b1 b2 b3) + + (= a0 [A (indexA 0 i k p)]) + (= a1 [A (indexA 1 i k p)]) + (= a2 [A (indexA 2 i k p)]) + (= a3 [A (indexA 3 i k p)]) + + (= b0 [B (indexB 0 k j m)]) + (= b1 [B (indexB 1 k j m)]) + (= b2 [B (indexB 2 k j m)]) + (= b3 [B (indexB 3 k j m)]) + + (+= sum0 (computeSum a0 b0 b1 b2 b3)) + (+= sum1 (computeSum a1 b0 b1 b2 b3)) + (+= sum2 (computeSum a2 b0 b1 b2 b3)) + (+= sum3 (computeSum a3 b0 b1 b2 b3))) + + (= [C (indexC 0 i j m)] sum0) + (= [C (indexC 1 i j m)] sum1) + (= [C (indexC 2 i j m)] sum2) + (= [C (indexC 3 i j m)] sum3)) + +; Multiplies the 1x4 vector a by the 4x4 matrix with rows b0, b1, b2 and b3. +(procedure int4 (computeSum [int4 a] [int4 b0] [int4 b1] [int4 b2] [int4 b3]) + (int4 + (+ (* [a x] [b0 x]) (* [a y] [b1 x]) (* [a z] [b2 x]) (* [a w] [b3 x])) + (+ (* [a x] [b0 y]) (* [a y] [b1 y]) (* [a z] [b2 y]) (* [a w] [b3 y])) + (+ (* [a x] [b0 z]) (* [a y] [b1 z]) (* [a z] [b2 z]) (* [a w] [b3 z])) + (+ (* [a x] [b0 w]) (* [a y] [b1 w]) (* [a z] [b2 w]) (* [a w] [b3 w])))) + +(procedure int (indexA [int off] [int i] [int k] [int p]) + (+ (* (+ (* i 4) off) (/ p 4)) (/ k 4))) + +(procedure int (indexB [int off] [int k] [int j] [int m]) + (+ (* (+ k off) (/ m 4)) j)) + +(procedure int (indexC [int off] [int i] [int j] [int m]) + (+ (* (+ (* i 4) off) (/ m 4)) j)) + + +;;; ---------------- Optimized kernel implementation transcribed from AMD's apps ---------------- ;;; +(: int TILEX TILEX_SHIFT TILEY TILEY_SHIFT) +(= TILEX 4) +(= TILEY_SHIFT 2) +(= TILEY 4) +(= TILEY_SHIFT 2) + +; Matrix multiplication C = A * B, where A is an n x p matrix and B is an +; p x m matrix. +(kernel void (mmulVectorKernelOpt [int4* A] [int4* B] [int4* C] [int p] [int m]) + (: int2 pos) + (: int4 sum0 sum1 sum2 sum3) + + (= pos (int2 (get_global_id 0) (get_global_id 1))) + (= sum0 0) + (= sum1 0) + (= sum2 0) + (= sum3 0) + + (/= m 4) + + (for [(: int i in (range 0 p 4))] + (: int4 a0 a1 a2 a3) + (: int4 b0 b1 b2 b3) + + (= a0 [A (+ (/ i 4) (* (<< [pos x] TILEY_SHIFT) (/ p 4)))]) + (= a1 [A (+ (/ i 4) (* (+ (<< [pos x] TILEY_SHIFT) 1) (/ p 4)))]) + (= a2 [A (+ (/ i 4) (* (+ (<< [pos x] TILEY_SHIFT) 2) (/ p 4)))]) + (= a3 [A (+ (/ i 4) (* (+ (<< [pos x] TILEY_SHIFT) 3) (/ p 4)))]) + + (= b0 [B (+ [pos y] (* i m))]) + (= b1 [B (+ [pos y] (* (+ i 1) m))]) + (= b2 [B (+ [pos y] (* (+ i 2) m))]) + (= b3 [B (+ [pos y] (* (+ i 3) m))]) + + (+= [sum0 x] (+ (* [a0 x] [b0 x]) (* [a0 y] [b1 x]) (* [a0 z] [b2 x]) (* [a0 w] [b3 x]))) + (+= [sum0 y] (+ (* [a0 x] [b0 y]) (* [a0 y] [b1 y]) (* [a0 z] [b2 y]) (* [a0 w] [b3 y]))) + (+= [sum0 z] (+ (* [a0 x] [b0 z]) (* [a0 y] [b1 z]) (* [a0 z] [b2 z]) (* [a0 w] [b3 z]))) + (+= [sum0 w] (+ (* [a0 x] [b0 w]) (* [a0 y] [b1 w]) (* [a0 z] [b2 w]) (* [a0 w] [b3 w]))) + + (+= [sum1 x] (+ (* [a1 x] [b0 x]) (* [a1 y] [b1 x]) (* [a1 z] [b2 x]) (* [a1 w] [b3 x]))) + (+= [sum1 y] (+ (* [a1 x] [b0 y]) (* [a1 y] [b1 y]) (* [a1 z] [b2 y]) (* [a1 w] [b3 y]))) + (+= [sum1 z] (+ (* [a1 x] [b0 z]) (* [a1 y] [b1 z]) (* [a1 z] [b2 z]) (* [a1 w] [b3 z]))) + (+= [sum1 w] (+ (* [a1 x] [b0 w]) (* [a1 y] [b1 w]) (* [a1 z] [b2 w]) (* [a1 w] [b3 w]))) + + (+= [sum2 x] (+ (* [a2 x] [b0 x]) (* [a2 y] [b1 x]) (* [a2 z] [b2 x]) (* [a2 w] [b3 x]))) + (+= [sum2 y] (+ (* [a2 x] [b0 y]) (* [a2 y] [b1 y]) (* [a2 z] [b2 y]) (* [a2 w] [b3 y]))) + (+= [sum2 z] (+ (* [a2 x] [b0 z]) (* [a2 y] [b1 z]) (* [a2 z] [b2 z]) (* [a2 w] [b3 z]))) + (+= [sum2 w] (+ (* [a2 x] [b0 w]) (* [a2 y] [b1 w]) (* [a2 z] [b2 w]) (* [a2 w] [b3 w]))) + + (+= [sum3 x] (+ (* [a3 x] [b0 x]) (* [a3 y] [b1 x]) (* [a3 z] [b2 x]) (* [a3 w] [b3 x]))) + (+= [sum3 y] (+ (* [a3 x] [b0 y]) (* [a3 y] [b1 y]) (* [a3 z] [b2 y]) (* [a3 w] [b3 y]))) + (+= [sum3 z] (+ (* [a3 x] [b0 z]) (* [a3 y] [b1 z]) (* [a3 z] [b2 z]) (* [a3 w] [b3 z]))) + (+= [sum3 w] (+ (* [a3 x] [b0 w]) (* [a3 y] [b1 w]) (* [a3 z] [b2 w]) (* [a3 w] [b3 w])))) + + (= [C (+ [pos y] (* (+ (<< [pos x] TILEY_SHIFT) 0) m))] sum0) + (= [C (+ [pos y] (* (+ (<< [pos x] TILEY_SHIFT) 1) m))] sum1) + (= [C (+ [pos y] (* (+ (<< [pos x] TILEY_SHIFT) 2) m))] sum2) + (= [C (+ [pos y] (* (+ (<< [pos x] TILEY_SHIFT) 3) m))] sum3)) diff --git a/turnstile/examples/tests/rosette/rosette3/run-all-rosette-tests-script.rkt b/turnstile/examples/tests/rosette/rosette3/run-all-rosette-tests-script.rkt index 1a03093..cd69a2d 100644 --- a/turnstile/examples/tests/rosette/rosette3/run-all-rosette-tests-script.rkt +++ b/turnstile/examples/tests/rosette/rosette3/run-all-rosette-tests-script.rkt @@ -20,8 +20,10 @@ (do-tests "bv-tests.rkt" "BV SDSL - General" "fsm3-tests.rkt" "FSM" - "ifc3-tests.rkt" "IFC" - "synthcl3-tests.rkt" "SynthCL" - "synthcl3-matrix-synth-tests.rkt" "SynthCL Matrix Mult: synth") + "ifc3-tests.rkt" "IFC") +(do-tests "synthcl3-tests.rkt" "SynthCL" + "synthcl3-matrix-synth-tests.rkt" "SynthCL Matrix Mult: synth" + "synthcl3-matrix-verify-tests.rkt" "SynthCL Matrix Mult: verify" + "synthcl3-matrix-verify-buggy-tests.rkt" "SynthCL buggy Matrix Mult: verify") (do-tests "bv-ref-tests.rkt" "BV SDSL - Hacker's Delight synthesis") diff --git a/turnstile/examples/tests/rosette/rosette3/synthcl3-matrix-verify-buggy-tests.rkt b/turnstile/examples/tests/rosette/rosette3/synthcl3-matrix-verify-buggy-tests.rkt new file mode 100644 index 0000000..5a18df0 --- /dev/null +++ b/turnstile/examples/tests/rosette/rosette3/synthcl3-matrix-verify-buggy-tests.rkt @@ -0,0 +1,126 @@ +#lang s-exp "../../../rosette/synthcl3.rkt" +(require "../../rackunit-typechecking.rkt") + +; A buggy reference implementation for square matrix multiplication. +; Multiplies two squre matrices A and B, where the dimension of A is +; n x p and dimension of B is p x m. Both matrices are given as +; flat arrays in row-major form. The output is the matrix C = A*B, +; also given in row-major form. +(procedure int* (mmulSequential [int* A] [int* B] [int n] [int p] [int m]) + (: int* C) + (= C ((int*) (malloc (* n m (sizeof int))))) + (for [(: int i in (range n)) + (: int j in (range m)) + (: int k in (range 1 p))] ; seeded bug + (+= [C (+ (* i m) j)] (* [A (+ (* i p) k)] [B (+ (* k m) j)]))) + C) + +; A host implementation of matrix multiplication. +(procedure int* (mmulHost [char* kernelName] [int typeLen] + [int* A] [int* B] [int n] [int p] [int m]) + (: cl_context context) + (: cl_command_queue command_queue) + (: cl_program program) + (: cl_kernel kernel) + (: cl_mem buffer_A buffer_B buffer_C) + (: int* C) + (: int[2] global) + (: int dimA dimB dimC) + + (= [global 0] (/ n typeLen)) + (= [global 1] (/ m typeLen)) + (= dimA (* n p (sizeof int))) + (= dimB (* p m (sizeof int))) + (= dimC (* n m (sizeof int))) + + (= C ((int*) (malloc dimC))) + + (= context (clCreateContext)) + + (= command_queue (clCreateCommandQueue context)) + + (= buffer_A (clCreateBuffer context CL_MEM_READ_ONLY dimA)) + (= buffer_B (clCreateBuffer context CL_MEM_READ_ONLY dimB)) + (= buffer_C (clCreateBuffer context CL_MEM_WRITE_ONLY dimC)) + + (= program (clCreateProgramWithSource context "matrix-verify-kernel.rkt")) + + (clEnqueueWriteBuffer command_queue buffer_A 0 dimA A) + (clEnqueueWriteBuffer command_queue buffer_B 0 dimB B) + + (= kernel (clCreateKernel program kernelName)) + (clSetKernelArg kernel 0 buffer_A) + (clSetKernelArg kernel 1 buffer_B) + (clSetKernelArg kernel 2 buffer_C) + (clSetKernelArg kernel 3 p) + (clSetKernelArg kernel 4 m) + + (clEnqueueNDRangeKernel command_queue kernel 2 NULL global NULL) + (clEnqueueReadBuffer command_queue buffer_C 0 dimC C) + C) +; A scalar parallel implementation of matrix multiplication. +(procedure int* (mmulScalar [int* A] [int* B] [int n] [int p] [int m]) + (mmulHost "mmulScalarKernel" 1 A B n p m)) + +; A vector parallel implementation of matrix multiplication. The dimensions +; n and m must be evenly divisible by 4. +(procedure int* (mmulVector [int* A] [int* B] [int n] [int p] [int m]) + (mmulHost "mmulVectorKernel" 4 A B n p m)) + +; An optimized vector parallel implementation of matrix multiplication. The dimensions +; n and m must be evenly divisible by 4. +(procedure int* (mmulVectorOpt [int* A] [int* B] [int n] [int p] [int m]) + (mmulHost "mmulVectorKernelOpt" 4 A B n p m)) + +; Given two arrays of the same size, checks that they hold the same +; values at each index. +(procedure void (check [int* actual] [int* expected] [int SIZE]) + (assert (>= SIZE 0)) + (for [(: int i in (range SIZE))] + (assert (== [actual i] [expected i])))) + +(procedure void (verify_scalar [int from] [int to]) + (verify #:forall [(: int n in (range from to)) + (: int p in (range from to)) + (: int m in (range from to)) + (: int[(* n p)] A) + (: int[(* p m)] B)] + #:ensure (check (mmulScalar A B n p m) + (mmulSequential A B n p m) + (* n m)))) + +(procedure void (verify_vector [int from] [int to]) + (verify #:forall [(: int n in (range from to 4)) + (: int p in (range from to 4)) + (: int m in (range from to 4)) + (: int[(* n p)] A) + (: int[(* p m)] B)] + #:ensure (check (mmulVector A B n p m) + (mmulSequential A B n p m) + (* n m)))) + +(procedure void (verify_vector_opt [int from] [int to]) + (verify #:forall [(: int n in (range from to 4)) + (: int p in (range from to 4)) + (: int m in (range from to 4)) + (: int[(* n p)] A) + (: int[(* p m)] B)] + #:ensure (check (mmulVectorOpt A B n p m) + (mmulSequential A B n p m) + (* n m)))) +(check-type + (with-output-to-string (λ () (verify_scalar 1 5))) + : CString -> "counterexample found:\nn = 1\np = 1\nm = 1\nA = #x0#(-1)\nB = #x1#(1)\n") +(check-type + (with-output-to-string (λ () (verify_vector 4 9))) + : CString -> "counterexample found:\nn = 4\np = 4\nm = 4\nA = #x5#(3 0 0 0 3 0 0 0 98355 0 0 0 98307 0 0 0)\nB = #x6#(-1431655765 -1431655765 1431661227 0 0 0 0 0 0 0 0 0 0 0 0 0)\n") +(check-type + (with-output-to-string (λ () (verify_vector_opt 4 9))) + : CString -> "counterexample found:\nn = 4\np = 4\nm = 4\nA = #xa#(3 0 0 0 3 0 0 0 98355 0 0 0 98307 0 0 0)\nB = #xb#(-1431655765 -1431655765 1431661227 0 0 0 0 0 0 0 0 0 0 0 0 0)\n") + +;(: int n p m) +;(= n 8) (= p 4) (= m 4) +;(: int[(* n p)] A) (: int[(* p m)] B) +;(mmulVector A B n p m) + + diff --git a/turnstile/examples/tests/rosette/rosette3/synthcl3-matrix-verify-tests.rkt b/turnstile/examples/tests/rosette/rosette3/synthcl3-matrix-verify-tests.rkt new file mode 100644 index 0000000..2be467d --- /dev/null +++ b/turnstile/examples/tests/rosette/rosette3/synthcl3-matrix-verify-tests.rkt @@ -0,0 +1,125 @@ +#lang s-exp "../../../rosette/synthcl3.rkt" +(require "../../rackunit-typechecking.rkt") + +; The reference implementation for square matrix multiplication. +; Multiplies two squre matrices A and B, where the dimension of A is +; n x p and dimension of B is p x m. Both matrices are given as +; flat arrays in row-major form. The output is the matrix C = A*B, +; also given in row-major form. +(procedure int* (mmulSequential [int* A] [int* B] [int n] [int p] [int m]) + (: int* C) + (= C ((int*) (malloc (* n m (sizeof int))))) + (for [(: int i in (range n)) + (: int j in (range m)) + (: int k in (range p))] + (+= [C (+ (* i m) j)] (* [A (+ (* i p) k)] [B (+ (* k m) j)]))) + C) + +; A host implementation of matrix multiplication. +(procedure int* (mmulHost [char* kernelName] [int typeLen] + [int* A] [int* B] [int n] [int p] [int m]) + (: cl_context context) + (: cl_command_queue command_queue) + (: cl_program program) + (: cl_kernel kernel) + (: cl_mem buffer_A buffer_B buffer_C) + (: int* C) + (: int[2] global) + (: int dimA dimB dimC) + + (= [global 0] (/ n typeLen)) + (= [global 1] (/ m typeLen)) + (= dimA (* n p (sizeof int))) + (= dimB (* p m (sizeof int))) + (= dimC (* n m (sizeof int))) + + (= C ((int*) (malloc dimC))) + + (= context (clCreateContext)) + + (= command_queue (clCreateCommandQueue context)) + + (= buffer_A (clCreateBuffer context CL_MEM_READ_ONLY dimA)) + (= buffer_B (clCreateBuffer context CL_MEM_READ_ONLY dimB)) + (= buffer_C (clCreateBuffer context CL_MEM_WRITE_ONLY dimC)) + + (= program (clCreateProgramWithSource context "matrix-verify-kernel.rkt")) + + (clEnqueueWriteBuffer command_queue buffer_A 0 dimA A) + (clEnqueueWriteBuffer command_queue buffer_B 0 dimB B) + + (= kernel (clCreateKernel program kernelName)) + (clSetKernelArg kernel 0 buffer_A) + (clSetKernelArg kernel 1 buffer_B) + (clSetKernelArg kernel 2 buffer_C) + (clSetKernelArg kernel 3 p) + (clSetKernelArg kernel 4 m) + + (clEnqueueNDRangeKernel command_queue kernel 2 NULL global NULL) + (clEnqueueReadBuffer command_queue buffer_C 0 dimC C) + C) +; A scalar parallel implementation of matrix multiplication. +(procedure int* (mmulScalar [int* A] [int* B] [int n] [int p] [int m]) + (mmulHost "mmulScalarKernel" 1 A B n p m)) + +; A vector parallel implementation of matrix multiplication. The dimensions +; n and m must be evenly divisible by 4. +(procedure int* (mmulVector [int* A] [int* B] [int n] [int p] [int m]) + (mmulHost "mmulVectorKernel" 4 A B n p m)) + +; An optimized vector parallel implementation of matrix multiplication. The dimensions +; n and m must be evenly divisible by 4. +(procedure int* (mmulVectorOpt [int* A] [int* B] [int n] [int p] [int m]) + (mmulHost "mmulVectorKernelOpt" 4 A B n p m)) + +; Given two arrays of the same size, checks that they hold the same +; values at each index. +(procedure void (check [int* actual] [int* expected] [int SIZE]) + (assert (>= SIZE 0)) + (for [(: int i in (range SIZE))] + (assert (== [actual i] [expected i])))) + +(procedure void (verify_scalar [int from] [int to]) + (verify #:forall [(: int n in (range from to)) + (: int p in (range from to)) + (: int m in (range from to)) + (: int[(* n p)] A) + (: int[(* p m)] B)] + #:ensure (check (mmulScalar A B n p m) + (mmulSequential A B n p m) + (* n m)))) + +(procedure void (verify_vector [int from] [int to]) + (verify #:forall [(: int n in (range from to 4)) + (: int p in (range from to 4)) + (: int m in (range from to 4)) + (: int[(* n p)] A) + (: int[(* p m)] B)] + #:ensure (check (mmulVector A B n p m) + (mmulSequential A B n p m) + (* n m)))) + +(procedure void (verify_vector_opt [int from] [int to]) + (verify #:forall [(: int n in (range from to 4)) + (: int p in (range from to 4)) + (: int m in (range from to 4)) + (: int[(* n p)] A) + (: int[(* p m)] B)] + #:ensure (check (mmulVectorOpt A B n p m) + (mmulSequential A B n p m) + (* n m)))) +(check-type + (with-output-to-string (λ () (verify_scalar 1 5))) + : CString -> "no counterexample found\n") +(check-type + (with-output-to-string (λ () (verify_vector 4 9))) + : CString -> "no counterexample found\n") +(check-type + (with-output-to-string (λ () (verify_vector_opt 4 9))) + : CString -> "no counterexample found\n") + +;(: int n p m) +;(= n 8) (= p 4) (= m 4) +;(: int[(* n p)] A) (: int[(* p m)] B) +;(mmulVector A B n p m) +