From b97c52701819e1eea1da2699917b5d3620be1367 Mon Sep 17 00:00:00 2001 From: Krishna Pai Date: Thu, 16 Nov 2023 12:54:00 -0800 Subject: [PATCH] Fix Arrow convertor to honor dictionary encoding inside complex types. (#7171) Summary: Arrow converter currently doesnt check for dictionary encoding etc when inside a complex type. Reviewed By: pedroerp Differential Revision: D50515953 --- velox/vector/arrow/Bridge.cpp | 234 ++++++++++-------- .../arrow/tests/ArrowBridgeArrayTest.cpp | 21 ++ .../arrow/tests/ArrowBridgeSchemaTest.cpp | 152 +++++++++++- 3 files changed, 289 insertions(+), 118 deletions(-) diff --git a/velox/vector/arrow/Bridge.cpp b/velox/vector/arrow/Bridge.cpp index c5ac55948c1c2..d84d1a875c1e3 100644 --- a/velox/vector/arrow/Bridge.cpp +++ b/velox/vector/arrow/Bridge.cpp @@ -871,6 +871,116 @@ void exportToArrowImpl( out.release = releaseArrowArray; } +TypePtr importFromArrowImpl( + const char* format, + const ArrowSchema& arrowSchema) { + VELOX_CHECK_NOT_NULL(format); + + switch (format[0]) { + case 'b': + return BOOLEAN(); + case 'c': + return TINYINT(); + case 's': + return SMALLINT(); + case 'i': + return INTEGER(); + case 'l': + return BIGINT(); + case 'f': + return REAL(); + case 'g': + return DOUBLE(); + + // Map both utf-8 and large utf-8 string to varchar. + case 'u': + case 'U': + return VARCHAR(); + + // Same for binary. + case 'z': + case 'Z': + return VARBINARY(); + + case 't': // temporal types. + // Mapping it to ttn for now. + if (format[1] == 't' && format[2] == 'n') { + return TIMESTAMP(); + } + if (format[1] == 'd' && format[2] == 'D') { + return DATE(); + } + break; + + case 'd': { // decimal types. + try { + std::string::size_type sz; + // Parse "d:". + int precision = std::stoi(&format[2], &sz); + // Parse ",". + int scale = std::stoi(&format[2 + sz + 1], &sz); + return DECIMAL(precision, scale); + } catch (std::invalid_argument& err) { + VELOX_USER_FAIL( + "Unable to convert '{}' ArrowSchema decimal format to Velox decimal", + format); + } + } + + // Complex types. + case '+': { + switch (format[1]) { + // Array/list. + case 'l': + VELOX_CHECK_EQ(arrowSchema.n_children, 1); + VELOX_CHECK_NOT_NULL(arrowSchema.children[0]); + return ARRAY(importFromArrow(*arrowSchema.children[0])); + + // Map. + case 'm': { + VELOX_CHECK_EQ(arrowSchema.n_children, 1); + VELOX_CHECK_NOT_NULL(arrowSchema.children[0]); + auto& child = *arrowSchema.children[0]; + VELOX_CHECK_EQ(strcmp(child.format, "+s"), 0); + VELOX_CHECK_EQ(child.n_children, 2); + VELOX_CHECK_NOT_NULL(child.children[0]); + VELOX_CHECK_NOT_NULL(child.children[1]); + return MAP( + importFromArrow(*child.children[0]), + importFromArrow(*child.children[1])); + } + + // Struct/rows. + case 's': { + // Loop collecting the child types and names. + std::vector childTypes; + std::vector childNames; + childTypes.reserve(arrowSchema.n_children); + childNames.reserve(arrowSchema.n_children); + + for (size_t i = 0; i < arrowSchema.n_children; ++i) { + VELOX_CHECK_NOT_NULL(arrowSchema.children[i]); + childTypes.emplace_back(importFromArrow(*arrowSchema.children[i])); + childNames.emplace_back( + arrowSchema.children[i]->name != nullptr + ? arrowSchema.children[i]->name + : ""); + } + return ROW(std::move(childNames), std::move(childTypes)); + } + + default: + break; + } + } break; + + default: + break; + } + VELOX_USER_FAIL( + "Unable to convert '{}' ArrowSchema format type to Velox.", format); +} + } // namespace void exportToArrow( @@ -1006,112 +1116,15 @@ void exportToArrow(const VectorPtr& vec, ArrowSchema& arrowSchema) { } TypePtr importFromArrow(const ArrowSchema& arrowSchema) { - const char* format = arrowSchema.format; - VELOX_CHECK_NOT_NULL(format); - - switch (format[0]) { - case 'b': - return BOOLEAN(); - case 'c': - return TINYINT(); - case 's': - return SMALLINT(); - case 'i': - return INTEGER(); - case 'l': - return BIGINT(); - case 'f': - return REAL(); - case 'g': - return DOUBLE(); - - // Map both utf-8 and large utf-8 string to varchar. - case 'u': - case 'U': - return VARCHAR(); - - // Same for binary. - case 'z': - case 'Z': - return VARBINARY(); - - case 't': // temporal types. - // Mapping it to ttn for now. - if (format[1] == 't' && format[2] == 'n') { - return TIMESTAMP(); - } - if (format[1] == 'd' && format[2] == 'D') { - return DATE(); - } - break; - - case 'd': { // decimal types. - try { - std::string::size_type sz; - // Parse "d:". - int precision = std::stoi(&format[2], &sz); - // Parse ",". - int scale = std::stoi(&format[2 + sz + 1], &sz); - return DECIMAL(precision, scale); - } catch (std::invalid_argument& err) { - VELOX_USER_FAIL( - "Unable to convert '{}' ArrowSchema decimal format to Velox decimal", - format); - } - } - - // Complex types. - case '+': { - switch (format[1]) { - // Array/list. - case 'l': - VELOX_CHECK_EQ(arrowSchema.n_children, 1); - VELOX_CHECK_NOT_NULL(arrowSchema.children[0]); - return ARRAY(importFromArrow(*arrowSchema.children[0])); - - // Map. - case 'm': { - VELOX_CHECK_EQ(arrowSchema.n_children, 1); - VELOX_CHECK_NOT_NULL(arrowSchema.children[0]); - auto& child = *arrowSchema.children[0]; - VELOX_CHECK_EQ(strcmp(child.format, "+s"), 0); - VELOX_CHECK_EQ(child.n_children, 2); - VELOX_CHECK_NOT_NULL(child.children[0]); - VELOX_CHECK_NOT_NULL(child.children[1]); - return MAP( - importFromArrow(*child.children[0]), - importFromArrow(*child.children[1])); - } - - // Struct/rows. - case 's': { - // Loop collecting the child types and names. - std::vector childTypes; - std::vector childNames; - childTypes.reserve(arrowSchema.n_children); - childNames.reserve(arrowSchema.n_children); - - for (size_t i = 0; i < arrowSchema.n_children; ++i) { - VELOX_CHECK_NOT_NULL(arrowSchema.children[i]); - childTypes.emplace_back(importFromArrow(*arrowSchema.children[i])); - childNames.emplace_back( - arrowSchema.children[i]->name != nullptr - ? arrowSchema.children[i]->name - : ""); - } - return ROW(std::move(childNames), std::move(childTypes)); - } - - default: - break; - } - } break; - - default: - break; - } - VELOX_USER_FAIL( - "Unable to convert '{}' ArrowSchema format type to Velox.", format); + // As per + // https://arrow.apache.org/docs/format/CDataInterface.html#dictionary-encoded-arrays + // format encodes the index type, and the value type is encoded in the + // dictionary. + const char* format = arrowSchema.dictionary ? arrowSchema.dictionary->format + : arrowSchema.format; + ArrowSchema schema = + arrowSchema.dictionary ? *arrowSchema.dictionary : arrowSchema; + return importFromArrowImpl(format, schema); } namespace { @@ -1314,8 +1327,15 @@ VectorPtr importFromArrowImpl( } if (arrowSchema.dictionary) { + auto indexType = importFromArrowImpl(arrowSchema.format, arrowSchema); return createDictionaryVector( - pool, type, nulls, arrowSchema, arrowArray, isViewer, wrapInBufferView); + pool, + indexType, + nulls, + arrowSchema, + arrowArray, + isViewer, + wrapInBufferView); } // String data types (VARCHAR and VARBINARY). diff --git a/velox/vector/arrow/tests/ArrowBridgeArrayTest.cpp b/velox/vector/arrow/tests/ArrowBridgeArrayTest.cpp index 68bd69bae12e6..a1eb7a3a0dbd2 100644 --- a/velox/vector/arrow/tests/ArrowBridgeArrayTest.cpp +++ b/velox/vector/arrow/tests/ArrowBridgeArrayTest.cpp @@ -24,6 +24,7 @@ #include "velox/core/QueryCtx.h" #include "velox/vector/arrow/Bridge.h" #include "velox/vector/tests/utils/VectorMaker.h" +#include "velox/vector/tests/utils/VectorTestBase.h" namespace facebook::velox::test { namespace { @@ -663,6 +664,26 @@ TEST_F(ArrowBridgeArrayExportTest, arrayCrossValidate) { } } +TEST_F(ArrowBridgeArrayExportTest, arrayDictionary) { + auto vec = ({ + auto indices = makeBuffer({1, 2, 0}); + auto wrapped = vectorMaker_.flatVector({1, 2, 3}); + auto inner = BaseVector::wrapInDictionary(nullptr, indices, 3, wrapped); + auto offsets = makeBuffer({2, 0}); + auto sizes = makeBuffer({1, 1}); + std::make_shared( + pool_.get(), ARRAY(inner->type()), nullptr, 2, offsets, sizes, inner); + }); + + ArrowSchema schema; + ArrowArray data; + velox::exportToArrow(vec, schema); + velox::exportToArrow(vec, data, vec->pool()); + + auto result = importFromArrowAsViewer(schema, data, vec->pool()); + test::assertEqualVectors(result, vec); +} + TEST_F(ArrowBridgeArrayExportTest, arrayGap) { auto elements = vectorMaker_.flatVector({1, 2, 3, 4, 5}); elements->setNull(3, true); diff --git a/velox/vector/arrow/tests/ArrowBridgeSchemaTest.cpp b/velox/vector/arrow/tests/ArrowBridgeSchemaTest.cpp index 2a93c8f825ed7..7277bd2e060db 100644 --- a/velox/vector/arrow/tests/ArrowBridgeSchemaTest.cpp +++ b/velox/vector/arrow/tests/ArrowBridgeSchemaTest.cpp @@ -257,13 +257,27 @@ class ArrowBridgeSchemaImportTest : public ArrowBridgeSchemaExportTest { return type; } - TypePtr testSchemaImportComplex( + TypePtr testSchemaDictionaryImport(const char* indexFmt, ArrowSchema schema) { + auto dictionarySchema = makeArrowSchema(indexFmt); + dictionarySchema.dictionary = &schema; + + auto type = importFromArrow(dictionarySchema); + dictionarySchema.release(&dictionarySchema); + return type; + } + + ArrowSchema makeComplexArrowSchema( + std::vector& schemas, + std::vector& schemaPtrs, + std::vector& mapSchemas, + std::vector& mapSchemaPtrs, const char* mainFormat, const std::vector& childrenFormat, const std::vector& colNames = {}) { - std::vector schemas; - std::vector schemaPtrs; - + schemas.clear(); + schemaPtrs.clear(); + mapSchemas.clear(); + mapSchemaPtrs.clear(); schemas.resize(childrenFormat.size()); schemaPtrs.resize(childrenFormat.size()); @@ -278,18 +292,40 @@ class ArrowBridgeSchemaImportTest : public ArrowBridgeSchemaExportTest { auto mainSchema = makeArrowSchema(mainFormat); if (strcmp(mainFormat, "+m") == 0) { // Arrow wraps key and value in a struct. - auto child = makeArrowSchema("+s"); - auto children = &child; - child.n_children = schemaPtrs.size(); - child.children = schemaPtrs.data(); + mapSchemas.resize(1); + mapSchemaPtrs.resize(1); + mapSchemas[0] = makeArrowSchema("+s"); + auto* child = &mapSchemas[0]; + mapSchemaPtrs[0] = &mapSchemas[0]; + child->n_children = schemaPtrs.size(); + child->children = schemaPtrs.data(); mainSchema.n_children = 1; - mainSchema.children = &children; - return importFromArrow(mainSchema); + mainSchema.children = mapSchemaPtrs.data(); } else { mainSchema.n_children = (int64_t)schemaPtrs.size(); mainSchema.children = schemaPtrs.data(); - return importFromArrow(mainSchema); } + + return mainSchema; + } + + TypePtr testSchemaImportComplex( + const char* mainFormat, + const std::vector& childrenFormat, + const std::vector& colNames = {}) { + std::vector schemas; + std::vector mapSchemas; + std::vector schemaPtrs; + std::vector mapSchemaPtrs; + auto type = importFromArrow(makeComplexArrowSchema( + schemas, + schemaPtrs, + mapSchemas, + mapSchemaPtrs, + mainFormat, + childrenFormat, + colNames)); + return type; } }; @@ -433,5 +469,99 @@ TEST_F(ArrowBridgeSchemaTest, validateInArrow) { } } +TEST_F(ArrowBridgeSchemaImportTest, dictionaryTypeTest) { + // Primitive types + EXPECT_EQ(DOUBLE(), testSchemaDictionaryImport("i", makeArrowSchema("g"))); + EXPECT_EQ(BOOLEAN(), testSchemaDictionaryImport("i", makeArrowSchema("b"))); + EXPECT_EQ(TINYINT(), testSchemaDictionaryImport("i", makeArrowSchema("c"))); + EXPECT_EQ(INTEGER(), testSchemaDictionaryImport("i", makeArrowSchema("i"))); + EXPECT_EQ(SMALLINT(), testSchemaDictionaryImport("i", makeArrowSchema("s"))); + EXPECT_EQ(BIGINT(), testSchemaDictionaryImport("i", makeArrowSchema("l"))); + EXPECT_EQ(REAL(), testSchemaDictionaryImport("i", makeArrowSchema("f"))); + EXPECT_EQ(VARCHAR(), testSchemaDictionaryImport("i", makeArrowSchema("u"))); + + std::vector schemas; + std::vector mapSchemas; + std::vector mapSchemaPtrs; + std::vector schemaPtrs; + + // Arrays + EXPECT_EQ( + *ARRAY(BIGINT()), + *testSchemaDictionaryImport( + "i", + makeComplexArrowSchema( + schemas, schemaPtrs, mapSchemas, mapSchemaPtrs, "+l", {"l"}))); + EXPECT_EQ( + *ARRAY(TIMESTAMP()), + *testSchemaDictionaryImport( + "i", + makeComplexArrowSchema( + schemas, schemaPtrs, mapSchemas, mapSchemaPtrs, "+l", {"ttn"}))); + EXPECT_EQ( + *ARRAY(DATE()), + *testSchemaDictionaryImport( + "i", + makeComplexArrowSchema( + schemas, schemaPtrs, mapSchemas, mapSchemaPtrs, "+l", {"tdD"}))); + EXPECT_EQ( + *ARRAY(VARCHAR()), + *testSchemaDictionaryImport( + "i", + makeComplexArrowSchema( + schemas, schemaPtrs, mapSchemas, mapSchemaPtrs, "+l", {"U"}))); + + // Maps + EXPECT_EQ( + *MAP(VARCHAR(), BOOLEAN()), + *testSchemaDictionaryImport( + "i", + makeComplexArrowSchema( + schemas, + schemaPtrs, + mapSchemas, + mapSchemaPtrs, + "+m", + {"U", "b"}))); + EXPECT_EQ( + *MAP(SMALLINT(), REAL()), + *testSchemaDictionaryImport( + "i", + makeComplexArrowSchema( + schemas, + schemaPtrs, + mapSchemas, + mapSchemaPtrs, + "+m", + {"s", "f"}))); + + // Rows + EXPECT_EQ( + *ROW({SMALLINT(), REAL()}), + *testSchemaDictionaryImport( + "i", + makeComplexArrowSchema( + schemas, + schemaPtrs, + mapSchemas, + mapSchemaPtrs, + "+s", + {"s", "f"}))); + + // Named Row + EXPECT_EQ( + *ROW({"col1", "col2"}, {SMALLINT(), REAL()}), + *testSchemaDictionaryImport( + "i", + makeComplexArrowSchema( + schemas, + schemaPtrs, + mapSchemas, + mapSchemaPtrs, + "+s", + {"s", "f"}, + {"col1", "col2"}))); +} + } // namespace } // namespace facebook::velox::test