Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FastGelu float16 #621

Merged
merged 13 commits into from
Dec 11, 2023
Merged

FastGelu float16 #621

merged 13 commits into from
Dec 11, 2023

Conversation

RandySheriffH
Copy link
Contributor

@RandySheriffH RandySheriffH commented Dec 7, 2023

Add float16 support for contrib cuda ops.

@RandySheriffH RandySheriffH marked this pull request as ready for review December 9, 2023 05:39
@RandySheriffH RandySheriffH requested a review from a team as a code owner December 9, 2023 05:39
#include "onnxruntime_c_api.h"
#if ORT_API_VERSION >= 16

#include "onnxruntime_float16.h"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this file shipping with ort C++ package?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes it is, since 1.16.

const T* bias_data = bias.has_value() ? (*bias)->Data() : nullptr;
auto bias_length = bias.has_value() ? (*bias)->NumberOfElement() : 0;
using TT = typename CudaT<T>::MappedType;
LaunchFastGeluKernel<TT>(reinterpret_cast<cudaStream_t>(ctx.cuda_stream),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the return error code be handled here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point - will report it in coming iteration.

]

input0 = helper.make_tensor_value_info(
'x', onnx_proto.TensorProto.FLOAT16, [])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dumb question, which fp16 was tested here, MFloat16 or BFloat16?

Copy link
Contributor Author

@RandySheriffH RandySheriffH Dec 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is MFloat16.
For BFloat16, we need to test it by native cases since the type is not exposed via python.

@RandySheriffH RandySheriffH merged commit 1ccc405 into main Dec 11, 2023
@RandySheriffH RandySheriffH deleted the rashuai/MFloat16 branch December 11, 2023 22:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants