Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-40876][SQL][FOLLOWUP] Widening type promotion from integers to decimal in Parquet vectorized reader #44803

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1406,7 +1406,11 @@ private static class IntegerToDecimalUpdater extends DecimalUpdater {
super(sparkType);
LogicalTypeAnnotation typeAnnotation =
descriptor.getPrimitiveType().getLogicalTypeAnnotation();
this.parquetScale = ((DecimalLogicalTypeAnnotation) typeAnnotation).getScale();
if (typeAnnotation instanceof DecimalLogicalTypeAnnotation) {
this.parquetScale = ((DecimalLogicalTypeAnnotation) typeAnnotation).getScale();
} else {
this.parquetScale = 0;
}
}

@Override
Expand Down Expand Up @@ -1435,14 +1439,18 @@ public void decodeSingleDictionaryId(
}
}

private static class LongToDecimalUpdater extends DecimalUpdater {
private static class LongToDecimalUpdater extends DecimalUpdater {
private final int parquetScale;

LongToDecimalUpdater(ColumnDescriptor descriptor, DecimalType sparkType) {
LongToDecimalUpdater(ColumnDescriptor descriptor, DecimalType sparkType) {
super(sparkType);
LogicalTypeAnnotation typeAnnotation =
descriptor.getPrimitiveType().getLogicalTypeAnnotation();
this.parquetScale = ((DecimalLogicalTypeAnnotation) typeAnnotation).getScale();
if (typeAnnotation instanceof DecimalLogicalTypeAnnotation) {
this.parquetScale = ((DecimalLogicalTypeAnnotation) typeAnnotation).getScale();
} else {
this.parquetScale = 0;
}
}

@Override
Expand Down Expand Up @@ -1640,6 +1648,12 @@ private static boolean isDateTypeMatched(ColumnDescriptor descriptor) {
return typeAnnotation instanceof DateLogicalTypeAnnotation;
}

private static boolean isSignedIntAnnotation(LogicalTypeAnnotation typeAnnotation) {
if (!(typeAnnotation instanceof IntLogicalTypeAnnotation)) return false;
IntLogicalTypeAnnotation intAnnotation = (IntLogicalTypeAnnotation) typeAnnotation;
return intAnnotation.isSigned();
}

private static boolean isDecimalTypeMatched(ColumnDescriptor descriptor, DataType dt) {
DecimalType requestedType = (DecimalType) dt;
LogicalTypeAnnotation typeAnnotation = descriptor.getPrimitiveType().getLogicalTypeAnnotation();
Expand All @@ -1651,6 +1665,20 @@ private static boolean isDecimalTypeMatched(ColumnDescriptor descriptor, DataTyp
int scaleIncrease = requestedType.scale() - parquetType.getScale();
int precisionIncrease = requestedType.precision() - parquetType.getPrecision();
return scaleIncrease >= 0 && precisionIncrease >= scaleIncrease;
} else if (typeAnnotation == null || isSignedIntAnnotation(typeAnnotation)) {
// Allow reading signed integers (which may be un-annotated) as decimal as long as the
// requested decimal type is large enough to represent all possible values.
PrimitiveType.PrimitiveTypeName typeName =
descriptor.getPrimitiveType().getPrimitiveTypeName();
int integerPrecision = requestedType.precision() - requestedType.scale();
switch (typeName) {
case INT32:
return integerPrecision >= DecimalType$.MODULE$.IntDecimal().precision();
case INT64:
return integerPrecision >= DecimalType$.MODULE$.LongDecimal().precision();
default:
return false;
}
}
return false;
}
Expand All @@ -1661,6 +1689,9 @@ private static boolean isSameDecimalScale(ColumnDescriptor descriptor, DataType
if (typeAnnotation instanceof DecimalLogicalTypeAnnotation) {
DecimalLogicalTypeAnnotation decimalType = (DecimalLogicalTypeAnnotation) typeAnnotation;
return decimalType.getScale() == d.scale();
} else if (typeAnnotation == null || isSignedIntAnnotation(typeAnnotation)) {
// Consider signed integers (which may be un-annotated) as having scale 0.
return d.scale() == 0;
}
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,18 +151,17 @@ private boolean isLazyDecodingSupported(
// rebasing.
switch (typeName) {
case INT32: {
boolean isDate = logicalTypeAnnotation instanceof DateLogicalTypeAnnotation;
boolean isDecimal = logicalTypeAnnotation instanceof DecimalLogicalTypeAnnotation;
boolean isDecimal = sparkType instanceof DecimalType;
boolean needsUpcast = sparkType == LongType || sparkType == DoubleType ||
(isDate && sparkType == TimestampNTZType) ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we remove isDate?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was redundant since reading an INT32 as TimestampNTZType necessarily requires converting the value. The fact that this only happens for parquet dates isn't really relevant here and with the current change this would be the only case where we look at the parquet type annotation which is a bit confusing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, this is inside isLazyDecodingSupported

sparkType == TimestampNTZType ||
(isDecimal && !DecimalType.is32BitDecimalType(sparkType));
boolean needsRebase = logicalTypeAnnotation instanceof DateLogicalTypeAnnotation &&
!"CORRECTED".equals(datetimeRebaseMode);
isSupported = !needsUpcast && !needsRebase && !needsDecimalScaleRebase(sparkType);
break;
}
case INT64: {
boolean isDecimal = logicalTypeAnnotation instanceof DecimalLogicalTypeAnnotation;
boolean isDecimal = sparkType instanceof DecimalType;
boolean needsUpcast = (isDecimal && !DecimalType.is64BitDecimalType(sparkType)) ||
updaterFactory.isTimestampTypeMatched(TimeUnit.MILLIS);
boolean needsRebase = updaterFactory.isTimestampTypeMatched(TimeUnit.MICROS) &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1037,8 +1037,10 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS

withAllParquetReaders {
// We can read the decimal parquet field with a larger precision, if scale is the same.
val schema = "a DECIMAL(9, 1), b DECIMAL(18, 2), c DECIMAL(38, 2)"
checkAnswer(readParquet(schema, path), df)
val schema1 = "a DECIMAL(9, 1), b DECIMAL(18, 2), c DECIMAL(38, 2)"
checkAnswer(readParquet(schema1, path), df)
val schema2 = "a DECIMAL(18, 1), b DECIMAL(38, 2), c DECIMAL(38, 2)"
checkAnswer(readParquet(schema2, path), df)
}

withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
Expand Down Expand Up @@ -1067,10 +1069,12 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS

withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
checkAnswer(readParquet("a DECIMAL(3, 2)", path), sql("SELECT 1.00"))
checkAnswer(readParquet("a DECIMAL(11, 2)", path), sql("SELECT 1.00"))
checkAnswer(readParquet("b DECIMAL(3, 2)", path), Row(null))
checkAnswer(readParquet("b DECIMAL(11, 1)", path), sql("SELECT 123456.0"))
checkAnswer(readParquet("c DECIMAL(11, 1)", path), Row(null))
checkAnswer(readParquet("c DECIMAL(13, 0)", path), df.select("c"))
checkAnswer(readParquet("c DECIMAL(22, 0)", path), df.select("c"))
val e = intercept[SparkException] {
readParquet("d DECIMAL(3, 2)", path).collect()
}.getCause
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.parquet
import java.io.File

import org.apache.hadoop.fs.Path
import org.apache.parquet.column.{Encoding, ParquetProperties}
import org.apache.parquet.format.converter.ParquetMetadataConverter
import org.apache.parquet.hadoop.{ParquetFileReader, ParquetOutputFormat}

Expand All @@ -31,6 +32,7 @@ import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf}
import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.DecimalType.{ByteDecimal, IntDecimal, LongDecimal, ShortDecimal}

class ParquetTypeWideningSuite
extends QueryTest
Expand Down Expand Up @@ -121,6 +123,19 @@ class ParquetTypeWideningSuite
if (dictionaryEnabled && !DecimalType.isByteArrayDecimalType(dataType)) {
assertAllParquetFilesDictionaryEncoded(dir)
}

// Check which encoding was used when writing Parquet V2 files.
val isParquetV2 = spark.conf.getOption(ParquetOutputFormat.WRITER_VERSION)
.contains(ParquetProperties.WriterVersion.PARQUET_2_0.toString)
if (isParquetV2) {
if (dictionaryEnabled) {
assertParquetV2Encoding(dir, Encoding.PLAIN)
} else if (DecimalType.is64BitDecimalType(dataType)) {
assertParquetV2Encoding(dir, Encoding.DELTA_BINARY_PACKED)
} else if (DecimalType.isByteArrayDecimalType(dataType)) {
assertParquetV2Encoding(dir, Encoding.DELTA_BYTE_ARRAY)
}
}
df
}

Expand All @@ -145,6 +160,27 @@ class ParquetTypeWideningSuite
}
}

/**
* Asserts that all parquet files in the given directory have all their columns encoded with the
* given encoding.
*/
private def assertParquetV2Encoding(dir: File, expected_encoding: Encoding): Unit = {
dir.listFiles(_.getName.endsWith(".parquet")).foreach { file =>
val parquetMetadata = ParquetFileReader.readFooter(
spark.sessionState.newHadoopConf(),
new Path(dir.toString, file.getName),
ParquetMetadataConverter.NO_FILTER)
parquetMetadata.getBlocks.forEach { block =>
block.getColumns.forEach { col =>
assert(
col.getEncodings.contains(expected_encoding),
s"Expected column '${col.getPath.toDotString}' to use encoding $expected_encoding " +
s"but found ${col.getEncodings}.")
}
}
}
}

for {
(values: Seq[String], fromType: DataType, toType: DataType) <- Seq(
(Seq("1", "2", Short.MinValue.toString), ShortType, IntegerType),
Expand All @@ -157,24 +193,77 @@ class ParquetTypeWideningSuite
(Seq("2020-01-01", "2020-01-02", "1312-02-27"), DateType, TimestampNTZType)
)
}
test(s"parquet widening conversion $fromType -> $toType") {
checkAllParquetReaders(values, fromType, toType, expectError = false)
}
test(s"parquet widening conversion $fromType -> $toType") {
checkAllParquetReaders(values, fromType, toType, expectError = false)
}

for {
(values: Seq[String], fromType: DataType, toType: DataType) <- Seq(
(Seq("1", Byte.MaxValue.toString), ByteType, IntDecimal),
(Seq("1", Byte.MaxValue.toString), ByteType, LongDecimal),
(Seq("1", Short.MaxValue.toString), ShortType, IntDecimal),
(Seq("1", Short.MaxValue.toString), ShortType, LongDecimal),
(Seq("1", Short.MaxValue.toString), ShortType, DecimalType(DecimalType.MAX_PRECISION, 0)),
(Seq("1", Int.MaxValue.toString), IntegerType, IntDecimal),
(Seq("1", Int.MaxValue.toString), IntegerType, LongDecimal),
(Seq("1", Int.MaxValue.toString), IntegerType, DecimalType(DecimalType.MAX_PRECISION, 0)),
(Seq("1", Long.MaxValue.toString), LongType, LongDecimal),
(Seq("1", Long.MaxValue.toString), LongType, DecimalType(DecimalType.MAX_PRECISION, 0)),
(Seq("1", Byte.MaxValue.toString), ByteType, DecimalType(IntDecimal.precision + 1, 1)),
(Seq("1", Short.MaxValue.toString), ShortType, DecimalType(IntDecimal.precision + 1, 1)),
(Seq("1", Int.MaxValue.toString), IntegerType, DecimalType(IntDecimal.precision + 1, 1)),
(Seq("1", Long.MaxValue.toString), LongType, DecimalType(LongDecimal.precision + 1, 1))
)
}
test(s"parquet widening conversion $fromType -> $toType") {
checkAllParquetReaders(values, fromType, toType, expectError = false)
}

for {
(values: Seq[String], fromType: DataType, toType: DataType) <- Seq(
(Seq("1", "2", Int.MinValue.toString), LongType, IntegerType),
(Seq("1.23", "10.34"), DoubleType, FloatType),
(Seq("1.23", "10.34"), FloatType, LongType),
(Seq("1", "10"), LongType, DoubleType),
(Seq("1", "10"), LongType, DateType),
(Seq("1", "10"), IntegerType, TimestampType),
(Seq("1", "10"), IntegerType, TimestampNTZType),
(Seq("2020-01-01", "2020-01-02", "1312-02-27"), DateType, TimestampType)
)
}
test(s"unsupported parquet conversion $fromType -> $toType") {
checkAllParquetReaders(values, fromType, toType, expectError = true)
}
test(s"unsupported parquet conversion $fromType -> $toType") {
checkAllParquetReaders(values, fromType, toType, expectError = true)
}

for {
(values: Seq[String], fromType: DataType, toType: DecimalType) <- Seq(
// Parquet stores byte, short, int values as INT32, which then requires using a decimal that
// can hold at least 4 byte integers.
(Seq("1", "2"), ByteType, DecimalType(1, 0)),
(Seq("1", "2"), ByteType, ByteDecimal),
(Seq("1", "2"), ShortType, ByteDecimal),
(Seq("1", "2"), ShortType, ShortDecimal),
(Seq("1", "2"), IntegerType, ShortDecimal),
(Seq("1", "2"), ByteType, DecimalType(ByteDecimal.precision + 1, 1)),
(Seq("1", "2"), ShortType, DecimalType(ShortDecimal.precision + 1, 1)),
(Seq("1", "2"), LongType, IntDecimal),
(Seq("1", "2"), ByteType, DecimalType(ByteDecimal.precision - 1, 0)),
(Seq("1", "2"), ShortType, DecimalType(ShortDecimal.precision - 1, 0)),
(Seq("1", "2"), IntegerType, DecimalType(IntDecimal.precision - 1, 0)),
(Seq("1", "2"), LongType, DecimalType(LongDecimal.precision - 1, 0)),
(Seq("1", "2"), ByteType, DecimalType(ByteDecimal.precision, 1)),
(Seq("1", "2"), ShortType, DecimalType(ShortDecimal.precision, 1)),
(Seq("1", "2"), IntegerType, DecimalType(IntDecimal.precision, 1)),
(Seq("1", "2"), LongType, DecimalType(LongDecimal.precision, 1))
)
}
test(s"unsupported parquet conversion $fromType -> $toType") {
checkAllParquetReaders(values, fromType, toType,
expectError =
// parquet-mr allows reading decimals into a smaller precision decimal type without
// checking for overflows. See test below checking for the overflow case in parquet-mr.
spark.conf.get(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key).toBoolean)
}

for {
(values: Seq[String], fromType: DataType, toType: DataType) <- Seq(
Expand All @@ -201,17 +290,17 @@ class ParquetTypeWideningSuite
Seq(5 -> 7, 5 -> 10, 5 -> 20, 10 -> 12, 10 -> 20, 20 -> 22) ++
Seq(7 -> 5, 10 -> 5, 20 -> 5, 12 -> 10, 20 -> 10, 22 -> 20)
}
test(
s"parquet decimal precision change Decimal($fromPrecision, 2) -> Decimal($toPrecision, 2)") {
checkAllParquetReaders(
values = Seq("1.23", "10.34"),
fromType = DecimalType(fromPrecision, 2),
toType = DecimalType(toPrecision, 2),
expectError = fromPrecision > toPrecision &&
// parquet-mr allows reading decimals into a smaller precision decimal type without
// checking for overflows. See test below checking for the overflow case in parquet-mr.
spark.conf.get(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key).toBoolean)
}
test(
s"parquet decimal precision change Decimal($fromPrecision, 2) -> Decimal($toPrecision, 2)") {
checkAllParquetReaders(
values = Seq("1.23", "10.34"),
fromType = DecimalType(fromPrecision, 2),
toType = DecimalType(toPrecision, 2),
expectError = fromPrecision > toPrecision &&
// parquet-mr allows reading decimals into a smaller precision decimal type without
// checking for overflows. See test below checking for the overflow case in parquet-mr.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for non-vectorized parquet reader, what's the behavior? silent overflow?

Copy link
Contributor Author

@johanl-db johanl-db Jan 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and vectorized reader just doesn't allow it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes the vectorized reader throws an exception which this test is checking

spark.conf.get(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key).toBoolean)
}

for {
((fromPrecision, fromScale), (toPrecision, toScale)) <-
Expand Down