Skip to content

Commit

Permalink
update Embedding constructor
Browse files Browse the repository at this point in the history
Updated Embedding constructor to use `=>` and added OneHotLikeVector and OneHotLikeMatrix consts.
  • Loading branch information
manikyabard committed Jul 12, 2021
1 parent 0ae379e commit 4532ec6
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 5 deletions.
6 changes: 3 additions & 3 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ function Base.show(io::IO, m::Parallel)
end

"""
Embedding(in, out; init=randn)
Embedding(in => out; init=randn)
A lookup table that stores embeddings of dimension `out`
for a vocabulary of size `in`.
Expand Down Expand Up @@ -409,9 +409,9 @@ end

@functor Embedding

function Embedding(in::Integer, out::Integer;
function Embedding(dims::Pair{<:Integer, <:Integer};
init = (i...) -> randn(Float32, i...))
return Embedding(init(out, in))
return Embedding(init(last(dims), first(dims)))
end

(m::Embedding)(x::Union{OneHotLikeVector, OneHotLikeMatrix}) = m.weight * x # equivalent to m.weight[:,onecold(x)]
Expand Down
3 changes: 3 additions & 0 deletions src/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ const OneHotLike{T, L, N, var"N+1", I} =
Union{OneHotArray{T, L, N, var"N+1", I},
Base.ReshapedArray{Bool, var"N+1", <:OneHotArray{T, L, <:Any, <:Any, I}}}

const OneHotLikeVector{T, L} = OneHotLike{T, L, 0, 1, T}
const OneHotLikeMatrix{T, L, I} = OneHotLike{T, L, 1, 2, I}

_isonehot(x::OneHotArray) = true
_isonehot(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray{<:Any, L}}) where L = (size(x, 1) == L)

Expand Down
2 changes: 1 addition & 1 deletion test/cuda/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ end

@testset "Embedding" begin
vocab_size, embed_size = 10, 4
m = Embedding(vocab_size, embed_size)
m = Embedding(vocab_size => embed_size)
x = rand(1:vocab_size, 3)
y = m(x)
m_g = m |> gpu
Expand Down
2 changes: 1 addition & 1 deletion test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ import Flux: activations

@testset "Embedding" begin
vocab_size, embed_size = 10, 4
m = Embedding(vocab_size, embed_size)
m = Embedding(vocab_size => embed_size)
@test size(m.weight) == (embed_size, vocab_size)

x = rand(1:vocab_size, 3)
Expand Down

0 comments on commit 4532ec6

Please sign in to comment.