Skip to content

Commit

Permalink
Fix Arrow convertor to honor dictionary encoding inside complex types. (
Browse files Browse the repository at this point in the history
#7171)

Summary:
Pull Request resolved: #7171

Arrow converter currently doesnt check for dictionary encoding etc when inside a complex type.

Reviewed By: pedroerp

Differential Revision: D50515953

fbshipit-source-id: de13583297cc3471b18b8a7d73f9a7cf1fd7aa11
  • Loading branch information
Krishna Pai authored and facebook-github-bot committed Nov 18, 2023
1 parent 460c4f1 commit a86cd7b
Show file tree
Hide file tree
Showing 3 changed files with 291 additions and 118 deletions.
234 changes: 127 additions & 107 deletions velox/vector/arrow/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,116 @@ void exportToArrowImpl(
out.release = releaseArrowArray;
}

TypePtr importFromArrowImpl(
const char* format,
const ArrowSchema& arrowSchema) {
VELOX_CHECK_NOT_NULL(format);

switch (format[0]) {
case 'b':
return BOOLEAN();
case 'c':
return TINYINT();
case 's':
return SMALLINT();
case 'i':
return INTEGER();
case 'l':
return BIGINT();
case 'f':
return REAL();
case 'g':
return DOUBLE();

// Map both utf-8 and large utf-8 string to varchar.
case 'u':
case 'U':
return VARCHAR();

// Same for binary.
case 'z':
case 'Z':
return VARBINARY();

case 't': // temporal types.
// Mapping it to ttn for now.
if (format[1] == 't' && format[2] == 'n') {
return TIMESTAMP();
}
if (format[1] == 'd' && format[2] == 'D') {
return DATE();
}
break;

case 'd': { // decimal types.
try {
std::string::size_type sz;
// Parse "d:".
int precision = std::stoi(&format[2], &sz);
// Parse ",".
int scale = std::stoi(&format[2 + sz + 1], &sz);
return DECIMAL(precision, scale);
} catch (std::invalid_argument& err) {
VELOX_USER_FAIL(
"Unable to convert '{}' ArrowSchema decimal format to Velox decimal",
format);
}
}

// Complex types.
case '+': {
switch (format[1]) {
// Array/list.
case 'l':
VELOX_CHECK_EQ(arrowSchema.n_children, 1);
VELOX_CHECK_NOT_NULL(arrowSchema.children[0]);
return ARRAY(importFromArrow(*arrowSchema.children[0]));

// Map.
case 'm': {
VELOX_CHECK_EQ(arrowSchema.n_children, 1);
VELOX_CHECK_NOT_NULL(arrowSchema.children[0]);
auto& child = *arrowSchema.children[0];
VELOX_CHECK_EQ(strcmp(child.format, "+s"), 0);
VELOX_CHECK_EQ(child.n_children, 2);
VELOX_CHECK_NOT_NULL(child.children[0]);
VELOX_CHECK_NOT_NULL(child.children[1]);
return MAP(
importFromArrow(*child.children[0]),
importFromArrow(*child.children[1]));
}

// Struct/rows.
case 's': {
// Loop collecting the child types and names.
std::vector<TypePtr> childTypes;
std::vector<std::string> childNames;
childTypes.reserve(arrowSchema.n_children);
childNames.reserve(arrowSchema.n_children);

for (size_t i = 0; i < arrowSchema.n_children; ++i) {
VELOX_CHECK_NOT_NULL(arrowSchema.children[i]);
childTypes.emplace_back(importFromArrow(*arrowSchema.children[i]));
childNames.emplace_back(
arrowSchema.children[i]->name != nullptr
? arrowSchema.children[i]->name
: "");
}
return ROW(std::move(childNames), std::move(childTypes));
}

default:
break;
}
} break;

default:
break;
}
VELOX_USER_FAIL(
"Unable to convert '{}' ArrowSchema format type to Velox.", format);
}

} // namespace

