synthcl3: add macros to abstract definition of common forms and types

This commit is contained in:
Stephen Chang 2016-11-15 15:00:39 -05:00
parent cdedc4b556
commit d20a238961

View File

@ -3,7 +3,6 @@
(require ;(prefix-in ro: (except-in rosette verify sqrt range print)) ; untyped
racket/stxparam
(prefix-in ro: rosette)
; define-symbolic* #%datum define if ! = number? boolean? cond ||))
(prefix-in cl: sdsl/synthcl/lang/forms)
(prefix-in cl: sdsl/synthcl/model/reals)
(prefix-in cl: sdsl/synthcl/model/operators)
@ -18,41 +17,24 @@
(define (mk-cl id) (format-id id "cl:~a" id))
(current-host-lang mk-cl))
;(define-base-types)
(provide (rename-out
[synth-app #%app]
; [rosette3:Bool bool] ; symbolic
; [rosette3:Int int] ; symbolic
; [rosette3:Num float] ; symbolic
#;[rosette3:CString char*]) ; always concrete
procedure kernel
#%datum if range for
bool int int2 int3 int4 float float3 int16 void
void* char* int* int16*
: ! ?: == + %
;; assignment ops
= += %=
(provide (rename-out [synth-app #%app])
procedure kernel #%datum if range for
int int2 int3 int4 int16 float float2 float3 float4 float16
bool void void* char* int* int16*
: ! ?: == + % || &&
= += %= ; assignment ops
(typed-out
;; need the concrete cases for literals;
;; alternative is to redefine #%datum to give literals symbolic type
[malloc : (C→ int void*)]
[get_work_dim : (C→ int)]
#;[% : (Ccase-> (C→ CInt CInt CInt)
(C→ Int Int Int))]
[!= : (Ccase-> (C→ CNum CNum CBool)
(C→ CNum CNum CNum CBool)
(C→ Num Num Bool)
(C→ Num Num Num Bool))]
[NULL : void*]
#;[== : (Ccase-> (C→ CNum CNum CBool)
(C→ CNum CNum CNum CBool)
(C→ Num Num Bool)
(C→ Num Num Num Bool))]))
[NULL : void*]))
(begin-for-syntax
;; TODO: use equality type relation instead of subtype
;; - require reimplementing many more things, eg #%datum, +, etc
; (current-typecheck-relation (current-type=?))
;; typecheck unexpanded types
(define (typecheck/un? t1 t2)
(typecheck? ((current-type-eval) t1)
@ -107,10 +89,7 @@
(format "no implicit conversion from ~a to ~a"
(type->str from) (type->str to)
#;(if (contract? to) (contract-name to) to))
expr subexpr))
#;(or (typecheck/un? from to) ; from == to
(and (real-type? from)
(typecheck/un? to #'bool))))
expr subexpr)))
(define (add-convert stx fn)
(set-stx-prop/preserved stx 'convert fn))
(define (get-convert stx)
@ -137,6 +116,7 @@
[(_ stx fn) (add-construct #'stx #'fn)])
;; TODO: reuse impls in model/reals.rkt ?
(ro:define (to-bool v)
(ro:cond
[(ro:boolean? v) v]
@ -153,65 +133,7 @@
[(ro:fixnum? v) (ro:exact->inexact v)]
[(ro:flonum? v) v]
[else (ro:type-cast ro:real? v)]))
(ro:define (to-float3 v)
(ro:cond
[(ro:list? v)
(ro:apply ro:vector-immutable
(ro:for/list ([i 3]) (to-float (ro:list-ref v i))))]
[(ro:vector? v)
(ro:apply ro:vector-immutable
(ro:for/list ([i 3]) (to-float (ro:vector-ref v i))))]
[else (ro:apply ro:vector-immutable (ro:make-list 3 (to-float v)))]))
(ro:define (to-int2 v)
(ro:cond
[(ro:list? v)
(ro:apply ro:vector-immutable
(ro:for/list ([i 2]) (to-int (ro:list-ref v i))))]
[(ro:vector? v)
(ro:apply ro:vector-immutable
(ro:for/list ([i 2]) (to-int (ro:vector-ref v i))))]
[else
(ro:apply ro:vector-immutable
(ro:make-list 2 (to-int v)))]))
(ro:define (to-int3 v)
(ro:cond
[(ro:list? v)
(ro:apply ro:vector-immutable
(ro:for/list ([i 3]) (to-int (ro:list-ref v i))))]
[(ro:vector? v)
(ro:apply ro:vector-immutable
(ro:for/list ([i 3]) (to-int (ro:vector-ref v i))))]
[else
(ro:apply ro:vector-immutable
(ro:make-list 3 (to-int v)))]))
(ro:define (to-int4 v)
(ro:cond
[(ro:list? v)
(ro:apply ro:vector-immutable
(ro:for/list ([i 4]) (to-int (ro:list-ref v i))))]
[(ro:vector? v)
(ro:apply ro:vector-immutable
(ro:for/list ([i 4]) (to-int (ro:vector-ref v i))))]
[else
(ro:apply ro:vector-immutable
(ro:make-list 4 (to-int v)))]))
(ro:define (mk-int2 x y)
(ro:#%app ro:vector-immutable (to-int x) (to-int y)))
(ro:define (mk-int3 x y z)
(ro:#%app ro:vector-immutable (to-int x) (to-int y) (to-int z)))
(ro:define (mk-int4 w x y z)
(ro:#%app ro:vector-immutable (to-int w) (to-int x) (to-int y) (to-int z)))
(ro:define (to-int16 v)
(ro:cond
[(ro:list? v)
(ro:apply ro:vector-immutable
(ro:for/list ([i 16]) (to-int (ro:list-ref v i))))]
[(ro:vector? v)
(ro:apply ro:vector-immutable
(ro:for/list ([i 16]) (to-int (ro:vector-ref v i))))]
[else
(ro:apply ro:vector-immutable
(ro:make-list 16 (to-int v)))]))
(ro:define (to-int16* v)
(cl:pointer-cast v cl:int16))
@ -223,35 +145,73 @@
(add-convertm rosette3:Num to-float))
(define-named-type-alias char* rosette3:CString)
(define-named-type-alias float3
(add-convertm
(rosette3:CVector rosette3:Num rosette3:Num rosette3:Num)
to-float3))
(define-named-type-alias int2
(add-constructm
(add-convertm
(rosette3:CVector rosette3:Int rosette3:Int)
to-int2)
mk-int2))
(define-named-type-alias int3
(add-constructm
(add-convertm
(rosette3:CVector rosette3:Int rosette3:Int rosette3:Int)
to-int3)
mk-int3))
(define-named-type-alias int4
(add-constructm
(add-convertm
(rosette3:CVector rosette3:Int rosette3:Int rosette3:Int rosette3:Int)
to-int4)
mk-int4))
(define-named-type-alias int16
(add-convertm
(rosette3:CVector rosette3:Int rosette3:Int rosette3:Int rosette3:Int
rosette3:Int rosette3:Int rosette3:Int rosette3:Int
rosette3:Int rosette3:Int rosette3:Int rosette3:Int
rosette3:Int rosette3:Int rosette3:Int rosette3:Int)
to-int16))
(define-syntax (define-int stx)
(syntax-parse stx
[(_ n)
#:with intn (format-id #'n "int~a" (syntax->datum #'n))
#:with to-intn (format-id #'n "to-~a" #'intn)
#:with mk-intn (format-id #'n "mk-~a" #'intn)
#:with (x ...) (generate-temporaries
(build-list (syntax->datum #'n) (lambda (x) x)))
#:with (I ...) (stx-map (lambda _ #'rosette3:Int) #'(x ...))
#'(begin
(define-named-type-alias intn
(add-constructm
(add-convertm
(rosette3:CVector I ...)
to-intn)
mk-intn))
(ro:define (to-intn v)
(ro:cond
[(ro:list? v)
(ro:apply ro:vector-immutable
(ro:for/list ([i n]) (to-int (ro:list-ref v i))))]
[(ro:vector? v)
(ro:apply ro:vector-immutable
(ro:for/list ([i n]) (to-int (ro:vector-ref v i))))]
[else
(ro:apply ro:vector-immutable
(ro:make-list n (to-int v)))]))
(ro:define (mk-intn x ...)
(ro:#%app ro:vector-immutable (to-int x) ...))
)]))
(define-simple-macro (define-ints n ...) (begin (define-int n) ...))
(define-ints 2 3 4 16)
(define-syntax (define-float stx)
(syntax-parse stx
[(_ n)
#:with floatn (format-id #'n "float~a" (syntax->datum #'n))
#:with to-floatn (format-id #'n "to-~a" #'floatn)
#:with mk-floatn (format-id #'n "mk-~a" #'floatn)
#:with (x ...) (generate-temporaries
(build-list (syntax->datum #'n) (lambda (x) x)))
#:with (I ...) (stx-map (lambda _ #'rosette3:Num) #'(x ...))
#'(begin
(define-named-type-alias floatn
(add-constructm
(add-convertm
(rosette3:CVector I ...)
to-floatn)
mk-floatn))
(ro:define (to-floatn v)
(ro:cond
[(ro:list? v)
(ro:apply ro:vector-immutable
(ro:for/list ([i n]) (to-float (ro:list-ref v i))))]
[(ro:vector? v)
(ro:apply ro:vector-immutable
(ro:for/list ([i n]) (to-float (ro:vector-ref v i))))]
[else
(ro:apply ro:vector-immutable
(ro:make-list n (to-float v)))]))
(ro:define (mk-floatn x ...)
(ro:#%app ro:vector-immutable (to-float x) ...))
)]))
(define-simple-macro (define-floats n ...) (begin (define-float n) ...))
(define-floats 2 3 4 16)
(define-type-constructor Pointer #:arity = 1)
(define-named-type-alias void rosette3:CUnit)
(define-named-type-alias void* (Pointer rosette3:CUnit))
@ -274,9 +234,6 @@
[ (ro:#%app convert e-) ty.norm]]
[(_ ty:type e ...) ; construct
[ e e- ty-e] ...
;; #:fail-unless (cast-ok? #'ty-e #'ty.norm #'e)
;; (format "cannot cast ~a to ~a"
;; (type->str #'ty-e) (type->str #'ty.norm))
#:with construct (get-construct #'ty.norm)
#:fail-unless (syntax-e #'construct)
(format "no constructor found for ~a type"
@ -294,7 +251,6 @@
#:with ty-out ((current-type-eval)
(format-id #'here "~a~a"
base-str len-str))
; #:with convert (get-convert #'ty-out)
--------
[ e-out ty-out]]
[(_ vec sel) ; applying vector to one arg is selector
@ -342,20 +298,18 @@
--------
[ #,(if (typecheck/un? #'ty-out #'void)
#'(rosette3:define (f [x col ty] ... arr ty-out)
;; TODO: this is deviating from rosette's impl
;; but I think it's a bug in rosette
;; otherwise it's unsound
; (⊢ (ro:set! x (ro:a conv x)) void) ...
(⊢m (ro:let ([x (ro:#%app conv x)] ...)
e ...
(rosette3:#%app rosette3:void))
ty-out))
;; TODO: this is deviating from rosette's impl
;; (to use let instead of set!)
;; but I think it's a bug in rosette, otherwise it's unsound
; (⊢ (ro:set! x (ro:a conv x)) void) ...
(⊢m (ro:let ([x (ro:#%app conv x)] ...)
e ... (rosette3:#%app rosette3:void))
ty-out))
#'(rosette3:define (f [x col ty] ... arr ty-out)
; (⊢ (ro:set! x (ro:a conv x)) void) ...
(⊢m (ro:let ([x (ro:#%app conv x)] ...)
(rosette3:#%app rosette3:void)
e ...)
ty-out)))]])
; (⊢ (ro:set! x (ro:a conv x)) void) ...
(⊢m (ro:let ([x (ro:#%app conv x)] ...)
(rosette3:#%app rosette3:void) e ...)
ty-out)))]])
(define-typed-syntax kernel
[(_ ty-out:type (f [ty:type x:id] ...) e ...)
--------
@ -372,14 +326,12 @@
--------
[ (if test {then ...} {})]])
;(define-syntax-parameter range (syntax-rules ()))
(define-typed-syntax (range e ...)
[ e e- int] ...
--------
[ (ro:#%app ro:in-range e- ...) int])
(define-typed-syntax for
[(_ [((~literal :) ty:type var:id (~datum in) rangeExpr) ...] e ...)
; [⊢ rangeExpr ≫ rangeExpr- ⇒ _] ...
[[var var- : ty.norm] ... [e e- ty-e] ...]
--------
[ (ro:for* ([var- rangeExpr] ...)
@ -389,14 +341,9 @@
;; need to redefine #%datum because rosette3:#%datum is too precise
(define-typed-syntax #%datum
[(_ . b:boolean)
; #:with ty_out (if (syntax-e #'b) #'True #'False)
--------
[ (ro:#%datum . b) bool]]
[(_ . n:integer)
;; #:with ty_out (let ([m (syntax-e #'n)])
;; (cond [(zero? m) #'Zero]
;; [(> m 0) #'PosInt]
;; [else #'NegInt]))
--------
[ (ro:#%datum . n) int]]
[(#%datum . n:number)
@ -565,9 +512,7 @@
[ (name- (synth-app (bool) e1-)
(synth-app (bool) e2-)) bool]])]))
(define-simple-macro (define-coercing-bool-binops o ...+)
(ro:begin
(provide o ...)
(define-coercing-bool-binop o) ...))
(ro:begin (define-coercing-bool-binop o) ...))
(define-coercing-bool-binops || &&)
@ -582,72 +527,32 @@
--------
[ (to-int (cl:== e1- e2-)) int]])
(define-typed-syntax +
[(_ e1 e2)
[ e1 e1- ty1]
[ e2 e2- ty2]
;; #:when (real-type? #'ty1)
;; #:when (real-type? #'ty2)
#:with ty-out (common-real-type #'ty1 #'ty2)
#:with convert (get-convert #'ty-out)
#:with ty-base (get-base #'ty-out)
#:with base-convert (get-convert #'ty-base)
--------
[ #,(if (scalar-type? #'ty-out)
#'(convert (cl:+ (synth-app (ty-out) e1-)
(synth-app (ty-out) e2-)))
#'(convert
(ro:let ([a (convert e1-)][b (convert e2-)])
(ro:for/list ([v1 a][v2 b])
(base-convert (cl:+ v1 v2)))))) ty-out]])
(define-typed-syntax +=
[(_ x e)
--------
[ (= x (+ x e))]])
(define-typed-syntax %
[(_ e1 e2)
[ e1 e1- ty1]
[ e2 e2- ty2]
;; #:when (real-type? #'ty1)
;; #:when (real-type? #'ty2)
#:with ty-out (common-real-type #'ty1 #'ty2)
#:with convert (get-convert #'ty-out)
#:with ty-base (get-base #'ty-out)
#:with base-convert (get-convert #'ty-base)
--------
[ #,(if (scalar-type? #'ty-out)
#'(convert (cl:% (synth-app (ty-out) e1-)
(synth-app (ty-out) e2-)))
#'(convert
(ro:let ([a (convert e1-)][b (convert e2-)])
(ro:for/list ([v1 a][v2 b])
(base-convert (cl:% v1 v2)))))) ty-out]])
(define-typed-syntax %=
[(_ x e)
--------
[ (= x (% x e))]])
#;(define-typed-syntax %=
[(_ x:id e)
[ x x- ty-x]
[ e e- ty-e]
#:fail-unless (cast-ok? #'ty-e #'ty-x stx)
(format "cannot cast ~a to ~a"
(type->str #'ty-e) (type->str #'ty-x))
--------
[ (ro:set! x- (% x- (synth-app (ty-x) e-))) void]])
#;(define-typed-syntax &&
[(_ e1 e2)
[ e1 e1- bool]
[ e2 e2- bool]
--------
[ (cl:&& e1- e2-) bool]]
;; else try to coerce
[(_ e1 e2)
[ e1 e1- ty1]
[ e2 e2- ty2]
--------
[ (cl:&& (synth-app (bool) e1-) (synth-app (bool) e2-)) bool]])
(define-syntax (define-coercing-real-binop stx)
(syntax-parse stx
[(_ name)
#:with name- (mk-cl #'name)
#:with name= (format-id #'name "~a=" #'name) ; assignment form
#'(begin-
(define-typed-syntax name
[(_ e1 e2)
[ e1 e1- ty1]
[ e2 e2- ty2]
#:with ty-out (common-real-type #'ty1 #'ty2)
#:with convert (get-convert #'ty-out)
#:with ty-base (get-base #'ty-out)
#:with base-convert (get-convert #'ty-base)
--------
[ #,(if (scalar-type? #'ty-out)
#'(convert (name- (synth-app (ty-out) e1-)
(synth-app (ty-out) e2-)))
#'(convert
(ro:let ([a (convert e1-)][b (convert e2-)])
(ro:for/list ([v1 a][v2 b])
(base-convert (name- v1 v2)))))) ty-out]])
(define-typed-syntax name=
[(_ x e)
--------
[ (= x (name x e))]]))]))
(define-simple-macro (define-real-binops o ...)
(ro:begin (define-coercing-real-binop o) ...))
(define-real-binops + %)