more synthcl3 code cleanup

This commit is contained in:
Stephen Chang 2016-12-08 16:04:40 -05:00
parent 25a66ae60e
commit f0646f300b
2 changed files with 90 additions and 141 deletions

View File

@ -798,9 +798,14 @@
(define (mk-tyvar X) (attach X 'tyvar #t))
(define (tyvar? X) (syntax-property X 'tyvar))
(define type-pat "[A-Za-z]+")
;; TODO: remove this? only benefit is single control point for current-promote
(define type-pat "[A-Za-z]+"))
(define-syntax (⊢m stx)
(syntax-parse stx #:datum-literals (:)
[(_ e : τ) (assign-type #`e #`τ)]
[(_ e τ) (assign-type #`e #`τ)]))
(begin-for-syntax
;; - infers type of e
;; - checks that type of e matches the specified type
;; - erases types in e

View File

@ -73,7 +73,7 @@
(format "no implicit conversion from ~a to ~a"
(type->str from) (type->str to)) expr subexpr)))
(define (mk-ptr id) (format-id id "~a*" id))
(define (mk-mk id) (format-id id "mk-~a" id))
(define (mk-mk id [ctx id]) (format-id ctx "mk-~a" id))
(define (mk-to id) (format-id id "to-~a" id))
(define (add-construct stx fn) (set-stx-prop/preserved stx 'construct fn))
(define (add-convert stx fn) (set-stx-prop/preserved stx 'convert fn))
@ -85,12 +85,13 @@
(define split-ty (ty->len t))
(string->number
(or (and split-ty (third split-ty)) "1")))
(define (uneval ty #:ctx [ctx #'here] #:str-fn [str-fn (λ (x) x)])
(datum->syntax ctx (string->symbol (str-fn (type->str ty)))))
(define (get-base/un ty [ctx #'here]) ; returns unexpanded base type
(datum->syntax ctx
(string->symbol (car (regexp-match #px"[a-z]+" (type->str ty))))))
(uneval ty #:ctx ctx #:str-fn (λ (s) (car (regexp-match #px"[a-z]+" s)))))
(define (get-base ty [ctx #'here]) ((current-type-eval) (get-base/un ty ctx)))
(define (get-pointer-base ty [ctx #'here]) ; returns unexpanded ptr base
(datum->syntax ctx (string->symbol (string-trim (type->str ty) "*"))))
(uneval ty #:ctx ctx #:str-fn (λ (s) (string-trim s "*"))))
(define (vector-type? ty)
(define tstr (type->str ty))
(ormap (λ (x) (string=? x tstr)) '("int2" "int3" "int4" "int16" "float2" "float3" "float4" "float16")))
@ -157,21 +158,15 @@
(define-typed-syntax synth-app
[(_ (ty:type) e) ; cast
[ e e- ty-e]
#:fail-unless (cast-ok? #'ty-e #'ty.norm #'e)
(format "cannot cast ~a to ~a"
(type->str #'ty-e) (type->str #'ty.norm))
#:when (cast-ok? #'ty-e #'ty.norm #'e) ; raises exn
#:with convert (get-convert #'ty.norm)
#:fail-unless (syntax-e #'convert)
(format "cannot cast ~a to ~a: conversion fn not found"
(type->str #'ty-e) (type->str #'ty.norm))
--------
[ (ro:#%app convert e-) ty.norm]]
[(_ ty:type e ...) ; construct
[ e e- ty-e] ...
#:with construct (get-construct #'ty.norm)
#:fail-unless (syntax-e #'construct)
(format "no constructor found for ~a type"
(type->str #'ty.norm))
(format "no constructor found for ~a type" (type->str #'ty.norm))
--------
[ (ro:#%app construct e- ...) ty.norm]]
[(_ p _) ; applying ptr to one arg is selector
@ -194,15 +189,13 @@
#:do [(define base-str (cadr split-ty))
(define len-str (caddr split-ty))]
#:do [(define sels (length (stx->list #'selector)))]
#:with e-out (if (= sels 1)
#'(ro:vector-ref vec- (ro:car 'selector))
#'(for/list ([idx 'selector])
(ro:vector-ref vec- idx)))
#:with e-out (if (= sels 1) #'(ro:vector-ref vec- (ro:car 'selector))
#'(for/list ([idx 'selector])
(ro:vector-ref vec- idx)))
#:with ty-out ((current-type-eval)
(if (= sels 1)
(format-id #'here "~a" base-str)
(format-id #'here "~a~a"
base-str (length (stx->list #'selector)))))
(if (= sels 1) (format-id #'here "~a" base-str)
(format-id #'here "~a~a"
base-str (length (stx->list #'selector)))))
#:with convert (get-convert #'ty-out)
--------
[ (ro:#%app convert e-out) ty-out]]
@ -212,14 +205,7 @@
#:when (stx-andmap cast-ok? #'(ty-e ...) #'(ty-in ...))
--------
[ (ro:#%app f- e- ...) ty-out]]
[(_ . es)
--------
[ (rosette3:#%app . es)]])
(define-syntax (⊢m stx)
(syntax-parse stx #:datum-literals (:)
[(_ e : τ) (assign-type #`e #`τ)]
[(_ e τ) (assign-type #`e #`τ)]))
[(_ . es) --- [ (rosette3:#%app . es)]])
;; top-level fns --------------------------------------------------
(define-typed-syntax procedure
@ -239,11 +225,12 @@
(rosette3:let ([x (⊢m (ro:#%app conv x) ty)] ...)
(⊢m (ro:let () e ... (rosette3:ann e-body : ty-out)) ty-out))))
(provide- f))]])
(define-typed-syntax kernel
[(_ ty-out:type (f [ty:type x:id] ...) e ...)
#:fail-unless (void? #'ty-out.norm)
(format "expected void, given ~a" (type->str #'ty-out.norm))
--- [ (procedure void (f [ty x] ...) e ...)]])
(define-typed-syntax (kernel ty-out:type (f [ty:type x:id] ...) e ...)
#:fail-unless (void? #'ty-out.norm)
(format "expected void, given ~a" (type->str #'ty-out.norm))
--- [ (procedure void (f [ty x] ...) e ...)])
(define-typed-syntax grammar
[(_ ty-out:type (f [ty:type x:id] ... [ty-depth k]) #:base be #:else ee)
#:with f- (generate-temporary #'f)
@ -255,8 +242,7 @@
(define-typed-syntax f
[(ff a ... j)
[ a _ ty] ...
[ j _ ty-depth]
;; j will be eval'ed, so strip its context
[ j _ ty-depth] ; j will be eval'ed, so strip its context
#:with j- (assign-type (datum->syntax #'H (stx->datum #'j)) #'int)
#:with f-- (replace-stx-loc #'f- #'ff)
-----------
@ -281,7 +267,7 @@
[(_ e-test {e1 ...} {e2 ...})
--------
[ (ro:if (to-bool e-test)
(ro:let () e1 ... (ro:void))
(ro:let () e1 ... (ro:void))
(ro:let () e2 ... (ro:void))) void]]
[(_ e-test es) --- [ (if e-test es {})]])
@ -298,53 +284,44 @@
(with-ctx ([x x- ty] ...)
(⊢m (ro:let () e ... (ro:void)) void)))) void]])
;; need to redefine #%datum because rosette3:#%datum is too precise
(define-typed-syntax #%datum
(define-typed-syntax #%datum ; redefine bc rosette3:#%datum is too precise
[(_ . b:boolean) --- [ (ro:#%datum . b) bool]]
[(_ . s:str) --- [ (ro:#%datum . s) char*]]
[(_ . n:integer) --- [ (ro:#%datum . n) int]]
[(#%datum . n:number)
#:when (real? (syntax-e #'n))
[(_ . n:number)
#:when (real? (syntax-e #'n))
--------
[ (ro:#%datum . n) float]]
[(_ . x)
--------
[_ #:error (type-error #:src #'x #:msg "Unsupported literal: ~v" #'x)]])
;; : --------------------------------------------------
;; : (var declaration) --------------------------------------------------
(define-typed-syntax :
[(_ ty:type x:id ...) ; special String case
#:when (rosette3:CString? #'ty.norm)
#:with (x- ...) (generate-temporaries #'(x ...))
--------
[ (begin-
(define-syntax- x
(make-rename-transformer (assign-type #'x- #'ty.norm))) ...
(ro:define x- (ro:#%datum . "")) ...)]]
;; TODO: vector types need a better representation
;; currently dissecting the identifier
;; TODO: combine vector and scalar cases
[ (begin- (define-syntax- x
(make-rename-transformer (assign-type #'x- #'ty.norm))) ...
(ro:define x- (ro:#%datum . "")) ...)]]
[(_ ty:type x:id ...)
#:when (real-type? #'ty.norm)
#:do [(define split-ty (ty->len #'ty))]
#:when (and split-ty (= 3 (length split-ty)))
#:do [(define base-str (cadr split-ty))
#:do [(define base-str (cadr split-ty))
(define len-str (caddr split-ty))]
#:with ty-base (datum->syntax #'ty (string->symbol base-str))
#:with pred (get-pred ((current-type-eval) #'ty-base))
#:fail-unless (syntax-e #'pred)
(format "no pred for ~a" (type->str #'ty))
#:fail-unless (syntax-e #'pred) (format "no pred for ~a" (type->str #'ty))
#:with (x- ...) (generate-temporaries #'(x ...))
#:with (x-- ...) (generate-temporaries #'(x ...))
#:with mk-ty (format-id #'here "mk-~a" #'ty)
#:with mk-ty (mk-mk #'ty #'here)
--------
[ (begin-
(ro:define-symbolic* x-- pred [#,(string->number len-str)]) ...
(ro:define x- (ro:apply mk-ty x--)) ...
(define-syntax- x
(make-rename-transformer (assign-type #'x- #'ty.norm))) ...)]]
[ (begin- (ro:define-symbolic* x-- pred [#,(string->number len-str)]) ...
(ro:define x- (ro:apply mk-ty x--)) ...
(define-syntax- x
(make-rename-transformer (assign-type #'x- #'ty.norm))) ...)]]
[(_ ty:type [len] x:id ...) ; array of vector types
#:when (real-type? #'ty.norm)
[ len len- int]
@ -353,42 +330,37 @@
#:with ty* (format-id #'ty "~a*" #'ty)
#:with to-ty* (format-id #'here "to-~a" #'ty*)
#:with pred (get-pred ((current-type-eval) #'ty-base))
#:fail-unless (syntax-e #'pred)
(format "no pred for ~a" (type->str #'ty))
#:fail-unless (syntax-e #'pred) (format "no pred for ~a" (type->str #'ty))
#:with (x- ...) (generate-temporaries #'(x ...))
#:with (*x ...) (generate-temporaries #'(x ...))
#:with (x-- ...) (generate-temporaries #'(x ...))
#:with mk-ty (format-id #'here "mk-~a" #'ty)
--------
[ (begin-
(ro:define-symbolic* x-- pred [len base-len]) ...
(ro:define x-
(ro:let ([*x (to-ty* (cl:malloc (ro:* len base-len)))])
(ro:for ([i len][v x--])
[ (begin- (ro:define-symbolic* x-- pred [len base-len]) ...
(ro:define x-
(ro:let ([*x (to-ty* (cl:malloc (ro:* len base-len)))])
(ro:for ([i len][v x--])
(cl:pointer-set! *x i (ro:apply mk-ty v)))
*x)) ...
(define-syntax- x
(make-rename-transformer (assign-type #'x- #'ty*))) ...)]]
(define-syntax- x
(make-rename-transformer (assign-type #'x- #'ty*))) ...)]]
;; real, scalar (ie non-vector) types
[(_ ty:type x:id ...)
#:when (real-type? #'ty.norm)
#:with pred (get-pred #'ty.norm)
#:fail-unless (syntax-e #'pred)
(format "no pred for ~a" (type->str #'ty))
#:fail-unless (syntax-e #'pred) (format "no pred for ~a" (type->str #'ty))
#:with (x- ...) (generate-temporaries #'(x ...))
--------
[ (begin-
(define-syntax- x
(make-rename-transformer (assign-type #'x- #'ty.norm))) ...
(ro:define-symbolic* x- pred) ...)]]
[ (begin- (define-syntax- x
(make-rename-transformer (assign-type #'x- #'ty.norm))) ...
(ro:define-symbolic* x- pred) ...)]]
;; else init to NULLs
[(_ ty:type x:id ...)
#:with (x- ...) (generate-temporaries #'(x ...))
--------
[ (begin-
(define-syntax- x
(make-rename-transformer (assign-type #'x- #'ty.norm))) ...
(ro:define x- cl:NULL) ...)]])
[ (begin- (define-syntax- x
(make-rename-transformer (assign-type #'x- #'ty.norm))) ...
(ro:define x- cl:NULL) ...)]])
;; ?: --------------------------------------------------
(define-typed-syntax ?:
@ -406,20 +378,16 @@
#:with ty-base ((current-type-eval) (datum->syntax #'e (string->symbol out-base-str)))
#:with base-convert (get-convert #'ty-base)
-------
[ (convert
(ro:let ([a (convert e-)][b (convert e1-)][c (convert e2-)])
(ro:for/list ([idx #,(string->number out-len-str)])
(ro:if (ro:< (ro:vector-ref a idx) 0)
(base-convert (ro:vector-ref b idx))
(base-convert (ro:vector-ref c idx))))))
ty-out]]
[ (convert (ro:let ([a (convert e-)][b (convert e1-)][c (convert e2-)])
(ro:for/list ([idx #,(string->number out-len-str)])
(ro:if (ro:< (ro:vector-ref a idx) 0)
(base-convert (ro:vector-ref b idx))
(base-convert (ro:vector-ref c idx)))))) ty-out]]
[(_ ~! e e1 e2) ; should be scalar and real
[ e e- ty]
#:fail-unless (real-type? #'ty)
(format "not a real type: ~s has type ~a"
(syntax->datum #'e) (type->str #'ty))
#:fail-unless (cast-ok? #'ty #'bool #'e)
(format "cannot cast ~a to bool" (type->str #'ty))
#:fail-unless (real-type? #'ty) (format "not a real type: ~s has type ~a"
(syntax->datum #'e) (type->str #'ty))
#:when (cast-ok? #'ty #'bool #'e)
[ e1 e1- ty1]
[ e2 e2- ty2]
#:with ty-out ((current-join) #'ty1 #'ty2)
@ -465,19 +433,15 @@
--------
[ (ro:#%app cl:! (to-bool e-)) bool]])
;; TODO: comparison ops need to support vec types (and result)
(define-simple-macro (mk-cmp cmp-op)
;TODO: cmps should produce vec int result with same length as comm-real-ty
(define-simple-macro (mk-cmp cmp-op)
(define-typed-syntax cmp-op
[(o e1 e2)
[ e1 e1- ty1]
[ e2 e2- ty2]
#:when (real-type? #'ty1)
#:when (real-type? #'ty2)
#:with ty-out ((current-join) #'ty1 #'ty2) ; only need this for the len
#:with conv (get-convert #'ty-out)
#:with o- (mk-cl #'o)
#:with conv (get-convert ((current-join) #'ty1 #'ty2))
--------
[ (to-int (o- (conv e1-) (conv e2-))) int]]))
[ (to-int (#,(mk-cl #'o) (conv e1-) (conv e2-))) int]]))
(define-simple-macro (mk-cmps o ...) (begin- (mk-cmp o) ...))
(mk-cmps == < <= > >= !=)
@ -490,17 +454,11 @@
[ e2 e2- bool]
--------
[ (name- e1- e2-) bool]]
[(_ e1 e2) ; else try to coerce
--- [ (name- (to-bool e1) (to-bool e2)) bool]]))
#;(define- (cl:/ x y)
(cond- [(zero?- y) 0]
[(integer?- x) (quotient- x y)]
[else (/- x y)]))
[(_ e1 e2) --- [ (name- (to-bool e1) (to-bool e2)) bool]])) ; coerce
(define-simple-macro (define-real-ops o ...) (ro:begin (define-real-op o) ...))
(define-simple-macro (define-real-op name (~optional (~seq #:extra-check p?)
#:defaults ([p? #'(λ _ #t)])))
#:defaults ([p? #'(λ _ #t)])))
#:with name- (mk-cl #'name)
#:with name= (format-id #'name "~a=" #'name) ; assignment form
(begin-
@ -521,14 +479,13 @@
#'(convert (ro:let ([x (convert e-)] (... ...))
(ro:for/list ([x x] (... ...))
(base-convert (name- x (... ...))))))) ty-out])
(define-typed-syntax (name= x e)
--- [ (= x (name x e))])))
(define-typed-syntax (name= x e) --- [ (= x (name x e))])))
(define-for-syntax (int? t givens)
(or (typecheck/un? t #'int)
(raise-syntax-error #f
(format "no common integer type for operands; given ~a"
(types->str givens)))))
(raise-syntax-error #f
(format "no common integer type for operands; given ~a"
(types->str givens)))))
(define-simple-macro (define-int-op o) (define-real-op o #:extra-check int?))
(define-simple-macro (define-int-ops o ...) (ro:begin (define-int-op o) ...))
@ -536,41 +493,34 @@
(define-real-ops + * - / sqrt)
(define-int-ops % << $ &)
(define-typerule (sizeof t:type) >> ---[ #,(real-type-length #'t.norm) int])
(define-typerule (print e ...) >> ---[ (ro:begin (display e) ...) void])
(define-typerule (sizeof t:type) --- [ #,(real-type-length #'t.norm) int])
(define-typerule (print e ...) --- [ (ro:begin (display e) ...) void])
(define-typerule (assert e) --- [ (ro:assert (to-bool e)) void])
(define-typerule (clCreateProgramWithSource ctx f)
--- [ (cl:clCreateProgramWithSource ctx f) cl_program])
(define-typed-syntax choose
[(ch e ...+)
#:with (e- ...) (stx-map expand/ro #'(e ...))
#:with (ty ...) (stx-map typeof #'(e- ...))
#:when (same-types? #'(ty ...))
#:with (e/disarmed ...) (stx-map replace-stx-loc #'(e- ...) #'(e ...))
;; the #'choose identifier itself must have the location of its use
;; see define-synthax implementation, specifically syntax/source in utils
#:with ch/disarmed (replace-stx-loc #'ro:choose #'ch)
#:with ch- (replace-stx-loc #'ro:choose #'ch)
--------
[ (ch/disarmed e/disarmed ...) #,(stx-car #'(ty ...))]])
[ (ch- e- ...) #,(stx-car #'(ty ...))]])
(define-typed-syntax ??
[(qq)
#:with ??/progsrc (replace-stx-loc #'cl:?? #'qq)
--------
[ (??/progsrc) int]]
[(qq ty:type)
#:with ??/progsrc (replace-stx-loc #'cl:?? #'qq)
#:with t (datum->syntax #'here (string->symbol (type->str #'ty.norm)))
#:with cl-t (mk-cl #'t)
;; #:with ty-base (get-base #'ty.norm)
;; #:with pred (get-pred ((current-type-eval) #'ty-base))
#:with qq- (replace-stx-loc #'cl:?? #'qq)
#:with cl-t (mk-cl (uneval #'ty.norm))
--------
[ (??/progsrc cl-t) ty]])
[ (qq- cl-t) ty.norm]]
[(qq) --- [ (#,(replace-stx-loc #'cl:?? #'qq)) int]])
(define-typed-syntax (@ x:id)
[ x x- ty+] ;; TODO: check ty = real, non-ptr type
#:with ty (datum->syntax #'x (string->symbol (type->str #'ty+)))
#:with cl-ty (mk-cl #'ty)
#:with ty (uneval #'ty+)
---------
[ (cl:address-of x- cl-ty) #,(mk-ptr #'ty)])
[ (cl:address-of x- #,(mk-cl #'ty)) #,(mk-ptr #'ty)])
(define-typed-syntax locally-scoped
[(_ e ...) ty --- [ (ro:let () e ...)]]
@ -603,7 +553,7 @@
(ro:synthesize
#:forall (ro:append tmp ...)
#:guarantee (ro:for ([id- tmp] ...)
(with-ctx ([id id- ty] ...) e)))))) void]]
(with-ctx ([id id- ty] ...) e)))))) void]]
[(_ #:forall [decl ...] #:ensure e)
--- [ (synth #:forall [decl ...] #:bitwidth 8 #:ensure e)]])
@ -624,9 +574,3 @@
(printf "~a = ~a\n" i (ro:evaluate i- cex)))
cex))
(begin (displayln "no counterexample found") (ro:unsat))))) void]])
(define-typed-syntax (assert e)
--- [ (ro:assert (to-bool #,(expand/ro #'e))) void])
(define-typed-syntax (clCreateProgramWithSource ctx f)
--- [ (cl:clCreateProgramWithSource ctx f) cl_program])