-
Notifications
You must be signed in to change notification settings - Fork 651
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added SmeLU #263
Added SmeLU #263
Conversation
excellent, thanks @kashif ! Couple of comments, I hope that helps, there's a trick with the definition of beta I think |
right i dont think the beta default arg would work as it is... I was just about to ask you how to deal with that? |
xformers/triton/k_activations.py
Outdated
|
||
|
||
@triton.jit | ||
def smelu(x, beta=2.0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that you can pass a default param with triton actually, it only works with a subset of the python syntax and my guess is that this is out of it (cc @ptillet). Something could be worth trying, having a getter for this kernel, like the following
def get_smelu_kernel(beta: float = 2.0): @triton.jit def smelu(x): pass # use beta here, but maybe that this will fail at the JIT phase
If that does not work,
- for a start we could have a fixed beta, then iterate on the implementation to expose it (completely fine by me)
- could be that the activation kernel take another parameter, which in that case would be the beta value, or that we figure out with Phil how to generate the kernel code on the fly with the proper beta
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @blefaudeux I'll give it a try... a bit late here so wanted to give it a shot in the morning 😴
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
checking this with Philippe, the default value should work actually, maybe that it needs to be : float ? or similar
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah ok! cool let me check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be clear, what should work (for now) is default arguments for tl.constexpr annotated arguments, and with triton 2.0 :p I'm not too sure about Triton 1.x
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah right... i'm on triton 1.x at the moment...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we need to update to triton2.. CI is blocking right now, I hope to get that sorted out this week end
ah cool! having a look |
hey @kashif let me know if I can help |
so I tried def get_smelu_kernel(x, beta: float = 2.0):
@triton.jit
def smelu(x, beta):
"""
SmeLU_ activation - Smooth ReLU
.. _SmeLU: https://arxiv.org/pdf/2202.06499.pdf
"""
zero = 0.0
four = 4.0
beta = beta.to(x.dtype)
output = (x + beta) * (x + beta) / (four.to(x.dtype) * beta)
relu = tl.where(x >= beta, x, zero.to(x.dtype))
return tl.where(tl.abs(x) <= beta, output, relu)
smelu(x, beta) but that didnt work either so I am setting the beta param to 2.0 for now. |
if you can I would really recommend setting up pre-commit, it helps with all the linting. Some explanations here |
Thanks for the updates @kashif, looks good to me, we can always iterate to expose beta down the line. It looks like the errors are unrelated to your changes, maybe dependent on main having changed, could you try to rebase ? I can do that also if you'd like |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, we need to solve these CI issues (guess is that it's waiting for a rebase), thanks a lot @kashif !
thank you! I learned a lot and now i have a 3090Ti to test on! |
i was just heading out to eat... can you kindly rebase? |
hmm, thanks for the update Kashif, looks like the errors are still there, I don't understand how they can be related. I'll have a look |
Hmm, I can repro the CI error but it should be unrelated to your changes @kashif, it means there's something wrong either in the triton stack or in the cuda kernels :( I'm trying to sort that out |
ok, at least I got something wrong: @fmassa, if I run
|
Note that it happens without fairscale (so without the MixtureOfExperts), same error, on "global" again |
I'm looking at the issue. |
This should be fixed with #300 Looks like some of the configurations in the test are generating a fully-empty (all zeros) matrix. Might be good to have a look to see if this is intended. |
For the record, here are my numbers on a desktop 3080 with an incoming fused linear PR (/main numbers should be somewhat close) --- Type: torch.float16 ---
--- Type: torch.float16 ---
|
What does this PR do?
Fixes #262 .
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.