start synthcl3: snippets up to int16 working (float3 and ?: selection)

This commit is contained in:
Stephen Chang 2016-11-04 15:36:35 -04:00
parent 5d8557bbca
commit 56b8b52ea6
4 changed files with 506 additions and 14 deletions

View File

@ -55,8 +55,8 @@
[(_ . stuff)
(syntax/loc this-syntax
(#%module-begin
; auto-provide some useful racket forms
(provide #%module-begin #%top-interaction #%top require only-in)
(provide #%module-begin #%top-interaction #%top ; useful racket forms
require only-in prefix-in rename-in)
. stuff))]))
(struct exn:fail:type:runtime exn:fail:user ())

View File

@ -1,7 +1,8 @@
#lang turnstile
;; reuse unlifted forms as-is
(reuse define λ let let* letrec begin void #%datum ann #%top-interaction
require only-in define-type-alias define-named-type-alias
require only-in prefix-in rename-in define-type-alias define-named-type-alias
current-join
#:from "../stlc+union.rkt")
(require
;; manual imports
@ -11,7 +12,7 @@
(combine-in
(only-in "../stlc+union+case.rkt"
PosInt Zero NegInt Float True False String Unit [U U*] U*?
[case-> case->*] case->? →?)
[case-> case->*] case->? →? String?)
(only-in "../stlc+cons.rkt" [List Listof])))
(only-in "../stlc+union+case.rkt" [~U* ~CU*] [~case-> ~Ccase->] [~→ ~C→])
(only-in "../stlc+cons.rkt" [~List ~CListof])
@ -21,13 +22,14 @@
(provide (rename-out [ro:#%module-begin #%module-begin]
[stlc+union:λ lambda])
(for-syntax get-pred)
Any CNothing Nothing
CU U
CU U (for-syntax ~U*)
Constant
C→ (for-syntax ~C→ C→?)
Ccase-> (for-syntax ~Ccase-> Ccase->?) ; TODO: sym case-> not supported
CListof Listof CList CPair Pair
CVectorof MVectorof IVectorof Vectorof CMVectorof CIVectorof
CVectorof MVectorof IVectorof Vectorof CMVectorof CIVectorof CVector
CParamof ; TODO: symbolic Param not supported yet
CBoxof MBoxof IBoxof CMBoxof CIBoxof CHashTable
CUnit Unit
@ -39,13 +41,13 @@
CFloat Float
CNum Num
CFalse CTrue CBool Bool
CString String
CString String (for-syntax CString?)
CStx ; symblic Stx not supported
CAsserts
;; BV types
CBV BV
CBVPred BVPred
CSolution CSolver CPict CSyntax CRegexp CSymbol CPred)
CSolution CSolver CPict CSyntax CRegexp CSymbol CPred CPredC)
(begin-for-syntax
(define (mk-ro:-id id) (format-id id "ro:~a" id))
@ -85,6 +87,7 @@
(define-named-type-alias (CVectorof X) (CU (CIVectorof X) (CMVectorof X)))
(define-named-type-alias (CBoxof X) (CU (CIBoxof X) (CMBoxof X)))
(define-type-constructor CList #:arity >= 0)
(define-type-constructor CVector #:arity >= 0)
(define-type-constructor CPair #:arity = 2)
;; TODO: update orig to use reduced type
@ -224,6 +227,7 @@
(define-symbolic-named-type-alias Num (CU CFloat CInt) #:pred ro:real?)
(define-named-type-alias CPred (C→ Any Bool))
(define-named-type-alias CPredC (C→ Any CBool))
;; ---------------------------------
;; define-symbolic
@ -1593,13 +1597,14 @@
;; ---------------------------------
;; Reflecting on symbolic values
;; TODO: CPredC correct here?
(provide (typed-out
[term? : CPred]
[expression? : CPred]
[constant? : CPred]
[type? : CPred]
[solvable? : CPred]
[union? : CPred]))
[term? : CPredC]
[expression? : CPredC]
[constant? : CPredC]
[type? : CPredC]
[solvable? : CPredC]
[union? : CPredC]))
(define-typed-syntax union-contents
[(_ u)

View File

@ -0,0 +1,320 @@
#lang turnstile
(extends "rosette3.rkt" #:except ! #%app || &&) ; typed rosette
(require ;(prefix-in ro: (except-in rosette verify sqrt range print)) ; untyped
(prefix-in ro: rosette)
; define-symbolic* #%datum define if ! = number? boolean? cond ||))
(prefix-in cl: sdsl/synthcl/model/operators)
(prefix-in cl: sdsl/synthcl/model/errors))
(begin-for-syntax
(define (mk-cl id) (format-id id "cl:~a" id))
(current-host-lang mk-cl))
;(define-base-types)
(provide (rename-out
[synth-app #%app]
; [rosette3:Bool bool] ; symbolic
; [rosette3:Int int] ; symbolic
; [rosette3:Num float] ; symbolic
[rosette3:CString char*]) ; always concrete
bool int float float3 int16
: ! ?: ==
(typed-out
;; need the concrete cases for literals;
;; alternative is to redefine #%datum to give literals symbolic type
[% : (Ccase-> (C→ CInt CInt CInt)
(C→ Int Int Int))]
[!= : (Ccase-> (C→ CNum CNum CBool)
(C→ CNum CNum CNum CBool)
(C→ Num Num Bool)
(C→ Num Num Num Bool))]
#;[== : (Ccase-> (C→ CNum CNum CBool)
(C→ CNum CNum CNum CBool)
(C→ Num Num Bool)
(C→ Num Num Num Bool))]))
(begin-for-syntax
;; TODO: use equality type relation instead of subtype
;; - require reimplementing many more things, eg #%datum, +, etc
; (current-typecheck-relation (current-type=?))
;; typecheck unexpanded types
(define (typecheck/un? t1 t2)
(typecheck? ((current-type-eval) t1)
((current-type-eval) t2)))
(define (real-type? t) (not (typecheck/un? t #'bool)))
(define (type-base t)
(datum->syntax t
(string->symbol
(car (regexp-match #px"[a-z]+" (type->str t))))))
(define (real-type<=? t1 t2)
(and (real-type? t1) (real-type? t2)
(or (typecheck? t1 t2)
(typecheck/un? t1 #'bool)
(and (typecheck/un? t1 #'int)
(not (typecheck/un? t2 #'bool)))
(and (typecheck/un? t1 #'float)
(typecheck/un? (type-base t2) #'float)))))
; Returns the common real type of the given types, as specified in
; Ch. 6.2.6 of opencl-1.2 specification. If there is no common
; real type, returns #f.
(define common-real-type
(case-lambda
[(t) (and (real-type? t) t)]
[(t1 t2) (cond [(real-type<=? t1 t2) t2]
[(real-type<=? t2 t1) t1]
[else #f])]
[ts (common-real-type (car ts) (apply common-real-type (cdr ts)))]))
;; implements common-real-type from model/reals.rkt
; Returns the common real type of the given types, as specified in
; Ch. 6.2.6 of opencl-1.2 specification. If there is no common
; real type, returns #f.
(current-join common-real-type)
;; copied from check-implicit-conversion in lang/types.rkt
;; TODO: this should not exception since it is used in stx-parse
;; clauses that may want to backtrack
(define (cast-ok? from to expr [subexpr #f])
(unless (if #t #;(type? to)
(or (typecheck/un? from to)
(and (scalar-type? from) (scalar-type? to))
(and (scalar-type? from) (vector-type? to))
#;(and (pointer-type? from) (pointer-type? to))
#;(and (equal? from cl_mem) (pointer-type? to)))
(to from))
(raise-syntax-error
#f
(format "no implicit conversion from ~a to ~a"
(type->str from) (type->str to)
#;(if (contract? to) (contract-name to) to))
expr subexpr))
#;(or (typecheck/un? from to) ; from == to
(and (real-type? from)
(typecheck/un? to #'bool))))
(define (add-convert stx fn)
(set-stx-prop/preserved stx 'convert fn))
(define (get-convert stx)
(syntax-property stx 'convert))
(define (ty->len ty)
(regexp-match #px"([a-z]+)([0-9]+)" (type->str ty)))
(define (vector-type? ty)
(ty->len ty))
(define (scalar-type? ty)
(not (vector-type? ty))))
(define-syntax-parser add-convertm
[(_ stx fn) (add-convert #'stx #'fn)])
;; TODO: reuse impls in model/reals.rkt ?
(ro:define (to-bool v)
(ro:cond
[(ro:boolean? v) v]
[(ro:number? v) (ro:! (ro:= 0 v))]
[else (cl:raise-conversion-error v "bool")]))
(ro:define (to-int v)
(ro:cond [(ro:boolean? v) (ro:if v 1 0)]
[(ro:fixnum? v) v]
[(ro:flonum? v) (ro:exact-truncate v)]
[else (ro:real->integer v)]))
(ro:define (to-float v)
(ro:cond
[(ro:boolean? v) (ro:if v 1.0 0.0)]
[(ro:fixnum? v) (ro:exact->inexact v)]
[(ro:flonum? v) v]
[else (ro:type-cast ro:real? v)]))
(ro:define (to-float3 v)
(ro:cond
[(ro:list? v)
(ro:apply ro:vector-immutable
(ro:for/list ([i 3]) (to-float (ro:list-ref v i))))]
[(ro:vector? v)
(ro:apply ro:vector-immutable
(ro:for/list ([i 3]) (to-float (ro:vector-ref v i))))]
[else (ro:apply ro:vector-immutable (ro:make-list 3 (to-float v)))]))
(ro:define (to-int16 v) v
#;(ro:cond
[(ro:list? v)
(ro:apply ro:vector-immutable
(ro:for/list ([i 16]) (to-int (ro:list-ref v i))))]
[(ro:vector? v)
(ro:apply ro:vector-immutable
(ro:for/list ([i 16]) (to-int (ro:vector-ref v i))))]
[else (ro:apply ro:vector-immutable (ro:make-list 16 (to-float v)))]))
(define-named-type-alias bool
(add-convertm rosette3:Bool to-bool))
(define-named-type-alias int
(add-convertm rosette3:Int to-int))
(define-named-type-alias float
(add-convertm rosette3:Num to-float))
(define-named-type-alias float3
(add-convertm
(rosette3:CVector rosette3:Num rosette3:Num rosette3:Num)
to-float3))
(define-named-type-alias int16
(add-convertm
(rosette3:CVector rosette3:Int rosette3:Int rosette3:Int rosette3:Int
rosette3:Int rosette3:Int rosette3:Int rosette3:Int
rosette3:Int rosette3:Int rosette3:Int rosette3:Int
rosette3:Int rosette3:Int rosette3:Int rosette3:Int)
to-int16))
(define-typed-syntax synth-app
[(_ (ty:type) e) ; cast
[ e e- ty-e]
#:fail-unless (cast-ok? #'ty-e #'ty.norm #'e)
(format "cannot cast ~a to ~a"
(type->str #'ty-e) (type->str #'ty.norm))
#:with convert (get-convert #'ty.norm)
#:fail-unless (syntax-e #'convert)
(format "cannot cast ~a to ~a: conversion fn not found"
(type->str #'ty-e) (type->str #'ty.norm))
--------
[ (convert e-) ty.norm]]
[(_ . es)
--------
[ (rosette3:#%app . es)]])
(define-typed-syntax :
[(_ ty:type x:id ...) ; special String case
#:when (rosette3:CString? #'ty.norm)
#:with (x- ...) (generate-temporaries #'(x ...))
--------
[ (begin-
(define-syntax- x
(make-rename-transformer (assign-type #'x- #'ty.norm))) ...
(ro:define x- (ro:#%datum . "")) ...)]]
;; TODO: vector types need a better representation
;; currently dissecting the identifier
[(_ ty:type x:id ...)
#:do [(define split-ty (ty->len #'ty))]
#:when (and split-ty (= 3 (length split-ty)))
#:do [(define base-str (cadr split-ty))
(define len-str (caddr split-ty))]
#:with ty-base (datum->syntax #'ty (string->symbol base-str))
#:with pred (get-pred ((current-type-eval) #'ty-base))
#:fail-unless (syntax-e #'pred)
(format "no pred for ~a" (type->str #'ty))
#:with (x- ...) (generate-temporaries #'(x ...))
#:with (x-- ...) (generate-temporaries #'(x ...))
--------
[ (begin-
(ro:define-symbolic* x-- pred [#,(string->number len-str)]) ...
(ro:define x- (ro:apply ro:vector-immutable x--)) ...
(define-syntax- x
(make-rename-transformer (assign-type #'x- #'ty.norm))) ...)]]
;; scalar (ie non-vector) types
[(_ ty:type x:id ...)
#:with pred (get-pred #'ty.norm)
#:fail-unless (syntax-e #'pred)
(format "no pred for ~a" (type->str #'ty))
#:with (x- ...) (generate-temporaries #'(x ...))
--------
[ (begin-
(define-syntax- x
(make-rename-transformer (assign-type #'x- #'ty.norm))) ...
(ro:define-symbolic* x- pred) ...)]])
(define-typed-syntax ?:
[(_ e e1 e2)
[ e e- bool]
[ e1 e1- ty1]
[ e2 e2- ty2]
-------
[ (cl:?: e- e1- e2-) ( τ1 τ2)]]
[(_ e e1 e2)
[ e e- ty] ; vector type
#:do [(define split-ty (ty->len #'ty))]
#:when (and split-ty (= 3 (length split-ty)))
[ e1 e1- ty1]
[ e2 e2- ty2]
#:with ty-out (common-real-type #'ty #'ty1 #'ty2)
#:with convert (get-convert #'ty-out)
#:do [(define split-ty-out (ty->len #'ty-out))
(define out-base-str (cadr split-ty-out))
(define out-len-str (caddr split-ty-out))]
#:with ty-base ((current-type-eval) (datum->syntax #'e (string->symbol out-base-str)))
#:with base-convert (get-convert #'ty-base)
-------
[ (convert
(ro:let ([a (convert e-)][b (convert e1-)][c (convert e2-)])
(ro:for/list ([idx #,(string->number out-len-str)])
(ro:if (ro:< (ro:vector-ref a idx) 0)
(base-convert (ro:vector-ref b idx))
(base-convert (ro:vector-ref c idx))))))
ty-out]]
[(_ e e1 e2) ; should be scalar
#:when (displayln 1)
[ e e- ty]
#:fail-unless (cast-ok? #'ty #'bool #'e)
(format "cannot cast ~a to bool" (type->str #'ty))
[ e1 e1- ty1]
[ e2 e2- ty2]
#:with ty-out ((current-join) #'ty1 #'ty2)
-------
[ (cl:?: (synth-app (bool) e-)
(synth-app (ty-out) e1-)
(synth-app (ty-out) e2-)) ty-out]])
(define-typed-syntax !
[(_ e)
[ e e- bool]
--------
[ (cl:! e-) bool]]
;; else try to coerce
[(_ e)
[ e e- ty]
--------
[ (cl:! (synth-app (bool) e-)) bool]])
(define-syntax (define-coercing-bool-binop stx)
(syntax-parse stx
[(_ name)
#:with name- (mk-cl #'name)
#'(define-typed-syntax name
[(_ e1 e2)
[ e1 e1- bool]
[ e2 e2- bool]
--------
[ (name- e1- e2-) bool]]
[(_ e1 e2) ; else try to coerce
[ e1 e1- ty1]
[ e2 e2- ty2]
#:fail-unless (cast-ok? #'ty1 #'bool #'e1)
(format "cannot cast ~a to bool" (type->str #'ty1))
#:fail-unless (cast-ok? #'ty2 #'bool #'e2)
(format "cannot cast ~a to bool" (type->str #'ty2))
--------
[ (name- (synth-app (bool) e1-)
(synth-app (bool) e2-)) bool]])]))
(define-simple-macro (define-coercing-bool-binops o ...+)
(ro:begin
(provide o ...)
(define-coercing-bool-binop o) ...))
(define-coercing-bool-binops || &&)
;; TODO: this should produce int-vector result?
(define-typed-syntax ==
[(_ e1 e2)
[ e1 e1- ty1]
[ e2 e2- ty2]
#:when (real-type? #'ty1)
#:when (real-type? #'ty2)
#:with ty-out ((current-join) #'ty1 #'ty2) ; only need this for the len
--------
[ (to-int (cl:== e1- e2-)) int]])
#;(define-typed-syntax &&
[(_ e1 e2)
[ e1 e1- bool]
[ e2 e2- bool]
--------
[ (cl:&& e1- e2-) bool]]
;; else try to coerce
[(_ e1 e2)
[ e1 e1- ty1]
[ e2 e2- ty2]
--------
[ (cl:&& (synth-app (bool) e1-) (synth-app (bool) e2-)) bool]])

View File

@ -0,0 +1,167 @@
#lang s-exp "../../../rosette/synthcl3.rkt"
(require "../../rackunit-typechecking.rkt"
(prefix-in cl: sdsl/synthcl/lang/main)
(prefix-in ro: (rename-in rosette [#%app a])))
;; from synthcl/test/snippets.rkt and more-snippets.rkt
;; (ro:define-symbolic b ro:boolean?)
(: int x)
(check-type x : int -> x)
(check-not-type x : CInt)
;; TODO: should these be defined in synthcl?
;; I think no, synthcl is not an extension of rosette
;; (check-type (term? x) : CBool -> #t)
;; (check-type (expression? x) : CBool -> #f)
;; (check-type (constant? x) : CBool -> #t)
(assert (+ x 1))
(assert (% (+ x 2) 3))
(assert (!= x 2))
(check-type "" : char*)
(: char* y)
(check-type y : char* -> "")
(: float v)
(check-type v : float -> v)
;; (check-type (term? v) : CBool -> #t)
;; (check-type (expression? v) : CBool -> #f)
;; (check-type (constant? v) : CBool -> #t)
(check-type ((bool) v) : bool -> (ro:a ro:! (ro:a ro:= 0 v)))
(check-type (! ((bool) x)) : bool -> (ro:a ro:= 0 x))
(assert (! ((bool) x)))
(check-type (|| x (! v)) : bool
-> (ro:a ro:|| (ro:a ro:! (ro:a ro:= 0 x))
(ro:a ro:&& (ro:a ro:= 0 x) (ro:a ro:= 0 v))))
(assert (|| x (! v)))
(check-type (== x v) : int
-> (ro:if (ro:a ro:= v (ro:a ro:integer->real x)) 1 0))
(assert (== x v))
(: float3 z)
(check-type z : float3 -> z)
;; check coercions
(check-type ((bool) x) : bool -> (ro:a ro:! (ro:a ro:= 0 x)))
(check-type ((float) x) : float -> (ro:a ro:integer->real x))
(check-type ((float3) x) : float3
-> (ro:a ro:vector-immutable
(ro:a ro:integer->real x)
(ro:a ro:integer->real x)
(ro:a ro:integer->real x)))
(check-type ((float3) z) : float3 -> z)
;; expected:
;; (vector
;; (ite (= 0 x$0) z$0 (integer->real x$0))
;; (ite (= 0 x$0) z$1 (integer->real x$0))
;; (ite (= 0 x$0) z$2 (integer->real x$0)))
(check-type (?: x x z) : float3
-> (ro:if (ro:a ro:= 0 x)
z
(ro:a ro:vector-immutable
(ro:a ro:integer->real x)
(ro:a ro:integer->real x)
(ro:a ro:integer->real x))))
(typecheck-fail ((bool) z)
#:with-msg "no implicit conversion from float3 to bool")
(check-type (?: z x x) : float3
-> (ro:a ro:vector-immutable
(ro:a ro:integer->real x)
(ro:a ro:integer->real x)
(ro:a ro:integer->real x)))
(: int16 u)
(check-type u : int16 -> u)
;; NULL
;; ((int16) v)
;; (= x 3.4)
;; x
;; (+= z 2)
;; z
;; (%= x 3)
;; x
;; (int3 4 5 6)
;; (= [u xyz] (int3 4 5 6))
;; u
;; (+ (int3 1 2 3) 4)
;; ((int4 5 6 7 8) s03)
;; (if x {}{})
;; (if x
;; { (= [u sf] 10) }
;; { (= [u sf] 9) }
;; )
;; u
;; (if (! x)
;; { (: int g) (= g 3) (= [u sf] g) }
;; { (= [u sf] 9) }
;; )
;; u
;; (for [(: int i in (range 0 4 1))] )
;; (for [(: int i in (range 0 4 1))]
;; (if (! x)
;; { (: int g) (= g i) (+= [u sf] g)} )
;; )
;; u
;; (: int16* w)
;; (= w ((int16*) (malloc 32)))
;; (= [w 0] 1)
;; (= [w 1] 2)
;; w
;; [w 0]
;; [w 1]
;; (get_work_dim)
;; (procedure void (nop1))
;; (nop1)
;; (kernel void (nop2))
;; (nop2)
;; (procedure int (int_iden [int x]) x)
;; (int_iden ((int) 4.5))
;; (int_iden #t)
;; (int_iden 4.5)
;; ;;;;;; assertion failure localization ;;;;;;
;; ; (assert #f)
;; ;;;;;; bad types etc ;;;;;;
;; ;(: float* NULL)
;; ;(+ x y)
;; ;(?: "" x x)
;; ;((int) z)
;; ;(-= z w)
;; ;(%= z 3)
;; ;(NULL 3)
;; ;(if x)
;; ;(for [() () "" (-= x 1)])
;; ;[w ""]
;; ;[w 2]
;; ;(procedure int (bad))
;; ;(procedure)
;; ;(kernel int (bad) 1)
;; ;(procedure void (w))
;; ;(int_iden "")
;; ;(procedure float (bad) "")