diff --git a/macrotypes/typecheck.rkt b/macrotypes/typecheck.rkt index 81aeac6..9524006 100644 --- a/macrotypes/typecheck.rkt +++ b/macrotypes/typecheck.rkt @@ -55,8 +55,8 @@ [(_ . stuff) (syntax/loc this-syntax (#%module-begin - ; auto-provide some useful racket forms - (provide #%module-begin #%top-interaction #%top require only-in) + (provide #%module-begin #%top-interaction #%top ; useful racket forms + require only-in prefix-in rename-in) . stuff))])) (struct exn:fail:type:runtime exn:fail:user ()) diff --git a/turnstile/examples/rosette/rosette3.rkt b/turnstile/examples/rosette/rosette3.rkt index aeaecdc..3869b4f 100644 --- a/turnstile/examples/rosette/rosette3.rkt +++ b/turnstile/examples/rosette/rosette3.rkt @@ -1,7 +1,8 @@ #lang turnstile ;; reuse unlifted forms as-is (reuse define λ let let* letrec begin void #%datum ann #%top-interaction - require only-in define-type-alias define-named-type-alias + require only-in prefix-in rename-in define-type-alias define-named-type-alias + current-join ⊔ #:from "../stlc+union.rkt") (require ;; manual imports @@ -11,7 +12,7 @@ (combine-in (only-in "../stlc+union+case.rkt" PosInt Zero NegInt Float True False String Unit [U U*] U*? - [case-> case->*] case->? → →?) + [case-> case->*] case->? → →? String?) (only-in "../stlc+cons.rkt" [List Listof]))) (only-in "../stlc+union+case.rkt" [~U* ~CU*] [~case-> ~Ccase->] [~→ ~C→]) (only-in "../stlc+cons.rkt" [~List ~CListof]) @@ -21,13 +22,14 @@ (provide (rename-out [ro:#%module-begin #%module-begin] [stlc+union:λ lambda]) + (for-syntax get-pred) Any CNothing Nothing - CU U + CU U (for-syntax ~U*) Constant C→ → (for-syntax ~C→ C→?) Ccase-> (for-syntax ~Ccase-> Ccase->?) ; TODO: sym case-> not supported CListof Listof CList CPair Pair - CVectorof MVectorof IVectorof Vectorof CMVectorof CIVectorof + CVectorof MVectorof IVectorof Vectorof CMVectorof CIVectorof CVector CParamof ; TODO: symbolic Param not supported yet CBoxof MBoxof IBoxof CMBoxof CIBoxof CHashTable CUnit Unit @@ -39,13 +41,13 @@ CFloat Float CNum Num CFalse CTrue CBool Bool - CString String + CString String (for-syntax CString?) CStx ; symblic Stx not supported CAsserts ;; BV types CBV BV CBVPred BVPred - CSolution CSolver CPict CSyntax CRegexp CSymbol CPred) + CSolution CSolver CPict CSyntax CRegexp CSymbol CPred CPredC) (begin-for-syntax (define (mk-ro:-id id) (format-id id "ro:~a" id)) @@ -85,6 +87,7 @@ (define-named-type-alias (CVectorof X) (CU (CIVectorof X) (CMVectorof X))) (define-named-type-alias (CBoxof X) (CU (CIBoxof X) (CMBoxof X))) (define-type-constructor CList #:arity >= 0) +(define-type-constructor CVector #:arity >= 0) (define-type-constructor CPair #:arity = 2) ;; TODO: update orig to use reduced type @@ -224,6 +227,7 @@ (define-symbolic-named-type-alias Num (CU CFloat CInt) #:pred ro:real?) (define-named-type-alias CPred (C→ Any Bool)) +(define-named-type-alias CPredC (C→ Any CBool)) ;; --------------------------------- ;; define-symbolic @@ -1593,13 +1597,14 @@ ;; --------------------------------- ;; Reflecting on symbolic values +;; TODO: CPredC correct here? (provide (typed-out - [term? : CPred] - [expression? : CPred] - [constant? : CPred] - [type? : CPred] - [solvable? : CPred] - [union? : CPred])) + [term? : CPredC] + [expression? : CPredC] + [constant? : CPredC] + [type? : CPredC] + [solvable? : CPredC] + [union? : CPredC])) (define-typed-syntax union-contents [(_ u) ≫ diff --git a/turnstile/examples/rosette/synthcl3.rkt b/turnstile/examples/rosette/synthcl3.rkt new file mode 100644 index 0000000..f7677f6 --- /dev/null +++ b/turnstile/examples/rosette/synthcl3.rkt @@ -0,0 +1,320 @@ +#lang turnstile +(extends "rosette3.rkt" #:except ! #%app || &&) ; typed rosette +(require ;(prefix-in ro: (except-in rosette verify sqrt range print)) ; untyped + (prefix-in ro: rosette) +; define-symbolic* #%datum define if ! = number? boolean? cond ||)) + (prefix-in cl: sdsl/synthcl/model/operators) + (prefix-in cl: sdsl/synthcl/model/errors)) + +(begin-for-syntax + (define (mk-cl id) (format-id id "cl:~a" id)) + (current-host-lang mk-cl)) + +;(define-base-types) + +(provide (rename-out + [synth-app #%app] +; [rosette3:Bool bool] ; symbolic +; [rosette3:Int int] ; symbolic +; [rosette3:Num float] ; symbolic + [rosette3:CString char*]) ; always concrete + bool int float float3 int16 + : ! ?: == + (typed-out + ;; need the concrete cases for literals; + ;; alternative is to redefine #%datum to give literals symbolic type + [% : (Ccase-> (C→ CInt CInt CInt) + (C→ Int Int Int))] + [!= : (Ccase-> (C→ CNum CNum CBool) + (C→ CNum CNum CNum CBool) + (C→ Num Num Bool) + (C→ Num Num Num Bool))] + #;[== : (Ccase-> (C→ CNum CNum CBool) + (C→ CNum CNum CNum CBool) + (C→ Num Num Bool) + (C→ Num Num Num Bool))])) +(begin-for-syntax + ;; TODO: use equality type relation instead of subtype + ;; - require reimplementing many more things, eg #%datum, +, etc +; (current-typecheck-relation (current-type=?)) + ;; typecheck unexpanded types + (define (typecheck/un? t1 t2) + (typecheck? ((current-type-eval) t1) + ((current-type-eval) t2))) + (define (real-type? t) (not (typecheck/un? t #'bool))) + (define (type-base t) + (datum->syntax t + (string->symbol + (car (regexp-match #px"[a-z]+" (type->str t)))))) + (define (real-type<=? t1 t2) + (and (real-type? t1) (real-type? t2) + (or (typecheck? t1 t2) + (typecheck/un? t1 #'bool) + (and (typecheck/un? t1 #'int) + (not (typecheck/un? t2 #'bool))) + (and (typecheck/un? t1 #'float) + (typecheck/un? (type-base t2) #'float))))) + + ; Returns the common real type of the given types, as specified in + ; Ch. 6.2.6 of opencl-1.2 specification. If there is no common + ; real type, returns #f. + (define common-real-type + (case-lambda + [(t) (and (real-type? t) t)] + [(t1 t2) (cond [(real-type<=? t1 t2) t2] + [(real-type<=? t2 t1) t1] + [else #f])] + [ts (common-real-type (car ts) (apply common-real-type (cdr ts)))])) + + ;; implements common-real-type from model/reals.rkt + ; Returns the common real type of the given types, as specified in + ; Ch. 6.2.6 of opencl-1.2 specification. If there is no common + ; real type, returns #f. + (current-join common-real-type) + ;; copied from check-implicit-conversion in lang/types.rkt + ;; TODO: this should not exception since it is used in stx-parse + ;; clauses that may want to backtrack + (define (cast-ok? from to expr [subexpr #f]) + (unless (if #t #;(type? to) + (or (typecheck/un? from to) + (and (scalar-type? from) (scalar-type? to)) + (and (scalar-type? from) (vector-type? to)) + #;(and (pointer-type? from) (pointer-type? to)) + #;(and (equal? from cl_mem) (pointer-type? to))) + (to from)) + (raise-syntax-error + #f + (format "no implicit conversion from ~a to ~a" + (type->str from) (type->str to) + #;(if (contract? to) (contract-name to) to)) + expr subexpr)) + #;(or (typecheck/un? from to) ; from == to + (and (real-type? from) + (typecheck/un? to #'bool)))) + (define (add-convert stx fn) + (set-stx-prop/preserved stx 'convert fn)) + (define (get-convert stx) + (syntax-property stx 'convert)) + (define (ty->len ty) + (regexp-match #px"([a-z]+)([0-9]+)" (type->str ty))) + (define (vector-type? ty) + (ty->len ty)) + (define (scalar-type? ty) + (not (vector-type? ty)))) + +(define-syntax-parser add-convertm + [(_ stx fn) (add-convert #'stx #'fn)]) + +;; TODO: reuse impls in model/reals.rkt ? +(ro:define (to-bool v) + (ro:cond + [(ro:boolean? v) v] + [(ro:number? v) (ro:! (ro:= 0 v))] + [else (cl:raise-conversion-error v "bool")])) +(ro:define (to-int v) + (ro:cond [(ro:boolean? v) (ro:if v 1 0)] + [(ro:fixnum? v) v] + [(ro:flonum? v) (ro:exact-truncate v)] + [else (ro:real->integer v)])) +(ro:define (to-float v) + (ro:cond + [(ro:boolean? v) (ro:if v 1.0 0.0)] + [(ro:fixnum? v) (ro:exact->inexact v)] + [(ro:flonum? v) v] + [else (ro:type-cast ro:real? v)])) +(ro:define (to-float3 v) + (ro:cond + [(ro:list? v) + (ro:apply ro:vector-immutable + (ro:for/list ([i 3]) (to-float (ro:list-ref v i))))] + [(ro:vector? v) + (ro:apply ro:vector-immutable + (ro:for/list ([i 3]) (to-float (ro:vector-ref v i))))] + [else (ro:apply ro:vector-immutable (ro:make-list 3 (to-float v)))])) +(ro:define (to-int16 v) v + #;(ro:cond + [(ro:list? v) + (ro:apply ro:vector-immutable + (ro:for/list ([i 16]) (to-int (ro:list-ref v i))))] + [(ro:vector? v) + (ro:apply ro:vector-immutable + (ro:for/list ([i 16]) (to-int (ro:vector-ref v i))))] + [else (ro:apply ro:vector-immutable (ro:make-list 16 (to-float v)))])) + +(define-named-type-alias bool + (add-convertm rosette3:Bool to-bool)) +(define-named-type-alias int + (add-convertm rosette3:Int to-int)) +(define-named-type-alias float + (add-convertm rosette3:Num to-float)) + +(define-named-type-alias float3 + (add-convertm + (rosette3:CVector rosette3:Num rosette3:Num rosette3:Num) + to-float3)) +(define-named-type-alias int16 + (add-convertm + (rosette3:CVector rosette3:Int rosette3:Int rosette3:Int rosette3:Int + rosette3:Int rosette3:Int rosette3:Int rosette3:Int + rosette3:Int rosette3:Int rosette3:Int rosette3:Int + rosette3:Int rosette3:Int rosette3:Int rosette3:Int) + to-int16)) + +(define-typed-syntax synth-app + [(_ (ty:type) e) ≫ ; cast + [⊢ e ≫ e- ⇒ ty-e] + #:fail-unless (cast-ok? #'ty-e #'ty.norm #'e) + (format "cannot cast ~a to ~a" + (type->str #'ty-e) (type->str #'ty.norm)) + #:with convert (get-convert #'ty.norm) + #:fail-unless (syntax-e #'convert) + (format "cannot cast ~a to ~a: conversion fn not found" + (type->str #'ty-e) (type->str #'ty.norm)) + -------- + [⊢ (convert e-) ⇒ ty.norm]] + [(_ . es) ≫ + -------- + [≻ (rosette3:#%app . es)]]) + +(define-typed-syntax : + [(_ ty:type x:id ...) ≫ ; special String case + #:when (rosette3:CString? #'ty.norm) + #:with (x- ...) (generate-temporaries #'(x ...)) + -------- + [≻ (begin- + (define-syntax- x + (make-rename-transformer (assign-type #'x- #'ty.norm))) ... + (ro:define x- (ro:#%datum . "")) ...)]] + ;; TODO: vector types need a better representation + ;; currently dissecting the identifier + [(_ ty:type x:id ...) ≫ + #:do [(define split-ty (ty->len #'ty))] + #:when (and split-ty (= 3 (length split-ty))) + #:do [(define base-str (cadr split-ty)) + (define len-str (caddr split-ty))] + #:with ty-base (datum->syntax #'ty (string->symbol base-str)) + #:with pred (get-pred ((current-type-eval) #'ty-base)) + #:fail-unless (syntax-e #'pred) + (format "no pred for ~a" (type->str #'ty)) + #:with (x- ...) (generate-temporaries #'(x ...)) + #:with (x-- ...) (generate-temporaries #'(x ...)) + -------- + [≻ (begin- + (ro:define-symbolic* x-- pred [#,(string->number len-str)]) ... + (ro:define x- (ro:apply ro:vector-immutable x--)) ... + (define-syntax- x + (make-rename-transformer (assign-type #'x- #'ty.norm))) ...)]] + ;; scalar (ie non-vector) types + [(_ ty:type x:id ...) ≫ + #:with pred (get-pred #'ty.norm) + #:fail-unless (syntax-e #'pred) + (format "no pred for ~a" (type->str #'ty)) + #:with (x- ...) (generate-temporaries #'(x ...)) + -------- + [≻ (begin- + (define-syntax- x + (make-rename-transformer (assign-type #'x- #'ty.norm))) ... + (ro:define-symbolic* x- pred) ...)]]) + +(define-typed-syntax ?: + [(_ e e1 e2) ≫ + [⊢ e ≫ e- ⇐ bool] + [⊢ e1 ≫ e1- ⇒ ty1] + [⊢ e2 ≫ e2- ⇒ ty2] + ------- + [⊢ (cl:?: e- e1- e2-) ⇒ (⊔ τ1 τ2)]] + [(_ e e1 e2) ≫ + [⊢ e ≫ e- ⇒ ty] ; vector type + #:do [(define split-ty (ty->len #'ty))] + #:when (and split-ty (= 3 (length split-ty))) + [⊢ e1 ≫ e1- ⇒ ty1] + [⊢ e2 ≫ e2- ⇒ ty2] + #:with ty-out (common-real-type #'ty #'ty1 #'ty2) + #:with convert (get-convert #'ty-out) + #:do [(define split-ty-out (ty->len #'ty-out)) + (define out-base-str (cadr split-ty-out)) + (define out-len-str (caddr split-ty-out))] + #:with ty-base ((current-type-eval) (datum->syntax #'e (string->symbol out-base-str))) + #:with base-convert (get-convert #'ty-base) + ------- + [⊢ (convert + (ro:let ([a (convert e-)][b (convert e1-)][c (convert e2-)]) + (ro:for/list ([idx #,(string->number out-len-str)]) + (ro:if (ro:< (ro:vector-ref a idx) 0) + (base-convert (ro:vector-ref b idx)) + (base-convert (ro:vector-ref c idx)))))) + ⇒ ty-out]] + [(_ e e1 e2) ≫ ; should be scalar + #:when (displayln 1) + [⊢ e ≫ e- ⇒ ty] + #:fail-unless (cast-ok? #'ty #'bool #'e) + (format "cannot cast ~a to bool" (type->str #'ty)) + [⊢ e1 ≫ e1- ⇒ ty1] + [⊢ e2 ≫ e2- ⇒ ty2] + #:with ty-out ((current-join) #'ty1 #'ty2) + ------- + [⊢ (cl:?: (synth-app (bool) e-) + (synth-app (ty-out) e1-) + (synth-app (ty-out) e2-)) ⇒ ty-out]]) + +(define-typed-syntax ! + [(_ e) ≫ + [⊢ e ≫ e- ⇐ bool] + -------- + [⊢ (cl:! e-) ⇒ bool]] + ;; else try to coerce + [(_ e) ≫ + [⊢ e ≫ e- ⇒ ty] + -------- + [⊢ (cl:! (synth-app (bool) e-)) ⇒ bool]]) + +(define-syntax (define-coercing-bool-binop stx) + (syntax-parse stx + [(_ name) + #:with name- (mk-cl #'name) + #'(define-typed-syntax name + [(_ e1 e2) ≫ + [⊢ e1 ≫ e1- ⇐ bool] + [⊢ e2 ≫ e2- ⇐ bool] + -------- + [⊢ (name- e1- e2-) ⇒ bool]] + [(_ e1 e2) ≫ ; else try to coerce + [⊢ e1 ≫ e1- ⇒ ty1] + [⊢ e2 ≫ e2- ⇒ ty2] + #:fail-unless (cast-ok? #'ty1 #'bool #'e1) + (format "cannot cast ~a to bool" (type->str #'ty1)) + #:fail-unless (cast-ok? #'ty2 #'bool #'e2) + (format "cannot cast ~a to bool" (type->str #'ty2)) + -------- + [⊢ (name- (synth-app (bool) e1-) + (synth-app (bool) e2-)) ⇒ bool]])])) +(define-simple-macro (define-coercing-bool-binops o ...+) + (ro:begin + (provide o ...) + (define-coercing-bool-binop o) ...)) + +(define-coercing-bool-binops || &&) + +;; TODO: this should produce int-vector result? +(define-typed-syntax == + [(_ e1 e2) ≫ + [⊢ e1 ≫ e1- ⇒ ty1] + [⊢ e2 ≫ e2- ⇒ ty2] + #:when (real-type? #'ty1) + #:when (real-type? #'ty2) + #:with ty-out ((current-join) #'ty1 #'ty2) ; only need this for the len + -------- + [⊢ (to-int (cl:== e1- e2-)) ⇒ int]]) + +#;(define-typed-syntax && + [(_ e1 e2) ≫ + [⊢ e1 ≫ e1- ⇐ bool] + [⊢ e2 ≫ e2- ⇐ bool] + -------- + [⊢ (cl:&& e1- e2-) ⇒ bool]] + ;; else try to coerce + [(_ e1 e2) ≫ + [⊢ e1 ≫ e1- ⇒ ty1] + [⊢ e2 ≫ e2- ⇒ ty2] + -------- + [⊢ (cl:&& (synth-app (bool) e1-) (synth-app (bool) e2-)) ⇒ bool]]) diff --git a/turnstile/examples/tests/rosette/rosette3/synthcl3-tests.rkt b/turnstile/examples/tests/rosette/rosette3/synthcl3-tests.rkt new file mode 100644 index 0000000..3f314e0 --- /dev/null +++ b/turnstile/examples/tests/rosette/rosette3/synthcl3-tests.rkt @@ -0,0 +1,167 @@ +#lang s-exp "../../../rosette/synthcl3.rkt" +(require "../../rackunit-typechecking.rkt" + (prefix-in cl: sdsl/synthcl/lang/main) + (prefix-in ro: (rename-in rosette [#%app a]))) + +;; from synthcl/test/snippets.rkt and more-snippets.rkt +;; (ro:define-symbolic b ro:boolean?) + +(: int x) +(check-type x : int -> x) +(check-not-type x : CInt) +;; TODO: should these be defined in synthcl? +;; I think no, synthcl is not an extension of rosette +;; (check-type (term? x) : CBool -> #t) +;; (check-type (expression? x) : CBool -> #f) +;; (check-type (constant? x) : CBool -> #t) + +(assert (+ x 1)) +(assert (% (+ x 2) 3)) +(assert (!= x 2)) + +(check-type "" : char*) +(: char* y) +(check-type y : char* -> "") + +(: float v) +(check-type v : float -> v) +;; (check-type (term? v) : CBool -> #t) +;; (check-type (expression? v) : CBool -> #f) +;; (check-type (constant? v) : CBool -> #t) + +(check-type ((bool) v) : bool -> (ro:a ro:! (ro:a ro:= 0 v))) +(check-type (! ((bool) x)) : bool -> (ro:a ro:= 0 x)) +(assert (! ((bool) x))) + +(check-type (|| x (! v)) : bool + -> (ro:a ro:|| (ro:a ro:! (ro:a ro:= 0 x)) + (ro:a ro:&& (ro:a ro:= 0 x) (ro:a ro:= 0 v)))) +(assert (|| x (! v))) + +(check-type (== x v) : int + -> (ro:if (ro:a ro:= v (ro:a ro:integer->real x)) 1 0)) +(assert (== x v)) + +(: float3 z) +(check-type z : float3 -> z) +;; check coercions +(check-type ((bool) x) : bool -> (ro:a ro:! (ro:a ro:= 0 x))) +(check-type ((float) x) : float -> (ro:a ro:integer->real x)) +(check-type ((float3) x) : float3 + -> (ro:a ro:vector-immutable + (ro:a ro:integer->real x) + (ro:a ro:integer->real x) + (ro:a ro:integer->real x))) +(check-type ((float3) z) : float3 -> z) + +;; expected: +;; (vector +;; (ite (= 0 x$0) z$0 (integer->real x$0)) +;; (ite (= 0 x$0) z$1 (integer->real x$0)) +;; (ite (= 0 x$0) z$2 (integer->real x$0))) +(check-type (?: x x z) : float3 + -> (ro:if (ro:a ro:= 0 x) + z + (ro:a ro:vector-immutable + (ro:a ro:integer->real x) + (ro:a ro:integer->real x) + (ro:a ro:integer->real x)))) + +(typecheck-fail ((bool) z) + #:with-msg "no implicit conversion from float3 to bool") + +(check-type (?: z x x) : float3 + -> (ro:a ro:vector-immutable + (ro:a ro:integer->real x) + (ro:a ro:integer->real x) + (ro:a ro:integer->real x))) + +(: int16 u) + +(check-type u : int16 -> u) + +;; NULL +;; ((int16) v) + +;; (= x 3.4) +;; x + +;; (+= z 2) +;; z + +;; (%= x 3) +;; x + +;; (int3 4 5 6) +;; (= [u xyz] (int3 4 5 6)) +;; u + +;; (+ (int3 1 2 3) 4) + +;; ((int4 5 6 7 8) s03) + +;; (if x {}{}) + +;; (if x +;; { (= [u sf] 10) } +;; { (= [u sf] 9) } +;; ) +;; u + +;; (if (! x) +;; { (: int g) (= g 3) (= [u sf] g) } +;; { (= [u sf] 9) } +;; ) +;; u + +;; (for [(: int i in (range 0 4 1))] ) +;; (for [(: int i in (range 0 4 1))] +;; (if (! x) +;; { (: int g) (= g i) (+= [u sf] g)} ) +;; ) +;; u + + +;; (: int16* w) +;; (= w ((int16*) (malloc 32))) +;; (= [w 0] 1) +;; (= [w 1] 2) +;; w +;; [w 0] +;; [w 1] + +;; (get_work_dim) + +;; (procedure void (nop1)) +;; (nop1) +;; (kernel void (nop2)) +;; (nop2) + +;; (procedure int (int_iden [int x]) x) +;; (int_iden ((int) 4.5)) +;; (int_iden #t) +;; (int_iden 4.5) + + + +;; ;;;;;; assertion failure localization ;;;;;; +;; ; (assert #f) + +;; ;;;;;; bad types etc ;;;;;; +;; ;(: float* NULL) +;; ;(+ x y) +;; ;(?: "" x x) +;; ;((int) z) +;; ;(-= z w) +;; ;(%= z 3) +;; ;(NULL 3) +;; ;(if x) +;; ;(for [() () "" (-= x 1)]) +;; ;[w ""] +;; ;[w 2] +;; ;(procedure int (bad)) +;; ;(procedure) +;; ;(kernel int (bad) 1) +;; ;(procedure void (w)) +;; ;(int_iden "") +;; ;(procedure float (bad) "")