diff --git a/collects/ffi/unsafe/com.rkt b/collects/ffi/unsafe/com.rkt index f214a50053..e5e06a0045 100644 --- a/collects/ffi/unsafe/com.rkt +++ b/collects/ffi/unsafe/com.rkt @@ -745,7 +745,7 @@ (GetTypeInfo/tl type-lib coclass-index)) (Release type-lib))])) -(define (event-type-info-from-com-object obj [exn? #t]) +(define (event-type-info-from-com-object obj) (or (com-object-event-type-info obj) (let ([dispatch (com-object-get-dispatch obj)]) (define provide-class-info (QueryInterface dispatch IID_IProvideClassInfo _IProvideClassInfo-pointer)) @@ -807,14 +807,14 @@ (cons name accum) (ReleaseVarDesc type-info var-desc)))))) -(define (extract-type-info who obj) +(define (extract-type-info who obj exn?) (cond - [(com-object? obj) (type-info-from-com-object obj)] + [(com-object? obj) (type-info-from-com-object obj exn?)] [(com-type? obj) (com-type-type-info obj)] [else (raise-type-error who "com-object or com-type" obj)])) (define (do-get-methods who obj inv-kind) - (define type-info (extract-type-info who obj)) + (define type-info (extract-type-info who obj #t)) (define type-attr (GetTypeAttr type-info)) (begin0 (sort (get-type-names type-info type-attr null inv-kind) string-cisymbol (format "COM-0x~x" vt))])) + [else + (if (= VT_ARRAY (bitwise-and vt VT_ARRAY)) + `(array ? ,(vt-to-scheme-type (- vt VT_ARRAY))) + (string->symbol (format "COM-0x~x" vt)))])) (define (arg-to-type arg) (cond + [(boolean? arg) 'boolean] [(signed-int? arg 32) 'int] [(unsigned-int? arg 32) 'unsigned-int] [(signed-int? arg 64) 'long-long] @@ -1103,7 +1106,14 @@ [(real? arg) 'double] [(com-object? arg) 'com-object] [(IUnknown? arg) 'iunknown] - [else 'any])) + [(vector? arg) `(array ,(vector-length arg) + ,(if (zero? (vector-length arg)) + 'int + (for/fold ([t (arg-to-type (vector-ref arg 0))]) ([v (in-vector arg)]) + (if (equal? t (arg-to-type v)) + t + 'any))))] + [else (error 'com "cannot infer marshal format for value: ~e" arg)])) (define (elem-desc-ref func-desc i) (ptr-add (FUNCDESC-lprgelemdescParam func-desc) i _ELEMDESC)) @@ -1132,13 +1142,13 @@ 0))) (define (do-get-method-type who obj name inv-kind internal?) - (define type-info (extract-type-info who obj)) + (define type-info (extract-type-info who obj (not internal?))) (when (and (= inv-kind INVOKE_FUNC) (is-dispatch-name? name)) (error who "IDispatch methods not available")) (define mx-type-desc (cond - [(com-object? obj) (get-method-type obj name inv-kind)] + [(com-object? obj) (get-method-type obj name inv-kind (not internal?))] [else (define x-type-info (if (= inv-kind INVOKE_EVENT) (event-type-info-from-com-type obj) @@ -1384,8 +1394,11 @@ (set-VARIANT-vt! var (get-var-type-from-elem-desc elem-desc)) (variant-set! var (to-ctype scheme-type) a)] [else - (set-VARIANT-vt! var (to-vt scheme-type)) - (variant-set! var (to-ctype scheme-type) a)])) + (define use-scheme-type (if (eq? scheme-type 'any) + (arg-to-type a) + scheme-type)) + (set-VARIANT-vt! var (to-vt use-scheme-type)) + (variant-set! var (to-ctype use-scheme-type) a)])) (define _float* (make-ctype _float @@ -1405,6 +1418,61 @@ (lambda (p) (ptr-ref p _t)))) +(define (make-a-VARIANT) + (define var (cast (malloc _VARIANT 'atomic-interior) + _pointer + _VARIANT-pointer)) + (VariantInit var) + var) + +(define (extract-variant-pointer var get?) + (define vt (VARIANT-vt var)) + (define ptr (union-ptr (VARIANT-u var))) + (switch + vt + [VT_BSTR (if get? ptr (ptr-ref ptr _pointer))] + [VT_DISPATCH (if get? ptr (ptr-ref ptr _pointer))] + [VT_UNKNOWN (if get? ptr (ptr-ref ptr _pointer))] + [VT_VARIANT var] + [else ptr])) + +(define (_safe-array/vectors dims base) + (make-ctype _pointer + (lambda (v) + (define sa (SafeArrayCreate (to-vt base) + (length dims) + (for/list ([d (in-list dims)]) + (make-SAFEARRAYBOUND d 0)))) + (register-cleanup! + (lambda () (SafeArrayDestroy sa))) + (let loop ([v v] [index null] [dims dims]) + (for ([v (in-vector v)] + [i (in-naturals)]) + (define idx (cons i index)) + (if (null? (cdr dims)) + (let ([var (make-a-VARIANT)]) + (scheme-to-variant! var v #f base) + (SafeArrayPutElement sa (reverse idx) + (extract-variant-pointer var #f))) + (loop v idx (cdr dims))))) + sa) + (lambda (_sa) + (define sa (cast _sa _pointer _SAFEARRAY-pointer)) + (define dims (for/list ([i (in-range (SafeArrayGetDim sa))]) + (- (add1 (SafeArrayGetUBound sa (add1 i))) + (SafeArrayGetLBound sa (add1 i))))) + (define vt (SafeArrayGetVartype sa)) + (let loop ([dims dims] [level 1] [index null]) + (define lb (SafeArrayGetLBound sa level)) + (for/vector ([i (in-range (car dims))]) + (if (null? (cdr dims)) + (let ([var (make-a-VARIANT)]) + (set-VARIANT-vt! var vt) + (SafeArrayGetElement sa (reverse (cons i index)) + (extract-variant-pointer var #t)) + (variant-to-scheme var)) + (loop (cdr dims) (add1 level) (cons i index)))))))) + (define (to-ctype type) (cond [(symbol? type) @@ -1435,8 +1503,15 @@ [(eq? 'box (car type)) (_box/permanent (to-ctype (cadr type)))] [(eq? 'array (car type)) - (_array/vector (to-ctype (caddr type)) - (cadr type))] + (define-values (dims base) + (let loop ([t type]) + (cond + [(and (pair? t) (eq? 'array (car t))) + (define-values (d b) (loop (caddr t))) + (values (cons (cadr t) d) b)] + [else + (values null t)]))) + (_safe-array/vectors dims base)] [else #f])) (define (to-vt type) @@ -1459,7 +1534,13 @@ [(boolean) VT_BOOL] [(iunknown) VT_UNKNOWN] [(com-object) VT_DISPATCH] - [else (error 'to-vt "Internal error: unsupported type ~s" type)])) + [(any) VT_VARIANT] + [else + (case (and (pair? type) + (car type)) + [(array) (bitwise-ior VT_ARRAY (to-vt (caddr type)))] + [else + (error 'to-vt "Internal error: unsupported type ~s" type)])])) (define (build-method-arguments-using-function-desc func-desc scheme-types inv-kind args) (define lcid-index (and func-desc (get-lcid-param-index func-desc))) @@ -1528,7 +1609,7 @@ (define (do-com-invoke who obj name args inv-kind) (check-com-obj who obj) (unless (string? name) (raise-type-error who "string" name)) - (let ([t (or (do-get-method-type 'com-invoke obj name inv-kind #t) + (let ([t (or (do-get-method-type who obj name inv-kind #t) ;; wing it by inferring types from the arguments: `(-> ,(map arg-to-type args) any))]) (unless (<= (length (filter (lambda (v) (not (and (pair? v) (eq? (car v) 'opt)))) @@ -1539,7 +1620,7 @@ (for ([arg (in-list args)] [type (in-list (cadr t))]) (check-argument 'com-invoke name arg type)) - (define type-desc (get-method-type obj name inv-kind)) ; cached + (define type-desc (get-method-type obj name inv-kind #f)) ; cached (cond [(if type-desc (mx-com-type-desc-memid type-desc) @@ -1551,13 +1632,22 @@ inv-kind args)) ;; from this point, don't escape/return without running cleanups + (when #f + ;; for debugging, inspect constructed arguments: + (eprintf "~e ~e\n" + t + (reverse + (for/list ([i (in-range num-params-passed)]) + (variant-to-scheme (ptr-ref (DISPPARAMS-rgvarg method-arguments) + _VARIANT + i)))))) (define method-result (if (= inv-kind INVOKE_PROPERTYPUT) #f (cast (malloc 'atomic _VARIANT) _pointer _VARIANT-pointer))) (when method-result (VariantInit method-result)) - (define-values (hr exn-info error-index) + (define-values (hr exn-info error-index) (Invoke (com-object-get-dispatch obj) memid IID_NULL LOCALE_SYSTEM_DEFAULT inv-kind method-arguments diff --git a/collects/ffi/unsafe/private/win32.rkt b/collects/ffi/unsafe/private/win32.rkt index 95a06ab0a5..f954871e5e 100644 --- a/collects/ffi/unsafe/private/win32.rkt +++ b/collects/ffi/unsafe/private/win32.rkt @@ -108,6 +108,7 @@ (define _VVAL (_union _double _intptr ;; etc. + (_array _pointer 2) )) (define-cstruct _VARIANT ([vt _VARTYPE] @@ -347,3 +348,33 @@ (let ([p (ptr-ref v _gcpointer)]) (let ([len (utf-16-length s)]) (SysAllocStringLen p len))))) + +(define _SAFEARRAY-pointer (_cpointer 'SAFEARRAY)) + +(define-oleaut SafeArrayCreate (_wfun _VARTYPE + _UINT + (dims : (_list i _SAFEARRAYBOUND)) + -> _SAFEARRAY-pointer)) +(define-oleaut SafeArrayDestroy (_hfun _SAFEARRAY-pointer + -> SafeArrayDestroy (void))) +(define-oleaut SafeArrayGetVartype (_hfun _SAFEARRAY-pointer + (vt : (_ptr o _VARTYPE)) + -> SafeArrayGetVartype vt)) +(define-oleaut SafeArrayGetLBound (_hfun _SAFEARRAY-pointer + _UINT + (v : (_ptr o _LONG)) + -> SafeArrayGetLBound v)) +(define-oleaut SafeArrayGetUBound (_hfun _SAFEARRAY-pointer + _UINT + (v : (_ptr o _LONG)) + -> SafeArrayGetUBound v)) +(define-oleaut SafeArrayPutElement (_hfun _SAFEARRAY-pointer + (_list i _LONG) + _pointer + -> SafeArrayPutElement (void))) +(define-oleaut SafeArrayGetElement (_hfun _SAFEARRAY-pointer + (_list i _LONG) + _pointer + -> SafeArrayGetElement (void))) +(define-oleaut SafeArrayGetDim (_wfun _SAFEARRAY-pointer + -> _UINT))