80 lines
3.2 KiB
Racket
80 lines
3.2 KiB
Racket
#lang typed/racket/base
|
|
|
|
(require racket/performance-hint
|
|
racket/promise
|
|
"../../flonum.rkt"
|
|
"../unsafe.rkt"
|
|
"dist-struct.rkt"
|
|
"utils.rkt")
|
|
|
|
(provide flbernoulli-pdf
|
|
flbernoulli-cdf
|
|
flbernoulli-inv-cdf
|
|
flbernoulli-sample
|
|
Bernoulli-Dist bernoulli-dist bernoulli-dist-prob)
|
|
|
|
(: flbernoulli-pdf (Flonum Flonum Any -> Flonum))
|
|
(define (flbernoulli-pdf q k log?)
|
|
(cond [(not (flprobability? q)) +nan.0]
|
|
[log? (cond [(fl= k 0.0) (fllog1p (- q))]
|
|
[(fl= k 1.0) (fllog q)]
|
|
[else +nan.0])]
|
|
[else (cond [(fl= k 0.0) (fl- 1.0 q)]
|
|
[(fl= k 1.0) q]
|
|
[else +nan.0])]))
|
|
|
|
(: flbernoulli-cdf (Flonum Flonum Any Any -> Flonum))
|
|
(define (flbernoulli-cdf q k log? 1-p?)
|
|
(cond [(not (flprobability? q)) +nan.0]
|
|
[1-p? (cond [log? (cond [(k . fl< . 0.0) 0.0]
|
|
[(k . fl< . 1.0) (fllog q)]
|
|
[else -inf.0])]
|
|
[else (cond [(k . fl< . 0.0) 1.0]
|
|
[(k . fl< . 1.0) q]
|
|
[else 0.0])])]
|
|
[else (cond [log? (cond [(k . fl< . 0.0) -inf.0]
|
|
[(k . fl< . 1.0) (fllog1p (- q))]
|
|
[else 0.0])]
|
|
[else (cond [(k . fl< . 0.0) 0.0]
|
|
[(k . fl< . 1.0) (- 1.0 q)]
|
|
[else 1.0])])]))
|
|
|
|
(: flbernoulli-inv-cdf (Flonum Flonum Any Any -> Flonum))
|
|
(define (flbernoulli-inv-cdf q p log? 1-p?)
|
|
(cond [(not (flprobability? q)) +nan.0]
|
|
[1-p? (cond [log? (if (p . fl< . (fllog q)) 0.0 1.0)]
|
|
[else (if (p . fl< . q) 0.0 1.0)])]
|
|
[else (cond [log? (if (p . fl<= . (fllog1p (- q))) 0.0 1.0)]
|
|
[else (if (p . fl<= . (fl- 1.0 q)) 0.0 1.0)])]))
|
|
|
|
(: flbernoulli-sample (Flonum Integer -> FlVector))
|
|
(define (flbernoulli-sample q n)
|
|
(cond [(n . < . 0) (raise-argument-error 'flbernoulli-sample "Natural" 1 q n)]
|
|
[(not (flprobability? q)) (build-flvector n (λ (_) +nan.0))]
|
|
[else (build-flvector n (λ (_) (if ((random) . > . q) 0.0 1.0)))]))
|
|
|
|
(define-real-dist: bernoulli-dist Bernoulli-Dist
|
|
bernoulli-dist-struct ([prob : Flonum]))
|
|
|
|
(begin-encourage-inline
|
|
|
|
(: bernoulli-dist (case-> (-> Bernoulli-Dist)
|
|
(Real -> Bernoulli-Dist)))
|
|
(define (bernoulli-dist [q 0.5])
|
|
(let ([q (fl q)])
|
|
(define pdf (opt-lambda: ([k : Real] [log? : Any #f])
|
|
(flbernoulli-pdf q (fl k) log?)))
|
|
(define cdf (opt-lambda: ([k : Real] [log? : Any #f] [1-p? : Any #f])
|
|
(flbernoulli-cdf q (fl k) log? 1-p?)))
|
|
(define inv-cdf (opt-lambda: ([p : Real] [log? : Any #f] [1-p? : Any #f])
|
|
(flbernoulli-inv-cdf q (fl p) log? 1-p?)))
|
|
(define sample (case-lambda:
|
|
[() (unsafe-flvector-ref (flbernoulli-sample q 1) 0)]
|
|
[([n : Integer]) (flvector->list (flbernoulli-sample q n))]))
|
|
(bernoulli-dist-struct
|
|
pdf sample cdf inv-cdf
|
|
0.0 1.0 (delay (if (q . fl<= . 0.5) 0.0 1.0))
|
|
q)))
|
|
|
|
)
|