-
-
Notifications
You must be signed in to change notification settings - Fork 609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add Embedding layer #1516
add Embedding layer #1516
Conversation
Needs GPU tests |
For such a straightforward routine, do we need a layer? I guess people are used to seeing them elsewhere, but why should that make it a necessary thing. Might just make it a necessary evil. Oh well. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems simple, makes sense to add.
Yeah, I don't think we want to take the charge on |
Since I need to define and use |
|
please take care of the rest of the comments in the meantime? I don't see how mat was something needed to be defined. The reshape api is pretty clever so as to be self documenting, but I can deal with that later. |
don't see anything to take care of
it is defined because it is convenient, same as flatten, when a code pattern is used in multiple spots it is typically factorized into a function, this is how people write programs |
These ones. |
I don't know what you refer to here
I asked to keep the change here |
@mcabbott given the aliasing problem, I think that the thing to do here is to restrict the input to cpu only (i.e. error out on CuArrays) for the time being so that we can merge this safely. When we have the gather/scatter infrastructure in place in NNlib.jl and CUDA.jl we switch the backend here. |
Is there a reason to favor the reliance on ScatterNNLib for the embedding operations rather than having a fix for the incorrect gradient when indices have repeats? It's my understanding that JuliaLang/julia#31407 was to address it, but it looks stalled. Any idea how to raise awareness on this issue would be welcome! |
Yes, it would be good to fix FluxML/Zygote.jl#821. |
this can now be implemented on top of the scatter/gather functions implemented in NNlib |
This layer needs to special case (like the normalization layers) in Something like (m::Embedding)(x::AbstractVector{<:Nil}) = fill(nil, size(m.weight, 1), length(x))
(m::Embedding)(x::AbstractArray{<:Nil}) = fill(nil, size(m.weight, 1), last(size(x))) |
2670684
to
4d3944c
Compare
gpu failure seems unrelated |
@DhairyaLGandhi can you dismiss your change request? |
This PR was included in a batch that successfully built, but then failed to merge into master. It will not be retried. Additional information: {"message":"1 review requesting changes and 1 approving review by reviewers with write access.","documentation_url":"https://docs.github.com/articles/about-protected-branches"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've left a couple thoughts.
There seem to be a lot of changes around code not intended to be changed in this PR (around the field names of Dense
) which should go in a different PR.
Also needs GPU tests with one-hot and integers. |
Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com>
Co-authored-by: Manikya <manikyabard@gmail.com>
Co-authored-by: Dhairya Gandhi <dhairya@juliacomputing.com>
Co-authored-by: Dhairya Gandhi <dhairya@juliacomputing.com>
Co-authored-by: Dhairya Gandhi <dhairya@juliacomputing.com>
Co-authored-by: Dhairya Gandhi <dhairya@juliacomputing.com>
Co-authored-by: Dhairya Gandhi <dhairya@juliacomputing.com>
My comment pointed out adding tests for the features this PR wants to bring in. #1656 aims to build on top of this, and has its own opinions about API, not implementation. |
need an approval again |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bors r+
Build succeeded: |
Basic implementation.
Maybe could be improved when FluxML/NNlib.jl#255 lands
PR Checklist
@dhairyagandhi96
(for API changes).