void exportToArrow(
Expand Down Expand Up @@ -1006,112 +1116,15 @@ void exportToArrow(const VectorPtr& vec, ArrowSchema& arrowSchema) {
}

TypePtr importFromArrow(const ArrowSchema& arrowSchema) {
const char* format = arrowSchema.format;
VELOX_CHECK_NOT_NULL(format);

switch (format[0]) {
case 'b':
return BOOLEAN();
case 'c':
return TINYINT();
case 's':
return SMALLINT();
case 'i':
return INTEGER();
case 'l':
return BIGINT();
case 'f':
return REAL();
case 'g':
return DOUBLE();

// Map both utf-8 and large utf-8 string to varchar.
case 'u':
case 'U':
return VARCHAR();

// Same for binary.
case 'z':
case 'Z':
return VARBINARY();

case 't': // temporal types.
// Mapping it to ttn for now.
if (format[1] == 't' && format[2] == 'n') {
return TIMESTAMP();
}
if (format[1] == 'd' && format[2] == 'D') {
return DATE();
}
break;

case 'd': { // decimal types.
try {
std::string::size_type sz;
// Parse "d:".
int precision = std::stoi(&format[2], &sz);
// Parse ",".
int scale = std::stoi(&format[2 + sz + 1], &sz);
return DECIMAL(precision, scale);
} catch (std::invalid_argument& err) {
VELOX_USER_FAIL(
"Unable to convert '{}' ArrowSchema decimal format to Velox decimal",
format);
}
}

// Complex types.
case '+': {
switch (format[1]) {
// Array/list.
case 'l':
VELOX_CHECK_EQ(arrowSchema.n_children, 1);
VELOX_CHECK_NOT_NULL(arrowSchema.children[0]);
return ARRAY(importFromArrow(*arrowSchema.children[0]));

// Map.
case 'm': {
VELOX_CHECK_EQ(arrowSchema.n_children, 1);
VELOX_CHECK_NOT_NULL(arrowSchema.children[0]);
auto& child = *arrowSchema.children[0];
VELOX_CHECK_EQ(strcmp(child.format, "+s"), 0);
VELOX_CHECK_EQ(child.n_children, 2);
VELOX_CHECK_NOT_NULL(child.children[0]);
VELOX_CHECK_NOT_NULL(child.children[1]);
return MAP(
importFromArrow(*child.children[0]),
importFromArrow(*child.children[1]));
}

// Struct/rows.
case 's': {
// Loop collecting the child types and names.
std::vector<TypePtr> childTypes;
std::vector<std::string> childNames;
childTypes.reserve(arrowSchema.n_children);
childNames.reserve(arrowSchema.n_children);

for (size_t i = 0; i < arrowSchema.n_children; ++i) {
VELOX_CHECK_NOT_NULL(arrowSchema.children[i]);
childTypes.emplace_back(importFromArrow(*arrowSchema.children[i]));
childNames.emplace_back(
arrowSchema.children[i]->name != nullptr
? arrowSchema.children[i]->name
: "");
}
return ROW(std::move(childNames), std::move(childTypes));
}

default:
break;
}
} break;

default:
break;
}
VELOX_USER_FAIL(
"Unable to convert '{}' ArrowSchema format type to Velox.", format);
// For dictionaries, format encodes the index type, while the dictionary value
// is encoded in the dictionary member, as per
// https://arrow.apache.org/docs/format/CDataInterface.html#dictionary-encoded-arrays.

const char* format = arrowSchema.dictionary ? arrowSchema.dictionary->format
: arrowSchema.format;
ArrowSchema schema =
arrowSchema.dictionary ? *arrowSchema.dictionary : arrowSchema;
return importFromArrowImpl(format, schema);
}

namespace {
Expand Down Expand Up @@ -1314,8 +1327,15 @@ VectorPtr importFromArrowImpl(
}

if (arrowSchema.dictionary) {
auto indexType = importFromArrowImpl(arrowSchema.format, arrowSchema);
return createDictionaryVector(
pool, type, nulls, arrowSchema, arrowArray, isViewer, wrapInBufferView);
pool,
indexType,
nulls,
arrowSchema,
arrowArray,
isViewer,
wrapInBufferView);
}

// String data types (VARCHAR and VARBINARY).
Expand Down
23 changes: 23 additions & 0 deletions velox/vector/arrow/tests/ArrowBridgeArrayTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "velox/core/QueryCtx.h"
#include "velox/vector/arrow/Bridge.h"
#include "velox/vector/tests/utils/VectorMaker.h"
#include "velox/vector/tests/utils/VectorTestBase.h"

namespace facebook::velox::test {
namespace {
Expand Down Expand Up @@ -663,6 +664,28 @@ TEST_F(ArrowBridgeArrayExportTest, arrayCrossValidate) {
}
}

TEST_F(ArrowBridgeArrayExportTest, arrayDictionary) {
auto vec = ({
auto indices = makeBuffer<vector_size_t>({1, 2, 0});
auto wrapped = vectorMaker_.flatVector<int64_t>({1, 2, 3});
auto inner = BaseVector::wrapInDictionary(nullptr, indices, 3, wrapped);
auto offsets = makeBuffer<vector_size_t>({2, 0});
auto sizes = makeBuffer<vector_size_t>({1, 1});
std::make_shared<ArrayVector>(
pool_.get(), ARRAY(inner->type()), nullptr, 2, offsets, sizes, inner);
});

ArrowSchema schema;
ArrowArray data;
velox::exportToArrow(vec, schema);
velox::exportToArrow(vec, data, vec->pool());

auto result = importFromArrowAsViewer(schema, data, vec->pool());
test::assertEqualVectors(result, vec);
schema.release(&schema);
data.release(&data);
}

TEST_F(ArrowBridgeArrayExportTest, arrayGap) {
auto elements = vectorMaker_.flatVector<int64_t>({1, 2, 3, 4, 5});
elements->setNull(3, true);
Expand Down
Loading

0 comments on commit a86cd7b

Please sign in to comment.