Skip to content

Commit

Permalink
fix ctc_gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Jul 10, 2021
1 parent f28343f commit 0be6765
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 7 deletions.
13 changes: 7 additions & 6 deletions src/losses/ctc-gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,14 +204,15 @@ end
function ctc_alpha(ŷ::CuArray, y)
= logsoftmax(ŷ)
blank = size(ŷ, 1)
z′ = fill(blank, 2 * length(y) + 1)
z′[eachindex(y) .* 2] = y
ycu = cu(y)
z′ = CUDA.fill(blank, 2 * length(y) + 1)
z′[eachindex(y) .* 2] .= ycu
T = size(ŷ, 2)
U′ = 2*length(y) + 1
alphas = CUDA.fill(log(zero(ŷ[1])), U′,T)
nRepeats = count_repeats(y)
alphas = CUDA.fill(log(zero(eltype(ŷ))), U′,T)
nRepeats = count_repeats(cpu(y))
nThreads = min(U′, MAX_THREADS)
@cuda blocks=1 threads=nThreads compute_alpha_kernel(ŷ, length(y), T, nRepeats, CuArray(y), CuArray(z′), alphas, blank)
@cuda blocks=1 threads=nThreads compute_alpha_kernel(ŷ, length(y), T, nRepeats, ycu, z′, alphas, blank)
return (loss=-1 * logsumexp(alphas[end-1:end]), alpha=alphas, z′=z′, yhat=ŷ, nRepeats=nRepeats)
end

Expand All @@ -221,7 +222,7 @@ function ∇ctc_loss(ŷ::CuArray, y, out)
loss, alphas, z′, ŷ, nRepeats = out
U′, T = size(alphas)
blank = size(ŷ, 1)
typed_zero = zero(first(ŷ))
typed_zero = zero(eltype(ŷ))
betas = CUDA.fill(log(typed_zero), U′, T)
output = CUDA.fill(log(typed_zero), U′, T)
nThreads = min(U′, MAX_THREADS)
Expand Down
1 change: 1 addition & 0 deletions test/outputsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
@test outputsize(m, (10,); padbatch=true) == (2, 1)
@test outputsize(m, (10, 30)) == (2, 30)

@info "Don't mind the following error, it's for testing purpose."
m = Chain(Dense(10, 8, σ), Dense(8, 4), Dense(5, 2))
@test_throws DimensionMismatch outputsize(m, (10,))

Expand Down
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Flux
using Flux.Data
using Flux: OneHotArray, OneHotMatrix, OneHotVector
using Test
using Random, Statistics, LinearAlgebra
using IterTools: ncycle
Expand Down Expand Up @@ -50,7 +51,7 @@ end
end
end

@static if VERSION == v"1.5"
@static if VERSION == v"1.6"
using Documenter
@testset "Docs" begin
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true)
Expand Down

0 comments on commit 0be6765

Please sign in to comment.