diff --git a/turnstile/examples/rosette/synthcl-model.rkt b/turnstile/examples/rosette/synthcl-model.rkt new file mode 100644 index 0000000..c568316 --- /dev/null +++ b/turnstile/examples/rosette/synthcl-model.rkt @@ -0,0 +1,30 @@ +#lang racket +(require (for-syntax syntax/stx racket/syntax) syntax/parse/define) + +;; in general, must work with forms and fns from model/, since lang/ often +;; uses (synthcl) type-directed macros (and not typed rosette types) + +(define-for-syntax (mk-model-path m) (format-id m "sdsl/synthcl/model/~a" m)) + +(define-simple-macro (require+provide/synthcl/model x ...) + #:with (m ...) (stx-map mk-model-path #'(x ...)) + (begin (require (combine-in m ...)) + (provide (combine-out (all-from-out m) ...)))) + +(require+provide/synthcl/model buffer + context + errors + flags + kernel + memory + operators + pointers + program + queue + reals + runtime + work) + +(require (for-syntax (only-in sdsl/synthcl/lang/util parse-selector)) + (only-in sdsl/synthcl/lang/forms ??)) +(provide (for-syntax parse-selector) ??) diff --git a/turnstile/examples/rosette/synthcl3.rkt b/turnstile/examples/rosette/synthcl3.rkt index ed3bb04..81fa7da 100644 --- a/turnstile/examples/rosette/synthcl3.rkt +++ b/turnstile/examples/rosette/synthcl3.rkt @@ -1,120 +1,77 @@ #lang turnstile (extends "rosette3.rkt" #:except ! #%app || && void = * + - / #%datum if assert verify < <= > >=) ; typed rosette -(require ;(prefix-in ro: (except-in rosette verify sqrt range print)) ; untyped - racket/stxparam - (prefix-in ro: (combine-in rosette rosette/lib/synthax)) - (prefix-in cl: (combine-in sdsl/synthcl/model/operators - sdsl/synthcl/lang/forms sdsl/synthcl/model/reals - sdsl/synthcl/model/errors sdsl/synthcl/model/kernel - sdsl/synthcl/model/memory sdsl/synthcl/model/runtime - sdsl/synthcl/model/work sdsl/synthcl/model/pointers - sdsl/synthcl/lang/queries sdsl/synthcl/model/context - sdsl/synthcl/model/queue sdsl/synthcl/model/buffer - sdsl/synthcl/model/flags sdsl/synthcl/model/program)) - (for-syntax (prefix-in cl: sdsl/synthcl/lang/util))) +(require (prefix-in ro: (combine-in rosette rosette/lib/synthax)) + (prefix-in cl: "synthcl-model.rkt")) (begin-for-syntax (define (mk-cl id) (format-id #'here "cl:~a" id)) (current-host-lang mk-cl)) -(provide (rename-out [synth-app #%app]) - procedure kernel grammar #%datum if range for print - choose ?? @ locally-scoped assert synth verify - int int2 int3 int4 int16 float float2 float3 float4 float16 - bool void void* char* - int* int2* int3* int4* int16* float* float2* float3* float4* float16* - cl_context cl_command_queue cl_program cl_kernel cl_mem - : ! ?: == + * / - sqrt || && - % << $ & > >= < <= ; int ops - = += -= *= /= %= $= &= ; assignment ops - sizeof clCreateProgramWithSource - (typed-out - [clCreateContext : (C→ cl_context)] - [clCreateCommandQueue : (C→ cl_context cl_command_queue)] - [clCreateBuffer : (C→ cl_context int int cl_mem)] - [clEnqueueReadBuffer : (C→ cl_command_queue cl_mem int int void* void)] +(provide + (rename-out [synth-app #%app]) + procedure kernel grammar #%datum if range for print + choose ?? @ locally-scoped assert synth verify + int int2 int3 int4 int16 float float2 float3 float4 float16 + int* int2* int3* int4* int16* float* float2* float3* float4* float16* + bool void void* char* + cl_context cl_command_queue cl_program cl_kernel cl_mem + : ! ?: == + * / - sqrt || && % << $ & > >= < <= != = += -= *= /= %= $= &= + sizeof clCreateProgramWithSource + (typed-out + [clCreateContext : (C→ cl_context)] + [clCreateCommandQueue : (C→ cl_context cl_command_queue)] + [clCreateBuffer : (C→ cl_context int int cl_mem)] + [clEnqueueReadBuffer : (C→ cl_command_queue cl_mem int int void* void)] + [clEnqueueWriteBuffer : (C→ cl_command_queue cl_mem int int void* void)] + [clEnqueueNDRangeKernel : (C→ cl_command_queue cl_kernel int int* int* int* void)] + [clCreateKernel : (C→ cl_program char* cl_kernel)] + [clSetKernelArg : (Ccase-> (C→ cl_kernel int cl_mem void) + (C→ cl_kernel int int void) + (C→ cl_kernel int float void))] + [get_global_id : (C→ int int)] + [CL_MEM_READ_ONLY : int] [CL_MEM_WRITE_ONLY : int] [CL_MEM_READ_WRITE : int] + [malloc : (C→ int void*)] + [memset : (C→ void* int int void*)] + [convert_float4 : (Ccase-> (C→ int4 float4) (C→ float4 float4))] + [convert_int4 : (Ccase-> (C→ int4 int4) (C→ float4 int4))] + [get_work_dim : (C→ int)] + [NULL : void*])) - [clEnqueueWriteBuffer : (C→ cl_command_queue cl_mem int int void* void)] - [clEnqueueNDRangeKernel : (C→ cl_command_queue cl_kernel int int* int* int* void)] - [clCreateKernel : (C→ cl_program char* cl_kernel)] - [clSetKernelArg : (Ccase-> (C→ cl_kernel int cl_mem void) - (C→ cl_kernel int int void) - (C→ cl_kernel int float void))] - [get_global_id : (C→ int int)] - [CL_MEM_READ_ONLY : int] - [CL_MEM_WRITE_ONLY : int] - [CL_MEM_READ_WRITE : int] - [malloc : (C→ int void*)] - [memset : (C→ void* int int void*)] - [convert_float4 : (Ccase-> (C→ int4 float4) (C→ float4 float4))] - [convert_int4 : (Ccase-> (C→ int4 int4) (C→ float4 int4))] - [get_work_dim : (C→ int)] - [!= : (Ccase-> (C→ CNum CNum CBool) - (C→ CNum CNum CNum CBool) - (C→ Num Num Bool) - (C→ Num Num Num Bool))] - [NULL : void*])) (begin-for-syntax - ;; TODO: use equality type relation instead of subtype - ;; - require reimplementing many more things, eg #%datum, +, etc - ;; typecheck unexpanded types - (define (typecheck/un? t1 t2) - (typecheck? ((current-type-eval) t1) - ((current-type-eval) t2))) + (current-typecheck-relation (current-type=?)) ; no subtyping + (define (typecheck/un? t1 t2) ; typecheck unexpanded types + (typecheck? ((current-type-eval) t1) ((current-type-eval) t2))) + (define (pointer-type? t) (Pointer? t)) (define (real-type? t) - (and #;(not (typecheck/un? t #'bool)) - (not (typecheck/un? t #'char*)) - (not (pointer-type? t)))) - (define (pointer-type? t) - (Pointer? t) - #;(regexp-match #px"\\*$" (type->str t))) + (and (not (pointer-type? t)) (not (typecheck/un? t #'char*)))) (define (real-type<=? t1 t2) (and (real-type? t1) (real-type? t2) - (or ((current-type=?) t1 t2) ; need type= to distinguish reals/ints + (or (typecheck? t1 t2) ; need type= to distinguish reals/ints (typecheck/un? t1 #'bool) - (and (typecheck/un? t1 #'int) - (not (typecheck/un? t2 #'bool))) + (and (typecheck/un? t1 #'int) (not (typecheck/un? t2 #'bool))) (and (typecheck/un? t1 #'float) - (typecheck/un? (get-base t2) #'float))))) - - ; Returns the common real type of the given types, as specified in - ; Ch. 6.2.6 of opencl-1.2 specification. If there is no common - ; real type, returns #f. + (typecheck/un? (get-base/un t2) #'float))))) + ;; same as common-real-type from model/reals.rkt + ;; Returns the common real type of the given types, as specified in + ;; Ch. 6.2.6 of opencl-1.2 specification. returns #f if none (define common-real-type (case-lambda [(t) (and (real-type? t) t)] [(t1 t2) (or (and (real-type<=? t1 t2) t2) (and (real-type<=? t2 t1) t1))] [ts (common-real-type (car ts) (apply common-real-type (cdr ts)))])) - - ;; implements common-real-type from model/reals.rkt - ; Returns the common real type of the given types, as specified in - ; Ch. 6.2.6 of opencl-1.2 specification. If there is no common - ; real type, returns #f. (current-join common-real-type) ;; copied from check-implicit-conversion in lang/types.rkt - ;; TODO: this should not exception since it is used in stx-parse - ;; clauses that may want to backtrack + ;; TODO: should this exn? it is used in stx-parse that may want to backtrack (define (cast-ok? from to [expr #f] [subexpr #f]) - #;(printf "casting ~a to ~a: ~a\n" (type->str from) (type->str to) - (or (typecheck/un? from to) - (and (scalar-type? from) (scalar-type? to)) - (and (scalar-type? from) (vector-type? to)) - (and (pointer-type? from) (pointer-type? to)) - #;(and (equal? from cl_mem) (pointer-type? to)))) - (unless (if #t #;(type? to) - (or (typecheck/un? from to) - (and (scalar-type? from) (scalar-type? to)) - (and (scalar-type? from) (vector-type? to)) - (and (pointer-type? from) (pointer-type? to)) - #;(and (equal? from cl_mem) (pointer-type? to))) - (to from)) - (raise-syntax-error - #f - (format "no implicit conversion from ~a to ~a" - (type->str from) (type->str to) - #;(if (contract? to) (contract-name to) to)) - expr subexpr))) + (unless (or (typecheck/un? from to) + (and (scalar-type? from) (scalar-type? to)) + (and (scalar-type? from) (vector-type? to)) + (and (pointer-type? from) (pointer-type? to)) + #;(and (equal? from cl_mem) (pointer-type? to))) + (raise-syntax-error #f + (format "no implicit conversion from ~a to ~a" + (type->str from) (type->str to)) expr subexpr))) (define (mk-ptr id) (format-id id "~a*" id)) (define (mk-mk id) (format-id id "mk-~a" id)) (define (mk-to id) (format-id id "to-~a" id)) @@ -128,11 +85,11 @@ (define split-ty (ty->len t)) (string->number (or (and split-ty (third split-ty)) "1"))) - (define (get-base ty [ctx #'here]) - ((current-type-eval) - (datum->syntax ctx - (string->symbol (car (regexp-match #px"[a-z]+" (type->str ty))))))) - (define (get-pointer-base ty [ctx #'here]) + (define (get-base/un ty [ctx #'here]) ; returns unexpanded base type + (datum->syntax ctx + (string->symbol (car (regexp-match #px"[a-z]+" (type->str ty)))))) + (define (get-base ty [ctx #'here]) ((current-type-eval) (get-base/un ty ctx))) + (define (get-pointer-base ty [ctx #'here]) ; returns unexpanded ptr base (datum->syntax ctx (string->symbol (string-trim (type->str ty) "*")))) (define (vector-type? ty) (define tstr (type->str ty)) @@ -144,91 +101,57 @@ (define-syntax-parser add-convertm [(_ stx fn) (add-convert #'stx #'fn)]) (define-syntax-parser add-constructm [(_ stx fn) (add-construct #'stx #'fn)]) -(ro:define (to-bool v) ; TODO: reuse impls in model/reals.rkt ? - (ro:cond [(ro:boolean? v) v] - [(ro:number? v) (ro:! (ro:= 0 v))] - [else (cl:raise-conversion-error v "bool")])) -(ro:define (to-int v) - (ro:cond [(ro:boolean? v) (ro:if v 1 0)] - [(ro:fixnum? v) v] - [(ro:flonum? v) (ro:exact-truncate v)] - [else (ro:real->integer v)])) -(ro:define (to-float v) - (ro:cond [(ro:boolean? v) (ro:if v 1.0 0.0)] - [(ro:fixnum? v) (ro:exact->inexact v)] - [(ro:flonum? v) v] - [else (ro:type-cast ro:real? v)])) -(ro:define (mk-int v) (ro:#%app cl:int v)) -(ro:define (mk-float v) (ro:#%app cl:float v)) -(ro:define (to-int* v) (cl:pointer-cast v cl:int)) -(ro:define (to-float* v) (cl:pointer-cast v cl:float)) +(ro:define (to-bool v) (ro:#%app (ro:#%app cl:bool) v)) -(define-type-constructor Pointer #:arity = 1) -;(define-named-type-alias void rosette3:CUnit) (define-base-types void cl_context cl_command_queue cl_program cl_kernel cl_mem) -(define-named-type-alias void* (add-convertm (Pointer void) (λ (x) x))) -(define-named-type-alias bool (add-convertm rosette3:Bool to-bool)) -(define-named-type-alias int (add-convertm rosette3:Int to-int)) -(define-named-type-alias int* (add-convertm (Pointer int) to-int*)) -(define-named-type-alias float (add-convertm rosette3:Num to-float)) -(define-named-type-alias float* (add-convertm (Pointer float) to-float*)) -(define-named-type-alias char* (add-convertm rosette3:CString (λ (x) x))) +(define-type-constructor Pointer #:arity = 1) +(define-named-type-alias void* (Pointer void)) +(define-named-type-alias char* rosette3:CString) +(define-named-type-alias bool (add-convertm rosette3:Bool to-bool)) -(define-syntax (define-int stx) - (syntax-parse stx - [(_ n) - #:with intn (format-id #'n "int~a" (syntax->datum #'n)) - #:with intn* (mk-ptr #'intn) - #:with to-intn (format-id #'n "to-~a" #'intn) - #:with mk-intn (mk-mk #'intn) - #:with to-intn* (mk-ptr #'to-intn) - #:with mk-intn* (mk-ptr #'mk-intn) - #:with cl-mk-intn (mk-cl #'intn) - #:with (x ...) (generate-temporaries (build-list (syntax->datum #'n) values)) - #:with (I ...) (stx-map (lambda _ #'rosette3:Int) #'(x ...)) - #'(begin - (define-named-type-alias intn - (add-constructm (add-convertm (rosette3:CVector I ...) to-intn) mk-intn)) - (define-named-type-alias intn* (add-convertm (Pointer intn) to-intn*)) - (ro:define (to-intn v) - (ro:cond - [(ro:list? v) - (ro:apply mk-intn (ro:for/list ([i n]) (to-int (ro:list-ref v i))))] - [(ro:vector? v) - (ro:apply mk-intn (ro:for/list ([i n]) (to-int (ro:vector-ref v i))))] - [else (ro:apply mk-intn (ro:make-list n (to-int v)))])) - (ro:define (to-intn* v) (cl:pointer-cast v cl-mk-intn)) - (ro:define (mk-intn x ...) (ro:#%app cl-mk-intn x ...)))])) -(define-simple-macro (define-ints n ...) (begin (define-int n) ...)) +(define-simple-macro (define-scalar-type TY #:from BASE) + #:with define-TY (format-id #'TY "define-~a" #'TY) + #:with define-TYs (format-id #'TY "define-~as" #'TY) + #:with TY* (mk-ptr #'TY) + #:with to-TY (mk-to #'TY) + #:with to-TY* (mk-ptr #'to-TY) + #:with mk-TY (mk-mk #'TY) + #:with cl-TY (mk-cl #'TY) + (begin- + (ro:define (to-TY v) (ro:#%app (ro:#%app cl-TY) v)) + (ro:define (to-TY* v) (cl:pointer-cast v cl-TY)) + (ro:define (mk-TY v) (ro:#%app cl-TY v)) + (define-named-type-alias TY (add-convertm BASE to-TY)) + (define-named-type-alias TY* (add-convertm (Pointer TY) to-TY*)) + (define-syntax define-TY ; defines a TY vector type of length n + (syntax-parser + [(_ n) + #:with TYn (format-id #'n "~a~a" #'TY (syntax->datum #'n)) + #:with TYn* (mk-ptr #'TYn) + #:with to-TYn (mk-to #'TYn) + #:with mk-TYn (mk-mk #'TYn) + #:with to-TYn* (mk-ptr #'to-TYn) + #:with mk-TYn* (mk-ptr #'mk-TYn) + #:with cl-TYn (mk-cl #'TYn) + #:with TYs (build-list (stx->datum #'n) (λ _ #'TY)) + #'(begin- + (define-named-type-alias TYn + (add-constructm (add-convertm (rosette3:CVector . TYs) to-TYn) mk-TYn)) + (define-named-type-alias TYn* (add-convertm (Pointer TYn) to-TYn*)) + (ro:define (to-TYn v) ; not using cl-Tyn bc I need to handle lists + (ro:cond + [(ro:list? v) + (ro:apply mk-TYn (ro:for/list ([i n]) (to-TY (ro:list-ref v i))))] + [(ro:vector? v) + (ro:apply mk-TYn (ro:for/list ([i n]) (to-TY (ro:vector-ref v i))))] + [else (ro:apply mk-TYn (ro:make-list n (to-TY v)))])) + (ro:define (to-TYn* v) (cl:pointer-cast v cl-TYn)) + (ro:define (mk-TYn . ns) (ro:apply cl-TYn ns)))])) + (... (define-simple-macro (define-TYs n ...) (begin- (define-TY n) ...))))) + +(define-scalar-type int #:from rosette3:Int) (define-ints 2 3 4 16) - -(define-syntax (define-float stx) - (syntax-parse stx - [(_ n) - #:with floatn (format-id #'n "float~a" (syntax->datum #'n)) - #:with floatn* (mk-ptr #'floatn) - #:with to-floatn (format-id #'n "to-~a" #'floatn) - #:with mk-floatn (mk-mk #'floatn) - #:with to-floatn* (mk-ptr #'to-floatn) - #:with mk-floatn* (mk-ptr #'mk-floatn) - #:with cl-mk-floatn (mk-cl #'floatn) - #:with (x ...) (generate-temporaries (build-list (syntax->datum #'n) values)) - #:with (I ...) (stx-map (lambda _ #'rosette3:Num) #'(x ...)) - #'(begin - (define-named-type-alias floatn - (add-constructm - (add-convertm (rosette3:CVector I ...) to-floatn) mk-floatn)) - (define-named-type-alias floatn* (add-convertm (Pointer floatn) to-floatn*)) - (ro:define (to-floatn v) - (ro:cond - [(ro:list? v) - (ro:apply mk-floatn (ro:for/list ([i n]) (to-float (ro:list-ref v i))))] - [(ro:vector? v) - (ro:apply mk-floatn (ro:for/list ([i n]) (to-float (ro:vector-ref v i))))] - [else (ro:apply mk-floatn (ro:make-list n (to-float v)))])) - (ro:define (to-floatn* v) (cl:pointer-cast v cl-mk-floatn)) - (ro:define (mk-floatn x ...) (ro:#%app cl-mk-floatn x ...)))])) -(define-simple-macro (define-floats n ...) (begin (define-float n) ...)) +(define-scalar-type float #:from rosette3:Num) (define-floats 2 3 4 16) (define-typed-syntax synth-app @@ -556,7 +479,7 @@ -------- [⊢ (to-int (o- (conv e1-) (conv e2-))) ⇒ int]])) (define-simple-macro (mk-cmps o ...) (begin- (mk-cmp o) ...)) -(mk-cmps == < <= > >=) +(mk-cmps == < <= > >= !=) (define-simple-macro (define-bool-ops o ...+) (ro:begin (define-bool-op o) ...)) (define-simple-macro (define-bool-op name) diff --git a/turnstile/examples/tests/rosette/rosette3/run-all-synthcl-tests.rkt b/turnstile/examples/tests/rosette/rosette3/run-all-synthcl-tests.rkt index 07e8f25..a5e1a08 100644 --- a/turnstile/examples/tests/rosette/rosette3/run-all-synthcl-tests.rkt +++ b/turnstile/examples/tests/rosette/rosette3/run-all-synthcl-tests.rkt @@ -3,10 +3,10 @@ (require macrotypes/examples/tests/do-tests) (do-tests - "synthcl3-tests.rkt" "SynthCL" - "synthcl3-matrix-synth-tests.rkt" "SynthCL Matrix Mult: synth" - "synthcl3-matrix-verify-tests.rkt" "SynthCL Matrix Mult: verify" + "synthcl3-tests.rkt" "SynthCL general" + "synthcl3-matrix-synth-tests.rkt" "SynthCL Matrix Mult: synth" + "synthcl3-matrix-verify-tests.rkt" "SynthCL Matrix Mult: verify" "synthcl3-matrix-verify-buggy-tests.rkt" "SynthCL buggy Matrix Mult: verify" - "synthcl3-walsh-synth-tests.rkt" "SynthCL Walsh Transform: synth" - "synthcl3-walsh-verify-tests.rkt" "SynthCL Walsh Transform: verify" - "synthcl3-sobel-tests.rkt" "SynthCL Sobel Filter: synth and verify") + "synthcl3-walsh-synth-tests.rkt" "SynthCL Walsh Transform: synth" + "synthcl3-walsh-verify-tests.rkt" "SynthCL Walsh Transform: verify" + "synthcl3-sobel-tests.rkt" "SynthCL Sobel Filter: synth and verify") diff --git a/turnstile/examples/tests/rosette/rosette3/synthcl3-sobel-tests.rkt b/turnstile/examples/tests/rosette/rosette3/synthcl3-sobel-tests.rkt index c9d3a01..8398bc2 100644 --- a/turnstile/examples/tests/rosette/rosette3/synthcl3-sobel-tests.rkt +++ b/turnstile/examples/tests/rosette/rosette3/synthcl3-sobel-tests.rkt @@ -170,7 +170,6 @@ (check-type (with-output-to-string (λ () (verify_scalar))) : CString -> "no counterexample found\n") -(verify_scalar) (check-type (with-output-to-string (λ () (verify_vectorized))) : CString -> "no counterexample found\n")