Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-31808][SQL] Makes struct function's output name and class name pretty #28633

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{FUNC_ALIAS, FunctionBuilder}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util._
Expand Down Expand Up @@ -311,7 +311,12 @@ case object NamePlaceholder extends LeafExpression with Unevaluable {
/**
* Returns a Row containing the evaluation of all children expressions.
*/
object CreateStruct extends FunctionBuilder {
object CreateStruct {
/**
* Returns a named struct with generating names or using the names when available.
* It should not be used for `struct` expressions or functions explicitly called
* by users.
*/
def apply(children: Seq[Expression]): CreateNamedStruct = {
CreateNamedStruct(children.zipWithIndex.flatMap {
case (e: NamedExpression, _) if e.resolved => Seq(Literal(e.name), e)
Expand All @@ -320,12 +325,23 @@ object CreateStruct extends FunctionBuilder {
})
}

/**
* Returns a named struct with a pretty SQL name. It will show the pretty SQL string
* in its output column name as if `struct(...)` was called. Should be
* used for `struct` expressions or functions explicitly called by users.
*/
def create(children: Seq[Expression]): CreateNamedStruct = {
val expr = CreateStruct(children)
expr.setTagValue(FUNC_ALIAS, "struct")
expr
}

/**
* Entry to use in the function registry.
*/
val registryEntry: (String, (ExpressionInfo, FunctionBuilder)) = {
val info: ExpressionInfo = new ExpressionInfo(
"org.apache.spark.sql.catalyst.expressions.NamedStruct",
classOf[CreateNamedStruct].getCanonicalName,
null,
"struct",
"_FUNC_(col1, col2, col3, ...) - Creates a struct with the given field values.",
Expand All @@ -335,7 +351,7 @@ object CreateStruct extends FunctionBuilder {
"",
"",
"")
("struct", (info, this))
("struct", (info, this.create))
}
}

Expand Down Expand Up @@ -433,7 +449,12 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
""".stripMargin, isNull = FalseLiteral)
}

override def prettyName: String = "named_struct"
override def prettyName: String = getTagValue(FUNC_ALIAS).getOrElse("named_struct")

override def sql: String = getTagValue(FUNC_ALIAS).map { alias =>
val childrenSQL = children.indices.filter(_ % 2 == 1).map(children(_).sql).mkString(", ")
s"$alias($childrenSQL)"
}.getOrElse(super.sql)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1534,7 +1534,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
* Create a [[CreateStruct]] expression.
*/
override def visitStruct(ctx: StructContext): Expression = withOrigin(ctx) {
CreateStruct(ctx.argument.asScala.map(expression))
CreateStruct.create(ctx.argument.asScala.map(expression))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1306,7 +1306,7 @@ object functions {
* @since 1.4.0
*/
@scala.annotation.varargs
def struct(cols: Column*): Column = withExpr { CreateStruct(cols.map(_.expr)) }
def struct(cols: Column*): Column = withExpr { CreateStruct.create(cols.map(_.expr)) }

/**
* Creates a new struct column that composes multiple input columns.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
## Summary
- Number of queries: 336
- Number of expressions that missing example: 34
- Expressions missing examples: and,string,tinyint,double,smallint,date,decimal,boolean,float,binary,bigint,int,timestamp,cume_dist,dense_rank,input_file_block_length,input_file_block_start,input_file_name,lag,lead,monotonically_increasing_id,ntile,struct,!,not,or,percent_rank,rank,row_number,spark_partition_id,version,window,positive,count_min_sketch
- Expressions missing examples: and,string,tinyint,double,smallint,date,decimal,boolean,float,binary,bigint,int,timestamp,struct,cume_dist,dense_rank,input_file_block_length,input_file_block_start,input_file_name,lag,lead,monotonically_increasing_id,ntile,!,not,or,percent_rank,rank,row_number,spark_partition_id,version,window,positive,count_min_sketch
## Schema of Built-in Functions
| Class name | Function name or alias | Query example | Output schema |
| ---------- | ---------------------- | ------------- | ------------- |
Expand Down Expand Up @@ -79,6 +79,7 @@
| org.apache.spark.sql.catalyst.expressions.CreateArray | array | SELECT array(1, 2, 3) | struct<array(1, 2, 3):array<int>> |
| org.apache.spark.sql.catalyst.expressions.CreateMap | map | SELECT map(1.0, '2', 3.0, '4') | struct<map(1.0, 2, 3.0, 4):map<decimal(2,1),string>> |
| org.apache.spark.sql.catalyst.expressions.CreateNamedStruct | named_struct | SELECT named_struct("a", 1, "b", 2, "c", 3) | struct<named_struct(a, 1, b, 2, c, 3):struct<a:int,b:int,c:int>> |
| org.apache.spark.sql.catalyst.expressions.CreateNamedStruct | struct | N/A | N/A |
| org.apache.spark.sql.catalyst.expressions.CsvToStructs | from_csv | SELECT from_csv('1, 0.8', 'a INT, b DOUBLE') | struct<from_csv(1, 0.8):struct<a:int,b:double>> |
| org.apache.spark.sql.catalyst.expressions.Cube | cube | SELECT name, age, count(*) FROM VALUES (2, 'Alice'), (5, 'Bob') people(age, name) GROUP BY cube(name, age) | struct<name:string,age:int,count(1):bigint> |
| org.apache.spark.sql.catalyst.expressions.CumeDist | cume_dist | N/A | N/A |
Expand Down Expand Up @@ -170,7 +171,7 @@
| org.apache.spark.sql.catalyst.expressions.MapEntries | map_entries | SELECT map_entries(map(1, 'a', 2, 'b')) | struct<map_entries(map(1, a, 2, b)):array<struct<key:int,value:string>>> |
| org.apache.spark.sql.catalyst.expressions.MapFilter | map_filter | SELECT map_filter(map(1, 0, 2, 2, 3, -1), (k, v) -> k > v) | struct<map_filter(map(1, 0, 2, 2, 3, -1), lambdafunction((namedlambdavariable() > namedlambdavariable()), namedlambdavariable(), namedlambdavariable())):map<int,int>> |
| org.apache.spark.sql.catalyst.expressions.MapFromArrays | map_from_arrays | SELECT map_from_arrays(array(1.0, 3.0), array('2', '4')) | struct<map_from_arrays(array(1.0, 3.0), array(2, 4)):map<decimal(2,1),string>> |
| org.apache.spark.sql.catalyst.expressions.MapFromEntries | map_from_entries | SELECT map_from_entries(array(struct(1, 'a'), struct(2, 'b'))) | struct<map_from_entries(array(named_struct(col1, 1, col2, a), named_struct(col1, 2, col2, b))):map<int,string>> |
| org.apache.spark.sql.catalyst.expressions.MapFromEntries | map_from_entries | SELECT map_from_entries(array(struct(1, 'a'), struct(2, 'b'))) | struct<map_from_entries(array(struct(1, a), struct(2, b))):map<int,string>> |
| org.apache.spark.sql.catalyst.expressions.MapKeys | map_keys | SELECT map_keys(map(1, 'a', 2, 'b')) | struct<map_keys(map(1, a, 2, b)):array<int>> |
| org.apache.spark.sql.catalyst.expressions.MapValues | map_values | SELECT map_values(map(1, 'a', 2, 'b')) | struct<map_values(map(1, a, 2, b)):array<string>> |
| org.apache.spark.sql.catalyst.expressions.MapZipWith | map_zip_with | SELECT map_zip_with(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2)) | struct<map_zip_with(map(1, a, 2, b), map(1, x, 2, y), lambdafunction(concat(namedlambdavariable(), namedlambdavariable()), namedlambdavariable(), namedlambdavariable(), namedlambdavariable())):map<int,string>> |
Expand All @@ -185,7 +186,6 @@
| org.apache.spark.sql.catalyst.expressions.Murmur3Hash | hash | SELECT hash('Spark', array(123), 2) | struct<hash(Spark, array(123), 2):int> |
| org.apache.spark.sql.catalyst.expressions.NTile | ntile | N/A | N/A |
| org.apache.spark.sql.catalyst.expressions.NaNvl | nanvl | SELECT nanvl(cast('NaN' as double), 123) | struct<nanvl(CAST(NaN AS DOUBLE), CAST(123 AS DOUBLE)):double> |
| org.apache.spark.sql.catalyst.expressions.NamedStruct | struct | N/A | N/A |
| org.apache.spark.sql.catalyst.expressions.NextDay | next_day | SELECT next_day('2015-01-14', 'TU') | struct<next_day(CAST(2015-01-14 AS DATE), TU):date> |
| org.apache.spark.sql.catalyst.expressions.Not | ! | N/A | N/A |
| org.apache.spark.sql.catalyst.expressions.Not | not | N/A | N/A |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ struct<foo:string,approx_count_distinct(a) FILTER (WHERE (b >= 0)):bigint>
-- !query
SELECT 'foo', MAX(STRUCT(a)) FILTER (WHERE b >= 1) FROM testData WHERE a = 0 GROUP BY 1
-- !query schema
struct<foo:string,max(named_struct(a, a)) FILTER (WHERE (b >= 1)):struct<a:int>>
struct<foo:string,max(struct(a)) FILTER (WHERE (b >= 1)):struct<a:int>>
-- !query output


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ struct<foo:string,approx_count_distinct(a):bigint>
-- !query
SELECT 'foo', MAX(STRUCT(a)) FROM testData WHERE a = 0 GROUP BY 1
-- !query schema
struct<foo:string,max(named_struct(a, a)):struct<a:int>>
struct<foo:string,max(struct(a)):struct<a:int>>
-- !query output


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ struct<ID:int,NST:string>
-- !query
SELECT ID, STRUCT(ST.C as STC, ST.D as STD).STD FROM tbl_x
-- !query schema
struct<ID:int,named_struct(STC, ST.C AS `C` AS `STC`, STD, ST.D AS `D` AS `STD`).STD:string>
struct<ID:int,struct(ST.C AS `C` AS `STC`, ST.D AS `D` AS `STD`).STD:string>
-- !query output
1 delta
2 eta
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ FROM various_maps
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
cannot resolve 'map_zip_with(various_maps.`decimal_map1`, various_maps.`decimal_map2`, lambdafunction(named_struct(NamePlaceholder(), k, NamePlaceholder(), v1, NamePlaceholder(), v2), k, v1, v2))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,0), decimal(36,35)].; line 1 pos 7
cannot resolve 'map_zip_with(various_maps.`decimal_map1`, various_maps.`decimal_map2`, lambdafunction(struct(k, v1, v2), k, v1, v2))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,0), decimal(36,35)].; line 1 pos 7


-- !query
Expand Down Expand Up @@ -113,7 +113,7 @@ FROM various_maps
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
cannot resolve 'map_zip_with(various_maps.`decimal_map2`, various_maps.`int_map`, lambdafunction(named_struct(NamePlaceholder(), k, NamePlaceholder(), v1, NamePlaceholder(), v2), k, v1, v2))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,35), int].; line 1 pos 7
cannot resolve 'map_zip_with(various_maps.`decimal_map2`, various_maps.`int_map`, lambdafunction(struct(k, v1, v2), k, v1, v2))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,35), int].; line 1 pos 7


-- !query
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ struct<foo:string,CAST(udf(cast(approx_count_distinct(cast(udf(cast(a as string)
-- !query
SELECT 'foo', MAX(STRUCT(udf(a))) FROM testData WHERE a = 0 GROUP BY udf(1)
-- !query schema
struct<foo:string,max(named_struct(col1, CAST(udf(cast(a as string)) AS INT))):struct<col1:int>>
struct<foo:string,max(struct(CAST(udf(cast(a as string)) AS INT))):struct<col1:int>>
-- !query output


Expand Down