synthcl3: fix pointer deref assignment bug; all matrix mult tests passing!
This commit is contained in:
parent
57bf9a5543
commit
e0a2900c77
|
@ -56,7 +56,7 @@
|
|||
(local-expand
|
||||
e 'expression
|
||||
(list #'ro:#%app #'ro:choose #'ro:synthesize #'ro:let #'ro:in-value
|
||||
#'ro:assert #'ro:if #'ro:?)))
|
||||
#'ro:assert #'ro:if #'ro:? #'ro:verify)))
|
||||
; (displayln (stx->datum e+))
|
||||
e+)
|
||||
(define (mk-ro:-id id) (format-id id "ro:~a" id))
|
||||
|
|
|
@ -3,15 +3,14 @@
|
|||
(require ;(prefix-in ro: (except-in rosette verify sqrt range print)) ; untyped
|
||||
racket/stxparam
|
||||
(prefix-in ro: (combine-in rosette rosette/lib/synthax))
|
||||
(prefix-in cl: (combine-in
|
||||
(prefix-in cl: (combine-in (except-in sdsl/synthcl/model/operators /)
|
||||
sdsl/synthcl/lang/forms sdsl/synthcl/model/reals
|
||||
sdsl/synthcl/model/operators sdsl/synthcl/model/errors
|
||||
sdsl/synthcl/model/errors sdsl/synthcl/model/kernel
|
||||
sdsl/synthcl/model/memory sdsl/synthcl/model/runtime
|
||||
sdsl/synthcl/model/work sdsl/synthcl/model/pointers
|
||||
sdsl/synthcl/lang/queries sdsl/synthcl/model/context
|
||||
sdsl/synthcl/model/queue sdsl/synthcl/model/buffer
|
||||
sdsl/synthcl/model/flags sdsl/synthcl/model/program
|
||||
sdsl/synthcl/model/kernel))
|
||||
sdsl/synthcl/model/flags sdsl/synthcl/model/program))
|
||||
(for-syntax (prefix-in cl: sdsl/synthcl/lang/util)))
|
||||
|
||||
(begin-for-syntax
|
||||
|
@ -24,12 +23,11 @@
|
|||
int int2 int3 int4 int16 float float2 float3 float4 float16
|
||||
bool void void* char* float* int* int2* int3* int4* int16*
|
||||
cl_context cl_command_queue cl_program cl_kernel cl_mem
|
||||
: ! ?: == + * - || &&
|
||||
: ! ?: == + * / - || &&
|
||||
% << ; int ops
|
||||
= += -= %= ; assignment ops
|
||||
= += -= *= /= %= ; assignment ops
|
||||
sizeof clCreateProgramWithSource
|
||||
(typed-out
|
||||
;[with-output-to-string : (C→ (C→ Any) char*)]
|
||||
[clCreateContext : (C→ cl_context)]
|
||||
[clCreateCommandQueue : (C→ cl_context cl_command_queue)]
|
||||
[clCreateBuffer : (C→ cl_context int int cl_mem)]
|
||||
|
@ -44,7 +42,6 @@
|
|||
[get_global_id : (C→ int int)]
|
||||
[CL_MEM_READ_ONLY : int]
|
||||
[CL_MEM_WRITE_ONLY : int]
|
||||
[/ : (Ccase-> (C→ int int int) (C→ float float float))]
|
||||
[malloc : (C→ int void*)]
|
||||
[get_work_dim : (C→ int)]
|
||||
[!= : (Ccase-> (C→ CNum CNum CBool)
|
||||
|
@ -115,6 +112,7 @@
|
|||
expr subexpr)))
|
||||
(define (mk-ptr id) (format-id id "~a*" id))
|
||||
(define (mk-mk id) (format-id id "mk-~a" id))
|
||||
(define (mk-to id) (format-id id "to-~a" id))
|
||||
(define (add-convert stx fn)
|
||||
(set-stx-prop/preserved stx 'convert fn))
|
||||
(define (get-convert stx)
|
||||
|
@ -202,10 +200,10 @@
|
|||
(ro:define (to-intn v)
|
||||
(ro:cond
|
||||
[(ro:list? v)
|
||||
(ro:apply mk-intn (ro:for/list ([i n]) (ro:#%app to-int (ro:list-ref v i))))]
|
||||
(ro:apply mk-intn (ro:for/list ([i n]) (to-int (ro:list-ref v i))))]
|
||||
[(ro:vector? v)
|
||||
(ro:apply mk-intn (ro:for/list ([i n]) (ro:#%app to-int (ro:vector-ref v i))))]
|
||||
[else (ro:apply mk-intn (ro:make-list n (ro:#%app to-int v)))]))
|
||||
(ro:apply mk-intn (ro:for/list ([i n]) (to-int (ro:vector-ref v i))))]
|
||||
[else (ro:apply mk-intn (ro:make-list n (to-int v)))]))
|
||||
(ro:define (to-intn* v) (cl:pointer-cast v cl-mk-intn))
|
||||
(ro:define (mk-intn x ...) (ro:#%app cl-mk-intn x ...)))]))
|
||||
(define-simple-macro (define-ints n ...) (begin (define-int n) ...))
|
||||
|
@ -227,10 +225,10 @@
|
|||
(ro:define (to-floatn v)
|
||||
(ro:cond
|
||||
[(ro:list? v)
|
||||
(ro:apply mk-floatn (ro:for/list ([i n]) (ro:#%app to-float (ro:list-ref v i))))]
|
||||
(ro:apply mk-floatn (ro:for/list ([i n]) (to-float (ro:list-ref v i))))]
|
||||
[(ro:vector? v)
|
||||
(ro:apply mk-floatn (ro:for/list ([i n]) (ro:#%app to-float (ro:vector-ref v i))))]
|
||||
[else (ro:apply mk-floatn (ro:make-list n (ro:#%app to-float v)))]))
|
||||
(ro:apply mk-floatn (ro:for/list ([i n]) (to-float (ro:vector-ref v i))))]
|
||||
[else (ro:apply mk-floatn (ro:make-list n (to-float v)))]))
|
||||
(ro:define (mk-floatn x ...) (ro:#%app cl-mk-floatn x ...)))]))
|
||||
(define-simple-macro (define-floats n ...) (begin (define-float n) ...))
|
||||
(define-floats 2 3 4 16)
|
||||
|
@ -346,7 +344,7 @@
|
|||
(define-typed-syntax if
|
||||
[(_ test {then ...} {else ...}) ≫
|
||||
--------
|
||||
[⊢ (ro:if (ro:#%app to-bool test)
|
||||
[⊢ (ro:if (to-bool test)
|
||||
(ro:let () then ... (ro:void))
|
||||
(ro:let () else ... (ro:void))) ⇒ void]]
|
||||
[(_ test {then ...}) ≫
|
||||
|
@ -433,7 +431,7 @@
|
|||
[≻ (begin-
|
||||
(ro:define-symbolic* x-- pred [len base-len]) ...
|
||||
(ro:define x-
|
||||
(ro:let ([*x (ro:#%app to-ty* (cl:malloc (ro:* len base-len)))])
|
||||
(ro:let ([*x (to-ty* (cl:malloc (ro:* len base-len)))])
|
||||
(ro:for ([i len][v x--])
|
||||
(cl:pointer-set! *x i (ro:apply mk-ty v)))
|
||||
*x)) ...
|
||||
|
@ -482,12 +480,12 @@
|
|||
#:with ty-base ((current-type-eval) (datum->syntax #'e (string->symbol out-base-str)))
|
||||
#:with base-convert (get-convert #'ty-base)
|
||||
-------
|
||||
[⊢ (ro:#%app convert
|
||||
(ro:let ([a (ro:#%app convert e-)][b (ro:#%app convert e1-)][c (ro:#%app convert e2-)])
|
||||
[⊢ (convert
|
||||
(ro:let ([a (convert e-)][b (convert e1-)][c (convert e2-)])
|
||||
(ro:for/list ([idx #,(string->number out-len-str)])
|
||||
(ro:if (ro:< (ro:vector-ref a idx) 0)
|
||||
(ro:#%app base-convert (ro:vector-ref b idx))
|
||||
(ro:#%app base-convert (ro:vector-ref c idx))))))
|
||||
(base-convert (ro:vector-ref b idx))
|
||||
(base-convert (ro:vector-ref c idx))))))
|
||||
⇒ ty-out]]
|
||||
[(_ ~! e e1 e2) ≫ ; should be scalar and real
|
||||
[⊢ e ≫ e- ⇒ ty]
|
||||
|
@ -514,22 +512,21 @@
|
|||
(format "cannot cast ~a to ~a" (type->str #'ty-e) (type->str #'ty-x))
|
||||
#:with conv (get-convert #'ty-x)
|
||||
--------
|
||||
[⊢ (ro:set! x- #,(if (syntax-e #'conv) #'(ro:#%app conv e-) #'e-)) ⇒ void]]
|
||||
[⊢ (ro:set! x- #,(if (syntax-e #'conv) #'(conv e-) #'e-)) ⇒ void]]
|
||||
;; selector can be list of numbers or up to wxyz for vectors of length <=4
|
||||
[(_ [x:id sel] e) ≫
|
||||
[⊢ x ≫ x- ⇒ ty-x]
|
||||
[⊢ e ≫ e- ⇒ ty-e]
|
||||
#:with out-e (if (pointer-type? #'ty-x)
|
||||
#'(ro:begin
|
||||
(cl:pointer-set! x- sel e-)
|
||||
x-)
|
||||
(with-syntax ([conv (mk-to (get-pointer-base #'ty-x))])
|
||||
#'(ro:begin (cl:pointer-set! x- sel (conv e-)) x-))
|
||||
(with-syntax ([selector (cl:parse-selector #f #'sel stx)])
|
||||
#`(ro:let ([out (ro:vector-copy x-)])
|
||||
#,(if (= 1 (length (stx->list #'selector)))
|
||||
#`(ro:let ([out (ro:vector-copy x-)])
|
||||
#,(if (= 1 (length (stx->list #'selector)))
|
||||
#`(ro:vector-set! out (car 'selector) e-)
|
||||
#'(ro:for ([idx 'selector] [v e-])
|
||||
(ro:vector-set! out idx v)))
|
||||
out)))
|
||||
out))) ; TODO: need mk-ty here?
|
||||
--------
|
||||
[⊢ (ro:set! x- out-e) ⇒ void]])
|
||||
|
||||
|
@ -541,7 +538,7 @@
|
|||
[(_ e) ≫ ; else try to coerce
|
||||
[⊢ e ≫ e- ⇒ ty]
|
||||
--------
|
||||
[⊢ (ro:#%app cl:! (ro:#%app to-bool e-)) ⇒ bool]])
|
||||
[⊢ (ro:#%app cl:! (to-bool e-)) ⇒ bool]])
|
||||
|
||||
;; TODO: this should produce int-vector result?
|
||||
(define-typed-syntax ==
|
||||
|
@ -552,7 +549,7 @@
|
|||
#:when (real-type? #'ty2)
|
||||
#:with ty-out ((current-join) #'ty1 #'ty2) ; only need this for the len
|
||||
--------
|
||||
[⊢ (ro:#%app to-int (cl:== e1- e2-)) ⇒ int]])
|
||||
[⊢ (to-int (cl:== e1- e2-)) ⇒ int]])
|
||||
|
||||
(define-simple-macro (define-bool-ops o ...+) (ro:begin (define-bool-op o) ...))
|
||||
(define-simple-macro (define-bool-op name)
|
||||
|
@ -564,7 +561,12 @@
|
|||
--------
|
||||
[⊢ (name- e1- e2-) ⇒ bool]]
|
||||
[(_ e1 e2) ≫ ; else try to coerce
|
||||
--- [⊢ (name- (ro:#%app to-bool e1) (ro:#%app to-bool e2)) ⇒ bool]]))
|
||||
--- [⊢ (name- (to-bool e1) (to-bool e2)) ⇒ bool]]))
|
||||
|
||||
(define- (cl:/ x y)
|
||||
(cond- [(zero?- y) 0]
|
||||
[(integer?- x) (quotient- x y)]
|
||||
[else (/- x y)]))
|
||||
|
||||
(define-simple-macro (define-real-ops o ...) (ro:begin (define-real-op o) ...))
|
||||
(define-simple-macro (define-real-op name (~optional (~seq #:extra-check p?)
|
||||
|
@ -585,10 +587,10 @@
|
|||
#:with (x (... ...)) (generate-temporaries #'(e (... ...)))
|
||||
--------
|
||||
[⊢ #,(if (scalar-type? #'ty-out)
|
||||
#'(ro:#%app convert (name- (convert e-) (... ...)))
|
||||
#'(ro:#%app convert (ro:let ([x (ro:#%app convert e-)] (... ...))
|
||||
#'(convert (name- (convert e-) (... ...)))
|
||||
#'(convert (ro:let ([x (convert e-)] (... ...))
|
||||
(ro:for/list ([x x] (... ...))
|
||||
(ro:#%app base-convert (name- x (... ...))))))) ⇒ ty-out])
|
||||
(base-convert (name- x (... ...))))))) ⇒ ty-out])
|
||||
(define-typed-syntax (name= x e) ≫
|
||||
--- [≻ (= x (name x e))])))
|
||||
|
||||
|
@ -601,14 +603,11 @@
|
|||
(define-simple-macro (define-int-ops o ...) (ro:begin (define-int-op o) ...))
|
||||
|
||||
(define-bool-ops || &&)
|
||||
(define-real-ops + * -)
|
||||
(define-real-ops + * - /)
|
||||
(define-int-ops % <<)
|
||||
|
||||
(define-typerule (sizeof t:type) >>
|
||||
--- [⊢ #,(real-type-length #'t.norm) ⇒ int])
|
||||
|
||||
(define-typerule (print e ...) >>
|
||||
--- [⊢ (ro:begin (display e) ...) ⇒ void])
|
||||
(define-typerule (sizeof t:type) >> ---[⊢ #,(real-type-length #'t.norm) ⇒ int])
|
||||
(define-typerule (print e ...) >> ---[⊢ (ro:begin (display e) ...) ⇒ void])
|
||||
|
||||
(define-typed-syntax choose
|
||||
[(ch e ...+) ≫
|
||||
|
@ -656,10 +655,8 @@
|
|||
[(_ #:forall [decl ...] #:ensure e) ≫
|
||||
--- [≻ (synth #:forall [decl ...] #:bitwidth 8 #:ensure e)]])
|
||||
|
||||
|
||||
|
||||
(define-typed-syntax verify
|
||||
[(vfy #:forall [decl ...] #:ensure e) ≫
|
||||
[(_ #:forall [decl ...] #:ensure e) ≫
|
||||
#:with ([id seq ty] ...) (stx-map decl->seq #'(decl ...))
|
||||
#:with (id- ...) (generate-temporaries #'(id ...))
|
||||
#:with (typed-seq ...) #'((with-ctx ([id id- ty] ...) seq) ...)
|
||||
|
@ -667,20 +664,17 @@
|
|||
[⊢ (ro:let ([id- 1] ...) ; dummy, enables simplifying stx template
|
||||
(ro:parameterize ([ro:current-bitwidth 32]
|
||||
[ro:term-cache (ro:hash-copy (ro:term-cache))])
|
||||
(ro:for*/or ([id- typed-seq] ...)
|
||||
(ro:define cex (with-ctx ([id id- ty] ...) (ro:verify e)))
|
||||
(ro:and (ro:sat? cex)
|
||||
(displayln "counterexample found:")
|
||||
(ro:for ([i '(id ...)] [i- (ro:list id- ...)])
|
||||
(printf "~a = ~a\n" i (ro:evaluate i- cex))))))) ⇒ void]])
|
||||
(ro:or (ro:for*/or ([id- typed-seq] ...)
|
||||
(ro:define cex (with-ctx ([id id- ty] ...) (ro:verify e)))
|
||||
(ro:and (ro:sat? cex)
|
||||
(displayln "counterexample found:")
|
||||
(ro:for ([i '(id ...)] [i- (ro:list id- ...)])
|
||||
(printf "~a = ~a\n" i (ro:evaluate i- cex)))
|
||||
cex))
|
||||
(begin (displayln "no counterexample found") (ro:unsat))))) ⇒ void]])
|
||||
|
||||
(define-typed-syntax (assert e) ≫
|
||||
--- [⊢ (ro:assert (ro:#%app to-bool #,(expand/ro #'e))) ⇒ void])
|
||||
|
||||
(define- (/ x y)
|
||||
(cond- [(zero?- y) 0]
|
||||
[(integer?- x) (quotient- x y)]
|
||||
[else (/- x y)]))
|
||||
--- [⊢ (ro:assert (to-bool #,(expand/ro #'e))) ⇒ void])
|
||||
|
||||
(define-typed-syntax (clCreateProgramWithSource ctx f) ≫
|
||||
--- [⊢ (cl:clCreateProgramWithSource ctx f) ⇒ cl_program])
|
||||
|
|
|
@ -0,0 +1,127 @@
|
|||
#lang s-exp "../../../rosette/synthcl3.rkt"
|
||||
|
||||
; Matrix multiplication C = A * B, where A is an n x p matrix and B is a p x m matrix.
|
||||
(kernel void (mmulScalarKernel [int* A] [int* B] [int* C] [int p] [int m])
|
||||
(: int i j sum)
|
||||
(= i (get_global_id 0))
|
||||
(= j (get_global_id 1))
|
||||
(= sum 0)
|
||||
(for [(: int k in (range p))]
|
||||
(+= sum (* [A (+ (* i p) k)] [B (+ (* k m) j)])))
|
||||
(= [C (+ (* i m) j)] sum))
|
||||
|
||||
;;--------------- Vectorized kernel ---------------;;
|
||||
|
||||
; Matrix multiplication C = A * B, where A is an n x p matrix and B is a p x m matrix.
|
||||
(kernel void (mmulVectorKernel [int4* A] [int4* B] [int4* C] [int p] [int m])
|
||||
(: int i j)
|
||||
(: int4 sum0 sum1 sum2 sum3)
|
||||
|
||||
(= i (get_global_id 0))
|
||||
(= j (get_global_id 1))
|
||||
(= sum0 0)
|
||||
(= sum1 0)
|
||||
(= sum2 0)
|
||||
(= sum3 0)
|
||||
|
||||
(for [(: int k in (range 0 p 4))]
|
||||
(: int4 a0 a1 a2 a3)
|
||||
(: int4 b0 b1 b2 b3)
|
||||
|
||||
(= a0 [A (indexA 0 i k p)])
|
||||
(= a1 [A (indexA 1 i k p)])
|
||||
(= a2 [A (indexA 2 i k p)])
|
||||
(= a3 [A (indexA 3 i k p)])
|
||||
|
||||
(= b0 [B (indexB 0 k j m)])
|
||||
(= b1 [B (indexB 1 k j m)])
|
||||
(= b2 [B (indexB 2 k j m)])
|
||||
(= b3 [B (indexB 3 k j m)])
|
||||
|
||||
(+= sum0 (computeSum a0 b0 b1 b2 b3))
|
||||
(+= sum1 (computeSum a1 b0 b1 b2 b3))
|
||||
(+= sum2 (computeSum a2 b0 b1 b2 b3))
|
||||
(+= sum3 (computeSum a3 b0 b1 b2 b3)))
|
||||
|
||||
(= [C (indexC 0 i j m)] sum0)
|
||||
(= [C (indexC 1 i j m)] sum1)
|
||||
(= [C (indexC 2 i j m)] sum2)
|
||||
(= [C (indexC 3 i j m)] sum3))
|
||||
|
||||
; Multiplies the 1x4 vector a by the 4x4 matrix with rows b0, b1, b2 and b3.
|
||||
(procedure int4 (computeSum [int4 a] [int4 b0] [int4 b1] [int4 b2] [int4 b3])
|
||||
(int4
|
||||
(+ (* [a x] [b0 x]) (* [a y] [b1 x]) (* [a z] [b2 x]) (* [a w] [b3 x]))
|
||||
(+ (* [a x] [b0 y]) (* [a y] [b1 y]) (* [a z] [b2 y]) (* [a w] [b3 y]))
|
||||
(+ (* [a x] [b0 z]) (* [a y] [b1 z]) (* [a z] [b2 z]) (* [a w] [b3 z]))
|
||||
(+ (* [a x] [b0 w]) (* [a y] [b1 w]) (* [a z] [b2 w]) (* [a w] [b3 w]))))
|
||||
|
||||
(procedure int (indexA [int off] [int i] [int k] [int p])
|
||||
(+ (* (+ (* i 4) off) (/ p 4)) (/ k 4)))
|
||||
|
||||
(procedure int (indexB [int off] [int k] [int j] [int m])
|
||||
(+ (* (+ k off) (/ m 4)) j))
|
||||
|
||||
(procedure int (indexC [int off] [int i] [int j] [int m])
|
||||
(+ (* (+ (* i 4) off) (/ m 4)) j))
|
||||
|
||||
|
||||
;;; ---------------- Optimized kernel implementation transcribed from AMD's apps ---------------- ;;;
|
||||
(: int TILEX TILEX_SHIFT TILEY TILEY_SHIFT)
|
||||
(= TILEX 4)
|
||||
(= TILEY_SHIFT 2)
|
||||
(= TILEY 4)
|
||||
(= TILEY_SHIFT 2)
|
||||
|
||||
; Matrix multiplication C = A * B, where A is an n x p matrix and B is an
|
||||
; p x m matrix.
|
||||
(kernel void (mmulVectorKernelOpt [int4* A] [int4* B] [int4* C] [int p] [int m])
|
||||
(: int2 pos)
|
||||
(: int4 sum0 sum1 sum2 sum3)
|
||||
|
||||
(= pos (int2 (get_global_id 0) (get_global_id 1)))
|
||||
(= sum0 0)
|
||||
(= sum1 0)
|
||||
(= sum2 0)
|
||||
(= sum3 0)
|
||||
|
||||
(/= m 4)
|
||||
|
||||
(for [(: int i in (range 0 p 4))]
|
||||
(: int4 a0 a1 a2 a3)
|
||||
(: int4 b0 b1 b2 b3)
|
||||
|
||||
(= a0 [A (+ (/ i 4) (* (<< [pos x] TILEY_SHIFT) (/ p 4)))])
|
||||
(= a1 [A (+ (/ i 4) (* (+ (<< [pos x] TILEY_SHIFT) 1) (/ p 4)))])
|
||||
(= a2 [A (+ (/ i 4) (* (+ (<< [pos x] TILEY_SHIFT) 2) (/ p 4)))])
|
||||
(= a3 [A (+ (/ i 4) (* (+ (<< [pos x] TILEY_SHIFT) 3) (/ p 4)))])
|
||||
|
||||
(= b0 [B (+ [pos y] (* i m))])
|
||||
(= b1 [B (+ [pos y] (* (+ i 1) m))])
|
||||
(= b2 [B (+ [pos y] (* (+ i 2) m))])
|
||||
(= b3 [B (+ [pos y] (* (+ i 3) m))])
|
||||
|
||||
(+= [sum0 x] (+ (* [a0 x] [b0 x]) (* [a0 y] [b1 x]) (* [a0 z] [b2 x]) (* [a0 w] [b3 x])))
|
||||
(+= [sum0 y] (+ (* [a0 x] [b0 y]) (* [a0 y] [b1 y]) (* [a0 z] [b2 y]) (* [a0 w] [b3 y])))
|
||||
(+= [sum0 z] (+ (* [a0 x] [b0 z]) (* [a0 y] [b1 z]) (* [a0 z] [b2 z]) (* [a0 w] [b3 z])))
|
||||
(+= [sum0 w] (+ (* [a0 x] [b0 w]) (* [a0 y] [b1 w]) (* [a0 z] [b2 w]) (* [a0 w] [b3 w])))
|
||||
|
||||
(+= [sum1 x] (+ (* [a1 x] [b0 x]) (* [a1 y] [b1 x]) (* [a1 z] [b2 x]) (* [a1 w] [b3 x])))
|
||||
(+= [sum1 y] (+ (* [a1 x] [b0 y]) (* [a1 y] [b1 y]) (* [a1 z] [b2 y]) (* [a1 w] [b3 y])))
|
||||
(+= [sum1 z] (+ (* [a1 x] [b0 z]) (* [a1 y] [b1 z]) (* [a1 z] [b2 z]) (* [a1 w] [b3 z])))
|
||||
(+= [sum1 w] (+ (* [a1 x] [b0 w]) (* [a1 y] [b1 w]) (* [a1 z] [b2 w]) (* [a1 w] [b3 w])))
|
||||
|
||||
(+= [sum2 x] (+ (* [a2 x] [b0 x]) (* [a2 y] [b1 x]) (* [a2 z] [b2 x]) (* [a2 w] [b3 x])))
|
||||
(+= [sum2 y] (+ (* [a2 x] [b0 y]) (* [a2 y] [b1 y]) (* [a2 z] [b2 y]) (* [a2 w] [b3 y])))
|
||||
(+= [sum2 z] (+ (* [a2 x] [b0 z]) (* [a2 y] [b1 z]) (* [a2 z] [b2 z]) (* [a2 w] [b3 z])))
|
||||
(+= [sum2 w] (+ (* [a2 x] [b0 w]) (* [a2 y] [b1 w]) (* [a2 z] [b2 w]) (* [a2 w] [b3 w])))
|
||||
|
||||
(+= [sum3 x] (+ (* [a3 x] [b0 x]) (* [a3 y] [b1 x]) (* [a3 z] [b2 x]) (* [a3 w] [b3 x])))
|
||||
(+= [sum3 y] (+ (* [a3 x] [b0 y]) (* [a3 y] [b1 y]) (* [a3 z] [b2 y]) (* [a3 w] [b3 y])))
|
||||
(+= [sum3 z] (+ (* [a3 x] [b0 z]) (* [a3 y] [b1 z]) (* [a3 z] [b2 z]) (* [a3 w] [b3 z])))
|
||||
(+= [sum3 w] (+ (* [a3 x] [b0 w]) (* [a3 y] [b1 w]) (* [a3 z] [b2 w]) (* [a3 w] [b3 w]))))
|
||||
|
||||
(= [C (+ [pos y] (* (+ (<< [pos x] TILEY_SHIFT) 0) m))] sum0)
|
||||
(= [C (+ [pos y] (* (+ (<< [pos x] TILEY_SHIFT) 1) m))] sum1)
|
||||
(= [C (+ [pos y] (* (+ (<< [pos x] TILEY_SHIFT) 2) m))] sum2)
|
||||
(= [C (+ [pos y] (* (+ (<< [pos x] TILEY_SHIFT) 3) m))] sum3))
|
|
@ -20,8 +20,10 @@
|
|||
|
||||
(do-tests "bv-tests.rkt" "BV SDSL - General"
|
||||
"fsm3-tests.rkt" "FSM"
|
||||
"ifc3-tests.rkt" "IFC"
|
||||
"synthcl3-tests.rkt" "SynthCL"
|
||||
"synthcl3-matrix-synth-tests.rkt" "SynthCL Matrix Mult: synth")
|
||||
"ifc3-tests.rkt" "IFC")
|
||||
(do-tests "synthcl3-tests.rkt" "SynthCL"
|
||||
"synthcl3-matrix-synth-tests.rkt" "SynthCL Matrix Mult: synth"
|
||||
"synthcl3-matrix-verify-tests.rkt" "SynthCL Matrix Mult: verify"
|
||||
"synthcl3-matrix-verify-buggy-tests.rkt" "SynthCL buggy Matrix Mult: verify")
|
||||
(do-tests "bv-ref-tests.rkt" "BV SDSL - Hacker's Delight synthesis")
|
||||
|
||||
|
|
|
@ -0,0 +1,126 @@
|
|||
#lang s-exp "../../../rosette/synthcl3.rkt"
|
||||
(require "../../rackunit-typechecking.rkt")
|
||||
|
||||
; A buggy reference implementation for square matrix multiplication.
|
||||
; Multiplies two squre matrices A and B, where the dimension of A is
|
||||
; n x p and dimension of B is p x m. Both matrices are given as
|
||||
; flat arrays in row-major form. The output is the matrix C = A*B,
|
||||
; also given in row-major form.
|
||||
(procedure int* (mmulSequential [int* A] [int* B] [int n] [int p] [int m])
|
||||
(: int* C)
|
||||
(= C ((int*) (malloc (* n m (sizeof int)))))
|
||||
(for [(: int i in (range n))
|
||||
(: int j in (range m))
|
||||
(: int k in (range 1 p))] ; seeded bug
|
||||
(+= [C (+ (* i m) j)] (* [A (+ (* i p) k)] [B (+ (* k m) j)])))
|
||||
C)
|
||||
|
||||
; A host implementation of matrix multiplication.
|
||||
(procedure int* (mmulHost [char* kernelName] [int typeLen]
|
||||
[int* A] [int* B] [int n] [int p] [int m])
|
||||
(: cl_context context)
|
||||
(: cl_command_queue command_queue)
|
||||
(: cl_program program)
|
||||
(: cl_kernel kernel)
|
||||
(: cl_mem buffer_A buffer_B buffer_C)
|
||||
(: int* C)
|
||||
(: int[2] global)
|
||||
(: int dimA dimB dimC)
|
||||
|
||||
(= [global 0] (/ n typeLen))
|
||||
(= [global 1] (/ m typeLen))
|
||||
(= dimA (* n p (sizeof int)))
|
||||
(= dimB (* p m (sizeof int)))
|
||||
(= dimC (* n m (sizeof int)))
|
||||
|
||||
(= C ((int*) (malloc dimC)))
|
||||
|
||||
(= context (clCreateContext))
|
||||
|
||||
(= command_queue (clCreateCommandQueue context))
|
||||
|
||||
(= buffer_A (clCreateBuffer context CL_MEM_READ_ONLY dimA))
|
||||
(= buffer_B (clCreateBuffer context CL_MEM_READ_ONLY dimB))
|
||||
(= buffer_C (clCreateBuffer context CL_MEM_WRITE_ONLY dimC))
|
||||
|
||||
(= program (clCreateProgramWithSource context "matrix-verify-kernel.rkt"))
|
||||
|
||||
(clEnqueueWriteBuffer command_queue buffer_A 0 dimA A)
|
||||
(clEnqueueWriteBuffer command_queue buffer_B 0 dimB B)
|
||||
|
||||
(= kernel (clCreateKernel program kernelName))
|
||||
(clSetKernelArg kernel 0 buffer_A)
|
||||
(clSetKernelArg kernel 1 buffer_B)
|
||||
(clSetKernelArg kernel 2 buffer_C)
|
||||
(clSetKernelArg kernel 3 p)
|
||||
(clSetKernelArg kernel 4 m)
|
||||
|
||||
(clEnqueueNDRangeKernel command_queue kernel 2 NULL global NULL)
|
||||
(clEnqueueReadBuffer command_queue buffer_C 0 dimC C)
|
||||
C)
|
||||
; A scalar parallel implementation of matrix multiplication.
|
||||
(procedure int* (mmulScalar [int* A] [int* B] [int n] [int p] [int m])
|
||||
(mmulHost "mmulScalarKernel" 1 A B n p m))
|
||||
|
||||
; A vector parallel implementation of matrix multiplication. The dimensions
|
||||
; n and m must be evenly divisible by 4.
|
||||
(procedure int* (mmulVector [int* A] [int* B] [int n] [int p] [int m])
|
||||
(mmulHost "mmulVectorKernel" 4 A B n p m))
|
||||
|
||||
; An optimized vector parallel implementation of matrix multiplication. The dimensions
|
||||
; n and m must be evenly divisible by 4.
|
||||
(procedure int* (mmulVectorOpt [int* A] [int* B] [int n] [int p] [int m])
|
||||
(mmulHost "mmulVectorKernelOpt" 4 A B n p m))
|
||||
|
||||
; Given two arrays of the same size, checks that they hold the same
|
||||
; values at each index.
|
||||
(procedure void (check [int* actual] [int* expected] [int SIZE])
|
||||
(assert (>= SIZE 0))
|
||||
(for [(: int i in (range SIZE))]
|
||||
(assert (== [actual i] [expected i]))))
|
||||
|
||||
(procedure void (verify_scalar [int from] [int to])
|
||||
(verify #:forall [(: int n in (range from to))
|
||||
(: int p in (range from to))
|
||||
(: int m in (range from to))
|
||||
(: int[(* n p)] A)
|
||||
(: int[(* p m)] B)]
|
||||
#:ensure (check (mmulScalar A B n p m)
|
||||
(mmulSequential A B n p m)
|
||||
(* n m))))
|
||||
|
||||
(procedure void (verify_vector [int from] [int to])
|
||||
(verify #:forall [(: int n in (range from to 4))
|
||||
(: int p in (range from to 4))
|
||||
(: int m in (range from to 4))
|
||||
(: int[(* n p)] A)
|
||||
(: int[(* p m)] B)]
|
||||
#:ensure (check (mmulVector A B n p m)
|
||||
(mmulSequential A B n p m)
|
||||
(* n m))))
|
||||
|
||||
(procedure void (verify_vector_opt [int from] [int to])
|
||||
(verify #:forall [(: int n in (range from to 4))
|
||||
(: int p in (range from to 4))
|
||||
(: int m in (range from to 4))
|
||||
(: int[(* n p)] A)
|
||||
(: int[(* p m)] B)]
|
||||
#:ensure (check (mmulVectorOpt A B n p m)
|
||||
(mmulSequential A B n p m)
|
||||
(* n m))))
|
||||
(check-type
|
||||
(with-output-to-string (λ () (verify_scalar 1 5)))
|
||||
: CString -> "counterexample found:\nn = 1\np = 1\nm = 1\nA = #x0#(-1)\nB = #x1#(1)\n")
|
||||
(check-type
|
||||
(with-output-to-string (λ () (verify_vector 4 9)))
|
||||
: CString -> "counterexample found:\nn = 4\np = 4\nm = 4\nA = #x5#(3 0 0 0 3 0 0 0 98355 0 0 0 98307 0 0 0)\nB = #x6#(-1431655765 -1431655765 1431661227 0 0 0 0 0 0 0 0 0 0 0 0 0)\n")
|
||||
(check-type
|
||||
(with-output-to-string (λ () (verify_vector_opt 4 9)))
|
||||
: CString -> "counterexample found:\nn = 4\np = 4\nm = 4\nA = #xa#(3 0 0 0 3 0 0 0 98355 0 0 0 98307 0 0 0)\nB = #xb#(-1431655765 -1431655765 1431661227 0 0 0 0 0 0 0 0 0 0 0 0 0)\n")
|
||||
|
||||
;(: int n p m)
|
||||
;(= n 8) (= p 4) (= m 4)
|
||||
;(: int[(* n p)] A) (: int[(* p m)] B)
|
||||
;(mmulVector A B n p m)
|
||||
|
||||
|
|
@ -0,0 +1,125 @@
|
|||
#lang s-exp "../../../rosette/synthcl3.rkt"
|
||||
(require "../../rackunit-typechecking.rkt")
|
||||
|
||||
; The reference implementation for square matrix multiplication.
|
||||
; Multiplies two squre matrices A and B, where the dimension of A is
|
||||
; n x p and dimension of B is p x m. Both matrices are given as
|
||||
; flat arrays in row-major form. The output is the matrix C = A*B,
|
||||
; also given in row-major form.
|
||||
(procedure int* (mmulSequential [int* A] [int* B] [int n] [int p] [int m])
|
||||
(: int* C)
|
||||
(= C ((int*) (malloc (* n m (sizeof int)))))
|
||||
(for [(: int i in (range n))
|
||||
(: int j in (range m))
|
||||
(: int k in (range p))]
|
||||
(+= [C (+ (* i m) j)] (* [A (+ (* i p) k)] [B (+ (* k m) j)])))
|
||||
C)
|
||||
|
||||
; A host implementation of matrix multiplication.
|
||||
(procedure int* (mmulHost [char* kernelName] [int typeLen]
|
||||
[int* A] [int* B] [int n] [int p] [int m])
|
||||
(: cl_context context)
|
||||
(: cl_command_queue command_queue)
|
||||
(: cl_program program)
|
||||
(: cl_kernel kernel)
|
||||
(: cl_mem buffer_A buffer_B buffer_C)
|
||||
(: int* C)
|
||||
(: int[2] global)
|
||||
(: int dimA dimB dimC)
|
||||
|
||||
(= [global 0] (/ n typeLen))
|
||||
(= [global 1] (/ m typeLen))
|
||||
(= dimA (* n p (sizeof int)))
|
||||
(= dimB (* p m (sizeof int)))
|
||||
(= dimC (* n m (sizeof int)))
|
||||
|
||||
(= C ((int*) (malloc dimC)))
|
||||
|
||||
(= context (clCreateContext))
|
||||
|
||||
(= command_queue (clCreateCommandQueue context))
|
||||
|
||||
(= buffer_A (clCreateBuffer context CL_MEM_READ_ONLY dimA))
|
||||
(= buffer_B (clCreateBuffer context CL_MEM_READ_ONLY dimB))
|
||||
(= buffer_C (clCreateBuffer context CL_MEM_WRITE_ONLY dimC))
|
||||
|
||||
(= program (clCreateProgramWithSource context "matrix-verify-kernel.rkt"))
|
||||
|
||||
(clEnqueueWriteBuffer command_queue buffer_A 0 dimA A)
|
||||
(clEnqueueWriteBuffer command_queue buffer_B 0 dimB B)
|
||||
|
||||
(= kernel (clCreateKernel program kernelName))
|
||||
(clSetKernelArg kernel 0 buffer_A)
|
||||
(clSetKernelArg kernel 1 buffer_B)
|
||||
(clSetKernelArg kernel 2 buffer_C)
|
||||
(clSetKernelArg kernel 3 p)
|
||||
(clSetKernelArg kernel 4 m)
|
||||
|
||||
(clEnqueueNDRangeKernel command_queue kernel 2 NULL global NULL)
|
||||
(clEnqueueReadBuffer command_queue buffer_C 0 dimC C)
|
||||
C)
|
||||
; A scalar parallel implementation of matrix multiplication.
|
||||
(procedure int* (mmulScalar [int* A] [int* B] [int n] [int p] [int m])
|
||||
(mmulHost "mmulScalarKernel" 1 A B n p m))
|
||||
|
||||
; A vector parallel implementation of matrix multiplication. The dimensions
|
||||
; n and m must be evenly divisible by 4.
|
||||
(procedure int* (mmulVector [int* A] [int* B] [int n] [int p] [int m])
|
||||
(mmulHost "mmulVectorKernel" 4 A B n p m))
|
||||
|
||||
; An optimized vector parallel implementation of matrix multiplication. The dimensions
|
||||
; n and m must be evenly divisible by 4.
|
||||
(procedure int* (mmulVectorOpt [int* A] [int* B] [int n] [int p] [int m])
|
||||
(mmulHost "mmulVectorKernelOpt" 4 A B n p m))
|
||||
|
||||
; Given two arrays of the same size, checks that they hold the same
|
||||
; values at each index.
|
||||
(procedure void (check [int* actual] [int* expected] [int SIZE])
|
||||
(assert (>= SIZE 0))
|
||||
(for [(: int i in (range SIZE))]
|
||||
(assert (== [actual i] [expected i]))))
|
||||
|
||||
(procedure void (verify_scalar [int from] [int to])
|
||||
(verify #:forall [(: int n in (range from to))
|
||||
(: int p in (range from to))
|
||||
(: int m in (range from to))
|
||||
(: int[(* n p)] A)
|
||||
(: int[(* p m)] B)]
|
||||
#:ensure (check (mmulScalar A B n p m)
|
||||
(mmulSequential A B n p m)
|
||||
(* n m))))
|
||||
|
||||
(procedure void (verify_vector [int from] [int to])
|
||||
(verify #:forall [(: int n in (range from to 4))
|
||||
(: int p in (range from to 4))
|
||||
(: int m in (range from to 4))
|
||||
(: int[(* n p)] A)
|
||||
(: int[(* p m)] B)]
|
||||
#:ensure (check (mmulVector A B n p m)
|
||||
(mmulSequential A B n p m)
|
||||
(* n m))))
|
||||
|
||||
(procedure void (verify_vector_opt [int from] [int to])
|
||||
(verify #:forall [(: int n in (range from to 4))
|
||||
(: int p in (range from to 4))
|
||||
(: int m in (range from to 4))
|
||||
(: int[(* n p)] A)
|
||||
(: int[(* p m)] B)]
|
||||
#:ensure (check (mmulVectorOpt A B n p m)
|
||||
(mmulSequential A B n p m)
|
||||
(* n m))))
|
||||
(check-type
|
||||
(with-output-to-string (λ () (verify_scalar 1 5)))
|
||||
: CString -> "no counterexample found\n")
|
||||
(check-type
|
||||
(with-output-to-string (λ () (verify_vector 4 9)))
|
||||
: CString -> "no counterexample found\n")
|
||||
(check-type
|
||||
(with-output-to-string (λ () (verify_vector_opt 4 9)))
|
||||
: CString -> "no counterexample found\n")
|
||||
|
||||
;(: int n p m)
|
||||
;(= n 8) (= p 4) (= m 4)
|
||||
;(: int[(* n p)] A) (: int[(* p m)] B)
|
||||
;(mmulVector A B n p m)
|
||||
|
Loading…
Reference in New Issue
Block a user