diff --git a/pkgs/racket-pkgs/racket-test/tests/racket/optimize.rktl b/pkgs/racket-pkgs/racket-test/tests/racket/optimize.rktl index b4deac79d5..245eb2cb51 100644 --- a/pkgs/racket-pkgs/racket-test/tests/racket/optimize.rktl +++ b/pkgs/racket-pkgs/racket-test/tests/racket/optimize.rktl @@ -3244,6 +3244,38 @@ (phase1-eval))) (test #t syntax? (f)) +;; ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; +;; Check unboxing through conditionals + +(let () + (define (check pred t1 e1) + (define v (* 2.0 (if (eval (pred 7.0)) + (eval (t1 7.0)) + (eval (e1 7.0))))) + (test v (eval `(lambda (arg) + (let ([x (if ,(pred 'arg) + ,(t1 'arg) + ,(e1 'arg))]) + (fl+ x x)))) + 7.0) + (test v (eval `(lambda (arg) + (fl* 2.0 (if ,(pred 'arg) + ,(t1 'arg) + ,(e1 'arg))))) + 7.0)) + (for ([pred (in-list (list + (lambda (arg) `(negative? ,arg)) + (lambda (arg) `(positive? ,arg)) + (lambda (arg) `(even? (fl* ,arg ,arg)))))]) + (for ([t1 (in-list (list + (lambda (arg) `(fl+ ,arg 8.0)) + (lambda (arg) `(fl- (fl+ ,arg 8.0) 1.0))))]) + (for ([e1 (in-list (list (lambda (arg) `(fl* 8.0 ,arg)) + (lambda (arg) `(begin + (display "") + (fl* 8.0 ,arg)))))]) + (check pred t1 e1))))) + ;; ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; diff --git a/racket/src/racket/src/jit.c b/racket/src/racket/src/jit.c index ac0ad88dcc..fc81b73d1a 100644 --- a/racket/src/racket/src/jit.c +++ b/racket/src/racket/src/jit.c @@ -1641,6 +1641,8 @@ static int generate_branch(Scheme_Object *obj, mz_jit_state *jitter, int is_tail Branch_Info for_this_branch; GC_CAN_IGNORE Branch_Info_Addr addrs[NUM_QUICK_INFO_ADDRS]; GC_CAN_IGNORE jit_insn *ref2; + mz_jit_unbox_state ubs; + int ubd, save_ubd; int pushed_marks; int nsrs, nsrs1, g1, g2, amt, need_sync, flostack, flostack_pos; int else_is_empty = 0, i, can_chain_branch, chain_true, chain_false, old_self_pos; @@ -1709,6 +1711,8 @@ static int generate_branch(Scheme_Object *obj, mz_jit_state *jitter, int is_tail mz_rs_sync(); + scheme_mz_unbox_save(jitter, &ubs); + if (!scheme_generate_inlined_test(jitter, branch->test, then_short_ok, &for_this_branch, need_sync)) { CHECK_LIMIT(); generate_non_tail_with_branch(branch->test, jitter, 0, 1, 0, &for_this_branch); @@ -1719,6 +1723,9 @@ static int generate_branch(Scheme_Object *obj, mz_jit_state *jitter, int is_tail } CHECK_LIMIT(); + save_ubd = jitter->unbox_depth; + scheme_mz_unbox_restore(jitter, &ubs); + /* True branch */ scheme_mz_runstack_saved(jitter); flostack = scheme_mz_flostack_save(jitter, &flostack_pos); @@ -1786,6 +1793,10 @@ static int generate_branch(Scheme_Object *obj, mz_jit_state *jitter, int is_tail if (old_self_pos != jitter->self_pos) scheme_signal_error("internal error: self position moved across branch"); + ubd = jitter->unbox_depth; + jitter->unbox_depth = save_ubd; + scheme_mz_unbox_restore(jitter, &ubs); + /* False branch */ mz_SET_REG_STATUS_VALID(0); scheme_mz_runstack_saved(jitter); @@ -1843,6 +1854,9 @@ static int generate_branch(Scheme_Object *obj, mz_jit_state *jitter, int is_tail END_JIT_DATA(12); + if (ubd != jitter->unbox_depth) + scheme_signal_error("internal error: different unbox depth for branches"); + /* Return result */ if ((g1 == 2) && (g2 == 2)) @@ -2600,10 +2614,13 @@ int scheme_generate(Scheme_Object *obj, mz_jit_state *jitter, int is_tail, int w { Scheme_Sequence *seq = (Scheme_Sequence *)obj; int cnt = seq->count, i; + mz_jit_unbox_state ubs; START_JIT_DATA(); LOG_IT(("begin\n")); + scheme_mz_unbox_save(jitter, &ubs); + for (i = 0; i < cnt - 1; i++) { scheme_generate_non_tail(seq->array[i], jitter, 1, 1, 1); CHECK_LIMIT(); @@ -2611,6 +2628,8 @@ int scheme_generate(Scheme_Object *obj, mz_jit_state *jitter, int is_tail, int w END_JIT_DATA(11); + scheme_mz_unbox_restore(jitter, &ubs); + return scheme_generate(seq->array[cnt - 1], jitter, is_tail, wcm_may_replace, multi_ok, orig_target, for_branch); } diff --git a/racket/src/racket/src/jitarith.c b/racket/src/racket/src/jitarith.c index f2e7748297..f3dc0502db 100644 --- a/racket/src/racket/src/jitarith.c +++ b/racket/src/racket/src/jitarith.c @@ -283,7 +283,7 @@ int scheme_can_unbox_inline(Scheme_Object *obj, int fuel, int regs, int unsafely } } -int scheme_can_unbox_directly(Scheme_Object *obj, int extfl) +int can_unbox_directly(Scheme_Object *obj, int extfl, int bfuel) /* Used only when !can_unbox_inline(). Detects safe operations that produce flonums when they don't raise an exception, and that the JIT supports directly unboxing. */ @@ -348,12 +348,28 @@ int scheme_can_unbox_directly(Scheme_Object *obj, int extfl) case scheme_letrec_type: obj = ((Scheme_Letrec *)obj)->body; break; + case scheme_branch_type: + if (!bfuel) + return 0; + bfuel--; + if (!can_unbox_directly(((Scheme_Branch_Rec *)obj)->tbranch, extfl, bfuel)) + return 0; + obj = ((Scheme_Branch_Rec *)obj)->fbranch; + break; + case scheme_sequence_type: + obj = ((Scheme_Sequence *)obj)->array[((Scheme_Sequence *)obj)->count - 1]; + break; default: return 0; } } } +int scheme_can_unbox_directly(Scheme_Object *obj, int extfl) +{ + return can_unbox_directly(obj, extfl, 3); +} + static jit_insn *generate_arith_slow_path(mz_jit_state *jitter, Scheme_Object *rator, jit_insn **_ref, jit_insn **_ref4, Branch_Info *for_branch, int branch_short, diff --git a/racket/src/racket/src/jitstate.c b/racket/src/racket/src/jitstate.c index 38c959af51..60f153deef 100644 --- a/racket/src/racket/src/jitstate.c +++ b/racket/src/racket/src/jitstate.c @@ -805,6 +805,7 @@ void scheme_mz_unbox_save(mz_jit_state *jitter, mz_jit_unbox_state *r) } void scheme_mz_unbox_restore(mz_jit_state *jitter, mz_jit_unbox_state *r) +/* can be called multipel times for an `r` by generate_branch() */ { jitter->unbox = r->unbox; #ifdef MZ_LONG_DOUBLE diff --git a/racket/src/racket/src/optimize.c b/racket/src/racket/src/optimize.c index ee3790b5db..10c83118c8 100644 --- a/racket/src/racket/src/optimize.c +++ b/racket/src/racket/src/optimize.c @@ -1971,6 +1971,20 @@ static int expr_produces_local_type(Scheme_Object *expr, int fuel) return produces_local_type(app->rator, 2); } break; + case scheme_branch_type: + { + Scheme_Branch_Rec *b = (Scheme_Branch_Rec *)expr; + return (expr_produces_local_type(b->tbranch, fuel / 2) + && expr_produces_local_type(b->fbranch, fuel / 2)); + } + break; + case scheme_sequence_type: + { + Scheme_Sequence *seq = (Scheme_Sequence *)expr; + + expr = seq->array[seq->count-1]; + break; + } case scheme_compiled_let_void_type: { Scheme_Let_Header *lh = (Scheme_Let_Header *)expr; diff --git a/racket/src/racket/src/validate.c b/racket/src/racket/src/validate.c index 80fd2091e7..4a160f6edd 100644 --- a/racket/src/racket/src/validate.c +++ b/racket/src/racket/src/validate.c @@ -1621,7 +1621,8 @@ static int validate_expr(Mz_CPort *port, Scheme_Object *expr, int cnt; int i, r; - no_typed(need_local_type, port); + if (type != scheme_sequence_type) + no_typed(need_local_type, port); cnt = seq->count; @@ -1642,8 +1643,6 @@ static int validate_expr(Mz_CPort *port, Scheme_Object *expr, Scheme_Branch_Rec *b; int vc_pos, vc_ncpos, r; - no_typed(need_local_type, port); - b = (Scheme_Branch_Rec *)expr; r = validate_expr(port, b->test, stack, tls, depth, letlimit, delta, num_toplevels, num_stxes, num_lifts, tl_use_map, @@ -1660,7 +1659,7 @@ static int validate_expr(Mz_CPort *port, Scheme_Object *expr, r = validate_expr(port, b->tbranch, stack, tls, depth, letlimit, delta, num_toplevels, num_stxes, num_lifts, tl_use_map, tl_state, tl_timestamp, - NULL, 0, result_ignored, vc, tailpos, 0, procs, + NULL, 0, result_ignored, vc, tailpos, need_local_type, procs, expected_results, NULL); result = validate_join_seq(result, r);