add macro to define int and float types; other cleanup

This commit is contained in:
Stephen Chang 2016-12-08 15:13:08 -05:00
parent f0cf86e9bd
commit 25a66ae60e
4 changed files with 140 additions and 188 deletions

View File

@ -0,0 +1,30 @@
#lang racket
(require (for-syntax syntax/stx racket/syntax) syntax/parse/define)
;; in general, must work with forms and fns from model/, since lang/ often
;; uses (synthcl) type-directed macros (and not typed rosette types)
(define-for-syntax (mk-model-path m) (format-id m "sdsl/synthcl/model/~a" m))
(define-simple-macro (require+provide/synthcl/model x ...)
#:with (m ...) (stx-map mk-model-path #'(x ...))
(begin (require (combine-in m ...))
(provide (combine-out (all-from-out m) ...))))
(require+provide/synthcl/model buffer
context
errors
flags
kernel
memory
operators
pointers
program
queue
reals
runtime
work)
(require (for-syntax (only-in sdsl/synthcl/lang/util parse-selector))
(only-in sdsl/synthcl/lang/forms ??))
(provide (for-syntax parse-selector) ??)

View File

@ -1,120 +1,77 @@
#lang turnstile
(extends "rosette3.rkt" #:except ! #%app || && void = * + - / #%datum if assert verify < <= > >=) ; typed rosette
(require ;(prefix-in ro: (except-in rosette verify sqrt range print)) ; untyped
racket/stxparam
(prefix-in ro: (combine-in rosette rosette/lib/synthax))
(prefix-in cl: (combine-in sdsl/synthcl/model/operators
sdsl/synthcl/lang/forms sdsl/synthcl/model/reals
sdsl/synthcl/model/errors sdsl/synthcl/model/kernel
sdsl/synthcl/model/memory sdsl/synthcl/model/runtime
sdsl/synthcl/model/work sdsl/synthcl/model/pointers
sdsl/synthcl/lang/queries sdsl/synthcl/model/context
sdsl/synthcl/model/queue sdsl/synthcl/model/buffer
sdsl/synthcl/model/flags sdsl/synthcl/model/program))
(for-syntax (prefix-in cl: sdsl/synthcl/lang/util)))
(require (prefix-in ro: (combine-in rosette rosette/lib/synthax))
(prefix-in cl: "synthcl-model.rkt"))
(begin-for-syntax
(define (mk-cl id) (format-id #'here "cl:~a" id))
(current-host-lang mk-cl))
(provide (rename-out [synth-app #%app])
procedure kernel grammar #%datum if range for print
choose ?? @ locally-scoped assert synth verify
int int2 int3 int4 int16 float float2 float3 float4 float16
bool void void* char*
int* int2* int3* int4* int16* float* float2* float3* float4* float16*
cl_context cl_command_queue cl_program cl_kernel cl_mem
: ! ?: == + * / - sqrt || &&
% << $ & > >= < <= ; int ops
= += -= *= /= %= $= &= ; assignment ops
sizeof clCreateProgramWithSource
(typed-out
[clCreateContext : (C→ cl_context)]
[clCreateCommandQueue : (C→ cl_context cl_command_queue)]
[clCreateBuffer : (C→ cl_context int int cl_mem)]
[clEnqueueReadBuffer : (C→ cl_command_queue cl_mem int int void* void)]
(provide
(rename-out [synth-app #%app])
procedure kernel grammar #%datum if range for print
choose ?? @ locally-scoped assert synth verify
int int2 int3 int4 int16 float float2 float3 float4 float16
int* int2* int3* int4* int16* float* float2* float3* float4* float16*
bool void void* char*
cl_context cl_command_queue cl_program cl_kernel cl_mem
: ! ?: == + * / - sqrt || && % << $ & > >= < <= != = += -= *= /= %= $= &=
sizeof clCreateProgramWithSource
(typed-out
[clCreateContext : (C→ cl_context)]
[clCreateCommandQueue : (C→ cl_context cl_command_queue)]
[clCreateBuffer : (C→ cl_context int int cl_mem)]
[clEnqueueReadBuffer : (C→ cl_command_queue cl_mem int int void* void)]
[clEnqueueWriteBuffer : (C→ cl_command_queue cl_mem int int void* void)]
[clEnqueueNDRangeKernel : (C→ cl_command_queue cl_kernel int int* int* int* void)]
[clCreateKernel : (C→ cl_program char* cl_kernel)]
[clSetKernelArg : (Ccase-> (C→ cl_kernel int cl_mem void)
(C→ cl_kernel int int void)
(C→ cl_kernel int float void))]
[get_global_id : (C→ int int)]
[CL_MEM_READ_ONLY : int] [CL_MEM_WRITE_ONLY : int] [CL_MEM_READ_WRITE : int]
[malloc : (C→ int void*)]
[memset : (C→ void* int int void*)]
[convert_float4 : (Ccase-> (C→ int4 float4) (C→ float4 float4))]
[convert_int4 : (Ccase-> (C→ int4 int4) (C→ float4 int4))]
[get_work_dim : (C→ int)]
[NULL : void*]))
[clEnqueueWriteBuffer : (C→ cl_command_queue cl_mem int int void* void)]
[clEnqueueNDRangeKernel : (C→ cl_command_queue cl_kernel int int* int* int* void)]
[clCreateKernel : (C→ cl_program char* cl_kernel)]
[clSetKernelArg : (Ccase-> (C→ cl_kernel int cl_mem void)
(C→ cl_kernel int int void)
(C→ cl_kernel int float void))]
[get_global_id : (C→ int int)]
[CL_MEM_READ_ONLY : int]
[CL_MEM_WRITE_ONLY : int]
[CL_MEM_READ_WRITE : int]
[malloc : (C→ int void*)]
[memset : (C→ void* int int void*)]
[convert_float4 : (Ccase-> (C→ int4 float4) (C→ float4 float4))]
[convert_int4 : (Ccase-> (C→ int4 int4) (C→ float4 int4))]
[get_work_dim : (C→ int)]
[!= : (Ccase-> (C→ CNum CNum CBool)
(C→ CNum CNum CNum CBool)
(C→ Num Num Bool)
(C→ Num Num Num Bool))]
[NULL : void*]))
(begin-for-syntax
;; TODO: use equality type relation instead of subtype
;; - require reimplementing many more things, eg #%datum, +, etc
;; typecheck unexpanded types
(define (typecheck/un? t1 t2)
(typecheck? ((current-type-eval) t1)
((current-type-eval) t2)))
(current-typecheck-relation (current-type=?)) ; no subtyping
(define (typecheck/un? t1 t2) ; typecheck unexpanded types
(typecheck? ((current-type-eval) t1) ((current-type-eval) t2)))
(define (pointer-type? t) (Pointer? t))
(define (real-type? t)
(and #;(not (typecheck/un? t #'bool))
(not (typecheck/un? t #'char*))
(not (pointer-type? t))))
(define (pointer-type? t)
(Pointer? t)
#;(regexp-match #px"\\*$" (type->str t)))
(and (not (pointer-type? t)) (not (typecheck/un? t #'char*))))
(define (real-type<=? t1 t2)
(and (real-type? t1) (real-type? t2)
(or ((current-type=?) t1 t2) ; need type= to distinguish reals/ints
(or (typecheck? t1 t2) ; need type= to distinguish reals/ints
(typecheck/un? t1 #'bool)
(and (typecheck/un? t1 #'int)
(not (typecheck/un? t2 #'bool)))
(and (typecheck/un? t1 #'int) (not (typecheck/un? t2 #'bool)))
(and (typecheck/un? t1 #'float)
(typecheck/un? (get-base t2) #'float)))))
; Returns the common real type of the given types, as specified in
; Ch. 6.2.6 of opencl-1.2 specification. If there is no common
; real type, returns #f.
(typecheck/un? (get-base/un t2) #'float)))))
;; same as common-real-type from model/reals.rkt
;; Returns the common real type of the given types, as specified in
;; Ch. 6.2.6 of opencl-1.2 specification. returns #f if none
(define common-real-type
(case-lambda
[(t) (and (real-type? t) t)]
[(t1 t2) (or (and (real-type<=? t1 t2) t2)
(and (real-type<=? t2 t1) t1))]
[ts (common-real-type (car ts) (apply common-real-type (cdr ts)))]))
;; implements common-real-type from model/reals.rkt
; Returns the common real type of the given types, as specified in
; Ch. 6.2.6 of opencl-1.2 specification. If there is no common
; real type, returns #f.
(current-join common-real-type)
;; copied from check-implicit-conversion in lang/types.rkt
;; TODO: this should not exception since it is used in stx-parse
;; clauses that may want to backtrack
;; TODO: should this exn? it is used in stx-parse that may want to backtrack
(define (cast-ok? from to [expr #f] [subexpr #f])
#;(printf "casting ~a to ~a: ~a\n" (type->str from) (type->str to)
(or (typecheck/un? from to)
(and (scalar-type? from) (scalar-type? to))
(and (scalar-type? from) (vector-type? to))
(and (pointer-type? from) (pointer-type? to))
#;(and (equal? from cl_mem) (pointer-type? to))))
(unless (if #t #;(type? to)
(or (typecheck/un? from to)
(and (scalar-type? from) (scalar-type? to))
(and (scalar-type? from) (vector-type? to))
(and (pointer-type? from) (pointer-type? to))
#;(and (equal? from cl_mem) (pointer-type? to)))
(to from))
(raise-syntax-error
#f
(format "no implicit conversion from ~a to ~a"
(type->str from) (type->str to)
#;(if (contract? to) (contract-name to) to))
expr subexpr)))
(unless (or (typecheck/un? from to)
(and (scalar-type? from) (scalar-type? to))
(and (scalar-type? from) (vector-type? to))
(and (pointer-type? from) (pointer-type? to))
#;(and (equal? from cl_mem) (pointer-type? to)))
(raise-syntax-error #f
(format "no implicit conversion from ~a to ~a"
(type->str from) (type->str to)) expr subexpr)))
(define (mk-ptr id) (format-id id "~a*" id))
(define (mk-mk id) (format-id id "mk-~a" id))
(define (mk-to id) (format-id id "to-~a" id))
@ -128,11 +85,11 @@
(define split-ty (ty->len t))
(string->number
(or (and split-ty (third split-ty)) "1")))
(define (get-base ty [ctx #'here])
((current-type-eval)
(datum->syntax ctx
(string->symbol (car (regexp-match #px"[a-z]+" (type->str ty)))))))
(define (get-pointer-base ty [ctx #'here])
(define (get-base/un ty [ctx #'here]) ; returns unexpanded base type
(datum->syntax ctx
(string->symbol (car (regexp-match #px"[a-z]+" (type->str ty))))))
(define (get-base ty [ctx #'here]) ((current-type-eval) (get-base/un ty ctx)))
(define (get-pointer-base ty [ctx #'here]) ; returns unexpanded ptr base
(datum->syntax ctx (string->symbol (string-trim (type->str ty) "*"))))
(define (vector-type? ty)
(define tstr (type->str ty))
@ -144,91 +101,57 @@
(define-syntax-parser add-convertm [(_ stx fn) (add-convert #'stx #'fn)])
(define-syntax-parser add-constructm [(_ stx fn) (add-construct #'stx #'fn)])
(ro:define (to-bool v) ; TODO: reuse impls in model/reals.rkt ?
(ro:cond [(ro:boolean? v) v]
[(ro:number? v) (ro:! (ro:= 0 v))]
[else (cl:raise-conversion-error v "bool")]))
(ro:define (to-int v)
(ro:cond [(ro:boolean? v) (ro:if v 1 0)]
[(ro:fixnum? v) v]
[(ro:flonum? v) (ro:exact-truncate v)]
[else (ro:real->integer v)]))
(ro:define (to-float v)
(ro:cond [(ro:boolean? v) (ro:if v 1.0 0.0)]
[(ro:fixnum? v) (ro:exact->inexact v)]
[(ro:flonum? v) v]
[else (ro:type-cast ro:real? v)]))
(ro:define (mk-int v) (ro:#%app cl:int v))
(ro:define (mk-float v) (ro:#%app cl:float v))
(ro:define (to-int* v) (cl:pointer-cast v cl:int))
(ro:define (to-float* v) (cl:pointer-cast v cl:float))
(ro:define (to-bool v) (ro:#%app (ro:#%app cl:bool) v))
(define-type-constructor Pointer #:arity = 1)
;(define-named-type-alias void rosette3:CUnit)
(define-base-types void cl_context cl_command_queue cl_program cl_kernel cl_mem)
(define-named-type-alias void* (add-convertm (Pointer void) (λ (x) x)))
(define-named-type-alias bool (add-convertm rosette3:Bool to-bool))
(define-named-type-alias int (add-convertm rosette3:Int to-int))
(define-named-type-alias int* (add-convertm (Pointer int) to-int*))
(define-named-type-alias float (add-convertm rosette3:Num to-float))
(define-named-type-alias float* (add-convertm (Pointer float) to-float*))
(define-named-type-alias char* (add-convertm rosette3:CString (λ (x) x)))
(define-type-constructor Pointer #:arity = 1)
(define-named-type-alias void* (Pointer void))
(define-named-type-alias char* rosette3:CString)
(define-named-type-alias bool (add-convertm rosette3:Bool to-bool))
(define-syntax (define-int stx)
(syntax-parse stx
[(_ n)
#:with intn (format-id #'n "int~a" (syntax->datum #'n))
#:with intn* (mk-ptr #'intn)
#:with to-intn (format-id #'n "to-~a" #'intn)
#:with mk-intn (mk-mk #'intn)
#:with to-intn* (mk-ptr #'to-intn)
#:with mk-intn* (mk-ptr #'mk-intn)
#:with cl-mk-intn (mk-cl #'intn)
#:with (x ...) (generate-temporaries (build-list (syntax->datum #'n) values))
#:with (I ...) (stx-map (lambda _ #'rosette3:Int) #'(x ...))
#'(begin
(define-named-type-alias intn
(add-constructm (add-convertm (rosette3:CVector I ...) to-intn) mk-intn))
(define-named-type-alias intn* (add-convertm (Pointer intn) to-intn*))
(ro:define (to-intn v)
(ro:cond
[(ro:list? v)
(ro:apply mk-intn (ro:for/list ([i n]) (to-int (ro:list-ref v i))))]
[(ro:vector? v)
(ro:apply mk-intn (ro:for/list ([i n]) (to-int (ro:vector-ref v i))))]
[else (ro:apply mk-intn (ro:make-list n (to-int v)))]))
(ro:define (to-intn* v) (cl:pointer-cast v cl-mk-intn))
(ro:define (mk-intn x ...) (ro:#%app cl-mk-intn x ...)))]))
(define-simple-macro (define-ints n ...) (begin (define-int n) ...))
(define-simple-macro (define-scalar-type TY #:from BASE)
#:with define-TY (format-id #'TY "define-~a" #'TY)
#:with define-TYs (format-id #'TY "define-~as" #'TY)
#:with TY* (mk-ptr #'TY)
#:with to-TY (mk-to #'TY)
#:with to-TY* (mk-ptr #'to-TY)
#:with mk-TY (mk-mk #'TY)
#:with cl-TY (mk-cl #'TY)
(begin-
(ro:define (to-TY v) (ro:#%app (ro:#%app cl-TY) v))
(ro:define (to-TY* v) (cl:pointer-cast v cl-TY))
(ro:define (mk-TY v) (ro:#%app cl-TY v))
(define-named-type-alias TY (add-convertm BASE to-TY))
(define-named-type-alias TY* (add-convertm (Pointer TY) to-TY*))
(define-syntax define-TY ; defines a TY vector type of length n
(syntax-parser
[(_ n)
#:with TYn (format-id #'n "~a~a" #'TY (syntax->datum #'n))
#:with TYn* (mk-ptr #'TYn)
#:with to-TYn (mk-to #'TYn)
#:with mk-TYn (mk-mk #'TYn)
#:with to-TYn* (mk-ptr #'to-TYn)
#:with mk-TYn* (mk-ptr #'mk-TYn)
#:with cl-TYn (mk-cl #'TYn)
#:with TYs (build-list (stx->datum #'n) (λ _ #'TY))
#'(begin-
(define-named-type-alias TYn
(add-constructm (add-convertm (rosette3:CVector . TYs) to-TYn) mk-TYn))
(define-named-type-alias TYn* (add-convertm (Pointer TYn) to-TYn*))
(ro:define (to-TYn v) ; not using cl-Tyn bc I need to handle lists
(ro:cond
[(ro:list? v)
(ro:apply mk-TYn (ro:for/list ([i n]) (to-TY (ro:list-ref v i))))]
[(ro:vector? v)
(ro:apply mk-TYn (ro:for/list ([i n]) (to-TY (ro:vector-ref v i))))]
[else (ro:apply mk-TYn (ro:make-list n (to-TY v)))]))
(ro:define (to-TYn* v) (cl:pointer-cast v cl-TYn))
(ro:define (mk-TYn . ns) (ro:apply cl-TYn ns)))]))
(... (define-simple-macro (define-TYs n ...) (begin- (define-TY n) ...)))))
(define-scalar-type int #:from rosette3:Int)
(define-ints 2 3 4 16)
(define-syntax (define-float stx)
(syntax-parse stx
[(_ n)
#:with floatn (format-id #'n "float~a" (syntax->datum #'n))
#:with floatn* (mk-ptr #'floatn)
#:with to-floatn (format-id #'n "to-~a" #'floatn)
#:with mk-floatn (mk-mk #'floatn)
#:with to-floatn* (mk-ptr #'to-floatn)
#:with mk-floatn* (mk-ptr #'mk-floatn)
#:with cl-mk-floatn (mk-cl #'floatn)
#:with (x ...) (generate-temporaries (build-list (syntax->datum #'n) values))
#:with (I ...) (stx-map (lambda _ #'rosette3:Num) #'(x ...))
#'(begin
(define-named-type-alias floatn
(add-constructm
(add-convertm (rosette3:CVector I ...) to-floatn) mk-floatn))
(define-named-type-alias floatn* (add-convertm (Pointer floatn) to-floatn*))
(ro:define (to-floatn v)
(ro:cond
[(ro:list? v)
(ro:apply mk-floatn (ro:for/list ([i n]) (to-float (ro:list-ref v i))))]
[(ro:vector? v)
(ro:apply mk-floatn (ro:for/list ([i n]) (to-float (ro:vector-ref v i))))]
[else (ro:apply mk-floatn (ro:make-list n (to-float v)))]))
(ro:define (to-floatn* v) (cl:pointer-cast v cl-mk-floatn))
(ro:define (mk-floatn x ...) (ro:#%app cl-mk-floatn x ...)))]))
(define-simple-macro (define-floats n ...) (begin (define-float n) ...))
(define-scalar-type float #:from rosette3:Num)
(define-floats 2 3 4 16)
(define-typed-syntax synth-app
@ -556,7 +479,7 @@
--------
[ (to-int (o- (conv e1-) (conv e2-))) int]]))
(define-simple-macro (mk-cmps o ...) (begin- (mk-cmp o) ...))
(mk-cmps == < <= > >=)
(mk-cmps == < <= > >= !=)
(define-simple-macro (define-bool-ops o ...+) (ro:begin (define-bool-op o) ...))
(define-simple-macro (define-bool-op name)

View File

@ -3,10 +3,10 @@
(require macrotypes/examples/tests/do-tests)
(do-tests
"synthcl3-tests.rkt" "SynthCL"
"synthcl3-matrix-synth-tests.rkt" "SynthCL Matrix Mult: synth"
"synthcl3-matrix-verify-tests.rkt" "SynthCL Matrix Mult: verify"
"synthcl3-tests.rkt" "SynthCL general"
"synthcl3-matrix-synth-tests.rkt" "SynthCL Matrix Mult: synth"
"synthcl3-matrix-verify-tests.rkt" "SynthCL Matrix Mult: verify"
"synthcl3-matrix-verify-buggy-tests.rkt" "SynthCL buggy Matrix Mult: verify"
"synthcl3-walsh-synth-tests.rkt" "SynthCL Walsh Transform: synth"
"synthcl3-walsh-verify-tests.rkt" "SynthCL Walsh Transform: verify"
"synthcl3-sobel-tests.rkt" "SynthCL Sobel Filter: synth and verify")
"synthcl3-walsh-synth-tests.rkt" "SynthCL Walsh Transform: synth"
"synthcl3-walsh-verify-tests.rkt" "SynthCL Walsh Transform: verify"
"synthcl3-sobel-tests.rkt" "SynthCL Sobel Filter: synth and verify")

View File

@ -170,7 +170,6 @@
(check-type
(with-output-to-string (λ () (verify_scalar)))
: CString -> "no counterexample found\n")
(verify_scalar)
(check-type
(with-output-to-string (λ () (verify_vectorized)))
: CString -> "no counterexample found\n")