Skip to content

Commit

Permalink
update Embedding constructor
Browse files Browse the repository at this point in the history
Updated Embedding constructor to use `=>` and added OneHotLikeVector and OneHotLikeMatrix consts.
  • Loading branch information
manikyabard committed Jul 12, 2021
1 parent 1c61da8 commit b7c588a
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 5 deletions.
4 changes: 2 additions & 2 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ function Base.show(io::IO, m::Parallel)
end

"""
Embedding(in, out; init=randn)
Embedding(in => out; init=randn)
A lookup table that stores embeddings of dimension `out`
for a vocabulary of size `in`.
Expand Down Expand Up @@ -466,7 +466,7 @@ end

@functor Embedding

Embedding(in::Integer, out::Integer; init = randn32) = Embedding(init(out, in))
Embedding(dims::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(init(last(dims), first(dims)))

(m::Embedding)(x::Union{OneHotVector, OneHotMatrix}) = m.weight * x # equivalent to m.weight[:,onecold(x)]
(m::Embedding)(x::Integer) = m([x])
Expand Down
3 changes: 3 additions & 0 deletions src/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ const OneHotLike{T, L, N, var"N+1", I} =
Union{OneHotArray{T, L, N, var"N+1", I},
Base.ReshapedArray{Bool, var"N+1", <:OneHotArray{T, L, <:Any, <:Any, I}}}

const OneHotLikeVector{T, L} = OneHotLike{T, L, 0, 1, T}
const OneHotLikeMatrix{T, L, I} = OneHotLike{T, L, 1, 2, I}

_isonehot(x::OneHotArray) = true
_isonehot(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray{<:Any, L}}) where L = (size(x, 1) == L)

Expand Down
4 changes: 2 additions & 2 deletions test/cuda/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,8 @@ end

@testset "Embedding" begin
vocab_size, embed_size = 5, 2
m = Flux.Embedding(vocab_size, embed_size)

m = Flux.Embedding(vocab_size => embed_size)
x = [1, 3, 5]
y = m(x)
m_g = m |> gpu
Expand Down
2 changes: 1 addition & 1 deletion test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ import Flux: activations

@testset "Embedding" begin
vocab_size, embed_size = 10, 4
m = Flux.Embedding(vocab_size, embed_size)
m = Flux.Embedding(vocab_size => embed_size)
@test size(m.weight) == (embed_size, vocab_size)

x = rand(1:vocab_size, 3)
Expand Down

0 comments on commit b7c588a

Please sign in to comment.