From 2398dfce1556b0811ddd598ea25335b2c3d14b99 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 23 Feb 2021 23:09:04 +0100 Subject: [PATCH] fix bug --- src/layers/basic.jl | 7 ++++--- src/layers/stateless.jl | 14 +------------- test/cuda/layers.jl | 3 --- test/layers/basic.jl | 6 +++++- 4 files changed, 10 insertions(+), 20 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 9cb5c58c92..c8d7f6f677 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -120,12 +120,13 @@ end @functor Dense -function (a::Dense)(x::Union{AbstractVector, AbstractMatrix}) +function (a::Dense)(x::AbstractVecOrMat) W, b, σ = a.W, a.b, a.σ return σ.(W*x .+ b) end -(a::Dense)(x::AbstractArray) = reshape(a(mat(x)), :, size(x)[2:end]...) +(a::Dense)(x::AbstractArray) = + reshape(a(reshape(x, size(x,1), :)), :, size(x)[2:end]...) function Base.show(io::IO, l::Dense) print(io, "Dense(", size(l.W, 2), ", ", size(l.W, 1)) @@ -418,7 +419,7 @@ end (m::Embedding)(x::Union{OneHotVector, OneHotMatrix}) = m.weight * x # equivalent to m.weight[:,onecold(x)] (m::Embedding)(x::Union{Int,AbstractVector}) = m.weight[:, x] -(m::Embedding)(x::AbstractArray) = reshape(m(mat(x)), :, size(x)[2:end]...) +(m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...) function Base.show(io::IO, m::Embedding) print(io, "Embedding($(size(m.weight, 2)), $(size(m.weight, 1)))") diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index c02cbccad2..1a3a0df5ec 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -4,7 +4,7 @@ Reshape arbitrarly-shaped input into a matrix-shaped output, preserving the size of the last dimension. -See also [`unsqueeze`](@ref) and [`mat`](@ref). +See also [`unsqueeze`](@ref). # Examples ```jldoctest @@ -26,18 +26,6 @@ function flatten(x::AbstractArray) return reshape(x, :, size(x)[end]) end -""" - mat(x::AbstractArray) - -Reshape arbitrarly-shaped input into a matrix-shaped output, -preserving the size of the first dimension. - -See also [`flatten`](@ref) and [`unsqueeze`](@ref). -""" -function mat(x::AbstractArray) - return reshape(x, size(x,1), :) -end - """ normalise(x; dims=ndims(x), ϵ=1e-5) diff --git a/test/cuda/layers.jl b/test/cuda/layers.jl index 0019b37a97..379ae35864 100644 --- a/test/cuda/layers.jl +++ b/test/cuda/layers.jl @@ -87,9 +87,6 @@ pixelshuffle = [PixelShuffle] gpu_gradtest("PixelShuffle 2d", pixelshuffle, rand(Float32, 3, 4, 18, 3), 3) gpu_gradtest("PixelShuffle 1d", pixelshuffle, rand(Float32, 3, 18, 3), 3) -embedding = [Embedding] -gpu_gradtest("Embedding", embedding, rand(1:10, 3), 10, 4) - @testset "function layers" begin x = rand(Float32, 3,3) gpu_autodiff_test(x -> sum(Flux.normalise(x; dims=1)), x) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 184bb8796b..a43583b8df 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -163,13 +163,17 @@ import Flux: activations y = m(x) @test y isa Matrix{Float32} @test y ≈ m.weight[:,x] - x2 = OneHotMatrix(x, vocab_size) y2 = m(x2) @test y2 isa Matrix{Float32} @test y2 ≈ y @test_throws DimensionMismatch m(OneHotMatrix(x, 1000)) + x = rand(1:vocab_size, 3, 4) + y = m(x) + @test y isa Array{Float32, 3} + @test size(y) == (embed_size, 3, 4) + @test m(2) ≈ m.weight[:,2] @test m(OneHotVector(3, vocab_size)) ≈ m.weight[:,3] @test_throws DimensionMismatch m(OneHotVector(3, 1000))