diff --git a/collects/data/union-find.rkt b/collects/data/union-find.rkt index 51da566340..06ae0cf4d9 100644 --- a/collects/data/union-find.rkt +++ b/collects/data/union-find.rkt @@ -20,36 +20,47 @@ [else (λ (p port) (print p port mode))])) (recur (uf-find uf) port) (write-string ">" port)))]) + (define (uf-new x) (uf-set (box x) 0)) -(define (uf-union! a b) - (define a-rank (uf-set-rank a)) - (define b-rank (uf-set-rank b)) - (cond - [(< a-rank b-rank) - (set-uf-set-x! a b)] - [else - (set-uf-set-x! b a) - (when (= a-rank b-rank) - (set-uf-set-rank! a (+ a-rank 1)))])) -(define (uf-find a) - (define bx (uf-get-box a)) - (unbox bx)) + +(define (uf-union! _a _b) + (define a (uf-get-root _a)) + (define b (uf-get-root _b)) + (unless (eq? a b) + (define a-rank (uf-set-rank a)) + (define b-rank (uf-set-rank b)) + (cond + [(< a-rank b-rank) + (set-uf-set-x! a b)] + [else + (set-uf-set-x! b a) + (when (= a-rank b-rank) + (set-uf-set-rank! a (+ a-rank 1)))]))) + +(define (uf-find a) (unbox (uf-get-box a))) + (define (uf-set-canonical! a b) (set-box! (uf-get-box a) b)) -(define (uf-get-box a) - (let loop ([a (uf-set-x a)]) - (cond - [(box? a) a] - [else - (define fnd (loop (uf-set-x a))) - (set-uf-set-x! a fnd) - fnd]))) + (define (uf-same-set? a b) (eq? (uf-get-box a) (uf-get-box b))) +(define (uf-get-box a) (uf-set-x (uf-get-root a))) + +(define (uf-get-root a) + (let loop ([c a] + [p (uf-set-x a)]) + (cond + [(box? p) c] + [else + (define fnd (loop p (uf-set-x p))) + (set-uf-set-x! c fnd) + fnd]))) + (module+ test - (require rackunit - racket/list) + (require rackunit + racket/pretty + racket/set) (check-equal? (uf-find (uf-new 1)) 1) (check-equal? (let ([a (uf-new 1)] @@ -74,6 +85,12 @@ (uf-find b) (uf-find b)) 1) + + (check-equal? (let ([a (uf-new 1)]) + (uf-union! a a) + (uf-find a)) + 1) + (check-equal? (uf-same-set? (uf-new 1) (uf-new 2)) #f) (check-equal? (uf-same-set? (uf-new 1) (uf-new 1)) #f) (check-equal? (let ([a (uf-new 1)] @@ -107,6 +124,55 @@ (get-output-string sp)) "#0=#") + + (let ([a (uf-new 1)] + [b (uf-new 2)] + [c (uf-new 3)] + [d (uf-new 4)] + [e (uf-new 5)]) + (uf-union! a b) + (uf-union! c d) + (uf-union! b d) + (uf-union! c e) + (check-equal? (uf-find a) + (uf-find e))) + + (let ([a (uf-new 1)] + [b (uf-new 2)] + [c (uf-new 3)] + [d (uf-new 4)] + [e (uf-new 5)]) + (uf-union! a b) + (uf-union! c d) + (uf-union! a c) + (uf-union! c e) + (check-equal? (uf-find a) + (uf-find e))) + + (let ([a (uf-new 1)] + [b (uf-new 2)] + [c (uf-new 3)] + [d (uf-new 4)] + [e (uf-new 5)]) + (uf-union! a b) + (uf-union! c d) + (uf-union! a d) + (uf-union! c e) + (check-equal? (uf-find a) + (uf-find e))) + + (let ([a (uf-new 1)] + [b (uf-new 2)] + [c (uf-new 3)] + [d (uf-new 4)] + [e (uf-new 5)]) + (uf-union! a b) + (uf-union! c d) + (uf-union! b c) + (uf-union! c e) + (check-equal? (uf-find a) + (uf-find e))) + (check-equal? (let ([a (uf-new 1)] [b (uf-new 2)] [c (uf-new 3)] @@ -120,39 +186,81 @@ (uf-set-rank d))) 2) - (define (check-ranks uf) - (let loop ([uf/box uf] - [rank -inf.0]) - (cond - [(box? uf/box) (void)] - [else - (unless (< rank (uf-set-rank uf/box)) - (error 'check-ranks "failed for ~s" - (let loop ([uf uf]) - (cond - [(box? uf) `(box ,(unbox uf))] - [else `(uf-set ,(loop (uf-set-x uf)) - ,(uf-set-rank uf))])))) - (loop (uf-set-x uf/box) - (uf-set-rank uf/box))]))) + (let ((uf-sets (for/list ((x (in-range 8))) (uf-new x)))) + (uf-union! (list-ref uf-sets 5) (list-ref uf-sets 7)) + (uf-union! (list-ref uf-sets 1) (list-ref uf-sets 6)) + (uf-union! (list-ref uf-sets 6) (list-ref uf-sets 5)) + (uf-union! (list-ref uf-sets 4) (list-ref uf-sets 7)) + (uf-union! (list-ref uf-sets 2) (list-ref uf-sets 0)) + (uf-union! (list-ref uf-sets 2) (list-ref uf-sets 5)) + (check-equal? (uf-find (list-ref uf-sets 4)) + (uf-find (list-ref uf-sets 7)))) - (for ([x (in-range 1000)]) - (define num-sets (+ 2 (random 40))) - (define uf-sets - (shuffle - (for/list ([x (in-range num-sets)]) - (uf-new x)))) - (let loop ([uf-set (car uf-sets)] - [uf-sets (cdr uf-sets)]) - (when (zero? (random 3)) - (uf-find uf-set)) - (unless (null? uf-sets) - (uf-union! uf-set (car uf-sets)) - (loop (car uf-sets) - (cdr uf-sets)))) - (check-true - (apply = (map uf-find uf-sets))) + + (define (run-random-tests) + (define (make-random-sets num-sets) + (define uf-sets + (for/list ([x (in-range num-sets)]) + (uf-new x))) + (define edges (make-hash (build-list num-sets (λ (x) (cons x (set)))))) + (define (add-edge a-num b-num) + (hash-set! edges a-num (set-add (hash-ref edges a-num) b-num))) + (define ops '()) + (for ([op (in-range (random 10))]) + (define a-num (random num-sets)) + (define b-num (random num-sets)) + (define a (list-ref uf-sets a-num)) + (define b (list-ref uf-sets b-num)) + (set! ops (cons `(uf-union! (list-ref uf-sets ,a-num) + (list-ref uf-sets ,b-num)) + ops)) + (uf-union! a b) + (add-edge a-num b-num) + (add-edge b-num a-num)) + (define code `(let ([uf-sets + (for/list ([x (in-range ,num-sets)]) + (uf-new x))]) + ,@(reverse ops))) + (values uf-sets edges code)) - (for ([uf (in-list uf-sets)]) - (check-ranks uf)))) - + (define (check-canonical-has-path uf-sets edges code) + (for ([set (in-list uf-sets)] + [i (in-naturals)]) + (define canon (uf-find set)) + (define visited (make-hash)) + (define found? + (let loop ([node i]) + (cond + [(= node canon) #t] + [(hash-ref visited node #f) #f] + [else + (hash-set! visited node #t) + (for/or ([neighbor (in-set (hash-ref edges node))]) + (loop neighbor))]))) + (unless found? + (pretty-print code (current-error-port)) + (error 'union-find.rkt "mismatch; expected a link from ~s to ~s, didn't find it" + i canon)))) + + (define (check-edges-share-canonical uf-sets edges code) + (for ([(src dests) (in-hash edges)]) + (for ([dest (in-set dests)]) + (define sc (uf-find (list-ref uf-sets src))) + (define dc (uf-find (list-ref uf-sets dest))) + (unless (= sc dc) + (pretty-print code (current-error-port)) + (error 'union-find.rkt + "mismatch; expected sets ~s and ~s to have the same canonical element, got ~s and ~s" + src dest + sc dc))))) + + (for ([x (in-range 10000)]) + (define-values (sets edges code) + (make-random-sets (+ 2 (random (+ 1 (floor (/ x 100))))))) + (check-canonical-has-path sets edges code) + (check-edges-share-canonical sets edges code))) + + (run-random-tests) + + (random-seed 0) + (time (run-random-tests)))