Skip to content

Commit

Permalink
[SPARK-46791][SQL] Support Java Set in JavaTypeInference
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This patch adds the support of Java `Set` as bean field in `JavaTypeInference`.

### Why are the changes needed?

Scala `Set` (`scala.collection.Set`) is supported in `ScalaReflection` so users can encode Scala `Set` in Dataset. But Java `Set` is not supported in bean encoder (i.e., `JavaTypeInference`). This feature inconsistency makes Java users cannot use `Set` like Scala users do.

### Does this PR introduce _any_ user-facing change?

Yes. Java `Set` is supported to be part of Java bean when encoding with bean encoder.

### How was this patch tested?

Added tests.

### Was this patch authored or co-authored using generative AI tooling?

No

Closes #44828 from viirya/java_set.

Authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
  • Loading branch information
viirya authored and dongjoon-hyun committed Jan 22, 2024
1 parent 02533d7 commit 667c0a9
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst

import java.beans.{Introspector, PropertyDescriptor}
import java.lang.reflect.{ParameterizedType, Type, TypeVariable}
import java.util.{List => JList, Map => JMap}
import java.util.{List => JList, Map => JMap, Set => JSet}
import javax.annotation.Nonnull

import scala.jdk.CollectionConverters._
Expand Down Expand Up @@ -112,6 +112,10 @@ object JavaTypeInference {
val element = encoderFor(c.getTypeParameters.array(0), seenTypeSet, typeVariables)
IterableEncoder(ClassTag(c), element, element.nullable, lenientSerialization = false)

case c: Class[_] if classOf[JSet[_]].isAssignableFrom(c) =>
val element = encoderFor(c.getTypeParameters.array(0), seenTypeSet, typeVariables)
IterableEncoder(ClassTag(c), element, element.nullable, lenientSerialization = false)

case c: Class[_] if classOf[JMap[_, _]].isAssignableFrom(c) =>
val keyEncoder = encoderFor(c.getTypeParameters.array(0), seenTypeSet, typeVariables)
val valueEncoder = encoderFor(c.getTypeParameters.array(1), seenTypeSet, typeVariables)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,8 @@ case class MapObjects private(
_.asInstanceOf[Array[_]].toImmutableArraySeq
case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) =>
_.asInstanceOf[java.util.List[_]].asScala.toSeq
case ObjectType(cls) if classOf[java.util.Set[_]].isAssignableFrom(cls) =>
_.asInstanceOf[java.util.Set[_]].asScala.toSeq
case ObjectType(cls) if cls == classOf[Object] =>
(inputCollection) => {
if (inputCollection.getClass.isArray) {
Expand Down Expand Up @@ -982,6 +984,34 @@ case class MapObjects private(
builder
}
}
case Some(cls) if classOf[java.util.Set[_]].isAssignableFrom(cls) =>
// Java set
if (cls == classOf[java.util.Set[_]] || cls == classOf[java.util.AbstractSet[_]]) {
// Specifying non concrete implementations of `java.util.Set`
executeFuncOnCollection(_).toSet.asJava
} else {
val constructors = cls.getConstructors()
val intParamConstructor = constructors.find { constructor =>
constructor.getParameterCount == 1 && constructor.getParameterTypes()(0) == classOf[Int]
}
val noParamConstructor = constructors.find { constructor =>
constructor.getParameterCount == 0
}

val constructor = intParamConstructor.map { intConstructor =>
(len: Int) => intConstructor.newInstance(len.asInstanceOf[Object])
}.getOrElse {
(_: Int) => noParamConstructor.get.newInstance()
}

// Specifying concrete implementations of `java.util.Set`
(inputs) => {
val results = executeFuncOnCollection(inputs)
val builder = constructor(inputs.length).asInstanceOf[java.util.Set[Any]]
results.foreach(builder.add(_))
builder
}
}
case None =>
// array
x => new GenericArrayData(executeFuncOnCollection(x).toArray)
Expand Down Expand Up @@ -1067,6 +1097,13 @@ case class MapObjects private(
s"java.util.Iterator $it = ${genInputData.value}.iterator();",
s"$it.next()"
)
case ObjectType(cls) if classOf[java.util.Set[_]].isAssignableFrom(cls) =>
val it = ctx.freshName("it")
(
s"${genInputData.value}.size()",
s"java.util.Iterator $it = ${genInputData.value}.iterator();",
s"$it.next()"
)
case ArrayType(et, _) =>
(
s"${genInputData.value}.numElements()",
Expand Down Expand Up @@ -1158,6 +1195,19 @@ case class MapObjects private(
(genValue: String) => s"$builder.add($genValue);",
s"$builder;"
)
case Some(cls) if classOf[java.util.Set[_]].isAssignableFrom(cls) =>
// Java set
val builder = ctx.freshName("collectionBuilder")
(
if (cls == classOf[java.util.Set[_]] || cls == classOf[java.util.AbstractSet[_]]) {
s"${cls.getName} $builder = new java.util.HashSet($dataLength);"
} else {
val param = Try(cls.getConstructor(Integer.TYPE)).map(_ => dataLength).getOrElse("")
s"${cls.getName} $builder = new ${cls.getName}($param);"
},
(genValue: String) => s"$builder.add($genValue);",
s"$builder;"
)
case _ =>
// array
(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst

import java.math.BigInteger
import java.util.{LinkedList, List => JList, Map => JMap}
import java.util.{HashSet, LinkedList, List => JList, Map => JMap, Set => JSet}

import scala.beans.{BeanProperty, BooleanBeanProperty}
import scala.reflect.{classTag, ClassTag}
Expand All @@ -37,6 +37,8 @@ class GenericCollectionBean {
@BeanProperty var listOfListOfStrings: JList[JList[String]] = _
@BeanProperty var mapOfDummyBeans: JMap[String, DummyBean] = _
@BeanProperty var linkedListOfStrings: LinkedList[String] = _
@BeanProperty var hashSetOfString: HashSet[String] = _
@BeanProperty var setOfSetOfStrings: JSet[JSet[String]] = _
}

class LeafBean {
Expand Down Expand Up @@ -139,9 +141,16 @@ class JavaTypeInferenceSuite extends SparkFunSuite {
assert(schema === expected)
}

test("resolve type parameters for map and list") {
test("resolve type parameters for map, list and set") {
val encoder = JavaTypeInference.encoderFor(classOf[GenericCollectionBean])
val expected = JavaBeanEncoder(ClassTag(classOf[GenericCollectionBean]), Seq(
encoderField(
"hashSetOfString",
IterableEncoder(
ClassTag(classOf[HashSet[_]]),
StringEncoder,
containsNull = true,
lenientSerialization = false)),
encoderField(
"linkedListOfStrings",
IterableEncoder(
Expand All @@ -166,7 +175,18 @@ class JavaTypeInferenceSuite extends SparkFunSuite {
ClassTag(classOf[JMap[_, _]]),
StringEncoder,
expectedDummyBeanEncoder,
valueContainsNull = true))))
valueContainsNull = true)),
encoderField(
"setOfSetOfStrings",
IterableEncoder(
ClassTag(classOf[JSet[_]]),
IterableEncoder(
ClassTag(classOf[JSet[_]]),
StringEncoder,
containsNull = true,
lenientSerialization = false),
containsNull = true,
lenientSerialization = false))))
assert(encoder === expected)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,8 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(result.asInstanceOf[ArrayData].array.toSeq == expected)
case l if classOf[java.util.List[_]].isAssignableFrom(l) =>
assert(result.asInstanceOf[java.util.List[_]].asScala == expected)
case s if classOf[java.util.Set[_]].isAssignableFrom(s) =>
assert(result.asInstanceOf[java.util.Set[_]].asScala == expected.toSet)
case a if classOf[mutable.ArraySeq[Int]].isAssignableFrom(a) =>
assert(result == mutable.ArraySeq.make[Int](expected.toArray))
case a if classOf[immutable.ArraySeq[Int]].isAssignableFrom(a) =>
Expand All @@ -379,7 +381,8 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
classOf[Seq[Int]], classOf[scala.collection.Set[Int]],
classOf[java.util.List[Int]], classOf[java.util.AbstractList[Int]],
classOf[java.util.AbstractSequentialList[Int]], classOf[java.util.Vector[Int]],
classOf[java.util.Stack[Int]], null)
classOf[java.util.Stack[Int]], null,
classOf[java.util.Set[Int]])

val list = new java.util.ArrayList[Int]()
list.add(1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,26 @@ public void testTypedFilterPreservingSchema() {
Assertions.assertEquals(ds.schema(), ds2.schema());
}

@Test
public void testBeanWithSet() {
BeanWithSet bean = new BeanWithSet();
Set<Long> fields = asSet(1L, 2L, 3L);
bean.setFields(fields);
List<BeanWithSet> objects = Collections.singletonList(bean);

Dataset<BeanWithSet> ds = spark.createDataset(objects, Encoders.bean(BeanWithSet.class));
Dataset<Row> df = ds.toDF();

Dataset<BeanWithSet> mapped =
df.map((MapFunction<Row, BeanWithSet>) row -> {
BeanWithSet obj = new BeanWithSet();
obj.setFields(new HashSet<>(row.<Long>getList(row.fieldIndex("fields"))));
return obj;
}, Encoders.bean(BeanWithSet.class));

Assertions.assertEquals(objects, mapped.collectAsList());
}

@Test
public void testCommonOperation() {
List<String> data = Arrays.asList("hello", "world");
Expand Down Expand Up @@ -1989,6 +2009,31 @@ public Row call(Long i) {
Assertions.assertEquals(expected, df.collectAsList());
}

public static class BeanWithSet implements Serializable {
private Set<Long> fields;

public Set<Long> getFields() {
return fields;
}

public void setFields(Set<Long> fields) {
this.fields = fields;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
BeanWithSet that = (BeanWithSet) o;
return Objects.equal(fields, that.fields);
}

@Override
public int hashCode() {
return Objects.hashCode(fields);
}
}

public static class SpecificListsBean implements Serializable {
private ArrayList<Integer> arrayList;
private LinkedList<Integer> linkedList;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql
import java.io.{Externalizable, ObjectInput, ObjectOutput}
import java.sql.{Date, Timestamp}

import scala.collection.immutable.HashSet
import scala.reflect.ClassTag
import scala.util.Random

Expand Down Expand Up @@ -2706,6 +2707,12 @@ class DatasetSuite extends QueryTest
assert(exception.context.head.asInstanceOf[DataFrameQueryContext].stackTrace.length == 2)
}
}

test("SPARK-46791: Dataset with set field") {
val ds = Seq(WithSet(0, HashSet("foo", "bar")), WithSet(1, HashSet("bar", "zoo"))).toDS()
checkDataset(ds.map(t => t),
WithSet(0, HashSet("foo", "bar")), WithSet(1, HashSet("bar", "zoo")))
}
}

class DatasetLargeResultCollectingSuite extends QueryTest
Expand Down Expand Up @@ -2759,6 +2766,8 @@ case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map
case class WithMap(id: String, map_test: scala.collection.Map[Long, String])
case class WithMapInOption(m: Option[scala.collection.Map[Int, Int]])

case class WithSet(id: Int, values: Set[String])

case class Generic[T](id: T, value: Double)

case class OtherTuple(_1: String, _2: Int)
Expand Down

0 comments on commit 667c0a9

Please sign in to comment.