Skip to content

Commit

Permalink
Elide bounds checks when kernels contains manual ones. (#2621)
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt authored Jan 16, 2025
1 parent 6ef1a3d commit d07a245
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 4 deletions.
17 changes: 14 additions & 3 deletions src/device/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ Base.size(g::CuDeviceArray) = g.dims
Base.sizeof(x::CuDeviceArray) = Base.elsize(x) * length(x)

# we store the array length too; computing prod(size) is expensive
Base.size(g::CuDeviceArray{<:Any,1}) = (g.len,)
Base.length(g::CuDeviceArray) = g.len

Base.pointer(x::CuDeviceArray{T,<:Any,A}) where {T,A} = Base.unsafe_convert(LLVMPtr{T,A}, x)
Expand Down Expand Up @@ -78,7 +79,11 @@ Base.unsafe_convert(::Type{LLVMPtr{T,A}}, x::CuDeviceArray{T,<:Any,A}) where {T,
end

@device_function @inline function arrayref(A::CuDeviceArray{T}, index::Integer) where {T}
@boundscheck checkbounds(A, index)
# simplified bounds check to avoid the OneTo construction, which calls `max`
# and breaks elimination of redundant bounds checks in the generated code.
#@boundscheck checkbounds(A, index)
@boundscheck index <= length(A) || Base.throw_boundserror(A, index)

if Base.isbitsunion(T)
arrayref_union(A, index)
else
Expand Down Expand Up @@ -120,7 +125,10 @@ end
end

@device_function @inline function arrayset(A::CuDeviceArray{T}, x::T, index::Integer) where {T}
@boundscheck checkbounds(A, index)
# simplified bounds check (see `arrayref`)
#@boundscheck checkbounds(A, index)
@boundscheck index <= length(A) || Base.throw_boundserror(A, index)

if Base.isbitsunion(T)
arrayset_union(A, x, index)
else
Expand Down Expand Up @@ -151,7 +159,10 @@ end
end

@device_function @inline function const_arrayref(A::CuDeviceArray{T}, index::Integer) where {T}
@boundscheck checkbounds(A, index)
# simplified bounds check (see `arrayset`)
#@boundscheck checkbounds(A, index)
@boundscheck index <= length(A) || Base.throw_boundserror(A, index)

align = alignment(A)
unsafe_cached_load(pointer(A), index, Val(align))
end
Expand Down
1 change: 0 additions & 1 deletion test/base/exceptions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ let (proc, out, err) = julia_exec(`-g2 -e $script`)
@test count(device_error_re, out) == 1
@test count("BoundsError", out) == 1
@test count("Out-of-bounds array access", out) == 1
@test occursin("] checkbounds at $(joinpath(".", "abstractarray.jl"))", out)
@test occursin("] kernel at $(joinpath(".", "none"))", out)
end

Expand Down
18 changes: 18 additions & 0 deletions test/core/device/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,24 @@ end
@test !occursin("jl_invoke", ir)
CUDA.code_ptx(devnull, kernel, tt)
end

# test that we don't do needless bounds checking when the kernel already does it
# (enabled by the fact that we store `len` next to `dims`)
let
function kernel(A)
idx = threadIdx().x
if idx <= length(A)
# we did our own bounds checking, so no check should be left!
A[idx] = 1
end
return
end

for N in 1:3
ir = sprint(io->CUDA.code_llvm(io, kernel, Tuple{CuDeviceArray{Int,N,AS.Global}}))
@test !occursin("boundserror", ir)
end
end
end

@testset "views" begin
Expand Down

0 comments on commit d07a245

Please sign in to comment.