Skip to content

Commit

Permalink
bpart: Track whether any binding replacement has happened in image mo…
Browse files Browse the repository at this point in the history
…dules (#57433)

This implements the optimization proposed in #57426 by keeping track of
whether any bindings were replaced in image modules (excluding `Main` as
facilitated by #57426). In addition, we augment serialization to keep
track of whether a method body contains any GlobalRefs that point to a
loaded (system or package) image. If both of these flags are true, we
can skip scanning the body of the method, since we know that we neither
need to add any additional backedges nor were any of the referenced
bindings invalidated. The performance impact on end-to-end load time is
small, but measurable. Overall `@time using ModelingToolkit`
consistently improves about 5% using this PR. However, I should note
that using time is still about 40% slower than 1.11. This is not
necessarily an Apples-to-Apples comparison as there were substantial
other changes on 1.12 (as well as current load-time-tunings targeting
older versions), but I wanted to put the number context.

(cherry picked from commit f6e2b98)
  • Loading branch information
Keno authored and KristofferC committed Feb 17, 2025
1 parent 1c9d39d commit 5e3f967
Show file tree
Hide file tree
Showing 13 changed files with 122 additions and 20 deletions.
8 changes: 6 additions & 2 deletions base/client.jl
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ function exec_options(opts)
distributed_mode = (opts.worker == 1) || (opts.nprocs > 0) || (opts.machine_file != C_NULL)
if distributed_mode
let Distributed = require(PkgId(UUID((0x8ba89e20_285c_5b6f, 0x9357_94700520ee1b)), "Distributed"))
Core.eval(MainInclude, :(const Distributed = $Distributed))
MainInclude.Distributed = Distributed
Core.eval(Main, :(using Base.MainInclude.Distributed))
invokelatest(Distributed.process_opts, opts)
end
Expand Down Expand Up @@ -400,7 +400,7 @@ function load_InteractiveUtils(mod::Module=Main)
try
# TODO: we have to use require_stdlib here because it is a dependency of REPL, but we would sort of prefer not to
let InteractiveUtils = require_stdlib(PkgId(UUID(0xb77e0a4c_d291_57a0_90e8_8db25a27a240), "InteractiveUtils"))
Core.eval(MainInclude, :(const InteractiveUtils = $InteractiveUtils))
MainInclude.InteractiveUtils = InteractiveUtils
end
catch ex
@warn "Failed to import InteractiveUtils into module $mod" exception=(ex, catch_backtrace())
Expand Down Expand Up @@ -535,6 +535,10 @@ The thrown errors are collected in a stack of exceptions.
"""
global err = nothing

# Used for memoizing require_stdlib of these modules
global InteractiveUtils::Module
global Distributed::Module

# weakly exposes ans and err variables to Main
export ans, err
end
Expand Down
18 changes: 13 additions & 5 deletions base/invalidation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,26 +180,34 @@ function binding_was_invalidated(b::Core.Binding)
b.partitions.min_world > unsafe_load(cglobal(:jl_require_world, UInt))
end

function scan_new_method!(methods_with_invalidated_source::IdSet{Method}, method::Method)
function scan_new_method!(methods_with_invalidated_source::IdSet{Method}, method::Method, image_backedges_only::Bool)
isdefined(method, :source) || return
if image_backedges_only && !has_image_globalref(method)
return
end
src = _uncompressed_ir(method)
mod = method.module
foreachgr(src) do gr::GlobalRef
b = convert(Core.Binding, gr)
binding_was_invalidated(b) && push!(methods_with_invalidated_source, method)
if binding_was_invalidated(b)
# TODO: We could turn this into an addition if condition. For now, use it as a reasonably cheap
# additional consistency chekc

Check warning on line 194 in base/invalidation.jl

View workflow job for this annotation

GitHub Actions / Check for new typos

perhaps "chekc" should be "check".
@assert !image_backedges_only
push!(methods_with_invalidated_source, method)
end
maybe_add_binding_backedge!(b, method)
end
end

function scan_new_methods(extext_methods::Vector{Any}, internal_methods::Vector{Any})
function scan_new_methods(extext_methods::Vector{Any}, internal_methods::Vector{Any}, image_backedges_only::Bool)
methods_with_invalidated_source = IdSet{Method}()
for method in internal_methods
if isa(method, Method)
scan_new_method!(methods_with_invalidated_source, method)
scan_new_method!(methods_with_invalidated_source, method, image_backedges_only)
end
end
for tme::Core.TypeMapEntry in extext_methods
scan_new_method!(methods_with_invalidated_source, tme.func::Method)
scan_new_method!(methods_with_invalidated_source, tme.func::Method, image_backedges_only)
end
return methods_with_invalidated_source
end
2 changes: 2 additions & 0 deletions base/runtime_internals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1656,3 +1656,5 @@ isempty(mt::Core.MethodTable) = (mt.defs === nothing)
uncompressed_ir(m::Method) = isdefined(m, :source) ? _uncompressed_ir(m) :
isdefined(m, :generator) ? error("Method is @generated; try `code_lowered` instead.") :
error("Code for this Method is not available.")

has_image_globalref(m::Method) = ccall(:jl_ir_flag_has_image_globalref, Bool, (Any,), m.source)
3 changes: 2 additions & 1 deletion base/staticdata.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ end
function insert_backedges(edges::Vector{Any}, ext_ci_list::Union{Nothing,Vector{Any}}, extext_methods::Vector{Any}, internal_methods::Vector{Any})
# determine which CodeInstance objects are still valid in our image
# to enable any applicable new codes
methods_with_invalidated_source = Base.scan_new_methods(extext_methods, internal_methods)
backedges_only = unsafe_load(cglobal(:jl_first_image_replacement_world, UInt)) == typemax(UInt)
methods_with_invalidated_source = Base.scan_new_methods(extext_methods, internal_methods, backedges_only)
stack = CodeInstance[]
visiting = IdDict{CodeInstance,Int}()
_insert_backedges(edges, stack, visiting, methods_with_invalidated_source)
Expand Down
16 changes: 14 additions & 2 deletions src/ircode.c
Original file line number Diff line number Diff line change
Expand Up @@ -547,14 +547,15 @@ static void jl_encode_value_(jl_ircode_state *s, jl_value_t *v, int as_literal)
}
}

static jl_code_info_flags_t code_info_flags(uint8_t propagate_inbounds, uint8_t has_fcall,
static jl_code_info_flags_t code_info_flags(uint8_t propagate_inbounds, uint8_t has_fcall, uint8_t has_image_globalref,
uint8_t nospecializeinfer, uint8_t isva,
uint8_t inlining, uint8_t constprop, uint8_t nargsmatchesmethod,
jl_array_t *ssaflags)
{
jl_code_info_flags_t flags;
flags.bits.propagate_inbounds = propagate_inbounds;
flags.bits.has_fcall = has_fcall;
flags.bits.has_image_globalref = has_image_globalref;
flags.bits.nospecializeinfer = nospecializeinfer;
flags.bits.isva = isva;
flags.bits.inlining = inlining;
Expand Down Expand Up @@ -1036,7 +1037,7 @@ JL_DLLEXPORT jl_string_t *jl_compress_ir(jl_method_t *m, jl_code_info_t *code)
};

uint8_t nargsmatchesmethod = code->nargs == m->nargs;
jl_code_info_flags_t flags = code_info_flags(code->propagate_inbounds, code->has_fcall,
jl_code_info_flags_t flags = code_info_flags(code->propagate_inbounds, code->has_fcall, code->has_image_globalref,
code->nospecializeinfer, code->isva,
code->inlining, code->constprop,
nargsmatchesmethod,
Expand Down Expand Up @@ -1134,6 +1135,7 @@ JL_DLLEXPORT jl_code_info_t *jl_uncompress_ir(jl_method_t *m, jl_code_instance_t
code->constprop = flags.bits.constprop;
code->propagate_inbounds = flags.bits.propagate_inbounds;
code->has_fcall = flags.bits.has_fcall;
code->has_image_globalref = flags.bits.has_image_globalref;
code->nospecializeinfer = flags.bits.nospecializeinfer;
code->isva = flags.bits.isva;
code->purity.bits = read_uint16(s.s);
Expand Down Expand Up @@ -1228,6 +1230,16 @@ JL_DLLEXPORT uint8_t jl_ir_flag_has_fcall(jl_string_t *data)
return flags.bits.has_fcall;
}

JL_DLLEXPORT uint8_t jl_ir_flag_has_image_globalref(jl_string_t *data)
{
if (jl_is_code_info(data))
return ((jl_code_info_t*)data)->has_image_globalref;
assert(jl_is_string(data));
jl_code_info_flags_t flags;
flags.packed = jl_string_data(data)[ir_offset_flags];
return flags.bits.has_image_globalref;
}

JL_DLLEXPORT uint16_t jl_ir_inlining_cost(jl_string_t *data)
{
if (jl_is_code_info(data))
Expand Down
6 changes: 4 additions & 2 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -3485,7 +3485,7 @@ void jl_init_types(void) JL_GC_DISABLED
jl_code_info_type =
jl_new_datatype(jl_symbol("CodeInfo"), core,
jl_any_type, jl_emptysvec,
jl_perm_symsvec(22,
jl_perm_symsvec(23,
"code",
"debuginfo",
"ssavaluetypes",
Expand All @@ -3502,13 +3502,14 @@ void jl_init_types(void) JL_GC_DISABLED
"nargs",
"propagate_inbounds",
"has_fcall",
"has_image_globalref",
"nospecializeinfer",
"isva",
"inlining",
"constprop",
"purity",
"inlining_cost"),
jl_svec(22,
jl_svec(23,
jl_array_any_type,
jl_debuginfo_type,
jl_any_type,
Expand All @@ -3527,6 +3528,7 @@ void jl_init_types(void) JL_GC_DISABLED
jl_bool_type,
jl_bool_type,
jl_bool_type,
jl_bool_type,
jl_uint8_type,
jl_uint8_type,
jl_uint16_type,
Expand Down
2 changes: 2 additions & 0 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ typedef struct _jl_code_info_t {
// various boolean properties:
uint8_t propagate_inbounds;
uint8_t has_fcall;
uint8_t has_image_globalref;
uint8_t nospecializeinfer;
uint8_t isva;
// uint8 settings
Expand Down Expand Up @@ -2263,6 +2264,7 @@ JL_DLLEXPORT jl_value_t *jl_compress_ir(jl_method_t *m, jl_code_info_t *code);
JL_DLLEXPORT jl_code_info_t *jl_uncompress_ir(jl_method_t *m, jl_code_instance_t *metadata, jl_value_t *data);
JL_DLLEXPORT uint8_t jl_ir_flag_inlining(jl_value_t *data) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint8_t jl_ir_flag_has_fcall(jl_value_t *data) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint8_t jl_ir_flag_has_image_globalref(jl_value_t *data) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint16_t jl_ir_inlining_cost(jl_value_t *data) JL_NOTSAFEPOINT;
JL_DLLEXPORT ssize_t jl_ir_nslots(jl_value_t *data) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint8_t jl_ir_slotflag(jl_value_t *data, size_t i) JL_NOTSAFEPOINT;
Expand Down
1 change: 1 addition & 0 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,7 @@ STATIC_INLINE jl_value_t *undefref_check(jl_datatype_t *dt, jl_value_t *v) JL_NO
typedef struct {
uint16_t propagate_inbounds:1;
uint16_t has_fcall:1;
uint16_t has_image_globalref:1;
uint16_t nospecializeinfer:1;
uint16_t isva:1;
uint16_t nargsmatchesmethod:1;
Expand Down
6 changes: 6 additions & 0 deletions src/method.c
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,11 @@ jl_code_info_t *jl_new_code_info_from_ir(jl_expr_t *ir)
is_flag_stmt = 1;
else if (jl_is_expr(st) && ((jl_expr_t*)st)->head == jl_return_sym)
jl_array_ptr_set(body, j, jl_new_struct(jl_returnnode_type, jl_exprarg(st, 0)));
else if (jl_is_globalref(st)) {
jl_globalref_t *gr = (jl_globalref_t*)st;
if (jl_object_in_image((jl_value_t*)gr->mod))
li->has_image_globalref = 1;
}
else {
if (jl_is_expr(st) && ((jl_expr_t*)st)->head == jl_assign_sym)
st = jl_exprarg(st, 1);
Expand Down Expand Up @@ -593,6 +598,7 @@ JL_DLLEXPORT jl_code_info_t *jl_new_code_info_uninit(void)
src->max_world = ~(size_t)0;
src->propagate_inbounds = 0;
src->has_fcall = 0;
src->has_image_globalref = 0;
src->nospecializeinfer = 0;
src->constprop = 0;
src->inlining = 0;
Expand Down
9 changes: 9 additions & 0 deletions src/module.c
Original file line number Diff line number Diff line change
Expand Up @@ -1294,10 +1294,19 @@ JL_DLLEXPORT jl_binding_partition_t *jl_replace_binding_locked(jl_binding_t *b,
new_world);
}

extern JL_DLLEXPORT _Atomic(size_t) jl_first_image_replacement_world;
JL_DLLEXPORT jl_binding_partition_t *jl_replace_binding_locked2(jl_binding_t *b,
jl_binding_partition_t *old_bpart, jl_value_t *restriction_val, size_t kind, size_t new_world)
{
check_safe_newbinding(b->globalref->mod, b->globalref->name);

// Check if this is a replacing a binding in the system or a package image.
// Until the first such replacement, we can fast-path validation.
// For these purposes, we consider the `Main` module to be a non-sysimg module.
// This is legal, because we special case the `Main` in check_safe_import_from.
if (jl_object_in_image((jl_value_t*)b) && b->globalref->mod != jl_main_module && jl_atomic_load_relaxed(&jl_first_image_replacement_world) == ~(size_t)0)
jl_atomic_store_relaxed(&jl_first_image_replacement_world, new_world);

assert(jl_atomic_load_relaxed(&b->partitions) == old_bpart);
jl_atomic_store_release(&old_bpart->max_world, new_world-1);
jl_binding_partition_t *new_bpart = new_binding_partition();
Expand Down
17 changes: 10 additions & 7 deletions src/staticdata.c
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ External links:

static const size_t WORLD_AGE_REVALIDATION_SENTINEL = 0x1;
JL_DLLEXPORT size_t jl_require_world = ~(size_t)0;
JL_DLLEXPORT _Atomic(size_t) jl_first_image_replacement_world = ~(size_t)0;

#include "staticdata_utils.c"
#include "precompile_utils.c"
Expand Down Expand Up @@ -3541,7 +3542,7 @@ extern void export_jl_small_typeof(void);
int IMAGE_NATIVE_CODE_TAINTED = 0;

// TODO: This should possibly be in Julia
static int jl_validate_binding_partition(jl_binding_t *b, jl_binding_partition_t *bpart, size_t mod_idx, int unchanged_implicit)
static int jl_validate_binding_partition(jl_binding_t *b, jl_binding_partition_t *bpart, size_t mod_idx, int unchanged_implicit, int no_replacement)
{
if (jl_atomic_load_relaxed(&bpart->max_world) != ~(size_t)0)
return 1;
Expand All @@ -3556,10 +3557,13 @@ static int jl_validate_binding_partition(jl_binding_t *b, jl_binding_partition_t
if (!jl_bkind_is_some_import(kind))
return 1;
jl_binding_t *imported_binding = (jl_binding_t*)bpart->restriction;
if (no_replacement)
goto add_backedge;
jl_binding_partition_t *latest_imported_bpart = jl_atomic_load_relaxed(&imported_binding->partitions);
if (!latest_imported_bpart)
return 1;
if (latest_imported_bpart->min_world <= bpart->min_world) {
add_backedge:
// Imported binding is still valid
if ((kind == BINDING_KIND_EXPLICIT || kind == BINDING_KIND_IMPORTED) &&
external_blob_index((jl_value_t*)imported_binding) != mod_idx) {
Expand All @@ -3583,7 +3587,7 @@ static int jl_validate_binding_partition(jl_binding_t *b, jl_binding_partition_t
jl_binding_t *bedge = (jl_binding_t*)edge;
if (!jl_atomic_load_relaxed(&bedge->partitions))
continue;
jl_validate_binding_partition(bedge, jl_atomic_load_relaxed(&bedge->partitions), mod_idx, 0);
jl_validate_binding_partition(bedge, jl_atomic_load_relaxed(&bedge->partitions), mod_idx, 0, 0);
}
}
if (bpart->kind & BINDING_FLAG_EXPORTED) {
Expand All @@ -3600,7 +3604,7 @@ static int jl_validate_binding_partition(jl_binding_t *b, jl_binding_partition_t
if (!jl_atomic_load_relaxed(&importee->partitions))
continue;
JL_UNLOCK(&mod->lock);
jl_validate_binding_partition(importee, jl_atomic_load_relaxed(&importee->partitions), mod_idx, 0);
jl_validate_binding_partition(importee, jl_atomic_load_relaxed(&importee->partitions), mod_idx, 0, 0);
JL_LOCK(&mod->lock);
}
}
Expand Down Expand Up @@ -4070,22 +4074,21 @@ static void jl_restore_system_image_from_stream_(ios_t *f, jl_image_t *image, jl
}
}
if (s.incremental) {
// This needs to be done in a second pass after the binding partitions
// have the proper ABI again.
int no_replacement = jl_atomic_load_relaxed(&jl_first_image_replacement_world) == ~(size_t)0;
for (size_t i = 0; i < s.fixup_objs.len; i++) {
uintptr_t item = (uintptr_t)s.fixup_objs.items[i];
jl_value_t *obj = (jl_value_t*)(image_base + item);
if (jl_is_module(obj)) {
jl_module_t *mod = (jl_module_t*)obj;
size_t mod_idx = external_blob_index((jl_value_t*)mod);
jl_svec_t *table = jl_atomic_load_relaxed(&mod->bindings);
int unchanged_implicit = all_usings_unchanged_implicit(mod);
int unchanged_implicit = no_replacement || all_usings_unchanged_implicit(mod);
for (size_t i = 0; i < jl_svec_len(table); i++) {
jl_binding_t *b = (jl_binding_t*)jl_svecref(table, i);
if ((jl_value_t*)b == jl_nothing)
continue;
jl_binding_partition_t *bpart = jl_atomic_load_relaxed(&b->partitions);
if (!jl_validate_binding_partition(b, bpart, mod_idx, unchanged_implicit)) {
if (!jl_validate_binding_partition(b, bpart, mod_idx, unchanged_implicit, no_replacement)) {
unchanged_implicit = all_usings_unchanged_implicit(mod);
}
}
Expand Down
5 changes: 4 additions & 1 deletion stdlib/Serialization/src/Serialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ const TAGS = Any[
const NTAGS = length(TAGS)
@assert NTAGS == 255

const ser_version = 29 # do not make changes without bumping the version #!
const ser_version = 30 # do not make changes without bumping the version #!

format_version(::AbstractSerializer) = ser_version
format_version(s::Serializer) = s.version
Expand Down Expand Up @@ -1268,6 +1268,9 @@ function deserialize(s::AbstractSerializer, ::Type{CodeInfo})
if format_version(s) >= 20
ci.has_fcall = deserialize(s)
end
if format_version(s) >= 30
ci.has_image_globalref = deserialize(s)::Bool
end
if format_version(s) >= 24
ci.nospecializeinfer = deserialize(s)::Bool
end
Expand Down
49 changes: 49 additions & 0 deletions test/rebinding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,3 +221,52 @@ module Regression
end
@test GeoParams57377.B.C.h() == GeoParams57377.B.C.S()
end

# Test that the validation bypass fast path is not defeated by loading InteractiveUtils
@test parse(UInt, readchomp(`$(Base.julia_cmd()) -e 'using InteractiveUtils; show(unsafe_load(cglobal(:jl_first_image_replacement_world, UInt)))'`)) == typemax(UInt)

# Test that imported module binding backedges are still added in a new module that has the fast path active
let test_code =
"""
using Test
@assert unsafe_load(cglobal(:jl_first_image_replacement_world, UInt)) == typemax(UInt)
include("precompile_utils.jl")
precompile_test_harness("rebinding precompile") do load_path
write(joinpath(load_path, "LotsOfBindingsToDelete2.jl"),
"module LotsOfBindingsToDelete2
const delete_me_6 = 6
end")
Base.compilecache(Base.PkgId("LotsOfBindingsToDelete2"))
write(joinpath(load_path, "UseTheBindings2.jl"),
"module UseTheBindings2
import LotsOfBindingsToDelete2: delete_me_6
f_use_bindings6() = delete_me_6
# Code Instances for each of these
@assert (f_use_bindings6(),) == (6,)
end")
Base.compilecache(Base.PkgId("UseTheBindings2"))
@eval using LotsOfBindingsToDelete2
@eval using UseTheBindings2
invokelatest() do
@test UseTheBindings2.f_use_bindings6() == 6
Base.delete_binding(LotsOfBindingsToDelete2, :delete_me_6)
invokelatest() do
@test_throws UndefVarError UseTheBindings2.f_use_bindings6()
end
end
end
finish_precompile_test!()
"""
@test success(pipeline(`$(Base.julia_cmd()) -e $test_code`; stderr))
end

# Image Globalref smoke test
module ImageGlobalRefFlag
using Test
@eval fimage() = $(GlobalRef(Base, :sin))
fnoimage() = x
@test Base.has_image_globalref(first(methods(fimage)))
@test !Base.has_image_globalref(first(methods(fnoimage)))
end

0 comments on commit 5e3f967

Please sign in to comment.