Skip to content

Commit

Permalink
update Embedding constructor
Browse files Browse the repository at this point in the history
Updated Embedding constructor to use `=>` and added OneHotLikeVector and OneHotLikeMatrix consts.
  • Loading branch information
manikyabard committed Jul 14, 2021
1 parent 6a7688a commit 0eb7ed9
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 10 deletions.
4 changes: 2 additions & 2 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ function Base.show(io::IO, m::Parallel)
end

"""
Embedding(in, out; init=randn)
Embedding(in => out; init=randn)
A lookup table that stores embeddings of dimension `out`
for a vocabulary of size `in`.
Expand Down Expand Up @@ -466,7 +466,7 @@ end

@functor Embedding

Embedding(in::Integer, out::Integer; init = randn32) = Embedding(init(out, in))
Embedding(dims::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(init(last(dims), first(dims)))


(m::Embedding)(x::Integer) = m.weight[:, x]
Expand Down
3 changes: 3 additions & 0 deletions src/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ const OneHotLike{T, L, N, var"N+1", I} =
Union{OneHotArray{T, L, N, var"N+1", I},
Base.ReshapedArray{Bool, var"N+1", <:OneHotArray{T, L, <:Any, <:Any, I}}}

const OneHotLikeVector{T, L} = OneHotLike{T, L, 0, 1, T}
const OneHotLikeMatrix{T, L, I} = OneHotLike{T, L, 1, 2, I}

_isonehot(x::OneHotArray) = true
_isonehot(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray{<:Any, L}}) where L = (size(x, 1) == L)

Expand Down
14 changes: 7 additions & 7 deletions test/cuda/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,13 @@ gpu_gradtest("PixelShuffle 2d", pixelshuffle, rand(Float32, 3, 4, 18, 3), 3)
gpu_gradtest("PixelShuffle 1d", pixelshuffle, rand(Float32, 3, 18, 3), 3)

embedding = [Flux.Embedding]
gpu_gradtest("Embedding", embedding, [1,3,5], 5, 2)
gpu_gradtest("Embedding repeated indices", embedding, [1,3,5,3], 5, 2)
gpu_gradtest("Embedding integer index", embedding, 1, 5, 2)
gpu_gradtest("Embedding 2d index", embedding, [1 2; 3 4], 5, 2)
gpu_gradtest("Embedding OneHotVec index", embedding, OneHotVector(1, 5), 5, 2)
gpu_gradtest("Embedding OneHotMatrix index", embedding, OneHotMatrix([1,2,3], 5), 5, 2)
gpu_gradtest("Embedding OneHotMatrix repeated indices", embedding, OneHotMatrix([1,2,2], 5), 5, 2)
gpu_gradtest("Embedding", embedding, [1,3,5], 5 => 2)
gpu_gradtest("Embedding repeated indices", embedding, [1,3,5,3], 5 => 2)
gpu_gradtest("Embedding integer index", embedding, 1, 5 => 2)
gpu_gradtest("Embedding 2d index", embedding, [1 2; 3 4], 5 => 2)
gpu_gradtest("Embedding OneHotVec index", embedding, OneHotVector(1, 5), 5 => 2)
gpu_gradtest("Embedding OneHotMatrix index", embedding, OneHotMatrix([1,2,3], 5), 5 => 2)
gpu_gradtest("Embedding OneHotMatrix repeated indices", embedding, OneHotMatrix([1,2,2], 5), 5 => 2)

@testset "function layers" begin
x = rand(Float32, 3,3)
Expand Down
2 changes: 1 addition & 1 deletion test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ import Flux: activations

@testset "Embedding" begin
vocab_size, embed_size = 10, 4
m = Flux.Embedding(vocab_size, embed_size)
m = Flux.Embedding(vocab_size => embed_size)
@test size(m.weight) == (embed_size, vocab_size)

x = rand(1:vocab_size, 3)
Expand Down

0 comments on commit 0eb7ed9

Please sign in to comment.