Skip to content

Commit

Permalink
commit
Browse files Browse the repository at this point in the history
commit
  • Loading branch information
dtenedor committed Dec 18, 2024
1 parent 576caec commit 1585219
Show file tree
Hide file tree
Showing 4 changed files with 222 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ trait RDG extends Expression with ExpressionWithRandomSeed {
@transient protected lazy val seed: Long = seedExpression match {
case e if e.dataType == IntegerType => e.eval().asInstanceOf[Int]
case e if e.dataType == LongType => e.eval().asInstanceOf[Long]
case e if e.dataType == FloatType => e.eval().asInstanceOf[Float].toLong
case e if e.dataType == DoubleType => e.eval().asInstanceOf[Double].toLong
case e if e.dataType.isInstanceOf[DecimalType] => e.eval().asInstanceOf[Decimal].toLong
}

override def nullable: Boolean = false
Expand Down Expand Up @@ -231,6 +234,8 @@ case class Uniform(min: Expression, max: Expression, seedExpression: Expression,
if Seq(first, second).forall(integer) => ShortType
case (_, DoubleType) | (DoubleType, _) => DoubleType
case (_, FloatType) | (FloatType, _) => FloatType
case (_, d: DecimalType) => d
case (d: DecimalType, _) => d
case _ =>
throw SparkException.internalError(
s"Unexpected argument data types: ${min.dataType}, ${max.dataType}")
Expand Down Expand Up @@ -263,7 +268,7 @@ case class Uniform(min: Expression, max: Expression, seedExpression: Expression,
"inputExpr" -> toSQLExpr(expr)))
} else expr.dataType match {
case _: ShortType | _: IntegerType | _: LongType | _: FloatType | _: DoubleType |
_: NullType =>
_: DecimalType | _: NullType =>
case _ =>
result = DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,31 @@ SELECT uniform(10.0F, 20.0F, 0) AS result
[Analyzer test output redacted due to nondeterminism]


-- !query
SELECT uniform(10.0F, 20.0F, 0.0F) AS result
-- !query analysis
[Analyzer test output redacted due to nondeterminism]


-- !query
SELECT uniform(10.0F, 20.0F, 0.0D) AS result
-- !query analysis
[Analyzer test output redacted due to nondeterminism]


-- !query
SELECT uniform(cast(10 as decimal(10, 3)), cast(20 as decimal(10, 3)), 0.0D) AS result
-- !query analysis
[Analyzer test output redacted due to nondeterminism]


-- !query
SELECT uniform(cast(10 as decimal(10, 3)), cast(20 as decimal(10, 3)), cast(0 as decimal(10, 3)))
AS result
-- !query analysis
[Analyzer test output redacted due to nondeterminism]


-- !query
SELECT uniform(10.0D, 20.0D, CAST(3 / 7 AS LONG)) AS result
-- !query analysis
Expand Down Expand Up @@ -271,6 +296,78 @@ org.apache.spark.sql.AnalysisException
}


-- !query
SELECT uniform('abc', 10, 0) AS result
-- !query analysis
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
"sqlState" : "42K09",
"messageParameters" : {
"inputSql" : "\"abc\"",
"inputType" : "\"STRING\"",
"paramIndex" : "first",
"requiredType" : "integer or floating-point",
"sqlExpr" : "\"uniform(abc, 10, 0)\""
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 8,
"stopIndex" : 28,
"fragment" : "uniform('abc', 10, 0)"
} ]
}


-- !query
SELECT uniform(0, 'def', 0) AS result
-- !query analysis
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
"sqlState" : "42K09",
"messageParameters" : {
"inputSql" : "\"def\"",
"inputType" : "\"STRING\"",
"paramIndex" : "second",
"requiredType" : "integer or floating-point",
"sqlExpr" : "\"uniform(0, def, 0)\""
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 8,
"stopIndex" : 27,
"fragment" : "uniform(0, 'def', 0)"
} ]
}


-- !query
SELECT uniform(0, 10, 'ghi') AS result
-- !query analysis
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
"sqlState" : "42K09",
"messageParameters" : {
"inputSql" : "\"ghi\"",
"inputType" : "\"STRING\"",
"paramIndex" : "third",
"requiredType" : "integer or floating-point",
"sqlExpr" : "\"uniform(0, 10, ghi)\""
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 8,
"stopIndex" : 28,
"fragment" : "uniform(0, 10, 'ghi')"
} ]
}


