Skip to content

Commit

Permalink
Sd3 text encoder refactoring (openvinotoolkit#1313)
Browse files Browse the repository at this point in the history
- Added unified function `concat` which accepts tensors and axis to
concat on
- Optimized `padding_right` function to avoid padding second tensor
before calling this function
- Optimized split by batch to avoid memory copy
  • Loading branch information
ilya-lavrenov authored Dec 5, 2024
1 parent 3a1bd24 commit b26fc8b
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 205 deletions.
2 changes: 1 addition & 1 deletion src/cpp/src/image_generation/flux_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ class FluxPipeline : public DiffusionPipeline {
void check_inputs(const ImageGenerationConfig& generation_config, ov::Tensor initial_image) const override {
check_image_size(generation_config.width, generation_config.height);

OPENVINO_ASSERT(generation_config.max_sequence_length <= 512, "T5's 'max_sequence_length' must be less than 512");
OPENVINO_ASSERT(generation_config.max_sequence_length <= 512, "T5's 'max_sequence_length' must be less or equal to 512");

OPENVINO_ASSERT(generation_config.negative_prompt == std::nullopt, "Negative prompt is not used by FluxPipeline");
OPENVINO_ASSERT(generation_config.negative_prompt_2 == std::nullopt, "Negative prompt 2 is not used by FluxPipeline");
Expand Down
51 changes: 46 additions & 5 deletions src/cpp/src/image_generation/numpy_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ std::vector<float> interp(const std::vector<std::int64_t>& x, const std::vector<
return interp_res;
}

void concat_3d_by_rows(const float* data_1, const float* data_2, float* res, const ov::Shape shape_1, const ov::Shape shape_2) {
namespace {

void concat_3d_axis_2(const float* data_1, const float* data_2, float* res, const ov::Shape shape_1, const ov::Shape shape_2) {
OPENVINO_ASSERT(shape_1.size() == 3 && shape_2.size() == 3, "Shape dimensions must be 3");
OPENVINO_ASSERT(shape_1[0] == shape_2[0] && shape_1[1] == shape_2[1], "Tensors for concatenation must have the same dimensions");

Expand All @@ -91,7 +93,7 @@ void concat_3d_by_rows(const float* data_1, const float* data_2, float* res, con
}
}

void concat_2d_by_rows(const float* data_1, const float* data_2, float* res, const ov::Shape shape_1, const ov::Shape shape_2) {
void concat_2d_axis_1(const float* data_1, const float* data_2, float* res, const ov::Shape shape_1, const ov::Shape shape_2) {
OPENVINO_ASSERT(shape_1.size() == 2 && shape_2.size() == 2, "Shape dimensions must be 2");
OPENVINO_ASSERT(shape_1[0] == shape_2[0], "Tensors for concatenation must have the same dimensions");

Expand All @@ -108,7 +110,7 @@ void concat_2d_by_rows(const float* data_1, const float* data_2, float* res, con
}
}

void concat_3d_by_cols(const float* data_1, const float* data_2, float* res, const ov::Shape shape_1, const ov::Shape shape_2) {
void concat_3d_axis_1(const float* data_1, const float* data_2, float* res, const ov::Shape shape_1, const ov::Shape shape_2) {
OPENVINO_ASSERT(shape_1.size() == 3 && shape_2.size() == 3, "Shape dimensions must be 3");
OPENVINO_ASSERT(shape_1[0] == shape_2[0] && shape_1[2] == shape_2[2], "Tensors for concatenation must have the same dimensions");

Expand All @@ -123,7 +125,7 @@ void concat_3d_by_cols(const float* data_1, const float* data_2, float* res, con
}
}

void concat_3d_by_channels(const float* data_1, const float* data_2, float* res, const ov::Shape shape_1, const ov::Shape shape_2) {
void concat_3d_axis_0(const float* data_1, const float* data_2, float* res, const ov::Shape shape_1, const ov::Shape shape_2) {
OPENVINO_ASSERT(shape_1.size() == 3 && shape_2.size() == 3, "Shape dimensions must be 3");
OPENVINO_ASSERT(shape_1[1] == shape_2[1] && shape_1[2] == shape_2[2], "Tensors for concatenation must have the same dimensions");

Expand All @@ -134,7 +136,7 @@ void concat_3d_by_channels(const float* data_1, const float* data_2, float* res,
std::memcpy(res + size_1, data_2, size_2 * sizeof(float));
}

void concat_2d_by_channels(const float* data_1, const float* data_2, float* res, const ov::Shape shape_1, const ov::Shape shape_2) {
void concat_2d_axis_0(const float* data_1, const float* data_2, float* res, const ov::Shape shape_1, const ov::Shape shape_2) {
OPENVINO_ASSERT(shape_1.size() == 2 && shape_2.size() == 2, "Shape dimensions must be 2");
OPENVINO_ASSERT(shape_1[1] == shape_2[1], "Tensors for concatenation must have the same dimensions");

Expand All @@ -145,6 +147,45 @@ void concat_2d_by_channels(const float* data_1, const float* data_2, float* res,
std::memcpy(res + size_1, data_2, size_2 * sizeof(float));
}

} // namespace

ov::Tensor concat(ov::Tensor tensor_1, ov::Tensor tensor_2, int axis) {
ov::Shape shape_1 = tensor_1.get_shape(), shape_2 = tensor_2.get_shape();
size_t rank = shape_1.size();

const size_t MAX_RANK = 3;
OPENVINO_ASSERT(rank <= MAX_RANK, "Maximum support rank of concatenated tensors is ", MAX_RANK, ", given rank is ", rank);

OPENVINO_ASSERT(rank == shape_2.size(), "Shapes for concatenated tensors must have the same rank");
OPENVINO_ASSERT(tensor_1.get_element_type() == ov::element::f32 && tensor_2.get_element_type() == ov::element::f32,
"Concat supports only tensor of fp32 data type");

if (axis < 0) {
axis += rank;
}

ov::Shape dst_shape(rank);
for (size_t d = 0; d < rank; ++d) {
OPENVINO_ASSERT(d == axis || shape_1[d] == shape_2[d], "Dimension ", d, " must be the same for tensor_1 (", shape_1[d], ") and tensor_2 (", shape_2[d], ")");
dst_shape[d] = d == axis ? shape_1[d] + shape_2[d] : shape_1[d];
}

typedef void (*concat_func_type) (const float*, const float*, float*, const ov::Shape, const ov::Shape);
concat_func_type concat_funcs [MAX_RANK][MAX_RANK] = {
{ nullptr, nullptr, nullptr },
{ concat_2d_axis_0, concat_2d_axis_1, nullptr },
{ concat_3d_axis_0, concat_3d_axis_1, concat_3d_axis_2 }
};

concat_func_type concat_func = concat_funcs[rank - 1][axis];
OPENVINO_ASSERT(concat_func != nullptr, "Unsupported combination of input tensors rank ", rank, " and axis ", axis);

ov::Tensor dst_tensor(tensor_1.get_element_type(), dst_shape);
concat_func(tensor_1.data<const float>(), tensor_2.data<const float>(), dst_tensor.data<float>(), shape_1, shape_2);

return dst_tensor;
}

void batch_copy(ov::Tensor src, ov::Tensor dst, size_t src_batch, size_t dst_batch, size_t batch_size) {
const ov::Shape src_shape = src.get_shape(), dst_shape = dst.get_shape();
ov::Coordinate src_start(src_shape.size(), 0), src_end = src_shape;
Expand Down
7 changes: 2 additions & 5 deletions src/cpp/src/image_generation/numpy_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,8 @@ void rescale_zero_terminal_snr(std::vector<float>& betas);
// np.interp(...) implementation
std::vector<float> interp(const std::vector<std::int64_t>& x, const std::vector<size_t>& xp, const std::vector<float>& fp);

void concat_3d_by_rows(const float* data_1, const float* data_2, float* res, const ov::Shape shape_1, const ov::Shape shape_2);
void concat_3d_by_cols(const float* data_1, const float* data_2, float* res, const ov::Shape shape_1, const ov::Shape shape_2);
void concat_3d_by_channels(const float* data_1, const float* data_2, float* res, const ov::Shape shape_1, const ov::Shape shape_2);
void concat_2d_by_rows(const float* data_1, const float* data_2, float* res, const ov::Shape shape_1, const ov::Shape shape_2);
void concat_2d_by_channels(const float* data_1, const float* data_2, float* res, const ov::Shape shape_1, const ov::Shape shape_2);
// concats two tensors by a given dimension
ov::Tensor concat(ov::Tensor tensor_1, ov::Tensor tensor_2, int axis);

void batch_copy(ov::Tensor src, ov::Tensor dst, size_t src_batch, size_t dst_batch, size_t batch_size = 1);
ov::Tensor repeat(const ov::Tensor input, const size_t num_images_per_prompt);
Expand Down
Loading

0 comments on commit b26fc8b

Please sign in to comment.