synthcl3: fix pointer deref assignment bug; all matrix mult tests passing!

This commit is contained in:
Stephen Chang 2016-12-02 16:25:35 -05:00
parent 57bf9a5543
commit e0a2900c77
6 changed files with 432 additions and 58 deletions

View File

@ -56,7 +56,7 @@
(local-expand
e 'expression
(list #'ro:#%app #'ro:choose #'ro:synthesize #'ro:let #'ro:in-value
#'ro:assert #'ro:if #'ro:?)))
#'ro:assert #'ro:if #'ro:? #'ro:verify)))
; (displayln (stx->datum e+))
e+)
(define (mk-ro:-id id) (format-id id "ro:~a" id))

View File

@ -3,15 +3,14 @@
(require ;(prefix-in ro: (except-in rosette verify sqrt range print)) ; untyped
racket/stxparam
(prefix-in ro: (combine-in rosette rosette/lib/synthax))
(prefix-in cl: (combine-in
(prefix-in cl: (combine-in (except-in sdsl/synthcl/model/operators /)
sdsl/synthcl/lang/forms sdsl/synthcl/model/reals
sdsl/synthcl/model/operators sdsl/synthcl/model/errors
sdsl/synthcl/model/errors sdsl/synthcl/model/kernel
sdsl/synthcl/model/memory sdsl/synthcl/model/runtime
sdsl/synthcl/model/work sdsl/synthcl/model/pointers
sdsl/synthcl/lang/queries sdsl/synthcl/model/context
sdsl/synthcl/model/queue sdsl/synthcl/model/buffer
sdsl/synthcl/model/flags sdsl/synthcl/model/program
sdsl/synthcl/model/kernel))
sdsl/synthcl/model/flags sdsl/synthcl/model/program))
(for-syntax (prefix-in cl: sdsl/synthcl/lang/util)))
(begin-for-syntax
@ -24,12 +23,11 @@
int int2 int3 int4 int16 float float2 float3 float4 float16
bool void void* char* float* int* int2* int3* int4* int16*
cl_context cl_command_queue cl_program cl_kernel cl_mem
: ! ?: == + * - || &&
: ! ?: == + * / - || &&
% << ; int ops
= += -= %= ; assignment ops
= += -= *= /= %= ; assignment ops
sizeof clCreateProgramWithSource
(typed-out
;[with-output-to-string : (C→ (C→ Any) char*)]
[clCreateContext : (C→ cl_context)]
[clCreateCommandQueue : (C→ cl_context cl_command_queue)]
[clCreateBuffer : (C→ cl_context int int cl_mem)]
@ -44,7 +42,6 @@
[get_global_id : (C→ int int)]
[CL_MEM_READ_ONLY : int]
[CL_MEM_WRITE_ONLY : int]
[/ : (Ccase-> (C→ int int int) (C→ float float float))]
[malloc : (C→ int void*)]
[get_work_dim : (C→ int)]
[!= : (Ccase-> (C→ CNum CNum CBool)
@ -115,6 +112,7 @@
expr subexpr)))
(define (mk-ptr id) (format-id id "~a*" id))
(define (mk-mk id) (format-id id "mk-~a" id))
(define (mk-to id) (format-id id "to-~a" id))
(define (add-convert stx fn)
(set-stx-prop/preserved stx 'convert fn))
(define (get-convert stx)
@ -202,10 +200,10 @@
(ro:define (to-intn v)
(ro:cond
[(ro:list? v)
(ro:apply mk-intn (ro:for/list ([i n]) (ro:#%app to-int (ro:list-ref v i))))]
(ro:apply mk-intn (ro:for/list ([i n]) (to-int (ro:list-ref v i))))]
[(ro:vector? v)
(ro:apply mk-intn (ro:for/list ([i n]) (ro:#%app to-int (ro:vector-ref v i))))]
[else (ro:apply mk-intn (ro:make-list n (ro:#%app to-int v)))]))
(ro:apply mk-intn (ro:for/list ([i n]) (to-int (ro:vector-ref v i))))]
[else (ro:apply mk-intn (ro:make-list n (to-int v)))]))
(ro:define (to-intn* v) (cl:pointer-cast v cl-mk-intn))
(ro:define (mk-intn x ...) (ro:#%app cl-mk-intn x ...)))]))
(define-simple-macro (define-ints n ...) (begin (define-int n) ...))
@ -227,10 +225,10 @@
(ro:define (to-floatn v)
(ro:cond
[(ro:list? v)
(ro:apply mk-floatn (ro:for/list ([i n]) (ro:#%app to-float (ro:list-ref v i))))]
(ro:apply mk-floatn (ro:for/list ([i n]) (to-float (ro:list-ref v i))))]
[(ro:vector? v)
(ro:apply mk-floatn (ro:for/list ([i n]) (ro:#%app to-float (ro:vector-ref v i))))]
[else (ro:apply mk-floatn (ro:make-list n (ro:#%app to-float v)))]))
(ro:apply mk-floatn (ro:for/list ([i n]) (to-float (ro:vector-ref v i))))]
[else (ro:apply mk-floatn (ro:make-list n (to-float v)))]))
(ro:define (mk-floatn x ...) (ro:#%app cl-mk-floatn x ...)))]))
(define-simple-macro (define-floats n ...) (begin (define-float n) ...))
(define-floats 2 3 4 16)
@ -346,7 +344,7 @@
(define-typed-syntax if
[(_ test {then ...} {else ...})
--------
[ (ro:if (ro:#%app to-bool test)
[ (ro:if (to-bool test)
(ro:let () then ... (ro:void))
(ro:let () else ... (ro:void))) void]]
[(_ test {then ...})
@ -433,7 +431,7 @@
[ (begin-
(ro:define-symbolic* x-- pred [len base-len]) ...
(ro:define x-
(ro:let ([*x (ro:#%app to-ty* (cl:malloc (ro:* len base-len)))])
(ro:let ([*x (to-ty* (cl:malloc (ro:* len base-len)))])
(ro:for ([i len][v x--])
(cl:pointer-set! *x i (ro:apply mk-ty v)))
*x)) ...
@ -482,12 +480,12 @@
#:with ty-base ((current-type-eval) (datum->syntax #'e (string->symbol out-base-str)))
#:with base-convert (get-convert #'ty-base)
-------
[ (ro:#%app convert
(ro:let ([a (ro:#%app convert e-)][b (ro:#%app convert e1-)][c (ro:#%app convert e2-)])
[ (convert
(ro:let ([a (convert e-)][b (convert e1-)][c (convert e2-)])
(ro:for/list ([idx #,(string->number out-len-str)])
(ro:if (ro:< (ro:vector-ref a idx) 0)
(ro:#%app base-convert (ro:vector-ref b idx))
(ro:#%app base-convert (ro:vector-ref c idx))))))
(base-convert (ro:vector-ref b idx))
(base-convert (ro:vector-ref c idx))))))
ty-out]]
[(_ ~! e e1 e2) ; should be scalar and real
[ e e- ty]
@ -514,22 +512,21 @@
(format "cannot cast ~a to ~a" (type->str #'ty-e) (type->str #'ty-x))
#:with conv (get-convert #'ty-x)
--------
[ (ro:set! x- #,(if (syntax-e #'conv) #'(ro:#%app conv e-) #'e-)) void]]
[ (ro:set! x- #,(if (syntax-e #'conv) #'(conv e-) #'e-)) void]]
;; selector can be list of numbers or up to wxyz for vectors of length <=4
[(_ [x:id sel] e)
[ x x- ty-x]
[ e e- ty-e]
#:with out-e (if (pointer-type? #'ty-x)
#'(ro:begin
(cl:pointer-set! x- sel e-)
x-)
(with-syntax ([conv (mk-to (get-pointer-base #'ty-x))])
#'(ro:begin (cl:pointer-set! x- sel (conv e-)) x-))
(with-syntax ([selector (cl:parse-selector #f #'sel stx)])
#`(ro:let ([out (ro:vector-copy x-)])
#,(if (= 1 (length (stx->list #'selector)))
#`(ro:let ([out (ro:vector-copy x-)])
#,(if (= 1 (length (stx->list #'selector)))
#`(ro:vector-set! out (car 'selector) e-)
#'(ro:for ([idx 'selector] [v e-])
(ro:vector-set! out idx v)))
out)))
out))) ; TODO: need mk-ty here?
--------
[ (ro:set! x- out-e) void]])
@ -541,7 +538,7 @@
[(_ e) ; else try to coerce
[ e e- ty]
--------
[ (ro:#%app cl:! (ro:#%app to-bool e-)) bool]])
[ (ro:#%app cl:! (to-bool e-)) bool]])
;; TODO: this should produce int-vector result?
(define-typed-syntax ==
@ -552,7 +549,7 @@
#:when (real-type? #'ty2)
#:with ty-out ((current-join) #'ty1 #'ty2) ; only need this for the len
--------
[ (ro:#%app to-int (cl:== e1- e2-)) int]])
[ (to-int (cl:== e1- e2-)) int]])
(define-simple-macro (define-bool-ops o ...+) (ro:begin (define-bool-op o) ...))
(define-simple-macro (define-bool-op name)
@ -564,7 +561,12 @@
--------
[ (name- e1- e2-) bool]]
[(_ e1 e2) ; else try to coerce
--- [ (name- (ro:#%app to-bool e1) (ro:#%app to-bool e2)) bool]]))
--- [ (name- (to-bool e1) (to-bool e2)) bool]]))
(define- (cl:/ x y)
(cond- [(zero?- y) 0]
[(integer?- x) (quotient- x y)]
[else (/- x y)]))
(define-simple-macro (define-real-ops o ...) (ro:begin (define-real-op o) ...))
(define-simple-macro (define-real-op name (~optional (~seq #:extra-check p?)
@ -585,10 +587,10 @@
#:with (x (... ...)) (generate-temporaries #'(e (... ...)))
--------
[ #,(if (scalar-type? #'ty-out)
#'(ro:#%app convert (name- (convert e-) (... ...)))
#'(ro:#%app convert (ro:let ([x (ro:#%app convert e-)] (... ...))
#'(convert (name- (convert e-) (... ...)))
#'(convert (ro:let ([x (convert e-)] (... ...))
(ro:for/list ([x x] (... ...))
(ro:#%app base-convert (name- x (... ...))))))) ty-out])
(base-convert (name- x (... ...))))))) ty-out])
(define-typed-syntax (name= x e)
--- [ (= x (name x e))])))
@ -601,14 +603,11 @@
(define-simple-macro (define-int-ops o ...) (ro:begin (define-int-op o) ...))
(define-bool-ops || &&)
(define-real-ops + * -)
(define-real-ops + * - /)
(define-int-ops % <<)
(define-typerule (sizeof t:type) >>
--- [ #,(real-type-length #'t.norm) int])
(define-typerule (print e ...) >>
--- [ (ro:begin (display e) ...) void])
(define-typerule (sizeof t:type) >> ---[ #,(real-type-length #'t.norm) int])
(define-typerule (print e ...) >> ---[ (ro:begin (display e) ...) void])
(define-typed-syntax choose
[(ch e ...+)
@ -656,10 +655,8 @@
[(_ #:forall [decl ...] #:ensure e)
--- [ (synth #:forall [decl ...] #:bitwidth 8 #:ensure e)]])
(define-typed-syntax verify
[(vfy #:forall [decl ...] #:ensure e)
[(_ #:forall [decl ...] #:ensure e)
#:with ([id seq ty] ...) (stx-map decl->seq #'(decl ...))
#:with (id- ...) (generate-temporaries #'(id ...))
#:with (typed-seq ...) #'((with-ctx ([id id- ty] ...) seq) ...)
@ -667,20 +664,17 @@
[ (ro:let ([id- 1] ...) ; dummy, enables simplifying stx template
(ro:parameterize ([ro:current-bitwidth 32]
[ro:term-cache (ro:hash-copy (ro:term-cache))])
(ro:for*/or ([id- typed-seq] ...)
(ro:define cex (with-ctx ([id id- ty] ...) (ro:verify e)))
(ro:and (ro:sat? cex)
(displayln "counterexample found:")
(ro:for ([i '(id ...)] [i- (ro:list id- ...)])
(printf "~a = ~a\n" i (ro:evaluate i- cex))))))) void]])
(ro:or (ro:for*/or ([id- typed-seq] ...)
(ro:define cex (with-ctx ([id id- ty] ...) (ro:verify e)))
(ro:and (ro:sat? cex)
(displayln "counterexample found:")
(ro:for ([i '(id ...)] [i- (ro:list id- ...)])
(printf "~a = ~a\n" i (ro:evaluate i- cex)))
cex))
(begin (displayln "no counterexample found") (ro:unsat))))) void]])
(define-typed-syntax (assert e)
--- [ (ro:assert (ro:#%app to-bool #,(expand/ro #'e))) void])
(define- (/ x y)
(cond- [(zero?- y) 0]
[(integer?- x) (quotient- x y)]
[else (/- x y)]))
--- [ (ro:assert (to-bool #,(expand/ro #'e))) void])
(define-typed-syntax (clCreateProgramWithSource ctx f)
--- [ (cl:clCreateProgramWithSource ctx f) cl_program])

View File

@ -0,0 +1,127 @@
#lang s-exp "../../../rosette/synthcl3.rkt"
; Matrix multiplication C = A * B, where A is an n x p matrix and B is a p x m matrix.
(kernel void (mmulScalarKernel [int* A] [int* B] [int* C] [int p] [int m])
(: int i j sum)
(= i (get_global_id 0))
(= j (get_global_id 1))
(= sum 0)
(for [(: int k in (range p))]
(+= sum (* [A (+ (* i p) k)] [B (+ (* k m) j)])))
(= [C (+ (* i m) j)] sum))
;;--------------- Vectorized kernel ---------------;;
; Matrix multiplication C = A * B, where A is an n x p matrix and B is a p x m matrix.
(kernel void (mmulVectorKernel [int4* A] [int4* B] [int4* C] [int p] [int m])
(: int i j)
(: int4 sum0 sum1 sum2 sum3)
(= i (get_global_id 0))
(= j (get_global_id 1))
(= sum0 0)
(= sum1 0)
(= sum2 0)
(= sum3 0)
(for [(: int k in (range 0 p 4))]
(: int4 a0 a1 a2 a3)
(: int4 b0 b1 b2 b3)
(= a0 [A (indexA 0 i k p)])
(= a1 [A (indexA 1 i k p)])
(= a2 [A (indexA 2 i k p)])
(= a3 [A (indexA 3 i k p)])
(= b0 [B (indexB 0 k j m)])
(= b1 [B (indexB 1 k j m)])
(= b2 [B (indexB 2 k j m)])
(= b3 [B (indexB 3 k j m)])
(+= sum0 (computeSum a0 b0 b1 b2 b3))
(+= sum1 (computeSum a1 b0 b1 b2 b3))
(+= sum2 (computeSum a2 b0 b1 b2 b3))
(+= sum3 (computeSum a3 b0 b1 b2 b3)))
(= [C (indexC 0 i j m)] sum0)
(= [C (indexC 1 i j m)] sum1)
(= [C (indexC 2 i j m)] sum2)
(= [C (indexC 3 i j m)] sum3))
; Multiplies the 1x4 vector a by the 4x4 matrix with rows b0, b1, b2 and b3.
(procedure int4 (computeSum [int4 a] [int4 b0] [int4 b1] [int4 b2] [int4 b3])
(int4
(+ (* [a x] [b0 x]) (* [a y] [b1 x]) (* [a z] [b2 x]) (* [a w] [b3 x]))
(+ (* [a x] [b0 y]) (* [a y] [b1 y]) (* [a z] [b2 y]) (* [a w] [b3 y]))
(+ (* [a x] [b0 z]) (* [a y] [b1 z]) (* [a z] [b2 z]) (* [a w] [b3 z]))
(+ (* [a x] [b0 w]) (* [a y] [b1 w]) (* [a z] [b2 w]) (* [a w] [b3 w]))))
(procedure int (indexA [int off] [int i] [int k] [int p])
(+ (* (+ (* i 4) off) (/ p 4)) (/ k 4)))
(procedure int (indexB [int off] [int k] [int j] [int m])
(+ (* (+ k off) (/ m 4)) j))
(procedure int (indexC [int off] [int i] [int j] [int m])
(+ (* (+ (* i 4) off) (/ m 4)) j))
;;; ---------------- Optimized kernel implementation transcribed from AMD's apps ---------------- ;;;
(: int TILEX TILEX_SHIFT TILEY TILEY_SHIFT)
(= TILEX 4)
(= TILEY_SHIFT 2)
(= TILEY 4)
(= TILEY_SHIFT 2)
; Matrix multiplication C = A * B, where A is an n x p matrix and B is an
; p x m matrix.
(kernel void (mmulVectorKernelOpt [int4* A] [int4* B] [int4* C] [int p] [int m])
(: int2 pos)
(: int4 sum0 sum1 sum2 sum3)
(= pos (int2 (get_global_id 0) (get_global_id 1)))
(= sum0 0)
(= sum1 0)
(= sum2 0)
(= sum3 0)
(/= m 4)
(for [(: int i in (range 0 p 4))]
(: int4 a0 a1 a2 a3)
(: int4 b0 b1 b2 b3)
(= a0 [A (+ (/ i 4) (* (<< [pos x] TILEY_SHIFT) (/ p 4)))])
(= a1 [A (+ (/ i 4) (* (+ (<< [pos x] TILEY_SHIFT) 1) (/ p 4)))])
(= a2 [A (+ (/ i 4) (* (+ (<< [pos x] TILEY_SHIFT) 2) (/ p 4)))])
(= a3 [A (+ (/ i 4) (* (+ (<< [pos x] TILEY_SHIFT) 3) (/ p 4)))])
(= b0 [B (+ [pos y] (* i m))])
(= b1 [B (+ [pos y] (* (+ i 1) m))])
(= b2 [B (+ [pos y] (* (+ i 2) m))])
(= b3 [B (+ [pos y] (* (+ i 3) m))])
(+= [sum0 x] (+ (* [a0 x] [b0 x]) (* [a0 y] [b1 x]) (* [a0 z] [b2 x]) (* [a0 w] [b3 x])))
(+= [sum0 y] (+ (* [a0 x] [b0 y]) (* [a0 y] [b1 y]) (* [a0 z] [b2 y]) (* [a0 w] [b3 y])))
(+= [sum0 z] (+ (* [a0 x] [b0 z]) (* [a0 y] [b1 z]) (* [a0 z] [b2 z]) (* [a0 w] [b3 z])))
(+= [sum0 w] (+ (* [a0 x] [b0 w]) (* [a0 y] [b1 w]) (* [a0 z] [b2 w]) (* [a0 w] [b3 w])))
(+= [sum1 x] (+ (* [a1 x] [b0 x]) (* [a1 y] [b1 x]) (* [a1 z] [b2 x]) (* [a1 w] [b3 x])))
(+= [sum1 y] (+ (* [a1 x] [b0 y]) (* [a1 y] [b1 y]) (* [a1 z] [b2 y]) (* [a1 w] [b3 y])))
(+= [sum1 z] (+ (* [a1 x] [b0 z]) (* [a1 y] [b1 z]) (* [a1 z] [b2 z]) (* [a1 w] [b3 z])))
(+= [sum1 w] (+ (* [a1 x] [b0 w]) (* [a1 y] [b1 w]) (* [a1 z] [b2 w]) (* [a1 w] [b3 w])))
(+= [sum2 x] (+ (* [a2 x] [b0 x]) (* [a2 y] [b1 x]) (* [a2 z] [b2 x]) (* [a2 w] [b3 x])))
(+= [sum2 y] (+ (* [a2 x] [b0 y]) (* [a2 y] [b1 y]) (* [a2 z] [b2 y]) (* [a2 w] [b3 y])))
(+= [sum2 z] (+ (* [a2 x] [b0 z]) (* [a2 y] [b1 z]) (* [a2 z] [b2 z]) (* [a2 w] [b3 z])))
(+= [sum2 w] (+ (* [a2 x] [b0 w]) (* [a2 y] [b1 w]) (* [a2 z] [b2 w]) (* [a2 w] [b3 w])))
(+= [sum3 x] (+ (* [a3 x] [b0 x]) (* [a3 y] [b1 x]) (* [a3 z] [b2 x]) (* [a3 w] [b3 x])))
(+= [sum3 y] (+ (* [a3 x] [b0 y]) (* [a3 y] [b1 y]) (* [a3 z] [b2 y]) (* [a3 w] [b3 y])))
(+= [sum3 z] (+ (* [a3 x] [b0 z]) (* [a3 y] [b1 z]) (* [a3 z] [b2 z]) (* [a3 w] [b3 z])))
(+= [sum3 w] (+ (* [a3 x] [b0 w]) (* [a3 y] [b1 w]) (* [a3 z] [b2 w]) (* [a3 w] [b3 w]))))
(= [C (+ [pos y] (* (+ (<< [pos x] TILEY_SHIFT) 0) m))] sum0)
(= [C (+ [pos y] (* (+ (<< [pos x] TILEY_SHIFT) 1) m))] sum1)
(= [C (+ [pos y] (* (+ (<< [pos x] TILEY_SHIFT) 2) m))] sum2)
(= [C (+ [pos y] (* (+ (<< [pos x] TILEY_SHIFT) 3) m))] sum3))

View File

@ -20,8 +20,10 @@
(do-tests "bv-tests.rkt" "BV SDSL - General"
"fsm3-tests.rkt" "FSM"
"ifc3-tests.rkt" "IFC"
"synthcl3-tests.rkt" "SynthCL"
"synthcl3-matrix-synth-tests.rkt" "SynthCL Matrix Mult: synth")
"ifc3-tests.rkt" "IFC")
(do-tests "synthcl3-tests.rkt" "SynthCL"
"synthcl3-matrix-synth-tests.rkt" "SynthCL Matrix Mult: synth"
"synthcl3-matrix-verify-tests.rkt" "SynthCL Matrix Mult: verify"
"synthcl3-matrix-verify-buggy-tests.rkt" "SynthCL buggy Matrix Mult: verify")
(do-tests "bv-ref-tests.rkt" "BV SDSL - Hacker's Delight synthesis")

View File

@ -0,0 +1,126 @@
#lang s-exp "../../../rosette/synthcl3.rkt"
(require "../../rackunit-typechecking.rkt")
; A buggy reference implementation for square matrix multiplication.
; Multiplies two squre matrices A and B, where the dimension of A is
; n x p and dimension of B is p x m. Both matrices are given as
; flat arrays in row-major form. The output is the matrix C = A*B,
; also given in row-major form.
(procedure int* (mmulSequential [int* A] [int* B] [int n] [int p] [int m])
(: int* C)
(= C ((int*) (malloc (* n m (sizeof int)))))
(for [(: int i in (range n))
(: int j in (range m))
(: int k in (range 1 p))] ; seeded bug
(+= [C (+ (* i m) j)] (* [A (+ (* i p) k)] [B (+ (* k m) j)])))
C)
; A host implementation of matrix multiplication.
(procedure int* (mmulHost [char* kernelName] [int typeLen]
[int* A] [int* B] [int n] [int p] [int m])
(: cl_context context)
(: cl_command_queue command_queue)
(: cl_program program)
(: cl_kernel kernel)
(: cl_mem buffer_A buffer_B buffer_C)
(: int* C)
(: int[2] global)
(: int dimA dimB dimC)
(= [global 0] (/ n typeLen))
(= [global 1] (/ m typeLen))
(= dimA (* n p (sizeof int)))
(= dimB (* p m (sizeof int)))
(= dimC (* n m (sizeof int)))
(= C ((int*) (malloc dimC)))
(= context (clCreateContext))
(= command_queue (clCreateCommandQueue context))
(= buffer_A (clCreateBuffer context CL_MEM_READ_ONLY dimA))
(= buffer_B (clCreateBuffer context CL_MEM_READ_ONLY dimB))
(= buffer_C (clCreateBuffer context CL_MEM_WRITE_ONLY dimC))
(= program (clCreateProgramWithSource context "matrix-verify-kernel.rkt"))
(clEnqueueWriteBuffer command_queue buffer_A 0 dimA A)
(clEnqueueWriteBuffer command_queue buffer_B 0 dimB B)
(= kernel (clCreateKernel program kernelName))
(clSetKernelArg kernel 0 buffer_A)
(clSetKernelArg kernel 1 buffer_B)
(clSetKernelArg kernel 2 buffer_C)
(clSetKernelArg kernel 3 p)
(clSetKernelArg kernel 4 m)
(clEnqueueNDRangeKernel command_queue kernel 2 NULL global NULL)
(clEnqueueReadBuffer command_queue buffer_C 0 dimC C)
C)
; A scalar parallel implementation of matrix multiplication.
(procedure int* (mmulScalar [int* A] [int* B] [int n] [int p] [int m])
(mmulHost "mmulScalarKernel" 1 A B n p m))
; A vector parallel implementation of matrix multiplication. The dimensions
; n and m must be evenly divisible by 4.
(procedure int* (mmulVector [int* A] [int* B] [int n] [int p] [int m])
(mmulHost "mmulVectorKernel" 4 A B n p m))
; An optimized vector parallel implementation of matrix multiplication. The dimensions
; n and m must be evenly divisible by 4.
(procedure int* (mmulVectorOpt [int* A] [int* B] [int n] [int p] [int m])
(mmulHost "mmulVectorKernelOpt" 4 A B n p m))
; Given two arrays of the same size, checks that they hold the same
; values at each index.
(procedure void (check [int* actual] [int* expected] [int SIZE])
(assert (>= SIZE 0))
(for [(: int i in (range SIZE))]
(assert (== [actual i] [expected i]))))
(procedure void (verify_scalar [int from] [int to])
(verify #:forall [(: int n in (range from to))
(: int p in (range from to))
(: int m in (range from to))
(: int[(* n p)] A)
(: int[(* p m)] B)]
#:ensure (check (mmulScalar A B n p m)
(mmulSequential A B n p m)
(* n m))))
(procedure void (verify_vector [int from] [int to])
(verify #:forall [(: int n in (range from to 4))
(: int p in (range from to 4))
(: int m in (range from to 4))
(: int[(* n p)] A)
(: int[(* p m)] B)]
#:ensure (check (mmulVector A B n p m)
(mmulSequential A B n p m)
(* n m))))
(procedure void (verify_vector_opt [int from] [int to])
(verify #:forall [(: int n in (range from to 4))
(: int p in (range from to 4))
(: int m in (range from to 4))
(: int[(* n p)] A)
(: int[(* p m)] B)]
#:ensure (check (mmulVectorOpt A B n p m)
(mmulSequential A B n p m)
(* n m))))
(check-type
(with-output-to-string (λ () (verify_scalar 1 5)))
: CString -> "counterexample found:\nn = 1\np = 1\nm = 1\nA = #x0#(-1)\nB = #x1#(1)\n")
(check-type
(with-output-to-string (λ () (verify_vector 4 9)))
: CString -> "counterexample found:\nn = 4\np = 4\nm = 4\nA = #x5#(3 0 0 0 3 0 0 0 98355 0 0 0 98307 0 0 0)\nB = #x6#(-1431655765 -1431655765 1431661227 0 0 0 0 0 0 0 0 0 0 0 0 0)\n")
(check-type
(with-output-to-string (λ () (verify_vector_opt 4 9)))
: CString -> "counterexample found:\nn = 4\np = 4\nm = 4\nA = #xa#(3 0 0 0 3 0 0 0 98355 0 0 0 98307 0 0 0)\nB = #xb#(-1431655765 -1431655765 1431661227 0 0 0 0 0 0 0 0 0 0 0 0 0)\n")
;(: int n p m)
;(= n 8) (= p 4) (= m 4)
;(: int[(* n p)] A) (: int[(* p m)] B)
;(mmulVector A B n p m)

View File

@ -0,0 +1,125 @@
#lang s-exp "../../../rosette/synthcl3.rkt"
(require "../../rackunit-typechecking.rkt")
; The reference implementation for square matrix multiplication.
; Multiplies two squre matrices A and B, where the dimension of A is
; n x p and dimension of B is p x m. Both matrices are given as
; flat arrays in row-major form. The output is the matrix C = A*B,
; also given in row-major form.
(procedure int* (mmulSequential [int* A] [int* B] [int n] [int p] [int m])
(: int* C)
(= C ((int*) (malloc (* n m (sizeof int)))))
(for [(: int i in (range n))
(: int j in (range m))
(: int k in (range p))]
(+= [C (+ (* i m) j)] (* [A (+ (* i p) k)] [B (+ (* k m) j)])))
C)
; A host implementation of matrix multiplication.
(procedure int* (mmulHost [char* kernelName] [int typeLen]
[int* A] [int* B] [int n] [int p] [int m])
(: cl_context context)
(: cl_command_queue command_queue)
(: cl_program program)
(: cl_kernel kernel)
(: cl_mem buffer_A buffer_B buffer_C)
(: int* C)
(: int[2] global)
(: int dimA dimB dimC)
(= [global 0] (/ n typeLen))
(= [global 1] (/ m typeLen))
(= dimA (* n p (sizeof int)))
(= dimB (* p m (sizeof int)))
(= dimC (* n m (sizeof int)))
(= C ((int*) (malloc dimC)))
(= context (clCreateContext))
(= command_queue (clCreateCommandQueue context))
(= buffer_A (clCreateBuffer context CL_MEM_READ_ONLY dimA))
(= buffer_B (clCreateBuffer context CL_MEM_READ_ONLY dimB))
(= buffer_C (clCreateBuffer context CL_MEM_WRITE_ONLY dimC))
(= program (clCreateProgramWithSource context "matrix-verify-kernel.rkt"))
(clEnqueueWriteBuffer command_queue buffer_A 0 dimA A)
(clEnqueueWriteBuffer command_queue buffer_B 0 dimB B)
(= kernel (clCreateKernel program kernelName))
(clSetKernelArg kernel 0 buffer_A)
(clSetKernelArg kernel 1 buffer_B)
(clSetKernelArg kernel 2 buffer_C)
(clSetKernelArg kernel 3 p)
(clSetKernelArg kernel 4 m)
(clEnqueueNDRangeKernel command_queue kernel 2 NULL global NULL)
(clEnqueueReadBuffer command_queue buffer_C 0 dimC C)
C)
; A scalar parallel implementation of matrix multiplication.
(procedure int* (mmulScalar [int* A] [int* B] [int n] [int p] [int m])
(mmulHost "mmulScalarKernel" 1 A B n p m))
; A vector parallel implementation of matrix multiplication. The dimensions
; n and m must be evenly divisible by 4.
(procedure int* (mmulVector [int* A] [int* B] [int n] [int p] [int m])
(mmulHost "mmulVectorKernel" 4 A B n p m))
; An optimized vector parallel implementation of matrix multiplication. The dimensions
; n and m must be evenly divisible by 4.
(procedure int* (mmulVectorOpt [int* A] [int* B] [int n] [int p] [int m])
(mmulHost "mmulVectorKernelOpt" 4 A B n p m))
; Given two arrays of the same size, checks that they hold the same
; values at each index.
(procedure void (check [int* actual] [int* expected] [int SIZE])
(assert (>= SIZE 0))
(for [(: int i in (range SIZE))]
(assert (== [actual i] [expected i]))))
(procedure void (verify_scalar [int from] [int to])
(verify #:forall [(: int n in (range from to))
(: int p in (range from to))
(: int m in (range from to))
(: int[(* n p)] A)
(: int[(* p m)] B)]
#:ensure (check (mmulScalar A B n p m)
(mmulSequential A B n p m)
(* n m))))
(procedure void (verify_vector [int from] [int to])
(verify #:forall [(: int n in (range from to 4))
(: int p in (range from to 4))
(: int m in (range from to 4))
(: int[(* n p)] A)
(: int[(* p m)] B)]
#:ensure (check (mmulVector A B n p m)
(mmulSequential A B n p m)
(* n m))))
(procedure void (verify_vector_opt [int from] [int to])
(verify #:forall [(: int n in (range from to 4))
(: int p in (range from to 4))
(: int m in (range from to 4))
(: int[(* n p)] A)
(: int[(* p m)] B)]
#:ensure (check (mmulVectorOpt A B n p m)
(mmulSequential A B n p m)
(* n m))))
(check-type
(with-output-to-string (λ () (verify_scalar 1 5)))
: CString -> "no counterexample found\n")
(check-type
(with-output-to-string (λ () (verify_vector 4 9)))
: CString -> "no counterexample found\n")
(check-type
(with-output-to-string (λ () (verify_vector_opt 4 9)))
: CString -> "no counterexample found\n")
;(: int n p m)
;(= n 8) (= p 4) (= m 4)
;(: int[(* n p)] A) (: int[(* p m)] B)
;(mmulVector A B n p m)