fix some bugs; matrix synth tests working
- pointer selection should have type of ptr base - fix more invalid vectors, created by to-X, and : - add cl fns
This commit is contained in:
parent
f1f43697f9
commit
b5e8a7bceb
|
@ -56,7 +56,7 @@
|
|||
(local-expand
|
||||
e 'expression
|
||||
(list #'ro:#%app #'ro:choose #'ro:synthesize #'ro:let #'ro:in-value
|
||||
#'ro:assert #'ro:if)))
|
||||
#'ro:assert #'ro:if #'ro:?)))
|
||||
; (displayln (stx->datum e+))
|
||||
e+)
|
||||
(define (mk-ro:-id id) (format-id id "ro:~a" id))
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
#lang turnstile
|
||||
(extends "rosette3.rkt" #:except ! #%app || && void = * + - #%datum if assert verify) ; typed rosette
|
||||
(extends "rosette3.rkt" #:except ! #%app || && void = * + - / #%datum if assert verify) ; typed rosette
|
||||
(require ;(prefix-in ro: (except-in rosette verify sqrt range print)) ; untyped
|
||||
racket/stxparam
|
||||
(prefix-in ro: (combine-in rosette rosette/lib/synthax))
|
||||
|
@ -8,7 +8,10 @@
|
|||
sdsl/synthcl/model/operators sdsl/synthcl/model/errors
|
||||
sdsl/synthcl/model/memory sdsl/synthcl/model/runtime
|
||||
sdsl/synthcl/model/work sdsl/synthcl/model/pointers
|
||||
sdsl/synthcl/lang/queries))
|
||||
sdsl/synthcl/lang/queries sdsl/synthcl/model/context
|
||||
sdsl/synthcl/model/queue sdsl/synthcl/model/buffer
|
||||
sdsl/synthcl/model/flags sdsl/synthcl/model/program
|
||||
sdsl/synthcl/model/kernel))
|
||||
(for-syntax (prefix-in cl: sdsl/synthcl/lang/util)))
|
||||
|
||||
(begin-for-syntax
|
||||
|
@ -19,16 +22,31 @@
|
|||
procedure kernel grammar #%datum if range for print
|
||||
choose locally-scoped assert synth verify
|
||||
int int2 int3 int4 int16 float float2 float3 float4 float16
|
||||
bool void void* char* float* int* int16* int2*
|
||||
bool void void* char* float* int* int2* int3* int4* int16*
|
||||
cl_context cl_command_queue cl_program cl_kernel cl_mem
|
||||
: ! ?: == + * - || &&
|
||||
% << ; int ops
|
||||
= += -= %= ; assignment ops
|
||||
sizeof
|
||||
sizeof clCreateProgramWithSource
|
||||
(typed-out
|
||||
;[with-output-to-string : (C→ (C→ Any) char*)]
|
||||
[clCreateContext : (C→ cl_context)]
|
||||
[clCreateCommandQueue : (C→ cl_context cl_command_queue)]
|
||||
[clCreateBuffer : (C→ cl_context int int cl_mem)]
|
||||
[clEnqueueReadBuffer : (C→ cl_command_queue cl_mem int int void* void)]
|
||||
|
||||
[clEnqueueWriteBuffer : (C→ cl_command_queue cl_mem int int void* void)]
|
||||
[clEnqueueNDRangeKernel : (C→ cl_command_queue cl_kernel int int* int* int* void)]
|
||||
[clCreateKernel : (C→ cl_program char* cl_kernel)]
|
||||
[clSetKernelArg : (Ccase-> (C→ cl_kernel int cl_mem void)
|
||||
(C→ cl_kernel int int void)
|
||||
(C→ cl_kernel int float void))]
|
||||
[get_global_id : (C→ int int)]
|
||||
[CL_MEM_READ_ONLY : int]
|
||||
[CL_MEM_WRITE_ONLY : int]
|
||||
[/ : (Ccase-> (C→ int int int) (C→ float float float))]
|
||||
[malloc : (C→ int void*)]
|
||||
[get_work_dim : (C→ int)]
|
||||
|
||||
[!= : (Ccase-> (C→ CNum CNum CBool)
|
||||
(C→ CNum CNum CNum CBool)
|
||||
(C→ Num Num Bool)
|
||||
|
@ -63,9 +81,8 @@
|
|||
(define common-real-type
|
||||
(case-lambda
|
||||
[(t) (and (real-type? t) t)]
|
||||
[(t1 t2) (cond [(real-type<=? t1 t2) t2]
|
||||
[(real-type<=? t2 t1) t1]
|
||||
[else #f])]
|
||||
[(t1 t2) (or (and (real-type<=? t1 t2) t2)
|
||||
(and (real-type<=? t2 t1) t1))]
|
||||
[ts (common-real-type (car ts) (apply common-real-type (cdr ts)))]))
|
||||
|
||||
;; implements common-real-type from model/reals.rkt
|
||||
|
@ -96,6 +113,8 @@
|
|||
(type->str from) (type->str to)
|
||||
#;(if (contract? to) (contract-name to) to))
|
||||
expr subexpr)))
|
||||
(define (mk-ptr id) (format-id id "~a*" id))
|
||||
(define (mk-mk id) (format-id id "mk-~a" id))
|
||||
(define (add-convert stx fn)
|
||||
(set-stx-prop/preserved stx 'convert fn))
|
||||
(define (get-convert stx)
|
||||
|
@ -114,6 +133,8 @@
|
|||
((current-type-eval)
|
||||
(datum->syntax ctx
|
||||
(string->symbol (car (regexp-match #px"[a-z]+" (type->str ty)))))))
|
||||
(define (get-pointer-base ty [ctx #'here])
|
||||
(datum->syntax ctx (string->symbol (string-trim (type->str ty) "*"))))
|
||||
(define (vector-type? ty)
|
||||
(ty->len ty)) ; TODO: check and not pointer-type?
|
||||
(define (scalar-type? ty)
|
||||
|
@ -146,113 +167,74 @@
|
|||
(ro:define (mk-int v)
|
||||
(ro:#%app cl:int v))
|
||||
|
||||
(ro:define (to-int16* v)
|
||||
(cl:pointer-cast v cl:int16))
|
||||
(ro:define (to-int2* v)
|
||||
(cl:pointer-cast v cl:int2))
|
||||
(ro:define (to-int* v)
|
||||
(cl:pointer-cast v cl:int))
|
||||
(ro:define (to-float* v)
|
||||
(cl:pointer-cast v cl:float))
|
||||
|
||||
(define-named-type-alias bool
|
||||
(add-convertm rosette3:Bool to-bool))
|
||||
(define-named-type-alias int
|
||||
(add-convertm rosette3:Int to-int))
|
||||
(define-named-type-alias float
|
||||
(add-convertm rosette3:Num to-float))
|
||||
(define-named-type-alias char*
|
||||
(add-convertm rosette3:CString (λ (x) x)))
|
||||
(define-type-constructor Pointer #:arity = 1)
|
||||
;(define-named-type-alias void rosette3:CUnit)
|
||||
(define-base-types void cl_context cl_command_queue cl_program cl_kernel cl_mem)
|
||||
(define-named-type-alias void* (add-convertm (Pointer void) (λ (x) x)))
|
||||
(define-named-type-alias bool (add-convertm rosette3:Bool to-bool))
|
||||
(define-named-type-alias int (add-convertm rosette3:Int to-int))
|
||||
(define-named-type-alias int* (add-convertm (Pointer int) to-int*))
|
||||
(define-named-type-alias float (add-convertm rosette3:Num to-float))
|
||||
(define-named-type-alias float* (add-convertm (Pointer float) to-float*))
|
||||
(define-named-type-alias char* (add-convertm rosette3:CString (λ (x) x)))
|
||||
|
||||
(define-syntax (define-int stx)
|
||||
(syntax-parse stx
|
||||
[(_ n)
|
||||
#:with intn (format-id #'n "int~a" (syntax->datum #'n))
|
||||
#:with to-intn (format-id #'n "to-~a" #'intn)
|
||||
#:with mk-intn (format-id #'n "mk-~a" #'intn)
|
||||
#:with cl-mk-intn (mk-cl #'intn)
|
||||
#:with (x ...) (generate-temporaries
|
||||
(build-list (syntax->datum #'n) (lambda (x) x)))
|
||||
#:with (I ...) (stx-map (lambda _ #'rosette3:Int) #'(x ...))
|
||||
#'(begin
|
||||
(define-named-type-alias intn
|
||||
(add-constructm
|
||||
(add-convertm
|
||||
(rosette3:CVector I ...)
|
||||
to-intn)
|
||||
mk-intn))
|
||||
(ro:define (to-intn v)
|
||||
(ro:cond
|
||||
[(ro:list? v)
|
||||
(ro:apply ro:vector-immutable
|
||||
(ro:for/list ([i n]) (to-int (ro:list-ref v i))))]
|
||||
[(ro:vector? v)
|
||||
(ro:apply ro:vector-immutable
|
||||
(ro:for/list ([i n]) (to-int (ro:vector-ref v i))))]
|
||||
[else
|
||||
(ro:apply ro:vector-immutable
|
||||
(ro:make-list n (to-int v)))]))
|
||||
(ro:define (mk-intn x ...)
|
||||
(ro:#%app cl-mk-intn x ...)
|
||||
#;(ro:#%app ro:vector-immutable (to-int x) ...))
|
||||
)]))
|
||||
(syntax-parse stx
|
||||
[(_ n)
|
||||
#:with intn (format-id #'n "int~a" (syntax->datum #'n))
|
||||
#:with intn* (mk-ptr #'intn)
|
||||
#:with to-intn (format-id #'n "to-~a" #'intn)
|
||||
#:with mk-intn (mk-mk #'intn)
|
||||
#:with to-intn* (mk-ptr #'to-intn)
|
||||
#:with mk-intn* (mk-ptr #'mk-intn)
|
||||
#:with cl-mk-intn (mk-cl #'intn)
|
||||
#:with (x ...) (generate-temporaries (build-list (syntax->datum #'n) values))
|
||||
#:with (I ...) (stx-map (lambda _ #'rosette3:Int) #'(x ...))
|
||||
#'(begin
|
||||
(define-named-type-alias intn
|
||||
(add-constructm (add-convertm (rosette3:CVector I ...) to-intn) mk-intn))
|
||||
(define-named-type-alias intn* (add-convertm (Pointer intn) to-intn*))
|
||||
(ro:define (to-intn v)
|
||||
(ro:cond
|
||||
[(ro:list? v)
|
||||
(ro:apply mk-intn (ro:for/list ([i n]) (ro:#%app to-int (ro:list-ref v i))))]
|
||||
[(ro:vector? v)
|
||||
(ro:apply mk-intn (ro:for/list ([i n]) (ro:#%app to-int (ro:vector-ref v i))))]
|
||||
[else (ro:apply mk-intn (ro:make-list n (ro:#%app to-int v)))]))
|
||||
(ro:define (to-intn* v) (cl:pointer-cast v cl-mk-intn))
|
||||
(ro:define (mk-intn x ...) (ro:#%app cl-mk-intn x ...)))]))
|
||||
(define-simple-macro (define-ints n ...) (begin (define-int n) ...))
|
||||
(define-ints 2 3 4 16)
|
||||
|
||||
(define-syntax (define-float stx)
|
||||
(syntax-parse stx
|
||||
[(_ n)
|
||||
#:with floatn (format-id #'n "float~a" (syntax->datum #'n))
|
||||
#:with to-floatn (format-id #'n "to-~a" #'floatn)
|
||||
#:with mk-floatn (format-id #'n "mk-~a" #'floatn)
|
||||
#:with cl-mk-floatn (mk-cl #'floatn)
|
||||
#:with (x ...) (generate-temporaries
|
||||
(build-list (syntax->datum #'n) (lambda (x) x)))
|
||||
#:with (I ...) (stx-map (lambda _ #'rosette3:Num) #'(x ...))
|
||||
#'(begin
|
||||
(define-named-type-alias floatn
|
||||
(add-constructm
|
||||
(add-convertm
|
||||
(rosette3:CVector I ...)
|
||||
to-floatn)
|
||||
mk-floatn))
|
||||
(ro:define (to-floatn v)
|
||||
(ro:cond
|
||||
[(ro:list? v)
|
||||
(ro:apply ro:vector-immutable
|
||||
(ro:for/list ([i n]) (to-float (ro:list-ref v i))))]
|
||||
[(ro:vector? v)
|
||||
(ro:apply ro:vector-immutable
|
||||
(ro:for/list ([i n]) (to-float (ro:vector-ref v i))))]
|
||||
[else
|
||||
(ro:apply ro:vector-immutable
|
||||
(ro:make-list n (to-float v)))]))
|
||||
(ro:define (mk-floatn x ...)
|
||||
(ro:#%app cl-mk-floatn x ...)
|
||||
#;(ro:#%app ro:vector-immutable (to-float x) ...))
|
||||
)]))
|
||||
(syntax-parse stx
|
||||
[(_ n)
|
||||
#:with floatn (format-id #'n "float~a" (syntax->datum #'n))
|
||||
#:with to-floatn (format-id #'n "to-~a" #'floatn)
|
||||
#:with mk-floatn (mk-mk #'floatn)
|
||||
#:with cl-mk-floatn (mk-cl #'floatn)
|
||||
#:with (x ...) (generate-temporaries (build-list (syntax->datum #'n) values))
|
||||
#:with (I ...) (stx-map (lambda _ #'rosette3:Num) #'(x ...))
|
||||
#'(begin
|
||||
(define-named-type-alias floatn
|
||||
(add-constructm
|
||||
(add-convertm (rosette3:CVector I ...) to-floatn) mk-floatn))
|
||||
(ro:define (to-floatn v)
|
||||
(ro:cond
|
||||
[(ro:list? v)
|
||||
(ro:apply mk-floatn (ro:for/list ([i n]) (ro:#%app to-float (ro:list-ref v i))))]
|
||||
[(ro:vector? v)
|
||||
(ro:apply mk-floatn (ro:for/list ([i n]) (ro:#%app to-float (ro:vector-ref v i))))]
|
||||
[else (ro:apply mk-floatn (ro:make-list n (ro:#%app to-float v)))]))
|
||||
(ro:define (mk-floatn x ...) (ro:#%app cl-mk-floatn x ...)))]))
|
||||
(define-simple-macro (define-floats n ...) (begin (define-float n) ...))
|
||||
(define-floats 2 3 4 16)
|
||||
|
||||
|
||||
(define-type-constructor Pointer #:arity = 1)
|
||||
;(define-named-type-alias void rosette3:CUnit)
|
||||
(define-base-type void)
|
||||
#;(begin-for-syntax
|
||||
(define-syntax ~void*
|
||||
(pattern-expander
|
||||
(make-variable-like-transformer #'(~and t:type (~parse ~void #'t.norm))))))
|
||||
(define-named-type-alias void*
|
||||
(add-convertm (Pointer void) (λ (x) x)))
|
||||
(define-named-type-alias int*
|
||||
(add-convertm (Pointer int) to-int*))
|
||||
(define-named-type-alias int16*
|
||||
(add-convertm (Pointer int16) to-int16*))
|
||||
(define-named-type-alias int2*
|
||||
(add-convertm (Pointer int2) to-int2*))
|
||||
(define-named-type-alias float*
|
||||
(add-convertm (Pointer float) to-float*))
|
||||
|
||||
(define-typed-syntax synth-app
|
||||
[(_ (ty:type) e) ≫ ; cast
|
||||
[⊢ e ≫ e- ⇒ ty-e]
|
||||
|
@ -282,13 +264,8 @@
|
|||
[⊢ ptr ≫ ptr- ⇒ ty-ptr]
|
||||
#:when (pointer-type? #'ty-ptr) #:with ~! #'dummy ; commit
|
||||
[⊢ sel ≫ sel- ⇐ int]
|
||||
#:do [(define split-ty (ty->len #'ty-ptr))]
|
||||
#:when (and split-ty (= 3 (length split-ty)))
|
||||
#:do [(define base-str (cadr split-ty))
|
||||
(define len-str (caddr split-ty))]
|
||||
#:with ty-out ((current-type-eval) (format-id #'h "~a~a" base-str len-str))
|
||||
--------
|
||||
[⊢ (cl:pointer-ref ptr- sel-) ⇒ ty-out]]
|
||||
[⊢ (cl:pointer-ref ptr- sel-) ⇒ #,(get-pointer-base #'ty-ptr)]]
|
||||
[(_ vec sel) ≫ ; applying vector to one arg is selector
|
||||
[⊢ vec ≫ vec- ⇒ ty-vec]
|
||||
#:when (vector-type? #'ty-vec)
|
||||
|
@ -342,13 +319,13 @@
|
|||
(define- f-
|
||||
(lambda- (x ...)
|
||||
(rosette3:let ([x (⊢m (ro:#%app conv x) ty)] ...)
|
||||
(⊢m (let- () e ... (rosette3:ann e-body : ty-out)) ty-out)))))]])
|
||||
(⊢m (ro:let () e ... (rosette3:ann e-body : ty-out)) ty-out))))
|
||||
(provide- f))]])
|
||||
(define-typed-syntax kernel
|
||||
[(_ ty-out:type (f [ty:type x:id] ...) e ...) ≫
|
||||
#:fail-unless (void? #'ty-out.norm)
|
||||
(format "expected void, given ~a" (type->str #'ty-out.norm))
|
||||
--------
|
||||
[≻ (procedure void (f [ty x] ...) e ...)]])
|
||||
--- [≻ (procedure void (f [ty x] ...) e ...)]])
|
||||
(define-typed-syntax grammar
|
||||
[(_ ty-out:type (f [ty:type x:id] ...) e) ≫
|
||||
#:with f- (generate-temporary #'f)
|
||||
|
@ -369,23 +346,21 @@
|
|||
(define-typed-syntax if
|
||||
[(_ test {then ...} {else ...}) ≫
|
||||
--------
|
||||
[⊢ (ro:if (to-bool test)
|
||||
[⊢ (ro:if (ro:#%app to-bool test)
|
||||
(ro:let () then ... (ro:void))
|
||||
(ro:let () else ... (ro:void))) ⇒ void]]
|
||||
[(_ test {then ...}) ≫
|
||||
--------
|
||||
[≻ (if test {then ...} {})]])
|
||||
--- [≻ (if test {then ...} {})]])
|
||||
|
||||
(define-typed-syntax (range e ...) ≫
|
||||
[⊢ e ≫ e- ⇐ int] ...
|
||||
--------
|
||||
[⊢ (ro:#%app ro:in-range e- ...) ⇒ int])
|
||||
--- [⊢ (ro:#%app ro:in-range e- ...) ⇒ int])
|
||||
(define-typed-syntax for
|
||||
[(_ [((~literal :) ty:type var:id (~datum in) rangeExpr) ...] e ...) ≫
|
||||
[[var ≫ var- : ty.norm] ... ⊢ [e ≫ e- ⇒ ty-e] ...]
|
||||
[(_ [((~literal :) ty:type x:id (~datum in) rangeExpr) ...] e ...) ≫
|
||||
--------
|
||||
[⊢ (ro:for* ([var- rangeExpr] ...)
|
||||
e- ... (ro:void)) ⇒ void]])
|
||||
[⊢ (ro:for* ([x rangeExpr] ...)
|
||||
(rosette3:let ([x (⊢m x ty)] ...)
|
||||
(⊢m (ro:let () e ... (ro:void)) void))) ⇒ void]])
|
||||
|
||||
|
||||
;; need to redefine #%datum because rosette3:#%datum is too precise
|
||||
|
@ -433,10 +408,11 @@
|
|||
(format "no pred for ~a" (type->str #'ty))
|
||||
#:with (x- ...) (generate-temporaries #'(x ...))
|
||||
#:with (x-- ...) (generate-temporaries #'(x ...))
|
||||
#:with mk-ty (format-id #'here "mk-~a" #'ty)
|
||||
--------
|
||||
[≻ (begin-
|
||||
(ro:define-symbolic* x-- pred [#,(string->number len-str)]) ...
|
||||
(ro:define x- (ro:apply ro:vector-immutable x--)) ...
|
||||
(ro:define x- (ro:apply mk-ty x--)) ...
|
||||
(define-syntax- x
|
||||
(make-rename-transformer (assign-type #'x- #'ty.norm))) ...)]]
|
||||
[(_ ty:type [len] x:id ...) ≫ ; array of vector types
|
||||
|
@ -457,7 +433,7 @@
|
|||
[≻ (begin-
|
||||
(ro:define-symbolic* x-- pred [len base-len]) ...
|
||||
(ro:define x-
|
||||
(ro:let ([*x (to-ty* (cl:malloc (ro:* len base-len)))])
|
||||
(ro:let ([*x (ro:#%app to-ty* (cl:malloc (ro:* len base-len)))])
|
||||
(ro:for ([i len][v x--])
|
||||
(cl:pointer-set! *x i (ro:apply mk-ty v)))
|
||||
*x)) ...
|
||||
|
@ -506,12 +482,12 @@
|
|||
#:with ty-base ((current-type-eval) (datum->syntax #'e (string->symbol out-base-str)))
|
||||
#:with base-convert (get-convert #'ty-base)
|
||||
-------
|
||||
[⊢ (convert
|
||||
(ro:let ([a (convert e-)][b (convert e1-)][c (convert e2-)])
|
||||
[⊢ (ro:#%app convert
|
||||
(ro:let ([a (ro:#%app convert e-)][b (ro:#%app convert e1-)][c (ro:#%app convert e2-)])
|
||||
(ro:for/list ([idx #,(string->number out-len-str)])
|
||||
(ro:if (ro:< (ro:vector-ref a idx) 0)
|
||||
(base-convert (ro:vector-ref b idx))
|
||||
(base-convert (ro:vector-ref c idx))))))
|
||||
(ro:#%app base-convert (ro:vector-ref b idx))
|
||||
(ro:#%app base-convert (ro:vector-ref c idx))))))
|
||||
⇒ ty-out]]
|
||||
[(_ ~! e e1 e2) ≫ ; should be scalar and real
|
||||
[⊢ e ≫ e- ⇒ ty]
|
||||
|
@ -535,10 +511,10 @@
|
|||
[⊢ x ≫ x- ⇒ ty-x]
|
||||
[⊢ e ≫ e- ⇒ ty-e]
|
||||
#:fail-unless (cast-ok? #'ty-e #'ty-x stx)
|
||||
(format "cannot cast ~a to ~a"
|
||||
(type->str #'ty-e) (type->str #'ty-x))
|
||||
(format "cannot cast ~a to ~a" (type->str #'ty-e) (type->str #'ty-x))
|
||||
#:with conv (get-convert #'ty-x)
|
||||
--------
|
||||
[⊢ (ro:set! x- (synth-app (ty-x) e-)) ⇒ void]]
|
||||
[⊢ (ro:set! x- #,(if (syntax-e #'conv) #'(ro:#%app conv e-) #'e-)) ⇒ void]]
|
||||
;; selector can be list of numbers or up to wxyz for vectors of length <=4
|
||||
[(_ [x:id sel] e) ≫
|
||||
[⊢ x ≫ x- ⇒ ty-x]
|
||||
|
@ -561,11 +537,11 @@
|
|||
[(_ e) ≫
|
||||
[⊢ e ≫ e- ⇐ bool]
|
||||
--------
|
||||
[⊢ (cl:! e-) ⇒ bool]]
|
||||
[⊢ (ro:#%app cl:! e-) ⇒ bool]]
|
||||
[(_ e) ≫ ; else try to coerce
|
||||
[⊢ e ≫ e- ⇒ ty]
|
||||
--------
|
||||
[⊢ (cl:! (synth-app (bool) e-)) ⇒ bool]])
|
||||
[⊢ (ro:#%app cl:! (ro:#%app to-bool e-)) ⇒ bool]])
|
||||
|
||||
;; TODO: this should produce int-vector result?
|
||||
(define-typed-syntax ==
|
||||
|
@ -576,7 +552,7 @@
|
|||
#:when (real-type? #'ty2)
|
||||
#:with ty-out ((current-join) #'ty1 #'ty2) ; only need this for the len
|
||||
--------
|
||||
[⊢ (to-int (cl:== e1- e2-)) ⇒ int]])
|
||||
[⊢ (ro:#%app to-int (cl:== e1- e2-)) ⇒ int]])
|
||||
|
||||
(define-simple-macro (define-bool-ops o ...+) (ro:begin (define-bool-op o) ...))
|
||||
(define-simple-macro (define-bool-op name)
|
||||
|
@ -588,8 +564,7 @@
|
|||
--------
|
||||
[⊢ (name- e1- e2-) ⇒ bool]]
|
||||
[(_ e1 e2) ≫ ; else try to coerce
|
||||
--------
|
||||
[⊢ (name- (synth-app (bool) e1) (synth-app (bool) e2)) ⇒ bool]]))
|
||||
--- [⊢ (name- (ro:#%app to-bool e1) (ro:#%app to-bool e2)) ⇒ bool]]))
|
||||
|
||||
(define-simple-macro (define-real-ops o ...) (ro:begin (define-real-op o) ...))
|
||||
(define-simple-macro (define-real-op name (~optional (~seq #:extra-check p?)
|
||||
|
@ -597,32 +572,31 @@
|
|||
#:with name- (mk-cl #'name)
|
||||
#:with name= (format-id #'name "~a=" #'name) ; assignment form
|
||||
(begin-
|
||||
(define-typed-syntax (name e1 e2) ≫
|
||||
[⊢ e1 ≫ e1- ⇒ ty1]
|
||||
[⊢ e2 ≫ e2- ⇒ ty2]
|
||||
#:with ty-out (common-real-type #'ty1 #'ty2)
|
||||
(define-typed-syntax (name e (... ...)) ≫
|
||||
[⊢ e ≫ e- ⇒ ty] (... ...)
|
||||
#:with ty-out (apply common-real-type (stx->list #'(ty (... ...))))
|
||||
#:fail-unless (syntax-e #'ty-out)
|
||||
(format "no common real type for operands; given ~a, ~a"
|
||||
(type->str #'ty1) (type->str #'ty2))
|
||||
#:when (p? #'ty-out #'ty1 #'ty2)
|
||||
(format "no common real type for operands; given ~a"
|
||||
(types->str #'(ty (... ...))))
|
||||
#:when (p? #'ty-out #'(ty (... ...)))
|
||||
#:with convert (get-convert #'ty-out)
|
||||
#:with ty-base (get-base #'ty-out)
|
||||
#:with base-convert (get-convert #'ty-base)
|
||||
#:with (x (... ...)) (generate-temporaries #'(e (... ...)))
|
||||
--------
|
||||
[⊢ #,(if (scalar-type? #'ty-out)
|
||||
#'(convert (name- (convert e1-) (convert e2-)))
|
||||
#'(convert (ro:let ([a (convert e1-)][b (convert e2-)])
|
||||
(ro:for/list ([v1 a][v2 b])
|
||||
(base-convert (name- v1 v2)))))) ⇒ ty-out])
|
||||
#'(ro:#%app convert (name- (convert e-) (... ...)))
|
||||
#'(ro:#%app convert (ro:let ([x (ro:#%app convert e-)] (... ...))
|
||||
(ro:for/list ([x x] (... ...))
|
||||
(ro:#%app base-convert (name- x (... ...))))))) ⇒ ty-out])
|
||||
(define-typed-syntax (name= x e) ≫
|
||||
--------
|
||||
[≻ (= x (name x e))])))
|
||||
--- [≻ (= x (name x e))])))
|
||||
|
||||
(define-for-syntax (int? t given1 given2)
|
||||
(define-for-syntax (int? t givens)
|
||||
(or (typecheck/un? t #'int)
|
||||
(raise-syntax-error #f
|
||||
(format "no common integer type for operands; given ~a, ~a"
|
||||
(type->str given1) (type->str given2)))))
|
||||
(format "no common integer type for operands; given ~a"
|
||||
(types->str givens)))))
|
||||
(define-simple-macro (define-int-op o) (define-real-op o #:extra-check int?))
|
||||
(define-simple-macro (define-int-ops o ...) (ro:begin (define-int-op o) ...))
|
||||
|
||||
|
@ -631,12 +605,10 @@
|
|||
(define-int-ops % <<)
|
||||
|
||||
(define-typerule (sizeof t:type) >>
|
||||
----------
|
||||
[⊢ #,(real-type-length #'t.norm) ⇒ int])
|
||||
--- [⊢ #,(real-type-length #'t.norm) ⇒ int])
|
||||
|
||||
(define-typerule (print e ...) >>
|
||||
----------
|
||||
[⊢ (ro:begin (display e) ...) ⇒ void])
|
||||
--- [⊢ (ro:begin (display e) ...) ⇒ void])
|
||||
|
||||
(define-typed-syntax choose
|
||||
[(ch e ...+) ≫
|
||||
|
@ -656,26 +628,35 @@
|
|||
|
||||
(define-for-syntax (decl->seq stx)
|
||||
(syntax-parse stx
|
||||
[((~datum :) type id (~datum in) rangeExpr)
|
||||
(syntax/loc stx (id rangeExpr type))]
|
||||
[((~datum :) type id)
|
||||
(syntax/loc stx (id (ro:in-value (ro:let () (: type id) id)) type))]))
|
||||
[((~datum :) ty:type id (~datum in) rangeExpr)
|
||||
(syntax/loc stx (id rangeExpr ty.norm))]
|
||||
[((~datum :) ty:type [len] id)
|
||||
#:with tyout (mk-ptr #'ty)
|
||||
(syntax/loc stx (id (ro:in-value (ro:let () (: ty [len] id) id)) tyout))]
|
||||
[((~datum :) ty id)
|
||||
(syntax/loc stx (id (ro:in-value (ro:let () (: ty id) id)) ty))]))
|
||||
|
||||
(define-typed-syntax (synth #:forall [decl ...] #:ensure e) ≫
|
||||
(define-typed-syntax synth
|
||||
[(_ #:forall [decl ...] #:bitwidth bw #:ensure e) ≫
|
||||
#:with ([id seq ty] ...) (stx-map decl->seq #'(decl ...))
|
||||
#:with (id- ...) (generate-temporaries #'(id ...))
|
||||
#:with (typed-seq ...) #'((with-ctx ([id id- ty] ...) seq) ...)
|
||||
#:with (tmp ...) (generate-temporaries #'(id ...))
|
||||
--------
|
||||
[⊢ (ro:let ([id- 1] ...) ; dummy, enables simplifying stx template
|
||||
[⊢ (ro:let ([id- 1] ...) ; dummy ensuring id- bound, simplifies stx template
|
||||
(ro:define-values (tmp ...)
|
||||
(ro:for*/lists (tmp ...) ([id- typed-seq] ...) (ro:values id- ...)))
|
||||
(ro:parameterize ([ro:term-cache (ro:hash-copy (ro:term-cache))])
|
||||
(ro:parameterize ([ro:current-bitwidth bw]
|
||||
[ro:term-cache (ro:hash-copy (ro:term-cache))])
|
||||
(ro:print-forms
|
||||
(ro:synthesize
|
||||
#:forall (ro:append tmp ...)
|
||||
#:guarantee (ro:for ([id- tmp] ...)
|
||||
(with-ctx ([id id- ty] ...) e)))))) ⇒ void])
|
||||
(with-ctx ([id id- ty] ...) e)))))) ⇒ void]]
|
||||
[(_ #:forall [decl ...] #:ensure e) ≫
|
||||
--- [≻ (synth #:forall [decl ...] #:bitwidth 8 #:ensure e)]])
|
||||
|
||||
|
||||
|
||||
(define-typed-syntax verify
|
||||
[(vfy #:forall [decl ...] #:ensure e) ≫
|
||||
|
@ -684,7 +665,8 @@
|
|||
#:with (typed-seq ...) #'((with-ctx ([id id- ty] ...) seq) ...)
|
||||
--------
|
||||
[⊢ (ro:let ([id- 1] ...) ; dummy, enables simplifying stx template
|
||||
(ro:parameterize ([ro:term-cache (ro:hash-copy (ro:term-cache))])
|
||||
(ro:parameterize ([ro:current-bitwidth 32]
|
||||
[ro:term-cache (ro:hash-copy (ro:term-cache))])
|
||||
(ro:for*/or ([id- typed-seq] ...)
|
||||
(ro:define cex (with-ctx ([id id- ty] ...) (ro:verify e)))
|
||||
(ro:and (ro:sat? cex)
|
||||
|
@ -693,6 +675,12 @@
|
|||
(printf "~a = ~a\n" i (ro:evaluate i- cex))))))) ⇒ void]])
|
||||
|
||||
(define-typed-syntax (assert e) ≫
|
||||
#:with e- (expand/ro #'e)
|
||||
--------
|
||||
[⊢ (ro:assert (to-bool e-)) ⇒ void])
|
||||
--- [⊢ (ro:assert (ro:#%app to-bool #,(expand/ro #'e))) ⇒ void])
|
||||
|
||||
(define- (/ x y)
|
||||
(cond- [(zero?- y) 0]
|
||||
[(integer?- x) (quotient- x y)]
|
||||
[else (/- x y)]))
|
||||
|
||||
(define-typed-syntax (clCreateProgramWithSource ctx f) ≫
|
||||
--- [⊢ (cl:clCreateProgramWithSource ctx f) ⇒ cl_program])
|
||||
|
|
|
@ -0,0 +1,110 @@
|
|||
#lang s-exp "../../../rosette/synthcl3.rkt"
|
||||
|
||||
; Matrix multiplication C = A * B, where A is an n x p matrix and B is a p x m matrix.
|
||||
(kernel void (mmulScalarKernel [int* A] [int* B] [int* C] [int p] [int m])
|
||||
(: int i j sum)
|
||||
(= i (get_global_id 0))
|
||||
(= j (get_global_id 1))
|
||||
(= sum 0)
|
||||
(for [(: int k in (range p))]
|
||||
(+= sum (* [A (+ (* i p) k)] [B (+ (* k m) j)])))
|
||||
(= [C (+ (* i m) j)] sum))
|
||||
|
||||
; Matrix multiplication C = A * B, where A is an n x p matrix and B is a p x m matrix.
|
||||
(kernel void (mmulVectorKernel [int4* A] [int4* B] [int4* C] [int p] [int m])
|
||||
(: int i j)
|
||||
(: int4 sum0 sum1 sum2 sum3)
|
||||
|
||||
(= i (get_global_id 0))
|
||||
(= j (get_global_id 1))
|
||||
(= sum0 0)
|
||||
(= sum1 0)
|
||||
(= sum2 0)
|
||||
(= sum3 0)
|
||||
|
||||
(for [(: int k in (range 0 p 4))]
|
||||
(: int4 a0 a1 a2 a3)
|
||||
(: int4 b0 b1 b2 b3)
|
||||
|
||||
(= a0 [A (indexA 0 i k p)])
|
||||
(= a1 [A (indexA 1 i k p)])
|
||||
(= a2 [A (indexA 2 i k p)])
|
||||
(= a3 [A (indexA 3 i k p)])
|
||||
|
||||
(= b0 [B (indexB 0 k j m)])
|
||||
(= b1 [B (indexB 1 k j m)])
|
||||
(= b2 [B (indexB 2 k j m)])
|
||||
(= b3 [B (indexB 3 k j m)])
|
||||
|
||||
(+= sum0 (computeSum a0 b0 b1 b2 b3))
|
||||
(+= sum1 (computeSum a1 b0 b1 b2 b3))
|
||||
(+= sum2 (computeSum a2 b0 b1 b2 b3))
|
||||
(+= sum3 (computeSum a3 b0 b1 b2 b3)))
|
||||
(= [C (indexC 0 i j m)] sum0)
|
||||
(= [C (indexC 1 i j m)] sum1)
|
||||
(= [C (indexC 2 i j m)] sum2)
|
||||
(= [C (indexC 3 i j m)] sum3))
|
||||
|
||||
; Multiplies the 1x4 vector a by the 4x4 matrix with rows b0, b1, b2 and b3.
|
||||
(procedure int4 (computeSum [int4 a] [int4 b0] [int4 b1] [int4 b2] [int4 b3])
|
||||
(int4
|
||||
(+ (* [a x] [b0 x]) (* [a y] [b1 x]) (* [a z] [b2 x]) (* [a w] [b3 x]))
|
||||
(+ (* [a x] [b0 y]) (* [a y] [b1 y]) (* [a z] [b2 y]) (* [a w] [b3 y]))
|
||||
(+ (* [a x] [b0 z]) (* [a y] [b1 z]) (* [a z] [b2 z]) (* [a w] [b3 z]))
|
||||
(+ (* [a x] [b0 w]) (* [a y] [b1 w]) (* [a z] [b2 w]) (* [a w] [b3 w]))))
|
||||
|
||||
; Function for indexing into the matrix A.
|
||||
(procedure int (indexA [int off] [int i] [int k] [int p])
|
||||
; (+ (* (+ (* i 4) off) (/ p 4)) (/ k 4)))
|
||||
(: int r c w)
|
||||
(= r (+ (choose i (/ i 4) (* i 4)) (choose off 0)))
|
||||
(= c (+ (choose k (/ k 4) (* k 4)) (choose off 0)))
|
||||
(= w (+ (choose p (/ p 4) (* p 4)) (choose off 0)))
|
||||
(+ (* r w) c))
|
||||
|
||||
; Function for indexing into the matrix B.
|
||||
(procedure int (indexB [int off] [int k] [int j] [int m])
|
||||
;(+ (* (+ k off) (/ m 4)) j))
|
||||
(: int r c w)
|
||||
(= r (+ (choose k (/ k 4) (* k 4)) (choose off 0)))
|
||||
(= c (+ (choose j (/ j 4) (* j 4)) (choose off 0)))
|
||||
(= w (+ (choose m (/ m 4) (* m 4)) (choose off 0)))
|
||||
(+ (* r w) c))
|
||||
|
||||
; Function for indexing into the matrix C.
|
||||
(procedure int (indexC [int off] [int i] [int j] [int m])
|
||||
;(+ (* (+ (* i 4) off) (/ m 4)) j))
|
||||
(: int r c w)
|
||||
(= r (+ (choose i (/ i 4) (* i 4)) (choose off 0)))
|
||||
(= c (+ (choose j (/ j 4) (* j 4)) (choose off 0)))
|
||||
(= w (+ (choose m (/ m 4) (* m 4)) (choose off 0)))
|
||||
(+ (* r w) c))
|
||||
|
||||
|
||||
; Bad sketch and completions:
|
||||
; Function for indexing into the matrix A.
|
||||
;(procedure int (indexA [int off] [int i] [int k] [int p])
|
||||
; (+ (* (+ (/ i 4) off) (/ p 4)) (/ k 4)))
|
||||
;(: int r c w)
|
||||
;(= r (+ (choose i (/ i 4)) (choose off 0)))
|
||||
;(= c (+ (choose k (/ k 4)) (choose off 0)))
|
||||
;(= w (+ (choose p (/ p 4)) (choose off 0)))
|
||||
;(+ (* r w) c))
|
||||
|
||||
; Function for indexing into the matrix B.
|
||||
;(procedure int (indexB [int off] [int k] [int j] [int m])
|
||||
;(+ (* (+ k off) (/ m 4)) j))
|
||||
;(: int r c w)
|
||||
;(= r (+ (choose k (/ k 4)) (choose off 0)))
|
||||
;(= c (+ (choose j (/ j 4)) (choose off 0)))
|
||||
;(= w (+ (choose m (/ m 4)) (choose off 0)))
|
||||
;(+ (* r w) c))
|
||||
|
||||
; Function for indexing into the matrix C.
|
||||
;(procedure int (indexC [int off] [int i] [int j] [int m])
|
||||
;(+ (* (+ (/ i 4) off) (/ m 4)) j))
|
||||
;(: int r c w)
|
||||
;(= r (+ (choose i (/ i 4)) (choose off 0)))
|
||||
;(= c (+ (choose j (/ j 4)) (choose off 0)))
|
||||
;(= w (+ (choose m (/ m 4)) (choose off 0)))
|
||||
;(+ (* r w) c))
|
|
@ -3,98 +3,104 @@
|
|||
(prefix-in cl: sdsl/synthcl/lang/main)
|
||||
(prefix-in ro: (rename-in rosette [#%app a])))
|
||||
|
||||
(: int2 [3] xs)
|
||||
xs
|
||||
(check-type xs : int2*)
|
||||
|
||||
(: int [4] xs2)
|
||||
xs2
|
||||
(check-type xs2 : int*)
|
||||
|
||||
|
||||
|
||||
;; ; The reference implementation for square matrix multiplication.
|
||||
;; ; Multiplies two squre matrices A and B, where the dimension of A is
|
||||
;; ; n x p and dimension of B is p x m. Both matrices are given as
|
||||
;; ; flat arrays in row-major form. The output is the matrix C = A*B,
|
||||
;; ; also given in row-major form.
|
||||
;; (procedure int* (mmulSequential [int* A] [int* B] [int n] [int p] [int m])
|
||||
;; (: int* C)
|
||||
;; (= C ((int*) (malloc (* n m (sizeof int)))))
|
||||
;; (for [(: int i in (range n))
|
||||
;; (: int j in (range m))
|
||||
;; (: int k in (range p))]
|
||||
;; (+= [C (+ (* i m) j)] (* [A (+ (* i p) k)] [B (+ (* k m) j)])))
|
||||
;; C)
|
||||
; The reference implementation for square matrix multiplication.
|
||||
; Multiplies two squre matrices A and B, where the dimension of A is
|
||||
; n x p and dimension of B is p x m. Both matrices are given as
|
||||
; flat arrays in row-major form. The output is the matrix C = A*B,
|
||||
; also given in row-major form.
|
||||
(procedure int* (mmulSequential [int* A] [int* B] [int n] [int p] [int m])
|
||||
(: int* C)
|
||||
(= C ((int*) (malloc (* n m (sizeof int)))))
|
||||
(for [(: int i in (range n))
|
||||
(: int j in (range m))
|
||||
(: int k in (range p))]
|
||||
(+= [C (+ (* i m) j)] (* [A (+ (* i p) k)] [B (+ (* k m) j)])))
|
||||
C)
|
||||
|
||||
; A host implementation of matrix multiplication.
|
||||
(procedure int* (mmulHost [char* kernelName] [int typeLen]
|
||||
[int* A] [int* B] [int n] [int p] [int m])
|
||||
;; (: cl_context context)
|
||||
;; (: cl_command_queue command_queue)
|
||||
;; (: cl_program program)
|
||||
;; (: cl_kernel kernel)
|
||||
;; (: cl_mem buffer_A buffer_B buffer_C)
|
||||
(: int* C)
|
||||
;; (: int[2] global)
|
||||
;; (: int dimA dimB dimC)
|
||||
|
||||
;; (= [global 0] (/ n typeLen))
|
||||
;; (= [global 1] (/ m typeLen))
|
||||
;; (= dimA (* n p (sizeof int)))
|
||||
;; (= dimB (* p m (sizeof int)))
|
||||
;; (= dimC (* n m (sizeof int)))
|
||||
|
||||
;; (= C ((int*) (malloc dimC)))
|
||||
|
||||
;; (= context (clCreateContext))
|
||||
|
||||
;; (= command_queue (clCreateCommandQueue context))
|
||||
|
||||
;; (= buffer_A (clCreateBuffer context CL_MEM_READ_ONLY dimA))
|
||||
;; (= buffer_B (clCreateBuffer context CL_MEM_READ_ONLY dimB))
|
||||
;; (= buffer_C (clCreateBuffer context CL_MEM_WRITE_ONLY dimC))
|
||||
|
||||
;; (= program (clCreateProgramWithSource context "kernel.rkt"))
|
||||
|
||||
;; (clEnqueueWriteBuffer command_queue buffer_A 0 dimA A)
|
||||
;; (clEnqueueWriteBuffer command_queue buffer_B 0 dimB B)
|
||||
|
||||
;; (= kernel (clCreateKernel program kernelName))
|
||||
;; (clSetKernelArg kernel 0 buffer_A)
|
||||
;; (clSetKernelArg kernel 1 buffer_B)
|
||||
;; (clSetKernelArg kernel 2 buffer_C)
|
||||
;; (clSetKernelArg kernel 3 p)
|
||||
;; (clSetKernelArg kernel 4 m)
|
||||
(: cl_context context)
|
||||
(: cl_command_queue command_queue)
|
||||
(: cl_program program)
|
||||
(: cl_kernel kernel)
|
||||
(: cl_mem buffer_A buffer_B buffer_C)
|
||||
|
||||
;; (clEnqueueNDRangeKernel command_queue kernel 2 NULL global NULL)
|
||||
;; (clEnqueueReadBuffer command_queue buffer_C 0 dimC C)
|
||||
(: int* C)
|
||||
(: int[2] global)
|
||||
(: int dimA dimB dimC)
|
||||
(= [global 0] (/ n typeLen))
|
||||
(= [global 1] (/ m typeLen))
|
||||
(= dimA (* n p (sizeof int)))
|
||||
(= dimB (* p m (sizeof int)))
|
||||
(= dimC (* n m (sizeof int)))
|
||||
|
||||
(= C ((int*) (malloc dimC)))
|
||||
|
||||
(= context (clCreateContext))
|
||||
|
||||
(= command_queue (clCreateCommandQueue context))
|
||||
|
||||
(= buffer_A (clCreateBuffer context CL_MEM_READ_ONLY dimA))
|
||||
(= buffer_B (clCreateBuffer context CL_MEM_READ_ONLY dimB))
|
||||
(= buffer_C (clCreateBuffer context CL_MEM_WRITE_ONLY dimC))
|
||||
|
||||
(= program (clCreateProgramWithSource context "matrix-synth-kernel.rkt"))
|
||||
|
||||
(clEnqueueWriteBuffer command_queue buffer_A 0 dimA A)
|
||||
(clEnqueueWriteBuffer command_queue buffer_B 0 dimB B)
|
||||
|
||||
(= kernel (clCreateKernel program kernelName))
|
||||
(clSetKernelArg kernel 0 buffer_A)
|
||||
(clSetKernelArg kernel 1 buffer_B)
|
||||
(clSetKernelArg kernel 2 buffer_C)
|
||||
(clSetKernelArg kernel 3 p)
|
||||
(clSetKernelArg kernel 4 m)
|
||||
|
||||
(clEnqueueNDRangeKernel command_queue kernel 2 NULL global NULL)
|
||||
(clEnqueueReadBuffer command_queue buffer_C 0 dimC C)
|
||||
C)
|
||||
;; ; A scalar parallel implementation of matrix multiplication.
|
||||
;; (procedure int* (mmulScalar [int* A] [int* B] [int n] [int p] [int m])
|
||||
;; (mmulHost "mmulScalarKernel" 1 A B n p m))
|
||||
; A scalar parallel implementation of matrix multiplication.
|
||||
(procedure int* (mmulScalar [int* A] [int* B] [int n] [int p] [int m])
|
||||
(mmulHost "mmulScalarKernel" 1 A B n p m))
|
||||
|
||||
; A vector parallel implementation of matrix multiplication. The dimensions
|
||||
; n and m must be evenly divisible by 4.
|
||||
#;(procedure int* (mmulVector [int* A] [int* B] [int n] [int p] [int m])
|
||||
(procedure int* (mmulVector [int* A] [int* B] [int n] [int p] [int m])
|
||||
(mmulHost "mmulVectorKernel" 4 A B n p m))
|
||||
|
||||
;; ; Given two arrays of the same size, checks that they hold the same
|
||||
;; ; values at each index.
|
||||
;; (procedure void (check [int* actual] [int* expected] [int SIZE])
|
||||
;; (assert (>= SIZE 0))
|
||||
;; (for [(: int i in (range SIZE))]
|
||||
;; (assert (== [actual i] [expected i]))))
|
||||
; Given two arrays of the same size, checks that they hold the same
|
||||
; values at each index.
|
||||
(procedure void (check [int* actual] [int* expected] [int SIZE])
|
||||
(assert (>= SIZE 0))
|
||||
(for [(: int i in (range SIZE))]
|
||||
(assert (== [actual i] [expected i]))))
|
||||
|
||||
;; (procedure void (synth_vector [int size])
|
||||
;; (synth #:forall [(: int n in (range size (+ 1 size)))
|
||||
;; (: int p in (range size (+ 1 size)))
|
||||
;; (: int m in (range size (+ 1 size)))
|
||||
;; (: int[(* n p)] A)
|
||||
;; (: int[(* p m)] B)]
|
||||
;; #:ensure (check (mmulVector A B n p m)
|
||||
;; (mmulSequential A B n p m)
|
||||
;; (* n m))))
|
||||
(procedure void (synth_vector [int size])
|
||||
(synth #:forall [(: int n in (range size (+ 1 size)))
|
||||
(: int p in (range size (+ 1 size)))
|
||||
(: int m in (range size (+ 1 size)))
|
||||
(: int[(* n p)] A)
|
||||
(: int[(* p m)] B)]
|
||||
#:ensure (check (mmulVector A B n p m)
|
||||
(mmulSequential A B n p m)
|
||||
(* n m))))
|
||||
|
||||
;; ; (synth_vector 4) ; 20 sec
|
||||
;; ; (synth_vector 8) ; 252 sec
|
||||
(: int n)
|
||||
(= n 4)
|
||||
(: int p)
|
||||
(= p 4)
|
||||
(: int m)
|
||||
(= m 4)
|
||||
(: int[(* n p)] A)
|
||||
(: int[(* p m)] B)
|
||||
(check-type (mmulVector A B n p m) : int*)
|
||||
(check-type (mmulSequential A B n p m) : int*)
|
||||
|
||||
(check-type
|
||||
(with-output-to-string
|
||||
(λ ()
|
||||
(synth_vector 4))) ; 20 sec
|
||||
: CString
|
||||
-> "/home/stchang/NEU_Research/macrotypes/turnstile/examples/tests/rosette/rosette3/matrix-synth-kernel.rkt:57:0\n'(procedure\n int\n (indexA (int off) (int i) (int k) (int p))\n (: int r c w)\n (= r (+ (choose i (/ i 4) (* i 4)) off))\n (= c (+ (choose k (/ k 4) (* k 4)) 0))\n (= w (+ (/ p 4) 0))\n (+ (* r w) c))\n/home/stchang/NEU_Research/macrotypes/turnstile/examples/tests/rosette/rosette3/matrix-synth-kernel.rkt:66:0\n'(procedure\n int\n (indexB (int off) (int k) (int j) (int m))\n (: int r c w)\n (= r (+ (choose k (/ k 4) (* k 4)) off))\n (= c (+ (choose j (/ j 4) (* j 4)) 0))\n (= w (+ (/ m 4) 0))\n (+ (* r w) c))\n/home/stchang/NEU_Research/macrotypes/turnstile/examples/tests/rosette/rosette3/matrix-synth-kernel.rkt:75:0\n'(procedure\n int\n (indexC (int off) (int i) (int j) (int m))\n (: int r c w)\n (= r (+ (choose i (/ i 4) (* i 4)) off))\n (= c (+ (choose j (/ j 4) (* j 4)) 0))\n (= w (+ (/ m 4) 0))\n (+ (* r w) c))\n")
|
||||
;(synth_vector 8) ; 252 sec
|
||||
|
|
|
@ -345,3 +345,36 @@
|
|||
{(assert k)}))))
|
||||
: CString
|
||||
-> "counterexample found:\nt = 2\nk = 0\np = 4\n")
|
||||
|
||||
(: int2 [3] xs)
|
||||
(check-type xs : int2*)
|
||||
|
||||
(: int [4] xs2)
|
||||
(check-type xs2 : int*)
|
||||
|
||||
; basic matrix multiplying
|
||||
(: int4 sum0 sum1 sum2 sum3)
|
||||
(= sum0 0)
|
||||
(= sum1 0)
|
||||
(= sum2 0)
|
||||
(= sum3 0)
|
||||
|
||||
(procedure int (computeSum1 [int4 a] [int4 b0] [int4 b1] [int4 b2] [int4 b3])
|
||||
(+ (* [a x] [b0 x]) (* [a y] [b1 x]) (* [a z] [b2 x]) (* [a w] [b3 x])))
|
||||
(procedure int (computeSum2 [int4 a] [int4 b0] [int4 b1] [int4 b2] [int4 b3])
|
||||
(+ (* [a x] [b0 y]) (* [a y] [b1 y]) (* [a z] [b2 y]) (* [a w] [b3 y])))
|
||||
(procedure int (computeSum3 [int4 a] [int4 b0] [int4 b1] [int4 b2] [int4 b3])
|
||||
(+ (* [a x] [b0 z]) (* [a y] [b1 z]) (* [a z] [b2 z]) (* [a w] [b3 z])))
|
||||
(procedure int (computeSum4 [int4 a] [int4 b0] [int4 b1] [int4 b2] [int4 b3])
|
||||
(+ (* [a x] [b0 w]) (* [a y] [b1 w]) (* [a z] [b2 w]) (* [a w] [b3 w])))
|
||||
|
||||
(check-type (computeSum1 sum0 sum0 sum1 sum2 sum3) : int -> 0)
|
||||
(check-type (computeSum2 sum0 sum0 sum1 sum2 sum3) : int -> 0)
|
||||
(check-type (computeSum3 sum0 sum0 sum1 sum2 sum3) : int -> 0)
|
||||
(check-type (computeSum4 sum0 sum0 sum1 sum2 sum3) : int -> 0)
|
||||
(check-type (int4 (computeSum1 sum0 sum0 sum1 sum2 sum3)
|
||||
(computeSum2 sum0 sum0 sum1 sum2 sum3)
|
||||
(computeSum3 sum0 sum0 sum1 sum2 sum3)
|
||||
(computeSum4 sum0 sum0 sum1 sum2 sum3))
|
||||
: int4
|
||||
-> (ro:a ro:vector-immutable 0 0 0 0))
|
||||
|
|
Loading…
Reference in New Issue
Block a user