Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50644][SQL] Read variant struct in Parquet reader. #49263

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,8 @@ public static Variant rebuild(ShreddedRow row, VariantSchema schema) {
throw malformedVariant();
}
byte[] metadata = row.getBinary(schema.topLevelMetadataIdx);
if (schema.variantIdx >= 0 && schema.typedIdx < 0) {
// The variant is unshredded. We are not required to do anything special, but we can have an
// optimization to avoid `rebuild`.
if (schema.isUnshredded()) {
// `rebuild` is unnecessary for unshredded variant.
if (row.isNullAt(schema.variantIdx)) {
throw malformedVariant();
}
Expand All @@ -65,8 +64,8 @@ public static Variant rebuild(ShreddedRow row, VariantSchema schema) {
// Rebuild a variant value from the shredded data according to the reconstruction algorithm in
// https://github.com/apache/parquet-format/blob/master/VariantShredding.md.
// Append the result to `builder`.
private static void rebuild(ShreddedRow row, byte[] metadata, VariantSchema schema,
VariantBuilder builder) {
public static void rebuild(ShreddedRow row, byte[] metadata, VariantSchema schema,
VariantBuilder builder) {
int typedIdx = schema.typedIdx;
int variantIdx = schema.variantIdx;
if (typedIdx >= 0 && !row.isNullAt(typedIdx)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,12 @@ public VariantSchema(int typedIdx, int variantIdx, int topLevelMetadataIdx, int
this.arraySchema = arraySchema;
}

// Return whether the variant column is unshrededed. The user is not required to do anything
// special, but can have certain optimizations for unshrededed variant.
public boolean isUnshredded() {
return topLevelMetadataIdx >= 0 && variantIdx >= 0 && typedIdx < 0;
}

@Override
public String toString() {
return "VariantSchema{" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.VariantType;
import org.apache.spark.types.variant.VariantSchema;
import org.apache.spark.unsafe.types.VariantVal;

/**
* Contains necessary information representing a Parquet column, either of primitive or nested type.
Expand All @@ -49,6 +48,9 @@ final class ParquetColumnVector {
// contains only one child that reads the underlying file content. This `ParquetColumnVector`
// should assemble Spark variant values from the file content.
private VariantSchema variantSchema;
// Only meaningful if `variantSchema` is not null. See `SparkShreddingUtils.getFieldsToExtract`
// for its meaning.
private FieldToExtract[] fieldsToExtract;

/**
* Repetition & Definition levels
Expand Down Expand Up @@ -117,6 +119,7 @@ final class ParquetColumnVector {
fileContent, capacity, memoryMode, missingColumns, false, null);
children.add(contentVector);
variantSchema = SparkShreddingUtils.buildVariantSchema(fileContentCol.sparkType());
fieldsToExtract = SparkShreddingUtils.getFieldsToExtract(column.sparkType(), variantSchema);
repetitionLevels = contentVector.repetitionLevels;
definitionLevels = contentVector.definitionLevels;
} else if (isPrimitive) {
Expand Down Expand Up @@ -188,20 +191,11 @@ void assemble() {
if (variantSchema != null) {
children.get(0).assemble();
WritableColumnVector fileContent = children.get(0).getValueVector();
int numRows = fileContent.getElementsAppended();
vector.reset();
vector.reserve(numRows);
WritableColumnVector valueChild = vector.getChild(0);
WritableColumnVector metadataChild = vector.getChild(1);
for (int i = 0; i < numRows; ++i) {
if (fileContent.isNullAt(i)) {
vector.appendStruct(true);
} else {
vector.appendStruct(false);
VariantVal v = SparkShreddingUtils.rebuild(fileContent.getStruct(i), variantSchema);
valueChild.appendByteArray(v.getValue(), 0, v.getValue().length);
metadataChild.appendByteArray(v.getMetadata(), 0, v.getMetadata().length);
}
if (fieldsToExtract == null) {
SparkShreddingUtils.assembleVariantBatch(fileContent, vector, variantSchema);
} else {
SparkShreddingUtils.assembleVariantStructBatch(fileContent, vector, variantSchema,
fieldsToExtract);
}
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.datasources.VariantMetadata
import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf}
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -221,6 +222,9 @@ object ParquetReadSupport extends Logging {
clipParquetMapType(
parquetType.asGroupType(), t.keyType, t.valueType, caseSensitive, useFieldId)

case t: StructType if VariantMetadata.isVariantStruct(t) =>
clipVariantSchema(parquetType.asGroupType(), t)

case t: StructType =>
clipParquetGroup(parquetType.asGroupType(), t, caseSensitive, useFieldId)

Expand Down Expand Up @@ -390,6 +394,11 @@ object ParquetReadSupport extends Logging {
.named(parquetRecord.getName)
}

private def clipVariantSchema(parquetType: GroupType, variantStruct: StructType): GroupType = {
// TODO(SHREDDING): clip `parquetType` to retain the necessary columns.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this requires the new parquet version to support column pruning for variant?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it requires a new Parquet version - it should be possible to clip it in sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala in the same way that unused struct fields are clipped. The logic of deciding which fields can be clipped is more complicated, though.

Copy link
Contributor Author

@chenhao-db chenhao-db Dec 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't. in this function, we will have custom logic to clip parquetType to retain the necessary columns for reading variantStruct. But this part will be in a future PR to avoid making the single PR too big.

parquetType
}

/**
* Clips a Parquet [[GroupType]] which corresponds to a Catalyst [[StructType]].
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.datasources.DataSourceUtils
import org.apache.spark.sql.execution.datasources.{DataSourceUtils, VariantMetadata}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
Expand Down Expand Up @@ -498,6 +498,9 @@ private[parquet] class ParquetRowConverter(
case t: MapType =>
new ParquetMapConverter(parquetType.asGroupType(), t, updater)

case t: StructType if VariantMetadata.isVariantStruct(t) =>
new ParquetVariantConverter(t, parquetType.asGroupType(), updater)

case t: StructType =>
val wrappedUpdater = {
// SPARK-30338: avoid unnecessary InternalRow copying for nested structs:
Expand Down Expand Up @@ -536,12 +539,7 @@ private[parquet] class ParquetRowConverter(

case t: VariantType =>
if (SQLConf.get.getConf(SQLConf.VARIANT_ALLOW_READING_SHREDDED)) {
// Infer a Spark type from `parquetType`. This piece of code is copied from
// `ParquetArrayConverter`.
val messageType = Types.buildMessage().addField(parquetType).named("foo")
val column = new ColumnIOFactory().getColumnIO(messageType)
val parquetSparkType = schemaConverter.convertField(column.getChild(0)).sparkType
new ParquetVariantConverter(parquetType.asGroupType(), parquetSparkType, updater)
new ParquetVariantConverter(t, parquetType.asGroupType(), updater)
} else {
new ParquetUnshreddedVariantConverter(parquetType.asGroupType(), updater)
}
Expand Down Expand Up @@ -909,13 +907,14 @@ private[parquet] class ParquetRowConverter(

/** Parquet converter for Variant (shredded or unshredded) */
private final class ParquetVariantConverter(
parquetType: GroupType,
parquetSparkType: DataType,
updater: ParentContainerUpdater)
targetType: DataType, parquetType: GroupType, updater: ParentContainerUpdater)
extends ParquetGroupConverter(updater) {

private[this] var currentRow: Any = _
private[this] val parquetSparkType = SparkShreddingUtils.parquetTypeToSparkType(parquetType)
private[this] val variantSchema = SparkShreddingUtils.buildVariantSchema(parquetSparkType)
private[this] val fieldsToExtract =
SparkShreddingUtils.getFieldsToExtract(targetType, variantSchema)
// A struct converter that reads the underlying file data.
private[this] val fileConverter = new ParquetRowConverter(
schemaConverter,
Expand All @@ -932,7 +931,12 @@ private[parquet] class ParquetRowConverter(

override def end(): Unit = {
fileConverter.end()
val v = SparkShreddingUtils.rebuild(currentRow.asInstanceOf[InternalRow], variantSchema)
val row = currentRow.asInstanceOf[InternalRow]
val v = if (fieldsToExtract == null) {
SparkShreddingUtils.assembleVariant(row, variantSchema)
} else {
SparkShreddingUtils.assembleVariantStruct(row, variantSchema, fieldsToExtract)
}
updater.set(v)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.parquet.schema.Type.Repetition._

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.VariantMetadata
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -185,6 +186,9 @@ class ParquetToSparkSchemaConverter(
} else {
convertVariantField(groupColumn)
}
case groupColumn: GroupColumnIO if targetType.exists(VariantMetadata.isVariantStruct) =>
val col = convertGroupField(groupColumn)
col.copy(sparkType = targetType.get, variantFileType = Some(col))
case groupColumn: GroupColumnIO => convertGroupField(groupColumn, targetType)
}
}
Expand Down
Loading
Loading