-- !query
SELECT randstr(1, 0) AS result
-- !query analysis
Expand Down
8 changes: 8 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/random.sql
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ SELECT uniform(0, 10L, 0) AS result;
SELECT uniform(0, 10S, 0) AS result;
SELECT uniform(10, 20, 0) AS result;
SELECT uniform(10.0F, 20.0F, 0) AS result;
SELECT uniform(10.0F, 20.0F, 0.0F) AS result;
SELECT uniform(10.0F, 20.0F, 0.0D) AS result;
SELECT uniform(cast(10 as decimal(10, 3)), cast(20 as decimal(10, 3)), 0.0D) AS result;
SELECT uniform(cast(10 as decimal(10, 3)), cast(20 as decimal(10, 3)), cast(0 as decimal(10, 3)))
AS result;
SELECT uniform(10.0D, 20.0D, CAST(3 / 7 AS LONG)) AS result;
SELECT uniform(10, 20.0F, 0) AS result;
SELECT uniform(10, 20, 0) AS result FROM VALUES (0), (1), (2) tab(col);
Expand All @@ -37,6 +42,9 @@ SELECT uniform(10, 20, col) AS result FROM VALUES (0), (1), (2) tab(col);
SELECT uniform(col, 10, 0) AS result FROM VALUES (0), (1), (2) tab(col);
SELECT uniform(10) AS result;
SELECT uniform(10, 20, 30, 40) AS result;
SELECT uniform('abc', 10, 0) AS result;
SELECT uniform(0, 'def', 0) AS result;
SELECT uniform(0, 10, 'ghi') AS result;

-- The randstr random string generation function supports generating random strings within a
-- specified length. We use a seed of zero for these queries to keep tests deterministic.
Expand Down
111 changes: 111 additions & 0 deletions sql/core/src/test/resources/sql-tests/results/random.sql.out
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,39 @@ struct<result:float>
17.604954


-- !query
SELECT uniform(10.0F, 20.0F, 0.0F) AS result
-- !query schema
struct<result:float>
-- !query output
17.604954


-- !query
SELECT uniform(10.0F, 20.0F, 0.0D) AS result
-- !query schema
struct<result:float>
-- !query output
17.604954


-- !query
SELECT uniform(cast(10 as decimal(10, 3)), cast(20 as decimal(10, 3)), 0.0D) AS result
-- !query schema
struct<result:decimal(10,3)>
-- !query output
17.605


-- !query
SELECT uniform(cast(10 as decimal(10, 3)), cast(20 as decimal(10, 3)), cast(0 as decimal(10, 3)))
AS result
-- !query schema
struct<result:decimal(10,3)>
-- !query output
17.605


-- !query
SELECT uniform(10.0D, 20.0D, CAST(3 / 7 AS LONG)) AS result
-- !query schema
Expand Down Expand Up @@ -329,6 +362,84 @@ org.apache.spark.sql.AnalysisException
}


-- !query
SELECT uniform('abc', 10, 0) AS result
-- !query schema
struct<>
-- !query output
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
"sqlState" : "42K09",
"messageParameters" : {
"inputSql" : "\"abc\"",
"inputType" : "\"STRING\"",
"paramIndex" : "first",
"requiredType" : "integer or floating-point",
"sqlExpr" : "\"uniform(abc, 10, 0)\""
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 8,
"stopIndex" : 28,
"fragment" : "uniform('abc', 10, 0)"
} ]
}


-- !query
SELECT uniform(0, 'def', 0) AS result
-- !query schema
struct<>
-- !query output
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
"sqlState" : "42K09",
"messageParameters" : {
"inputSql" : "\"def\"",
"inputType" : "\"STRING\"",
"paramIndex" : "second",
"requiredType" : "integer or floating-point",
"sqlExpr" : "\"uniform(0, def, 0)\""
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 8,
"stopIndex" : 27,
"fragment" : "uniform(0, 'def', 0)"
} ]
}


-- !query
SELECT uniform(0, 10, 'ghi') AS result
-- !query schema
struct<>
-- !query output
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
"sqlState" : "42K09",
"messageParameters" : {
"inputSql" : "\"ghi\"",
"inputType" : "\"STRING\"",
"paramIndex" : "third",
"requiredType" : "integer or floating-point",
"sqlExpr" : "\"uniform(0, 10, ghi)\""
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 8,
"stopIndex" : 28,
"fragment" : "uniform(0, 10, 'ghi')"
} ]
}


-- !query
SELECT randstr(1, 0) AS result
-- !query schema
Expand Down

0 comments on commit 1585219

Please sign in to comment.