Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-17528][SQL] data should be copied properly before saving into InternalRow #18483

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1088,6 +1088,12 @@ public UTF8String clone() {
return fromBytes(getBytes());
}

public UTF8String copy() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because clone() doesn't always make a copy, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we should just make clone make an actual copy...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UTF8String is public to users, so I'm hesitating to change the clone method

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, let's leave it then

byte[] bytes = new byte[numBytes];
copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, numBytes);
return fromBytes(bytes);
}

@Override
public int compareTo(@Nonnull final UTF8String other) {
int len = Math.min(numBytes, other.numBytes);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
package org.apache.spark.sql.catalyst

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types.{DataType, Decimal, StructType}
import org.apache.spark.unsafe.types.UTF8String

/**
* An abstract class for row used internally in Spark SQL, which only contains the columns as
Expand All @@ -33,6 +35,10 @@ abstract class InternalRow extends SpecializedGetters with Serializable {

def setNullAt(i: Int): Unit

/**
* Updates the value at column `i`. Note that after updating, the given value will be kept in this
* row, and the caller side should guarantee that this value won't be changed afterwards.
*/
def update(i: Int, value: Any): Unit

// default implementation (slow)
Expand All @@ -58,7 +64,15 @@ abstract class InternalRow extends SpecializedGetters with Serializable {
def copy(): InternalRow

/** Returns true if there are any NULL values in this row. */
def anyNull: Boolean
def anyNull: Boolean = {
val len = numFields
var i = 0
while (i < len) {
if (isNullAt(i)) { return true }
i += 1
}
false
}

/* ---------------------- utility methods for Scala ---------------------- */

Expand Down Expand Up @@ -94,4 +108,21 @@ object InternalRow {

/** Returns an empty [[InternalRow]]. */
val empty = apply()

/**
* Copies the given value if it's string/struct/array/map type.
*/
def copyValue(value: Any): Any = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have some marker trait for objects that need to be copied? It might make the code a bit more concise, and it will also save some typing. The only downside would be that the trait will probably live in a very weird place because UTF8String lives in the common/unsafe project.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only 4 types, maybe not need to bother?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is fine for now.

if (value.isInstanceOf[UTF8String]) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use pattern matching?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this method may be called many times, for nested complex type, so I'm worried about performance here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

??? - pattern matching should yield the same performance as if (value.instanceof[...]) ... else if (value.instanceof[...]) ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, I was told to not use pattern match in performance critical path... Anyway, seems the copyshould take most of the execution time, so pattern match should be fine, let me update.

value.asInstanceOf[UTF8String].copy()
} else if (value.isInstanceOf[InternalRow]) {
value.asInstanceOf[InternalRow].copy()
} else if (value.isInstanceOf[ArrayData]) {
value.asInstanceOf[ArrayData].copy()
} else if (value.isInstanceOf[MapData]) {
value.asInstanceOf[MapData].copy()
} else {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also support Clonable? This might be an escape door for developers who are putting their own objects in InternalRow?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The internal values are totally internal, do we really need an esacpe door?

value
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1047,7 +1047,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
final $rowClass $result = new $rowClass(${fieldsCasts.length});
final InternalRow $tmpRow = $c;
$fieldsEvalCode
$evPrim = $result.copy();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this copy is not needed, because we already do the copy when setting columns to this row.

$evPrim = $result;
"""
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -220,17 +219,6 @@ final class SpecificInternalRow(val values: Array[MutableValue]) extends BaseGen

override def isNullAt(i: Int): Boolean = values(i).isNull

override def copy(): InternalRow = {
val newValues = new Array[Any](values.length)
var i = 0
while (i < values.length) {
newValues(i) = values(i).boxed
i += 1
}

new GenericInternalRow(newValues)
}

override protected def genericGet(i: Int): Any = values(i).boxed

override def update(ordinal: Int, value: Any) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImper
// Do not allow null values. We follow the semantics of Hive's collect_list/collect_set here.
// See: org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator
if (value != null) {
buffer += value
buffer += InternalRow.copyValue(value)
}
buffer
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,9 @@ abstract class ImperativeAggregate extends AggregateFunction with CodegenFallbac
* Updates its aggregation buffer, located in `mutableAggBuffer`, based on the given `inputRow`.
*
* Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`.
*
* Note that, the input row may be produced by unsafe projection and it may not be safe to cache
* some fields of the input row, as the values can be changed unexpectedly.
*/
def update(mutableAggBuffer: InternalRow, inputRow: InternalRow): Unit

Expand All @@ -326,6 +329,9 @@ abstract class ImperativeAggregate extends AggregateFunction with CodegenFallbac
*
* Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`.
* Use `fieldNumber + inputAggBufferOffset` to access fields of `inputAggBuffer`.
*
* Note that, the input row may be produced by unsafe projection and it may not be safe to cache
* some fields of the input row, as the values can be changed unexpectedly.
*/
def merge(mutableAggBuffer: InternalRow, inputAggBuffer: InternalRow): Unit
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -408,9 +408,11 @@ class CodegenContext {
dataType match {
case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)"
case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})"
// The UTF8String may came from UnsafeRow, otherwise clone is cheap (re-use the bytes)
case StringType => s"$row.update($ordinal, $value.clone())"
case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value)
// The UTF8String, InternalRow, ArrayData and MapData may came from UnsafeRow, we should copy
// it to avoid keeping a "pointer" to a memory region which may get updated afterwards.
case StringType | _: StructType | _: ArrayType | _: MapType =>
s"$row.update($ordinal, $value.copy())"
case _ => s"$row.update($ordinal, $value)"
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,6 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
case s: StructType => createCodeForStruct(ctx, input, s)
case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType)
case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType)
// UTF8String act as a pointer if it's inside UnsafeRow, so copy it to make it safe.
case StringType => ExprCode("", "false", s"$input.clone()")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this copy is not needed, as we will do copy before we updating a value to the row.

case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType)
case _ => ExprCode("", "false", input)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,6 @@ trait BaseGenericInternalRow extends InternalRow {
override def getMap(ordinal: Int): MapData = getAs(ordinal)
override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal)

override def anyNull: Boolean = {
val len = numFields
var i = 0
while (i < len) {
if (isNullAt(i)) { return true }
i += 1
}
false
}

override def toString: String = {
if (numFields == 0) {
"[empty row]"
Expand All @@ -79,6 +69,17 @@ trait BaseGenericInternalRow extends InternalRow {
}
}

override def copy(): GenericInternalRow = {
val len = numFields
val newValues = new Array[Any](len)
var i = 0
while (i < len) {
newValues(i) = InternalRow.copyValue(genericGet(i))
i += 1
}
new GenericInternalRow(newValues)
}

override def equals(o: Any): Boolean = {
if (!o.isInstanceOf[BaseGenericInternalRow]) {
return false
Expand Down Expand Up @@ -206,6 +207,4 @@ class GenericInternalRow(val values: Array[Any]) extends BaseGenericInternalRow
override def setNullAt(i: Int): Unit = { values(i) = null}

override def update(i: Int, value: Any): Unit = { values(i) = value }

override def copy(): GenericInternalRow = this
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,15 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData {

def this(seqOrArray: Any) = this(GenericArrayData.anyToSeq(seqOrArray))

override def copy(): ArrayData = new GenericArrayData(array.clone())
override def copy(): ArrayData = {
val newValues = new Array[Any](array.length)
var i = 0
while (i < array.length) {
newValues(i) = InternalRow.copyValue(array(i))
i += 1
}
new GenericArrayData(newValues)
}

override def numElements(): Int = array.length

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,6 @@ class RowTest extends FunSpec with Matchers {
externalRow should be theSameInstanceAs externalRow.copy()
}

it("copy should return same ref for internal rows") {
internalRow should be theSameInstanceAs internalRow.copy()
}

it("toSeq should not expose internal state for external rows") {
val modifiedValues = modifyValues(externalRow.toSeq)
externalRow.toSeq should not equal modifiedValues
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,40 @@ class GeneratedProjectionSuite extends SparkFunSuite {
assert(unsafe1 === unsafe3)
assert(unsafe1.getStruct(1, 7) === unsafe3.getStruct(1, 7))
}

test("MutableProjection should not cache content from the input row") {
val mutableProj = GenerateMutableProjection.generate(
Seq(BoundReference(0, new StructType().add("i", StringType), true)))
val row = new GenericInternalRow(1)
mutableProj.target(row)

val unsafeProj = GenerateUnsafeProjection.generate(
Seq(BoundReference(0, new StructType().add("i", StringType), true)))
val unsafeRow = unsafeProj.apply(InternalRow(InternalRow(UTF8String.fromString("a"))))

mutableProj.apply(unsafeRow)
assert(row.getStruct(0, 1).getString(0) == "a")

// Even if the input row of the mutable projection has been changed, the target mutable row
// should keep same.
unsafeProj.apply(InternalRow(InternalRow(UTF8String.fromString("b"))))
assert(row.getStruct(0, 1).getString(0).toString == "a")
}

test("SafeProjection should not cache content from the input row") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has always worked right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, I add this test to make sure my change to the GenerateSafeProjection is safe.

val safeProj = GenerateSafeProjection.generate(
Seq(BoundReference(0, new StructType().add("i", StringType), true)))

val unsafeProj = GenerateUnsafeProjection.generate(
Seq(BoundReference(0, new StructType().add("i", StringType), true)))
val unsafeRow = unsafeProj.apply(InternalRow(InternalRow(UTF8String.fromString("a"))))

val row = safeProj.apply(unsafeRow)
assert(row.getStruct(0, 1).getString(0) == "a")

// Even if the input row of the mutable projection has been changed, the target mutable row
// should keep same.
unsafeProj.apply(InternalRow(InternalRow(UTF8String.fromString("b"))))
assert(row.getStruct(0, 1).getString(0).toString == "a")
}
}
Loading