From 8069a9be5490a3b4af3074cd0316c9e90a0ccf71 Mon Sep 17 00:00:00 2001 From: Neil Toronto Date: Wed, 26 Mar 2014 14:28:35 -0600 Subject: [PATCH] Reorganize some code, fix beta PDF for a,x = 1,0 and b,x = 1,1 --- .../math-pkgs/math-lib/math/distributions.rkt | 2 + .../private/distributions/dist-functions.rkt | 117 ++++++++++++++++++ .../private/distributions/dist-struct.rkt | 75 +---------- .../private/distributions/impl/beta-pdf.rkt | 6 + .../private/distributions/truncated-dist.rkt | 1 + 5 files changed, 128 insertions(+), 73 deletions(-) create mode 100644 pkgs/math-pkgs/math-lib/math/private/distributions/dist-functions.rkt diff --git a/pkgs/math-pkgs/math-lib/math/distributions.rkt b/pkgs/math-pkgs/math-lib/math/distributions.rkt index 528c0fd24d..d12118213d 100644 --- a/pkgs/math-pkgs/math-lib/math/distributions.rkt +++ b/pkgs/math-pkgs/math-lib/math/distributions.rkt @@ -1,6 +1,7 @@ #lang racket/base (require "private/distributions/dist-struct.rkt" + "private/distributions/dist-functions.rkt" "private/distributions/delta-dist.rkt" "private/distributions/uniform-dist.rkt" "private/distributions/triangle-dist.rkt" @@ -19,6 +20,7 @@ (provide (all-from-out "private/distributions/dist-struct.rkt" + "private/distributions/dist-functions.rkt" "private/distributions/delta-dist.rkt" "private/distributions/uniform-dist.rkt" "private/distributions/triangle-dist.rkt" diff --git a/pkgs/math-pkgs/math-lib/math/private/distributions/dist-functions.rkt b/pkgs/math-pkgs/math-lib/math/private/distributions/dist-functions.rkt new file mode 100644 index 0000000000..c1c7e9be13 --- /dev/null +++ b/pkgs/math-pkgs/math-lib/math/private/distributions/dist-functions.rkt @@ -0,0 +1,117 @@ +#lang typed/racket/base + +(require racket/promise + "../../flonum.rkt" + "dist-struct.rkt") + +(provide real-dist-prob + real-dist-hpd-interval) + +;; =================================================================================================== +;; Computing probabilities + +(: real-dist-prob* (Real-Dist Flonum Flonum Any -> Flonum)) +;; Assumes a <= b +(define (real-dist-prob* d a b 1-p?) + (define c (force (ordered-dist-median d))) + (define cdf (ordered-dist-cdf d)) + (define p + (cond [(a . fl= . b) (if 1-p? 1.0 0.0)] + [1-p? (fl+ (cdf a #f #f) (cdf b #f #t))] + [(b . fl<= . c) + ;; Both less than the median; use lower tail only + (fl- (cdf b #f #f) (cdf a #f #f))] + [(a . fl>= . c) + ;; Both greater than the median; use upper tail only + (fl- (cdf a #f #t) (cdf b #f #t))] + [else + ;; Median between a and b; use lower for (a,c] and upper for (c,b] + (define P_x<=a (cdf a #f #f)) + (define P_x>b (cdf b #f #t)) + (fl+ (fl- 0.5 P_x<=a) (fl- 0.5 P_x>b))])) + (max 0.0 (min 1.0 p))) + +(: real-dist-log-prob* (Real-Dist Flonum Flonum Any -> Flonum)) +;; Assumes a <= b +(define (real-dist-log-prob* d a b 1-p?) + (define c (force (ordered-dist-median d))) + (define cdf (ordered-dist-cdf d)) + (define log-p + (cond [(a . fl= . b) (if 1-p? 0.0 -inf.0)] + [1-p? (lg+ (cdf a #t #f) (cdf b #t #t))] + [(b . fl<= . c) + ;; Both less than the median; use lower tail only + (define log-P_x<=a (cdf a #t #f)) + (define log-P_x<=b (cdf b #t #f)) + (cond [(log-P_x<=b . fl< . log-P_x<=a) -inf.0] + [else (lg- log-P_x<=b log-P_x<=a)])] + [(a . fl>= . c) + ;; Both greater than the median; use upper tail only + (define log-P_x>a (cdf a #t #t)) + (define log-P_x>b (cdf b #t #t)) + (cond [(log-P_x>a . fl< . log-P_x>b) -inf.0] + [else (lg- log-P_x>a log-P_x>b)])] + [else + ;; Median between a and b; try 1-upper first + (define log-P_x<=a (cdf a #t #f)) + (define log-P_x>b (cdf b #t #t)) + (define log-p (lg1- (lg+ log-P_x<=a log-P_x>b))) + (cond [(log-p . fl> . (log 0.1)) log-p] + [else + ;; Subtracting from 1.0 (in log space) lost bits; split and add instead + (define log-P_ab) -inf.0] + [else (lg- (fllog 0.5) log-P_x>b)])) + (lg+ log-P_a (Real-Dist Real Real -> Flonum) + (Real-Dist Real Real Any -> Flonum) + (Real-Dist Real Real Any Any -> Flonum))) +(define (real-dist-prob d a b [log? #f] [1-p? #f]) + (let ([a (fl a)] [b (fl b)]) + (let ([a (flmin a b)] [b (flmax a b)]) + (cond [log? (define p (real-dist-prob* d a b 1-p?)) + (cond [(and (p . fl> . +max-subnormal.0) (p . fl< . 0.9)) (fllog p)] + [else (real-dist-log-prob* d a b 1-p?)])] + [else (real-dist-prob* d a b 1-p?)])))) + +;; =================================================================================================== +;; Highest probability density (HPD) regions + +(: real-dist-hpd-interval (-> Real-Dist Real (Values Real Real))) +(define (real-dist-hpd-interval d α) + (when (or (α . <= . 0) (α . > . 1)) + (raise-argument-error 'real-dist-hpd-interval "Real in (0,1]" 1 d α)) + + (let ([α (max (* 128.0 epsilon.0) (fl α))]) + (cond + [(α . >= . 1.0) (values (ordered-dist-min d) (ordered-dist-max d))] + [else + (define pdf (distribution-pdf d)) + (define cdf (ordered-dist-cdf d)) + (define inv-cdf (ordered-dist-inv-cdf d)) + + (: high-endpoint (-> Real Flonum)) + (define (high-endpoint a) + (inv-cdf (min 1.0 (+ (cdf a) α)))) + + (: objective (-> Flonum Flonum)) + (define (objective p) + (define a (inv-cdf p)) + (define b (high-endpoint a)) + (- (pdf a) (pdf b))) + + (define p (flbracketed-root objective +min.0 (- 1.0 α))) + + (define a + (cond [(<= 0.0 p 1.0) (inv-cdf p)] + [else + (if ((abs (objective 0.0)) . < . (abs (objective (- 1.0 α)))) + (inv-cdf 0.0) + (inv-cdf (- 1.0 α)))])) + + (values a (high-endpoint a))]))) diff --git a/pkgs/math-pkgs/math-lib/math/private/distributions/dist-struct.rkt b/pkgs/math-pkgs/math-lib/math/private/distributions/dist-struct.rkt index b268534388..ab1f04730b 100644 --- a/pkgs/math-pkgs/math-lib/math/private/distributions/dist-struct.rkt +++ b/pkgs/math-pkgs/math-lib/math/private/distributions/dist-struct.rkt @@ -1,15 +1,13 @@ #lang typed/racket/base -(require racket/performance-hint - racket/promise - "../../flonum.rkt") +(require racket/performance-hint) (provide PDF Sample CDF Inverse-CDF (struct-out distribution) (struct-out ordered-dist) Real-Dist - pdf sample cdf inv-cdf real-dist-prob) + pdf sample cdf inv-cdf) (define-type (PDF In) (case-> (In -> Flonum) @@ -72,72 +70,3 @@ ((ordered-dist-inv-cdf d) p log? 1-p?)) ) - -(: real-dist-prob* (Real-Dist Flonum Flonum Any -> Flonum)) -;; Assumes a <= b -(define (real-dist-prob* d a b 1-p?) - (define c (force (ordered-dist-median d))) - (define cdf (ordered-dist-cdf d)) - (define p - (cond [(a . fl= . b) (if 1-p? 1.0 0.0)] - [1-p? (fl+ (cdf a #f #f) (cdf b #f #t))] - [(b . fl<= . c) - ;; Both less than the median; use lower tail only - (fl- (cdf b #f #f) (cdf a #f #f))] - [(a . fl>= . c) - ;; Both greater than the median; use upper tail only - (fl- (cdf a #f #t) (cdf b #f #t))] - [else - ;; Median between a and b; use lower for (a,c] and upper for (c,b] - (define P_x<=a (cdf a #f #f)) - (define P_x>b (cdf b #f #t)) - (fl+ (fl- 0.5 P_x<=a) (fl- 0.5 P_x>b))])) - (max 0.0 (min 1.0 p))) - -(: real-dist-log-prob* (Real-Dist Flonum Flonum Any -> Flonum)) -;; Assumes a <= b -(define (real-dist-log-prob* d a b 1-p?) - (define c (force (ordered-dist-median d))) - (define cdf (ordered-dist-cdf d)) - (define log-p - (cond [(a . fl= . b) (if 1-p? 0.0 -inf.0)] - [1-p? (lg+ (cdf a #t #f) (cdf b #t #t))] - [(b . fl<= . c) - ;; Both less than the median; use lower tail only - (define log-P_x<=a (cdf a #t #f)) - (define log-P_x<=b (cdf b #t #f)) - (cond [(log-P_x<=b . fl< . log-P_x<=a) -inf.0] - [else (lg- log-P_x<=b log-P_x<=a)])] - [(a . fl>= . c) - ;; Both greater than the median; use upper tail only - (define log-P_x>a (cdf a #t #t)) - (define log-P_x>b (cdf b #t #t)) - (cond [(log-P_x>a . fl< . log-P_x>b) -inf.0] - [else (lg- log-P_x>a log-P_x>b)])] - [else - ;; Median between a and b; try 1-upper first - (define log-P_x<=a (cdf a #t #f)) - (define log-P_x>b (cdf b #t #t)) - (define log-p (lg1- (lg+ log-P_x<=a log-P_x>b))) - (cond [(log-p . fl> . (log 0.1)) log-p] - [else - ;; Subtracting from 1.0 (in log space) lost bits; split and add instead - (define log-P_ab) -inf.0] - [else (lg- (fllog 0.5) log-P_x>b)])) - (lg+ log-P_a (Real-Dist Real Real -> Flonum) - (Real-Dist Real Real Any -> Flonum) - (Real-Dist Real Real Any Any -> Flonum))) -(define (real-dist-prob d a b [log? #f] [1-p? #f]) - (let ([a (fl a)] [b (fl b)]) - (let ([a (flmin a b)] [b (flmax a b)]) - (cond [log? (define p (real-dist-prob* d a b 1-p?)) - (cond [(and (p . fl> . +max-subnormal.0) (p . fl< . 0.9)) (fllog p)] - [else (real-dist-log-prob* d a b 1-p?)])] - [else (real-dist-prob* d a b 1-p?)])))) diff --git a/pkgs/math-pkgs/math-lib/math/private/distributions/impl/beta-pdf.rkt b/pkgs/math-pkgs/math-lib/math/private/distributions/impl/beta-pdf.rkt index 193e35340d..67cf8c5060 100644 --- a/pkgs/math-pkgs/math-lib/math/private/distributions/impl/beta-pdf.rkt +++ b/pkgs/math-pkgs/math-lib/math/private/distributions/impl/beta-pdf.rkt @@ -17,6 +17,12 @@ (if (fl= x 1.0) +inf.0 -inf.0)] [(or (x . fl< . 0.0) (x . fl> . 1.0)) -inf.0] + ;; Avoid (* 0.0 -inf.0) by taking a limit from the right + [(and (fl= a 1.0) (fl= x 0.0)) + (- (fllog-beta 1.0 b))] + ;; Avoid (* 0.0 -inf.0) by taking a limit from the left + [(and (fl= b 1.0) (fl= x 1.0)) + (- (fllog-beta a 1.0))] [else (flsum (list (fl* (fl- a 1.0) (fllog x)) (fl* (fl- b 1.0) (fllog1p (- x))) diff --git a/pkgs/math-pkgs/math-lib/math/private/distributions/truncated-dist.rkt b/pkgs/math-pkgs/math-lib/math/private/distributions/truncated-dist.rkt index e7d4ad5d2d..a507bb6b0f 100644 --- a/pkgs/math-pkgs/math-lib/math/private/distributions/truncated-dist.rkt +++ b/pkgs/math-pkgs/math-lib/math/private/distributions/truncated-dist.rkt @@ -4,6 +4,7 @@ racket/promise "../../flonum.rkt" "dist-struct.rkt" + "dist-functions.rkt" "utils.rkt") (provide Truncated-Dist