racket/collects/math/private/distributions/truncated-dist.rkt
Neil Toronto 055512b4e8 Renamed make-flexp/base' to make-flexpt'
Renamed `dist' struct type to `distribution' ("dist" is too common)
2012-12-03 22:45:31 -07:00

104 lines
4.0 KiB
Racket

#lang typed/racket/base
(require racket/performance-hint
racket/promise
"../../flonum.rkt"
"dist-struct.rkt"
"utils.rkt")
(provide Truncated-Dist
truncated-dist
truncated-dist-min
truncated-dist-max
truncated-dist-original)
(define-real-dist: truncated-dist Truncated-Dist
truncated-dist-struct ([original : Real-Dist] [min : Flonum] [max : Flonum]))
(: truncated-dist (case-> (Real-Dist -> Truncated-Dist)
(Real-Dist Real -> Truncated-Dist)
(Real-Dist Real Real -> Truncated-Dist)))
(define truncated-dist
(case-lambda
[(d) (truncated-dist d -inf.0 +inf.0)]
[(d a) (truncated-dist d -inf.0 a)]
[(d a b)
(let*-values ([(a b) (values (fl a) (fl b))]
[(a b) (values (max (ordered-dist-min d) (min a b))
(min (ordered-dist-max d) (max a b)))])
(unsafe-truncated-dist d a b))]))
(: unsafe-truncated-dist (Real-Dist Float Float -> Truncated-Dist))
(define (unsafe-truncated-dist d a b)
(define orig-pdf (distribution-pdf d))
(define orig-cdf (ordered-dist-cdf d))
(define orig-inv-cdf (ordered-dist-inv-cdf d))
(define orig-sample (distribution-sample d))
(define log-P_a<x<=b (real-dist-prob d a b #t #f))
(define log-P_x<=a (delay (orig-cdf a #t #f)))
(define log-P_x>b (delay (orig-cdf b #t #t)))
(: pdf (case-> (Real -> Flonum)
(Real Any -> Flonum)))
(define (pdf x [log? #f])
(let ([x (fl x)])
(define log-d
(cond [(x . fl< . a) -inf.0]
[(x . fl> . b) -inf.0]
[else (fl- (orig-pdf x #t) log-P_a<x<=b)]))
(if log? log-d (flexp log-d))))
(: cdf (case-> (Real -> Flonum)
(Real Any -> Flonum)
(Real Any Any -> Flonum)))
(define (cdf x [log? #f] [1-p? #f])
(let ([x (fl x)])
(define log-p
(cond [1-p? (cond [(x . fl< . a) 0.0]
[(x . fl> . b) -inf.0]
[else (flmin 0.0 (fl- (lg- (orig-cdf x #t #t)
(force log-P_x>b))
log-P_a<x<=b))])]
[else (cond [(x . fl< . a) -inf.0]
[(x . fl> . b) 0.0]
[else (flmin 0.0 (fl- (lg- (orig-cdf x #t #f)
(force log-P_x<=a))
log-P_a<x<=b))])]))
(if log? log-p (flexp log-p))))
(: inv-cdf (case-> (Real -> Flonum)
(Real Any -> Flonum)
(Real Any Any -> Flonum)))
(define (inv-cdf p [log? #f] [1-p? #f])
(let ([log-p (if log? (fl p) (fllog (fl p)))])
(cond
[(not (flprobability? log-p #t)) +nan.0]
[else
(define x
(cond [1-p? (cond [(fl= log-p 0.0) a]
[(fl= log-p -inf.0) b]
[else (orig-inv-cdf (lg+ (fl+ log-p log-P_a<x<=b)
(force log-P_x>b))
#t #t)])]
[else (cond [(fl= log-p 0.0) b]
[(fl= log-p -inf.0) a]
[else (orig-inv-cdf (lg+ (fl+ log-p log-P_a<x<=b)
(force log-P_x<=a))
#t #f)])]))
(min b (max a x))])))
(: sample-single (-> Flonum))
(define (sample-single)
(inv-cdf (fl* 0.5 (random)) #f ((random) . fl> . 0.5)))
(: sample (Sample Flonum))
(define sample
(case-lambda:
[() (sample-single)]
[([n : Integer])
(cond [(n . < . 0) (raise-argument-error 'truncated-dist-sample "Natural" n)]
[else (build-list n (λ (_) (sample-single)))])]))
;; Finally put it together
(truncated-dist-struct pdf sample cdf inv-cdf a b (delay (inv-cdf 0.5)) d a b))