diff --git a/pkgs/racket-test-extra/tests/ffi/union.rkt b/pkgs/racket-test-extra/tests/ffi/union.rkt new file mode 100644 index 0000000000..9029d3729f --- /dev/null +++ b/pkgs/racket-test-extra/tests/ffi/union.rkt @@ -0,0 +1,31 @@ +#lang racket/base + +;; Tests for FFI unions + +(require ffi/unsafe + rackunit) + +(define union-type (_union (_list-struct _int _int))) +(define union-type-2 (_union (_list-struct _int _int) + (_list-struct _double _double))) + +(define val (cast (list 1 2) (_list-struct _int _int) union-type)) +(define val-2 (cast (list 1.2 2.2) (_list-struct _double _double) union-type-2)) + +(check-equal? (car (union-ref val 0)) 1) +(check-equal? (car (union-ref val-2 1)) 1.2) + +(union-set! val-2 0 (list 5 4)) +(check-equal? (car (union-ref val-2 0)) 5) + +(check-exn #rx"expected: list of c types" (λ () (_union 3))) +(check-exn #rx"expected: list of c types" (λ () (_union _int 4))) +(check-not-exn (λ () (_union _int _int))) + +(check-exn #rx"too large" (λ () (union-ref val 1))) +(check-exn #rx"nonnegative-integer" (λ () (union-ref val -1))) +(check-exn #rx"nonnegative-integer" (λ () (union-ref val "foo"))) +(check-exn #rx"too large" (λ () (union-ref val-2 2))) +(check-exn #rx"too large" (λ () (union-set! val 1 (list 1 2)))) +(check-exn #rx"nonnegative-integer" (λ () (union-set! val -1 (list 1 2)))) +(check-exn #rx"nonnegative-integer" (λ () (union-set! val "foo" (list 1 2)))) diff --git a/racket/collects/ffi/unsafe.rkt b/racket/collects/ffi/unsafe.rkt index 6829304bc8..18260f5f3e 100644 --- a/racket/collects/ffi/unsafe.rkt +++ b/racket/collects/ffi/unsafe.rkt @@ -1233,6 +1233,8 @@ (protect-out union-ref union-set!)) (define (_union t . ts) + (unless (and (ctype? t) (andmap ctype? ts)) + (raise-argument-error '_union "list of c types" (cons t ts))) (let ([ts (cons t ts)]) (make-ctype (apply make-union-type ts) (lambda (v) (union-ptr v)) @@ -1240,8 +1242,26 @@ (define-struct union (ptr types)) (define (union-ref u i) + (unless (union? u) + (raise-argument-error 'union-ref "union value" 0 u i)) + (unless (exact-nonnegative-integer? i) + (raise-argument-error 'union-ref "exact-nonnegative-integer?" 1 u i)) + (unless (< i (length (union-types u))) + (raise-arguments-error 'union-ref + "index too large for union" + "index" + i)) (ptr-ref (union-ptr u) (list-ref (union-types u) i))) (define (union-set! u i v) + (unless (union? u) + (raise-argument-error 'union-ref "union value" 0 u i)) + (unless (exact-nonnegative-integer? i) + (raise-argument-error 'union-ref "exact-nonnegative-integer?" 1 u i)) + (unless (< i (length (union-types u))) + (raise-arguments-error 'union-ref + "index too large for union" + "index" + i)) (ptr-set! (union-ptr u) (list-ref (union-types u) i) v)) ;; ----------------------------------------------------------------------------