From 5e22bb81f5bc53839e3add5b5c6ba342e6ebc22a Mon Sep 17 00:00:00 2001 From: Matthew Flatt Date: Wed, 7 Apr 2010 12:48:15 +0000 Subject: [PATCH] constant-folding repairs to some unsafe operations svn: r18745 --- collects/tests/mzscheme/unsafe.ss | 12 +++++++++--- src/mzscheme/src/numarith.c | 10 +++++++++- src/mzscheme/src/numcomp.c | 12 +++++++----- 3 files changed, 25 insertions(+), 9 deletions(-) diff --git a/collects/tests/mzscheme/unsafe.ss b/collects/tests/mzscheme/unsafe.ss index a7bfccc2f5..9013a3383e 100644 --- a/collects/tests/mzscheme/unsafe.ss +++ b/collects/tests/mzscheme/unsafe.ss @@ -27,11 +27,14 @@ (test result (compose post (eval `(lambda (x y) (,proc x y ,z)))) x y) (pre) (test result (compose post (eval `(lambda (x) (,proc x ,y ,z)))) x) + (pre) (when lit-ok? (pre) (test result (compose post (eval `(lambda (y) (,proc ,x y ,z)))) y) (pre) - (test result (compose post (eval `(lambda (z) (,proc ,x ,y z)))) z))) + (test result (compose post (eval `(lambda (z) (,proc ,x ,y z)))) z) + (pre) + (test result (compose post (eval `(lambda () (,proc ,x ,y ,z))))))) (define (test-bin result proc x y #:pre [pre void] #:post [post (lambda (x) x)] @@ -42,12 +45,15 @@ (test result (compose post (eval `(lambda (x y) (,proc x y)))) x y) (when lit-ok? (pre) - (test result (compose post (eval `(lambda (y) (,proc ,x y)))) y)) + (test result (compose post (eval `(lambda (y) (,proc ,x y)))) y) + (pre) + (test result (compose post (eval `(lambda () (,proc ,x ,y)))))) (pre) (test result (compose post (eval `(lambda (x) (,proc x ,y)))) x)) (define (test-un result proc x) (test result (eval proc) x) - (test result (eval `(lambda (x) (,proc x))) x)) + (test result (eval `(lambda (x) (,proc x))) x) + (test result (eval `(lambda () (,proc ',x))))) (test-bin 3 'unsafe-fx+ 1 2) (test-bin -1 'unsafe-fx+ 1 -2) diff --git a/src/mzscheme/src/numarith.c b/src/mzscheme/src/numarith.c index fbf2c024b3..d4259e4500 100644 --- a/src/mzscheme/src/numarith.c +++ b/src/mzscheme/src/numarith.c @@ -1005,7 +1005,15 @@ UNSAFE_FL(unsafe_fl_div, /, div_prim) } UNSAFE_FL1(unsafe_fl_abs, fabs, scheme_abs) -UNSAFE_FL1(unsafe_fl_sqrt, sqrt, scheme_sqrt) + +static Scheme_Object *pos_sqrt(int argc, Scheme_Object **argv) +{ + if (SCHEME_DBLP(argv[0]) && (SCHEME_DBL_VAL(argv[0]) < 0.0)) + return scheme_nan_object; + return scheme_sqrt(argc, argv); +} + +UNSAFE_FL1(unsafe_fl_sqrt, sqrt, pos_sqrt) #define SAFE_FL(name, sname, op) \ static Scheme_Object *name(int argc, Scheme_Object *argv[]) \ diff --git a/src/mzscheme/src/numcomp.c b/src/mzscheme/src/numcomp.c index 001f9481d0..260a5d84fe 100644 --- a/src/mzscheme/src/numcomp.c +++ b/src/mzscheme/src/numcomp.c @@ -505,17 +505,18 @@ SAFE_FX(fx_gt_eq, "fx>=", >=) SAFE_FX_X(fx_min, "fxmin", <, argv[0], argv[1]) SAFE_FX_X(fx_max, "fxmax", >, argv[0], argv[1]) -#define UNSAFE_FX_X(name, op, fold, T, F) \ +#define UNSAFE_FX_X(name, op, fold, T, F, SEL) \ static Scheme_Object *name(int argc, Scheme_Object *argv[]) \ { \ - if (scheme_current_thread->constant_folding) return (fold(argv[0], argv[1]) ? scheme_true : scheme_false); \ + if (scheme_current_thread->constant_folding) return SEL(fold(argv[0], argv[1])); \ if (SCHEME_INT_VAL(argv[0]) op SCHEME_INT_VAL(argv[1])) \ return T; \ else \ return F; \ } -#define UNSAFE_FX(name, op, fold) UNSAFE_FX_X(name, op, fold, scheme_true, scheme_false) +#define FX_SEL_BOOLEAN(e) (e ? scheme_true : scheme_false) +#define UNSAFE_FX(name, op, fold) UNSAFE_FX_X(name, op, fold, scheme_true, scheme_false, FX_SEL_BOOLEAN) UNSAFE_FX(unsafe_fx_eq, ==, scheme_bin_eq) UNSAFE_FX(unsafe_fx_lt, <, scheme_bin_lt) @@ -523,8 +524,9 @@ UNSAFE_FX(unsafe_fx_gt, >, scheme_bin_gt) UNSAFE_FX(unsafe_fx_lt_eq, <=, scheme_bin_lt_eq) UNSAFE_FX(unsafe_fx_gt_eq, >=, scheme_bin_gt_eq) -UNSAFE_FX_X(unsafe_fx_min, <, bin_min, argv[0], argv[1]) -UNSAFE_FX_X(unsafe_fx_max, >, bin_max, argv[0], argv[1]) +#define FX_SEL_ID(e) e +UNSAFE_FX_X(unsafe_fx_min, <, bin_min, argv[0], argv[1], FX_SEL_ID) +UNSAFE_FX_X(unsafe_fx_max, >, bin_max, argv[0], argv[1], FX_SEL_ID) #define SAFE_FL_X(name, sname, op, T, F) \ static Scheme_Object *name(int argc, Scheme_Object *argv[]) \