Skip to content

Commit

Permalink
Remove new op
Browse files Browse the repository at this point in the history
  • Loading branch information
duanqn committed Jun 12, 2024
1 parent b6bc59d commit db5fcb2
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 280 deletions.
166 changes: 0 additions & 166 deletions onnxruntime/contrib_ops/cpu/bert/flash_attention.cc

This file was deleted.

22 changes: 0 additions & 22 deletions onnxruntime/contrib_ops/cpu/bert/flash_attention.h

This file was deleted.

41 changes: 30 additions & 11 deletions onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MultiHeadAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FlashAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GroupQueryAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, RotaryEmbedding);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling);
Expand Down Expand Up @@ -65,10 +64,14 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint16_t, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int16_t, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int32_t, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UInt4x2, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Int4x2, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint16_t, QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int16_t, QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UInt4x2, QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Int4x2, QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QLinearLeakyRelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QLinearLeakyRelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QLinearSigmoid);
Expand Down Expand Up @@ -201,15 +204,32 @@ Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearConcat)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearWhere)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearAveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint16_t, DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int16_t, DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int32_t, DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QuantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QuantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint16_t, QuantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int16_t, QuantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t,
DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t,
DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint16_t,
DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int16_t,
DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int32_t,
DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UInt4x2,
DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Int4x2,
DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t,
QuantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t,
QuantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint16_t,
QuantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int16_t,
QuantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UInt4x2,
QuantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Int4x2,
QuantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QLinearLeakyRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QLinearLeakyRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QLinearSigmoid)>,
Expand Down Expand Up @@ -259,7 +279,6 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MultiHeadAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FlashAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GroupQueryAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, RotaryEmbedding)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling)>,
Expand Down
83 changes: 4 additions & 79 deletions onnxruntime/core/graph/contrib_ops/bert_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1043,85 +1043,6 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
MultiHeadAttentionTypeAndShapeInference(ctx, 6);
}));

ONNX_MS_OPERATOR_SET_SCHEMA(
FlashAttention, 1,
OpSchema()
.SetDoc("")
.Attr("num_heads", "Number of attention heads", AttributeProto::INT)
.Attr("mask_filter_value", "The value to be filled in the attention mask. Default value is -10000.0f",
AttributeProto::FLOAT, OPTIONAL_VALUE)
.Attr("scale",
"Custom scale will be used if specified. Default value is 1/sqrt(head_size)",
AttributeProto::FLOAT,
OPTIONAL_VALUE)
.Attr("unidirectional",
"Whether every token can only attend to previous tokens. Default value is 0.",
AttributeProto::INT,
static_cast<int64_t>(0))
.Input(0,
"query",
"Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape (batch_size, kv_sequence_length, num_heads, 3, head_size)",
"T")
.Input(1,
"key",
"Key with shape (batch_size, kv_sequence_length, hidden_size), or packed KV with shape (batch_size, kv_sequence_length, num_heads, 2, head_size), "
"or past_key with shape (batch_size, num_heads, kv_sequence_length, head_size)",
"T",
OpSchema::Optional)
.Input(2,
"value",
"Value with shape (batch_size, kv_sequence_length, v_hidden_size), or past_value with shape (batch_size, num_heads, kv_sequence_length, head_size)",
"T",
OpSchema::Optional)
.Input(3,
"bias",
"Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection",
"T",
OpSchema::Optional)
.Input(4,
"key_padding_mask",
"Key padding mask with shape (batch_size), (3 * batch_size + 2), (batch_size, kv_sequence_length), (batch_size, total_sequence_length), "
"or (batch_size, sequence_length, total_sequence_length)",
"M",
OpSchema::Optional)
.Input(5,
"relative_position_bias",
"relative position bias: addition to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)"
" or (1, num_heads, sequence_length, total_sequence_length)",
"T",
OpSchema::Optional)
.Input(6,
"past_key",
"past state for self attention key with shape (batch_size, num_heads, past_sequence_length, head_size)",
"T",
OpSchema::Optional)
.Input(7,
"past_value",
"past state for self attention value with shape (batch_size, num_heads, past_sequence_length, head_size)",
"T",
OpSchema::Optional)
.Output(0,
"output",
"3D output tensor with shape (batch_size, sequence_length, v_hidden_size)",
"T")
.Output(1,
"present_key",
"present state for cross attention key with shape (batch_size, num_heads, kv_sequence_length, head_size)"
"or present state for self attention key with shape (batch_size, num_heads, total_sequence_length, head_size)",
"T",
OpSchema::Optional)
.Output(2,
"present_value",
"present state for cross attention value with shape (batch_size, num_heads, kv_sequence_length, head_size)"
"or present state for self attention value with shape (batch_size, num_heads, total_sequence_length, head_size)",
"T",
OpSchema::Optional)
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output to float tensors.")
.TypeConstraint("M", {"tensor(int32)"}, "Constrain mask to integer types")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
MultiHeadAttentionTypeAndShapeInference(ctx, 6);
}));

constexpr const char* GroupQueryAttention_ver1_doc = R"DOC(
Group Query Self/Cross Attention.
Expand Down Expand Up @@ -1427,6 +1348,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"Number of attention heads. Default value is 0. Must use with rotary_embedding_dim",
AttributeProto::INT,
OPTIONAL_VALUE)
.Attr("is_packed_batching",
"ragged batch inputs or not. Default value is 0",
AttributeProto::INT,
OPTIONAL_VALUE)
.Input(0,
"input",
"3D tensor with shape (batch_size, sequence_length, hidden_size) or 4D with shape (batch_size, num_heads, sequence_length, head_size)",
Expand Down
Loading

0 comments on commit db5fcb2

Please sign in to comment.