Skip to content

Commit

Permalink
Slightly fancier test for allreduce
Browse files Browse the repository at this point in the history
  • Loading branch information
kshyatt committed Dec 4, 2019
1 parent d9b4e16 commit bb10cac
Showing 1 changed file with 30 additions and 4 deletions.
34 changes: 30 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ using Test
@test device(comms[i]) == i-1
@test size(comms[i]) == length(CUDAdrv.devices())
end
id = UniqueID()
#=num_devs = length(CUDAdrv.devices())
comm = Communicator(num_devs, id, 0)
@test device(comm) == 0=#
#id = UniqueID()
#num_devs = length(CUDAdrv.devices())
#comm = Communicator(num_devs, id, 0)
#@test device(comm) == 0
end
@testset "Allreduce!" begin
devs = CUDAdrv.devices()
Expand All @@ -37,6 +37,32 @@ using Test
crecv = collect(recvbuf[ii])
@test all(crecv .== answer)
end
# more complex example?
recvbuf = Vector{CuMatrix{Float64}}(undef, length(devs))
sendbuf = Vector{CuMatrix{Float64}}(undef, length(devs))
streams = Vector{CuStream}(undef, length(devs))
m = 256
k = 512
n = 256
As = [rand(m, k) for i in 1:length(devs)]
Bs = [rand(k, n) for i in 1:length(devs)]
C = sum(As .* Bs)
for (ii, dev) in enumerate(devs)
CUDAnative.device!(ii - 1)
sendbuf[ii] = cu(As[ii]) * cu(Bs[ii])
recvbuf[ii] = CuArrays.zeros(Float64, m, n)
streams[ii] = CuStream()
end
groupStart()
for ii in 1:length(devs)
Allreduce!(sendbuf[ii], recvbuf[ii], m*n, NCCL.ncclSum, comms[ii], stream=streams[ii])
end
groupEnd()
for (ii, dev) in enumerate(devs)
device!(ii - 1)
crecv = collect(recvbuf[ii])
@test crecv C rtol=1e-6
end
end
@testset "Broadcast!" begin
devs = CUDAdrv.devices()
Expand Down

0 comments on commit bb10cac

Please sign in to comment.