Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Feb 23, 2021
1 parent 49e42a7 commit 2398dfc
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 20 deletions.
7 changes: 4 additions & 3 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,13 @@ end

@functor Dense

function (a::Dense)(x::Union{AbstractVector, AbstractMatrix})
function (a::Dense)(x::AbstractVecOrMat)
W, b, σ = a.W, a.b, a.σ
return σ.(W*x .+ b)
end

(a::Dense)(x::AbstractArray) = reshape(a(mat(x)), :, size(x)[2:end]...)
(a::Dense)(x::AbstractArray) =
reshape(a(reshape(x, size(x,1), :)), :, size(x)[2:end]...)

function Base.show(io::IO, l::Dense)
print(io, "Dense(", size(l.W, 2), ", ", size(l.W, 1))
Expand Down Expand Up @@ -418,7 +419,7 @@ end

(m::Embedding)(x::Union{OneHotVector, OneHotMatrix}) = m.weight * x # equivalent to m.weight[:,onecold(x)]
(m::Embedding)(x::Union{Int,AbstractVector}) = m.weight[:, x]
(m::Embedding)(x::AbstractArray) = reshape(m(mat(x)), :, size(x)[2:end]...)
(m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)

function Base.show(io::IO, m::Embedding)
print(io, "Embedding($(size(m.weight, 2)), $(size(m.weight, 1)))")
Expand Down
14 changes: 1 addition & 13 deletions src/layers/stateless.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Reshape arbitrarly-shaped input into a matrix-shaped output,
preserving the size of the last dimension.
See also [`unsqueeze`](@ref) and [`mat`](@ref).
See also [`unsqueeze`](@ref).
# Examples
```jldoctest
Expand All @@ -26,18 +26,6 @@ function flatten(x::AbstractArray)
return reshape(x, :, size(x)[end])
end

"""
mat(x::AbstractArray)
Reshape arbitrarly-shaped input into a matrix-shaped output,
preserving the size of the first dimension.
See also [`flatten`](@ref) and [`unsqueeze`](@ref).
"""
function mat(x::AbstractArray)
return reshape(x, size(x,1), :)
end

"""
normalise(x; dims=ndims(x), ϵ=1e-5)
Expand Down
3 changes: 0 additions & 3 deletions test/cuda/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,6 @@ pixelshuffle = [PixelShuffle]
gpu_gradtest("PixelShuffle 2d", pixelshuffle, rand(Float32, 3, 4, 18, 3), 3)
gpu_gradtest("PixelShuffle 1d", pixelshuffle, rand(Float32, 3, 18, 3), 3)

embedding = [Embedding]
gpu_gradtest("Embedding", embedding, rand(1:10, 3), 10, 4)

@testset "function layers" begin
x = rand(Float32, 3,3)
gpu_autodiff_test(x -> sum(Flux.normalise(x; dims=1)), x)
Expand Down
6 changes: 5 additions & 1 deletion test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,17 @@ import Flux: activations
y = m(x)
@test y isa Matrix{Float32}
@test y m.weight[:,x]

x2 = OneHotMatrix(x, vocab_size)
y2 = m(x2)
@test y2 isa Matrix{Float32}
@test y2 y
@test_throws DimensionMismatch m(OneHotMatrix(x, 1000))

x = rand(1:vocab_size, 3, 4)
y = m(x)
@test y isa Array{Float32, 3}
@test size(y) == (embed_size, 3, 4)

@test m(2) m.weight[:,2]
@test m(OneHotVector(3, vocab_size)) m.weight[:,3]
@test_throws DimensionMismatch m(OneHotVector(3, 1000))
Expand Down

0 comments on commit 2398dfc

Please sign in to comment.