Skip to content

Commit

Permalink
use randn32
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Jul 12, 2021
1 parent 058a4a0 commit dfb390d
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
7 changes: 2 additions & 5 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -471,11 +471,8 @@ end

@functor Embedding

function Embedding(in::Integer, out::Integer;
init = (i...) -> randn(Float32, i...))
return Embedding(init(out, in))
end

Embedding(in::Integer, out::Integer; init = randn32) = Embedding(init(out, in))

(m::Embedding)(x::Union{OneHotVector, OneHotMatrix}) = m.weight * x # equivalent to m.weight[:,onecold(x)]
(m::Embedding)(x::Integer) = m([x])
(m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x)
Expand Down
2 changes: 2 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,8 @@ identity_init(rng::AbstractRNG; init_kwargs...) = (args...;kwargs...) -> identit

ones32(dims...) = Base.ones(Float32, dims...)
zeros32(dims...) = Base.zeros(Float32, dims...)
rand32(dims...) = Base.rand(Float32, dims...)
randn32(dims...) = Base.randn(Float32, dims...)

"""
create_bias(weights, bias, length)
Expand Down

0 comments on commit dfb390d

Please sign in to comment.