From b25a8a77c26e1e7093214284ce55f7f998b54caf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Belmant?= Date: Fri, 7 Feb 2025 07:16:07 -0500 Subject: [PATCH 01/16] Extend `PartialStruct` with an additional `defined` field --- Compiler/src/abstractinterpretation.jl | 25 ++---- Compiler/src/tfuncs.jl | 6 +- Compiler/src/typelattice.jl | 31 +++++-- Compiler/src/typelimits.jl | 116 +++++++++++++++++++------ Compiler/test/inference.jl | 76 ++++++++++++++++ base/boot.jl | 2 +- base/coreir.jl | 40 +++++++-- src/jltypes.c | 6 +- 8 files changed, 239 insertions(+), 63 deletions(-) diff --git a/Compiler/src/abstractinterpretation.jl b/Compiler/src/abstractinterpretation.jl index 8a7e8aee715a6..95284990f3e8a 100644 --- a/Compiler/src/abstractinterpretation.jl +++ b/Compiler/src/abstractinterpretation.jl @@ -2148,23 +2148,16 @@ function form_partially_defined_struct(@nospecialize(obj), @nospecialize(name)) isabstracttype(objt) && return nothing fldidx = try_compute_fieldidx(objt, name.val) fldidx === nothing && return nothing + isa(obj, PartialStruct) && return define_field(obj, fldidx, fieldtype(objt0, fldidx)) nminfld = datatype_min_ninitialized(objt) - if ismutabletype(objt) - # A mutable struct can have non-contiguous undefined fields, but `PartialStruct` cannot - # model such a state. So here `PartialStruct` can be used to represent only the - # objects where the field following the minimum initialized fields is also defined. - if fldidx ≠ nminfld+1 - # if it is already represented as a `PartialStruct`, we can add one more - # `isdefined`-field information on top of those implied by its `fields` - if !(obj isa PartialStruct && fldidx == length(obj.fields)+1) - return nothing - end - end - else - fldidx > nminfld || return nothing - end - return PartialStruct(fallback_lattice, objt0, Any[obj isa PartialStruct && i≤length(obj.fields) ? - obj.fields[i] : fieldtype(objt0,i) for i = 1:fldidx]) + fldidx > nminfld || return nothing + fields = Any[fieldtype(objt0, i) for i in 1:nminfld] + nmaxfld = something(datatype_fieldcount(objt), fldidx) + defined = falses(nmaxfld) + for i in 1:nminfld defined[i] = true end + defined[fldidx] = true + push!(fields, fieldtype(objt0, fldidx)) + return PartialStruct(fallback_lattice, objt0, defined, fields) end function abstract_call_unionall(interp::AbstractInterpreter, argtypes::Vector{Any}, call::CallMeta) diff --git a/Compiler/src/tfuncs.jl b/Compiler/src/tfuncs.jl index 50b88bb0222ce..3ffcec66fbb85 100644 --- a/Compiler/src/tfuncs.jl +++ b/Compiler/src/tfuncs.jl @@ -439,7 +439,7 @@ end end elseif isa(arg1, PartialStruct) if !isvarargtype(arg1.fields[end]) - if 1 ≤ idx ≤ length(arg1.fields) + if is_field_defined(arg1, idx) return Const(true) end end @@ -1143,8 +1143,8 @@ end sty = unwrap_unionall(s)::DataType if isa(name, Const) nv = _getfield_fieldindex(sty, name) - if isa(nv, Int) && 1 <= nv <= length(s00.fields) - return unwrapva(s00.fields[nv]) + if isa(nv, Int) && is_field_defined(s00, nv) + return unwrapva(get_defined_field(s00, nv)) end end s00 = s diff --git a/Compiler/src/typelattice.jl b/Compiler/src/typelattice.jl index 6f7612b836c89..a9322268777fa 100644 --- a/Compiler/src/typelattice.jl +++ b/Compiler/src/typelattice.jl @@ -431,9 +431,17 @@ end return false end end - for i in 1:length(b.fields) - af = a.fields[i] - bf = b.fields[i] + length(a.defined) ≥ length(b.defined) || return false + n = length(b.defined) + ai = bi = 0 + for i in 1:n + ai += a.defined[i] + bi += b.defined[i] + !a.defined[i] && b.defined[i] && return false + !b.defined[i] && continue + # Field is defined for both `a` and `b` + af = a.fields[ai] + bf = b.fields[bi] if i == length(b.fields) if isvarargtype(af) # If `af` is vararg, so must bf by the <: above @@ -468,10 +476,11 @@ end n_initialized(a) ≥ length(b.fields) || return false end nf = nfields(a.val) + bi = 0 for i in 1:nf - isdefined(a.val, i) || continue # since ∀ T Union{} ⊑ T - i > length(b.fields) && break # `a` has more information than `b` that is partially initialized struct - bfᵢ = b.fields[i] + !isdefined(a.val, i) && b.defined[i] && return false + !b.defined[i] && continue + bfᵢ = b.fields[bi += 1] if i == nf bfᵢ = unwrapva(bfᵢ) end @@ -541,6 +550,7 @@ end if isa(a, PartialStruct) isa(b, PartialStruct) || return false length(a.fields) == length(b.fields) || return false + a.defined == b.defined || return false widenconst(a) == widenconst(b) || return false a.fields === b.fields && return true # fast path for i in 1:length(a.fields) @@ -751,5 +761,12 @@ function Core.PartialStruct(::AbstractLattice, @nospecialize(typ), fields::Vecto for i = 1:length(fields) assert_nested_slotwrapper(fields[i]) end - return Core._PartialStruct(typ, fields) + return PartialStruct(typ, fields) +end + +function Core.PartialStruct(::AbstractLattice, @nospecialize(typ), defined::BitVector, fields::Vector{Any}) + for i = 1:length(fields) + assert_nested_slotwrapper(fields[i]) + end + return Core._PartialStruct(typ, defined, fields) end diff --git a/Compiler/src/typelimits.jl b/Compiler/src/typelimits.jl index 536b5fb34d1b1..c7210bbc35569 100644 --- a/Compiler/src/typelimits.jl +++ b/Compiler/src/typelimits.jl @@ -326,6 +326,71 @@ function n_initialized(t::Const) return something(findfirst(i::Int->!isdefined(t.val,i), 1:nf), nf+1)-1 end +defined_fields(pstruct::PartialStruct) = pstruct.defined + +function defined_field_index(pstruct::PartialStruct, fi) + i = 0 + for iter in 1:fi + iter ≤ length(pstruct.defined) && pstruct.defined[iter] && (i += 1) + end + i +end + +get_defined_field(pstruct::PartialStruct, fi) = pstruct.fields[defined_field_index(pstruct, fi)] +is_field_defined(pstruct::PartialStruct, fi) = get(pstruct.defined, fi, false) + +function define_field(pstruct::PartialStruct, fi, @nospecialize(ft)) + n = length(pstruct.defined) + if fi ≤ n && pstruct.defined[fi] + # XXX: merge new information? + # `setfield!(..., rand()); setfield!(..., 2.0)` + return nothing + end + defined = falses(max(fi, n)) + for i in 1:n + defined[i] = pstruct.defined[i] + end + fields = copy(pstruct.fields) + defined[fi] = true + i = defined_field_index(pstruct, fi) + insert!(fields, i + 1, ft) + PartialStruct(fallback_lattice, pstruct.typ, defined, fields) +end + +# needed while we are missing functions such as broadcasting or ranges + +function _bitvector(nt::NTuple) + bv = BitVector(undef, length(nt)) + i = 1 + while i ≤ length(nt) + bv[i] = nt[i] + i += 1 + end + bv +end + +function _count(bv::BitVector) + n = 0 + for val in bv + n += val + end + n +end + +#- + +function defined_fields(t::Const) + nf = nfields(t.val) + _bitvector(ntuple(i -> isdefined(t.val, i), nf)) +end + +function defined_fields(x, y) + xdef = defined_fields(x) + ydef = defined_fields(y) + n = min(length(xdef), length(ydef)) + _bitvector(ntuple(i -> xdef[i] & ydef[i], n)) +end + # A simplified type_more_complex query over the extended lattice # (assumes typeb ⊑ typea) @nospecializeinfer function issimplertype(𝕃::AbstractLattice, @nospecialize(typea), @nospecialize(typeb)) @@ -333,15 +398,14 @@ end typea === typeb && return true if typea isa PartialStruct aty = widenconst(typea) - if typeb isa Const - @assert length(typea.fields) ≤ n_initialized(typeb) "typeb ⊑ typea is assumed" - elseif typeb isa PartialStruct - @assert length(typea.fields) ≤ length(typeb.fields) "typeb ⊑ typea is assumed" - else - return false - end - for i = 1:length(typea.fields) - ai = unwrapva(typea.fields[i]) + isa(typeb, Const) || isa(typeb, PartialStruct) || return false + @assert all(x & y == x for (x, y) in zip(defined_fields(typea), defined_fields(typeb))) "typeb ⊑ typea is assumed" + fi = 0 + nf = length(typea.defined) + for i = 1:nf + typea.defined[i] || continue + fi += 1 + ai = unwrapva(typea.fields[fi]) bi = fieldtype(aty, i) is_lattice_equal(𝕃, ai, bi) && continue tni = _typename(widenconst(ai)) @@ -588,21 +652,15 @@ end aty = widenconst(typea) bty = widenconst(typeb) if aty === bty && !isType(aty) - if typea isa PartialStruct - if typeb isa PartialStruct - nflds = min(length(typea.fields), length(typeb.fields)) - else - nflds = min(length(typea.fields), n_initialized(typeb::Const)) - end - elseif typeb isa PartialStruct - nflds = min(n_initialized(typea::Const), length(typeb.fields)) - else - nflds = min(n_initialized(typea::Const), n_initialized(typeb::Const)) - end - nflds == 0 && return nothing - fields = Vector{Any}(undef, nflds) - anyrefine = nflds > datatype_min_ninitialized(aty) - for i = 1:nflds + typea::Union{PartialStruct, Const} + typeb::Union{PartialStruct, Const} + defined = defined_fields(typea, typeb) + ndefined = _count(defined) + ndefined == 0 && return nothing + fields = [] + anyrefine = ndefined > datatype_min_ninitialized(aty) + for (i, def) in enumerate(defined) + def || continue ai = getfield_tfunc(𝕃, typea, Const(i)) bi = getfield_tfunc(𝕃, typeb, Const(i)) ft = fieldtype(aty, i) @@ -632,13 +690,19 @@ end tyi = ft end end - fields[i] = tyi + push!(fields, tyi) if !anyrefine anyrefine = has_nontrivial_extended_info(𝕃, tyi) || # extended information ⋤(𝕃, tyi, ft) # just a type-level information, but more precise than the declared type end end - anyrefine && return PartialStruct(𝕃, aty, fields) + if isa(typea, PartialStruct) && isa(typeb, PartialStruct) && + isvarargtype(typea.fields[end]) && isvarargtype(typeb.fields[end]) + # XXX: If it may be more precise than `Vararg` (e.g. `Vararg{T}`), + # handle that in the main loop above to get a more accurate type. + push!(fields, Vararg) + end + anyrefine && return PartialStruct(𝕃, aty, defined, fields) end return nothing end diff --git a/Compiler/test/inference.jl b/Compiler/test/inference.jl index 563828ac77296..a5d6f754048d7 100644 --- a/Compiler/test/inference.jl +++ b/Compiler/test/inference.jl @@ -4783,10 +4783,36 @@ end @test a == Tuple end +module _Partials_inference + mutable struct Partial + x::String + y::Integer + z::Any + Partial(args...) = new(args...) + end + + struct Partial2 + x::String + y::Integer + z::Any + Partial2(args...) = new(args...) + end + + struct Partial3 + x::Int + y::String + z::Float64 + Partial3(args...) = new(args...) + end +end + let ⊑ = Compiler.partialorder(Compiler.fallback_lattice) ⊔ = Compiler.join(Compiler.fallback_lattice) 𝕃 = Compiler.fallback_lattice Const, PartialStruct = Core.Const, Core.PartialStruct + form_partially_defined_struct = Compiler.form_partially_defined_struct + M = Partials_inference + Partial, Partial2, Partial3 = M.Partial, M.Partial2, M.Partial3 @test (Const((1,2)) ⊑ PartialStruct(𝕃, Tuple{Int,Int}, Any[Const(1),Int])) @test !(Const((1,2)) ⊑ PartialStruct(𝕃, Tuple{Int,Int,Int}, Any[Const(1),Int,Int])) @@ -4807,6 +4833,56 @@ let ⊑ = Compiler.partialorder(Compiler.fallback_lattice) @test t isa PartialStruct && length(t.fields) == 2 && t.fields[1] === Const(false) t = t ⊔ Const((false, false, 0)) @test t ⊑ Union{Tuple{Bool,Bool},Tuple{Bool,Bool,Int}} + + t = PartialStruct(𝕃, Tuple{Int, Int}, Any[Const(1), Int]) + @test t.defined == [true, true] + t = PartialStruct(𝕃, Partial, Any[String, Const(2)]) + @test t.defined == [true, true, false] + @test t ⊑ t && t ⊔ t === t + + t1 = PartialStruct(𝕃, Partial, Any[String, Const(3)]) + t2 = PartialStruct(𝕃, Partial, Any[Const("x"), Int]) + @test !(t1 ⊑ t2) && !(t2 ⊑ t1) + t3 = t1 ⊔ t2 + @test t3.fields == Any[String, Int] + + t1 = PartialStruct(𝕃, Partial, BitVector([false, true, true]), Any[Int, Const(3)]) + @test t1 ⊑ t1 && t1 ⊔ t1 === t1 + t2 = PartialStruct(𝕃, Partial, BitVector([true, false, true]), Any[Const("x"), Int]) + t3 = t1 ⊔ t2 + @test t3.defined == [false, false, true] && t3.fields == Any[Int] + + t1 = PartialStruct(𝕃, Tuple, Any[Int, String, Vararg]) + @test t1.defined == [true, true] + @test t1 ⊑ t1 && t1 ⊔ t1 == t1 + t2 = PartialStruct(𝕃, Tuple, Any[Int, Any]) + @test !(t1 ⊑ t2) && !(t2 ⊑ t1) + t3 = t1 ⊔ t2 + @test t3.defined == [true, true] && t3.fields == Any[Int, Any] + t2 = PartialStruct(𝕃, Tuple, Any[Int, Any, Vararg]) + @test t1 ⊑ t2 + @test t1 ⊔ t2 === t2 + + t = PartialStruct(𝕃, Partial, Any[String, Const(2)]) + @test form_partially_defined_struct(t, Const(:x)) === nothing + t′ = form_partially_defined_struct(t, Const(:z)) + @test t′ == PartialStruct(𝕃, Partial, Any[String, Const(2), Any]) + + t = PartialStruct(𝕃, Partial2, Any[String, Const(2)]) + @test form_partially_defined_struct(t, Const(:x)) === nothing + t′ = form_partially_defined_struct(t, Const(:z)) + @test t′ == PartialStruct(𝕃, Partial2, Any[String, Const(2), Any]) + + @test form_partially_defined_struct(Partial3, Const(:x)) === nothing + t = form_partially_defined_struct(Partial3, Const(:y)) + @test t == PartialStruct(𝕃, Partial3, Any[Int, String]) + t = form_partially_defined_struct(Partial3, Const(:z)) + @test t == PartialStruct(𝕃, Partial3, BitVector([true, false, true]), Any[Int, Float64]) + t = form_partially_defined_struct(t, Const(:y)) + @test t == PartialStruct(𝕃, Partial3, Any[Int, String, Float64]) + t = PartialStruct(𝕃, Partial3, Any[Int, String]) + t′ = form_partially_defined_struct(t, Const(:z)) + @test t′ == PartialStruct(𝕃, Partial3, Any[Int, String, Float64]) end # Test that a function-wise `@max_methods` works as expected diff --git a/base/boot.jl b/base/boot.jl index 26a405f92f884..a4f955001a570 100644 --- a/base/boot.jl +++ b/base/boot.jl @@ -540,7 +540,7 @@ eval(Core, quote UpsilonNode(@nospecialize(val)) = $(Expr(:new, :UpsilonNode, :val)) UpsilonNode() = $(Expr(:new, :UpsilonNode)) Const(@nospecialize(v)) = $(Expr(:new, :Const, :v)) - _PartialStruct(@nospecialize(typ), fields::Array{Any, 1}) = $(Expr(:new, :PartialStruct, :typ, :fields)) + _PartialStruct(@nospecialize(typ), defined, fields::Array{Any, 1}) = $(Expr(:new, :PartialStruct, :typ, :defined, :fields)) PartialOpaque(@nospecialize(typ), @nospecialize(env), parent::MethodInstance, source) = $(Expr(:new, :PartialOpaque, :typ, :env, :parent, :source)) InterConditional(slot::Int, @nospecialize(thentype), @nospecialize(elsetype)) = $(Expr(:new, :InterConditional, :slot, :thentype, :elsetype)) MethodMatch(@nospecialize(spec_types), sparams::SimpleVector, method::Method, fully_covers::Bool) = $(Expr(:new, :MethodMatch, :spec_types, :sparams, :method, :fully_covers)) diff --git a/base/coreir.jl b/base/coreir.jl index 5199dfd35f028..6b7745342272d 100644 --- a/base/coreir.jl +++ b/base/coreir.jl @@ -14,7 +14,8 @@ Core.Const """ struct PartialStruct typ - fields::Vector{Any} # elements are other type lattice members + defined::BitVector # sorted list of fields that are known to be defined + fields::Vector{Any} # i-th element describes the lattice element for the i-th defined field end This extended lattice element is introduced when we have information about an object's @@ -23,19 +24,44 @@ some elements are known to be constants or a struct whose `Any`-typed field is i with `Int` values. - `typ` indicates the type of the object -- `fields` holds the lattice elements corresponding to each field of the object +- `defined` records which fields are defined +- `fields` holds the lattice elements corresponding to each defined field of the object -If `typ` is a struct, `fields` represents the fields of the struct that are guaranteed to be -initialized. For instance, if the length of `fields` of `PartialStruct` representing a -struct with 4 fields is 3, the 4th field may not be initialized. If the length is 4, all -fields are guaranteed to be initialized. +If `typ` is a struct, `defined` represents whether the corresponding field of the struct is guaranteed to be +initialized. For any defined field (`defined[i] === true`), there is a corresponding `fields` element +which provides information about the type of the defined field. If `typ` is a tuple, the last element of `fields` may be `Vararg`. In this case, it is guaranteed that the number of elements in the tuple is at least `length(fields)-1`, but the -exact number of elements is unknown. +exact number of elements is unknown (`defined` then has a length of `length(fields)-1`). """ Core.PartialStruct +function Core.PartialStruct(@nospecialize(typ), fields::Vector{Any}) + ndefined = lastindex(fields) + fields[end] === Vararg && (ndefined -= 1) + t = typ + (isa(t, UnionAll) || isa(t, Union)) && (t = argument_datatype(t)) + nfields = isa(t, DataType) ? datatype_fieldcount(t) : nothing + if nfields === nothing || nfields == ndefined + defined = trues(ndefined) + else + @assert nfields ≥ ndefined + defined = falses(nfields) + for i in 1:ndefined + defined[i] = true + end + end + Core._PartialStruct(typ, defined, fields) +end + +(==)(a::PartialStruct, b::PartialStruct) = a.typ === b.typ && a.defined == b.defined && a.fields == b.fields + +function Base.getproperty(pstruct::Core.PartialStruct, name::Symbol) + name === :defined && return getfield(pstruct, :defined)::BitVector + getfield(pstruct, name) +end + """ struct InterConditional slot::Int diff --git a/src/jltypes.c b/src/jltypes.c index dc7422460c0cf..7c3edf07d9164 100644 --- a/src/jltypes.c +++ b/src/jltypes.c @@ -3691,9 +3691,9 @@ void jl_init_types(void) JL_GC_DISABLED jl_emptysvec, 0, 0, 1); jl_partial_struct_type = jl_new_datatype(jl_symbol("PartialStruct"), core, jl_any_type, jl_emptysvec, - jl_perm_symsvec(2, "typ", "fields"), - jl_svec2(jl_any_type, jl_array_any_type), - jl_emptysvec, 0, 0, 2); + jl_perm_symsvec(3, "typ", "defined", "fields"), + jl_svec(3, jl_any_type, jl_any_type, jl_array_any_type), + jl_emptysvec, 0, 0, 3); jl_interconditional_type = jl_new_datatype(jl_symbol("InterConditional"), core, jl_any_type, jl_emptysvec, jl_perm_symsvec(3, "slot", "thentype", "elsetype"), From 67b12f81f2dbb626d339f38cd194e62286c0dd10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Belmant?= Date: Fri, 7 Feb 2025 07:21:01 -0500 Subject: [PATCH 02/16] Add constprop test --- Compiler/test/irpasses.jl | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/Compiler/test/irpasses.jl b/Compiler/test/irpasses.jl index 0d0b0e4daf83e..22910b49883ed 100644 --- a/Compiler/test/irpasses.jl +++ b/Compiler/test/irpasses.jl @@ -2042,3 +2042,23 @@ let src = code_typed1(()) do end @test count(iscall((src, setfield!)), src.code) == 1 end + +module _Partials_irpasses + mutable struct Partial + x::String + y::Integer + z::Any + Partial() = new() + end +end + +# once `isdefined(p, name)` holds, this information should be kept +# as a `PartialStruct` over `p` for subsequent constant propagation. +let src = code_typed1(()) do + p = _Partials_irpasses.Partial() + invokelatest(identity, p) + isdefined(p, :z) && isdefined(p, :x) || return nothing + isdefined(p, :x) & isdefined(p, :z) + end + @test count(iscall((src, isdefined)), src.code) == 2 +end From 58827f0a707d758876392fa1b254557ef59ad35b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Belmant?= Date: Fri, 7 Feb 2025 07:44:38 -0500 Subject: [PATCH 03/16] Satisfy typo checker --- Compiler/src/typelimits.jl | 6 +++--- base/coreir.jl | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/Compiler/src/typelimits.jl b/Compiler/src/typelimits.jl index c7210bbc35569..68c4ef8825b7a 100644 --- a/Compiler/src/typelimits.jl +++ b/Compiler/src/typelimits.jl @@ -655,10 +655,10 @@ end typea::Union{PartialStruct, Const} typeb::Union{PartialStruct, Const} defined = defined_fields(typea, typeb) - ndefined = _count(defined) - ndefined == 0 && return nothing + ndef = _count(defined) + ndef == 0 && return nothing fields = [] - anyrefine = ndefined > datatype_min_ninitialized(aty) + anyrefine = ndef > datatype_min_ninitialized(aty) for (i, def) in enumerate(defined) def || continue ai = getfield_tfunc(𝕃, typea, Const(i)) diff --git a/base/coreir.jl b/base/coreir.jl index 6b7745342272d..f4cc87661a158 100644 --- a/base/coreir.jl +++ b/base/coreir.jl @@ -38,17 +38,17 @@ exact number of elements is unknown (`defined` then has a length of `length(fiel Core.PartialStruct function Core.PartialStruct(@nospecialize(typ), fields::Vector{Any}) - ndefined = lastindex(fields) - fields[end] === Vararg && (ndefined -= 1) + ndef = lastindex(fields) + fields[end] === Vararg && (ndef -= 1) t = typ (isa(t, UnionAll) || isa(t, Union)) && (t = argument_datatype(t)) nfields = isa(t, DataType) ? datatype_fieldcount(t) : nothing - if nfields === nothing || nfields == ndefined - defined = trues(ndefined) + if nfields === nothing || nfields == ndef + defined = trues(ndef) else - @assert nfields ≥ ndefined + @assert nfields ≥ ndef defined = falses(nfields) - for i in 1:ndefined + for i in 1:ndef defined[i] = true end end From 1274f23a8a644a2ed901fd8246ab7bc80130a3cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Belmant?= Date: Fri, 7 Feb 2025 07:46:19 -0500 Subject: [PATCH 04/16] Remove trailing whitespace --- Compiler/test/irpasses.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Compiler/test/irpasses.jl b/Compiler/test/irpasses.jl index 22910b49883ed..91348ca1203c5 100644 --- a/Compiler/test/irpasses.jl +++ b/Compiler/test/irpasses.jl @@ -2051,7 +2051,7 @@ module _Partials_irpasses Partial() = new() end end - + # once `isdefined(p, name)` holds, this information should be kept # as a `PartialStruct` over `p` for subsequent constant propagation. let src = code_typed1(()) do From ede6dc77905c208c8d64bb153e01637bd150ace2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Belmant?= Date: Fri, 7 Feb 2025 09:15:46 -0500 Subject: [PATCH 05/16] Fix binding access --- Compiler/test/inference.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Compiler/test/inference.jl b/Compiler/test/inference.jl index a5d6f754048d7..6a55a62c74c2d 100644 --- a/Compiler/test/inference.jl +++ b/Compiler/test/inference.jl @@ -4811,7 +4811,7 @@ let ⊑ = Compiler.partialorder(Compiler.fallback_lattice) 𝕃 = Compiler.fallback_lattice Const, PartialStruct = Core.Const, Core.PartialStruct form_partially_defined_struct = Compiler.form_partially_defined_struct - M = Partials_inference + M = _Partials_inference Partial, Partial2, Partial3 = M.Partial, M.Partial2, M.Partial3 @test (Const((1,2)) ⊑ PartialStruct(𝕃, Tuple{Int,Int}, Any[Const(1),Int])) From f361ab2818d167c2f7193e39ecd72cf3779da6f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Belmant?= Date: Fri, 7 Feb 2025 09:23:11 -0500 Subject: [PATCH 06/16] Retrigger tests From 67b916839a083545315530e5c6f17df2c1ade574 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Belmant?= Date: Tue, 11 Feb 2025 08:22:50 -0500 Subject: [PATCH 07/16] defined -> undef --- Compiler/src/abstractinterpretation.jl | 8 ++++---- Compiler/src/typelattice.jl | 22 +++++++++++----------- Compiler/src/typelimits.jl | 24 ++++++++++++------------ Compiler/test/inference.jl | 16 ++++++++-------- base/boot.jl | 2 +- base/coreir.jl | 22 +++++++++++----------- src/jltypes.c | 2 +- 7 files changed, 48 insertions(+), 48 deletions(-) diff --git a/Compiler/src/abstractinterpretation.jl b/Compiler/src/abstractinterpretation.jl index 973543c95b18c..ff5a2b1de2e1c 100644 --- a/Compiler/src/abstractinterpretation.jl +++ b/Compiler/src/abstractinterpretation.jl @@ -2153,11 +2153,11 @@ function form_partially_defined_struct(@nospecialize(obj), @nospecialize(name)) fldidx > nminfld || return nothing fields = Any[fieldtype(objt0, i) for i in 1:nminfld] nmaxfld = something(datatype_fieldcount(objt), fldidx) - defined = falses(nmaxfld) - for i in 1:nminfld defined[i] = true end - defined[fldidx] = true + undef = trues(nmaxfld) + for i in 1:nminfld undef[i] = false end + undef[fldidx] = false push!(fields, fieldtype(objt0, fldidx)) - return PartialStruct(fallback_lattice, objt0, defined, fields) + return PartialStruct(fallback_lattice, objt0, undef, fields) end function abstract_call_unionall(interp::AbstractInterpreter, argtypes::Vector{Any}, call::CallMeta) diff --git a/Compiler/src/typelattice.jl b/Compiler/src/typelattice.jl index a9322268777fa..daf169a77195d 100644 --- a/Compiler/src/typelattice.jl +++ b/Compiler/src/typelattice.jl @@ -431,14 +431,14 @@ end return false end end - length(a.defined) ≥ length(b.defined) || return false - n = length(b.defined) + length(a.undef) ≥ length(b.undef) || return false + n = length(b.undef) ai = bi = 0 for i in 1:n - ai += a.defined[i] - bi += b.defined[i] - !a.defined[i] && b.defined[i] && return false - !b.defined[i] && continue + ai += !a.undef[i] + bi += !b.undef[i] + a.undef[i] && !b.undef[i] && return false + b.undef[i] && continue # Field is defined for both `a` and `b` af = a.fields[ai] bf = b.fields[bi] @@ -478,8 +478,8 @@ end nf = nfields(a.val) bi = 0 for i in 1:nf - !isdefined(a.val, i) && b.defined[i] && return false - !b.defined[i] && continue + !isdefined(a.val, i) && !b.undef[i] && return false + b.undef[i] && continue bfᵢ = b.fields[bi += 1] if i == nf bfᵢ = unwrapva(bfᵢ) @@ -550,7 +550,7 @@ end if isa(a, PartialStruct) isa(b, PartialStruct) || return false length(a.fields) == length(b.fields) || return false - a.defined == b.defined || return false + a.undef == b.undef || return false widenconst(a) == widenconst(b) || return false a.fields === b.fields && return true # fast path for i in 1:length(a.fields) @@ -764,9 +764,9 @@ function Core.PartialStruct(::AbstractLattice, @nospecialize(typ), fields::Vecto return PartialStruct(typ, fields) end -function Core.PartialStruct(::AbstractLattice, @nospecialize(typ), defined::BitVector, fields::Vector{Any}) +function Core.PartialStruct(::AbstractLattice, @nospecialize(typ), undef::BitVector, fields::Vector{Any}) for i = 1:length(fields) assert_nested_slotwrapper(fields[i]) end - return Core._PartialStruct(typ, defined, fields) + return Core._PartialStruct(typ, undef, fields) end diff --git a/Compiler/src/typelimits.jl b/Compiler/src/typelimits.jl index 68c4ef8825b7a..ef03b1a238d15 100644 --- a/Compiler/src/typelimits.jl +++ b/Compiler/src/typelimits.jl @@ -326,35 +326,35 @@ function n_initialized(t::Const) return something(findfirst(i::Int->!isdefined(t.val,i), 1:nf), nf+1)-1 end -defined_fields(pstruct::PartialStruct) = pstruct.defined +defined_fields(pstruct::PartialStruct) = _bitvector(ntuple(i -> !pstruct.undef[i], length(pstruct.undef))) function defined_field_index(pstruct::PartialStruct, fi) i = 0 for iter in 1:fi - iter ≤ length(pstruct.defined) && pstruct.defined[iter] && (i += 1) + iter ≤ length(pstruct.undef) && !pstruct.undef[iter] && (i += 1) end i end get_defined_field(pstruct::PartialStruct, fi) = pstruct.fields[defined_field_index(pstruct, fi)] -is_field_defined(pstruct::PartialStruct, fi) = get(pstruct.defined, fi, false) +is_field_defined(pstruct::PartialStruct, fi) = !get(pstruct.undef, fi, true) function define_field(pstruct::PartialStruct, fi, @nospecialize(ft)) - n = length(pstruct.defined) - if fi ≤ n && pstruct.defined[fi] + n = length(pstruct.undef) + if fi ≤ n && !pstruct.undef[fi] # XXX: merge new information? # `setfield!(..., rand()); setfield!(..., 2.0)` return nothing end - defined = falses(max(fi, n)) + undef = trues(max(fi, n)) for i in 1:n - defined[i] = pstruct.defined[i] + undef[i] = pstruct.undef[i] end fields = copy(pstruct.fields) - defined[fi] = true + undef[fi] = false i = defined_field_index(pstruct, fi) insert!(fields, i + 1, ft) - PartialStruct(fallback_lattice, pstruct.typ, defined, fields) + PartialStruct(fallback_lattice, pstruct.typ, undef, fields) end # needed while we are missing functions such as broadcasting or ranges @@ -401,9 +401,9 @@ end isa(typeb, Const) || isa(typeb, PartialStruct) || return false @assert all(x & y == x for (x, y) in zip(defined_fields(typea), defined_fields(typeb))) "typeb ⊑ typea is assumed" fi = 0 - nf = length(typea.defined) + nf = length(typea.undef) for i = 1:nf - typea.defined[i] || continue + !typea.undef[i] || continue fi += 1 ai = unwrapva(typea.fields[fi]) bi = fieldtype(aty, i) @@ -702,7 +702,7 @@ end # handle that in the main loop above to get a more accurate type. push!(fields, Vararg) end - anyrefine && return PartialStruct(𝕃, aty, defined, fields) + anyrefine && return PartialStruct(𝕃, aty, _bitvector(ntuple(i -> !defined[i], length(defined))), fields) end return nothing end diff --git a/Compiler/test/inference.jl b/Compiler/test/inference.jl index b20f97f44f415..e3159e13bbda8 100644 --- a/Compiler/test/inference.jl +++ b/Compiler/test/inference.jl @@ -4835,9 +4835,9 @@ let ⊑ = Compiler.partialorder(Compiler.fallback_lattice) @test t ⊑ Union{Tuple{Bool,Bool},Tuple{Bool,Bool,Int}} t = PartialStruct(𝕃, Tuple{Int, Int}, Any[Const(1), Int]) - @test t.defined == [true, true] + @test t.undef == [false, false] t = PartialStruct(𝕃, Partial, Any[String, Const(2)]) - @test t.defined == [true, true, false] + @test t.undef == [false, false, true] @test t ⊑ t && t ⊔ t === t t1 = PartialStruct(𝕃, Partial, Any[String, Const(3)]) @@ -4846,19 +4846,19 @@ let ⊑ = Compiler.partialorder(Compiler.fallback_lattice) t3 = t1 ⊔ t2 @test t3.fields == Any[String, Int] - t1 = PartialStruct(𝕃, Partial, BitVector([false, true, true]), Any[Int, Const(3)]) + t1 = PartialStruct(𝕃, Partial, BitVector([true, false, false]), Any[Int, Const(3)]) @test t1 ⊑ t1 && t1 ⊔ t1 === t1 - t2 = PartialStruct(𝕃, Partial, BitVector([true, false, true]), Any[Const("x"), Int]) + t2 = PartialStruct(𝕃, Partial, BitVector([false, true, false]), Any[Const("x"), Int]) t3 = t1 ⊔ t2 - @test t3.defined == [false, false, true] && t3.fields == Any[Int] + @test t3.undef == [true, true, false] && t3.fields == Any[Int] t1 = PartialStruct(𝕃, Tuple, Any[Int, String, Vararg]) - @test t1.defined == [true, true] + @test t1.undef == [false, false] @test t1 ⊑ t1 && t1 ⊔ t1 == t1 t2 = PartialStruct(𝕃, Tuple, Any[Int, Any]) @test !(t1 ⊑ t2) && !(t2 ⊑ t1) t3 = t1 ⊔ t2 - @test t3.defined == [true, true] && t3.fields == Any[Int, Any] + @test t3.undef == [false, false] && t3.fields == Any[Int, Any] t2 = PartialStruct(𝕃, Tuple, Any[Int, Any, Vararg]) @test t1 ⊑ t2 @test t1 ⊔ t2 === t2 @@ -4877,7 +4877,7 @@ let ⊑ = Compiler.partialorder(Compiler.fallback_lattice) t = form_partially_defined_struct(Partial3, Const(:y)) @test t == PartialStruct(𝕃, Partial3, Any[Int, String]) t = form_partially_defined_struct(Partial3, Const(:z)) - @test t == PartialStruct(𝕃, Partial3, BitVector([true, false, true]), Any[Int, Float64]) + @test t == PartialStruct(𝕃, Partial3, BitVector([false, true, false]), Any[Int, Float64]) t = form_partially_defined_struct(t, Const(:y)) @test t == PartialStruct(𝕃, Partial3, Any[Int, String, Float64]) t = PartialStruct(𝕃, Partial3, Any[Int, String]) diff --git a/base/boot.jl b/base/boot.jl index a4f955001a570..dab1904bb64f7 100644 --- a/base/boot.jl +++ b/base/boot.jl @@ -540,7 +540,7 @@ eval(Core, quote UpsilonNode(@nospecialize(val)) = $(Expr(:new, :UpsilonNode, :val)) UpsilonNode() = $(Expr(:new, :UpsilonNode)) Const(@nospecialize(v)) = $(Expr(:new, :Const, :v)) - _PartialStruct(@nospecialize(typ), defined, fields::Array{Any, 1}) = $(Expr(:new, :PartialStruct, :typ, :defined, :fields)) + _PartialStruct(@nospecialize(typ), undef, fields::Array{Any, 1}) = $(Expr(:new, :PartialStruct, :typ, :undef, :fields)) PartialOpaque(@nospecialize(typ), @nospecialize(env), parent::MethodInstance, source) = $(Expr(:new, :PartialOpaque, :typ, :env, :parent, :source)) InterConditional(slot::Int, @nospecialize(thentype), @nospecialize(elsetype)) = $(Expr(:new, :InterConditional, :slot, :thentype, :elsetype)) MethodMatch(@nospecialize(spec_types), sparams::SimpleVector, method::Method, fully_covers::Bool) = $(Expr(:new, :MethodMatch, :spec_types, :sparams, :method, :fully_covers)) diff --git a/base/coreir.jl b/base/coreir.jl index 8c789e3c69c40..8e95517b060c2 100644 --- a/base/coreir.jl +++ b/base/coreir.jl @@ -14,7 +14,7 @@ Core.Const """ struct PartialStruct typ - defined::BitVector # sorted list of fields that are known to be defined + undef::BitVector # represents whether a given field may be undefined fields::Vector{Any} # i-th element describes the lattice element for the i-th defined field end @@ -24,16 +24,16 @@ some elements are known to be constants or a struct whose `Any`-typed field is i with `Int` values. - `typ` indicates the type of the object -- `defined` records which fields are defined +- `undef` records which fields are possibly undefined - `fields` holds the lattice elements corresponding to each defined field of the object -If `typ` is a struct, `defined` represents whether the corresponding field of the struct is guaranteed to be -initialized. For any defined field (`defined[i] === true`), there is a corresponding `fields` element +If `typ` is a struct, `undef` represents whether the corresponding field of the struct is guaranteed to be +initialized. For any defined field (`undef[i] === false`), there is a corresponding `fields` element which provides information about the type of the defined field. If `typ` is a tuple, the last element of `fields` may be `Vararg`. In this case, it is guaranteed that the number of elements in the tuple is at least `length(fields)-1`, but the -exact number of elements is unknown (`defined` then has a length of `length(fields)-1`). +exact number of elements is unknown (`undef` then has a length of `length(fields)-1`). """ Core.PartialStruct @@ -44,21 +44,21 @@ function Core.PartialStruct(@nospecialize(typ), fields::Vector{Any}) (isa(t, UnionAll) || isa(t, Union)) && (t = argument_datatype(t)) nfields = isa(t, DataType) ? datatype_fieldcount(t) : nothing if nfields === nothing || nfields == ndef - defined = trues(ndef) + undef = falses(ndef) else @assert nfields ≥ ndef - defined = falses(nfields) + undef = trues(nfields) for i in 1:ndef - defined[i] = true + undef[i] = false end end - Core._PartialStruct(typ, defined, fields) + Core._PartialStruct(typ, undef, fields) end -(==)(a::PartialStruct, b::PartialStruct) = a.typ === b.typ && a.defined == b.defined && a.fields == b.fields +(==)(a::PartialStruct, b::PartialStruct) = a.typ === b.typ && a.undef == b.undef && a.fields == b.fields function Base.getproperty(pstruct::Core.PartialStruct, name::Symbol) - name === :defined && return getfield(pstruct, :defined)::BitVector + name === :undef && return getfield(pstruct, :undef)::BitVector getfield(pstruct, name) end diff --git a/src/jltypes.c b/src/jltypes.c index 7c3edf07d9164..98294038f0395 100644 --- a/src/jltypes.c +++ b/src/jltypes.c @@ -3691,7 +3691,7 @@ void jl_init_types(void) JL_GC_DISABLED jl_emptysvec, 0, 0, 1); jl_partial_struct_type = jl_new_datatype(jl_symbol("PartialStruct"), core, jl_any_type, jl_emptysvec, - jl_perm_symsvec(3, "typ", "defined", "fields"), + jl_perm_symsvec(3, "typ", "undef", "fields"), jl_svec(3, jl_any_type, jl_any_type, jl_array_any_type), jl_emptysvec, 0, 0, 3); From 3857bafe73c176005dc3fbee25d1dca593392ea7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Belmant?= Date: Fri, 14 Feb 2025 05:19:11 -0500 Subject: [PATCH 08/16] Record undef & field information for all PartialStruct fields --- Compiler/src/Compiler.jl | 2 +- Compiler/src/abstractinterpretation.jl | 11 ++-- Compiler/src/tfuncs.jl | 2 +- Compiler/src/typelattice.jl | 43 ++++++------ Compiler/src/typelimits.jl | 90 +++++++++++++------------- Compiler/src/typeutils.jl | 33 ---------- Compiler/test/inference.jl | 13 ++-- base/coreir.jl | 48 ++++++++++---- base/essentials.jl | 33 ++++++++++ 9 files changed, 149 insertions(+), 126 deletions(-) diff --git a/Compiler/src/Compiler.jl b/Compiler/src/Compiler.jl index 2c68729ee1dc2..394039c75877f 100644 --- a/Compiler/src/Compiler.jl +++ b/Compiler/src/Compiler.jl @@ -67,7 +67,7 @@ using Base: @_foldable_meta, @_gc_preserve_begin, @_gc_preserve_end, @nospeciali partition_restriction, quoted, rename_unionall, rewrap_unionall, specialize_method, structdiff, tls_world_age, unconstrain_vararg_length, unionlen, uniontype_layout, uniontypes, unsafe_convert, unwrap_unionall, unwrapva, vect, widen_diagonal, - _uncompressed_ir, maybe_add_binding_backedge! + _uncompressed_ir, maybe_add_binding_backedge!, datatype_min_ninitialized using Base.Order import Base: ==, _topmod, append!, convert, copy, copy!, findall, first, get, get!, diff --git a/Compiler/src/abstractinterpretation.jl b/Compiler/src/abstractinterpretation.jl index ff5a2b1de2e1c..9f774637e70f8 100644 --- a/Compiler/src/abstractinterpretation.jl +++ b/Compiler/src/abstractinterpretation.jl @@ -2151,12 +2151,10 @@ function form_partially_defined_struct(@nospecialize(obj), @nospecialize(name)) isa(obj, PartialStruct) && return define_field(obj, fldidx, fieldtype(objt0, fldidx)) nminfld = datatype_min_ninitialized(objt) fldidx > nminfld || return nothing - fields = Any[fieldtype(objt0, i) for i in 1:nminfld] + fields = collect(Any, fieldtypes(objt0)) nmaxfld = something(datatype_fieldcount(objt), fldidx) undef = trues(nmaxfld) - for i in 1:nminfld undef[i] = false end undef[fldidx] = false - push!(fields, fieldtype(objt0, fldidx)) return PartialStruct(fallback_lattice, objt0, undef, fields) end @@ -3137,7 +3135,7 @@ function abstract_eval_splatnew(interp::AbstractInterpreter, e::Expr, sstate::St all(i::Int -> ⊑(𝕃ᵢ, (at.fields::Vector{Any})[i], fieldtype(t, i)), 1:n) end)) nothrow = isexact - rt = PartialStruct(𝕃ᵢ, rt, at.fields::Vector{Any}) + rt = PartialStruct(𝕃ᵢ, rt, at.undef, at.fields::Vector{Any}) end else rt = refine_partial_type(rt) @@ -3718,8 +3716,7 @@ end @nospecializeinfer function widenreturn_partials(𝕃ᵢ::PartialsLattice, @nospecialize(rt), info::BestguessInfo) if isa(rt, PartialStruct) fields = copy(rt.fields) - anyrefine = !isvarargtype(rt.fields[end]) && - length(rt.fields) > datatype_min_ninitialized(rt.typ) + anyrefine = refines_definedness_information(rt) 𝕃 = typeinf_lattice(info.interp) ⊏ = strictpartialorder(𝕃) for i in 1:length(fields) @@ -3731,7 +3728,7 @@ end end fields[i] = a end - anyrefine && return PartialStruct(𝕃ᵢ, rt.typ, fields) + anyrefine && return PartialStruct(𝕃ᵢ, rt.typ, rt.undef, fields) end if isa(rt, PartialOpaque) return rt # XXX: this case was missed in #39512 diff --git a/Compiler/src/tfuncs.jl b/Compiler/src/tfuncs.jl index 0c10000a1901f..c384885df5c3f 100644 --- a/Compiler/src/tfuncs.jl +++ b/Compiler/src/tfuncs.jl @@ -1142,7 +1142,7 @@ end if isa(name, Const) nv = _getfield_fieldindex(sty, name) if isa(nv, Int) && is_field_defined(s00, nv) - return unwrapva(get_defined_field(s00, nv)) + return unwrapva(s00.fields[nv]) end end s00 = s diff --git a/Compiler/src/typelattice.jl b/Compiler/src/typelattice.jl index daf169a77195d..e439ce929f0a5 100644 --- a/Compiler/src/typelattice.jl +++ b/Compiler/src/typelattice.jl @@ -325,16 +325,21 @@ end end end return Conditional(slot, - thenfields === nothing ? Bottom : PartialStruct(fallback_lattice, vartyp.typ, thenfields), - elsefields === nothing ? Bottom : PartialStruct(fallback_lattice, vartyp.typ, elsefields)) + thenfields === nothing ? Bottom : PartialStruct(fallback_lattice, vartyp.typ, vartyp.undef, thenfields), + elsefields === nothing ? Bottom : PartialStruct(fallback_lattice, vartyp.typ, vartyp.undef, elsefields)) else vartyp_widened = widenconst(vartyp) thenfields = thentype === Bottom ? nothing : Any[] elsefields = elsetype === Bottom ? nothing : Any[] - for i in 1:fieldcount(vartyp_widened) + nf = fieldcount(vartyp_widened) + undef = trues(nf) + for i in 1:nf if i == fldidx thenfields === nothing || push!(thenfields, thentype) elsefields === nothing || push!(elsefields, elsetype) + if thenfields === nothing && elsefields === nothing + undef[i] = false # this field was already accessed + end else t = fieldtype(vartyp_widened, i) thenfields === nothing || push!(thenfields, t) @@ -342,8 +347,8 @@ end end end return Conditional(slot, - thenfields === nothing ? Bottom : PartialStruct(fallback_lattice, vartyp_widened, thenfields), - elsefields === nothing ? Bottom : PartialStruct(fallback_lattice, vartyp_widened, elsefields)) + thenfields === nothing ? Bottom : PartialStruct(fallback_lattice, vartyp_widened, undef, thenfields), + elsefields === nothing ? Bottom : PartialStruct(fallback_lattice, vartyp_widened, undef, elsefields)) end end @@ -432,16 +437,12 @@ end end end length(a.undef) ≥ length(b.undef) || return false - n = length(b.undef) - ai = bi = 0 - for i in 1:n - ai += !a.undef[i] - bi += !b.undef[i] - a.undef[i] && !b.undef[i] && return false - b.undef[i] && continue + for i in 1:length(b.fields) + !is_field_defined(a, i) && is_field_defined(b, i) && return false + !is_field_defined(b, i) && continue # Field is defined for both `a` and `b` - af = a.fields[ai] - bf = b.fields[bi] + af = a.fields[i] + bf = b.fields[i] if i == length(b.fields) if isvarargtype(af) # If `af` is vararg, so must bf by the <: above @@ -473,14 +474,16 @@ end else widea <: wideb || return false # for structs we need to check that `a` has more information than `b` that may be partially initialized - n_initialized(a) ≥ length(b.fields) || return false + # it may happen that `b` has more information beyond the first undefined field + # but in this case we choose `Const` nonetheless. + n_initialized(a) ≥ n_initialized(b) || return false end nf = nfields(a.val) - bi = 0 for i in 1:nf - !isdefined(a.val, i) && !b.undef[i] && return false - b.undef[i] && continue - bfᵢ = b.fields[bi += 1] + isdefined(a.val, i) || continue # since ∀ T Union{} ⊑ T + i > length(b.fields) && break # `a` has more information than `b` that is partially initialized struct + is_field_defined(b, i) || continue # `a` gives a decisive answer as to whether the field is defined or undefined + bfᵢ = b.fields[i] if i == nf bfᵢ = unwrapva(bfᵢ) end @@ -768,5 +771,5 @@ function Core.PartialStruct(::AbstractLattice, @nospecialize(typ), undef::BitVec for i = 1:length(fields) assert_nested_slotwrapper(fields[i]) end - return Core._PartialStruct(typ, undef, fields) + return PartialStruct(typ, undef, fields) end diff --git a/Compiler/src/typelimits.jl b/Compiler/src/typelimits.jl index ef03b1a238d15..94a33973e178b 100644 --- a/Compiler/src/typelimits.jl +++ b/Compiler/src/typelimits.jl @@ -326,34 +326,34 @@ function n_initialized(t::Const) return something(findfirst(i::Int->!isdefined(t.val,i), 1:nf), nf+1)-1 end -defined_fields(pstruct::PartialStruct) = _bitvector(ntuple(i -> !pstruct.undef[i], length(pstruct.undef))) +function n_initialized(pstruct::PartialStruct) + i = findfirst(pstruct.undef) + i !== nothing && return i - 1 + length(pstruct.undef) +end -function defined_field_index(pstruct::PartialStruct, fi) - i = 0 - for iter in 1:fi - iter ≤ length(pstruct.undef) && !pstruct.undef[iter] && (i += 1) - end - i +maybeundef_fields(pstruct::PartialStruct) = pstruct.undef +is_field_defined(pstruct::PartialStruct, fi) = 1 ≤ fi ≤ length(pstruct.undef) && !pstruct.undef[fi] + +refines_definedness_information(pstruct::PartialStruct) = !isvarargtype(pstruct.fields[end]) && refines_definedness_information(pstruct.typ, pstruct.undef) +function refines_definedness_information(@nospecialize(typ), undef) + nflds = length(undef) + something(findfirst(undef), nflds + 1) > datatype_min_ninitialized(typ) + 1 end -get_defined_field(pstruct::PartialStruct, fi) = pstruct.fields[defined_field_index(pstruct, fi)] -is_field_defined(pstruct::PartialStruct, fi) = !get(pstruct.undef, fi, true) +# Returns an iterator over contiguously defined fields +function defined_fields(pstruct::PartialStruct) + i = findfirst(pstruct.undef) + i === nothing && return pstruct.fields + Any[pstruct.fields[i] for i in 1:(i - 1)] +end function define_field(pstruct::PartialStruct, fi, @nospecialize(ft)) - n = length(pstruct.undef) - if fi ≤ n && !pstruct.undef[fi] - # XXX: merge new information? - # `setfield!(..., rand()); setfield!(..., 2.0)` - return nothing - end - undef = trues(max(fi, n)) - for i in 1:n - undef[i] = pstruct.undef[i] - end + !pstruct.undef[fi] && return nothing # no new information to be gained + undef = copy(pstruct.undef) fields = copy(pstruct.fields) undef[fi] = false - i = defined_field_index(pstruct, fi) - insert!(fields, i + 1, ft) + fields[fi] = ft PartialStruct(fallback_lattice, pstruct.typ, undef, fields) end @@ -379,16 +379,17 @@ end #- -function defined_fields(t::Const) +maybeundef_fields(t::Const) = undefined_fields(t) +function undefined_fields(t::Const) nf = nfields(t.val) - _bitvector(ntuple(i -> isdefined(t.val, i), nf)) + _bitvector(ntuple(i -> !isdefined(t.val, i), nf)) end -function defined_fields(x, y) - xdef = defined_fields(x) - ydef = defined_fields(y) +function maybeundef_fields(x, y) + xdef = maybeundef_fields(x) + ydef = maybeundef_fields(y) n = min(length(xdef), length(ydef)) - _bitvector(ntuple(i -> xdef[i] & ydef[i], n)) + _bitvector(ntuple(i -> xdef[i] | ydef[i], n)) end # A simplified type_more_complex query over the extended lattice @@ -398,14 +399,16 @@ end typea === typeb && return true if typea isa PartialStruct aty = widenconst(typea) - isa(typeb, Const) || isa(typeb, PartialStruct) || return false - @assert all(x & y == x for (x, y) in zip(defined_fields(typea), defined_fields(typeb))) "typeb ⊑ typea is assumed" - fi = 0 - nf = length(typea.undef) - for i = 1:nf - !typea.undef[i] || continue - fi += 1 - ai = unwrapva(typea.fields[fi]) + if typeb isa Const + @assert n_initialized(typea) ≤ n_initialized(typeb) "typeb ⊑ typea is assumed" + elseif typeb isa PartialStruct + @assert length(typea.fields) ≤ length(typeb.fields) && + all(!b | a for (a, b) in zip(typea.undef, typeb.undef)) "typeb ⊑ typea is assumed" + else + return false + end + for i = 1:length(typea.fields) + ai = unwrapva(typea.fields[i]) bi = fieldtype(aty, i) is_lattice_equal(𝕃, ai, bi) && continue tni = _typename(widenconst(ai)) @@ -654,13 +657,12 @@ end if aty === bty && !isType(aty) typea::Union{PartialStruct, Const} typeb::Union{PartialStruct, Const} - defined = defined_fields(typea, typeb) - ndef = _count(defined) - ndef == 0 && return nothing - fields = [] - anyrefine = ndef > datatype_min_ninitialized(aty) - for (i, def) in enumerate(defined) - def || continue + undefined = maybeundef_fields(typea, typeb) + all(undefined) && return nothing + nflds = length(undefined) + fields = Vector{Any}(undef, nflds) + anyrefine = refines_definedness_information(aty, undefined) + for i = 1:nflds ai = getfield_tfunc(𝕃, typea, Const(i)) bi = getfield_tfunc(𝕃, typeb, Const(i)) ft = fieldtype(aty, i) @@ -690,7 +692,7 @@ end tyi = ft end end - push!(fields, tyi) + fields[i] = tyi if !anyrefine anyrefine = has_nontrivial_extended_info(𝕃, tyi) || # extended information ⋤(𝕃, tyi, ft) # just a type-level information, but more precise than the declared type @@ -702,7 +704,7 @@ end # handle that in the main loop above to get a more accurate type. push!(fields, Vararg) end - anyrefine && return PartialStruct(𝕃, aty, _bitvector(ntuple(i -> !defined[i], length(defined))), fields) + anyrefine && return PartialStruct(𝕃, aty, undefined, fields) end return nothing end diff --git a/Compiler/src/typeutils.jl b/Compiler/src/typeutils.jl index d588a9aee1a6c..a368ef805d98f 100644 --- a/Compiler/src/typeutils.jl +++ b/Compiler/src/typeutils.jl @@ -67,39 +67,6 @@ function isknownlength(t::DataType) return isdefined(va, :N) && va.N isa Int end -# Compute the minimum number of initialized fields for a particular datatype -# (therefore also a lower bound on the number of fields) -function datatype_min_ninitialized(@nospecialize t0) - t = unwrap_unionall(t0) - t isa DataType || return 0 - isabstracttype(t) && return 0 - if t.name === _NAMEDTUPLE_NAME - names, types = t.parameters[1], t.parameters[2] - if names isa Tuple - return length(names) - end - t = argument_datatype(types) - t isa DataType || return 0 - t.name === Tuple.name || return 0 - end - if t.name === Tuple.name - n = length(t.parameters) - n == 0 && return 0 - va = t.parameters[n] - if isvarargtype(va) - n -= 1 - if isdefined(va, :N) - va = va.N - if va isa Int - n += va - end - end - end - return n - end - return length(t.name.names) - t.name.n_uninitialized -end - has_concrete_subtype(d::DataType) = d.flags & 0x0020 == 0x0020 # n.b. often computed only after setting the type and layout fields # determine whether x is a valid lattice element diff --git a/Compiler/test/inference.jl b/Compiler/test/inference.jl index e3159e13bbda8..ea3ee18cad86a 100644 --- a/Compiler/test/inference.jl +++ b/Compiler/test/inference.jl @@ -4727,7 +4727,7 @@ end c = a ⊔ b @test a ⊑ c && b ⊑ c @test c isa PartialStruct - @test length(c.fields) == 1 + @test c.undef == a.undef == [0, 1, 1] end let T = Base.ImmutableDict{Number,Number} a = PartialStruct(𝕃, T, Any[T]) @@ -4838,19 +4838,20 @@ let ⊑ = Compiler.partialorder(Compiler.fallback_lattice) @test t.undef == [false, false] t = PartialStruct(𝕃, Partial, Any[String, Const(2)]) @test t.undef == [false, false, true] + @test t.fields == Any[String, Const(2), Any] @test t ⊑ t && t ⊔ t === t t1 = PartialStruct(𝕃, Partial, Any[String, Const(3)]) t2 = PartialStruct(𝕃, Partial, Any[Const("x"), Int]) @test !(t1 ⊑ t2) && !(t2 ⊑ t1) t3 = t1 ⊔ t2 - @test t3.fields == Any[String, Int] + @test t3.fields == Any[String, Int, Any] - t1 = PartialStruct(𝕃, Partial, BitVector([true, false, false]), Any[Int, Const(3)]) + t1 = PartialStruct(𝕃, Partial, BitVector([true, false, false]), Any[String, Int, Const(3)]) @test t1 ⊑ t1 && t1 ⊔ t1 === t1 - t2 = PartialStruct(𝕃, Partial, BitVector([false, true, false]), Any[Const("x"), Int]) + t2 = PartialStruct(𝕃, Partial, BitVector([false, true, false]), Any[Const("x"), Int, Any]) t3 = t1 ⊔ t2 - @test t3.undef == [true, true, false] && t3.fields == Any[Int] + @test t3 === Partial t1 = PartialStruct(𝕃, Tuple, Any[Int, String, Vararg]) @test t1.undef == [false, false] @@ -4877,7 +4878,7 @@ let ⊑ = Compiler.partialorder(Compiler.fallback_lattice) t = form_partially_defined_struct(Partial3, Const(:y)) @test t == PartialStruct(𝕃, Partial3, Any[Int, String]) t = form_partially_defined_struct(Partial3, Const(:z)) - @test t == PartialStruct(𝕃, Partial3, BitVector([false, true, false]), Any[Int, Float64]) + @test t == PartialStruct(𝕃, Partial3, BitVector([false, true, false]), Any[Int, String, Float64]) t = form_partially_defined_struct(t, Const(:y)) @test t == PartialStruct(𝕃, Partial3, Any[Int, String, Float64]) t = PartialStruct(𝕃, Partial3, Any[Int, String]) diff --git a/base/coreir.jl b/base/coreir.jl index 8e95517b060c2..44f1bf4fd641c 100644 --- a/base/coreir.jl +++ b/base/coreir.jl @@ -37,24 +37,44 @@ exact number of elements is unknown (`undef` then has a length of `length(fields """ Core.PartialStruct -function Core.PartialStruct(@nospecialize(typ), fields::Vector{Any}) - ndef = lastindex(fields) - fields[end] === Vararg && (ndef -= 1) - t = typ - (isa(t, UnionAll) || isa(t, Union)) && (t = argument_datatype(t)) - nfields = isa(t, DataType) ? datatype_fieldcount(t) : nothing - if nfields === nothing || nfields == ndef - undef = falses(ndef) - else - @assert nfields ≥ ndef - undef = trues(nfields) - for i in 1:ndef - undef[i] = false - end +function Core.PartialStruct(@nospecialize(typ), undef::BitVector, fields::Vector{Any}) + @assert length(undef) ≥ length(fields) - (fields[end] === Vararg) + for i in 1:datatype_min_ninitialized(typ) + undef[i] = false end Core._PartialStruct(typ, undef, fields) end +function get_fieldcount(@nospecialize(t)) + if isa(t, UnionAll) || isa(t, Union) + t = argument_datatype(t) + t === nothing && return nothing + end + isa(t, DataType) || return nothing + return datatype_fieldcount(t) +end + +function Core.PartialStruct(@nospecialize(typ), fields::Vector{Any}) + nf = length(fields) + fields[end] === Vararg && (nf -= 1) + nflds = get_fieldcount(typ) + nflds === nothing && (nflds = nf) + undef = trues(nflds) + + # The provided fields (in absence of an `undef` argument) + # are assumed to be defined. + for i in 1:nf + undef[i] = false + end + + # Make sure no field is missing. + if nflds > nf + fields = Any[get(fields, i, fieldtype(typ, i)) for i in 1:nflds] + end + + Core.PartialStruct(typ, undef, fields) +end + (==)(a::PartialStruct, b::PartialStruct) = a.typ === b.typ && a.undef == b.undef && a.fields == b.fields function Base.getproperty(pstruct::Core.PartialStruct, name::Symbol) diff --git a/base/essentials.jl b/base/essentials.jl index 5db7a5f6fb0d9..091fc114f0a7d 100644 --- a/base/essentials.jl +++ b/base/essentials.jl @@ -593,6 +593,39 @@ function unconstrain_vararg_length(va::Core.TypeofVararg) return Vararg{unwrapva(va)} end +# Compute the minimum number of initialized fields for a particular datatype +# (therefore also a lower bound on the number of fields) +function datatype_min_ninitialized(@nospecialize t0) + t = unwrap_unionall(t0) + t isa DataType || return 0 + isabstracttype(t) && return 0 + if t.name === _NAMEDTUPLE_NAME + names, types = t.parameters[1], t.parameters[2] + if names isa Tuple + return length(names) + end + t = argument_datatype(types) + t isa DataType || return 0 + t.name === Tuple.name || return 0 + end + if t.name === Tuple.name + n = length(t.parameters) + n == 0 && return 0 + va = t.parameters[n] + if isvarargtype(va) + n -= 1 + if isdefined(va, :N) + va = va.N + if va isa Int + n += va + end + end + end + return n + end + return length(t.name.names) - t.name.n_uninitialized +end + import Core: typename _tuple_error(T::Type, x) = (@noinline; throw(MethodError(convert, (T, x)))) From c9ee9b01e1d43436f11a52cfa87cf3c85016bc72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Belmant?= Date: Fri, 14 Feb 2025 08:19:27 -0500 Subject: [PATCH 09/16] Fix tests --- Compiler/src/typeinfer.jl | 7 ++++--- Compiler/src/typelimits.jl | 14 +++++++++----- base/coreir.jl | 2 +- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/Compiler/src/typeinfer.jl b/Compiler/src/typeinfer.jl index ddcca9a6ffaa1..39086b510b8e5 100644 --- a/Compiler/src/typeinfer.jl +++ b/Compiler/src/typeinfer.jl @@ -513,7 +513,7 @@ function finishinfer!(me::InferenceState, interp::AbstractInterpreter) rettype_const = result_type.parameters[1] const_flags = 0x2 elseif isa(result_type, PartialStruct) - rettype_const = result_type.fields + rettype_const = (result_type.undef, result_type.fields) const_flags = 0x2 elseif isa(result_type, InterConditional) rettype_const = result_type @@ -957,8 +957,9 @@ function cached_return_type(code::CodeInstance) rettype_const = code.rettype_const # the second subtyping/egal conditions are necessary to distinguish usual cases # from rare cases when `Const` wrapped those extended lattice type objects - if isa(rettype_const, Vector{Any}) && !(Vector{Any} <: rettype) - return PartialStruct(fallback_lattice, rettype, rettype_const) + if isa(rettype_const, Tuple{BitVector, Vector{Any}}) && !(Tuple{BitVector, Vector{Any}} <: rettype) + undef, fields = rettype_const + return PartialStruct(fallback_lattice, rettype, undef, fields) elseif isa(rettype_const, PartialOpaque) && rettype <: Core.OpaqueClosure return rettype_const elseif isa(rettype_const, InterConditional) && rettype !== InterConditional diff --git a/Compiler/src/typelimits.jl b/Compiler/src/typelimits.jl index 94a33973e178b..f7e8e7018a4df 100644 --- a/Compiler/src/typelimits.jl +++ b/Compiler/src/typelimits.jl @@ -657,11 +657,15 @@ end if aty === bty && !isType(aty) typea::Union{PartialStruct, Const} typeb::Union{PartialStruct, Const} - undefined = maybeundef_fields(typea, typeb) - all(undefined) && return nothing - nflds = length(undefined) + maybeundef = maybeundef_fields(typea, typeb) + if all(maybeundef) + # We could also preserve information about refined field types + # (e.g. to better infer non-throwing `getfield` branches). + return nothing + end + nflds = length(maybeundef) fields = Vector{Any}(undef, nflds) - anyrefine = refines_definedness_information(aty, undefined) + anyrefine = refines_definedness_information(aty, maybeundef) for i = 1:nflds ai = getfield_tfunc(𝕃, typea, Const(i)) bi = getfield_tfunc(𝕃, typeb, Const(i)) @@ -704,7 +708,7 @@ end # handle that in the main loop above to get a more accurate type. push!(fields, Vararg) end - anyrefine && return PartialStruct(𝕃, aty, undefined, fields) + anyrefine && return PartialStruct(𝕃, aty, maybeundef, fields) end return nothing end diff --git a/base/coreir.jl b/base/coreir.jl index 44f1bf4fd641c..17af961907245 100644 --- a/base/coreir.jl +++ b/base/coreir.jl @@ -37,7 +37,7 @@ exact number of elements is unknown (`undef` then has a length of `length(fields """ Core.PartialStruct -function Core.PartialStruct(@nospecialize(typ), undef::BitVector, fields::Vector{Any}) +function Core.PartialStruct(typ::Type, undef::BitVector, fields::Vector{Any}) @assert length(undef) ≥ length(fields) - (fields[end] === Vararg) for i in 1:datatype_min_ninitialized(typ) undef[i] = false From ced6296de4229e28831769e19aa758a1be443ca7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Belmant?= Date: Fri, 14 Feb 2025 11:22:34 -0500 Subject: [PATCH 10/16] Remove unnecessary function --- Compiler/src/typelimits.jl | 8 -------- 1 file changed, 8 deletions(-) diff --git a/Compiler/src/typelimits.jl b/Compiler/src/typelimits.jl index f7e8e7018a4df..9ece9a2089e84 100644 --- a/Compiler/src/typelimits.jl +++ b/Compiler/src/typelimits.jl @@ -369,14 +369,6 @@ function _bitvector(nt::NTuple) bv end -function _count(bv::BitVector) - n = 0 - for val in bv - n += val - end - n -end - #- maybeundef_fields(t::Const) = undefined_fields(t) From b44dfb571889eb91e4aa39ef84e06c4137f55a5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Belmant?= Date: Mon, 17 Feb 2025 06:48:55 -0500 Subject: [PATCH 11/16] =?UTF-8?q?Fix=20unsoundness=20for=20`=E2=8A=91`,=20?= =?UTF-8?q?don't=20widen=20undef=20information=20for=20PartialStruct?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Compiler/src/typelattice.jl | 16 ++++++---------- Compiler/test/inference.jl | 13 +++++++++++-- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/Compiler/src/typelattice.jl b/Compiler/src/typelattice.jl index e439ce929f0a5..8800de7e8cbf0 100644 --- a/Compiler/src/typelattice.jl +++ b/Compiler/src/typelattice.jl @@ -318,28 +318,26 @@ end fields = vartyp.fields thenfields = thentype === Bottom ? nothing : copy(fields) elsefields = elsetype === Bottom ? nothing : copy(fields) + undef = copy(vartyp.undef) for i in 1:length(fields) if i == fldidx thenfields === nothing || (thenfields[i] = thentype) elsefields === nothing || (elsefields[i] = elsetype) + undef[i] = false end end return Conditional(slot, - thenfields === nothing ? Bottom : PartialStruct(fallback_lattice, vartyp.typ, vartyp.undef, thenfields), - elsefields === nothing ? Bottom : PartialStruct(fallback_lattice, vartyp.typ, vartyp.undef, elsefields)) + thenfields === nothing ? Bottom : PartialStruct(fallback_lattice, vartyp.typ, undef, thenfields), + elsefields === nothing ? Bottom : PartialStruct(fallback_lattice, vartyp.typ, undef, elsefields)) else vartyp_widened = widenconst(vartyp) thenfields = thentype === Bottom ? nothing : Any[] elsefields = elsetype === Bottom ? nothing : Any[] nf = fieldcount(vartyp_widened) - undef = trues(nf) for i in 1:nf if i == fldidx thenfields === nothing || push!(thenfields, thentype) elsefields === nothing || push!(elsefields, elsetype) - if thenfields === nothing && elsefields === nothing - undef[i] = false # this field was already accessed - end else t = fieldtype(vartyp_widened, i) thenfields === nothing || push!(thenfields, t) @@ -347,8 +345,8 @@ end end end return Conditional(slot, - thenfields === nothing ? Bottom : PartialStruct(fallback_lattice, vartyp_widened, undef, thenfields), - elsefields === nothing ? Bottom : PartialStruct(fallback_lattice, vartyp_widened, undef, elsefields)) + thenfields === nothing ? Bottom : PartialStruct(fallback_lattice, vartyp_widened, thenfields), + elsefields === nothing ? Bottom : PartialStruct(fallback_lattice, vartyp_widened, elsefields)) end end @@ -439,8 +437,6 @@ end length(a.undef) ≥ length(b.undef) || return false for i in 1:length(b.fields) !is_field_defined(a, i) && is_field_defined(b, i) && return false - !is_field_defined(b, i) && continue - # Field is defined for both `a` and `b` af = a.fields[i] bf = b.fields[i] if i == length(b.fields) diff --git a/Compiler/test/inference.jl b/Compiler/test/inference.jl index ea3ee18cad86a..a4b7ef53e69c7 100644 --- a/Compiler/test/inference.jl +++ b/Compiler/test/inference.jl @@ -4807,6 +4807,7 @@ module _Partials_inference end let ⊑ = Compiler.partialorder(Compiler.fallback_lattice) + ⋢ = !⊑ ⊔ = Compiler.join(Compiler.fallback_lattice) 𝕃 = Compiler.fallback_lattice Const, PartialStruct = Core.Const, Core.PartialStruct @@ -4843,7 +4844,7 @@ let ⊑ = Compiler.partialorder(Compiler.fallback_lattice) t1 = PartialStruct(𝕃, Partial, Any[String, Const(3)]) t2 = PartialStruct(𝕃, Partial, Any[Const("x"), Int]) - @test !(t1 ⊑ t2) && !(t2 ⊑ t1) + @test t1 ⋢ t2 && t2 ⋢ t1 t3 = t1 ⊔ t2 @test t3.fields == Any[String, Int, Any] @@ -4857,7 +4858,7 @@ let ⊑ = Compiler.partialorder(Compiler.fallback_lattice) @test t1.undef == [false, false] @test t1 ⊑ t1 && t1 ⊔ t1 == t1 t2 = PartialStruct(𝕃, Tuple, Any[Int, Any]) - @test !(t1 ⊑ t2) && !(t2 ⊑ t1) + @test t1 ⋢ t2 && t2 ⋢ t1 t3 = t1 ⊔ t2 @test t3.undef == [false, false] && t3.fields == Any[Int, Any] t2 = PartialStruct(𝕃, Tuple, Any[Int, Any, Vararg]) @@ -4884,6 +4885,14 @@ let ⊑ = Compiler.partialorder(Compiler.fallback_lattice) t = PartialStruct(𝕃, Partial3, Any[Int, String]) t′ = form_partially_defined_struct(t, Const(:z)) @test t′ == PartialStruct(𝕃, Partial3, Any[Int, String, Float64]) + + t1 = PartialStruct(𝕃, Partial3, Any[Int, String]) + t2 = PartialStruct(𝕃, Partial3, Any[Const(1)]) + @test t1 ⋢ t2 && t2 ⋢ t1 + c = @eval Const($(Expr(:new, Partial3, 1))) + @test c ⋢ t1 && t1 ⋢ c && c ⊑ t2 && t2 ⋢ c + t3 = PartialStruct(𝕃, Partial3, Any[Const(1), Const("x")]) + @test c ⋢ t3 && t3 ⋢ c end # Test that a function-wise `@max_methods` works as expected From b1efbe556321be3f07b9666108856ba1fbe36636 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Belmant?= Date: Tue, 18 Feb 2025 14:31:49 +0100 Subject: [PATCH 12/16] Apply suggestions from code review Co-authored-by: Shuhei Kadowaki <40514306+aviatesk@users.noreply.github.com> --- Compiler/src/abstractinterpretation.jl | 4 ++-- Compiler/src/typelattice.jl | 12 +++++------- base/coreir.jl | 2 +- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/Compiler/src/abstractinterpretation.jl b/Compiler/src/abstractinterpretation.jl index 8c9b2c9950a39..37273d36c9f99 100644 --- a/Compiler/src/abstractinterpretation.jl +++ b/Compiler/src/abstractinterpretation.jl @@ -2151,10 +2151,10 @@ function form_partially_defined_struct(@nospecialize(obj), @nospecialize(name)) isa(obj, PartialStruct) && return define_field(obj, fldidx, fieldtype(objt0, fldidx)) nminfld = datatype_min_ninitialized(objt) fldidx > nminfld || return nothing - fields = collect(Any, fieldtypes(objt0)) nmaxfld = something(datatype_fieldcount(objt), fldidx) undef = trues(nmaxfld) undef[fldidx] = false + fields = Any[fieldtype(objt0, i) for i = 1:nmaxfld] return PartialStruct(fallback_lattice, objt0, undef, fields) end @@ -3135,7 +3135,7 @@ function abstract_eval_splatnew(interp::AbstractInterpreter, e::Expr, sstate::St all(i::Int -> ⊑(𝕃ᵢ, (at.fields::Vector{Any})[i], fieldtype(t, i)), 1:n) end)) nothrow = isexact - rt = PartialStruct(𝕃ᵢ, rt, at.undef, at.fields::Vector{Any}) + rt = PartialStruct(𝕃ᵢ, rt, trues(length(at.undef)), at.fields::Vector{Any}) end else rt = refine_partial_type(rt) diff --git a/Compiler/src/typelattice.jl b/Compiler/src/typelattice.jl index 8800de7e8cbf0..0200d9b1ce1fc 100644 --- a/Compiler/src/typelattice.jl +++ b/Compiler/src/typelattice.jl @@ -319,12 +319,10 @@ end thenfields = thentype === Bottom ? nothing : copy(fields) elsefields = elsetype === Bottom ? nothing : copy(fields) undef = copy(vartyp.undef) - for i in 1:length(fields) - if i == fldidx - thenfields === nothing || (thenfields[i] = thentype) - elsefields === nothing || (elsefields[i] = elsetype) - undef[i] = false - end + if 1 ≤ fldidx ≤ length(fields) + thenfields === nothing || (thenfields[fldidx] = thentype) + elsefields === nothing || (elsefields[fldidx] = elsetype) + undef[fldidx] = false end return Conditional(slot, thenfields === nothing ? Bottom : PartialStruct(fallback_lattice, vartyp.typ, undef, thenfields), @@ -436,7 +434,7 @@ end end length(a.undef) ≥ length(b.undef) || return false for i in 1:length(b.fields) - !is_field_defined(a, i) && is_field_defined(b, i) && return false + is_field_defined(a, i) ≥ is_field_defined(b, i) || return false af = a.fields[i] bf = b.fields[i] if i == length(b.fields) diff --git a/base/coreir.jl b/base/coreir.jl index 17af961907245..a402bc844c86e 100644 --- a/base/coreir.jl +++ b/base/coreir.jl @@ -38,7 +38,7 @@ exact number of elements is unknown (`undef` then has a length of `length(fields Core.PartialStruct function Core.PartialStruct(typ::Type, undef::BitVector, fields::Vector{Any}) - @assert length(undef) ≥ length(fields) - (fields[end] === Vararg) + @assert length(undef) == length(fields) - (isvarargtype(fields[end])) for i in 1:datatype_min_ninitialized(typ) undef[i] = false end From 73095063eb5ca05bf012ca12ca7b404ca031fa27 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Belmant?= Date: Thu, 20 Feb 2025 14:03:01 -0500 Subject: [PATCH 13/16] Allow sparse field type information, remove mutations in constructor --- Compiler/src/Compiler.jl | 3 +- Compiler/src/abstractinterpretation.jl | 7 +- Compiler/src/tfuncs.jl | 6 +- Compiler/src/typelattice.jl | 32 ++++--- Compiler/src/typelimits.jl | 117 ++++++++++++++----------- Compiler/test/inference.jl | 52 ++++++----- base/coreir.jl | 47 ++++------ 7 files changed, 138 insertions(+), 126 deletions(-) diff --git a/Compiler/src/Compiler.jl b/Compiler/src/Compiler.jl index 394039c75877f..d29f227c79e6b 100644 --- a/Compiler/src/Compiler.jl +++ b/Compiler/src/Compiler.jl @@ -67,7 +67,8 @@ using Base: @_foldable_meta, @_gc_preserve_begin, @_gc_preserve_end, @nospeciali partition_restriction, quoted, rename_unionall, rewrap_unionall, specialize_method, structdiff, tls_world_age, unconstrain_vararg_length, unionlen, uniontype_layout, uniontypes, unsafe_convert, unwrap_unionall, unwrapva, vect, widen_diagonal, - _uncompressed_ir, maybe_add_binding_backedge!, datatype_min_ninitialized + _uncompressed_ir, maybe_add_binding_backedge!, datatype_min_ninitialized, + partialstruct_undef_length, partialstruct_init_undef using Base.Order import Base: ==, _topmod, append!, convert, copy, copy!, findall, first, get, get!, diff --git a/Compiler/src/abstractinterpretation.jl b/Compiler/src/abstractinterpretation.jl index 37273d36c9f99..519795626cf4e 100644 --- a/Compiler/src/abstractinterpretation.jl +++ b/Compiler/src/abstractinterpretation.jl @@ -2151,10 +2151,9 @@ function form_partially_defined_struct(@nospecialize(obj), @nospecialize(name)) isa(obj, PartialStruct) && return define_field(obj, fldidx, fieldtype(objt0, fldidx)) nminfld = datatype_min_ninitialized(objt) fldidx > nminfld || return nothing - nmaxfld = something(datatype_fieldcount(objt), fldidx) - undef = trues(nmaxfld) + undef = partialstruct_init_undef(objt, fldidx; all_defined = false) undef[fldidx] = false - fields = Any[fieldtype(objt0, i) for i = 1:nmaxfld] + fields = Any[fieldtype(objt0, i) for i = 1:fldidx] return PartialStruct(fallback_lattice, objt0, undef, fields) end @@ -3135,7 +3134,7 @@ function abstract_eval_splatnew(interp::AbstractInterpreter, e::Expr, sstate::St all(i::Int -> ⊑(𝕃ᵢ, (at.fields::Vector{Any})[i], fieldtype(t, i)), 1:n) end)) nothrow = isexact - rt = PartialStruct(𝕃ᵢ, rt, trues(length(at.undef)), at.fields::Vector{Any}) + rt = PartialStruct(𝕃ᵢ, rt, at.fields::Vector{Any}) end else rt = refine_partial_type(rt) diff --git a/Compiler/src/tfuncs.jl b/Compiler/src/tfuncs.jl index c384885df5c3f..ff57495f1848f 100644 --- a/Compiler/src/tfuncs.jl +++ b/Compiler/src/tfuncs.jl @@ -439,7 +439,7 @@ end end elseif isa(arg1, PartialStruct) if !isvarargtype(arg1.fields[end]) - if is_field_defined(arg1, idx) + if is_field_initialized(arg1, idx) return Const(true) end end @@ -1141,8 +1141,8 @@ end sty = unwrap_unionall(s)::DataType if isa(name, Const) nv = _getfield_fieldindex(sty, name) - if isa(nv, Int) && is_field_defined(s00, nv) - return unwrapva(s00.fields[nv]) + if isa(nv, Int) && is_field_initialized(s00, nv) + return unwrapva(partialstruct_getfield(s00, nv)) end end s00 = s diff --git a/Compiler/src/typelattice.jl b/Compiler/src/typelattice.jl index 0200d9b1ce1fc..5c3ff95366d59 100644 --- a/Compiler/src/typelattice.jl +++ b/Compiler/src/typelattice.jl @@ -331,8 +331,7 @@ end vartyp_widened = widenconst(vartyp) thenfields = thentype === Bottom ? nothing : Any[] elsefields = elsetype === Bottom ? nothing : Any[] - nf = fieldcount(vartyp_widened) - for i in 1:nf + for i in 1:fieldcount(vartyp_widened) if i == fldidx thenfields === nothing || push!(thenfields, thentype) elsefields === nothing || push!(elsefields, elsetype) @@ -432,12 +431,14 @@ end return false end end - length(a.undef) ≥ length(b.undef) || return false - for i in 1:length(b.fields) - is_field_defined(a, i) ≥ is_field_defined(b, i) || return false - af = a.fields[i] - bf = b.fields[i] - if i == length(b.fields) + na = length(a.fields) + nb = length(b.fields) + nmax = max(na, nb) + for i in 1:nmax + is_field_initialized(a, i) ≥ is_field_initialized(b, i) || return false + af = partialstruct_getfield(a, i) + bf = partialstruct_getfield(b, i) + if i == na || i == nb if isvarargtype(af) # If `af` is vararg, so must bf by the <: above @assert isvarargtype(bf) @@ -467,16 +468,14 @@ end nfields(a.val) == length(b.fields) || return false else widea <: wideb || return false - # for structs we need to check that `a` has more information than `b` that may be partially initialized - # it may happen that `b` has more information beyond the first undefined field - # but in this case we choose `Const` nonetheless. + # for structs we need to check that `a` does not have less information than `b` that may be partially initialized n_initialized(a) ≥ n_initialized(b) || return false end nf = nfields(a.val) for i in 1:nf isdefined(a.val, i) || continue # since ∀ T Union{} ⊑ T i > length(b.fields) && break # `a` has more information than `b` that is partially initialized struct - is_field_defined(b, i) || continue # `a` gives a decisive answer as to whether the field is defined or undefined + is_field_initialized(b, i) || continue # `a` gives a decisive answer as to whether the field is defined or undefined bfᵢ = b.fields[i] if i == nf bfᵢ = unwrapva(bfᵢ) @@ -754,11 +753,10 @@ end # The ::AbstractLattice argument is unused and simply serves to disambiguate # different instances of the compiler that may share the `Core.PartialStruct` # type. -function Core.PartialStruct(::AbstractLattice, @nospecialize(typ), fields::Vector{Any}) - for i = 1:length(fields) - assert_nested_slotwrapper(fields[i]) - end - return PartialStruct(typ, fields) + +function Core.PartialStruct(𝕃::AbstractLattice, @nospecialize(typ), fields::Vector{Any}; all_defined::Bool = true) + undef = partialstruct_init_undef(typ, fields; all_defined) + return PartialStruct(𝕃, typ, undef, fields) end function Core.PartialStruct(::AbstractLattice, @nospecialize(typ), undef::BitVector, fields::Vector{Any}) diff --git a/Compiler/src/typelimits.jl b/Compiler/src/typelimits.jl index 9ece9a2089e84..a97e07f1d3022 100644 --- a/Compiler/src/typelimits.jl +++ b/Compiler/src/typelimits.jl @@ -326,35 +326,55 @@ function n_initialized(t::Const) return something(findfirst(i::Int->!isdefined(t.val,i), 1:nf), nf+1)-1 end +is_field_initialized(t::Const, i) = isdefined(t.val, i) + function n_initialized(pstruct::PartialStruct) i = findfirst(pstruct.undef) - i !== nothing && return i - 1 - length(pstruct.undef) + nmin = datatype_min_ninitialized(pstruct.typ) + i === nothing && return max(length(pstruct.undef), nmin) + n = i::Int - 1 + @assert n ≥ nmin + n end -maybeundef_fields(pstruct::PartialStruct) = pstruct.undef -is_field_defined(pstruct::PartialStruct, fi) = 1 ≤ fi ≤ length(pstruct.undef) && !pstruct.undef[fi] +function is_field_initialized(pstruct::PartialStruct, fi) + fi ≥ 1 || return false + fi ≤ length(pstruct.undef) && return !pstruct.undef[fi] + fi ≤ datatype_min_ninitialized(pstruct.typ) +end -refines_definedness_information(pstruct::PartialStruct) = !isvarargtype(pstruct.fields[end]) && refines_definedness_information(pstruct.typ, pstruct.undef) -function refines_definedness_information(@nospecialize(typ), undef) - nflds = length(undef) - something(findfirst(undef), nflds + 1) > datatype_min_ninitialized(typ) + 1 +function partialstruct_getfield(pstruct::PartialStruct, fi::Integer) + @assert fi > 0 + fi ≤ length(pstruct.fields) && return pstruct.fields[fi] + fieldtype(pstruct.typ, fi) end -# Returns an iterator over contiguously defined fields -function defined_fields(pstruct::PartialStruct) - i = findfirst(pstruct.undef) - i === nothing && return pstruct.fields - Any[pstruct.fields[i] for i in 1:(i - 1)] +function refines_definedness_information(pstruct::PartialStruct) + nflds = length(pstruct.undef) + something(findfirst(pstruct.undef), nflds + 1) - 1 > datatype_min_ninitialized(pstruct.typ) end function define_field(pstruct::PartialStruct, fi, @nospecialize(ft)) - !pstruct.undef[fi] && return nothing # no new information to be gained - undef = copy(pstruct.undef) - fields = copy(pstruct.fields) + if is_field_initialized(pstruct, fi) + # no new information to be gained + return nothing + end + + n = length(pstruct.undef) + undef = partialstruct_init_undef(pstruct.typ, max(fi, n); all_defined = false) + for i in 1:n + undef[i] &= pstruct.undef[i] + end undef[fi] = false + + fields = copy(pstruct.fields) + nf = length(fields) + typ = pstruct.typ + for i in (nf + 1):fi + push!(fields, fieldtype(typ, i)) + end fields[fi] = ft - PartialStruct(fallback_lattice, pstruct.typ, undef, fields) + PartialStruct(fallback_lattice, typ, undef, fields) end # needed while we are missing functions such as broadcasting or ranges @@ -371,19 +391,6 @@ end #- -maybeundef_fields(t::Const) = undefined_fields(t) -function undefined_fields(t::Const) - nf = nfields(t.val) - _bitvector(ntuple(i -> !isdefined(t.val, i), nf)) -end - -function maybeundef_fields(x, y) - xdef = maybeundef_fields(x) - ydef = maybeundef_fields(y) - n = min(length(xdef), length(ydef)) - _bitvector(ntuple(i -> xdef[i] | ydef[i], n)) -end - # A simplified type_more_complex query over the extended lattice # (assumes typeb ⊑ typea) @nospecializeinfer function issimplertype(𝕃::AbstractLattice, @nospecialize(typea), @nospecialize(typeb)) @@ -391,11 +398,11 @@ end typea === typeb && return true if typea isa PartialStruct aty = widenconst(typea) - if typeb isa Const + if typeb isa Const || typeb isa PartialStruct @assert n_initialized(typea) ≤ n_initialized(typeb) "typeb ⊑ typea is assumed" elseif typeb isa PartialStruct - @assert length(typea.fields) ≤ length(typeb.fields) && - all(!b | a for (a, b) in zip(typea.undef, typeb.undef)) "typeb ⊑ typea is assumed" + @assert n_initialized(typea) ≤ n_initialized(typeb) && + all(b < a for (a, b) in zip(typea.undef, typeb.undef)) "typeb ⊑ typea is assumed" else return false end @@ -647,17 +654,27 @@ end aty = widenconst(typea) bty = widenconst(typeb) if aty === bty && !isType(aty) - typea::Union{PartialStruct, Const} - typeb::Union{PartialStruct, Const} - maybeundef = maybeundef_fields(typea, typeb) - if all(maybeundef) - # We could also preserve information about refined field types - # (e.g. to better infer non-throwing `getfield` branches). - return nothing - end - nflds = length(maybeundef) + if typea isa PartialStruct + if typeb isa PartialStruct + nflds = min(length(typea.fields), length(typeb.fields)) + nundef = nflds - (isvarargtype(typea.fields[end]) && isvarargtype(typeb.fields[end])) + else + nflds = min(length(typea.fields), n_initialized(typeb::Const)) + nundef = nflds + end + elseif typeb isa PartialStruct + nflds = min(n_initialized(typea::Const), length(typeb.fields)) + nundef = nflds + else + nflds = min(n_initialized(typea::Const), n_initialized(typeb::Const)) + nundef = nflds + end + nflds == 0 && return nothing fields = Vector{Any}(undef, nflds) - anyrefine = refines_definedness_information(aty, maybeundef) + _undef = trues(nundef) + fldmin = datatype_min_ninitialized(aty) + n_initialized_merged = min(n_initialized(typea::Union{Const, PartialStruct}), n_initialized(typeb::Union{Const, PartialStruct})) + anyrefine = n_initialized_merged > fldmin for i = 1:nflds ai = getfield_tfunc(𝕃, typea, Const(i)) bi = getfield_tfunc(𝕃, typeb, Const(i)) @@ -689,18 +706,16 @@ end end end fields[i] = tyi + if i ≤ nundef + _undef[i] = !is_field_initialized(typea, i) || !is_field_initialized(typeb, i) + end if !anyrefine anyrefine = has_nontrivial_extended_info(𝕃, tyi) || # extended information - ⋤(𝕃, tyi, ft) # just a type-level information, but more precise than the declared type + ⋤(𝕃, tyi, ft) || # just a type-level information, but more precise than the declared type + !get(_undef, i, true) && i > fldmin # possibly initialized field is known to be initialized end end - if isa(typea, PartialStruct) && isa(typeb, PartialStruct) && - isvarargtype(typea.fields[end]) && isvarargtype(typeb.fields[end]) - # XXX: If it may be more precise than `Vararg` (e.g. `Vararg{T}`), - # handle that in the main loop above to get a more accurate type. - push!(fields, Vararg) - end - anyrefine && return PartialStruct(𝕃, aty, maybeundef, fields) + anyrefine && return PartialStruct(𝕃, aty, _undef, fields) end return nothing end diff --git a/Compiler/test/inference.jl b/Compiler/test/inference.jl index a4b7ef53e69c7..ce6e38bfe1d76 100644 --- a/Compiler/test/inference.jl +++ b/Compiler/test/inference.jl @@ -4727,7 +4727,7 @@ end c = a ⊔ b @test a ⊑ c && b ⊑ c @test c isa PartialStruct - @test c.undef == a.undef == [0, 1, 1] + @test length(c.fields) == 1 && c.undef == [0] end let T = Base.ImmutableDict{Number,Number} a = PartialStruct(𝕃, T, Any[T]) @@ -4788,21 +4788,28 @@ module _Partials_inference x::String y::Integer z::Any - Partial(args...) = new(args...) + Partial() = new() end struct Partial2 x::String y::Integer z::Any - Partial2(args...) = new(args...) + Partial2(x) = new(x) end struct Partial3 x::Int y::String z::Float64 - Partial3(args...) = new(args...) + Partial3(x, y) = new(x, y) + end + + struct Partial4 + x::Int + y::String + z::Float64 + Partial4(x) = new(x) end end @@ -4813,7 +4820,7 @@ let ⊑ = Compiler.partialorder(Compiler.fallback_lattice) Const, PartialStruct = Core.Const, Core.PartialStruct form_partially_defined_struct = Compiler.form_partially_defined_struct M = _Partials_inference - Partial, Partial2, Partial3 = M.Partial, M.Partial2, M.Partial3 + Partial, Partial2, Partial3, Partial4 = M.Partial, M.Partial2, M.Partial3, M.Partial4 @test (Const((1,2)) ⊑ PartialStruct(𝕃, Tuple{Int,Int}, Any[Const(1),Int])) @test !(Const((1,2)) ⊑ PartialStruct(𝕃, Tuple{Int,Int,Int}, Any[Const(1),Int,Int])) @@ -4835,22 +4842,26 @@ let ⊑ = Compiler.partialorder(Compiler.fallback_lattice) t = t ⊔ Const((false, false, 0)) @test t ⊑ Union{Tuple{Bool,Bool},Tuple{Bool,Bool,Int}} - t = PartialStruct(𝕃, Tuple{Int, Int}, Any[Const(1), Int]) - @test t.undef == [false, false] + t = PartialStruct(𝕃, Tuple{Int, Int}, Any[Const(1)]) + @test t.undef == [false] + @test Compiler.is_field_initialized(t, 2) + @test Compiler.n_initialized(t) == 2 t = PartialStruct(𝕃, Partial, Any[String, Const(2)]) - @test t.undef == [false, false, true] - @test t.fields == Any[String, Const(2), Any] + @test t.undef == [false, false] + @test t.fields == Any[String, Const(2)] @test t ⊑ t && t ⊔ t === t t1 = PartialStruct(𝕃, Partial, Any[String, Const(3)]) - t2 = PartialStruct(𝕃, Partial, Any[Const("x"), Int]) + t2 = PartialStruct(𝕃, Partial, Any[Const("x")]) @test t1 ⋢ t2 && t2 ⋢ t1 t3 = t1 ⊔ t2 - @test t3.fields == Any[String, Int, Any] + @test t3.fields == Any[String] t1 = PartialStruct(𝕃, Partial, BitVector([true, false, false]), Any[String, Int, Const(3)]) + @test Compiler.n_initialized(t1) == 0 @test t1 ⊑ t1 && t1 ⊔ t1 === t1 - t2 = PartialStruct(𝕃, Partial, BitVector([false, true, false]), Any[Const("x"), Int, Any]) + t2 = PartialStruct(𝕃, Partial, BitVector([false, true]), Any[Const("x"), Int]) + @test Compiler.n_initialized(t2) == 1 t3 = t1 ⊔ t2 @test t3 === Partial @@ -4865,6 +4876,10 @@ let ⊑ = Compiler.partialorder(Compiler.fallback_lattice) @test t1 ⊑ t2 @test t1 ⊔ t2 === t2 + t = PartialStruct(𝕃, Partial, Any[Const("x")]) + @test form_partially_defined_struct(t, Const(:x)) === nothing + t′ = form_partially_defined_struct(t, Const(:z)) + @test t′ == PartialStruct(𝕃, Partial, BitVector([false, true, false]), Any[Const("x"), Integer, Any]) t = PartialStruct(𝕃, Partial, Any[String, Const(2)]) @test form_partially_defined_struct(t, Const(:x)) === nothing t′ = form_partially_defined_struct(t, Const(:z)) @@ -4876,22 +4891,19 @@ let ⊑ = Compiler.partialorder(Compiler.fallback_lattice) @test t′ == PartialStruct(𝕃, Partial2, Any[String, Const(2), Any]) @test form_partially_defined_struct(Partial3, Const(:x)) === nothing - t = form_partially_defined_struct(Partial3, Const(:y)) - @test t == PartialStruct(𝕃, Partial3, Any[Int, String]) + @test form_partially_defined_struct(Partial3, Const(:y)) === nothing t = form_partially_defined_struct(Partial3, Const(:z)) - @test t == PartialStruct(𝕃, Partial3, BitVector([false, true, false]), Any[Int, String, Float64]) - t = form_partially_defined_struct(t, Const(:y)) @test t == PartialStruct(𝕃, Partial3, Any[Int, String, Float64]) t = PartialStruct(𝕃, Partial3, Any[Int, String]) t′ = form_partially_defined_struct(t, Const(:z)) @test t′ == PartialStruct(𝕃, Partial3, Any[Int, String, Float64]) - t1 = PartialStruct(𝕃, Partial3, Any[Int, String]) - t2 = PartialStruct(𝕃, Partial3, Any[Const(1)]) + t1 = PartialStruct(𝕃, Partial4, Any[Int, String]) + t2 = PartialStruct(𝕃, Partial4, Any[Const(1)]) @test t1 ⋢ t2 && t2 ⋢ t1 - c = @eval Const($(Expr(:new, Partial3, 1))) + c = Const(Partial4(1)) @test c ⋢ t1 && t1 ⋢ c && c ⊑ t2 && t2 ⋢ c - t3 = PartialStruct(𝕃, Partial3, Any[Const(1), Const("x")]) + t3 = PartialStruct(𝕃, Partial4, Any[Const(1), Const("x")]) @test c ⋢ t3 && t3 ⋢ c end diff --git a/base/coreir.jl b/base/coreir.jl index a402bc844c86e..a0a164b4cfc2f 100644 --- a/base/coreir.jl +++ b/base/coreir.jl @@ -28,8 +28,8 @@ with `Int` values. - `fields` holds the lattice elements corresponding to each defined field of the object If `typ` is a struct, `undef` represents whether the corresponding field of the struct is guaranteed to be -initialized. For any defined field (`undef[i] === false`), there is a corresponding `fields` element -which provides information about the type of the defined field. +initialized. For any defined field, there is a corresponding `fields` element which provides information +about the type of the defined field. If `typ` is a tuple, the last element of `fields` may be `Vararg`. In this case, it is guaranteed that the number of elements in the tuple is at least `length(fields)-1`, but the @@ -38,41 +38,28 @@ exact number of elements is unknown (`undef` then has a length of `length(fields Core.PartialStruct function Core.PartialStruct(typ::Type, undef::BitVector, fields::Vector{Any}) - @assert length(undef) == length(fields) - (isvarargtype(fields[end])) - for i in 1:datatype_min_ninitialized(typ) - undef[i] = false - end + @assert length(undef) == length(fields) - isvarargtype(fields[end]) Core._PartialStruct(typ, undef, fields) end -function get_fieldcount(@nospecialize(t)) - if isa(t, UnionAll) || isa(t, Union) - t = argument_datatype(t) - t === nothing && return nothing - end - isa(t, DataType) || return nothing - return datatype_fieldcount(t) +function Core.PartialStruct(@nospecialize(typ), fields::Vector{Any}) + Core.PartialStruct(typ, partialstruct_init_undef(typ, fields), fields) end -function Core.PartialStruct(@nospecialize(typ), fields::Vector{Any}) - nf = length(fields) - fields[end] === Vararg && (nf -= 1) - nflds = get_fieldcount(typ) - nflds === nothing && (nflds = nf) - undef = trues(nflds) - - # The provided fields (in absence of an `undef` argument) - # are assumed to be defined. - for i in 1:nf - undef[i] = false - end +partialstruct_undef_length(fields) = length(fields) - isvarargtype(fields[end]) - # Make sure no field is missing. - if nflds > nf - fields = Any[get(fields, i, fieldtype(typ, i)) for i in 1:nflds] - end +function partialstruct_init_undef(@nospecialize(typ), fields; all_defined = true) + n = partialstruct_undef_length(fields) + partialstruct_init_undef(typ, n; all_defined) +end - Core.PartialStruct(typ, undef, fields) +function partialstruct_init_undef(@nospecialize(typ), n::Integer; all_defined = true) + all_defined && return falses(n) + undef = trues(n) + for i in 1:min(datatype_min_ninitialized(typ), n) + undef[i] = false + end + undef end (==)(a::PartialStruct, b::PartialStruct) = a.typ === b.typ && a.undef == b.undef && a.fields == b.fields From d771b94e270b0d9ff876ee35f5617dba04d528ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Belmant?= Date: Thu, 20 Feb 2025 14:32:00 -0500 Subject: [PATCH 14/16] Remove unused utility --- Compiler/src/typelimits.jl | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/Compiler/src/typelimits.jl b/Compiler/src/typelimits.jl index a97e07f1d3022..ba0298b10aca6 100644 --- a/Compiler/src/typelimits.jl +++ b/Compiler/src/typelimits.jl @@ -377,20 +377,6 @@ function define_field(pstruct::PartialStruct, fi, @nospecialize(ft)) PartialStruct(fallback_lattice, typ, undef, fields) end -# needed while we are missing functions such as broadcasting or ranges - -function _bitvector(nt::NTuple) - bv = BitVector(undef, length(nt)) - i = 1 - while i ≤ length(nt) - bv[i] = nt[i] - i += 1 - end - bv -end - -#- - # A simplified type_more_complex query over the extended lattice # (assumes typeb ⊑ typea) @nospecializeinfer function issimplertype(𝕃::AbstractLattice, @nospecialize(typea), @nospecialize(typeb)) From b20cf698ea9ac0abe1529e81c34439ecf8868699 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Belmant?= Date: Thu, 20 Feb 2025 18:24:57 -0500 Subject: [PATCH 15/16] Refactor `define_field`, don't overwrite defined field type --- Compiler/src/abstractinterpretation.jl | 2 +- Compiler/src/typelimits.jl | 27 +++++++++++++++----------- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/Compiler/src/abstractinterpretation.jl b/Compiler/src/abstractinterpretation.jl index 519795626cf4e..3039d17bd9ba1 100644 --- a/Compiler/src/abstractinterpretation.jl +++ b/Compiler/src/abstractinterpretation.jl @@ -2148,7 +2148,7 @@ function form_partially_defined_struct(@nospecialize(obj), @nospecialize(name)) isabstracttype(objt) && return nothing fldidx = try_compute_fieldidx(objt, name.val) fldidx === nothing && return nothing - isa(obj, PartialStruct) && return define_field(obj, fldidx, fieldtype(objt0, fldidx)) + isa(obj, PartialStruct) && return define_field(obj, fldidx) nminfld = datatype_min_ninitialized(objt) fldidx > nminfld || return nothing undef = partialstruct_init_undef(objt, fldidx; all_defined = false) diff --git a/Compiler/src/typelimits.jl b/Compiler/src/typelimits.jl index ba0298b10aca6..2bad6bc7817aa 100644 --- a/Compiler/src/typelimits.jl +++ b/Compiler/src/typelimits.jl @@ -354,27 +354,32 @@ function refines_definedness_information(pstruct::PartialStruct) something(findfirst(pstruct.undef), nflds + 1) - 1 > datatype_min_ninitialized(pstruct.typ) end -function define_field(pstruct::PartialStruct, fi, @nospecialize(ft)) +function define_field(pstruct::PartialStruct, fi::Int) if is_field_initialized(pstruct, fi) # no new information to be gained return nothing end + new = expand_partialstruct(pstruct, fi) + if new === nothing + new = PartialStruct(fallback_lattice, pstruct.typ, copy(pstruct.undef), copy(pstruct.fields)) + end + new.undef[fi] = false + return new +end + +function expand_partialstruct(pstruct::PartialStruct, until::Int) n = length(pstruct.undef) - undef = partialstruct_init_undef(pstruct.typ, max(fi, n); all_defined = false) + until ≤ n && return nothing + + undef = partialstruct_init_undef(pstruct.typ, until; all_defined = false) for i in 1:n undef[i] &= pstruct.undef[i] end - undef[fi] = false - - fields = copy(pstruct.fields) - nf = length(fields) + nf = length(pstruct.fields) typ = pstruct.typ - for i in (nf + 1):fi - push!(fields, fieldtype(typ, i)) - end - fields[fi] = ft - PartialStruct(fallback_lattice, typ, undef, fields) + fields = Any[i ≤ nf ? pstruct.fields[i] : fieldtype(typ, i) for i in 1:until] + return PartialStruct(fallback_lattice, typ, undef, fields) end # A simplified type_more_complex query over the extended lattice From 32d3a893690bc5bcd7e334c15d6f661a02ed1891 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Belmant?= Date: Thu, 20 Feb 2025 18:33:22 -0500 Subject: [PATCH 16/16] Minor polish --- Compiler/src/typelimits.jl | 4 ++-- base/coreir.jl | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/Compiler/src/typelimits.jl b/Compiler/src/typelimits.jl index 2bad6bc7817aa..a30b630b70bd7 100644 --- a/Compiler/src/typelimits.jl +++ b/Compiler/src/typelimits.jl @@ -661,8 +661,8 @@ end nundef = nflds end nflds == 0 && return nothing + _undef = partialstruct_init_undef(aty, nundef; all_defined = false) fields = Vector{Any}(undef, nflds) - _undef = trues(nundef) fldmin = datatype_min_ninitialized(aty) n_initialized_merged = min(n_initialized(typea::Union{Const, PartialStruct}), n_initialized(typeb::Union{Const, PartialStruct})) anyrefine = n_initialized_merged > fldmin @@ -703,7 +703,7 @@ end if !anyrefine anyrefine = has_nontrivial_extended_info(𝕃, tyi) || # extended information ⋤(𝕃, tyi, ft) || # just a type-level information, but more precise than the declared type - !get(_undef, i, true) && i > fldmin # possibly initialized field is known to be initialized + !get(_undef, i, true) && i > fldmin # possibly uninitialized field is known to be initialized end end anyrefine && return PartialStruct(𝕃, aty, _undef, fields) diff --git a/base/coreir.jl b/base/coreir.jl index a0a164b4cfc2f..f36617be5fba4 100644 --- a/base/coreir.jl +++ b/base/coreir.jl @@ -39,18 +39,18 @@ Core.PartialStruct function Core.PartialStruct(typ::Type, undef::BitVector, fields::Vector{Any}) @assert length(undef) == length(fields) - isvarargtype(fields[end]) - Core._PartialStruct(typ, undef, fields) + return Core._PartialStruct(typ, undef, fields) end function Core.PartialStruct(@nospecialize(typ), fields::Vector{Any}) - Core.PartialStruct(typ, partialstruct_init_undef(typ, fields), fields) + return Core.PartialStruct(typ, partialstruct_init_undef(typ, fields), fields) end partialstruct_undef_length(fields) = length(fields) - isvarargtype(fields[end]) function partialstruct_init_undef(@nospecialize(typ), fields; all_defined = true) n = partialstruct_undef_length(fields) - partialstruct_init_undef(typ, n; all_defined) + return partialstruct_init_undef(typ, n; all_defined) end function partialstruct_init_undef(@nospecialize(typ), n::Integer; all_defined = true) @@ -59,14 +59,14 @@ function partialstruct_init_undef(@nospecialize(typ), n::Integer; all_defined = for i in 1:min(datatype_min_ninitialized(typ), n) undef[i] = false end - undef + return undef end (==)(a::PartialStruct, b::PartialStruct) = a.typ === b.typ && a.undef == b.undef && a.fields == b.fields function Base.getproperty(pstruct::Core.PartialStruct, name::Symbol) name === :undef && return getfield(pstruct, :undef)::BitVector - getfield(pstruct, name) + return getfield(pstruct, name) end """