Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FastGelu float16 #621

Merged
merged 13 commits into from
Dec 11, 2023
173 changes: 173 additions & 0 deletions includes/custom_op_lite.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#pragma once
#include "onnxruntime_customop.hpp"
#include "onnxruntime_f16.h"
#include <optional>
#include <numeric>

Expand Down Expand Up @@ -83,6 +84,40 @@ struct Span {
const T* data() const { return data_; }
};

#if ORT_API_VERSION >= 16

template <>
struct Span<MFloat16> {
const MFloat16* data_ = {};
size_t size_ = {};
void Assign(const MFloat16* data, size_t size) {
data_ = data;
size_ = size;
}
size_t size() const { return size_; }
MFloat16 operator[](size_t indice) const {
return data_[indice];
}
const MFloat16* data() const { return data_; }
};

template <>
struct Span<BFloat16> {
const BFloat16* data_ = {};
size_t size_ = {};
void Assign(const BFloat16* data, size_t size) {
data_ = data;
size_ = size;
}
size_t size() const { return size_; }
BFloat16 operator[](size_t indice) const {
return data_[indice];
}
const BFloat16* data() const { return data_; }
};

#endif

template <typename T>
class Tensor : public TensorBase {
public:
Expand Down Expand Up @@ -316,6 +351,134 @@ class Tensor<std::string_view> : public TensorBase {
std::vector<std::string_view> input_string_views_; // for input
};

#if ORT_API_VERSION >= 16

template <>
struct Tensor<MFloat16> : public TensorBase {
Tensor(const OrtW::CustomOpApi& api,
OrtKernelContext& ctx,
size_t indice,
bool is_input) : TensorBase(api,
ctx,
indice,
is_input) {
type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
if (is_input_) {
auto input_count = api_.KernelContext_GetInputCount(&ctx_);
if (indice >= input_count) {
ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
}
const_value_ = api_.KernelContext_GetInput(&ctx_, indice);
auto* info = api_.GetTensorTypeAndShape(const_value_);
shape_ = api_.GetTensorShape(info);
type_ = api_.GetTensorElementType(info);
api_.ReleaseTensorTypeAndShapeInfo(info);
const OrtMemoryInfo* mem_info = {};
api_.ThrowOnError(api_.GetOrtApi().GetTensorMemoryInfo(const_value_, &mem_info));
if (mem_info) {
api_.ThrowOnError(api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type_));
}
}
}

const MFloat16* Data() const {
return reinterpret_cast<const MFloat16*>(api_.GetTensorData<uint16_t>(const_value_));
}

MFloat16* Allocate(const std::vector<int64_t>& shape) {
if (!data_) {
OrtValue* out = api_.KernelContext_GetOutput(&ctx_, indice_, shape.data(), shape.size());
shape_ = shape;
data_ = reinterpret_cast<MFloat16*>(api_.GetTensorMutableData<uint16_t>(out));
}
return data_;
}

const Span<MFloat16>& AsSpan() {
ORTX_CXX_API_THROW("AsSpan for MFloat16 not implemented", ORT_RUNTIME_EXCEPTION);
}

const MFloat16& AsScalar() {
ORTX_CXX_API_THROW("AsScalar for MFloat16 not implemented", ORT_RUNTIME_EXCEPTION);
}

const void* DataRaw() const override {
return reinterpret_cast<const void*>(Data());
}

virtual size_t SizeInBytes() const override {
return NumberOfElement() * sizeof(uint16_t);
}

private:
const OrtValue* const_value_{}; // for input
MFloat16* data_{}; // for output
};

template <>
struct Tensor<BFloat16> : public TensorBase {
Tensor(const OrtW::CustomOpApi& api,
OrtKernelContext& ctx,
size_t indice,
bool is_input) : TensorBase(api,
ctx,
indice,
is_input) {
type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
if (is_input_) {
auto input_count = api_.KernelContext_GetInputCount(&ctx_);
if (indice >= input_count) {
ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
}
const_value_ = api_.KernelContext_GetInput(&ctx_, indice);
auto* info = api_.GetTensorTypeAndShape(const_value_);
shape_ = api_.GetTensorShape(info);
type_ = api_.GetTensorElementType(info);
api_.ReleaseTensorTypeAndShapeInfo(info);
const OrtMemoryInfo* mem_info = {};
api_.ThrowOnError(api_.GetOrtApi().GetTensorMemoryInfo(const_value_, &mem_info));
if (mem_info) {
api_.ThrowOnError(api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type_));
}
}
}

const BFloat16* Data() const {
return reinterpret_cast<const BFloat16*>(api_.GetTensorData<uint16_t>(const_value_));
}

BFloat16* Allocate(const std::vector<int64_t>& shape) {
if (!data_) {
OrtValue* out = api_.KernelContext_GetOutput(&ctx_, indice_, shape.data(), shape.size());
shape_ = shape;
data_ = reinterpret_cast<BFloat16*>(api_.GetTensorMutableData<uint16_t>(out));
}
return data_;
}

const Span<BFloat16>& AsSpan() {
ORTX_CXX_API_THROW("AsSpan for BFloat16 not implemented", ORT_RUNTIME_EXCEPTION);
}

const BFloat16& AsScalar() {
ORTX_CXX_API_THROW("AsScalar for BFloat16 not implemented", ORT_RUNTIME_EXCEPTION);
}

const void* DataRaw() const override {
return reinterpret_cast<const void*>(Data());
}

virtual size_t SizeInBytes() const override {
return NumberOfElement() * sizeof(uint16_t);
}

private:
const OrtValue* const_value_{}; // for input
BFloat16* data_{}; // for output
};

#endif

using TensorPtr = std::unique_ptr<Custom::TensorBase>;
using TensorPtrs = std::vector<TensorPtr>;

Expand Down Expand Up @@ -438,6 +601,8 @@ struct CudaContext {

#endif

// using mf16_t = uint16_t;

struct OrtLiteCustomOp : public OrtCustomOp {
// CreateTuple
template <size_t ith_input, size_t ith_output, typename... Ts>
Expand Down Expand Up @@ -638,6 +803,10 @@ struct OrtLiteCustomOp : public OrtCustomOp {
CREATE_TUPLE_OUTPUT(data_type)

CREATE_TUPLE(bool)
#if ORT_API_VERSION >= 16
CREATE_TUPLE(MFloat16)
CREATE_TUPLE(BFloat16)
#endif
CREATE_TUPLE(float)
CREATE_TUPLE(double)
CREATE_TUPLE(int8_t)
Expand Down Expand Up @@ -759,6 +928,10 @@ struct OrtLiteCustomOp : public OrtCustomOp {
PARSE_OUTPUT(data_type, onnx_type)

PARSE_ARGS(bool, ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL)
#if ORT_API_VERSION >= 16
PARSE_ARGS(MFloat16, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16)
PARSE_ARGS(BFloat16, ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16)
#endif
PARSE_ARGS(float, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
PARSE_ARGS(double, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE)
PARSE_ARGS(int8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8)
Expand Down
Loading