#lang racket/base (provide (all-from-out (submod "." typed-defs)) ;; Rename transformers dist-cdf dist-inv-cdf dist-min dist-max) (module typed-defs typed/racket/base (require racket/performance-hint racket/promise "../../flonum.rkt") (provide PDF Sample CDF Inverse-CDF (struct-out dist) (struct-out ordered-dist) Real-Dist dist-median sample real-dist-prob) (define-type (PDF In) (case-> (In -> Flonum) (In Any -> Flonum))) (define-type (Sample Out) (case-> (-> Out) (Integer -> (Listof Out)))) (define-type (CDF In) (case-> (In -> Flonum) (In Any -> Flonum) (In Any Any -> Flonum))) (define-type (Inverse-CDF Out) (case-> (Real -> Out) (Real Any -> Out) (Real Any Any -> Out))) (struct: (In Out) dist ([pdf : (PDF In)] [sample : (Sample Out)]) #:transparent) (struct: (In Out) ordered-dist dist ([cdf : (CDF In)] [inv-cdf : (Inverse-CDF Out)] [min : Out] [max : Out] [median : (Promise Out)]) #:transparent) (define-type Real-Dist (ordered-dist Real Flonum)) ;; ================================================================================================= (begin-encourage-inline (: dist-median (All (In Out) ((ordered-dist In Out) -> Out))) (define (dist-median d) (force (ordered-dist-median d))) (: sample (All (In Out) (case-> ((dist In Out) -> Out) ((dist In Out) Integer -> (Listof Out))))) (define sample (case-lambda [(d) ((dist-sample d))] [(d n) ((dist-sample d) n)])) ) (: real-dist-prob* (Real-Dist Flonum Flonum Any -> Flonum)) ;; Assumes a <= b (define (real-dist-prob* d a b 1-p?) (define c (dist-median d)) (define cdf (ordered-dist-cdf d)) (define p (cond [(a . fl= . b) (if 1-p? 1.0 0.0)] [1-p? (fl+ (cdf a #f #f) (cdf b #f #t))] [(b . fl<= . c) ;; Both less than the median; use lower tail only (fl- (cdf b #f #f) (cdf a #f #f))] [(a . fl>= . c) ;; Both greater than the median; use upper tail only (fl- (cdf a #f #t) (cdf b #f #t))] [else ;; Median between a and b; use lower for (a,c] and upper for (c,b] (define P_x<=a (cdf a #f #f)) (define P_x>b (cdf b #f #t)) (fl+ (fl- 0.5 P_x<=a) (fl- 0.5 P_x>b))])) (max 0.0 (min 1.0 p))) (: real-dist-log-prob* (Real-Dist Flonum Flonum Any -> Flonum)) ;; Assumes a <= b (define (real-dist-log-prob* d a b 1-p?) (define c (dist-median d)) (define cdf (ordered-dist-cdf d)) (define log-p (cond [(a . fl= . b) (if 1-p? 0.0 -inf.0)] [1-p? (lg+ (cdf a #t #f) (cdf b #t #t))] [(b . fl<= . c) ;; Both less than the median; use lower tail only (define log-P_x<=a (cdf a #t #f)) (define log-P_x<=b (cdf b #t #f)) (cond [(log-P_x<=b . fl< . log-P_x<=a) -inf.0] [else (lg- log-P_x<=b log-P_x<=a)])] [(a . fl>= . c) ;; Both greater than the median; use upper tail only (define log-P_x>a (cdf a #t #t)) (define log-P_x>b (cdf b #t #t)) (cond [(log-P_x>a . fl< . log-P_x>b) -inf.0] [else (lg- log-P_x>a log-P_x>b)])] [else ;; Median between a and b; try 1-upper first (define log-P_x<=a (cdf a #t #f)) (define log-P_x>b (cdf b #t #t)) (define log-p (lg1- (lg+ log-P_x<=a log-P_x>b))) (cond [(log-p . fl> . (log 0.1)) log-p] [else ;; Subtracting from 1.0 (in log space) lost bits; split and add instead (define log-P_ab) -inf.0] [else (lg- (fllog 0.5) log-P_x>b)])) (lg+ log-P_a (Real-Dist Real Real -> Flonum) (Real-Dist Real Real Any -> Flonum) (Real-Dist Real Real Any Any -> Flonum))) (define (real-dist-prob d a b [log? #f] [1-p? #f]) (let ([a (fl a)] [b (fl b)]) (let ([a (flmin a b)] [b (flmax a b)]) (cond [log? (define p (real-dist-prob* d a b 1-p?)) (cond [(and (p . fl> . +max-subnormal.0) (p . fl< . 0.9)) (fllog p)] [else (real-dist-log-prob* d a b 1-p?)])] [else (real-dist-prob* d a b 1-p?)])))) ) (require (submod "." typed-defs) (for-syntax racket/base)) (define-syntax dist-cdf (make-rename-transformer #'ordered-dist-cdf)) (define-syntax dist-inv-cdf (make-rename-transformer #'ordered-dist-inv-cdf)) (define-syntax dist-min (make-rename-transformer #'ordered-dist-min)) (define-syntax dist-max (make-rename-transformer #'ordered-dist-max))