working with app snippet tests: assignment, vectors, pointers, selectors, procedures

This commit is contained in:
Stephen Chang 2016-11-09 17:03:46 -05:00
parent 8e2710e133
commit cdedc4b556
2 changed files with 393 additions and 68 deletions

View File

@ -1,11 +1,18 @@
#lang turnstile
(extends "rosette3.rkt" #:except ! #%app || && void = +) ; typed rosette
(extends "rosette3.rkt" #:except ! #%app || && void = + #%datum if) ; typed rosette
(require ;(prefix-in ro: (except-in rosette verify sqrt range print)) ; untyped
racket/stxparam
(prefix-in ro: rosette)
; define-symbolic* #%datum define if ! = number? boolean? cond ||))
(prefix-in cl: sdsl/synthcl/lang/forms)
(prefix-in cl: sdsl/synthcl/model/reals)
(prefix-in cl: sdsl/synthcl/model/operators)
(prefix-in cl: sdsl/synthcl/model/errors)
(prefix-in cl: sdsl/synthcl/model/memory))
(prefix-in cl: sdsl/synthcl/model/memory)
(prefix-in cl: sdsl/synthcl/model/runtime)
(prefix-in cl: sdsl/synthcl/model/work)
(prefix-in cl: sdsl/synthcl/model/pointers)
(for-syntax (prefix-in cl: sdsl/synthcl/lang/util)))
(begin-for-syntax
(define (mk-cl id) (format-id id "cl:~a" id))
@ -13,20 +20,25 @@
;(define-base-types)
(provide (rename-out
(provide (rename-out
[synth-app #%app]
; [rosette3:Bool bool] ; symbolic
; [rosette3:Int int] ; symbolic
; [rosette3:Num float] ; symbolic
[rosette3:CString char*]) ; always concrete
bool int float float3 int16 void void*
: ! ?: == +
#;[rosette3:CString char*]) ; always concrete
procedure kernel
#%datum if range for
bool int int2 int3 int4 float float3 int16 void
void* char* int* int16*
: ! ?: == + %
;; assignment ops
= +=
= += %=
(typed-out
;; need the concrete cases for literals;
;; alternative is to redefine #%datum to give literals symbolic type
[% : (Ccase-> (C→ CInt CInt CInt)
[malloc : (C→ int void*)]
[get_work_dim : (C→ int)]
#;[% : (Ccase-> (C→ CInt CInt CInt)
(C→ Int Int Int))]
[!= : (Ccase-> (C→ CNum CNum CBool)
(C→ CNum CNum CNum CBool)
@ -45,7 +57,11 @@
(define (typecheck/un? t1 t2)
(typecheck? ((current-type-eval) t1)
((current-type-eval) t2)))
(define (real-type? t) (not (typecheck/un? t #'bool)))
(define (real-type? t)
(and (not (typecheck/un? t #'bool))
(not (pointer-type? t))))
(define (pointer-type? t)
(regexp-match #px"\\*$" (type->str t)))
(define (type-base t)
(datum->syntax t
(string->symbol
@ -78,12 +94,12 @@
;; copied from check-implicit-conversion in lang/types.rkt
;; TODO: this should not exception since it is used in stx-parse
;; clauses that may want to backtrack
(define (cast-ok? from to expr [subexpr #f])
(define (cast-ok? from to [expr #f] [subexpr #f])
(unless (if #t #;(type? to)
(or (typecheck/un? from to)
(and (scalar-type? from) (scalar-type? to))
(and (scalar-type? from) (vector-type? to))
#;(and (pointer-type? from) (pointer-type? to))
(and (pointer-type? from) (pointer-type? to))
#;(and (equal? from cl_mem) (pointer-type? to)))
(to from))
(raise-syntax-error
@ -99,6 +115,10 @@
(set-stx-prop/preserved stx 'convert fn))
(define (get-convert stx)
(syntax-property stx 'convert))
(define (add-construct stx fn)
(set-stx-prop/preserved stx 'construct fn))
(define (get-construct stx)
(syntax-property stx 'construct))
(define (ty->len ty)
(regexp-match #px"([a-z]+)([0-9]+)" (type->str ty)))
(define (get-base ty [ctx #'here])
@ -106,12 +126,15 @@
(datum->syntax ctx
(string->symbol (car (regexp-match #px"[a-z]+" (type->str ty)))))))
(define (vector-type? ty)
;; TODO: and not pointer-type?
(ty->len ty))
(define (scalar-type? ty)
(not (vector-type? ty))))
(define-syntax-parser add-convertm
[(_ stx fn) (add-convert #'stx #'fn)])
(define-syntax-parser add-constructm
[(_ stx fn) (add-construct #'stx #'fn)])
;; TODO: reuse impls in model/reals.rkt ?
(ro:define (to-bool v)
@ -139,6 +162,45 @@
(ro:apply ro:vector-immutable
(ro:for/list ([i 3]) (to-float (ro:vector-ref v i))))]
[else (ro:apply ro:vector-immutable (ro:make-list 3 (to-float v)))]))
(ro:define (to-int2 v)
(ro:cond
[(ro:list? v)
(ro:apply ro:vector-immutable
(ro:for/list ([i 2]) (to-int (ro:list-ref v i))))]
[(ro:vector? v)
(ro:apply ro:vector-immutable
(ro:for/list ([i 2]) (to-int (ro:vector-ref v i))))]
[else
(ro:apply ro:vector-immutable
(ro:make-list 2 (to-int v)))]))
(ro:define (to-int3 v)
(ro:cond
[(ro:list? v)
(ro:apply ro:vector-immutable
(ro:for/list ([i 3]) (to-int (ro:list-ref v i))))]
[(ro:vector? v)
(ro:apply ro:vector-immutable
(ro:for/list ([i 3]) (to-int (ro:vector-ref v i))))]
[else
(ro:apply ro:vector-immutable
(ro:make-list 3 (to-int v)))]))
(ro:define (to-int4 v)
(ro:cond
[(ro:list? v)
(ro:apply ro:vector-immutable
(ro:for/list ([i 4]) (to-int (ro:list-ref v i))))]
[(ro:vector? v)
(ro:apply ro:vector-immutable
(ro:for/list ([i 4]) (to-int (ro:vector-ref v i))))]
[else
(ro:apply ro:vector-immutable
(ro:make-list 4 (to-int v)))]))
(ro:define (mk-int2 x y)
(ro:#%app ro:vector-immutable (to-int x) (to-int y)))
(ro:define (mk-int3 x y z)
(ro:#%app ro:vector-immutable (to-int x) (to-int y) (to-int z)))
(ro:define (mk-int4 w x y z)
(ro:#%app ro:vector-immutable (to-int w) (to-int x) (to-int y) (to-int z)))
(ro:define (to-int16 v)
(ro:cond
[(ro:list? v)
@ -150,6 +212,8 @@
[else
(ro:apply ro:vector-immutable
(ro:make-list 16 (to-int v)))]))
(ro:define (to-int16* v)
(cl:pointer-cast v cl:int16))
(define-named-type-alias bool
(add-convertm rosette3:Bool to-bool))
@ -157,11 +221,30 @@
(add-convertm rosette3:Int to-int))
(define-named-type-alias float
(add-convertm rosette3:Num to-float))
(define-named-type-alias char* rosette3:CString)
(define-named-type-alias float3
(add-convertm
(rosette3:CVector rosette3:Num rosette3:Num rosette3:Num)
to-float3))
(define-named-type-alias int2
(add-constructm
(add-convertm
(rosette3:CVector rosette3:Int rosette3:Int)
to-int2)
mk-int2))
(define-named-type-alias int3
(add-constructm
(add-convertm
(rosette3:CVector rosette3:Int rosette3:Int rosette3:Int)
to-int3)
mk-int3))
(define-named-type-alias int4
(add-constructm
(add-convertm
(rosette3:CVector rosette3:Int rosette3:Int rosette3:Int rosette3:Int)
to-int4)
mk-int4))
(define-named-type-alias int16
(add-convertm
(rosette3:CVector rosette3:Int rosette3:Int rosette3:Int rosette3:Int
@ -169,8 +252,13 @@
rosette3:Int rosette3:Int rosette3:Int rosette3:Int
rosette3:Int rosette3:Int rosette3:Int rosette3:Int)
to-int16))
(define-named-type-alias void* rosette3:Unit)
(define-type-constructor Pointer #:arity = 1)
(define-named-type-alias void rosette3:CUnit)
(define-named-type-alias void* (Pointer rosette3:CUnit))
(define-named-type-alias int*
(Pointer int))
(define-named-type-alias int16*
(add-convertm (Pointer int16) to-int16*))
(define-typed-syntax synth-app
[(_ (ty:type) e) ; cast
@ -183,11 +271,146 @@
(format "cannot cast ~a to ~a: conversion fn not found"
(type->str #'ty-e) (type->str #'ty.norm))
--------
[ (convert e-) ty.norm]]
[ (ro:#%app convert e-) ty.norm]]
[(_ ty:type e ...) ; construct
[ e e- ty-e] ...
;; #:fail-unless (cast-ok? #'ty-e #'ty.norm #'e)
;; (format "cannot cast ~a to ~a"
;; (type->str #'ty-e) (type->str #'ty.norm))
#:with construct (get-construct #'ty.norm)
#:fail-unless (syntax-e #'construct)
(format "no constructor found for ~a type"
(type->str #'ty.norm))
--------
[ (ro:#%app construct e- ...) ty.norm]]
[(_ ptr sel) ; applying ptr to one arg is selector
[ ptr ptr- ty-ptr]
#:when (pointer-type? #'ty-ptr)
#:do [(define split-ty (ty->len #'ty-ptr))]
#:when (and split-ty (= 3 (length split-ty)))
#:do [(define base-str (cadr split-ty))
(define len-str (caddr split-ty))]
#:with e-out #'(cl:pointer-ref ptr- sel)
#:with ty-out ((current-type-eval)
(format-id #'here "~a~a"
base-str len-str))
; #:with convert (get-convert #'ty-out)
--------
[ e-out ty-out]]
[(_ vec sel) ; applying vector to one arg is selector
[ vec vec- ty-vec]
#:when (vector-type? #'ty-vec)
#:with selector (cl:parse-selector #t #'sel stx)
#:do [(define split-ty (ty->len #'ty-vec))]
#:when (and split-ty (= 3 (length split-ty)))
#:do [(define base-str (cadr split-ty))
(define len-str (caddr split-ty))]
#:do [(define sels (length (stx->list #'selector)))]
#:with e-out (if (= sels 1)
#'(ro:vector-ref vec- (ro:car 'selector))
#'(for/list ([idx 'selector])
(ro:vector-ref vec- idx)))
#:with ty-out ((current-type-eval)
(if (= sels 1)
(format-id #'here "~a" base-str)
(format-id #'here "~a~a"
base-str (length (stx->list #'selector)))))
#:with convert (get-convert #'ty-out)
--------
[ (ro:#%app convert e-out) ty-out]]
[(_ f e ...)
[ f f- (~C→ ty-in ... ty-out)]
[ e e- ty-e] ...
#:when (stx-andmap cast-ok? #'(ty-e ...) #'(ty-in ...))
--------
[ (ro:#%app f- e- ...) ty-out]]
[(_ . es)
--------
[ (rosette3:#%app . es)]])
(define-syntax (⊢m stx)
(syntax-parse stx #:datum-literals (:)
[(_ e : τ) (assign-type #`e #`τ)]
[(_ e τ) (assign-type #`e #`τ)]))
;; top-level fns --------------------------------------------------
(define-typed-syntax procedure
[(_ ty-out:type (f [ty:type x:id] ...) e ...)
#:with col (datum->syntax #f ':)
#:with arr (datum->syntax #f '->)
#:with (conv ...) (stx-map get-convert #'(ty.norm ...))
--------
[ #,(if (typecheck/un? #'ty-out #'void)
#'(rosette3:define (f [x col ty] ... arr ty-out)
;; TODO: this is deviating from rosette's impl
;; but I think it's a bug in rosette
;; otherwise it's unsound
; (⊢ (ro:set! x (ro:a conv x)) void) ...
(⊢m (ro:let ([x (ro:#%app conv x)] ...)
e ...
(rosette3:#%app rosette3:void))
ty-out))
#'(rosette3:define (f [x col ty] ... arr ty-out)
; (⊢ (ro:set! x (ro:a conv x)) void) ...
(⊢m (ro:let ([x (ro:#%app conv x)] ...)
(rosette3:#%app rosette3:void)
e ...)
ty-out)))]])
(define-typed-syntax kernel
[(_ ty-out:type (f [ty:type x:id] ...) e ...)
--------
[ (procedure ty-out (f [ty x] ...) e ...)]])
;; for and if statement --------------------------------------------------
(define-typed-syntax if
[(_ test {then ...} {else ...})
--------
[ (ro:if (to-bool test)
(ro:let () then ... (ro:void))
(ro:let () else ... (ro:void))) void]]
[(_ test {then ...})
--------
[ (if test {then ...} {})]])
;(define-syntax-parameter range (syntax-rules ()))
(define-typed-syntax (range e ...)
[ e e- int] ...
--------
[ (ro:#%app ro:in-range e- ...) int])
(define-typed-syntax for
[(_ [((~literal :) ty:type var:id (~datum in) rangeExpr) ...] e ...)
; [⊢ rangeExpr ≫ rangeExpr- ⇒ _] ...
[[var var- : ty.norm] ... [e e- ty-e] ...]
--------
[ (ro:for* ([var- rangeExpr] ...)
e- ... (ro:void)) void]])
;; need to redefine #%datum because rosette3:#%datum is too precise
(define-typed-syntax #%datum
[(_ . b:boolean)
; #:with ty_out (if (syntax-e #'b) #'True #'False)
--------
[ (ro:#%datum . b) bool]]
[(_ . n:integer)
;; #:with ty_out (let ([m (syntax-e #'n)])
;; (cond [(zero? m) #'Zero]
;; [(> m 0) #'PosInt]
;; [else #'NegInt]))
--------
[ (ro:#%datum . n) int]]
[(#%datum . n:number)
#:when (real? (syntax-e #'n))
--------
[ (ro:#%datum . n) float]]
[(_ . s:str)
--------
[ (ro:#%datum . s) char*]]
[(_ . x)
--------
[_ #:error (type-error #:src #'x #:msg "Unsupported literal: ~v" #'x)]])
;; : --------------------------------------------------
(define-typed-syntax :
[(_ ty:type x:id ...) ; special String case
@ -201,6 +424,7 @@
;; TODO: vector types need a better representation
;; currently dissecting the identifier
[(_ ty:type x:id ...)
#:when (real-type? #'ty.norm)
#:do [(define split-ty (ty->len #'ty))]
#:when (and split-ty (= 3 (length split-ty)))
#:do [(define base-str (cadr split-ty))
@ -217,8 +441,9 @@
(ro:define x- (ro:apply ro:vector-immutable x--)) ...
(define-syntax- x
(make-rename-transformer (assign-type #'x- #'ty.norm))) ...)]]
;; scalar (ie non-vector) types
;; real, scalar (ie non-vector) types
[(_ ty:type x:id ...)
#:when (real-type? #'ty.norm)
#:with pred (get-pred #'ty.norm)
#:fail-unless (syntax-e #'pred)
(format "no pred for ~a" (type->str #'ty))
@ -227,7 +452,16 @@
[ (begin-
(define-syntax- x
(make-rename-transformer (assign-type #'x- #'ty.norm))) ...
(ro:define-symbolic* x- pred) ...)]])
(ro:define-symbolic* x- pred) ...)]]
;; else init to NULLs
[(_ ty:type x:id ...)
; #:when (not (real-type? #'ty.norm))
#:with (x- ...) (generate-temporaries #'(x ...))
--------
[ (begin-
(define-syntax- x
(make-rename-transformer (assign-type #'x- #'ty.norm))) ...
(ro:define x- cl:NULL) ...)]])
;; ?: --------------------------------------------------
(define-typed-syntax ?:
@ -280,16 +514,24 @@
(format "cannot cast ~a to ~a"
(type->str #'ty-e) (type->str #'ty-x))
--------
[ (ro:set! x- (synth-app (ty-x) e-)) void]])
(define-typed-syntax +=
[(_ x:id e)
[ (ro:set! x- (synth-app (ty-x) e-)) void]]
;; selector can be list of numbers or up to wxyz for vectors of length <=4
[(_ [x:id sel] e)
[ x x- ty-x]
[ e e- ty-e]
#:fail-unless (cast-ok? #'ty-e #'ty-x stx)
(format "cannot cast ~a to ~a"
(type->str #'ty-e) (type->str #'ty-x))
#:with out-e (if (pointer-type? #'ty-x)
#'(ro:begin
(cl:pointer-set! x- sel e-)
x-)
(with-syntax ([selector (cl:parse-selector #f #'sel stx)])
#`(ro:let ([out (ro:vector-copy x-)])
#,(if (= 1 (length (stx->list #'selector)))
#`(ro:vector-set! out (car 'selector) e-)
#'(ro:for ([idx 'selector] [v e-])
(ro:vector-set! out idx v)))
out)))
--------
[ (ro:set! x- (+ x- (synth-app (ty-x) e-))) void]])
[ (ro:set! x- out-e) void]])
(define-typed-syntax !
[(_ e)
@ -357,8 +599,45 @@
#'(convert
(ro:let ([a (convert e1-)][b (convert e2-)])
(ro:for/list ([v1 a][v2 b])
(base-convert (ro:+ v1 v2)))))) ty-out]])
(base-convert (cl:+ v1 v2)))))) ty-out]])
(define-typed-syntax +=
[(_ x e)
--------
[ (= x (+ x e))]])
(define-typed-syntax %
[(_ e1 e2)
[ e1 e1- ty1]
[ e2 e2- ty2]
;; #:when (real-type? #'ty1)
;; #:when (real-type? #'ty2)
#:with ty-out (common-real-type #'ty1 #'ty2)
#:with convert (get-convert #'ty-out)
#:with ty-base (get-base #'ty-out)
#:with base-convert (get-convert #'ty-base)
--------
[ #,(if (scalar-type? #'ty-out)
#'(convert (cl:% (synth-app (ty-out) e1-)
(synth-app (ty-out) e2-)))
#'(convert
(ro:let ([a (convert e1-)][b (convert e2-)])
(ro:for/list ([v1 a][v2 b])
(base-convert (cl:% v1 v2)))))) ty-out]])
(define-typed-syntax %=
[(_ x e)
--------
[ (= x (% x e))]])
#;(define-typed-syntax %=
[(_ x:id e)
[ x x- ty-x]
[ e e- ty-e]
#:fail-unless (cast-ok? #'ty-e #'ty-x stx)
(format "cannot cast ~a to ~a"
(type->str #'ty-e) (type->str #'ty-x))
--------
[ (ro:set! x- (% x- (synth-app (ty-x) e-))) void]])
#;(define-typed-syntax &&
[(_ e1 e2)

View File

@ -15,9 +15,9 @@
;; (check-type (expression? x) : CBool -> #f)
;; (check-type (constant? x) : CBool -> #t)
;; (assert (+ x 1))
;; (assert (% (+ x 2) 3))
;; (assert (!= x 2))
(assert (+ x 1))
(assert (% (+ x 2) 3))
(assert (!= x 2))
(check-type "" : char*)
(: char* y)
@ -76,9 +76,10 @@
(ro:a ro:integer->real x)
(ro:a ro:integer->real x)))
(: int16 u)
(: int16 u u2)
(= u2 u)
(check-type u : int16 -> u)
(check-type u2 : int16 -> u)
(check-type NULL : void* -> NULL)
(check-type ((int16) v) : int16
@ -110,58 +111,103 @@
;; (ro:a ro:+ 2.0 (ro:a ro:vector-ref z 1))
;; (ro:a ro:+ 2.0 (ro:a ro:vector-ref z 2))))
;; (%= x 3)
;; x
(%= x 3)
(check-type x : int -> 0)
;; (int3 4 5 6)
;; (= [u xyz] (int3 4 5 6))
;; u
(check-type (int3 4 5 6) : int3
-> (ro:a ro:vector-immutable 4 5 6))
;; (+ (int3 1 2 3) 4)
(= [u xyz] (int3 4 5 6))
;; ((int4 5 6 7 8) s03)
(check-type u : int16
-> (ro:let ([out (ro:a ro:vector-copy u2)])
(ro:a ro:vector-set! out 0 4)
(ro:a ro:vector-set! out 1 5)
(ro:a ro:vector-set! out 2 6)
out))
;; (if x {}{})
(check-type (+ (int3 1 2 3) 4) : int3
-> (ro:a ro:vector-immutable 5 6 7))
;; (if x
;; { (= [u sf] 10) }
;; { (= [u sf] 9) }
;; )
;; u
(check-type ((int4 5 6 7 8) s03) : int2
-> (ro:a ro:vector-immutable 5 8))
;; (if (! x)
;; { (: int g) (= g 3) (= [u sf] g) }
;; { (= [u sf] 9) }
;; )
;; u
(check-type (if x {}{}) : void -> (= x x))
;; (for [(: int i in (range 0 4 1))] )
;; (for [(: int i in (range 0 4 1))]
;; (if (! x)
;; { (: int g) (= g i) (+= [u sf] g)} )
;; )
;; u
(= u2 u)
(if x
{ (= [u sf] 10) }
{ (= [u sf] 9) }
)
(check-type u : int16
-> (ro:let ([out (ro:a ro:vector-copy u2)])
(ro:a ro:vector-set! out 15 9)
out))
(if (! x)
{ (: int g) (= g 3) (= [u sf] g) }
{ (= [u sf] 9) }
)
(check-type u : int16
-> (ro:let ([out (ro:a ro:vector-copy u2)])
(ro:a ro:vector-set! out 15 3)
out))
;; (: int16* w)
;; (= w ((int16*) (malloc 32)))
;; (= [w 0] 1)
;; (= [w 1] 2)
;; w
;; [w 0]
;; [w 1]
(check-type (for [(: int i in (range 0 4 1))] ) : void -> (= x x))
;; (get_work_dim)
(check-type (! x) : bool -> #t)
(: int g1)
(= g1 3)
(check-type (u sf) : int -> 3)
(+= [u sf] g1)
(check-type u : int16
-> (ro:let ([out (ro:a ro:vector-copy u2)])
(ro:a ro:vector-set! out 15 6)
out))
(= [u sf] 3)
;; (procedure void (nop1))
;; (nop1)
;; (kernel void (nop2))
;; (nop2)
(for [(: int i in (range 0 4 1))]
(if (! x)
{ (: int g) (= g i) (+= [u sf] g)} )
)
(check-type u : int16
-> (ro:let ([out (ro:a ro:vector-copy u2)])
(ro:a ro:vector-set! out 15 9)
out))
;; (procedure int (int_iden [int x]) x)
;; (int_iden ((int) 4.5))
;; (int_iden #t)
;; (int_iden 4.5)
(: int16* w)
(check-type w : int16* -> NULL)
(check-type (malloc 32) : void*)
(= w ((int16*) (malloc 32)))
; TODO: how to check this?
;#x0#(0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0)
(check-type w : int16*)
(= [w 0] 1)
;#x0#(1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0)
(= [w 1] 2)
;#x0#(1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2)
(check-type [w 0] : int16
-> (ro:a ro:vector-immutable 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1))
(check-type [w 1] : int16
-> (ro:a ro:vector-immutable 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2))
(check-type (get_work_dim) : int -> 0)
(procedure void (nop1))
(check-type (nop1) : void -> (= x x))
(kernel void (nop2))
(check-type (nop2) : void -> (= x x))
(procedure int (int_iden [int x]) x)
;; huh? these are unsound?
;; but match rosette's implementation
;; specifically, procedure does not coerce (but kernel does)
(check-type (int_iden ((int) 4.5)) : int -> 4)
;; (check-type (int_iden #t) : int -> #t)
;; (check-type (int_iden 4.5) : int -> 4.5)
(check-type (int_iden #t) : int -> 1)
(check-type (int_iden 4.5) : int -> 4)