Skip to content

Commit

Permalink
add outputsize special case for NNlib.gather
Browse files Browse the repository at this point in the history
  • Loading branch information
manikyabard committed Feb 13, 2022
1 parent 383bc02 commit de43bf5
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ Embedding(dims::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(init(las


(m::Embedding)(x::Integer) = m.weight[:, x]
(m::Embedding)(x::AbstractVector{<:Integer}) = NNlib.gather(m.weight, x)
(m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x)
(m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)
(m::Embedding)(x::AbstractArray{Bool}) = reshape(m(reshape(x, size(x, 1), :)), :, size(x)[2:end]...)
(m::Embedding)(x::AbstractVecOrMat{Bool}) = m.weight * x # handles OneHotLikeVector, OneHotLikeMatrix
Expand Down
2 changes: 1 addition & 1 deletion src/outputsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,4 +169,4 @@ for (fn, Dims) in ((:conv, DenseConvDims), (:depthwiseconv, DepthwiseConvDims))
end
end

(m::Embedding)(x::AbstractVecOrMat{<:Nil}) = fill(nil, size(m.weight, 1), length(x))
NNlib.gather!(dst::AbstractArray, ::AbstractArray, ::AbstractArray{<:Nil}) = fill(nil, size(dst)...)

0 comments on commit de43bf5

Please sign in to comment.