diff --git a/datafu-spark/src/main/scala/datafu/spark/SparkUDAFs.scala b/datafu-spark/src/main/scala/datafu/spark/SparkUDAFs.scala deleted file mode 100644 index 9b92b3db..00000000 --- a/datafu-spark/src/main/scala/datafu/spark/SparkUDAFs.scala +++ /dev/null @@ -1,285 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package datafu.spark - -import scala.collection.{mutable, Map} - -import org.apache.spark.sql.Row -import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} -import org.apache.spark.sql.types.{ArrayType, _} - - -/** - * UserDefineAggregateFunction is deprecated and will be removed in DataFu 2.1.0 in order to allow compilation with Spark 3.2 and up. - * Please use the methods in @Aggregators instead - */ -@Deprecated -object SparkUDAFs { - - /** - * Like Google's MultiSets. - * Aggregate function that creates a map of key to its count. - */ - @Deprecated - class MultiSet() extends UserDefinedAggregateFunction { - - def inputSchema: StructType = new StructType().add("key", StringType) - - def bufferSchema: StructType = - new StructType().add("mp", MapType(StringType, IntegerType)) - - def dataType: DataType = MapType(StringType, IntegerType, false) - - def deterministic: Boolean = true - - // This function is called whenever key changes - def initialize(buffer: MutableAggregationBuffer): Unit = { - buffer(0) = mutable.Map() - } - - // Iterate over each entry of a group - def update(buffer: MutableAggregationBuffer, input: Row): Unit = { - val key = input.getString(0) - if (key != null) { - buffer(0) = buffer.getMap(0) + (key -> (buffer - .getMap(0) - .getOrElse(key, 0) + 1)) - } - } - - // Merge two partial aggregates - def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { - val mp = mutable.Map[String, Int]() ++= buffer1.getMap(0) - buffer2 - .getMap(0) - .keys - .foreach((key: String) => - if (key != null) { - mp.put(key, - mp.getOrElse(key, 0) + buffer2.getMap(0).getOrElse(key, 0)) - }) - buffer1(0) = mp - } - - // Called after all the entries are exhausted. - def evaluate(buffer: Row): Any = { - buffer(0) - } - - } - - /** - * Essentially the same as MultiSet, but gets an Array for input. - * There is an extra option to limit the number of keys (like CountDistinctUpTo) - */ - @Deprecated - class MultiArraySet[T: Ordering](dt: DataType = StringType, maxKeys: Int = -1) - extends UserDefinedAggregateFunction { - - def inputSchema: StructType = new StructType().add("key", ArrayType(dt)) - - def bufferSchema: StructType = new StructType().add("mp", dataType) - - def dataType: DataType = MapType(dt, IntegerType, false) - - def deterministic: Boolean = true - - // This function is called whenever key changes - def initialize(buffer: MutableAggregationBuffer): Unit = { - buffer(0) = mutable.Map() - } - - // Iterate over each entry of a group - def update(buffer: MutableAggregationBuffer, input: Row): Unit = { - val mp = mutable.Map[T, Int]() ++= buffer.getMap(0) - val keyArr: Seq[T] = Option(input.getAs[Seq[T]](0)).getOrElse(Nil) - for (key <- keyArr; if key != null) - mp.put(key, mp.getOrElse(key, 0) + 1) - - buffer(0) = limitKeys(mp, 3) - } - - // Merge two partial aggregates - def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { - val mp = mutable.Map[T, Int]() ++= buffer1.getMap(0) - buffer2 - .getMap(0) - .keys - .foreach((key: T) => - if (key != null) { - mp.put(key, - mp.getOrElse(key, 0) + buffer2.getMap(0).getOrElse(key, 0)) - }) - - buffer1(0) = limitKeys(mp, 3) - } - - private def limitKeys(mp: Map[T, Int], factor: Int = 1): Map[T, Int] = { - if (maxKeys > 0 && maxKeys * factor < mp.size) { - val k = mp.toList.map(_.swap).sorted.reverse(maxKeys - 1)._1 - var mp2 = mutable.Map[T, Int]() ++= mp.filter((t: (T, Int)) => - t._2 >= k) - var toRemove = mp2.size - maxKeys - if (toRemove > 0) { - mp2 = mp2.filter((t: (T, Int)) => { - if (t._2 > k) { - true - } else { - if (toRemove >= 0) { - toRemove = toRemove - 1 - } - toRemove < 0 - } - }) - } - mp2 - } else { - mp - } - } - - // Called after all the entries are exhausted. - def evaluate(buffer: Row): Map[T, Int] = { - limitKeys(buffer.getMap(0).asInstanceOf[Map[T, Int]]) - } - - } - - /** - * Merge maps of kind string -> set - */ - @Deprecated - class MapSetMerge extends UserDefinedAggregateFunction { - - def inputSchema: StructType = new StructType().add("key", dataType) - - def bufferSchema: StructType = inputSchema - - def dataType: DataType = MapType(StringType, ArrayType(StringType)) - - def deterministic: Boolean = true - - // This function is called whenever key changes - def initialize(buffer: MutableAggregationBuffer): Unit = { - buffer(0) = mutable.Map() - } - - // Iterate over each entry of a group - def update(buffer: MutableAggregationBuffer, input: Row): Unit = { - val mp0 = input.getMap(0) - if (mp0 != null) { - val mp = mutable.Map[String, mutable.WrappedArray[String]]() ++= input - .getMap(0) - buffer(0) = - merge(mp, buffer.getMap[String, mutable.WrappedArray[String]](0)) - } - } - - // Merge two partial aggregates - def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { - val mp = mutable.Map[String, mutable.WrappedArray[String]]() ++= buffer1 - .getMap(0) - buffer1(0) = - merge(mp, buffer2.getMap[String, mutable.WrappedArray[String]](0)) - } - - def merge(mpBuffer: mutable.Map[String, mutable.WrappedArray[String]], - mp: Map[String, mutable.WrappedArray[String]]) - : mutable.Map[String, mutable.WrappedArray[String]] = { - if (mp != null) { - mp.keys.foreach((key: String) => { - val blah1: mutable.WrappedArray[String] = - mpBuffer.getOrElse(key, mutable.WrappedArray.empty) - val blah2: mutable.WrappedArray[String] = - mp.getOrElse(key, mutable.WrappedArray.empty) - mpBuffer.put( - key, - mutable.WrappedArray.make( - (Option(blah1).getOrElse(mutable.WrappedArray.empty) ++ Option( - blah2).getOrElse(mutable.WrappedArray.empty)).toSet.toArray) - ) - }) - } - - mpBuffer - } - - // Called after all the entries are exhausted. - def evaluate(buffer: Row): Any = { - buffer(0) - } - - } - - /** - * Counts number of distinct records, but only up to a preset amount - - * more efficient than an unbounded count - */ - @Deprecated - class CountDistinctUpTo(maxItems: Int = -1) - extends UserDefinedAggregateFunction { - - def inputSchema: StructType = new StructType().add("key", StringType) - - def bufferSchema: StructType = - new StructType().add("mp", MapType(StringType, BooleanType)) - - def dataType: DataType = IntegerType - - def deterministic: Boolean = true - - // This function is called whenever key changes - def initialize(buffer: MutableAggregationBuffer): Unit = { - buffer(0) = mutable.Map() - } - - // Iterate over each entry of a group - def update(buffer: MutableAggregationBuffer, input: Row): Unit = { - if (buffer.getMap(0).size < maxItems) { - val key = input.getString(0) - if (key != null) { - buffer(0) = buffer.getMap(0) + (key -> true) - } - } - } - - // Merge two partial aggregates - def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { - if (buffer1.getMap(0).size < maxItems) { - val mp = mutable.Map[String, Boolean]() ++= buffer1.getMap(0) - buffer2 - .getMap(0) - .keys - .foreach((key: String) => - if (key != null) { - mp.put(key, true) - }) - buffer1(0) = mp - } - - } - - // Called after all the entries are exhausted. - def evaluate(buffer: Row): Int = { - math.min(buffer.getMap(0).size, maxItems) - } - - } - -} diff --git a/datafu-spark/src/test/scala/datafu/spark/TestAggregators.scala b/datafu-spark/src/test/scala/datafu/spark/TestAggregators.scala index d2046ac1..e2f30dcb 100644 --- a/datafu-spark/src/test/scala/datafu/spark/TestAggregators.scala +++ b/datafu-spark/src/test/scala/datafu/spark/TestAggregators.scala @@ -158,7 +158,7 @@ class TestAggregators extends FunSuite with DataFrameSuiteBase { sqlContext.createDataFrame(List(mapExp(Map("dsa" -> 1, "asd" -> 5)))), spark.table("mas_table2").groupBy().agg(mas2($"arr").as("map_col"))) - val mas1 = new SparkUDAFs.MultiArraySet[String](maxKeys = 1) + val mas1 = udaf(new Aggregators.MultiArraySet[String](maxKeys = 1)) assertDataFrameEquals( sqlContext.createDataFrame(List(mapExp(Map("asd" -> 5)))), spark.table("mas_table2").groupBy().agg(mas1($"arr").as("map_col"))) @@ -247,17 +247,17 @@ class TestAggregators extends FunSuite with DataFrameSuiteBase { Exp6(Option(2), Option(1)) )) - assertDataFrameEquals(results3DF, + assertDataFrameNoOrderEquals(results3DF, inputDF .groupBy("col_grp") .agg(countDistinctUpTo3($"col_ord").as("col_ord"))) - assertDataFrameEquals(results6DF, + assertDataFrameNoOrderEquals(results6DF, inputDF .groupBy("col_grp") .agg(countDistinctUpTo6($"col_ord").as("col_ord"))) - assertDataFrameEquals(results2DF, inputDF + assertDataFrameNoOrderEquals(results2DF, inputDF .groupBy("col_ord") .agg(countDistinctUpTo2($"col_grp").as("col_grp"))) } diff --git a/datafu-spark/src/test/scala/datafu/spark/TestSparkUDAFs.scala b/datafu-spark/src/test/scala/datafu/spark/TestSparkUDAFs.scala index 3e086a42..1abe09d3 100644 --- a/datafu-spark/src/test/scala/datafu/spark/TestSparkUDAFs.scala +++ b/datafu-spark/src/test/scala/datafu/spark/TestSparkUDAFs.scala @@ -87,126 +87,8 @@ class UdafTests extends FunSuite with DataFrameSuiteBase { ) } - test("test multiset simple") { - val ms = new SparkUDAFs.MultiSet() - val expected: DataFrame = - sqlContext.createDataFrame(List(mapExp(Map("b" -> 1, "a" -> 3)))) - assertDataFrameEquals(expected, df.agg(ms($"col_grp").as("map_col"))) - } - - val mas = new SparkUDAFs.MultiArraySet[String]() - - test("test multiarrayset simple") { - assertDataFrameEquals( - sqlContext.createDataFrame(List(mapExp(Map("tre" -> 1, "asd" -> 2)))), - spark - .sql("select array('asd','tre','asd') arr") - .groupBy() - .agg(mas($"arr").as("map_col")) - ) - } - - test("test multiarrayset all nulls") { - // end case - spark.sql("drop table if exists mas_table") - deleteLeftoverFiles("mas_table") - - spark.sql("create table mas_table (arr array)") - spark.sql( - "insert overwrite table mas_table select case when 1=2 then array('asd') end " + - "from (select 1)") - spark.sql( - "insert into table mas_table select case when 1=2 then array('asd') end from (select 1)") - spark.sql( - "insert into table mas_table select case when 1=2 then array('asd') end from (select 1)") - spark.sql( - "insert into table mas_table select case when 1=2 then array('asd') end from (select 1)") - spark.sql( - "insert into table mas_table select case when 1=2 then array('asd') end from (select 1)") - - val expected = sqlContext.createDataFrame(List(mapExp(Map()))) - - val actual = - spark.table("mas_table").groupBy().agg(mas($"arr").as("map_col")) - - assertDataFrameEquals(expected, actual) - } - - test("test multiarrayset max keys") { - // max keys case - spark.sql("drop table if exists mas_table2") - deleteLeftoverFiles("mas_table2") - - spark.sql("create table mas_table2 (arr array)") - spark.sql( - "insert overwrite table mas_table2 select array('asd','dsa') from (select 1)") - spark.sql( - "insert into table mas_table2 select array('asd','abc') from (select 1)") - spark.sql( - "insert into table mas_table2 select array('asd') from (select 1)") - spark.sql( - "insert into table mas_table2 select array('asd') from (select 1)") - spark.sql( - "insert into table mas_table2 select array('asd') from (select 1)") - spark.sql( - "insert into table mas_table2 select array('asd2') from (select 1)") - - val mas2 = new SparkUDAFs.MultiArraySet[String](maxKeys = 2) - - assertDataFrameEquals( - sqlContext.createDataFrame(List(mapExp(Map("dsa" -> 1, "asd" -> 5)))), - spark.table("mas_table2").groupBy().agg(mas2($"arr").as("map_col"))) - - val mas1 = new SparkUDAFs.MultiArraySet[String](maxKeys = 1) - assertDataFrameEquals( - sqlContext.createDataFrame(List(mapExp(Map("asd" -> 5)))), - spark.table("mas_table2").groupBy().agg(mas1($"arr").as("map_col"))) - } - - test("test multiarrayset big input") { - val N = 100000 - val blah = spark.sparkContext - .parallelize(1 to N, 20) - .toDF("num") - .selectExpr("array('asd',concat('dsa',num)) as arr") - val mas = new SparkUDAFs.MultiArraySet[String](maxKeys = 3) - val time1 = System.currentTimeMillis() - val mp = blah - .groupBy() - .agg(mas($"arr")) - .collect() - .map(_.getMap[String, Int](0)) - .head - Assert.assertEquals(3, mp.size) - Assert.assertEquals("asd", mp.maxBy(_._2)._1) - Assert.assertEquals(N, mp.maxBy(_._2)._2) - val time2 = System.currentTimeMillis() - logger.info("time took: " + (time2 - time1) / 1000 + " secs") - } - - test("test mapmerge") { - val mapMerge = new SparkUDAFs.MapSetMerge() - - spark.sql("drop table if exists mapmerge_table") - deleteLeftoverFiles("mapmerge_table") - - spark.sql("create table mapmerge_table (c map>)") - spark.sql( - "insert overwrite table mapmerge_table select map('k1', array('v1')) from (select 1) z") - spark.sql( - "insert into table mapmerge_table select map('k1', array('v1')) from (select 1) z") - spark.sql( - "insert into table mapmerge_table select map('k2', array('v3')) from (select 1) z") - - assertDataFrameEquals( - sqlContext.createDataFrame( - List(mapArrExp(Map("k1" -> Array("v1"), "k2" -> Array("v3"))))), - spark.table("mapmerge_table").groupBy().agg(mapMerge($"c").as("map_col")) - ) - } - test("minKeyValue") { - assertDataFrameEquals( + assertDataFrameNoOrderEquals( sqlContext.createDataFrame(List(("b", "asd4"), ("a", "asd1"))), df.groupBy($"col_grp".as("_1")) .agg(SparkOverwriteUDAFs.minValueByKey($"col_ord", $"col_str").as("_2")) @@ -223,7 +105,7 @@ class UdafTests extends FunSuite with DataFrameSuiteBase { ) test("minKeyValue window") { - assertDataFrameEquals( + assertDataFrameNoOrderEquals( sqlContext.createDataFrame( sc.parallelize( Seq( @@ -241,63 +123,6 @@ class UdafTests extends FunSuite with DataFrameSuiteBase { ) } - case class Exp5(col_grp: String, col_ord: Option[Int]) - case class Exp6(col_ord: Option[Int], col_grp: Option[Int]) - - test("countDistinctUpTo") { - import datafu.spark.SparkUDAFs.CountDistinctUpTo - - val countDistinctUpTo2 = new CountDistinctUpTo(2) - val countDistinctUpTo3 = new CountDistinctUpTo(3) - val countDistinctUpTo6 = new CountDistinctUpTo(6) - - val inputDF = sqlContext.createDataFrame( - List( - Exp5("c", Option(1)), - Exp5("b", Option(1)), - Exp5("a", Option(1)), - Exp5("a", Option(2)), - Exp5("a", Option(3)), - Exp5("a", Option(4)) - )) - - val results3DF = sqlContext.createDataFrame( - List( - Exp5("c", Option(1)), - Exp5("b", Option(1)), - Exp5("a", Option(3)) - )) - - val results6DF = sqlContext.createDataFrame( - List( - Exp5("c", Option(1)), - Exp5("b", Option(1)), - Exp5("a", Option(4)) - )) - - val results2DF = sqlContext.createDataFrame( - List( - Exp6(Option(1), Option(2)), - Exp6(Option(3), Option(1)), - Exp6(Option(4), Option(1)), - Exp6(Option(2), Option(1)) - )) - - assertDataFrameEquals(results3DF, - inputDF - .groupBy("col_grp") - .agg(countDistinctUpTo3($"col_ord").as("col_ord"))) - - assertDataFrameEquals(results6DF, - inputDF - .groupBy("col_grp") - .agg(countDistinctUpTo6($"col_ord").as("col_ord"))) - - assertDataFrameEquals(results2DF,inputDF - .groupBy("col_ord") - .agg(countDistinctUpTo2($"col_grp").as("col_grp"))) - } - test("test_limited_collect_list") { val maxSize = 10