racket/collects/math/private/matrix/matrix-multiply.rkt
Neil Toronto f2dc2027f6 Initial math library commit. The history for these changes is preserved
in the original GitHub fork:

  https://github.com/ntoronto/racket

Some things about this are known to be broken (most egregious is that the
array tests DO NOT RUN because of a problem in typed/rackunit), about half
has no coverage in the tests, and half has no documentation. Fixes and
docs are coming. This is committed now to allow others to find errors and
inconsistency in the things that appear to be working, and to give the
author a (rather incomplete) sense of closure.
2012-11-16 11:39:51 -07:00

61 lines
2.8 KiB
Racket

#lang typed/racket
(require "../unsafe.rkt"
"../../array.rkt"
"matrix-types.rkt")
(provide matrix*)
;; The `make-matrix-*' operators have to be macros; see ../array/array-pointwise.rkt for an
;; explanation.
#;(: make-matrix-multiply (All (A) (Symbol
((Array A) Integer -> (Array A))
((Array A) (Array A) -> (Array A))
-> ((Array A) (Array A) -> (Array A)))))
(define-syntax-rule (make-matrix-multiply name array-axis-sum array*)
(λ (arr brr)
(unless (array-matrix? arr) (raise-type-error name "matrix" 0 arr brr))
(unless (array-matrix? brr) (raise-type-error name "matrix" 1 arr brr))
(match-define (vector ad0 ad1) (array-shape arr))
(match-define (vector bd0 bd1) (array-shape brr))
(unless (= ad1 bd0)
(error name
"1st argument column size and 2nd argument row size are not equal; given ~e and ~e"
arr brr))
;; Get strict versions of both because each element in both is evaluated multiple times
(let ([arr (array->mutable-array arr)]
[brr (array->mutable-array brr)])
;; This next part could be done with array-permute, but it's much slower that way
(define avs (mutable-array-data arr))
(define bvs (mutable-array-data brr))
;; Extend arr in the center dimension
(define: ds-ext : (Vectorof Index) (vector ad0 bd1 ad1))
(define arr-ext
(unsafe-build-array
ds-ext (λ: ([js : (Vectorof Index)])
(define j0 (unsafe-vector-ref js 0))
(define j1 (unsafe-vector-ref js 2))
;(unsafe-array-ref* arr j0 j1) [twice as slow]
(unsafe-vector-ref avs (unsafe-fx+ j1 (unsafe-fx* j0 ad1))))))
;; Transpose brr and extend in the leftmost dimension
;; Note that ds-ext = (vector ad0 bd1 bd0) because bd0 = ad1
(define brr-ext
(unsafe-build-array
ds-ext (λ: ([js : (Vectorof Index)])
(define j0 (unsafe-vector-ref js 2))
(define j1 (unsafe-vector-ref js 1))
;(unsafe-array-ref* brr j0 j1) [twice as slow]
(unsafe-vector-ref bvs (unsafe-fx+ j1 (unsafe-fx* j0 bd1))))))
(array-axis-sum (array* arr-ext brr-ext) 2))))
;; ---------------------------------------------------------------------------------------------------
(: matrix* (case-> ((Matrix Real) (Matrix Real) -> (Matrix Real))
((Matrix Number) (Matrix Number) -> (Matrix Number))))
(define matrix* (make-matrix-multiply 'matrix* array-axis-sum array*))
;(: matrix-fl* ((Array Float) (Array Float) -> (Array Float)))
;(define matrix-fl* (make-matrix-multiply 'matrix-fl* array-axis-flsum array-fl*))