diff --git a/velox/functions/sparksql/fuzzer/tests/SparkQueryRunnerTest.cpp b/velox/functions/sparksql/fuzzer/tests/SparkQueryRunnerTest.cpp index 974c9c3e69ac..ad51305d7c52 100644 --- a/velox/functions/sparksql/fuzzer/tests/SparkQueryRunnerTest.cpp +++ b/velox/functions/sparksql/fuzzer/tests/SparkQueryRunnerTest.cpp @@ -18,6 +18,7 @@ #include #include "velox/common/base/tests/GTestUtils.h" +#include "velox/dwio/parquet/RegisterParquetWriter.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/PlanBuilder.h" #include "velox/functions/sparksql/Register.h" @@ -37,6 +38,7 @@ class SparkQueryRunnerTest : public ::testing::Test, protected: static void SetUpTestCase() { memory::MemoryManager::testingSetInstance({}); + parquet::registerParquetWriterFactory(); } void SetUp() override { @@ -80,6 +82,21 @@ TEST_F(SparkQueryRunnerTest, DISABLED_basic) { exec::test::assertEqualResults(sparkResults, outputType, {expected}); } +// This test requires a Spark coordinator running at localhost, so disable it +// by default. +TEST_F(SparkQueryRunnerTest, DISABLED_decimal) { + auto aggregatePool = rootPool_->addAggregateChild("decimal"); + auto queryRunner = std::make_unique( + aggregatePool.get(), "localhost:15002", "test", "decimal"); + auto input = makeRowVector({ + makeConstant(123456789, 25, DECIMAL(34, 2)), + }); + auto outputType = ROW({"a"}, {DECIMAL(34, 2)}); + auto sparkResults = + queryRunner->execute("SELECT abs(c0) FROM tmp", {input}, outputType); + exec::test::assertEqualResults(sparkResults, outputType, {input}); +} + // This test requires a Spark coordinator running at localhost, so disable it // by default. TEST_F(SparkQueryRunnerTest, DISABLED_fuzzer) { diff --git a/velox/vector/arrow/Bridge.cpp b/velox/vector/arrow/Bridge.cpp index 8c5e15a4bc07..808430e2e366 100644 --- a/velox/vector/arrow/Bridge.cpp +++ b/velox/vector/arrow/Bridge.cpp @@ -1866,6 +1866,24 @@ VectorPtr createShortDecimalVector( pool, type, nulls, length, values, nullCount); } +// Arrow uses two uint64_t values to represent a 128-bit decimal value. The +// memory allocated by Arrow might not be 16-byte aligned, so we need to copy +// the values to a new buffer to ensure 16-byte alignment. +VectorPtr createLongDecimalVector( + memory::MemoryPool* pool, + const TypePtr& type, + BufferPtr nulls, + const int128_t* input, + size_t length, + int64_t nullCount) { + auto values = AlignedBuffer::allocate(length, pool); + auto rawValues = values->asMutable(); + memcpy(rawValues, input, length * sizeof(int128_t)); + + return createFlatVector( + pool, type, nulls, length, values, nullCount); +} + bool isREE(const ArrowSchema& arrowSchema) { return arrowSchema.format[0] == '+' && arrowSchema.format[1] == 'r'; } @@ -1960,6 +1978,14 @@ VectorPtr importFromArrowImpl( static_cast(arrowArray.buffers[1]), arrowArray.length, arrowArray.null_count); + } else if (type->isLongDecimal()) { + return createLongDecimalVector( + pool, + type, + nulls, + static_cast(arrowArray.buffers[1]), + arrowArray.length, + arrowArray.null_count); } else if (type->isRow()) { // Row/structs. return createRowVector( diff --git a/velox/vector/arrow/tests/ArrowBridgeArrayTest.cpp b/velox/vector/arrow/tests/ArrowBridgeArrayTest.cpp index bd07abdfcc71..4dd14966417f 100644 --- a/velox/vector/arrow/tests/ArrowBridgeArrayTest.cpp +++ b/velox/vector/arrow/tests/ArrowBridgeArrayTest.cpp @@ -1270,6 +1270,14 @@ class ArrowBridgeArrayImportTest : public ArrowBridgeArrayExportTest { testArrowImport( "d:5,2", {1, -1, 0, 12345, -12345, std::nullopt}); + testArrowImport( + "d:36,2", + {HugeInt::parse("20000000000000000"), + HugeInt::parse("50000000000000000"), + 0, + HugeInt::parse("50000000000000000000"), + HugeInt::parse("-40000000000000000000"), + std::nullopt}); } template