From 4a2c375be2bcd98cc7e00bea920fd6a0f68a4e14 Mon Sep 17 00:00:00 2001
From: Herman van Hovell
Date: Tue, 16 Aug 2016 21:35:39 -0700
Subject: [PATCH 001/270] [SPARK-17084][SQL] Rename ParserUtils.assert to
validate
## What changes were proposed in this pull request?
This PR renames `ParserUtils.assert` to `ParserUtils.validate`. This is done because this method is used to check requirements, and not to check if the program is in an invalid state.
## How was this patch tested?
Simple rename. Compilation should do.
Author: Herman van Hovell
Closes #14665 from hvanhovell/SPARK-17084.
---
.../spark/sql/catalyst/parser/AstBuilder.scala | 14 +++++++-------
.../spark/sql/catalyst/parser/ParserUtils.scala | 4 ++--
.../spark/sql/execution/SparkSqlParser.scala | 5 ++---
3 files changed, 11 insertions(+), 12 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 25c8445b4d33f..09b650ce18790 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -132,7 +132,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
// Build the insert clauses.
val inserts = ctx.multiInsertQueryBody.asScala.map {
body =>
- assert(body.querySpecification.fromClause == null,
+ validate(body.querySpecification.fromClause == null,
"Multi-Insert queries cannot have a FROM clause in their individual SELECT statements",
body)
@@ -596,7 +596,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
// function takes X PERCENT as the input and the range of X is [0, 100], we need to
// adjust the fraction.
val eps = RandomSampler.roundingEpsilon
- assert(fraction >= 0.0 - eps && fraction <= 1.0 + eps,
+ validate(fraction >= 0.0 - eps && fraction <= 1.0 + eps,
s"Sampling fraction ($fraction) must be on interval [0, 1]",
ctx)
Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, query)(true)
@@ -664,7 +664,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
// Get the backing expressions.
val expressions = ctx.expression.asScala.map { eCtx =>
val e = expression(eCtx)
- assert(e.foldable, "All expressions in an inline table must be constants.", eCtx)
+ validate(e.foldable, "All expressions in an inline table must be constants.", eCtx)
e
}
@@ -686,7 +686,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
val baseAttributes = structType.toAttributes.map(_.withNullability(true))
val attributes = if (ctx.identifierList != null) {
val aliases = visitIdentifierList(ctx.identifierList)
- assert(aliases.size == baseAttributes.size,
+ validate(aliases.size == baseAttributes.size,
"Number of aliases must match the number of fields in an inline table.", ctx)
baseAttributes.zip(aliases).map(p => p._1.withName(p._2))
} else {
@@ -1094,7 +1094,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
// We currently only allow foldable integers.
def value: Int = {
val e = expression(ctx.expression)
- assert(e.resolved && e.foldable && e.dataType == IntegerType,
+ validate(e.resolved && e.foldable && e.dataType == IntegerType,
"Frame bound value must be a constant integer.",
ctx)
e.eval().asInstanceOf[Int]
@@ -1347,7 +1347,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
*/
override def visitInterval(ctx: IntervalContext): Literal = withOrigin(ctx) {
val intervals = ctx.intervalField.asScala.map(visitIntervalField)
- assert(intervals.nonEmpty, "at least one time unit should be given for interval literal", ctx)
+ validate(intervals.nonEmpty, "at least one time unit should be given for interval literal", ctx)
Literal(intervals.reduce(_.add(_)))
}
@@ -1374,7 +1374,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
case (from, Some(t)) =>
throw new ParseException(s"Intervals FROM $from TO $t are not supported.", ctx)
}
- assert(interval != null, "No interval can be constructed", ctx)
+ validate(interval != null, "No interval can be constructed", ctx)
interval
} catch {
// Handle Exceptions thrown by CalendarInterval
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
index b04ce58e233aa..bc35ae2f55409 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
@@ -77,8 +77,8 @@ object ParserUtils {
Origin(Option(token.getLine), Option(token.getCharPositionInLine))
}
- /** Assert if a condition holds. If it doesn't throw a parse exception. */
- def assert(f: => Boolean, message: String, ctx: ParserRuleContext): Unit = {
+ /** Validate the condition. If it doesn't throw a parse exception. */
+ def validate(f: => Boolean, message: String, ctx: ParserRuleContext): Unit = {
if (!f) {
throw new ParseException(message, ctx)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
index 9da2b5a254e28..71c3bd31e02e4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
@@ -18,7 +18,6 @@
package org.apache.spark.sql.execution
import scala.collection.JavaConverters._
-import scala.util.Try
import org.antlr.v4.runtime.{ParserRuleContext, Token}
import org.antlr.v4.runtime.tree.TerminalNode
@@ -799,7 +798,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder {
}
/**
- * Create an [[AlterTableDiscoverPartitionsCommand]] command
+ * Create an [[AlterTableRecoverPartitionsCommand]] command
*
* For example:
* {{{
@@ -1182,7 +1181,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder {
entry("mapkey.delim", ctx.keysTerminatedBy) ++
Option(ctx.linesSeparatedBy).toSeq.map { token =>
val value = string(token)
- assert(
+ validate(
value == "\n",
s"LINES TERMINATED BY only supports newline '\\n' right now: $value",
ctx)
From f7c9ff57c17a950cccdc26aadf8768c899a4d572 Mon Sep 17 00:00:00 2001
From: Herman van Hovell
Date: Tue, 16 Aug 2016 23:09:53 -0700
Subject: [PATCH 002/270] [SPARK-17068][SQL] Make view-usage visible during
analysis
## What changes were proposed in this pull request?
This PR adds a field to subquery alias in order to make the usage of views in a resolved `LogicalPlan` more visible (and more understandable).
For example, the following view and query:
```sql
create view constants as select 1 as id union all select 1 union all select 42
select * from constants;
```
...now yields the following analyzed plan:
```
Project [id#39]
+- SubqueryAlias c, `default`.`constants`
+- Project [gen_attr_0#36 AS id#39]
+- SubqueryAlias gen_subquery_0
+- Union
:- Union
: :- Project [1 AS gen_attr_0#36]
: : +- OneRowRelation$
: +- Project [1 AS gen_attr_1#37]
: +- OneRowRelation$
+- Project [42 AS gen_attr_2#38]
+- OneRowRelation$
```
## How was this patch tested?
Added tests for the two code paths in `SessionCatalogSuite` (sql/core) and `HiveMetastoreCatalogSuite` (sql/hive)
Author: Herman van Hovell
Closes #14657 from hvanhovell/SPARK-17068.
---
.../sql/catalyst/analysis/Analyzer.scala | 4 +--
.../sql/catalyst/analysis/CheckAnalysis.scala | 4 +--
.../sql/catalyst/catalog/SessionCatalog.scala | 30 ++++++++++---------
.../spark/sql/catalyst/dsl/package.scala | 4 +--
.../sql/catalyst/expressions/subquery.scala | 8 ++---
.../sql/catalyst/optimizer/Optimizer.scala | 8 ++---
.../sql/catalyst/parser/AstBuilder.scala | 4 +--
.../plans/logical/basicLogicalOperators.scala | 7 ++++-
.../sql/catalyst/analysis/AnalysisSuite.scala | 4 +--
.../catalog/SessionCatalogSuite.scala | 19 ++++++++----
.../optimizer/ColumnPruningSuite.scala | 8 ++---
.../EliminateSubqueryAliasesSuite.scala | 6 ++--
.../optimizer/JoinOptimizationSuite.scala | 8 ++---
.../sql/catalyst/parser/PlanParserSuite.scala | 2 +-
.../scala/org/apache/spark/sql/Dataset.scala | 2 +-
.../spark/sql/catalyst/SQLBuilder.scala | 6 ++--
.../sql/execution/datasources/rules.scala | 2 +-
.../spark/sql/hive/HiveMetastoreCatalog.scala | 21 ++++++-------
.../spark/sql/hive/HiveSessionCatalog.scala | 4 +--
.../sql/hive/HiveMetastoreCatalogSuite.scala | 14 ++++++++-
20 files changed, 94 insertions(+), 71 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index a2a022c2476fb..bd4c19181f647 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -138,7 +138,7 @@ class Analyzer(
case u : UnresolvedRelation =>
val substituted = cteRelations.find(x => resolver(x._1, u.tableIdentifier.table))
.map(_._2).map { relation =>
- val withAlias = u.alias.map(SubqueryAlias(_, relation))
+ val withAlias = u.alias.map(SubqueryAlias(_, relation, None))
withAlias.getOrElse(relation)
}
substituted.getOrElse(u)
@@ -2057,7 +2057,7 @@ class Analyzer(
*/
object EliminateSubqueryAliases extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
- case SubqueryAlias(_, child) => child
+ case SubqueryAlias(_, child, _) => child
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 41b7e62d8ccea..e07e9194bee9f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -141,8 +141,8 @@ trait CheckAnalysis extends PredicateHelper {
// Skip projects and subquery aliases added by the Analyzer and the SQLBuilder.
def cleanQuery(p: LogicalPlan): LogicalPlan = p match {
- case SubqueryAlias(_, child) => cleanQuery(child)
- case Project(_, child) => cleanQuery(child)
+ case s: SubqueryAlias => cleanQuery(s.child)
+ case p: Project => cleanQuery(p.child)
case child => child
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
index 00c3db0aac1ac..62d0da076b5a2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
@@ -411,27 +411,29 @@ class SessionCatalog(
}
/**
- * Return a [[LogicalPlan]] that represents the given table.
+ * Return a [[LogicalPlan]] that represents the given table or view.
*
- * If a database is specified in `name`, this will return the table from that database.
- * If no database is specified, this will first attempt to return a temporary table with
- * the same name, then, if that does not exist, return the table from the current database.
+ * If a database is specified in `name`, this will return the table/view from that database.
+ * If no database is specified, this will first attempt to return a temporary table/view with
+ * the same name, then, if that does not exist, return the table/view from the current database.
+ *
+ * If the relation is a view, the relation will be wrapped in a [[SubqueryAlias]] which will
+ * track the name of the view.
*/
def lookupRelation(name: TableIdentifier, alias: Option[String] = None): LogicalPlan = {
synchronized {
val db = formatDatabaseName(name.database.getOrElse(currentDb))
val table = formatTableName(name.table)
- val relation =
- if (name.database.isDefined || !tempTables.contains(table)) {
- val metadata = externalCatalog.getTable(db, table)
- SimpleCatalogRelation(db, metadata)
- } else {
- tempTables(table)
+ val relationAlias = alias.getOrElse(table)
+ if (name.database.isDefined || !tempTables.contains(table)) {
+ val metadata = externalCatalog.getTable(db, table)
+ val view = Option(metadata.tableType).collect {
+ case CatalogTableType.VIEW => name
}
- val qualifiedTable = SubqueryAlias(table, relation)
- // If an alias was specified by the lookup, wrap the plan in a subquery so that
- // attributes are properly qualified with this alias.
- alias.map(a => SubqueryAlias(a, qualifiedTable)).getOrElse(qualifiedTable)
+ SubqueryAlias(relationAlias, SimpleCatalogRelation(db, metadata), view)
+ } else {
+ SubqueryAlias(relationAlias, tempTables(table), Option(name))
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 5181dcc786a3d..9f54d709a022d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -343,7 +343,7 @@ package object dsl {
orderSpec: Seq[SortOrder]): LogicalPlan =
Window(windowExpressions, partitionSpec, orderSpec, logicalPlan)
- def subquery(alias: Symbol): LogicalPlan = SubqueryAlias(alias.name, logicalPlan)
+ def subquery(alias: Symbol): LogicalPlan = SubqueryAlias(alias.name, logicalPlan, None)
def except(otherPlan: LogicalPlan): LogicalPlan = Except(logicalPlan, otherPlan)
@@ -367,7 +367,7 @@ package object dsl {
def as(alias: String): LogicalPlan = logicalPlan match {
case UnresolvedRelation(tbl, _) => UnresolvedRelation(tbl, Option(alias))
- case plan => SubqueryAlias(alias, plan)
+ case plan => SubqueryAlias(alias, plan, None)
}
def repartition(num: Integer): LogicalPlan =
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
index ac44f08897cbd..ddbe937cba9bd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
@@ -72,7 +72,7 @@ case class ScalarSubquery(
override def dataType: DataType = query.schema.fields.head.dataType
override def foldable: Boolean = false
override def nullable: Boolean = true
- override def plan: LogicalPlan = SubqueryAlias(toString, query)
+ override def plan: LogicalPlan = SubqueryAlias(toString, query, None)
override def withNewPlan(plan: LogicalPlan): ScalarSubquery = copy(query = plan)
override def toString: String = s"scalar-subquery#${exprId.id} $conditionString"
}
@@ -100,7 +100,7 @@ case class PredicateSubquery(
override lazy val resolved = childrenResolved && query.resolved
override lazy val references: AttributeSet = super.references -- query.outputSet
override def nullable: Boolean = nullAware
- override def plan: LogicalPlan = SubqueryAlias(toString, query)
+ override def plan: LogicalPlan = SubqueryAlias(toString, query, None)
override def withNewPlan(plan: LogicalPlan): PredicateSubquery = copy(query = plan)
override def semanticEquals(o: Expression): Boolean = o match {
case p: PredicateSubquery =>
@@ -153,7 +153,7 @@ case class ListQuery(query: LogicalPlan, exprId: ExprId = NamedExpression.newExp
override def dataType: DataType = ArrayType(NullType)
override def nullable: Boolean = false
override def withNewPlan(plan: LogicalPlan): ListQuery = copy(query = plan)
- override def plan: LogicalPlan = SubqueryAlias(toString, query)
+ override def plan: LogicalPlan = SubqueryAlias(toString, query, None)
override def toString: String = s"list#${exprId.id}"
}
@@ -174,6 +174,6 @@ case class Exists(query: LogicalPlan, exprId: ExprId = NamedExpression.newExprId
override def children: Seq[Expression] = Seq.empty
override def nullable: Boolean = false
override def withNewPlan(plan: LogicalPlan): Exists = copy(query = plan)
- override def plan: LogicalPlan = SubqueryAlias(toString, query)
+ override def plan: LogicalPlan = SubqueryAlias(toString, query, None)
override def toString: String = s"exists#${exprId.id}"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index e34a478818e98..f97a78b411597 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -1862,7 +1862,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
// and Project operators, followed by an optional Filter, followed by an
// Aggregate. Traverse the operators recursively.
def evalPlan(lp : LogicalPlan) : Map[ExprId, Option[Any]] = lp match {
- case SubqueryAlias(_, child) => evalPlan(child)
+ case SubqueryAlias(_, child, _) => evalPlan(child)
case Filter(condition, child) =>
val bindings = evalPlan(child)
if (bindings.isEmpty) bindings
@@ -1920,7 +1920,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
topPart += p
bottomPart = child
- case s @ SubqueryAlias(_, child) =>
+ case s @ SubqueryAlias(_, child, _) =>
topPart += s
bottomPart = child
@@ -1991,8 +1991,8 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
topPart.reverse.foreach {
case Project(projList, _) =>
subqueryRoot = Project(projList ++ havingInputs, subqueryRoot)
- case s @ SubqueryAlias(alias, _) =>
- subqueryRoot = SubqueryAlias(alias, subqueryRoot)
+ case s @ SubqueryAlias(alias, _, None) =>
+ subqueryRoot = SubqueryAlias(alias, subqueryRoot, None)
case op => sys.error(s"Unexpected operator $op in corelated subquery")
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 09b650ce18790..adf78396d7fc0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -107,7 +107,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
* This is only used for Common Table Expressions.
*/
override def visitNamedQuery(ctx: NamedQueryContext): SubqueryAlias = withOrigin(ctx) {
- SubqueryAlias(ctx.name.getText, plan(ctx.queryNoWith))
+ SubqueryAlias(ctx.name.getText, plan(ctx.queryNoWith), None)
}
/**
@@ -723,7 +723,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
* Create an alias (SubqueryAlias) for a LogicalPlan.
*/
private def aliasPlan(alias: ParserRuleContext, plan: LogicalPlan): LogicalPlan = {
- SubqueryAlias(alias.getText, plan)
+ SubqueryAlias(alias.getText, plan, None)
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 2917d8d2a97aa..af1736e60799b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical
import scala.collection.mutable.ArrayBuffer
+import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
@@ -693,7 +694,11 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo
}
}
-case class SubqueryAlias(alias: String, child: LogicalPlan) extends UnaryNode {
+case class SubqueryAlias(
+ alias: String,
+ child: LogicalPlan,
+ view: Option[TableIdentifier])
+ extends UnaryNode {
override def output: Seq[Attribute] = child.output.map(_.withQualifier(Some(alias)))
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 22e1c9be0573d..8971edc7d3b9a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -339,8 +339,8 @@ class AnalysisSuite extends AnalysisTest {
val query =
Project(Seq($"x.key", $"y.key"),
Join(
- Project(Seq($"x.key"), SubqueryAlias("x", input)),
- Project(Seq($"y.key"), SubqueryAlias("y", input)),
+ Project(Seq($"x.key"), SubqueryAlias("x", input, None)),
+ Project(Seq($"y.key"), SubqueryAlias("y", input, None)),
Inner, None))
assertAnalysisSuccess(query)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
index b31b4406ae600..c9d4fef8056ca 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
@@ -395,31 +395,38 @@ class SessionCatalogSuite extends SparkFunSuite {
sessionCatalog.setCurrentDatabase("db2")
// If we explicitly specify the database, we'll look up the relation in that database
assert(sessionCatalog.lookupRelation(TableIdentifier("tbl1", Some("db2")))
- == SubqueryAlias("tbl1", SimpleCatalogRelation("db2", metastoreTable1)))
+ == SubqueryAlias("tbl1", SimpleCatalogRelation("db2", metastoreTable1), None))
// Otherwise, we'll first look up a temporary table with the same name
assert(sessionCatalog.lookupRelation(TableIdentifier("tbl1"))
- == SubqueryAlias("tbl1", tempTable1))
+ == SubqueryAlias("tbl1", tempTable1, Some(TableIdentifier("tbl1"))))
// Then, if that does not exist, look up the relation in the current database
sessionCatalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false, purge = false)
assert(sessionCatalog.lookupRelation(TableIdentifier("tbl1"))
- == SubqueryAlias("tbl1", SimpleCatalogRelation("db2", metastoreTable1)))
+ == SubqueryAlias("tbl1", SimpleCatalogRelation("db2", metastoreTable1), None))
}
test("lookup table relation with alias") {
val catalog = new SessionCatalog(newBasicCatalog())
val alias = "monster"
val tableMetadata = catalog.getTableMetadata(TableIdentifier("tbl1", Some("db2")))
- val relation = SubqueryAlias("tbl1", SimpleCatalogRelation("db2", tableMetadata))
+ val relation = SubqueryAlias("tbl1", SimpleCatalogRelation("db2", tableMetadata), None)
val relationWithAlias =
SubqueryAlias(alias,
- SubqueryAlias("tbl1",
- SimpleCatalogRelation("db2", tableMetadata)))
+ SimpleCatalogRelation("db2", tableMetadata), None)
assert(catalog.lookupRelation(
TableIdentifier("tbl1", Some("db2")), alias = None) == relation)
assert(catalog.lookupRelation(
TableIdentifier("tbl1", Some("db2")), alias = Some(alias)) == relationWithAlias)
}
+ test("lookup view with view name in alias") {
+ val catalog = new SessionCatalog(newBasicCatalog())
+ val tmpView = Range(1, 10, 2, 10)
+ catalog.createTempView("vw1", tmpView, overrideIfExists = false)
+ val plan = catalog.lookupRelation(TableIdentifier("vw1"), Option("range"))
+ assert(plan == SubqueryAlias("range", tmpView, Option(TableIdentifier("vw1"))))
+ }
+
test("table exists") {
val catalog = new SessionCatalog(newBasicCatalog())
assert(catalog.tableExists(TableIdentifier("tbl1", Some("db2"))))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
index 589607e3ad5cb..5bd1bc80c3b8a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
@@ -320,16 +320,16 @@ class ColumnPruningSuite extends PlanTest {
val query =
Project(Seq($"x.key", $"y.key"),
Join(
- SubqueryAlias("x", input),
- BroadcastHint(SubqueryAlias("y", input)), Inner, None)).analyze
+ SubqueryAlias("x", input, None),
+ BroadcastHint(SubqueryAlias("y", input, None)), Inner, None)).analyze
val optimized = Optimize.execute(query)
val expected =
Join(
- Project(Seq($"x.key"), SubqueryAlias("x", input)),
+ Project(Seq($"x.key"), SubqueryAlias("x", input, None)),
BroadcastHint(
- Project(Seq($"y.key"), SubqueryAlias("y", input))),
+ Project(Seq($"y.key"), SubqueryAlias("y", input, None))),
Inner, None).analyze
comparePlans(optimized, expected)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubqueryAliasesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubqueryAliasesSuite.scala
index 9b6d68aee803a..a8aeedbd62759 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubqueryAliasesSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubqueryAliasesSuite.scala
@@ -46,13 +46,13 @@ class EliminateSubqueryAliasesSuite extends PlanTest with PredicateHelper {
test("eliminate top level subquery") {
val input = LocalRelation('a.int, 'b.int)
- val query = SubqueryAlias("a", input)
+ val query = SubqueryAlias("a", input, None)
comparePlans(afterOptimization(query), input)
}
test("eliminate mid-tree subquery") {
val input = LocalRelation('a.int, 'b.int)
- val query = Filter(TrueLiteral, SubqueryAlias("a", input))
+ val query = Filter(TrueLiteral, SubqueryAlias("a", input, None))
comparePlans(
afterOptimization(query),
Filter(TrueLiteral, LocalRelation('a.int, 'b.int)))
@@ -61,7 +61,7 @@ class EliminateSubqueryAliasesSuite extends PlanTest with PredicateHelper {
test("eliminate multiple subqueries") {
val input = LocalRelation('a.int, 'b.int)
val query = Filter(TrueLiteral,
- SubqueryAlias("c", SubqueryAlias("b", SubqueryAlias("a", input))))
+ SubqueryAlias("c", SubqueryAlias("b", SubqueryAlias("a", input, None), None), None))
comparePlans(
afterOptimization(query),
Filter(TrueLiteral, LocalRelation('a.int, 'b.int)))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
index c1ebf8b09e08d..dbb3e6a5272ec 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
@@ -97,15 +97,15 @@ class JoinOptimizationSuite extends PlanTest {
val query =
Project(Seq($"x.key", $"y.key"),
Join(
- SubqueryAlias("x", input),
- BroadcastHint(SubqueryAlias("y", input)), Inner, None)).analyze
+ SubqueryAlias("x", input, None),
+ BroadcastHint(SubqueryAlias("y", input, None)), Inner, None)).analyze
val optimized = Optimize.execute(query)
val expected =
Join(
- Project(Seq($"x.key"), SubqueryAlias("x", input)),
- BroadcastHint(Project(Seq($"y.key"), SubqueryAlias("y", input))),
+ Project(Seq($"x.key"), SubqueryAlias("x", input, None)),
+ BroadcastHint(Project(Seq($"y.key"), SubqueryAlias("y", input, None))),
Inner, None).analyze
comparePlans(optimized, expected)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index 34d52c75e0af2..7af333b34f723 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -80,7 +80,7 @@ class PlanParserSuite extends PlanTest {
def cte(plan: LogicalPlan, namedPlans: (String, LogicalPlan)*): With = {
val ctes = namedPlans.map {
case (name, cte) =>
- name -> SubqueryAlias(name, cte)
+ name -> SubqueryAlias(name, cte, None)
}
With(plan, ctes)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index c119df83b3d71..6da99ce0dd683 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -967,7 +967,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
def as(alias: String): Dataset[T] = withTypedPlan {
- SubqueryAlias(alias, logicalPlan)
+ SubqueryAlias(alias, logicalPlan, None)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala
index 5d93419f357ef..ff8e0f2642055 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala
@@ -75,7 +75,7 @@ class SQLBuilder private (
val aliasedOutput = canonicalizedPlan.output.zip(outputNames).map {
case (attr, name) => Alias(attr.withQualifier(None), name)()
}
- val finalPlan = Project(aliasedOutput, SubqueryAlias(finalName, canonicalizedPlan))
+ val finalPlan = Project(aliasedOutput, SubqueryAlias(finalName, canonicalizedPlan, None))
try {
val replaced = finalPlan.transformAllExpressions {
@@ -440,7 +440,7 @@ class SQLBuilder private (
object RemoveSubqueriesAboveSQLTable extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
- case SubqueryAlias(_, t @ ExtractSQLTable(_)) => t
+ case SubqueryAlias(_, t @ ExtractSQLTable(_), _) => t
}
}
@@ -557,7 +557,7 @@ class SQLBuilder private (
}
private def addSubquery(plan: LogicalPlan): SubqueryAlias = {
- SubqueryAlias(newSubqueryName(), plan)
+ SubqueryAlias(newSubqueryName(), plan, None)
}
private def addSubqueryIfNeeded(plan: LogicalPlan): LogicalPlan = plan match {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
index fc8d8c3667901..5eb2f0a9ff034 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
@@ -55,7 +55,7 @@ class ResolveDataSource(sparkSession: SparkSession) extends Rule[LogicalPlan] {
s"${u.tableIdentifier.database.get}")
}
val plan = LogicalRelation(dataSource.resolveRelation())
- u.alias.map(a => SubqueryAlias(u.alias.get, plan)).getOrElse(plan)
+ u.alias.map(a => SubqueryAlias(u.alias.get, plan, None)).getOrElse(plan)
} catch {
case e: ClassNotFoundException => u
case e: Exception =>
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index c7c1acda25db2..7118edabb83cf 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -162,24 +162,21 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
if (table.properties.get(DATASOURCE_PROVIDER).isDefined) {
val dataSourceTable = cachedDataSourceTables(qualifiedTableName)
- val qualifiedTable = SubqueryAlias(qualifiedTableName.name, dataSourceTable)
+ val qualifiedTable = SubqueryAlias(qualifiedTableName.name, dataSourceTable, None)
// Then, if alias is specified, wrap the table with a Subquery using the alias.
// Otherwise, wrap the table with a Subquery using the table name.
- alias.map(a => SubqueryAlias(a, qualifiedTable)).getOrElse(qualifiedTable)
+ alias.map(a => SubqueryAlias(a, qualifiedTable, None)).getOrElse(qualifiedTable)
} else if (table.tableType == CatalogTableType.VIEW) {
val viewText = table.viewText.getOrElse(sys.error("Invalid view without text."))
- alias match {
- case None =>
- SubqueryAlias(table.identifier.table,
- sparkSession.sessionState.sqlParser.parsePlan(viewText))
- case Some(aliasText) =>
- SubqueryAlias(aliasText, sessionState.sqlParser.parsePlan(viewText))
- }
+ SubqueryAlias(
+ alias.getOrElse(table.identifier.table),
+ sparkSession.sessionState.sqlParser.parsePlan(viewText),
+ Option(table.identifier))
} else {
val qualifiedTable =
MetastoreRelation(
qualifiedTableName.database, qualifiedTableName.name)(table, client, sparkSession)
- alias.map(a => SubqueryAlias(a, qualifiedTable)).getOrElse(qualifiedTable)
+ alias.map(a => SubqueryAlias(a, qualifiedTable, None)).getOrElse(qualifiedTable)
}
}
@@ -383,7 +380,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
// Read path
case relation: MetastoreRelation if shouldConvertMetastoreParquet(relation) =>
val parquetRelation = convertToParquetRelation(relation)
- SubqueryAlias(relation.tableName, parquetRelation)
+ SubqueryAlias(relation.tableName, parquetRelation, None)
}
}
}
@@ -421,7 +418,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
// Read path
case relation: MetastoreRelation if shouldConvertMetastoreOrc(relation) =>
val orcRelation = convertToOrcRelation(relation)
- SubqueryAlias(relation.tableName, orcRelation)
+ SubqueryAlias(relation.tableName, orcRelation, None)
}
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
index c59ac3dcafea4..ebed9eb6e7dca 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
@@ -68,10 +68,10 @@ private[sql] class HiveSessionCatalog(
metastoreCatalog.lookupRelation(newName, alias)
} else {
val relation = tempTables(table)
- val tableWithQualifiers = SubqueryAlias(table, relation)
+ val tableWithQualifiers = SubqueryAlias(table, relation, None)
// If an alias was specified by the lookup, wrap the plan in a subquery so that
// attributes are properly qualified with this alias.
- alias.map(a => SubqueryAlias(a, tableWithQualifiers)).getOrElse(tableWithQualifiers)
+ alias.map(a => SubqueryAlias(a, tableWithQualifiers, None)).getOrElse(tableWithQualifiers)
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
index 9d72367f437bf..0477ea4d4c380 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
@@ -23,12 +23,13 @@ import org.apache.spark.sql.{QueryTest, Row, SaveMode}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.CatalogTableType
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{ExamplePointUDT, SQLTestUtils}
import org.apache.spark.sql.types.{DecimalType, IntegerType, StringType, StructField, StructType}
-class HiveMetastoreCatalogSuite extends TestHiveSingleton {
+class HiveMetastoreCatalogSuite extends TestHiveSingleton with SQLTestUtils {
import spark.implicits._
test("struct field should accept underscore in sub-column name") {
@@ -57,6 +58,17 @@ class HiveMetastoreCatalogSuite extends TestHiveSingleton {
val dataType = StructType((1 to 100).map(field))
assert(CatalystSqlParser.parseDataType(dataType.catalogString) == dataType)
}
+
+ test("view relation") {
+ withView("vw1") {
+ spark.sql("create view vw1 as select 1 as id")
+ val plan = spark.sql("select id from vw1").queryExecution.analyzed
+ val aliases = plan.collect {
+ case x @ SubqueryAlias("vw1", _, Some(TableIdentifier("vw1", Some("default")))) => x
+ }
+ assert(aliases.size == 1)
+ }
+ }
}
class DataSourceWithHiveMetastoreCatalogSuite
From 0f6aa8afaacdf0ceca9c2c1650ca26a5c167ae69 Mon Sep 17 00:00:00 2001
From: mvervuurt
Date: Tue, 16 Aug 2016 23:12:59 -0700
Subject: [PATCH 003/270] [MINOR][DOC] Fix the descriptions for `properties`
argument in the documenation for jdbc APIs
## What changes were proposed in this pull request?
This should be credited to mvervuurt. The main purpose of this PR is
- simply to include the change for the same instance in `DataFrameReader` just to match up.
- just avoid duplicately verifying the PR (as I already did).
The documentation for both should be the same because both assume the `properties` should be the same `dict` for the same option.
## How was this patch tested?
Manually building Python documentation.
This will produce the output as below:
- `DataFrameReader`

- `DataFrameWriter`

Closes #14624
Author: hyukjinkwon
Author: mvervuurt
Closes #14677 from HyukjinKwon/typo-python.
---
python/pyspark/sql/readwriter.py | 11 ++++++-----
1 file changed, 6 insertions(+), 5 deletions(-)
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index 4020bb3fa45b0..64de33e8ec0a8 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -401,8 +401,9 @@ def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPar
:param numPartitions: the number of partitions
:param predicates: a list of expressions suitable for inclusion in WHERE clauses;
each one defines one partition of the :class:`DataFrame`
- :param properties: a dictionary of JDBC database connection arguments; normally,
- at least a "user" and "password" property should be included
+ :param properties: a dictionary of JDBC database connection arguments. Normally at
+ least properties "user" and "password" with their corresponding values.
+ For example { 'user' : 'SYSTEM', 'password' : 'mypassword' }
:return: a DataFrame
"""
if properties is None:
@@ -716,9 +717,9 @@ def jdbc(self, url, table, mode=None, properties=None):
* ``overwrite``: Overwrite existing data.
* ``ignore``: Silently ignore this operation if data already exists.
* ``error`` (default case): Throw an exception if data already exists.
- :param properties: JDBC database connection arguments, a list of
- arbitrary string tag/value. Normally at least a
- "user" and "password" property should be included.
+ :param properties: a dictionary of JDBC database connection arguments. Normally at
+ least properties "user" and "password" with their corresponding values.
+ For example { 'user' : 'SYSTEM', 'password' : 'mypassword' }
"""
if properties is None:
properties = dict()
From 4d0cc84afca9efd4541a2e8d583e3e0f2df37c0d Mon Sep 17 00:00:00 2001
From: jiangxingbo
Date: Wed, 17 Aug 2016 14:22:36 +0200
Subject: [PATCH 004/270] [SPARK-17032][SQL] Add test cases for methods in
ParserUtils.
## What changes were proposed in this pull request?
Currently methods in `ParserUtils` are tested indirectly, we should add test cases in `ParserUtilsSuite` to verify their integrity directly.
## How was this patch tested?
New test cases in `ParserUtilsSuite`
Author: jiangxingbo
Closes #14620 from jiangxb1987/parserUtils.
---
.../sql/catalyst/parser/ParserUtils.scala | 9 +-
.../catalyst/parser/ParserUtilsSuite.scala | 126 +++++++++++++++++-
2 files changed, 128 insertions(+), 7 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
index bc35ae2f55409..cb89a9679a8cf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
@@ -31,11 +31,7 @@ import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
object ParserUtils {
/** Get the command which created the token. */
def command(ctx: ParserRuleContext): String = {
- command(ctx.getStart.getInputStream)
- }
-
- /** Get the command which created the token. */
- def command(stream: CharStream): String = {
+ val stream = ctx.getStart.getInputStream
stream.getText(Interval.of(0, stream.size()))
}
@@ -74,7 +70,8 @@ object ParserUtils {
/** Get the origin (line and position) of the token. */
def position(token: Token): Origin = {
- Origin(Option(token.getLine), Option(token.getCharPositionInLine))
+ val opt = Option(token)
+ Origin(opt.map(_.getLine), opt.map(_.getCharPositionInLine))
}
/** Validate the condition. If it doesn't throw a parse exception. */
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala
index d090daf7b41eb..d5748a4ff18f8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala
@@ -16,12 +16,53 @@
*/
package org.apache.spark.sql.catalyst.parser
+import org.antlr.v4.runtime.{CommonTokenStream, ParserRuleContext}
+
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.parser.SqlBaseParser._
+import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
class ParserUtilsSuite extends SparkFunSuite {
import ParserUtils._
+ val setConfContext = buildContext("set example.setting.name=setting.value") { parser =>
+ parser.statement().asInstanceOf[SetConfigurationContext]
+ }
+
+ val showFuncContext = buildContext("show functions foo.bar") { parser =>
+ parser.statement().asInstanceOf[ShowFunctionsContext]
+ }
+
+ val descFuncContext = buildContext("describe function extended bar") { parser =>
+ parser.statement().asInstanceOf[DescribeFunctionContext]
+ }
+
+ val showDbsContext = buildContext("show databases like 'identifier_with_wildcards'") { parser =>
+ parser.statement().asInstanceOf[ShowDatabasesContext]
+ }
+
+ val createDbContext = buildContext(
+ """
+ |CREATE DATABASE IF NOT EXISTS database_name
+ |COMMENT 'database_comment' LOCATION '/home/user/db'
+ |WITH DBPROPERTIES ('a'='a', 'b'='b', 'c'='c')
+ """.stripMargin
+ ) { parser =>
+ parser.statement().asInstanceOf[CreateDatabaseContext]
+ }
+
+ val emptyContext = buildContext("") { parser =>
+ parser.statement
+ }
+
+ private def buildContext[T](command: String)(toResult: SqlBaseParser => T): T = {
+ val lexer = new SqlBaseLexer(new ANTLRNoCaseStringStream(command))
+ val tokenStream = new CommonTokenStream(lexer)
+ val parser = new SqlBaseParser(tokenStream)
+ toResult(parser)
+ }
+
test("unescapeSQLString") {
// scalastyle:off nonascii
@@ -61,5 +102,88 @@ class ParserUtilsSuite extends SparkFunSuite {
// scalastyle:on nonascii
}
- // TODO: Add test cases for other methods in ParserUtils
+ test("command") {
+ assert(command(setConfContext) == "set example.setting.name=setting.value")
+ assert(command(showFuncContext) == "show functions foo.bar")
+ assert(command(descFuncContext) == "describe function extended bar")
+ assert(command(showDbsContext) == "show databases like 'identifier_with_wildcards'")
+ }
+
+ test("operationNotAllowed") {
+ val errorMessage = "parse.fail.operation.not.allowed.error.message"
+ val e = intercept[ParseException] {
+ operationNotAllowed(errorMessage, showFuncContext)
+ }.getMessage
+ assert(e.contains("Operation not allowed"))
+ assert(e.contains(errorMessage))
+ }
+
+ test("checkDuplicateKeys") {
+ val properties = Seq(("a", "a"), ("b", "b"), ("c", "c"))
+ checkDuplicateKeys[String](properties, createDbContext)
+
+ val properties2 = Seq(("a", "a"), ("b", "b"), ("a", "c"))
+ val e = intercept[ParseException] {
+ checkDuplicateKeys(properties2, createDbContext)
+ }.getMessage
+ assert(e.contains("Found duplicate keys"))
+ }
+
+ test("source") {
+ assert(source(setConfContext) == "set example.setting.name=setting.value")
+ assert(source(showFuncContext) == "show functions foo.bar")
+ assert(source(descFuncContext) == "describe function extended bar")
+ assert(source(showDbsContext) == "show databases like 'identifier_with_wildcards'")
+ }
+
+ test("remainder") {
+ assert(remainder(setConfContext) == "")
+ assert(remainder(showFuncContext) == "")
+ assert(remainder(descFuncContext) == "")
+ assert(remainder(showDbsContext) == "")
+
+ assert(remainder(setConfContext.SET.getSymbol) == " example.setting.name=setting.value")
+ assert(remainder(showFuncContext.FUNCTIONS.getSymbol) == " foo.bar")
+ assert(remainder(descFuncContext.EXTENDED.getSymbol) == " bar")
+ assert(remainder(showDbsContext.LIKE.getSymbol) == " 'identifier_with_wildcards'")
+ }
+
+ test("string") {
+ assert(string(showDbsContext.pattern) == "identifier_with_wildcards")
+ assert(string(createDbContext.comment) == "database_comment")
+
+ assert(string(createDbContext.locationSpec.STRING) == "/home/user/db")
+ }
+
+ test("position") {
+ assert(position(setConfContext.start) == Origin(Some(1), Some(0)))
+ assert(position(showFuncContext.stop) == Origin(Some(1), Some(19)))
+ assert(position(descFuncContext.describeFuncName.start) == Origin(Some(1), Some(27)))
+ assert(position(createDbContext.locationSpec.start) == Origin(Some(3), Some(27)))
+ assert(position(emptyContext.stop) == Origin(None, None))
+ }
+
+ test("validate") {
+ val f1 = { ctx: ParserRuleContext =>
+ ctx.children != null && !ctx.children.isEmpty
+ }
+ val message = "ParserRuleContext should not be empty."
+ validate(f1(showFuncContext), message, showFuncContext)
+
+ val e = intercept[ParseException] {
+ validate(f1(emptyContext), message, emptyContext)
+ }.getMessage
+ assert(e.contains(message))
+ }
+
+ test("withOrigin") {
+ val ctx = createDbContext.locationSpec
+ val current = CurrentOrigin.get
+ val (location, origin) = withOrigin(ctx) {
+ (string(ctx.STRING), CurrentOrigin.get)
+ }
+ assert(location == "/home/user/db")
+ assert(origin == Origin(Some(3), Some(27)))
+ assert(CurrentOrigin.get == current)
+ }
}
From 363793f2bf57205f1d753d4705583aaf441849b5 Mon Sep 17 00:00:00 2001
From: "wm624@hotmail.com"
Date: Wed, 17 Aug 2016 06:15:04 -0700
Subject: [PATCH 005/270] [SPARK-16444][SPARKR] Isotonic Regression wrapper in
SparkR
## What changes were proposed in this pull request?
(Please fill in changes proposed in this fix)
Add Isotonic Regression wrapper in SparkR
Wrappers in R and Scala are added.
Unit tests
Documentation
## How was this patch tested?
Manually tested with sudo ./R/run-tests.sh
(Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests)
Author: wm624@hotmail.com
Closes #14182 from wangmiao1981/isoR.
---
R/pkg/NAMESPACE | 3 +-
R/pkg/R/generics.R | 4 +
R/pkg/R/mllib.R | 118 +++++++++++++++++
R/pkg/inst/tests/testthat/test_mllib.R | 32 +++++
.../ml/r/IsotonicRegressionWrapper.scala | 119 ++++++++++++++++++
.../org/apache/spark/ml/r/RWrappers.scala | 2 +
6 files changed, 277 insertions(+), 1 deletion(-)
create mode 100644 mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index aaab92f5cfc7b..1e23b233c1116 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -24,7 +24,8 @@ exportMethods("glm",
"spark.kmeans",
"fitted",
"spark.naiveBayes",
- "spark.survreg")
+ "spark.survreg",
+ "spark.isoreg")
# Job group lifecycle management methods
export("setJobGroup",
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 52ab730e215c2..ebacc11741812 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1304,6 +1304,10 @@ setGeneric("spark.naiveBayes", function(data, formula, ...) { standardGeneric("s
#' @export
setGeneric("spark.survreg", function(data, formula, ...) { standardGeneric("spark.survreg") })
+#' @rdname spark.isoreg
+#' @export
+setGeneric("spark.isoreg", function(data, formula, ...) { standardGeneric("spark.isoreg") })
+
#' @rdname write.ml
#' @export
setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") })
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index 6f6e2fc255c3f..0dcc54d7af09b 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -53,6 +53,13 @@ setClass("AFTSurvivalRegressionModel", representation(jobj = "jobj"))
#' @note KMeansModel since 2.0.0
setClass("KMeansModel", representation(jobj = "jobj"))
+#' S4 class that represents an IsotonicRegressionModel
+#'
+#' @param jobj a Java object reference to the backing Scala IsotonicRegressionModel
+#' @export
+#' @note IsotonicRegressionModel since 2.1.0
+setClass("IsotonicRegressionModel", representation(jobj = "jobj"))
+
#' Saves the MLlib model to the input path
#'
#' Saves the MLlib model to the input path. For more information, see the specific
@@ -62,6 +69,7 @@ setClass("KMeansModel", representation(jobj = "jobj"))
#' @export
#' @seealso \link{spark.glm}, \link{glm}
#' @seealso \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg}
+#' @seealso \link{spark.isoreg}
#' @seealso \link{read.ml}
NULL
@@ -74,6 +82,7 @@ NULL
#' @export
#' @seealso \link{spark.glm}, \link{glm}
#' @seealso \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg}
+#' @seealso \link{spark.isoreg}
NULL
#' Generalized Linear Models
@@ -299,6 +308,94 @@ setMethod("summary", signature(object = "NaiveBayesModel"),
return(list(apriori = apriori, tables = tables))
})
+#' Isotonic Regression Model
+#'
+#' Fits an Isotonic Regression model against a Spark DataFrame, similarly to R's isoreg().
+#' Users can print, make predictions on the produced model and save the model to the input path.
+#'
+#' @param data SparkDataFrame for training
+#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
+#' operators are supported, including '~', '.', ':', '+', and '-'.
+#' @param isotonic Whether the output sequence should be isotonic/increasing (TRUE) or
+#' antitonic/decreasing (FALSE)
+#' @param featureIndex The index of the feature if \code{featuresCol} is a vector column (default: `0`),
+#' no effect otherwise
+#' @param weightCol The weight column name.
+#' @return \code{spark.isoreg} returns a fitted Isotonic Regression model
+#' @rdname spark.isoreg
+#' @aliases spark.isoreg,SparkDataFrame,formula-method
+#' @name spark.isoreg
+#' @export
+#' @examples
+#' \dontrun{
+#' sparkR.session()
+#' data <- list(list(7.0, 0.0), list(5.0, 1.0), list(3.0, 2.0),
+#' list(5.0, 3.0), list(1.0, 4.0))
+#' df <- createDataFrame(data, c("label", "feature"))
+#' model <- spark.isoreg(df, label ~ feature, isotonic = FALSE)
+#' # return model boundaries and prediction as lists
+#' result <- summary(model, df)
+#' # prediction based on fitted model
+#' predict_data <- list(list(-2.0), list(-1.0), list(0.5),
+#' list(0.75), list(1.0), list(2.0), list(9.0))
+#' predict_df <- createDataFrame(predict_data, c("feature"))
+#' # get prediction column
+#' predict_result <- collect(select(predict(model, predict_df), "prediction"))
+#'
+#' # save fitted model to input path
+#' path <- "path/to/model"
+#' write.ml(model, path)
+#'
+#' # can also read back the saved model and print
+#' savedModel <- read.ml(path)
+#' summary(savedModel)
+#' }
+#' @note spark.isoreg since 2.1.0
+setMethod("spark.isoreg", signature(data = "SparkDataFrame", formula = "formula"),
+ function(data, formula, isotonic = TRUE, featureIndex = 0, weightCol = NULL) {
+ formula <- paste0(deparse(formula), collapse = "")
+
+ if (is.null(weightCol)) {
+ weightCol <- ""
+ }
+
+ jobj <- callJStatic("org.apache.spark.ml.r.IsotonicRegressionWrapper", "fit",
+ data@sdf, formula, as.logical(isotonic), as.integer(featureIndex),
+ as.character(weightCol))
+ return(new("IsotonicRegressionModel", jobj = jobj))
+ })
+
+# Predicted values based on an isotonicRegression model
+
+#' @param object a fitted IsotonicRegressionModel
+#' @param newData SparkDataFrame for testing
+#' @return \code{predict} returns a SparkDataFrame containing predicted values
+#' @rdname spark.isoreg
+#' @aliases predict,IsotonicRegressionModel,SparkDataFrame-method
+#' @export
+#' @note predict(IsotonicRegressionModel) since 2.1.0
+setMethod("predict", signature(object = "IsotonicRegressionModel"),
+ function(object, newData) {
+ return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
+ })
+
+# Get the summary of an IsotonicRegressionModel model
+
+#' @param object a fitted IsotonicRegressionModel
+#' @param ... Other optional arguments to summary of an IsotonicRegressionModel
+#' @return \code{summary} returns the model's boundaries and prediction as lists
+#' @rdname spark.isoreg
+#' @aliases summary,IsotonicRegressionModel-method
+#' @export
+#' @note summary(IsotonicRegressionModel) since 2.1.0
+setMethod("summary", signature(object = "IsotonicRegressionModel"),
+ function(object, ...) {
+ jobj <- object@jobj
+ boundaries <- callJMethod(jobj, "boundaries")
+ predictions <- callJMethod(jobj, "predictions")
+ return(list(boundaries = boundaries, predictions = predictions))
+ })
+
#' K-Means Clustering Model
#'
#' Fits a k-means clustering model against a Spark DataFrame, similarly to R's kmeans().
@@ -533,6 +630,25 @@ setMethod("write.ml", signature(object = "KMeansModel", path = "character"),
invisible(callJMethod(writer, "save", path))
})
+# Save fitted IsotonicRegressionModel to the input path
+
+#' @param path The directory where the model is saved
+#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
+#' which means throw exception if the output path exists.
+#'
+#' @rdname spark.isoreg
+#' @aliases write.ml,IsotonicRegressionModel,character-method
+#' @export
+#' @note write.ml(IsotonicRegression, character) since 2.1.0
+setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "character"),
+ function(object, path, overwrite = FALSE) {
+ writer <- callJMethod(object@jobj, "write")
+ if (overwrite) {
+ writer <- callJMethod(writer, "overwrite")
+ }
+ invisible(callJMethod(writer, "save", path))
+ })
+
#' Load a fitted MLlib model from the input path.
#'
#' @param path Path of the model to read.
@@ -558,6 +674,8 @@ read.ml <- function(path) {
return(new("GeneralizedLinearRegressionModel", jobj = jobj))
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.KMeansWrapper")) {
return(new("KMeansModel", jobj = jobj))
+ } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.IsotonicRegressionWrapper")) {
+ return(new("IsotonicRegressionModel", jobj = jobj))
} else {
stop(paste("Unsupported model: ", jobj))
}
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index bc18224680586..b759b28927365 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -476,4 +476,36 @@ test_that("spark.survreg", {
}
})
+test_that("spark.isotonicRegression", {
+ label <- c(7.0, 5.0, 3.0, 5.0, 1.0)
+ feature <- c(0.0, 1.0, 2.0, 3.0, 4.0)
+ weight <- c(1.0, 1.0, 1.0, 1.0, 1.0)
+ data <- as.data.frame(cbind(label, feature, weight))
+ df <- suppressWarnings(createDataFrame(data))
+
+ model <- spark.isoreg(df, label ~ feature, isotonic = FALSE,
+ weightCol = "weight")
+ # only allow one variable on the right hand side of the formula
+ expect_error(model2 <- spark.isoreg(df, ~., isotonic = FALSE))
+ result <- summary(model, df)
+ expect_equal(result$predictions, list(7, 5, 4, 4, 1))
+
+ # Test model prediction
+ predict_data <- list(list(-2.0), list(-1.0), list(0.5),
+ list(0.75), list(1.0), list(2.0), list(9.0))
+ predict_df <- createDataFrame(predict_data, c("feature"))
+ predict_result <- collect(select(predict(model, predict_df), "prediction"))
+ expect_equal(predict_result$prediction, c(7.0, 7.0, 6.0, 5.5, 5.0, 4.0, 1.0))
+
+ # Test model save/load
+ modelPath <- tempfile(pattern = "spark-isotonicRegression", fileext = ".tmp")
+ write.ml(model, modelPath)
+ expect_error(write.ml(model, modelPath))
+ write.ml(model, modelPath, overwrite = TRUE)
+ model2 <- read.ml(modelPath)
+ expect_equal(result, summary(model2, df))
+
+ unlink(modelPath)
+})
+
sparkR.session.stop()
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala
new file mode 100644
index 0000000000000..1ea80cb46ab7b
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import org.apache.hadoop.fs.Path
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.ml.{Pipeline, PipelineModel}
+import org.apache.spark.ml.attribute.{AttributeGroup}
+import org.apache.spark.ml.feature.RFormula
+import org.apache.spark.ml.regression.{IsotonicRegression, IsotonicRegressionModel}
+import org.apache.spark.ml.util._
+import org.apache.spark.sql.{DataFrame, Dataset}
+
+private[r] class IsotonicRegressionWrapper private (
+ val pipeline: PipelineModel,
+ val features: Array[String]) extends MLWritable {
+
+ private val isotonicRegressionModel: IsotonicRegressionModel =
+ pipeline.stages(1).asInstanceOf[IsotonicRegressionModel]
+
+ lazy val boundaries: Array[Double] = isotonicRegressionModel.boundaries.toArray
+
+ lazy val predictions: Array[Double] = isotonicRegressionModel.predictions.toArray
+
+ def transform(dataset: Dataset[_]): DataFrame = {
+ pipeline.transform(dataset).drop(isotonicRegressionModel.getFeaturesCol)
+ }
+
+ override def write: MLWriter = new IsotonicRegressionWrapper.IsotonicRegressionWrapperWriter(this)
+}
+
+private[r] object IsotonicRegressionWrapper
+ extends MLReadable[IsotonicRegressionWrapper] {
+
+ def fit(
+ data: DataFrame,
+ formula: String,
+ isotonic: Boolean,
+ featureIndex: Int,
+ weightCol: String): IsotonicRegressionWrapper = {
+
+ val rFormulaModel = new RFormula()
+ .setFormula(formula)
+ .setFeaturesCol("features")
+ .fit(data)
+
+ // get feature names from output schema
+ val schema = rFormulaModel.transform(data).schema
+ val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol))
+ .attributes.get
+ val features = featureAttrs.map(_.name.get)
+ require(features.size == 1)
+
+ // assemble and fit the pipeline
+ val isotonicRegression = new IsotonicRegression()
+ .setIsotonic(isotonic)
+ .setFeatureIndex(featureIndex)
+ .setWeightCol(weightCol)
+
+ val pipeline = new Pipeline()
+ .setStages(Array(rFormulaModel, isotonicRegression))
+ .fit(data)
+
+ new IsotonicRegressionWrapper(pipeline, features)
+ }
+
+ override def read: MLReader[IsotonicRegressionWrapper] = new IsotonicRegressionWrapperReader
+
+ override def load(path: String): IsotonicRegressionWrapper = super.load(path)
+
+ class IsotonicRegressionWrapperWriter(instance: IsotonicRegressionWrapper) extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ val rMetadataPath = new Path(path, "rMetadata").toString
+ val pipelinePath = new Path(path, "pipeline").toString
+
+ val rMetadata = ("class" -> instance.getClass.getName) ~
+ ("features" -> instance.features.toSeq)
+ val rMetadataJson: String = compact(render(rMetadata))
+ sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
+
+ instance.pipeline.save(pipelinePath)
+ }
+ }
+
+ class IsotonicRegressionWrapperReader extends MLReader[IsotonicRegressionWrapper] {
+
+ override def load(path: String): IsotonicRegressionWrapper = {
+ implicit val format = DefaultFormats
+ val rMetadataPath = new Path(path, "rMetadata").toString
+ val pipelinePath = new Path(path, "pipeline").toString
+
+ val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
+ val rMetadata = parse(rMetadataStr)
+ val features = (rMetadata \ "features").extract[Array[String]]
+
+ val pipeline = PipelineModel.load(pipelinePath)
+ new IsotonicRegressionWrapper(pipeline, features)
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
index 568c160ee50d7..f9a44d60e691a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
@@ -44,6 +44,8 @@ private[r] object RWrappers extends MLReader[Object] {
GeneralizedLinearRegressionWrapper.load(path)
case "org.apache.spark.ml.r.KMeansWrapper" =>
KMeansWrapper.load(path)
+ case "org.apache.spark.ml.r.IsotonicRegressionWrapper" =>
+ IsotonicRegressionWrapper.load(path)
case _ =>
throw new SparkException(s"SparkR read.ml does not support load $className")
}
From 56d86742d2600b8426d75bd87ab3c73332dca1d2 Mon Sep 17 00:00:00 2001
From: Kazuaki Ishizaki
Date: Wed, 17 Aug 2016 21:34:57 +0800
Subject: [PATCH 006/270] [SPARK-15285][SQL] Generated
SpecificSafeProjection.apply method grows beyond 64 KB
## What changes were proposed in this pull request?
This PR splits the generated code for ```SafeProjection.apply``` by using ```ctx.splitExpressions()```. This is because the large code body for ```NewInstance``` may grow beyond 64KB bytecode size for ```apply()``` method.
Here is [the original PR](https://github.com/apache/spark/pull/13243) for SPARK-15285. However, it breaks a build with Scala 2.10 since Scala 2.10 does not a case class with large number of members. Thus, it was reverted by [this commit](https://github.com/apache/spark/commit/fa244e5a90690d6a31be50f2aa203ae1a2e9a1cf).
## How was this patch tested?
Added new tests by using `DefinedByConstructorParams` instead of case class for scala-2.10
Author: Kazuaki Ishizaki
Closes #14670 from kiszk/SPARK-15285-2.
---
.../expressions/objects/objects.scala | 32 ++++++++++++---
.../spark/sql/DataFrameComplexTypeSuite.scala | 40 +++++++++++++++++++
2 files changed, 66 insertions(+), 6 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 7cb94a7942885..31ed485317487 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -245,27 +245,47 @@ case class NewInstance(
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val javaType = ctx.javaType(dataType)
- val argGen = arguments.map(_.genCode(ctx))
- val argString = argGen.map(_.value).mkString(", ")
+ val argIsNulls = ctx.freshName("argIsNulls")
+ ctx.addMutableState("boolean[]", argIsNulls,
+ s"$argIsNulls = new boolean[${arguments.size}];")
+ val argValues = arguments.zipWithIndex.map { case (e, i) =>
+ val argValue = ctx.freshName("argValue")
+ ctx.addMutableState(ctx.javaType(e.dataType), argValue, "")
+ argValue
+ }
+
+ val argCodes = arguments.zipWithIndex.map { case (e, i) =>
+ val expr = e.genCode(ctx)
+ expr.code + s"""
+ $argIsNulls[$i] = ${expr.isNull};
+ ${argValues(i)} = ${expr.value};
+ """
+ }
+ val argCode = ctx.splitExpressions(ctx.INPUT_ROW, argCodes)
val outer = outerPointer.map(func => Literal.fromObject(func()).genCode(ctx))
var isNull = ev.isNull
val setIsNull = if (propagateNull && arguments.nonEmpty) {
- s"final boolean $isNull = ${argGen.map(_.isNull).mkString(" || ")};"
+ s"""
+ boolean $isNull = false;
+ for (int idx = 0; idx < ${arguments.length}; idx++) {
+ if ($argIsNulls[idx]) { $isNull = true; break; }
+ }
+ """
} else {
isNull = "false"
""
}
val constructorCall = outer.map { gen =>
- s"""${gen.value}.new ${cls.getSimpleName}($argString)"""
+ s"""${gen.value}.new ${cls.getSimpleName}(${argValues.mkString(", ")})"""
}.getOrElse {
- s"new $className($argString)"
+ s"new $className(${argValues.mkString(", ")})"
}
val code = s"""
- ${argGen.map(_.code).mkString("\n")}
+ $argCode
${outer.map(_.code).getOrElse("")}
$setIsNull
final $javaType ${ev.value} = $isNull ? ${ctx.defaultValue(javaType)} : $constructorCall;
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala
index 72f676e6225ee..1230b921aa279 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql
+import org.apache.spark.sql.catalyst.DefinedByConstructorParams
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
@@ -58,4 +59,43 @@ class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext {
val nullIntRow = df.selectExpr("i[1]").collect()(0)
assert(nullIntRow == org.apache.spark.sql.Row(null))
}
+
+ test("SPARK-15285 Generated SpecificSafeProjection.apply method grows beyond 64KB") {
+ val ds100_5 = Seq(S100_5()).toDS()
+ ds100_5.rdd.count
+ }
}
+
+class S100(
+ val s1: String = "1", val s2: String = "2", val s3: String = "3", val s4: String = "4",
+ val s5: String = "5", val s6: String = "6", val s7: String = "7", val s8: String = "8",
+ val s9: String = "9", val s10: String = "10", val s11: String = "11", val s12: String = "12",
+ val s13: String = "13", val s14: String = "14", val s15: String = "15", val s16: String = "16",
+ val s17: String = "17", val s18: String = "18", val s19: String = "19", val s20: String = "20",
+ val s21: String = "21", val s22: String = "22", val s23: String = "23", val s24: String = "24",
+ val s25: String = "25", val s26: String = "26", val s27: String = "27", val s28: String = "28",
+ val s29: String = "29", val s30: String = "30", val s31: String = "31", val s32: String = "32",
+ val s33: String = "33", val s34: String = "34", val s35: String = "35", val s36: String = "36",
+ val s37: String = "37", val s38: String = "38", val s39: String = "39", val s40: String = "40",
+ val s41: String = "41", val s42: String = "42", val s43: String = "43", val s44: String = "44",
+ val s45: String = "45", val s46: String = "46", val s47: String = "47", val s48: String = "48",
+ val s49: String = "49", val s50: String = "50", val s51: String = "51", val s52: String = "52",
+ val s53: String = "53", val s54: String = "54", val s55: String = "55", val s56: String = "56",
+ val s57: String = "57", val s58: String = "58", val s59: String = "59", val s60: String = "60",
+ val s61: String = "61", val s62: String = "62", val s63: String = "63", val s64: String = "64",
+ val s65: String = "65", val s66: String = "66", val s67: String = "67", val s68: String = "68",
+ val s69: String = "69", val s70: String = "70", val s71: String = "71", val s72: String = "72",
+ val s73: String = "73", val s74: String = "74", val s75: String = "75", val s76: String = "76",
+ val s77: String = "77", val s78: String = "78", val s79: String = "79", val s80: String = "80",
+ val s81: String = "81", val s82: String = "82", val s83: String = "83", val s84: String = "84",
+ val s85: String = "85", val s86: String = "86", val s87: String = "87", val s88: String = "88",
+ val s89: String = "89", val s90: String = "90", val s91: String = "91", val s92: String = "92",
+ val s93: String = "93", val s94: String = "94", val s95: String = "95", val s96: String = "96",
+ val s97: String = "97", val s98: String = "98", val s99: String = "99", val s100: String = "100")
+extends DefinedByConstructorParams
+
+case class S100_5(
+ s1: S100 = new S100(), s2: S100 = new S100(), s3: S100 = new S100(),
+ s4: S100 = new S100(), s5: S100 = new S100())
+
+
From 0b0c8b95e3594db36d87ef0e59a30eefe8508ac1 Mon Sep 17 00:00:00 2001
From: Herman van Hovell
Date: Wed, 17 Aug 2016 07:03:24 -0700
Subject: [PATCH 007/270] [SPARK-17106] [SQL] Simplify the SubqueryExpression
interface
## What changes were proposed in this pull request?
The current subquery expression interface contains a little bit of technical debt in the form of a few different access paths to get and set the query contained by the expression. This is confusing to anyone who goes over this code.
This PR unifies these access paths.
## How was this patch tested?
(Existing tests)
Author: Herman van Hovell
Closes #14685 from hvanhovell/SPARK-17106.
---
.../sql/catalyst/analysis/Analyzer.scala | 4 +-
.../sql/catalyst/expressions/subquery.scala | 60 +++++++++----------
.../sql/catalyst/optimizer/Optimizer.scala | 6 +-
.../spark/sql/catalyst/plans/QueryPlan.scala | 4 +-
.../spark/sql/catalyst/SQLBuilder.scala | 2 +-
.../apache/spark/sql/execution/subquery.scala | 49 ++++++---------
.../org/apache/spark/sql/QueryTest.scala | 4 +-
.../benchmark/TPCDSQueryBenchmark.scala | 1 -
8 files changed, 56 insertions(+), 74 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index bd4c19181f647..f540816366ca8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -146,7 +146,7 @@ class Analyzer(
// This cannot be done in ResolveSubquery because ResolveSubquery does not know the CTE.
other transformExpressions {
case e: SubqueryExpression =>
- e.withNewPlan(substituteCTE(e.query, cteRelations))
+ e.withNewPlan(substituteCTE(e.plan, cteRelations))
}
}
}
@@ -1091,7 +1091,7 @@ class Analyzer(
f: (LogicalPlan, Seq[Expression]) => SubqueryExpression): SubqueryExpression = {
// Step 1: Resolve the outer expressions.
var previous: LogicalPlan = null
- var current = e.query
+ var current = e.plan
do {
// Try to resolve the subquery plan using the regular analyzer.
previous = current
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
index ddbe937cba9bd..e2e7d98e33459 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
@@ -17,33 +17,33 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.plans.QueryPlan
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.types._
/**
- * An interface for subquery that is used in expressions.
+ * An interface for expressions that contain a [[QueryPlan]].
*/
-abstract class SubqueryExpression extends Expression {
+abstract class PlanExpression[T <: QueryPlan[_]] extends Expression {
/** The id of the subquery expression. */
def exprId: ExprId
- /** The logical plan of the query. */
- def query: LogicalPlan
+ /** The plan being wrapped in the query. */
+ def plan: T
- /**
- * Either a logical plan or a physical plan. The generated tree string (explain output) uses this
- * field to explain the subquery.
- */
- def plan: QueryPlan[_]
-
- /** Updates the query with new logical plan. */
- def withNewPlan(plan: LogicalPlan): SubqueryExpression
+ /** Updates the expression with a new plan. */
+ def withNewPlan(plan: T): PlanExpression[T]
protected def conditionString: String = children.mkString("[", " && ", "]")
}
+/**
+ * A base interface for expressions that contain a [[LogicalPlan]].
+ */
+abstract class SubqueryExpression extends PlanExpression[LogicalPlan] {
+ override def withNewPlan(plan: LogicalPlan): SubqueryExpression
+}
+
object SubqueryExpression {
def hasCorrelatedSubquery(e: Expression): Boolean = {
e.find {
@@ -60,20 +60,19 @@ object SubqueryExpression {
* Note: `exprId` is used to have a unique name in explain string output.
*/
case class ScalarSubquery(
- query: LogicalPlan,
+ plan: LogicalPlan,
children: Seq[Expression] = Seq.empty,
exprId: ExprId = NamedExpression.newExprId)
extends SubqueryExpression with Unevaluable {
- override lazy val resolved: Boolean = childrenResolved && query.resolved
+ override lazy val resolved: Boolean = childrenResolved && plan.resolved
override lazy val references: AttributeSet = {
- if (query.resolved) super.references -- query.outputSet
+ if (plan.resolved) super.references -- plan.outputSet
else super.references
}
- override def dataType: DataType = query.schema.fields.head.dataType
+ override def dataType: DataType = plan.schema.fields.head.dataType
override def foldable: Boolean = false
override def nullable: Boolean = true
- override def plan: LogicalPlan = SubqueryAlias(toString, query, None)
- override def withNewPlan(plan: LogicalPlan): ScalarSubquery = copy(query = plan)
+ override def withNewPlan(plan: LogicalPlan): ScalarSubquery = copy(plan = plan)
override def toString: String = s"scalar-subquery#${exprId.id} $conditionString"
}
@@ -92,19 +91,18 @@ object ScalarSubquery {
* be rewritten into a left semi/anti join during analysis.
*/
case class PredicateSubquery(
- query: LogicalPlan,
+ plan: LogicalPlan,
children: Seq[Expression] = Seq.empty,
nullAware: Boolean = false,
exprId: ExprId = NamedExpression.newExprId)
extends SubqueryExpression with Predicate with Unevaluable {
- override lazy val resolved = childrenResolved && query.resolved
- override lazy val references: AttributeSet = super.references -- query.outputSet
+ override lazy val resolved = childrenResolved && plan.resolved
+ override lazy val references: AttributeSet = super.references -- plan.outputSet
override def nullable: Boolean = nullAware
- override def plan: LogicalPlan = SubqueryAlias(toString, query, None)
- override def withNewPlan(plan: LogicalPlan): PredicateSubquery = copy(query = plan)
+ override def withNewPlan(plan: LogicalPlan): PredicateSubquery = copy(plan = plan)
override def semanticEquals(o: Expression): Boolean = o match {
case p: PredicateSubquery =>
- query.sameResult(p.query) && nullAware == p.nullAware &&
+ plan.sameResult(p.plan) && nullAware == p.nullAware &&
children.length == p.children.length &&
children.zip(p.children).forall(p => p._1.semanticEquals(p._2))
case _ => false
@@ -146,14 +144,13 @@ object PredicateSubquery {
* FROM b)
* }}}
*/
-case class ListQuery(query: LogicalPlan, exprId: ExprId = NamedExpression.newExprId)
+case class ListQuery(plan: LogicalPlan, exprId: ExprId = NamedExpression.newExprId)
extends SubqueryExpression with Unevaluable {
override lazy val resolved = false
override def children: Seq[Expression] = Seq.empty
override def dataType: DataType = ArrayType(NullType)
override def nullable: Boolean = false
- override def withNewPlan(plan: LogicalPlan): ListQuery = copy(query = plan)
- override def plan: LogicalPlan = SubqueryAlias(toString, query, None)
+ override def withNewPlan(plan: LogicalPlan): ListQuery = copy(plan = plan)
override def toString: String = s"list#${exprId.id}"
}
@@ -168,12 +165,11 @@ case class ListQuery(query: LogicalPlan, exprId: ExprId = NamedExpression.newExp
* WHERE b.id = a.id)
* }}}
*/
-case class Exists(query: LogicalPlan, exprId: ExprId = NamedExpression.newExprId)
+case class Exists(plan: LogicalPlan, exprId: ExprId = NamedExpression.newExprId)
extends SubqueryExpression with Predicate with Unevaluable {
override lazy val resolved = false
override def children: Seq[Expression] = Seq.empty
override def nullable: Boolean = false
- override def withNewPlan(plan: LogicalPlan): Exists = copy(query = plan)
- override def plan: LogicalPlan = SubqueryAlias(toString, query, None)
+ override def withNewPlan(plan: LogicalPlan): Exists = copy(plan = plan)
override def toString: String = s"exists#${exprId.id}"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index f97a78b411597..aa15f4a82383c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -127,7 +127,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
object OptimizeSubqueries extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case s: SubqueryExpression =>
- s.withNewPlan(Optimizer.this.execute(s.query))
+ s.withNewPlan(Optimizer.this.execute(s.plan))
}
}
}
@@ -1814,7 +1814,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
val newExpression = expression transform {
case s: ScalarSubquery if s.children.nonEmpty =>
subqueries += s
- s.query.output.head
+ s.plan.output.head
}
newExpression.asInstanceOf[E]
}
@@ -2029,7 +2029,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
// grouping expressions. As a result we need to replace all the scalar subqueries in the
// grouping expressions by their result.
val newGrouping = grouping.map { e =>
- subqueries.find(_.semanticEquals(e)).map(_.query.output.head).getOrElse(e)
+ subqueries.find(_.semanticEquals(e)).map(_.plan.output.head).getOrElse(e)
}
Aggregate(newGrouping, newExpressions, constructLeftJoins(child, subqueries))
} else {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index becf6945a2f2b..8ee31f42ad88e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -263,7 +263,9 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
* All the subqueries of current plan.
*/
def subqueries: Seq[PlanType] = {
- expressions.flatMap(_.collect {case e: SubqueryExpression => e.plan.asInstanceOf[PlanType]})
+ expressions.flatMap(_.collect {
+ case e: PlanExpression[_] => e.plan.asInstanceOf[PlanType]
+ })
}
override protected def innerChildren: Seq[QueryPlan[_]] = subqueries
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala
index ff8e0f2642055..0f51aa58d63ba 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala
@@ -80,7 +80,7 @@ class SQLBuilder private (
try {
val replaced = finalPlan.transformAllExpressions {
case s: SubqueryExpression =>
- val query = new SQLBuilder(s.query, nextSubqueryId, nextGenAttrId, exprIdMap).toSQL
+ val query = new SQLBuilder(s.plan, nextSubqueryId, nextGenAttrId, exprIdMap).toSQL
val sql = s match {
case _: ListQuery => query
case _: Exists => s"EXISTS($query)"
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
index c730bee6ae050..730ca27f82bac 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
@@ -22,9 +22,8 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.{expressions, InternalRow}
-import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.{Expression, ExprId, InSet, Literal, PlanExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BooleanType, DataType, StructType}
@@ -32,18 +31,7 @@ import org.apache.spark.sql.types.{BooleanType, DataType, StructType}
/**
* The base class for subquery that is used in SparkPlan.
*/
-trait ExecSubqueryExpression extends SubqueryExpression {
-
- val executedPlan: SubqueryExec
- def withExecutedPlan(plan: SubqueryExec): ExecSubqueryExpression
-
- // does not have logical plan
- override def query: LogicalPlan = throw new UnsupportedOperationException
- override def withNewPlan(plan: LogicalPlan): SubqueryExpression =
- throw new UnsupportedOperationException
-
- override def plan: SparkPlan = executedPlan
-
+abstract class ExecSubqueryExpression extends PlanExpression[SubqueryExec] {
/**
* Fill the expression with collected result from executed plan.
*/
@@ -56,30 +44,29 @@ trait ExecSubqueryExpression extends SubqueryExpression {
* This is the physical copy of ScalarSubquery to be used inside SparkPlan.
*/
case class ScalarSubquery(
- executedPlan: SubqueryExec,
+ plan: SubqueryExec,
exprId: ExprId)
extends ExecSubqueryExpression {
- override def dataType: DataType = executedPlan.schema.fields.head.dataType
+ override def dataType: DataType = plan.schema.fields.head.dataType
override def children: Seq[Expression] = Nil
override def nullable: Boolean = true
- override def toString: String = executedPlan.simpleString
-
- def withExecutedPlan(plan: SubqueryExec): ExecSubqueryExpression = copy(executedPlan = plan)
+ override def toString: String = plan.simpleString
+ override def withNewPlan(query: SubqueryExec): ScalarSubquery = copy(plan = query)
override def semanticEquals(other: Expression): Boolean = other match {
- case s: ScalarSubquery => executedPlan.sameResult(executedPlan)
+ case s: ScalarSubquery => plan.sameResult(s.plan)
case _ => false
}
// the first column in first row from `query`.
- @volatile private var result: Any = null
+ @volatile private var result: Any = _
@volatile private var updated: Boolean = false
def updateResult(): Unit = {
val rows = plan.executeCollect()
if (rows.length > 1) {
- sys.error(s"more than one row returned by a subquery used as an expression:\n${plan}")
+ sys.error(s"more than one row returned by a subquery used as an expression:\n$plan")
}
if (rows.length == 1) {
assert(rows(0).numFields == 1,
@@ -108,7 +95,7 @@ case class ScalarSubquery(
*/
case class InSubquery(
child: Expression,
- executedPlan: SubqueryExec,
+ plan: SubqueryExec,
exprId: ExprId,
private var result: Array[Any] = null,
private var updated: Boolean = false) extends ExecSubqueryExpression {
@@ -116,13 +103,11 @@ case class InSubquery(
override def dataType: DataType = BooleanType
override def children: Seq[Expression] = child :: Nil
override def nullable: Boolean = child.nullable
- override def toString: String = s"$child IN ${executedPlan.name}"
-
- def withExecutedPlan(plan: SubqueryExec): ExecSubqueryExpression = copy(executedPlan = plan)
+ override def toString: String = s"$child IN ${plan.name}"
+ override def withNewPlan(plan: SubqueryExec): InSubquery = copy(plan = plan)
override def semanticEquals(other: Expression): Boolean = other match {
- case in: InSubquery => child.semanticEquals(in.child) &&
- executedPlan.sameResult(in.executedPlan)
+ case in: InSubquery => child.semanticEquals(in.child) && plan.sameResult(in.plan)
case _ => false
}
@@ -159,8 +144,8 @@ case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] {
ScalarSubquery(
SubqueryExec(s"subquery${subquery.exprId.id}", executedPlan),
subquery.exprId)
- case expressions.PredicateSubquery(plan, Seq(e: Expression), _, exprId) =>
- val executedPlan = new QueryExecution(sparkSession, plan).executedPlan
+ case expressions.PredicateSubquery(query, Seq(e: Expression), _, exprId) =>
+ val executedPlan = new QueryExecution(sparkSession, query).executedPlan
InSubquery(e, SubqueryExec(s"subquery${exprId.id}", executedPlan), exprId)
}
}
@@ -184,9 +169,9 @@ case class ReuseSubquery(conf: SQLConf) extends Rule[SparkPlan] {
val sameSchema = subqueries.getOrElseUpdate(sub.plan.schema, ArrayBuffer[SubqueryExec]())
val sameResult = sameSchema.find(_.sameResult(sub.plan))
if (sameResult.isDefined) {
- sub.withExecutedPlan(sameResult.get)
+ sub.withNewPlan(sameResult.get)
} else {
- sameSchema += sub.executedPlan
+ sameSchema += sub.plan
sub
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 304881d4a4bdd..cff9d22d089c3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -292,7 +292,7 @@ abstract class QueryTest extends PlanTest {
p.expressions.foreach {
_.foreach {
case s: SubqueryExpression =>
- s.query.foreach(collectData)
+ s.plan.foreach(collectData)
case _ =>
}
}
@@ -334,7 +334,7 @@ abstract class QueryTest extends PlanTest {
case p =>
p.transformExpressions {
case s: SubqueryExpression =>
- s.withNewPlan(s.query.transformDown(renormalize))
+ s.withNewPlan(s.plan.transformDown(renormalize))
}
}
val normalized2 = jsonBackPlan.transformDown(renormalize)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala
index 957a1d6426e87..3988d9750b585 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala
@@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util._
-import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.Benchmark
/**
From 928ca1c6d12b23d84f9b6205e22d2e756311f072 Mon Sep 17 00:00:00 2001
From: Wenchen Fan
Date: Wed, 17 Aug 2016 09:31:22 -0700
Subject: [PATCH 008/270] [SPARK-17102][SQL] bypass UserDefinedGenerator for
json format check
## What changes were proposed in this pull request?
We use reflection to convert `TreeNode` to json string, and currently don't support arbitrary object. `UserDefinedGenerator` takes a function object, so we should skip json format test for it, or the tests can be flacky, e.g. `DataFrameSuite.simple explode`, this test always fail with scala 2.10(branch 1.6 builds with scala 2.10 by default), but pass with scala 2.11(master branch builds with scala 2.11 by default).
## How was this patch tested?
N/A
Author: Wenchen Fan
Closes #14679 from cloud-fan/json.
---
sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index cff9d22d089c3..484e4380331f8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -249,9 +249,10 @@ abstract class QueryTest extends PlanTest {
}
p
}.transformAllExpressions {
- case a: ImperativeAggregate => return
+ case _: ImperativeAggregate => return
case _: TypedAggregateExpression => return
case Literal(_, _: ObjectType) => return
+ case _: UserDefinedGenerator => return
}
// bypass hive tests before we fix all corner cases in hive module.
From e3fec51fa1ed161789ab7aa32ed36efe357b5d31 Mon Sep 17 00:00:00 2001
From: Marcelo Vanzin
Date: Wed, 17 Aug 2016 11:12:21 -0700
Subject: [PATCH 009/270] [SPARK-16930][YARN] Fix a couple of races in cluster
app initialization.
There are two narrow races that could cause the ApplicationMaster to miss
when the user application instantiates the SparkContext, which could cause
app failures when nothing was wrong with the app. It was also possible for
a failing application to get stuck in the loop that waits for the context
for a long time, instead of failing quickly.
The change uses a promise to track the SparkContext instance, which gets
rid of the races and allows for some simplification of the code.
Tested with existing unit tests, and a new one being added to test the
timeout code.
Author: Marcelo Vanzin
Closes #14542 from vanzin/SPARK-16930.
---
.../spark/deploy/yarn/ApplicationMaster.scala | 98 +++++++++----------
.../cluster/YarnClusterScheduler.scala | 5 -
.../spark/deploy/yarn/YarnClusterSuite.scala | 22 +++++
3 files changed, 66 insertions(+), 59 deletions(-)
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index 614278c8b2d22..a4b575c85d5fb 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -20,9 +20,11 @@ package org.apache.spark.deploy.yarn
import java.io.{File, IOException}
import java.lang.reflect.InvocationTargetException
import java.net.{Socket, URI, URL}
-import java.util.concurrent.atomic.AtomicReference
+import java.util.concurrent.{TimeoutException, TimeUnit}
import scala.collection.mutable.HashMap
+import scala.concurrent.Promise
+import scala.concurrent.duration.Duration
import scala.util.control.NonFatal
import org.apache.hadoop.fs.{FileSystem, Path}
@@ -106,12 +108,11 @@ private[spark] class ApplicationMaster(
// Next wait interval before allocator poll.
private var nextAllocationInterval = initialAllocationInterval
- // Fields used in client mode.
private var rpcEnv: RpcEnv = null
private var amEndpoint: RpcEndpointRef = _
- // Fields used in cluster mode.
- private val sparkContextRef = new AtomicReference[SparkContext](null)
+ // In cluster mode, used to tell the AM when the user's SparkContext has been initialized.
+ private val sparkContextPromise = Promise[SparkContext]()
private var credentialRenewer: AMCredentialRenewer = _
@@ -316,23 +317,15 @@ private[spark] class ApplicationMaster(
}
private def sparkContextInitialized(sc: SparkContext) = {
- sparkContextRef.synchronized {
- sparkContextRef.compareAndSet(null, sc)
- sparkContextRef.notifyAll()
- }
- }
-
- private def sparkContextStopped(sc: SparkContext) = {
- sparkContextRef.compareAndSet(sc, null)
+ sparkContextPromise.success(sc)
}
private def registerAM(
+ _sparkConf: SparkConf,
_rpcEnv: RpcEnv,
driverRef: RpcEndpointRef,
uiAddress: String,
securityMgr: SecurityManager) = {
- val sc = sparkContextRef.get()
-
val appId = client.getAttemptId().getApplicationId().toString()
val attemptId = client.getAttemptId().getAttemptId().toString()
val historyAddress =
@@ -341,7 +334,6 @@ private[spark] class ApplicationMaster(
.map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}/${attemptId}" }
.getOrElse("")
- val _sparkConf = if (sc != null) sc.getConf else sparkConf
val driverUrl = RpcEndpointAddress(
_sparkConf.get("spark.driver.host"),
_sparkConf.get("spark.driver.port").toInt,
@@ -385,21 +377,35 @@ private[spark] class ApplicationMaster(
// This a bit hacky, but we need to wait until the spark.driver.port property has
// been set by the Thread executing the user class.
- val sc = waitForSparkContextInitialized()
-
- // If there is no SparkContext at this point, just fail the app.
- if (sc == null) {
- finish(FinalApplicationStatus.FAILED,
- ApplicationMaster.EXIT_SC_NOT_INITED,
- "Timed out waiting for SparkContext.")
- } else {
- rpcEnv = sc.env.rpcEnv
- val driverRef = runAMEndpoint(
- sc.getConf.get("spark.driver.host"),
- sc.getConf.get("spark.driver.port"),
- isClusterMode = true)
- registerAM(rpcEnv, driverRef, sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr)
+ logInfo("Waiting for spark context initialization...")
+ val totalWaitTime = sparkConf.get(AM_MAX_WAIT_TIME)
+ try {
+ val sc = ThreadUtils.awaitResult(sparkContextPromise.future,
+ Duration(totalWaitTime, TimeUnit.MILLISECONDS))
+ if (sc != null) {
+ rpcEnv = sc.env.rpcEnv
+ val driverRef = runAMEndpoint(
+ sc.getConf.get("spark.driver.host"),
+ sc.getConf.get("spark.driver.port"),
+ isClusterMode = true)
+ registerAM(sc.getConf, rpcEnv, driverRef, sc.ui.map(_.appUIAddress).getOrElse(""),
+ securityMgr)
+ } else {
+ // Sanity check; should never happen in normal operation, since sc should only be null
+ // if the user app did not create a SparkContext.
+ if (!finished) {
+ throw new IllegalStateException("SparkContext is null but app is still running!")
+ }
+ }
userClassThread.join()
+ } catch {
+ case e: SparkException if e.getCause().isInstanceOf[TimeoutException] =>
+ logError(
+ s"SparkContext did not initialize after waiting for $totalWaitTime ms. " +
+ "Please check earlier log output for errors. Failing the application.")
+ finish(FinalApplicationStatus.FAILED,
+ ApplicationMaster.EXIT_SC_NOT_INITED,
+ "Timed out waiting for SparkContext.")
}
}
@@ -409,7 +415,8 @@ private[spark] class ApplicationMaster(
clientMode = true)
val driverRef = waitForSparkDriver()
addAmIpFilter()
- registerAM(rpcEnv, driverRef, sparkConf.get("spark.driver.appUIAddress", ""), securityMgr)
+ registerAM(sparkConf, rpcEnv, driverRef, sparkConf.get("spark.driver.appUIAddress", ""),
+ securityMgr)
// In client mode the actor will stop the reporter thread.
reporterThread.join()
@@ -525,26 +532,6 @@ private[spark] class ApplicationMaster(
}
}
- private def waitForSparkContextInitialized(): SparkContext = {
- logInfo("Waiting for spark context initialization")
- sparkContextRef.synchronized {
- val totalWaitTime = sparkConf.get(AM_MAX_WAIT_TIME)
- val deadline = System.currentTimeMillis() + totalWaitTime
-
- while (sparkContextRef.get() == null && System.currentTimeMillis < deadline && !finished) {
- logInfo("Waiting for spark context initialization ... ")
- sparkContextRef.wait(10000L)
- }
-
- val sparkContext = sparkContextRef.get()
- if (sparkContext == null) {
- logError(("SparkContext did not initialize after waiting for %d ms. Please check earlier"
- + " log output for errors. Failing the application.").format(totalWaitTime))
- }
- sparkContext
- }
- }
-
private def waitForSparkDriver(): RpcEndpointRef = {
logInfo("Waiting for Spark driver to be reachable.")
var driverUp = false
@@ -647,6 +634,13 @@ private[spark] class ApplicationMaster(
ApplicationMaster.EXIT_EXCEPTION_USER_CLASS,
"User class threw exception: " + cause)
}
+ sparkContextPromise.tryFailure(e.getCause())
+ } finally {
+ // Notify the thread waiting for the SparkContext, in case the application did not
+ // instantiate one. This will do nothing when the user code instantiates a SparkContext
+ // (with the correct master), or when the user code throws an exception (due to the
+ // tryFailure above).
+ sparkContextPromise.trySuccess(null)
}
}
}
@@ -759,10 +753,6 @@ object ApplicationMaster extends Logging {
master.sparkContextInitialized(sc)
}
- private[spark] def sparkContextStopped(sc: SparkContext): Boolean = {
- master.sparkContextStopped(sc)
- }
-
private[spark] def getAttemptId(): ApplicationAttemptId = {
master.getAttemptId
}
diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
index 72ec4d6b34af6..96c9151fc351d 100644
--- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
+++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
@@ -34,9 +34,4 @@ private[spark] class YarnClusterScheduler(sc: SparkContext) extends YarnSchedule
logInfo("YarnClusterScheduler.postStartHook done")
}
- override def stop() {
- super.stop()
- ApplicationMaster.sparkContextStopped(sc)
- }
-
}
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
index 8ab7b21c22139..fb7926f6a1e28 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
@@ -33,6 +33,7 @@ import org.scalatest.concurrent.Eventually._
import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.deploy.yarn.config._
import org.apache.spark.internal.Logging
import org.apache.spark.launcher._
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationStart,
@@ -192,6 +193,14 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
}
}
+ test("timeout to get SparkContext in cluster mode triggers failure") {
+ val timeout = 2000
+ val finalState = runSpark(false, mainClassName(SparkContextTimeoutApp.getClass),
+ appArgs = Seq((timeout * 4).toString),
+ extraConf = Map(AM_MAX_WAIT_TIME.key -> timeout.toString))
+ finalState should be (SparkAppHandle.State.FAILED)
+ }
+
private def testBasicYarnApp(clientMode: Boolean, conf: Map[String, String] = Map()): Unit = {
val result = File.createTempFile("result", null, tempDir)
val finalState = runSpark(clientMode, mainClassName(YarnClusterDriver.getClass),
@@ -469,3 +478,16 @@ private object YarnLauncherTestApp {
}
}
+
+/**
+ * Used to test code in the AM that detects the SparkContext instance. Expects a single argument
+ * with the duration to sleep for, in ms.
+ */
+private object SparkContextTimeoutApp {
+
+ def main(args: Array[String]): Unit = {
+ val Array(sleepTime) = args
+ Thread.sleep(java.lang.Long.parseLong(sleepTime))
+ }
+
+}
From 4d92af310ad29ade039e4130f91f2a3d9180deef Mon Sep 17 00:00:00 2001
From: Yanbo Liang
Date: Wed, 17 Aug 2016 11:18:33 -0700
Subject: [PATCH 010/270] [SPARK-16446][SPARKR][ML] Gaussian Mixture Model
wrapper in SparkR
## What changes were proposed in this pull request?
Gaussian Mixture Model wrapper in SparkR, similarly to R's ```mvnormalmixEM```.
## How was this patch tested?
Unit test.
Author: Yanbo Liang
Closes #14392 from yanboliang/spark-16446.
---
R/pkg/NAMESPACE | 3 +-
R/pkg/R/generics.R | 7 +
R/pkg/R/mllib.R | 139 +++++++++++++++++-
R/pkg/inst/tests/testthat/test_mllib.R | 62 ++++++++
.../spark/ml/r/GaussianMixtureWrapper.scala | 128 ++++++++++++++++
.../org/apache/spark/ml/r/RWrappers.scala | 2 +
6 files changed, 338 insertions(+), 3 deletions(-)
create mode 100644 mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 1e23b233c1116..c71eec5ce0437 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -25,7 +25,8 @@ exportMethods("glm",
"fitted",
"spark.naiveBayes",
"spark.survreg",
- "spark.isoreg")
+ "spark.isoreg",
+ "spark.gaussianMixture")
# Job group lifecycle management methods
export("setJobGroup",
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index ebacc11741812..06bb25d62d34d 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1308,6 +1308,13 @@ setGeneric("spark.survreg", function(data, formula, ...) { standardGeneric("spar
#' @export
setGeneric("spark.isoreg", function(data, formula, ...) { standardGeneric("spark.isoreg") })
+#' @rdname spark.gaussianMixture
+#' @export
+setGeneric("spark.gaussianMixture",
+ function(data, formula, ...) {
+ standardGeneric("spark.gaussianMixture")
+ })
+
#' @rdname write.ml
#' @export
setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") })
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index 0dcc54d7af09b..db74046056a99 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -60,6 +60,13 @@ setClass("KMeansModel", representation(jobj = "jobj"))
#' @note IsotonicRegressionModel since 2.1.0
setClass("IsotonicRegressionModel", representation(jobj = "jobj"))
+#' S4 class that represents a GaussianMixtureModel
+#'
+#' @param jobj a Java object reference to the backing Scala GaussianMixtureModel
+#' @export
+#' @note GaussianMixtureModel since 2.1.0
+setClass("GaussianMixtureModel", representation(jobj = "jobj"))
+
#' Saves the MLlib model to the input path
#'
#' Saves the MLlib model to the input path. For more information, see the specific
@@ -67,7 +74,7 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj"))
#' @rdname write.ml
#' @name write.ml
#' @export
-#' @seealso \link{spark.glm}, \link{glm}
+#' @seealso \link{spark.glm}, \link{glm}, \link{spark.gaussianMixture}
#' @seealso \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg}
#' @seealso \link{spark.isoreg}
#' @seealso \link{read.ml}
@@ -80,7 +87,7 @@ NULL
#' @rdname predict
#' @name predict
#' @export
-#' @seealso \link{spark.glm}, \link{glm}
+#' @seealso \link{spark.glm}, \link{glm}, \link{spark.gaussianMixture}
#' @seealso \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg}
#' @seealso \link{spark.isoreg}
NULL
@@ -649,6 +656,25 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char
invisible(callJMethod(writer, "save", path))
})
+# Save fitted MLlib model to the input path
+
+#' @param path the directory where the model is saved.
+#' @param overwrite overwrites or not if the output path already exists. Default is FALSE
+#' which means throw exception if the output path exists.
+#'
+#' @aliases write.ml,GaussianMixtureModel,character-method
+#' @rdname spark.gaussianMixture
+#' @export
+#' @note write.ml(GaussianMixtureModel, character) since 2.1.0
+setMethod("write.ml", signature(object = "GaussianMixtureModel", path = "character"),
+ function(object, path, overwrite = FALSE) {
+ writer <- callJMethod(object@jobj, "write")
+ if (overwrite) {
+ writer <- callJMethod(writer, "overwrite")
+ }
+ invisible(callJMethod(writer, "save", path))
+ })
+
#' Load a fitted MLlib model from the input path.
#'
#' @param path Path of the model to read.
@@ -676,6 +702,8 @@ read.ml <- function(path) {
return(new("KMeansModel", jobj = jobj))
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.IsotonicRegressionWrapper")) {
return(new("IsotonicRegressionModel", jobj = jobj))
+ } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GaussianMixtureWrapper")) {
+ return(new("GaussianMixtureModel", jobj = jobj))
} else {
stop(paste("Unsupported model: ", jobj))
}
@@ -757,3 +785,110 @@ setMethod("predict", signature(object = "AFTSurvivalRegressionModel"),
function(object, newData) {
return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
})
+
+#' Multivariate Gaussian Mixture Model (GMM)
+#'
+#' Fits multivariate gaussian mixture model against a Spark DataFrame, similarly to R's
+#' mvnormalmixEM(). Users can call \code{summary} to print a summary of the fitted model,
+#' \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml}
+#' to save/load fitted models.
+#'
+#' @param data a SparkDataFrame for training.
+#' @param formula a symbolic description of the model to be fitted. Currently only a few formula
+#' operators are supported, including '~', '.', ':', '+', and '-'.
+#' Note that the response variable of formula is empty in spark.gaussianMixture.
+#' @param k number of independent Gaussians in the mixture model.
+#' @param maxIter maximum iteration number.
+#' @param tol the convergence tolerance.
+#' @aliases spark.gaussianMixture,SparkDataFrame,formula-method
+#' @return \code{spark.gaussianMixture} returns a fitted multivariate gaussian mixture model.
+#' @rdname spark.gaussianMixture
+#' @name spark.gaussianMixture
+#' @seealso mixtools: \url{https://cran.r-project.org/web/packages/mixtools/}
+#' @export
+#' @examples
+#' \dontrun{
+#' sparkR.session()
+#' library(mvtnorm)
+#' set.seed(100)
+#' a <- rmvnorm(4, c(0, 0))
+#' b <- rmvnorm(6, c(3, 4))
+#' data <- rbind(a, b)
+#' df <- createDataFrame(as.data.frame(data))
+#' model <- spark.gaussianMixture(df, ~ V1 + V2, k = 2)
+#' summary(model)
+#'
+#' # fitted values on training data
+#' fitted <- predict(model, df)
+#' head(select(fitted, "V1", "prediction"))
+#'
+#' # save fitted model to input path
+#' path <- "path/to/model"
+#' write.ml(model, path)
+#'
+#' # can also read back the saved model and print
+#' savedModel <- read.ml(path)
+#' summary(savedModel)
+#' }
+#' @note spark.gaussianMixture since 2.1.0
+#' @seealso \link{predict}, \link{read.ml}, \link{write.ml}
+setMethod("spark.gaussianMixture", signature(data = "SparkDataFrame", formula = "formula"),
+ function(data, formula, k = 2, maxIter = 100, tol = 0.01) {
+ formula <- paste(deparse(formula), collapse = "")
+ jobj <- callJStatic("org.apache.spark.ml.r.GaussianMixtureWrapper", "fit", data@sdf,
+ formula, as.integer(k), as.integer(maxIter), as.numeric(tol))
+ return(new("GaussianMixtureModel", jobj = jobj))
+ })
+
+# Get the summary of a multivariate gaussian mixture model
+
+#' @param object a fitted gaussian mixture model.
+#' @param ... currently not used argument(s) passed to the method.
+#' @return \code{summary} returns the model's lambda, mu, sigma and posterior.
+#' @aliases spark.gaussianMixture,SparkDataFrame,formula-method
+#' @rdname spark.gaussianMixture
+#' @export
+#' @note summary(GaussianMixtureModel) since 2.1.0
+setMethod("summary", signature(object = "GaussianMixtureModel"),
+ function(object, ...) {
+ jobj <- object@jobj
+ is.loaded <- callJMethod(jobj, "isLoaded")
+ lambda <- unlist(callJMethod(jobj, "lambda"))
+ muList <- callJMethod(jobj, "mu")
+ sigmaList <- callJMethod(jobj, "sigma")
+ k <- callJMethod(jobj, "k")
+ dim <- callJMethod(jobj, "dim")
+ mu <- c()
+ for (i in 1 : k) {
+ start <- (i - 1) * dim + 1
+ end <- i * dim
+ mu[[i]] <- unlist(muList[start : end])
+ }
+ sigma <- c()
+ for (i in 1 : k) {
+ start <- (i - 1) * dim * dim + 1
+ end <- i * dim * dim
+ sigma[[i]] <- t(matrix(sigmaList[start : end], ncol = dim))
+ }
+ posterior <- if (is.loaded) {
+ NULL
+ } else {
+ dataFrame(callJMethod(jobj, "posterior"))
+ }
+ return(list(lambda = lambda, mu = mu, sigma = sigma,
+ posterior = posterior, is.loaded = is.loaded))
+ })
+
+# Predicted values based on a gaussian mixture model
+
+#' @param newData a SparkDataFrame for testing.
+#' @return \code{predict} returns a SparkDataFrame containing predicted labels in a column named
+#' "prediction".
+#' @aliases predict,GaussianMixtureModel,SparkDataFrame-method
+#' @rdname spark.gaussianMixture
+#' @export
+#' @note predict(GaussianMixtureModel) since 2.1.0
+setMethod("predict", signature(object = "GaussianMixtureModel"),
+ function(object, newData) {
+ return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
+ })
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index b759b28927365..96179864a88bf 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -508,4 +508,66 @@ test_that("spark.isotonicRegression", {
unlink(modelPath)
})
+test_that("spark.gaussianMixture", {
+ # R code to reproduce the result.
+ # nolint start
+ #' library(mvtnorm)
+ #' set.seed(100)
+ #' a <- rmvnorm(4, c(0, 0))
+ #' b <- rmvnorm(6, c(3, 4))
+ #' data <- rbind(a, b)
+ #' model <- mvnormalmixEM(data, k = 2)
+ #' model$lambda
+ #
+ # [1] 0.4 0.6
+ #
+ #' model$mu
+ #
+ # [1] -0.2614822 0.5128697
+ # [1] 2.647284 4.544682
+ #
+ #' model$sigma
+ #
+ # [[1]]
+ # [,1] [,2]
+ # [1,] 0.08427399 0.00548772
+ # [2,] 0.00548772 0.09090715
+ #
+ # [[2]]
+ # [,1] [,2]
+ # [1,] 0.1641373 -0.1673806
+ # [2,] -0.1673806 0.7508951
+ # nolint end
+ data <- list(list(-0.50219235, 0.1315312), list(-0.07891709, 0.8867848),
+ list(0.11697127, 0.3186301), list(-0.58179068, 0.7145327),
+ list(2.17474057, 3.6401379), list(3.08988614, 4.0962745),
+ list(2.79836605, 4.7398405), list(3.12337950, 3.9706833),
+ list(2.61114575, 4.5108563), list(2.08618581, 6.3102968))
+ df <- createDataFrame(data, c("x1", "x2"))
+ model <- spark.gaussianMixture(df, ~ x1 + x2, k = 2)
+ stats <- summary(model)
+ rLambda <- c(0.4, 0.6)
+ rMu <- c(-0.2614822, 0.5128697, 2.647284, 4.544682)
+ rSigma <- c(0.08427399, 0.00548772, 0.00548772, 0.09090715,
+ 0.1641373, -0.1673806, -0.1673806, 0.7508951)
+ expect_equal(stats$lambda, rLambda)
+ expect_equal(unlist(stats$mu), rMu, tolerance = 1e-3)
+ expect_equal(unlist(stats$sigma), rSigma, tolerance = 1e-3)
+ p <- collect(select(predict(model, df), "prediction"))
+ expect_equal(p$prediction, c(0, 0, 0, 0, 1, 1, 1, 1, 1, 1))
+
+ # Test model save/load
+ modelPath <- tempfile(pattern = "spark-gaussianMixture", fileext = ".tmp")
+ write.ml(model, modelPath)
+ expect_error(write.ml(model, modelPath))
+ write.ml(model, modelPath, overwrite = TRUE)
+ model2 <- read.ml(modelPath)
+ stats2 <- summary(model2)
+ expect_equal(stats$lambda, stats2$lambda)
+ expect_equal(unlist(stats$mu), unlist(stats2$mu))
+ expect_equal(unlist(stats$sigma), unlist(stats2$sigma))
+
+ unlink(modelPath)
+})
+
sparkR.session.stop()
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala
new file mode 100644
index 0000000000000..1e8b3bbab6655
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import org.apache.hadoop.fs.Path
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.ml.{Pipeline, PipelineModel}
+import org.apache.spark.ml.attribute.AttributeGroup
+import org.apache.spark.ml.clustering.{GaussianMixture, GaussianMixtureModel}
+import org.apache.spark.ml.feature.RFormula
+import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.ml.util.{MLReadable, MLReader, MLWritable, MLWriter}
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.functions._
+
+private[r] class GaussianMixtureWrapper private (
+ val pipeline: PipelineModel,
+ val dim: Int,
+ val isLoaded: Boolean = false) extends MLWritable {
+
+ private val gmm: GaussianMixtureModel = pipeline.stages(1).asInstanceOf[GaussianMixtureModel]
+
+ lazy val k: Int = gmm.getK
+
+ lazy val lambda: Array[Double] = gmm.weights
+
+ lazy val mu: Array[Double] = gmm.gaussians.flatMap(_.mean.toArray)
+
+ lazy val sigma: Array[Double] = gmm.gaussians.flatMap(_.cov.toArray)
+
+ lazy val vectorToArray = udf { probability: Vector => probability.toArray }
+ lazy val posterior: DataFrame = gmm.summary.probability
+ .withColumn("posterior", vectorToArray(col(gmm.summary.probabilityCol)))
+ .drop(gmm.summary.probabilityCol)
+
+ def transform(dataset: Dataset[_]): DataFrame = {
+ pipeline.transform(dataset).drop(gmm.getFeaturesCol)
+ }
+
+ override def write: MLWriter = new GaussianMixtureWrapper.GaussianMixtureWrapperWriter(this)
+
+}
+
+private[r] object GaussianMixtureWrapper extends MLReadable[GaussianMixtureWrapper] {
+
+ def fit(
+ data: DataFrame,
+ formula: String,
+ k: Int,
+ maxIter: Int,
+ tol: Double): GaussianMixtureWrapper = {
+
+ val rFormulaModel = new RFormula()
+ .setFormula(formula)
+ .setFeaturesCol("features")
+ .fit(data)
+
+ // get feature names from output schema
+ val schema = rFormulaModel.transform(data).schema
+ val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol))
+ .attributes.get
+ val features = featureAttrs.map(_.name.get)
+ val dim = features.length
+
+ val gm = new GaussianMixture()
+ .setK(k)
+ .setMaxIter(maxIter)
+ .setTol(tol)
+
+ val pipeline = new Pipeline()
+ .setStages(Array(rFormulaModel, gm))
+ .fit(data)
+
+ new GaussianMixtureWrapper(pipeline, dim)
+ }
+
+ override def read: MLReader[GaussianMixtureWrapper] = new GaussianMixtureWrapperReader
+
+ override def load(path: String): GaussianMixtureWrapper = super.load(path)
+
+ class GaussianMixtureWrapperWriter(instance: GaussianMixtureWrapper) extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ val rMetadataPath = new Path(path, "rMetadata").toString
+ val pipelinePath = new Path(path, "pipeline").toString
+
+ val rMetadata = ("class" -> instance.getClass.getName) ~
+ ("dim" -> instance.dim)
+ val rMetadataJson: String = compact(render(rMetadata))
+
+ sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
+ instance.pipeline.save(pipelinePath)
+ }
+ }
+
+ class GaussianMixtureWrapperReader extends MLReader[GaussianMixtureWrapper] {
+
+ override def load(path: String): GaussianMixtureWrapper = {
+ implicit val format = DefaultFormats
+ val rMetadataPath = new Path(path, "rMetadata").toString
+ val pipelinePath = new Path(path, "pipeline").toString
+ val pipeline = PipelineModel.load(pipelinePath)
+
+ val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
+ val rMetadata = parse(rMetadataStr)
+ val dim = (rMetadata \ "dim").extract[Int]
+ new GaussianMixtureWrapper(pipeline, dim, isLoaded = true)
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
index f9a44d60e691a..88ac26bc5e351 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
@@ -46,6 +46,8 @@ private[r] object RWrappers extends MLReader[Object] {
KMeansWrapper.load(path)
case "org.apache.spark.ml.r.IsotonicRegressionWrapper" =>
IsotonicRegressionWrapper.load(path)
+ case "org.apache.spark.ml.r.GaussianMixtureWrapper" =>
+ GaussianMixtureWrapper.load(path)
case _ =>
throw new SparkException(s"SparkR read.ml does not support load $className")
}
From cc97ea188e1d5b8e851d1a8438b8af092783ec04 Mon Sep 17 00:00:00 2001
From: Steve Loughran
Date: Wed, 17 Aug 2016 11:42:57 -0700
Subject: [PATCH 011/270] [SPARK-16736][CORE][SQL] purge superfluous fs calls
A review of the code, working back from Hadoop's `FileSystem.exists()` and `FileSystem.isDirectory()` code, then removing uses of the calls when superfluous.
1. delete is harmless if called on a nonexistent path, so don't do any checks before deletes
1. any `FileSystem.exists()` check before `getFileStatus()` or `open()` is superfluous as the operation itself does the check. Instead the `FileNotFoundException` is caught and triggers the downgraded path. When a `FileNotFoundException` was thrown before, the code still creates a new FNFE with the error messages. Though now the inner exceptions are nested, for easier diagnostics.
Initially, relying on Jenkins test runs.
One troublespot here is that some of the codepaths are clearly error situations; it's not clear that they have coverage anyway. Trying to create the failure conditions in tests would be ideal, but it will also be hard.
Author: Steve Loughran
Closes #14371 from steveloughran/cloud/SPARK-16736-superfluous-fs-calls.
---
.../scala/org/apache/spark/SparkContext.scala | 3 --
.../deploy/history/FsHistoryProvider.scala | 27 +++++++---------
.../spark/rdd/ReliableCheckpointRDD.scala | 31 ++++++++----------
.../spark/rdd/ReliableRDDCheckpointData.scala | 7 +---
.../scheduler/EventLoggingListener.scala | 13 ++------
.../spark/repl/ExecutorClassLoader.scala | 9 +++---
.../state/HDFSBackedStateStoreProvider.scala | 32 ++++++++++---------
.../hive/JavaMetastoreDataSourcesSuite.java | 4 +--
.../sql/hive/MetastoreDataSourcesSuite.scala | 2 +-
.../apache/spark/streaming/Checkpoint.scala | 17 ++++------
.../util/FileBasedWriteAheadLog.scala | 27 ++++++++++++----
.../spark/streaming/util/HdfsUtils.scala | 24 +++++++-------
.../org/apache/spark/deploy/yarn/Client.scala | 5 ++-
13 files changed, 92 insertions(+), 109 deletions(-)
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index a6853fe3989a8..60f042f1e07c5 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -1410,9 +1410,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
val scheme = new URI(schemeCorrectedPath).getScheme
if (!Array("http", "https", "ftp").contains(scheme)) {
val fs = hadoopPath.getFileSystem(hadoopConfiguration)
- if (!fs.exists(hadoopPath)) {
- throw new FileNotFoundException(s"Added file $hadoopPath does not exist.")
- }
val isDir = fs.getFileStatus(hadoopPath).isDirectory
if (!isLocal && scheme == "file" && isDir) {
throw new SparkException(s"addFile does not support local directories when not running " +
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
index bc09935f93f80..6874aa5f938ac 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
@@ -193,16 +193,18 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock)
private def startPolling(): Unit = {
// Validate the log directory.
val path = new Path(logDir)
- if (!fs.exists(path)) {
- var msg = s"Log directory specified does not exist: $logDir"
- if (logDir == DEFAULT_LOG_DIR) {
- msg += " Did you configure the correct one through spark.history.fs.logDirectory?"
+ try {
+ if (!fs.getFileStatus(path).isDirectory) {
+ throw new IllegalArgumentException(
+ "Logging directory specified is not a directory: %s".format(logDir))
}
- throw new IllegalArgumentException(msg)
- }
- if (!fs.getFileStatus(path).isDirectory) {
- throw new IllegalArgumentException(
- "Logging directory specified is not a directory: %s".format(logDir))
+ } catch {
+ case f: FileNotFoundException =>
+ var msg = s"Log directory specified does not exist: $logDir"
+ if (logDir == DEFAULT_LOG_DIR) {
+ msg += " Did you configure the correct one through spark.history.fs.logDirectory?"
+ }
+ throw new FileNotFoundException(msg).initCause(f)
}
// Disable the background thread during tests.
@@ -495,12 +497,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock)
val leftToClean = new mutable.ListBuffer[FsApplicationAttemptInfo]
attemptsToClean.foreach { attempt =>
try {
- val path = new Path(logDir, attempt.logPath)
- if (fs.exists(path)) {
- if (!fs.delete(path, true)) {
- logWarning(s"Error deleting ${path}")
- }
- }
+ fs.delete(new Path(logDir, attempt.logPath), true)
} catch {
case e: AccessControlException =>
logInfo(s"No permission to delete ${attempt.logPath}, ignoring.")
diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala
index fddb9353018a8..ab6554fd8a7e7 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala
@@ -17,7 +17,7 @@
package org.apache.spark.rdd
-import java.io.IOException
+import java.io.{FileNotFoundException, IOException}
import scala.reflect.ClassTag
import scala.util.control.NonFatal
@@ -166,9 +166,6 @@ private[spark] object ReliableCheckpointRDD extends Logging {
val tempOutputPath =
new Path(outputDir, s".$finalOutputName-attempt-${ctx.attemptNumber()}")
- if (fs.exists(tempOutputPath)) {
- throw new IOException(s"Checkpoint failed: temporary path $tempOutputPath already exists")
- }
val bufferSize = env.conf.getInt("spark.buffer.size", 65536)
val fileOutputStream = if (blockSize < 0) {
@@ -240,22 +237,20 @@ private[spark] object ReliableCheckpointRDD extends Logging {
val bufferSize = sc.conf.getInt("spark.buffer.size", 65536)
val partitionerFilePath = new Path(checkpointDirPath, checkpointPartitionerFileName)
val fs = partitionerFilePath.getFileSystem(sc.hadoopConfiguration)
- if (fs.exists(partitionerFilePath)) {
- val fileInputStream = fs.open(partitionerFilePath, bufferSize)
- val serializer = SparkEnv.get.serializer.newInstance()
- val deserializeStream = serializer.deserializeStream(fileInputStream)
- val partitioner = Utils.tryWithSafeFinally[Partitioner] {
- deserializeStream.readObject[Partitioner]
- } {
- deserializeStream.close()
- }
- logDebug(s"Read partitioner from $partitionerFilePath")
- Some(partitioner)
- } else {
- logDebug("No partitioner file")
- None
+ val fileInputStream = fs.open(partitionerFilePath, bufferSize)
+ val serializer = SparkEnv.get.serializer.newInstance()
+ val deserializeStream = serializer.deserializeStream(fileInputStream)
+ val partitioner = Utils.tryWithSafeFinally[Partitioner] {
+ deserializeStream.readObject[Partitioner]
+ } {
+ deserializeStream.close()
}
+ logDebug(s"Read partitioner from $partitionerFilePath")
+ Some(partitioner)
} catch {
+ case e: FileNotFoundException =>
+ logDebug("No partitioner file", e)
+ None
case NonFatal(e) =>
logWarning(s"Error reading partitioner from $checkpointDirPath, " +
s"partitioner will not be recovered which may lead to performance loss", e)
diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala
index 74f187642af21..b6d723c682796 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala
@@ -80,12 +80,7 @@ private[spark] object ReliableRDDCheckpointData extends Logging {
/** Clean up the files associated with the checkpoint data for this RDD. */
def cleanCheckpoint(sc: SparkContext, rddId: Int): Unit = {
checkpointPath(sc, rddId).foreach { path =>
- val fs = path.getFileSystem(sc.hadoopConfiguration)
- if (fs.exists(path)) {
- if (!fs.delete(path, true)) {
- logWarning(s"Error deleting ${path.toString()}")
- }
- }
+ path.getFileSystem(sc.hadoopConfiguration).delete(path, true)
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
index a7d06391176d2..ce7877469f03f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
@@ -91,7 +91,7 @@ private[spark] class EventLoggingListener(
*/
def start() {
if (!fileSystem.getFileStatus(new Path(logBaseDir)).isDirectory) {
- throw new IllegalArgumentException(s"Log directory $logBaseDir does not exist.")
+ throw new IllegalArgumentException(s"Log directory $logBaseDir is not a directory.")
}
val workingPath = logPath + IN_PROGRESS
@@ -100,11 +100,8 @@ private[spark] class EventLoggingListener(
val defaultFs = FileSystem.getDefaultUri(hadoopConf).getScheme
val isDefaultLocal = defaultFs == null || defaultFs == "file"
- if (shouldOverwrite && fileSystem.exists(path)) {
+ if (shouldOverwrite && fileSystem.delete(path, true)) {
logWarning(s"Event log $path already exists. Overwriting...")
- if (!fileSystem.delete(path, true)) {
- logWarning(s"Error deleting $path")
- }
}
/* The Hadoop LocalFileSystem (r1.0.4) has known issues with syncing (HADOOP-7844).
@@ -301,12 +298,6 @@ private[spark] object EventLoggingListener extends Logging {
* @return input stream that holds one JSON record per line.
*/
def openEventLog(log: Path, fs: FileSystem): InputStream = {
- // It's not clear whether FileSystem.open() throws FileNotFoundException or just plain
- // IOException when a file does not exist, so try our best to throw a proper exception.
- if (!fs.exists(log)) {
- throw new FileNotFoundException(s"File $log does not exist.")
- }
-
val in = new BufferedInputStream(fs.open(log))
// Compression codec is encoded as an extension, e.g. app_123.lzf
diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala
index 2f07395edf8d1..df13b32451af2 100644
--- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala
+++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala
@@ -17,7 +17,7 @@
package org.apache.spark.repl
-import java.io.{ByteArrayOutputStream, FilterInputStream, InputStream, IOException}
+import java.io.{ByteArrayOutputStream, FileNotFoundException, FilterInputStream, InputStream, IOException}
import java.net.{HttpURLConnection, URI, URL, URLEncoder}
import java.nio.channels.Channels
@@ -147,10 +147,11 @@ class ExecutorClassLoader(
private def getClassFileInputStreamFromFileSystem(fileSystem: FileSystem)(
pathInDirectory: String): InputStream = {
val path = new Path(directory, pathInDirectory)
- if (fileSystem.exists(path)) {
+ try {
fileSystem.open(path)
- } else {
- throw new ClassNotFoundException(s"Class file not found at path $path")
+ } catch {
+ case _: FileNotFoundException =>
+ throw new ClassNotFoundException(s"Class file not found at path $path")
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
index 3335755fd3b67..bec966b15ed0f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution.streaming.state
-import java.io.{DataInputStream, DataOutputStream, IOException}
+import java.io.{DataInputStream, DataOutputStream, FileNotFoundException, IOException}
import scala.collection.JavaConverters._
import scala.collection.mutable
@@ -171,7 +171,7 @@ private[state] class HDFSBackedStateStoreProvider(
if (tempDeltaFileStream != null) {
tempDeltaFileStream.close()
}
- if (tempDeltaFile != null && fs.exists(tempDeltaFile)) {
+ if (tempDeltaFile != null) {
fs.delete(tempDeltaFile, true)
}
logInfo("Aborted")
@@ -278,14 +278,12 @@ private[state] class HDFSBackedStateStoreProvider(
/** Initialize the store provider */
private def initialize(): Unit = {
- if (!fs.exists(baseDir)) {
+ try {
fs.mkdirs(baseDir)
- } else {
- if (!fs.isDirectory(baseDir)) {
+ } catch {
+ case e: IOException =>
throw new IllegalStateException(
- s"Cannot use ${id.checkpointLocation} for storing state data for $this as " +
- s"$baseDir already exists and is not a directory")
- }
+ s"Cannot use ${id.checkpointLocation} for storing state data for $this: $e ", e)
}
}
@@ -340,13 +338,16 @@ private[state] class HDFSBackedStateStoreProvider(
private def updateFromDeltaFile(version: Long, map: MapType): Unit = {
val fileToRead = deltaFile(version)
- if (!fs.exists(fileToRead)) {
- throw new IllegalStateException(
- s"Error reading delta file $fileToRead of $this: $fileToRead does not exist")
- }
var input: DataInputStream = null
+ val sourceStream = try {
+ fs.open(fileToRead)
+ } catch {
+ case f: FileNotFoundException =>
+ throw new IllegalStateException(
+ s"Error reading delta file $fileToRead of $this: $fileToRead does not exist", f)
+ }
try {
- input = decompressStream(fs.open(fileToRead))
+ input = decompressStream(sourceStream)
var eof = false
while(!eof) {
@@ -405,8 +406,6 @@ private[state] class HDFSBackedStateStoreProvider(
private def readSnapshotFile(version: Long): Option[MapType] = {
val fileToRead = snapshotFile(version)
- if (!fs.exists(fileToRead)) return None
-
val map = new MapType()
var input: DataInputStream = null
@@ -443,6 +442,9 @@ private[state] class HDFSBackedStateStoreProvider(
}
logInfo(s"Read snapshot file for version $version of $this from $fileToRead")
Some(map)
+ } catch {
+ case _: FileNotFoundException =>
+ None
} finally {
if (input != null) input.close()
}
diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java
index e73117c8144ce..061c7431a6362 100644
--- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java
+++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java
@@ -75,9 +75,7 @@ public void setUp() throws IOException {
hiveManagedPath = new Path(
catalog.hiveDefaultTableFilePath(new TableIdentifier("javaSavedTable")));
fs = hiveManagedPath.getFileSystem(sc.hadoopConfiguration());
- if (fs.exists(hiveManagedPath)){
- fs.delete(hiveManagedPath, true);
- }
+ fs.delete(hiveManagedPath, true);
List jsonObjects = new ArrayList<>(10);
for (int i = 0; i < 10; i++) {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
index c36b0275f4161..3892fe87e2a80 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
@@ -375,7 +375,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
sessionState.catalog.hiveDefaultTableFilePath(TableIdentifier("ctasJsonTable"))
val filesystemPath = new Path(expectedPath)
val fs = filesystemPath.getFileSystem(spark.sessionState.newHadoopConf())
- if (fs.exists(filesystemPath)) fs.delete(filesystemPath, true)
+ fs.delete(filesystemPath, true)
// It is a managed table when we do not specify the location.
sql(
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
index 398fa6500f093..5cbad8bf3ce6e 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
@@ -117,7 +117,7 @@ object Checkpoint extends Logging {
val path = new Path(checkpointDir)
val fs = fsOption.getOrElse(path.getFileSystem(SparkHadoopUtil.get.conf))
- if (fs.exists(path)) {
+ try {
val statuses = fs.listStatus(path)
if (statuses != null) {
val paths = statuses.map(_.getPath)
@@ -127,9 +127,10 @@ object Checkpoint extends Logging {
logWarning(s"Listing $path returned null")
Seq.empty
}
- } else {
- logWarning(s"Checkpoint directory $path does not exist")
- Seq.empty
+ } catch {
+ case _: FileNotFoundException =>
+ logWarning(s"Checkpoint directory $path does not exist")
+ Seq.empty
}
}
@@ -229,9 +230,7 @@ class CheckpointWriter(
logInfo(s"Saving checkpoint for time $checkpointTime to file '$checkpointFile'")
// Write checkpoint to temp file
- if (fs.exists(tempFile)) {
- fs.delete(tempFile, true) // just in case it exists
- }
+ fs.delete(tempFile, true) // just in case it exists
val fos = fs.create(tempFile)
Utils.tryWithSafeFinally {
fos.write(bytes)
@@ -242,9 +241,7 @@ class CheckpointWriter(
// If the checkpoint file exists, back it up
// If the backup exists as well, just delete it, otherwise rename will fail
if (fs.exists(checkpointFile)) {
- if (fs.exists(backupFile)) {
- fs.delete(backupFile, true) // just in case it exists
- }
+ fs.delete(backupFile, true) // just in case it exists
if (!fs.rename(checkpointFile, backupFile)) {
logWarning(s"Could not rename $checkpointFile to $backupFile")
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala
index 9b689f01b8d39..845f554308c43 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala
@@ -16,6 +16,7 @@
*/
package org.apache.spark.streaming.util
+import java.io.FileNotFoundException
import java.nio.ByteBuffer
import java.util.{Iterator => JIterator}
import java.util.concurrent.RejectedExecutionException
@@ -231,13 +232,25 @@ private[streaming] class FileBasedWriteAheadLog(
val logDirectoryPath = new Path(logDirectory)
val fileSystem = HdfsUtils.getFileSystemForPath(logDirectoryPath, hadoopConf)
- if (fileSystem.exists(logDirectoryPath) &&
- fileSystem.getFileStatus(logDirectoryPath).isDirectory) {
- val logFileInfo = logFilesTologInfo(fileSystem.listStatus(logDirectoryPath).map { _.getPath })
- pastLogs.clear()
- pastLogs ++= logFileInfo
- logInfo(s"Recovered ${logFileInfo.size} write ahead log files from $logDirectory")
- logDebug(s"Recovered files are:\n${logFileInfo.map(_.path).mkString("\n")}")
+ try {
+ // If you call listStatus(file) it returns a stat of the file in the array,
+ // rather than an array listing all the children.
+ // This makes it hard to differentiate listStatus(file) and
+ // listStatus(dir-with-one-child) except by examining the name of the returned status,
+ // and once you've got symlinks in the mix that differentiation isn't easy.
+ // Checking for the path being a directory is one more call to the filesystem, but
+ // leads to much clearer code.
+ if (fileSystem.getFileStatus(logDirectoryPath).isDirectory) {
+ val logFileInfo = logFilesTologInfo(
+ fileSystem.listStatus(logDirectoryPath).map { _.getPath })
+ pastLogs.clear()
+ pastLogs ++= logFileInfo
+ logInfo(s"Recovered ${logFileInfo.size} write ahead log files from $logDirectory")
+ logDebug(s"Recovered files are:\n${logFileInfo.map(_.path).mkString("\n")}")
+ }
+ } catch {
+ case _: FileNotFoundException =>
+ // there is no log directory, hence nothing to recover
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala
index 13a765d035ee8..6a3b3200dccdb 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala
@@ -16,7 +16,7 @@
*/
package org.apache.spark.streaming.util
-import java.io.IOException
+import java.io.{FileNotFoundException, IOException}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs._
@@ -44,18 +44,16 @@ private[streaming] object HdfsUtils {
def getInputStream(path: String, conf: Configuration): FSDataInputStream = {
val dfsPath = new Path(path)
val dfs = getFileSystemForPath(dfsPath, conf)
- if (dfs.isFile(dfsPath)) {
- try {
- dfs.open(dfsPath)
- } catch {
- case e: IOException =>
- // If we are really unlucky, the file may be deleted as we're opening the stream.
- // This can happen as clean up is performed by daemon threads that may be left over from
- // previous runs.
- if (!dfs.isFile(dfsPath)) null else throw e
- }
- } else {
- null
+ try {
+ dfs.open(dfsPath)
+ } catch {
+ case _: FileNotFoundException =>
+ null
+ case e: IOException =>
+ // If we are really unlucky, the file may be deleted as we're opening the stream.
+ // This can happen as clean up is performed by daemon threads that may be left over from
+ // previous runs.
+ if (!dfs.isFile(dfsPath)) null else throw e
}
}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index e3572d781b0db..93684005f1cc0 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -189,9 +189,8 @@ private[spark] class Client(
try {
val preserveFiles = sparkConf.get(PRESERVE_STAGING_FILES)
val fs = stagingDirPath.getFileSystem(hadoopConf)
- if (!preserveFiles && fs.exists(stagingDirPath)) {
- logInfo("Deleting staging directory " + stagingDirPath)
- fs.delete(stagingDirPath, true)
+ if (!preserveFiles && fs.delete(stagingDirPath, true)) {
+ logInfo(s"Deleted staging directory $stagingDirPath")
}
} catch {
case ioe: IOException =>
From d60af8f6aa53373de1333cc642cf2a9d7b39d912 Mon Sep 17 00:00:00 2001
From: Tathagata Das
Date: Wed, 17 Aug 2016 13:31:34 -0700
Subject: [PATCH 012/270] [SPARK-17096][SQL][STREAMING] Improve exception
string reported through the StreamingQueryListener
## What changes were proposed in this pull request?
Currently, the stackTrace (as `Array[StackTraceElements]`) reported through StreamingQueryListener.onQueryTerminated is useless as it has the stack trace of where StreamingQueryException is defined, not the stack trace of underlying exception. For example, if a streaming query fails because of a / by zero exception in a task, the `QueryTerminated.stackTrace` will have
```
org.apache.spark.sql.execution.streaming.StreamExecution.org$apache$spark$sql$execution$streaming$StreamExecution$$runBatches(StreamExecution.scala:211)
org.apache.spark.sql.execution.streaming.StreamExecution$$anon$1.run(StreamExecution.scala:124)
```
This is basically useless, as it is location where the StreamingQueryException was defined. What we want is
Here is the right way to reason about what should be posted as through StreamingQueryListener.onQueryTerminated
- The actual exception could either be a SparkException, or an arbitrary exception.
- SparkException reports the relevant executor stack trace of a failed task as a string in the the exception message. The `Array[StackTraceElements]` returned by `SparkException.stackTrace()` is mostly irrelevant.
- For any arbitrary exception, the `Array[StackTraceElements]` returned by `exception.stackTrace()` may be relevant.
- When there is an error in a streaming query, it's hard to reason whether the `Array[StackTraceElements]` is useful or not. In fact, it is not clear whether it is even useful to report the stack trace as this array of Java objects. It may be sufficient to report the strack trace as a string, along with the message. This is how Spark reported executor stra
- Hence, this PR simplifies the API by removing the array `stackTrace` from `QueryTerminated`. Instead the `exception` returns a string containing the message and the stack trace of the actual underlying exception that failed the streaming query (i.e. not that of the StreamingQueryException). If anyone is interested in the actual stack trace as an array, can always access them through `streamingQuery.exception` which returns the exception object.
With this change, if a streaming query fails because of a / by zero exception in a task, the `QueryTerminated.exception` will be
```
org.apache.spark.SparkException: Job aborted due to stage failure: Task 1 in stage 0.0 failed 1 times, most recent failure: Lost task 1.0 in stage 0.0 (TID 1, localhost): java.lang.ArithmeticException: / by zero
at org.apache.spark.sql.streaming.StreamingQueryListenerSuite$$anonfun$5$$anonfun$apply$mcV$sp$4$$anonfun$apply$mcV$sp$5.apply$mcII$sp(StreamingQueryListenerSuite.scala:153)
at org.apache.spark.sql.streaming.StreamingQueryListenerSuite$$anonfun$5$$anonfun$apply$mcV$sp$4$$anonfun$apply$mcV$sp$5.apply(StreamingQueryListenerSuite.scala:153)
at org.apache.spark.sql.streaming.StreamingQueryListenerSuite$$anonfun$5$$anonfun$apply$mcV$sp$4$$anonfun$apply$mcV$sp$5.apply(StreamingQueryListenerSuite.scala:153)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:370)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:232)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:226)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:803)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:803)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:319)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:283)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:70)
at org.apache.spark.scheduler.Task.run(Task.scala:86)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:274)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615)
at java.lang.Thread.run(Thread.java:744)
Driver stacktrace:
at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1429)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1417)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1416)
at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1416)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:802)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:802)
...
```
It contains the relevant executor stack trace. In a case non-SparkException, if the streaming source MemoryStream throws an exception, exception message will have the relevant stack trace.
```
java.lang.RuntimeException: this is the exception message
at org.apache.spark.sql.execution.streaming.MemoryStream.getBatch(memory.scala:103)
at org.apache.spark.sql.execution.streaming.StreamExecution$$anonfun$5.apply(StreamExecution.scala:316)
at org.apache.spark.sql.execution.streaming.StreamExecution$$anonfun$5.apply(StreamExecution.scala:313)
at scala.collection.TraversableLike$$anonfun$flatMap$1.apply(TraversableLike.scala:241)
at scala.collection.TraversableLike$$anonfun$flatMap$1.apply(TraversableLike.scala:241)
at scala.collection.Iterator$class.foreach(Iterator.scala:893)
at scala.collection.AbstractIterator.foreach(Iterator.scala:1336)
at scala.collection.IterableLike$class.foreach(IterableLike.scala:72)
at org.apache.spark.sql.execution.streaming.StreamProgress.foreach(StreamProgress.scala:25)
at scala.collection.TraversableLike$class.flatMap(TraversableLike.scala:241)
at org.apache.spark.sql.execution.streaming.StreamProgress.flatMap(StreamProgress.scala:25)
at org.apache.spark.sql.execution.streaming.StreamExecution.org$apache$spark$sql$execution$streaming$StreamExecution$$runBatch(StreamExecution.scala:313)
at org.apache.spark.sql.execution.streaming.StreamExecution$$anonfun$org$apache$spark$sql$execution$streaming$StreamExecution$$runBatches$1.apply$mcZ$sp(StreamExecution.scala:197)
at org.apache.spark.sql.execution.streaming.ProcessingTimeExecutor.execute(TriggerExecutor.scala:43)
at org.apache.spark.sql.execution.streaming.StreamExecution.org$apache$spark$sql$execution$streaming$StreamExecution$$runBatches(StreamExecution.scala:187)
at org.apache.spark.sql.execution.streaming.StreamExecution$$anon$1.run(StreamExecution.scala:124)
```
Note that this change in the public `QueryTerminated` class is okay as the APIs are still experimental.
## How was this patch tested?
Unit tests that test whether the right information is present in the exception message reported through QueryTerminated object.
Author: Tathagata Das
Closes #14675 from tdas/SPARK-17096.
---
.../sql/execution/streaming/StreamExecution.scala | 5 +----
.../sql/streaming/StreamingQueryException.scala | 3 ++-
.../sql/streaming/StreamingQueryListener.scala | 3 +--
.../sql/streaming/StreamingQueryListenerSuite.scala | 13 ++++++-------
4 files changed, 10 insertions(+), 14 deletions(-)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index 66fb5a4bdeb7f..4d05af0b60358 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -217,10 +217,7 @@ class StreamExecution(
} finally {
state = TERMINATED
sparkSession.streams.notifyQueryTermination(StreamExecution.this)
- postEvent(new QueryTerminated(
- this.toInfo,
- exception.map(_.getMessage),
- exception.map(_.getStackTrace.toSeq).getOrElse(Nil)))
+ postEvent(new QueryTerminated(this.toInfo, exception.map(_.cause).map(Utils.exceptionString)))
terminationLatch.countDown()
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala
index 90f95ca9d4229..bd3e5a5618ec4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala
@@ -22,7 +22,8 @@ import org.apache.spark.sql.execution.streaming.{Offset, StreamExecution}
/**
* :: Experimental ::
- * Exception that stopped a [[StreamingQuery]].
+ * Exception that stopped a [[StreamingQuery]]. Use `cause` get the actual exception
+ * that caused the failure.
* @param query Query that caused the exception
* @param message Message of this exception
* @param cause Internal cause of this exception
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala
index 3b3cead3a66de..db606abb8ce43 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala
@@ -108,6 +108,5 @@ object StreamingQueryListener {
@Experimental
class QueryTerminated private[sql](
val queryInfo: StreamingQueryInfo,
- val exception: Option[String],
- val stackTrace: Seq[StackTraceElement]) extends Event
+ val exception: Option[String]) extends Event
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala
index 7f4d28cf0598f..77602e8167fa3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala
@@ -94,7 +94,6 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
assert(status.id === query.id)
assert(status.sourceStatuses(0).offsetDesc === Some(LongOffset(0).toString))
assert(status.sinkStatus.offsetDesc === CompositeOffset.fill(LongOffset(0)).toString)
- assert(listener.terminationStackTrace.isEmpty)
assert(listener.terminationException === None)
}
listener.checkAsyncErrors()
@@ -147,7 +146,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
}
}
- test("exception should be reported in QueryTerminated") {
+ testQuietly("exception should be reported in QueryTerminated") {
val listener = new QueryStatusCollector
withListenerAdded(listener) {
val input = MemoryStream[Int]
@@ -159,8 +158,11 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
spark.sparkContext.listenerBus.waitUntilEmpty(10000)
assert(listener.terminationStatus !== null)
assert(listener.terminationException.isDefined)
+ // Make sure that the exception message reported through listener
+ // contains the actual exception and relevant stack trace
+ assert(!listener.terminationException.get.contains("StreamingQueryException"))
assert(listener.terminationException.get.contains("java.lang.ArithmeticException"))
- assert(listener.terminationStackTrace.nonEmpty)
+ assert(listener.terminationException.get.contains("StreamingQueryListenerSuite"))
}
)
}
@@ -205,8 +207,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
val exception = new RuntimeException("exception")
val queryQueryTerminated = new StreamingQueryListener.QueryTerminated(
queryTerminatedInfo,
- Some(exception.getMessage),
- exception.getStackTrace)
+ Some(exception.getMessage))
val json =
JsonProtocol.sparkEventToJson(queryQueryTerminated)
val newQueryTerminated = JsonProtocol.sparkEventFromJson(json)
@@ -262,7 +263,6 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
@volatile var startStatus: StreamingQueryInfo = null
@volatile var terminationStatus: StreamingQueryInfo = null
@volatile var terminationException: Option[String] = null
- @volatile var terminationStackTrace: Seq[StackTraceElement] = null
val progressStatuses = new ConcurrentLinkedQueue[StreamingQueryInfo]
@@ -296,7 +296,6 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
assert(startStatus != null, "onQueryTerminated called before onQueryStarted")
terminationStatus = queryTerminated.queryInfo
terminationException = queryTerminated.exception
- terminationStackTrace = queryTerminated.stackTrace
}
asyncTestWaiter.dismiss()
}
From e6bef7d52f0e19ec771fb0f3e96c7ddbd1a6a19b Mon Sep 17 00:00:00 2001
From: Xin Ren
Date: Wed, 17 Aug 2016 16:31:42 -0700
Subject: [PATCH 013/270] [SPARK-17038][STREAMING] fix metrics retrieval source
of 'lastReceivedBatch'
https://issues.apache.org/jira/browse/SPARK-17038
## What changes were proposed in this pull request?
StreamingSource's lastReceivedBatch_submissionTime, lastReceivedBatch_processingTimeStart, and lastReceivedBatch_processingTimeEnd all use data from lastCompletedBatch instead of lastReceivedBatch.
In particular, this makes it impossible to match lastReceivedBatch_records with a batchID/submission time.
This is apparent when looking at StreamingSource.scala, lines 89-94.
## How was this patch tested?
Manually running unit tests on local laptop
Author: Xin Ren
Closes #14681 from keypointt/SPARK-17038.
---
.../scala/org/apache/spark/streaming/StreamingSource.scala | 6 +++---
.../streaming/ui/StreamingJobProgressListenerSuite.scala | 3 +++
2 files changed, 6 insertions(+), 3 deletions(-)
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala
index 9697437dd2fe5..0b306a28d1a59 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala
@@ -87,11 +87,11 @@ private[streaming] class StreamingSource(ssc: StreamingContext) extends Source {
// Gauge for last received batch, useful for monitoring the streaming job's running status,
// displayed data -1 for any abnormal condition.
registerGaugeWithOption("lastReceivedBatch_submissionTime",
- _.lastCompletedBatch.map(_.submissionTime), -1L)
+ _.lastReceivedBatch.map(_.submissionTime), -1L)
registerGaugeWithOption("lastReceivedBatch_processingStartTime",
- _.lastCompletedBatch.flatMap(_.processingStartTime), -1L)
+ _.lastReceivedBatch.flatMap(_.processingStartTime), -1L)
registerGaugeWithOption("lastReceivedBatch_processingEndTime",
- _.lastCompletedBatch.flatMap(_.processingEndTime), -1L)
+ _.lastReceivedBatch.flatMap(_.processingEndTime), -1L)
// Gauge for last received batch records.
registerGauge("lastReceivedBatch_records", _.lastReceivedBatchRecords.values.sum, 0L)
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala
index 26b757cc2d535..46ab3ac8de3d4 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala
@@ -68,6 +68,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers {
listener.waitingBatches should be (List(BatchUIData(batchInfoSubmitted)))
listener.runningBatches should be (Nil)
listener.retainedCompletedBatches should be (Nil)
+ listener.lastReceivedBatch should be (Some(BatchUIData(batchInfoSubmitted)))
listener.lastCompletedBatch should be (None)
listener.numUnprocessedBatches should be (1)
listener.numTotalCompletedBatches should be (0)
@@ -81,6 +82,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers {
listener.waitingBatches should be (Nil)
listener.runningBatches should be (List(BatchUIData(batchInfoStarted)))
listener.retainedCompletedBatches should be (Nil)
+ listener.lastReceivedBatch should be (Some(BatchUIData(batchInfoStarted)))
listener.lastCompletedBatch should be (None)
listener.numUnprocessedBatches should be (1)
listener.numTotalCompletedBatches should be (0)
@@ -123,6 +125,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers {
listener.waitingBatches should be (Nil)
listener.runningBatches should be (Nil)
listener.retainedCompletedBatches should be (List(BatchUIData(batchInfoCompleted)))
+ listener.lastReceivedBatch should be (Some(BatchUIData(batchInfoCompleted)))
listener.lastCompletedBatch should be (Some(BatchUIData(batchInfoCompleted)))
listener.numUnprocessedBatches should be (0)
listener.numTotalCompletedBatches should be (1)
From 10204b9d29cd69895f5a606e75510dc64cf2e009 Mon Sep 17 00:00:00 2001
From: Liang-Chi Hsieh
Date: Thu, 18 Aug 2016 13:24:12 +0800
Subject: [PATCH 014/270] [SPARK-16995][SQL] TreeNodeException when flat
mapping RelationalGroupedDataset created from DataFrame containing a column
created with lit/expr
## What changes were proposed in this pull request?
A TreeNodeException is thrown when executing the following minimal example in Spark 2.0.
import spark.implicits._
case class test (x: Int, q: Int)
val d = Seq(1).toDF("x")
d.withColumn("q", lit(0)).as[test].groupByKey(_.x).flatMapGroups{case (x, iter) => List[Int]()}.show
d.withColumn("q", expr("0")).as[test].groupByKey(_.x).flatMapGroups{case (x, iter) => List[Int]()}.show
The problem is at `FoldablePropagation`. The rule will do `transformExpressions` on `LogicalPlan`. The query above contains a `MapGroups` which has a parameter `dataAttributes:Seq[Attribute]`. One attributes in `dataAttributes` will be transformed to an `Alias(literal(0), _)` in `FoldablePropagation`. `Alias` is not an `Attribute` and causes the error.
We can't easily detect such type inconsistency during transforming expressions. A direct approach to this problem is to skip doing `FoldablePropagation` on object operators as they should not contain such expressions.
## How was this patch tested?
Jenkins tests.
Author: Liang-Chi Hsieh
Closes #14648 from viirya/flat-mapping.
---
.../spark/sql/catalyst/optimizer/Optimizer.scala | 13 +++++++++++++
.../scala/org/apache/spark/sql/DatasetSuite.scala | 13 +++++++++++++
2 files changed, 26 insertions(+)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index aa15f4a82383c..b53c0b5beccf2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -727,6 +727,19 @@ object FoldablePropagation extends Rule[LogicalPlan] {
case j @ Join(_, _, LeftOuter | RightOuter | FullOuter, _) =>
stop = true
j
+
+ // These 3 operators take attributes as constructor parameters, and these attributes
+ // can't be replaced by alias.
+ case m: MapGroups =>
+ stop = true
+ m
+ case f: FlatMapGroupsInR =>
+ stop = true
+ f
+ case c: CoGroup =>
+ stop = true
+ c
+
case p: LogicalPlan if !stop => p.transformExpressions {
case a: AttributeReference if foldableMap.contains(a) =>
foldableMap(a)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 88fb1472b668b..8ce6ea66b6bbf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -878,6 +878,19 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
val ds = spark.createDataset(data)(enc)
checkDataset(ds, (("a", "b"), "c"), (null, "d"))
}
+
+ test("SPARK-16995: flat mapping on Dataset containing a column created with lit/expr") {
+ val df = Seq("1").toDF("a")
+
+ import df.sparkSession.implicits._
+
+ checkDataset(
+ df.withColumn("b", lit(0)).as[ClassData]
+ .groupByKey(_.a).flatMapGroups { case (x, iter) => List[Int]() })
+ checkDataset(
+ df.withColumn("b", expr("0")).as[ClassData]
+ .groupByKey(_.a).flatMapGroups { case (x, iter) => List[Int]() })
+ }
}
case class Generic[T](id: T, value: Double)
From 3e6ef2e8a435a91b6a76876e9833917e5aa0945e Mon Sep 17 00:00:00 2001
From: petermaxlee
Date: Thu, 18 Aug 2016 16:17:01 +0800
Subject: [PATCH 015/270] [SPARK-17034][SQL] Minor code cleanup for
UnresolvedOrdinal
## What changes were proposed in this pull request?
I was looking at the code for UnresolvedOrdinal and made a few small changes to make it slightly more clear:
1. Rename the rule to SubstituteUnresolvedOrdinals which is more consistent with other rules that start with verbs. Note that this is still inconsistent with CTESubstitution and WindowsSubstitution.
2. Broke the test suite down from a single test case to three test cases.
## How was this patch tested?
This is a minor cleanup.
Author: petermaxlee
Closes #14672 from petermaxlee/SPARK-17034.
---
.../sql/catalyst/analysis/Analyzer.scala | 2 +-
...ala => SubstituteUnresolvedOrdinals.scala} | 26 ++++++++++---------
.../sql/catalyst/planning/patterns.scala | 13 ----------
...> SubstituteUnresolvedOrdinalsSuite.scala} | 24 +++++++++--------
4 files changed, 28 insertions(+), 37 deletions(-)
rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/{UnresolvedOrdinalSubstitution.scala => SubstituteUnresolvedOrdinals.scala} (69%)
rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/{UnresolvedOrdinalSubstitutionSuite.scala => SubstituteUnresolvedOrdinalsSuite.scala} (76%)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index f540816366ca8..cfab6ae7bd02b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -84,7 +84,7 @@ class Analyzer(
CTESubstitution,
WindowsSubstitution,
EliminateUnions,
- new UnresolvedOrdinalSubstitution(conf)),
+ new SubstituteUnresolvedOrdinals(conf)),
Batch("Resolution", fixedPoint,
ResolveRelations ::
ResolveReferences ::
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolvedOrdinalSubstitution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala
similarity index 69%
rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolvedOrdinalSubstitution.scala
rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala
index e21cd08af8b0d..6d8dc8628229a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolvedOrdinalSubstitution.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala
@@ -18,32 +18,34 @@
package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.CatalystConf
-import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder}
-import org.apache.spark.sql.catalyst.planning.IntegerIndex
+import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Sort}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
+import org.apache.spark.sql.types.IntegerType
/**
* Replaces ordinal in 'order by' or 'group by' with UnresolvedOrdinal expression.
*/
-class UnresolvedOrdinalSubstitution(conf: CatalystConf) extends Rule[LogicalPlan] {
- private def isIntegerLiteral(sorter: Expression) = IntegerIndex.unapply(sorter).nonEmpty
+class SubstituteUnresolvedOrdinals(conf: CatalystConf) extends Rule[LogicalPlan] {
+ private def isIntLiteral(e: Expression) = e match {
+ case Literal(_, IntegerType) => true
+ case _ => false
+ }
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case s @ Sort(orders, global, child) if conf.orderByOrdinal &&
- orders.exists(o => isIntegerLiteral(o.child)) =>
- val newOrders = orders.map {
- case order @ SortOrder(ordinal @ IntegerIndex(index: Int), _) =>
+ case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) =>
+ val newOrders = s.order.map {
+ case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _) =>
val newOrdinal = withOrigin(ordinal.origin)(UnresolvedOrdinal(index))
withOrigin(order.origin)(order.copy(child = newOrdinal))
case other => other
}
withOrigin(s.origin)(s.copy(order = newOrders))
- case a @ Aggregate(groups, aggs, child) if conf.groupByOrdinal &&
- groups.exists(isIntegerLiteral(_)) =>
- val newGroups = groups.map {
- case ordinal @ IntegerIndex(index) =>
+
+ case a: Aggregate if conf.groupByOrdinal && a.groupingExpressions.exists(isIntLiteral) =>
+ val newGroups = a.groupingExpressions.map {
+ case ordinal @ Literal(index: Int, IntegerType) =>
withOrigin(ordinal.origin)(UnresolvedOrdinal(index))
case other => other
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index f42e67ca6ec20..476c66af76b29 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -208,19 +208,6 @@ object Unions {
}
}
-/**
- * Extractor for retrieving Int value.
- */
-object IntegerIndex {
- def unapply(a: Any): Option[Int] = a match {
- case Literal(a: Int, IntegerType) => Some(a)
- // When resolving ordinal in Sort and Group By, negative values are extracted
- // for issuing error messages.
- case UnaryMinus(IntegerLiteral(v)) => Some(-v)
- case _ => None
- }
-}
-
/**
* An extractor used when planning the physical execution of an aggregation. Compared with a logical
* aggregation, the following transformations are performed:
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnresolvedOrdinalSubstitutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinalsSuite.scala
similarity index 76%
rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnresolvedOrdinalSubstitutionSuite.scala
rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinalsSuite.scala
index 23995e96e1d2b..3c429ebce1a8d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnresolvedOrdinalSubstitutionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinalsSuite.scala
@@ -23,20 +23,21 @@ import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.SimpleCatalystConf
-class UnresolvedOrdinalSubstitutionSuite extends AnalysisTest {
-
- test("test rule UnresolvedOrdinalSubstitution, replaces ordinal in order by or group by") {
- val a = testRelation2.output(0)
- val b = testRelation2.output(1)
- val conf = new SimpleCatalystConf(caseSensitiveAnalysis = true)
+class SubstituteUnresolvedOrdinalsSuite extends AnalysisTest {
+ private lazy val conf = SimpleCatalystConf(caseSensitiveAnalysis = true)
+ private lazy val a = testRelation2.output(0)
+ private lazy val b = testRelation2.output(1)
+ test("unresolved ordinal should not be unresolved") {
// Expression OrderByOrdinal is unresolved.
assert(!UnresolvedOrdinal(0).resolved)
+ }
+ test("order by ordinal") {
// Tests order by ordinal, apply single rule.
val plan = testRelation2.orderBy(Literal(1).asc, Literal(2).asc)
comparePlans(
- new UnresolvedOrdinalSubstitution(conf).apply(plan),
+ new SubstituteUnresolvedOrdinals(conf).apply(plan),
testRelation2.orderBy(UnresolvedOrdinal(1).asc, UnresolvedOrdinal(2).asc))
// Tests order by ordinal, do full analysis
@@ -44,14 +45,15 @@ class UnresolvedOrdinalSubstitutionSuite extends AnalysisTest {
// order by ordinal can be turned off by config
comparePlans(
- new UnresolvedOrdinalSubstitution(conf.copy(orderByOrdinal = false)).apply(plan),
+ new SubstituteUnresolvedOrdinals(conf.copy(orderByOrdinal = false)).apply(plan),
testRelation2.orderBy(Literal(1).asc, Literal(2).asc))
+ }
-
+ test("group by ordinal") {
// Tests group by ordinal, apply single rule.
val plan2 = testRelation2.groupBy(Literal(1), Literal(2))('a, 'b)
comparePlans(
- new UnresolvedOrdinalSubstitution(conf).apply(plan2),
+ new SubstituteUnresolvedOrdinals(conf).apply(plan2),
testRelation2.groupBy(UnresolvedOrdinal(1), UnresolvedOrdinal(2))('a, 'b))
// Tests group by ordinal, do full analysis
@@ -59,7 +61,7 @@ class UnresolvedOrdinalSubstitutionSuite extends AnalysisTest {
// group by ordinal can be turned off by config
comparePlans(
- new UnresolvedOrdinalSubstitution(conf.copy(groupByOrdinal = false)).apply(plan2),
+ new SubstituteUnresolvedOrdinals(conf.copy(groupByOrdinal = false)).apply(plan2),
testRelation2.groupBy(Literal(1), Literal(2))('a, 'b))
}
}
From 1748f824101870b845dbbd118763c6885744f98a Mon Sep 17 00:00:00 2001
From: Reynold Xin
Date: Thu, 18 Aug 2016 16:37:25 +0800
Subject: [PATCH 016/270] [SPARK-16391][SQL] Support partial aggregation for
reduceGroups
## What changes were proposed in this pull request?
This patch introduces a new private ReduceAggregator interface that is a subclass of Aggregator. ReduceAggregator only requires a single associative and commutative reduce function. ReduceAggregator is also used to implement KeyValueGroupedDataset.reduceGroups in order to support partial aggregation.
Note that the pull request was initially done by viirya.
## How was this patch tested?
Covered by original tests for reduceGroups, as well as a new test suite for ReduceAggregator.
Author: Reynold Xin
Author: Liang-Chi Hsieh
Closes #14576 from rxin/reduceAggregator.
---
.../spark/sql/KeyValueGroupedDataset.scala | 10 +--
.../sql/expressions/ReduceAggregator.scala | 68 +++++++++++++++++
.../expressions/ReduceAggregatorSuite.scala | 73 +++++++++++++++++++
3 files changed, 146 insertions(+), 5 deletions(-)
create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala
create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 65a725f3d4a81..61a3e6e0bc4f8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -21,10 +21,11 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.function._
-import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, OuterScopes}
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.QueryExecution
+import org.apache.spark.sql.expressions.ReduceAggregator
/**
* :: Experimental ::
@@ -177,10 +178,9 @@ class KeyValueGroupedDataset[K, V] private[sql](
* @since 1.6.0
*/
def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = {
- val func = (key: K, it: Iterator[V]) => Iterator((key, it.reduce(f)))
-
- implicit val resultEncoder = ExpressionEncoder.tuple(kExprEnc, vExprEnc)
- flatMapGroups(func)
+ val vEncoder = encoderFor[V]
+ val aggregator: TypedColumn[V, V] = new ReduceAggregator[V](f)(vEncoder).toColumn
+ agg(aggregator)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala
new file mode 100644
index 0000000000000..174378304d4a5
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.expressions
+
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+
+/**
+ * An aggregator that uses a single associative and commutative reduce function. This reduce
+ * function can be used to go through all input values and reduces them to a single value.
+ * If there is no input, a null value is returned.
+ *
+ * This class currently assumes there is at least one input row.
+ */
+private[sql] class ReduceAggregator[T: Encoder](func: (T, T) => T)
+ extends Aggregator[T, (Boolean, T), T] {
+
+ private val encoder = implicitly[Encoder[T]]
+
+ override def zero: (Boolean, T) = (false, null.asInstanceOf[T])
+
+ override def bufferEncoder: Encoder[(Boolean, T)] =
+ ExpressionEncoder.tuple(
+ ExpressionEncoder[Boolean](),
+ encoder.asInstanceOf[ExpressionEncoder[T]])
+
+ override def outputEncoder: Encoder[T] = encoder
+
+ override def reduce(b: (Boolean, T), a: T): (Boolean, T) = {
+ if (b._1) {
+ (true, func(b._2, a))
+ } else {
+ (true, a)
+ }
+ }
+
+ override def merge(b1: (Boolean, T), b2: (Boolean, T)): (Boolean, T) = {
+ if (!b1._1) {
+ b2
+ } else if (!b2._1) {
+ b1
+ } else {
+ (true, func(b1._2, b2._2))
+ }
+ }
+
+ override def finish(reduction: (Boolean, T)): T = {
+ if (!reduction._1) {
+ throw new IllegalStateException("ReduceAggregator requires at least one input row")
+ }
+ reduction._2
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala
new file mode 100644
index 0000000000000..d826d3f54d922
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.expressions
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.Encoders
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+
+class ReduceAggregatorSuite extends SparkFunSuite {
+
+ test("zero value") {
+ val encoder: ExpressionEncoder[Int] = ExpressionEncoder()
+ val func = (v1: Int, v2: Int) => v1 + v2
+ val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func)(Encoders.scalaInt)
+ assert(aggregator.zero == (false, null))
+ }
+
+ test("reduce, merge and finish") {
+ val encoder: ExpressionEncoder[Int] = ExpressionEncoder()
+ val func = (v1: Int, v2: Int) => v1 + v2
+ val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func)(Encoders.scalaInt)
+
+ val firstReduce = aggregator.reduce(aggregator.zero, 1)
+ assert(firstReduce == (true, 1))
+
+ val secondReduce = aggregator.reduce(firstReduce, 2)
+ assert(secondReduce == (true, 3))
+
+ val thirdReduce = aggregator.reduce(secondReduce, 3)
+ assert(thirdReduce == (true, 6))
+
+ val mergeWithZero1 = aggregator.merge(aggregator.zero, firstReduce)
+ assert(mergeWithZero1 == (true, 1))
+
+ val mergeWithZero2 = aggregator.merge(secondReduce, aggregator.zero)
+ assert(mergeWithZero2 == (true, 3))
+
+ val mergeTwoReduced = aggregator.merge(firstReduce, secondReduce)
+ assert(mergeTwoReduced == (true, 4))
+
+ assert(aggregator.finish(firstReduce)== 1)
+ assert(aggregator.finish(secondReduce) == 3)
+ assert(aggregator.finish(thirdReduce) == 6)
+ assert(aggregator.finish(mergeWithZero1) == 1)
+ assert(aggregator.finish(mergeWithZero2) == 3)
+ assert(aggregator.finish(mergeTwoReduced) == 4)
+ }
+
+ test("requires at least one input row") {
+ val encoder: ExpressionEncoder[Int] = ExpressionEncoder()
+ val func = (v1: Int, v2: Int) => v1 + v2
+ val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func)(Encoders.scalaInt)
+
+ intercept[IllegalStateException] {
+ aggregator.finish(aggregator.zero)
+ }
+ }
+}
From e82dbe600e0d36d76cd5607a77c3243a26777b77 Mon Sep 17 00:00:00 2001
From: Liang-Chi Hsieh
Date: Thu, 18 Aug 2016 12:45:56 +0200
Subject: [PATCH 017/270] [SPARK-17107][SQL] Remove redundant pushdown rule for
Union
## What changes were proposed in this pull request?
The `Optimizer` rules `PushThroughSetOperations` and `PushDownPredicate` have a redundant rule to push down `Filter` through `Union`. We should remove it.
## How was this patch tested?
Jenkins tests.
Author: Liang-Chi Hsieh
Closes #14687 from viirya/remove-extra-pushdown.
---
.../sql/catalyst/optimizer/Optimizer.scala | 21 +++++--------------
.../optimizer/SetOperationSuite.scala | 3 ++-
2 files changed, 7 insertions(+), 17 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index b53c0b5beccf2..f7aa6da0a5bdc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -75,7 +75,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
RemoveRepetitionFromGroupExpressions) ::
Batch("Operator Optimizations", fixedPoint,
// Operator push down
- PushThroughSetOperations,
+ PushProjectionThroughUnion,
ReorderJoin,
EliminateOuterJoin,
PushPredicateThroughJoin,
@@ -302,14 +302,14 @@ object LimitPushDown extends Rule[LogicalPlan] {
}
/**
- * Pushes certain operations to both sides of a Union operator.
+ * Pushes Project operator to both sides of a Union operator.
* Operations that are safe to pushdown are listed as follows.
* Union:
* Right now, Union means UNION ALL, which does not de-duplicate rows. So, it is
- * safe to pushdown Filters and Projections through it. Once we add UNION DISTINCT,
- * we will not be able to pushdown Projections.
+ * safe to pushdown Filters and Projections through it. Filter pushdown is handled by another
+ * rule PushDownPredicate. Once we add UNION DISTINCT, we will not be able to pushdown Projections.
*/
-object PushThroughSetOperations extends Rule[LogicalPlan] with PredicateHelper {
+object PushProjectionThroughUnion extends Rule[LogicalPlan] with PredicateHelper {
/**
* Maps Attributes from the left side to the corresponding Attribute on the right side.
@@ -364,17 +364,6 @@ object PushThroughSetOperations extends Rule[LogicalPlan] with PredicateHelper {
} else {
p
}
-
- // Push down filter into union
- case Filter(condition, Union(children)) =>
- assert(children.nonEmpty)
- val (deterministic, nondeterministic) = partitionByDeterministic(condition)
- val newFirstChild = Filter(deterministic, children.head)
- val newOtherChildren = children.tail.map { child =>
- val rewrites = buildRewrites(children.head, child)
- Filter(pushToRight(deterministic, rewrites), child)
- }
- Filter(nondeterministic, Union(newFirstChild +: newOtherChildren))
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
index dab45a6b166be..7227706ab2b36 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
@@ -31,7 +31,8 @@ class SetOperationSuite extends PlanTest {
EliminateSubqueryAliases) ::
Batch("Union Pushdown", Once,
CombineUnions,
- PushThroughSetOperations,
+ PushProjectionThroughUnion,
+ PushDownPredicate,
PruneFilters) :: Nil
}
From b81421afb04959bb22b53653be0a09c1f1c5845f Mon Sep 17 00:00:00 2001
From: Stavros Kontopoulos
Date: Thu, 18 Aug 2016 12:19:19 +0100
Subject: [PATCH 018/270] [SPARK-17087][MESOS] Documentation for Making Spark
on Mesos honor port restrictions
## What changes were proposed in this pull request?
- adds documentation for https://issues.apache.org/jira/browse/SPARK-11714
## How was this patch tested?
Doc no test needed.
Author: Stavros Kontopoulos
Closes #14667 from skonto/add_doc.
---
docs/running-on-mesos.md | 10 ++++++++++
1 file changed, 10 insertions(+)
diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md
index a6ce34c761c82..173961deaadcb 100644
--- a/docs/running-on-mesos.md
+++ b/docs/running-on-mesos.md
@@ -207,6 +207,16 @@ The scheduler will start executors round-robin on the offers Mesos
gives it, but there are no spread guarantees, as Mesos does not
provide such guarantees on the offer stream.
+In this mode spark executors will honor port allocation if such is
+provided from the user. Specifically if the user defines
+`spark.executor.port` or `spark.blockManager.port` in Spark configuration,
+the mesos scheduler will check the available offers for a valid port
+range containing the port numbers. If no such range is available it will
+not launch any task. If no restriction is imposed on port numbers by the
+user, ephemeral ports are used as usual. This port honouring implementation
+implies one task per host if the user defines a port. In the future network
+isolation shall be supported.
+
The benefit of coarse-grained mode is much lower startup overhead, but
at the cost of reserving Mesos resources for the complete duration of
the application. To configure your job to dynamically adjust to its
From 412dba63b511474a6db3c43c8618d803e604bc6b Mon Sep 17 00:00:00 2001
From: Eric Liang
Date: Thu, 18 Aug 2016 13:33:55 +0200
Subject: [PATCH 019/270] [SPARK-17069] Expose spark.range() as table-valued
function in SQL
## What changes were proposed in this pull request?
This adds analyzer rules for resolving table-valued functions, and adds one builtin implementation for range(). The arguments for range() are the same as those of `spark.range()`.
## How was this patch tested?
Unit tests.
cc hvanhovell
Author: Eric Liang
Closes #14656 from ericl/sc-4309.
---
.../spark/sql/catalyst/parser/SqlBase.g4 | 1 +
.../sql/catalyst/analysis/Analyzer.scala | 1 +
.../ResolveTableValuedFunctions.scala | 132 ++++++++++++++++++
.../sql/catalyst/analysis/unresolved.scala | 11 ++
.../sql/catalyst/parser/AstBuilder.scala | 8 ++
.../sql/catalyst/parser/PlanParserSuite.scala | 8 +-
.../inputs/table-valued-functions.sql | 20 +++
.../results/table-valued-functions.sql.out | 87 ++++++++++++
8 files changed, 267 insertions(+), 1 deletion(-)
create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala
create mode 100644 sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql
create mode 100644 sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
index 6122bcdef8f07..cab7c3ff5a8f7 100644
--- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
@@ -433,6 +433,7 @@ relationPrimary
| '(' queryNoWith ')' sample? (AS? strictIdentifier)? #aliasedQuery
| '(' relation ')' sample? (AS? strictIdentifier)? #aliasedRelation
| inlineTable #inlineTableDefault2
+ | identifier '(' (expression (',' expression)*)? ')' #tableValuedFunction
;
inlineTable
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index cfab6ae7bd02b..333dd4d9a4f2a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -86,6 +86,7 @@ class Analyzer(
EliminateUnions,
new SubstituteUnresolvedOrdinals(conf)),
Batch("Resolution", fixedPoint,
+ ResolveTableValuedFunctions ::
ResolveRelations ::
ResolveReferences ::
ResolveDeserializer ::
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala
new file mode 100644
index 0000000000000..7fdf7fa0c06a3
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range}
+import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.types.{DataType, IntegerType, LongType}
+
+/**
+ * Rule that resolves table-valued function references.
+ */
+object ResolveTableValuedFunctions extends Rule[LogicalPlan] {
+ private lazy val defaultParallelism =
+ SparkContext.getOrCreate(new SparkConf(false)).defaultParallelism
+
+ /**
+ * List of argument names and their types, used to declare a function.
+ */
+ private case class ArgumentList(args: (String, DataType)*) {
+ /**
+ * Try to cast the expressions to satisfy the expected types of this argument list. If there
+ * are any types that cannot be casted, then None is returned.
+ */
+ def implicitCast(values: Seq[Expression]): Option[Seq[Expression]] = {
+ if (args.length == values.length) {
+ val casted = values.zip(args).map { case (value, (_, expectedType)) =>
+ TypeCoercion.ImplicitTypeCasts.implicitCast(value, expectedType)
+ }
+ if (casted.forall(_.isDefined)) {
+ return Some(casted.map(_.get))
+ }
+ }
+ None
+ }
+
+ override def toString: String = {
+ args.map { a =>
+ s"${a._1}: ${a._2.typeName}"
+ }.mkString(", ")
+ }
+ }
+
+ /**
+ * A TVF maps argument lists to resolver functions that accept those arguments. Using a map
+ * here allows for function overloading.
+ */
+ private type TVF = Map[ArgumentList, Seq[Any] => LogicalPlan]
+
+ /**
+ * TVF builder.
+ */
+ private def tvf(args: (String, DataType)*)(pf: PartialFunction[Seq[Any], LogicalPlan])
+ : (ArgumentList, Seq[Any] => LogicalPlan) = {
+ (ArgumentList(args: _*),
+ pf orElse {
+ case args =>
+ throw new IllegalArgumentException(
+ "Invalid arguments for resolved function: " + args.mkString(", "))
+ })
+ }
+
+ /**
+ * Internal registry of table-valued functions.
+ */
+ private val builtinFunctions: Map[String, TVF] = Map(
+ "range" -> Map(
+ /* range(end) */
+ tvf("end" -> LongType) { case Seq(end: Long) =>
+ Range(0, end, 1, defaultParallelism)
+ },
+
+ /* range(start, end) */
+ tvf("start" -> LongType, "end" -> LongType) { case Seq(start: Long, end: Long) =>
+ Range(start, end, 1, defaultParallelism)
+ },
+
+ /* range(start, end, step) */
+ tvf("start" -> LongType, "end" -> LongType, "step" -> LongType) {
+ case Seq(start: Long, end: Long, step: Long) =>
+ Range(start, end, step, defaultParallelism)
+ },
+
+ /* range(start, end, step, numPartitions) */
+ tvf("start" -> LongType, "end" -> LongType, "step" -> LongType,
+ "numPartitions" -> IntegerType) {
+ case Seq(start: Long, end: Long, step: Long, numPartitions: Int) =>
+ Range(start, end, step, numPartitions)
+ })
+ )
+
+ override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) =>
+ builtinFunctions.get(u.functionName) match {
+ case Some(tvf) =>
+ val resolved = tvf.flatMap { case (argList, resolver) =>
+ argList.implicitCast(u.functionArgs) match {
+ case Some(casted) =>
+ Some(resolver(casted.map(_.eval())))
+ case _ =>
+ None
+ }
+ }
+ resolved.headOption.getOrElse {
+ val argTypes = u.functionArgs.map(_.dataType.typeName).mkString(", ")
+ u.failAnalysis(
+ s"""error: table-valued function ${u.functionName} with alternatives:
+ |${tvf.keys.map(_.toString).toSeq.sorted.map(x => s" ($x)").mkString("\n")}
+ |cannot be applied to: (${argTypes})""".stripMargin)
+ }
+ case _ =>
+ u.failAnalysis(s"could not resolve `${u.functionName}` to a table-valued function")
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 42e7aae0b6b05..3735a1501cbfa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -49,6 +49,17 @@ case class UnresolvedRelation(
override lazy val resolved = false
}
+/**
+ * Holds a table-valued function call that has yet to be resolved.
+ */
+case class UnresolvedTableValuedFunction(
+ functionName: String, functionArgs: Seq[Expression]) extends LeafNode {
+
+ override def output: Seq[Attribute] = Nil
+
+ override lazy val resolved = false
+}
+
/**
* Holds the name of an attribute that has yet to be resolved.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index adf78396d7fc0..01322ae327e4a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -657,6 +657,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
table.optionalMap(ctx.sample)(withSample)
}
+ /**
+ * Create a table-valued function call with arguments, e.g. range(1000)
+ */
+ override def visitTableValuedFunction(ctx: TableValuedFunctionContext)
+ : LogicalPlan = withOrigin(ctx) {
+ UnresolvedTableValuedFunction(ctx.identifier.getText, ctx.expression.asScala.map(expression))
+ }
+
/**
* Create an inline table (a virtual table in Hive parlance).
*/
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index 7af333b34f723..cbe4a022e730d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.parser
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.FunctionIdentifier
-import org.apache.spark.sql.catalyst.analysis.UnresolvedGenerator
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedGenerator, UnresolvedTableValuedFunction}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
@@ -426,6 +426,12 @@ class PlanParserSuite extends PlanTest {
assertEqual("table d.t", table("d", "t"))
}
+ test("table valued function") {
+ assertEqual(
+ "select * from range(2)",
+ UnresolvedTableValuedFunction("range", Literal(2) :: Nil).select(star()))
+ }
+
test("inline table") {
assertEqual("values 1, 2, 3, 4", LocalRelation.fromExternalRows(
Seq('col1.int),
diff --git a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql
new file mode 100644
index 0000000000000..2e6dcd538b7ac
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql
@@ -0,0 +1,20 @@
+-- unresolved function
+select * from dummy(3);
+
+-- range call with end
+select * from range(6 + cos(3));
+
+-- range call with start and end
+select * from range(5, 10);
+
+-- range call with step
+select * from range(0, 10, 2);
+
+-- range call with numPartitions
+select * from range(0, 10, 1, 200);
+
+-- range call error
+select * from range(1, 1, 1, 1, 1);
+
+-- range call with null
+select * from range(1, null);
diff --git a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out
new file mode 100644
index 0000000000000..d769bcef0aca7
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out
@@ -0,0 +1,87 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 7
+
+
+-- !query 0
+select * from dummy(3)
+-- !query 0 schema
+struct<>
+-- !query 0 output
+org.apache.spark.sql.AnalysisException
+could not resolve `dummy` to a table-valued function; line 1 pos 14
+
+
+-- !query 1
+select * from range(6 + cos(3))
+-- !query 1 schema
+struct
+-- !query 1 output
+0
+1
+2
+3
+4
+
+
+-- !query 2
+select * from range(5, 10)
+-- !query 2 schema
+struct
+-- !query 2 output
+5
+6
+7
+8
+9
+
+
+-- !query 3
+select * from range(0, 10, 2)
+-- !query 3 schema
+struct
+-- !query 3 output
+0
+2
+4
+6
+8
+
+
+-- !query 4
+select * from range(0, 10, 1, 200)
+-- !query 4 schema
+struct
+-- !query 4 output
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+
+
+-- !query 5
+select * from range(1, 1, 1, 1, 1)
+-- !query 5 schema
+struct<>
+-- !query 5 output
+org.apache.spark.sql.AnalysisException
+error: table-valued function range with alternatives:
+ (end: long)
+ (start: long, end: long)
+ (start: long, end: long, step: long)
+ (start: long, end: long, step: long, numPartitions: integer)
+cannot be applied to: (integer, integer, integer, integer, integer); line 1 pos 14
+
+
+-- !query 6
+select * from range(1, null)
+-- !query 6 schema
+struct<>
+-- !query 6 output
+java.lang.IllegalArgumentException
+Invalid arguments for resolved function: 1, null
From 68f5087d2107d6afec5d5745f0cb0e9e3bdd6a0b Mon Sep 17 00:00:00 2001
From: petermaxlee
Date: Thu, 18 Aug 2016 13:44:13 +0200
Subject: [PATCH 020/270] [SPARK-17117][SQL] 1 / NULL should not fail analysis
## What changes were proposed in this pull request?
This patch fixes the problem described in SPARK-17117, i.e. "SELECT 1 / NULL" throws an analysis exception:
```
org.apache.spark.sql.AnalysisException: cannot resolve '(1 / NULL)' due to data type mismatch: differing types in '(1 / NULL)' (int and null).
```
The problem is that division type coercion did not take null type into account.
## How was this patch tested?
A unit test for the type coercion, and a few end-to-end test cases using SQLQueryTestSuite.
Author: petermaxlee
Closes #14695 from petermaxlee/SPARK-17117.
---
.../sql/catalyst/analysis/TypeCoercion.scala | 7 +-
.../catalyst/analysis/TypeCoercionSuite.scala | 9 +-
.../resources/sql-tests/inputs/arithmetic.sql | 12 ++-
.../sql-tests/results/arithmetic.sql.out | 84 +++++++++++++++----
4 files changed, 89 insertions(+), 23 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 021952e7166f9..21e96aaf53844 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -543,11 +543,14 @@ object TypeCoercion {
// Decimal and Double remain the same
case d: Divide if d.dataType == DoubleType => d
case d: Divide if d.dataType.isInstanceOf[DecimalType] => d
- case Divide(left, right) if isNumeric(left) && isNumeric(right) =>
+ case Divide(left, right) if isNumericOrNull(left) && isNumericOrNull(right) =>
Divide(Cast(left, DoubleType), Cast(right, DoubleType))
}
- private def isNumeric(ex: Expression): Boolean = ex.dataType.isInstanceOf[NumericType]
+ private def isNumericOrNull(ex: Expression): Boolean = {
+ // We need to handle null types in case a query contains null literals.
+ ex.dataType.isInstanceOf[NumericType] || ex.dataType == NullType
+ }
}
/**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
index a13c45fe2ffee..9560563a8ca56 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis
import java.sql.Timestamp
-import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{Division, FunctionArgumentConversion}
+import org.apache.spark.sql.catalyst.analysis.TypeCoercion._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
@@ -730,6 +730,13 @@ class TypeCoercionSuite extends PlanTest {
// the right expression to Decimal.
ruleTest(rules, sum(Divide(Decimal(4.0), 3)), sum(Divide(Decimal(4.0), 3)))
}
+
+ test("SPARK-17117 null type coercion in divide") {
+ val rules = Seq(FunctionArgumentConversion, Division, ImplicitTypeCasts)
+ val nullLit = Literal.create(null, NullType)
+ ruleTest(rules, Divide(1L, nullLit), Divide(Cast(1L, DoubleType), Cast(nullLit, DoubleType)))
+ ruleTest(rules, Divide(nullLit, 1L), Divide(Cast(nullLit, DoubleType), Cast(1L, DoubleType)))
+ }
}
diff --git a/sql/core/src/test/resources/sql-tests/inputs/arithmetic.sql b/sql/core/src/test/resources/sql-tests/inputs/arithmetic.sql
index cbe40410cdc10..f62b10ca0037b 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/arithmetic.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/arithmetic.sql
@@ -16,11 +16,19 @@ select + + 100;
select - - max(key) from testdata;
select + - key from testdata where key = 33;
+-- div
+select 5 / 2;
+select 5 / 0;
+select 5 / null;
+select null / 5;
+select 5 div 2;
+select 5 div 0;
+select 5 div null;
+select null div 5;
+
-- other arithmetics
select 1 + 2;
select 1 - 2;
select 2 * 5;
-select 5 / 2;
-select 5 div 2;
select 5 % 3;
select pmod(-7, 3);
diff --git a/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out b/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out
index f2b40a00d062d..6abe048af477d 100644
--- a/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 22
+-- Number of queries: 28
-- !query 0
@@ -123,35 +123,35 @@ struct<(- key):int>
-- !query 15
-select 1 + 2
+select 5 / 2
-- !query 15 schema
-struct<(1 + 2):int>
+struct<(CAST(5 AS DOUBLE) / CAST(2 AS DOUBLE)):double>
-- !query 15 output
-3
+2.5
-- !query 16
-select 1 - 2
+select 5 / 0
-- !query 16 schema
-struct<(1 - 2):int>
+struct<(CAST(5 AS DOUBLE) / CAST(0 AS DOUBLE)):double>
-- !query 16 output
--1
+NULL
-- !query 17
-select 2 * 5
+select 5 / null
-- !query 17 schema
-struct<(2 * 5):int>
+struct<(CAST(5 AS DOUBLE) / CAST(NULL AS DOUBLE)):double>
-- !query 17 output
-10
+NULL
-- !query 18
-select 5 / 2
+select null / 5
-- !query 18 schema
-struct<(CAST(5 AS DOUBLE) / CAST(2 AS DOUBLE)):double>
+struct<(CAST(NULL AS DOUBLE) / CAST(5 AS DOUBLE)):double>
-- !query 18 output
-2.5
+NULL
-- !query 19
@@ -163,16 +163,64 @@ struct
-- !query 20
-select 5 % 3
+select 5 div 0
-- !query 20 schema
-struct<(5 % 3):int>
+struct
-- !query 20 output
-2
+NULL
-- !query 21
-select pmod(-7, 3)
+select 5 div null
-- !query 21 schema
-struct
+struct
-- !query 21 output
+NULL
+
+
+-- !query 22
+select null div 5
+-- !query 22 schema
+struct
+-- !query 22 output
+NULL
+
+
+-- !query 23
+select 1 + 2
+-- !query 23 schema
+struct<(1 + 2):int>
+-- !query 23 output
+3
+
+
+-- !query 24
+select 1 - 2
+-- !query 24 schema
+struct<(1 - 2):int>
+-- !query 24 output
+-1
+
+
+-- !query 25
+select 2 * 5
+-- !query 25 schema
+struct<(2 * 5):int>
+-- !query 25 output
+10
+
+
+-- !query 26
+select 5 % 3
+-- !query 26 schema
+struct<(5 % 3):int>
+-- !query 26 output
+2
+
+
+-- !query 27
+select pmod(-7, 3)
+-- !query 27 schema
+struct
+-- !query 27 output
2
From b72bb62d421840f82d663c6b8e3922bd14383fbb Mon Sep 17 00:00:00 2001
From: Xusen Yin
Date: Thu, 18 Aug 2016 05:33:52 -0700
Subject: [PATCH 021/270] [SPARK-16447][ML][SPARKR] LDA wrapper in SparkR
## What changes were proposed in this pull request?
Add LDA Wrapper in SparkR with the following interfaces:
- spark.lda(data, ...)
- spark.posterior(object, newData, ...)
- spark.perplexity(object, ...)
- summary(object)
- write.ml(object)
- read.ml(path)
## How was this patch tested?
Test with SparkR unit test.
Author: Xusen Yin
Closes #14229 from yinxusen/SPARK-16447.
---
R/pkg/NAMESPACE | 3 +
R/pkg/R/generics.R | 14 ++
R/pkg/R/mllib.R | 166 +++++++++++++-
R/pkg/inst/tests/testthat/test_mllib.R | 87 +++++++
.../org/apache/spark/ml/clustering/LDA.scala | 4 +
.../org/apache/spark/ml/r/LDAWrapper.scala | 216 ++++++++++++++++++
.../org/apache/spark/ml/r/RWrappers.scala | 2 +
7 files changed, 490 insertions(+), 2 deletions(-)
create mode 100644 mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index c71eec5ce0437..4404cffc292aa 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -25,6 +25,9 @@ exportMethods("glm",
"fitted",
"spark.naiveBayes",
"spark.survreg",
+ "spark.lda",
+ "spark.posterior",
+ "spark.perplexity",
"spark.isoreg",
"spark.gaussianMixture")
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 06bb25d62d34d..fe04bcfc7d14d 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1304,6 +1304,19 @@ setGeneric("spark.naiveBayes", function(data, formula, ...) { standardGeneric("s
#' @export
setGeneric("spark.survreg", function(data, formula, ...) { standardGeneric("spark.survreg") })
+#' @rdname spark.lda
+#' @param ... Additional parameters to tune LDA.
+#' @export
+setGeneric("spark.lda", function(data, ...) { standardGeneric("spark.lda") })
+
+#' @rdname spark.lda
+#' @export
+setGeneric("spark.posterior", function(object, newData) { standardGeneric("spark.posterior") })
+
+#' @rdname spark.lda
+#' @export
+setGeneric("spark.perplexity", function(object, data) { standardGeneric("spark.perplexity") })
+
#' @rdname spark.isoreg
#' @export
setGeneric("spark.isoreg", function(data, formula, ...) { standardGeneric("spark.isoreg") })
@@ -1315,6 +1328,7 @@ setGeneric("spark.gaussianMixture",
standardGeneric("spark.gaussianMixture")
})
+#' write.ml
#' @rdname write.ml
#' @export
setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") })
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index db74046056a99..b9527410a9853 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -39,6 +39,13 @@ setClass("GeneralizedLinearRegressionModel", representation(jobj = "jobj"))
#' @note NaiveBayesModel since 2.0.0
setClass("NaiveBayesModel", representation(jobj = "jobj"))
+#' S4 class that represents an LDAModel
+#'
+#' @param jobj a Java object reference to the backing Scala LDAWrapper
+#' @export
+#' @note LDAModel since 2.1.0
+setClass("LDAModel", representation(jobj = "jobj"))
+
#' S4 class that represents a AFTSurvivalRegressionModel
#'
#' @param jobj a Java object reference to the backing Scala AFTSurvivalRegressionWrapper
@@ -75,7 +82,7 @@ setClass("GaussianMixtureModel", representation(jobj = "jobj"))
#' @name write.ml
#' @export
#' @seealso \link{spark.glm}, \link{glm}, \link{spark.gaussianMixture}
-#' @seealso \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg}
+#' @seealso \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg}, \link{spark.lda}
#' @seealso \link{spark.isoreg}
#' @seealso \link{read.ml}
NULL
@@ -315,6 +322,94 @@ setMethod("summary", signature(object = "NaiveBayesModel"),
return(list(apriori = apriori, tables = tables))
})
+# Returns posterior probabilities from a Latent Dirichlet Allocation model produced by spark.lda()
+
+#' @param newData A SparkDataFrame for testing
+#' @return \code{spark.posterior} returns a SparkDataFrame containing posterior probabilities
+#' vectors named "topicDistribution"
+#' @rdname spark.lda
+#' @aliases spark.posterior,LDAModel,SparkDataFrame-method
+#' @export
+#' @note spark.posterior(LDAModel) since 2.1.0
+setMethod("spark.posterior", signature(object = "LDAModel", newData = "SparkDataFrame"),
+ function(object, newData) {
+ return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
+ })
+
+# Returns the summary of a Latent Dirichlet Allocation model produced by \code{spark.lda}
+
+#' @param object A Latent Dirichlet Allocation model fitted by \code{spark.lda}.
+#' @param maxTermsPerTopic Maximum number of terms to collect for each topic. Default value of 10.
+#' @return \code{summary} returns a list containing
+#' \item{\code{docConcentration}}{concentration parameter commonly named \code{alpha} for
+#' the prior placed on documents distributions over topics \code{theta}}
+#' \item{\code{topicConcentration}}{concentration parameter commonly named \code{beta} or
+#' \code{eta} for the prior placed on topic distributions over terms}
+#' \item{\code{logLikelihood}}{log likelihood of the entire corpus}
+#' \item{\code{logPerplexity}}{log perplexity}
+#' \item{\code{isDistributed}}{TRUE for distributed model while FALSE for local model}
+#' \item{\code{vocabSize}}{number of terms in the corpus}
+#' \item{\code{topics}}{top 10 terms and their weights of all topics}
+#' \item{\code{vocabulary}}{whole terms of the training corpus, NULL if libsvm format file
+#' used as training set}
+#' @rdname spark.lda
+#' @aliases summary,LDAModel-method
+#' @export
+#' @note summary(LDAModel) since 2.1.0
+setMethod("summary", signature(object = "LDAModel"),
+ function(object, maxTermsPerTopic) {
+ maxTermsPerTopic <- as.integer(ifelse(missing(maxTermsPerTopic), 10, maxTermsPerTopic))
+ jobj <- object@jobj
+ docConcentration <- callJMethod(jobj, "docConcentration")
+ topicConcentration <- callJMethod(jobj, "topicConcentration")
+ logLikelihood <- callJMethod(jobj, "logLikelihood")
+ logPerplexity <- callJMethod(jobj, "logPerplexity")
+ isDistributed <- callJMethod(jobj, "isDistributed")
+ vocabSize <- callJMethod(jobj, "vocabSize")
+ topics <- dataFrame(callJMethod(jobj, "topics", maxTermsPerTopic))
+ vocabulary <- callJMethod(jobj, "vocabulary")
+ return(list(docConcentration = unlist(docConcentration),
+ topicConcentration = topicConcentration,
+ logLikelihood = logLikelihood, logPerplexity = logPerplexity,
+ isDistributed = isDistributed, vocabSize = vocabSize,
+ topics = topics,
+ vocabulary = unlist(vocabulary)))
+ })
+
+# Returns the log perplexity of a Latent Dirichlet Allocation model produced by \code{spark.lda}
+
+#' @return \code{spark.perplexity} returns the log perplexity of given SparkDataFrame, or the log
+#' perplexity of the training data if missing argument "data".
+#' @rdname spark.lda
+#' @aliases spark.perplexity,LDAModel-method
+#' @export
+#' @note spark.perplexity(LDAModel) since 2.1.0
+setMethod("spark.perplexity", signature(object = "LDAModel", data = "SparkDataFrame"),
+ function(object, data) {
+ return(ifelse(missing(data), callJMethod(object@jobj, "logPerplexity"),
+ callJMethod(object@jobj, "computeLogPerplexity", data@sdf)))
+ })
+
+# Saves the Latent Dirichlet Allocation model to the input path.
+
+#' @param path The directory where the model is saved
+#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
+#' which means throw exception if the output path exists.
+#'
+#' @rdname spark.lda
+#' @aliases write.ml,LDAModel,character-method
+#' @export
+#' @seealso \link{read.ml}
+#' @note write.ml(LDAModel, character) since 2.1.0
+setMethod("write.ml", signature(object = "LDAModel", path = "character"),
+ function(object, path, overwrite = FALSE) {
+ writer <- callJMethod(object@jobj, "write")
+ if (overwrite) {
+ writer <- callJMethod(writer, "overwrite")
+ }
+ invisible(callJMethod(writer, "save", path))
+ })
+
#' Isotonic Regression Model
#'
#' Fits an Isotonic Regression model against a Spark DataFrame, similarly to R's isoreg().
@@ -700,6 +795,8 @@ read.ml <- function(path) {
return(new("GeneralizedLinearRegressionModel", jobj = jobj))
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.KMeansWrapper")) {
return(new("KMeansModel", jobj = jobj))
+ } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LDAWrapper")) {
+ return(new("LDAModel", jobj = jobj))
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.IsotonicRegressionWrapper")) {
return(new("IsotonicRegressionModel", jobj = jobj))
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GaussianMixtureWrapper")) {
@@ -751,6 +848,71 @@ setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula
return(new("AFTSurvivalRegressionModel", jobj = jobj))
})
+#' Latent Dirichlet Allocation
+#'
+#' \code{spark.lda} fits a Latent Dirichlet Allocation model on a SparkDataFrame. Users can call
+#' \code{summary} to get a summary of the fitted LDA model, \code{spark.posterior} to compute
+#' posterior probabilities on new data, \code{spark.perplexity} to compute log perplexity on new
+#' data and \code{write.ml}/\code{read.ml} to save/load fitted models.
+#'
+#' @param data A SparkDataFrame for training
+#' @param features Features column name, default "features". Either libSVM-format column or
+#' character-format column is valid.
+#' @param k Number of topics, default 10
+#' @param maxIter Maximum iterations, default 20
+#' @param optimizer Optimizer to train an LDA model, "online" or "em", default "online"
+#' @param subsamplingRate (For online optimizer) Fraction of the corpus to be sampled and used in
+#' each iteration of mini-batch gradient descent, in range (0, 1], default 0.05
+#' @param topicConcentration concentration parameter (commonly named \code{beta} or \code{eta}) for
+#' the prior placed on topic distributions over terms, default -1 to set automatically on the
+#' Spark side. Use \code{summary} to retrieve the effective topicConcentration. Only 1-size
+#' numeric is accepted.
+#' @param docConcentration concentration parameter (commonly named \code{alpha}) for the
+#' prior placed on documents distributions over topics (\code{theta}), default -1 to set
+#' automatically on the Spark side. Use \code{summary} to retrieve the effective
+#' docConcentration. Only 1-size or \code{k}-size numeric is accepted.
+#' @param customizedStopWords stopwords that need to be removed from the given corpus. Ignore the
+#' parameter if libSVM-format column is used as the features column.
+#' @param maxVocabSize maximum vocabulary size, default 1 << 18
+#' @return \code{spark.lda} returns a fitted Latent Dirichlet Allocation model
+#' @rdname spark.lda
+#' @aliases spark.lda,SparkDataFrame-method
+#' @seealso topicmodels: \url{https://cran.r-project.org/web/packages/topicmodels/}
+#' @export
+#' @examples
+#' \dontrun{
+#' text <- read.df("path/to/data", source = "libsvm")
+#' model <- spark.lda(data = text, optimizer = "em")
+#'
+#' # get a summary of the model
+#' summary(model)
+#'
+#' # compute posterior probabilities
+#' posterior <- spark.posterior(model, df)
+#' showDF(posterior)
+#'
+#' # compute perplexity
+#' perplexity <- spark.perplexity(model, df)
+#'
+#' # save and load the model
+#' path <- "path/to/model"
+#' write.ml(model, path)
+#' savedModel <- read.ml(path)
+#' summary(savedModel)
+#' }
+#' @note spark.lda since 2.1.0
+setMethod("spark.lda", signature(data = "SparkDataFrame"),
+ function(data, features = "features", k = 10, maxIter = 20, optimizer = c("online", "em"),
+ subsamplingRate = 0.05, topicConcentration = -1, docConcentration = -1,
+ customizedStopWords = "", maxVocabSize = bitwShiftL(1, 18)) {
+ optimizer <- match.arg(optimizer)
+ jobj <- callJStatic("org.apache.spark.ml.r.LDAWrapper", "fit", data@sdf, features,
+ as.integer(k), as.integer(maxIter), optimizer,
+ as.numeric(subsamplingRate), topicConcentration,
+ as.array(docConcentration), as.array(customizedStopWords),
+ maxVocabSize)
+ return(new("LDAModel", jobj = jobj))
+ })
# Returns a summary of the AFT survival regression model produced by spark.survreg,
# similarly to R's summary().
@@ -891,4 +1053,4 @@ setMethod("summary", signature(object = "GaussianMixtureModel"),
setMethod("predict", signature(object = "GaussianMixtureModel"),
function(object, newData) {
return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
- })
+ })
\ No newline at end of file
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index 96179864a88bf..8c380fbf150f4 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -570,4 +570,91 @@ test_that("spark.gaussianMixture", {
unlink(modelPath)
})
+test_that("spark.lda with libsvm", {
+ text <- read.df("data/mllib/sample_lda_libsvm_data.txt", source = "libsvm")
+ model <- spark.lda(text, optimizer = "em")
+
+ stats <- summary(model, 10)
+ isDistributed <- stats$isDistributed
+ logLikelihood <- stats$logLikelihood
+ logPerplexity <- stats$logPerplexity
+ vocabSize <- stats$vocabSize
+ topics <- stats$topicTopTerms
+ weights <- stats$topicTopTermsWeights
+ vocabulary <- stats$vocabulary
+
+ expect_false(isDistributed)
+ expect_true(logLikelihood <= 0 & is.finite(logLikelihood))
+ expect_true(logPerplexity >= 0 & is.finite(logPerplexity))
+ expect_equal(vocabSize, 11)
+ expect_true(is.null(vocabulary))
+
+ # Test model save/load
+ modelPath <- tempfile(pattern = "spark-lda", fileext = ".tmp")
+ write.ml(model, modelPath)
+ expect_error(write.ml(model, modelPath))
+ write.ml(model, modelPath, overwrite = TRUE)
+ model2 <- read.ml(modelPath)
+ stats2 <- summary(model2)
+
+ expect_false(stats2$isDistributed)
+ expect_equal(logLikelihood, stats2$logLikelihood)
+ expect_equal(logPerplexity, stats2$logPerplexity)
+ expect_equal(vocabSize, stats2$vocabSize)
+ expect_equal(vocabulary, stats2$vocabulary)
+
+ unlink(modelPath)
+})
+
+test_that("spark.lda with text input", {
+ text <- read.text("data/mllib/sample_lda_data.txt")
+ model <- spark.lda(text, optimizer = "online", features = "value")
+
+ stats <- summary(model)
+ isDistributed <- stats$isDistributed
+ logLikelihood <- stats$logLikelihood
+ logPerplexity <- stats$logPerplexity
+ vocabSize <- stats$vocabSize
+ topics <- stats$topicTopTerms
+ weights <- stats$topicTopTermsWeights
+ vocabulary <- stats$vocabulary
+
+ expect_false(isDistributed)
+ expect_true(logLikelihood <= 0 & is.finite(logLikelihood))
+ expect_true(logPerplexity >= 0 & is.finite(logPerplexity))
+ expect_equal(vocabSize, 10)
+ expect_true(setequal(stats$vocabulary, c("0", "1", "2", "3", "4", "5", "6", "7", "8", "9")))
+
+ # Test model save/load
+ modelPath <- tempfile(pattern = "spark-lda-text", fileext = ".tmp")
+ write.ml(model, modelPath)
+ expect_error(write.ml(model, modelPath))
+ write.ml(model, modelPath, overwrite = TRUE)
+ model2 <- read.ml(modelPath)
+ stats2 <- summary(model2)
+
+ expect_false(stats2$isDistributed)
+ expect_equal(logLikelihood, stats2$logLikelihood)
+ expect_equal(logPerplexity, stats2$logPerplexity)
+ expect_equal(vocabSize, stats2$vocabSize)
+ expect_true(all.equal(vocabulary, stats2$vocabulary))
+
+ unlink(modelPath)
+})
+
+test_that("spark.posterior and spark.perplexity", {
+ text <- read.text("data/mllib/sample_lda_data.txt")
+ model <- spark.lda(text, features = "value", k = 3)
+
+ # Assert perplexities are equal
+ stats <- summary(model)
+ logPerplexity <- spark.perplexity(model, text)
+ expect_equal(logPerplexity, stats$logPerplexity)
+
+ # Assert the sum of every topic distribution is equal to 1
+ posterior <- spark.posterior(model, text)
+ local.posterior <- collect(posterior)$topicDistribution
+ expect_equal(length(local.posterior), sum(unlist(local.posterior)))
+})
+
sparkR.session.stop()
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index 034f2c3fa2fd9..b5a764b5863f1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -386,6 +386,10 @@ sealed abstract class LDAModel private[ml] (
@Since("1.6.0")
protected def getModel: OldLDAModel
+ private[ml] def getEffectiveDocConcentration: Array[Double] = getModel.docConcentration.toArray
+
+ private[ml] def getEffectiveTopicConcentration: Double = getModel.topicConcentration
+
/**
* The features for LDA should be a [[Vector]] representing the word counts in a document.
* The vector should be of length vocabSize, with counts for each term (word).
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala
new file mode 100644
index 0000000000000..cbe6a705007d1
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import scala.collection.mutable
+
+import org.apache.hadoop.fs.Path
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.SparkException
+import org.apache.spark.ml.{Pipeline, PipelineModel, PipelineStage}
+import org.apache.spark.ml.clustering.{LDA, LDAModel}
+import org.apache.spark.ml.feature.{CountVectorizer, CountVectorizerModel, RegexTokenizer, StopWordsRemover}
+import org.apache.spark.ml.linalg.{Vector, VectorUDT}
+import org.apache.spark.ml.param.ParamPair
+import org.apache.spark.ml.util._
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.StringType
+
+
+private[r] class LDAWrapper private (
+ val pipeline: PipelineModel,
+ val logLikelihood: Double,
+ val logPerplexity: Double,
+ val vocabulary: Array[String]) extends MLWritable {
+
+ import LDAWrapper._
+
+ private val lda: LDAModel = pipeline.stages.last.asInstanceOf[LDAModel]
+ private val preprocessor: PipelineModel =
+ new PipelineModel(s"${Identifiable.randomUID(pipeline.uid)}", pipeline.stages.dropRight(1))
+
+ def transform(data: Dataset[_]): DataFrame = {
+ val vec2ary = udf { vec: Vector => vec.toArray }
+ val outputCol = lda.getTopicDistributionCol
+ val tempCol = s"${Identifiable.randomUID(outputCol)}"
+ val preprocessed = preprocessor.transform(data)
+ lda.transform(preprocessed, ParamPair(lda.topicDistributionCol, tempCol))
+ .withColumn(outputCol, vec2ary(col(tempCol)))
+ .drop(TOKENIZER_COL, STOPWORDS_REMOVER_COL, COUNT_VECTOR_COL, tempCol)
+ }
+
+ def computeLogPerplexity(data: Dataset[_]): Double = {
+ lda.logPerplexity(preprocessor.transform(data))
+ }
+
+ def topics(maxTermsPerTopic: Int): DataFrame = {
+ val topicIndices: DataFrame = lda.describeTopics(maxTermsPerTopic)
+ if (vocabulary.isEmpty || vocabulary.length < vocabSize) {
+ topicIndices
+ } else {
+ val index2term = udf { indices: mutable.WrappedArray[Int] => indices.map(i => vocabulary(i)) }
+ topicIndices
+ .select(col("topic"), index2term(col("termIndices")).as("term"), col("termWeights"))
+ }
+ }
+
+ lazy val isDistributed: Boolean = lda.isDistributed
+ lazy val vocabSize: Int = lda.vocabSize
+ lazy val docConcentration: Array[Double] = lda.getEffectiveDocConcentration
+ lazy val topicConcentration: Double = lda.getEffectiveTopicConcentration
+
+ override def write: MLWriter = new LDAWrapper.LDAWrapperWriter(this)
+}
+
+private[r] object LDAWrapper extends MLReadable[LDAWrapper] {
+
+ val TOKENIZER_COL = s"${Identifiable.randomUID("rawTokens")}"
+ val STOPWORDS_REMOVER_COL = s"${Identifiable.randomUID("tokens")}"
+ val COUNT_VECTOR_COL = s"${Identifiable.randomUID("features")}"
+
+ private def getPreStages(
+ features: String,
+ customizedStopWords: Array[String],
+ maxVocabSize: Int): Array[PipelineStage] = {
+ val tokenizer = new RegexTokenizer()
+ .setInputCol(features)
+ .setOutputCol(TOKENIZER_COL)
+ val stopWordsRemover = new StopWordsRemover()
+ .setInputCol(TOKENIZER_COL)
+ .setOutputCol(STOPWORDS_REMOVER_COL)
+ stopWordsRemover.setStopWords(stopWordsRemover.getStopWords ++ customizedStopWords)
+ val countVectorizer = new CountVectorizer()
+ .setVocabSize(maxVocabSize)
+ .setInputCol(STOPWORDS_REMOVER_COL)
+ .setOutputCol(COUNT_VECTOR_COL)
+
+ Array(tokenizer, stopWordsRemover, countVectorizer)
+ }
+
+ def fit(
+ data: DataFrame,
+ features: String,
+ k: Int,
+ maxIter: Int,
+ optimizer: String,
+ subsamplingRate: Double,
+ topicConcentration: Double,
+ docConcentration: Array[Double],
+ customizedStopWords: Array[String],
+ maxVocabSize: Int): LDAWrapper = {
+
+ val lda = new LDA()
+ .setK(k)
+ .setMaxIter(maxIter)
+ .setSubsamplingRate(subsamplingRate)
+
+ val featureSchema = data.schema(features)
+ val stages = featureSchema.dataType match {
+ case d: StringType =>
+ getPreStages(features, customizedStopWords, maxVocabSize) ++
+ Array(lda.setFeaturesCol(COUNT_VECTOR_COL))
+ case d: VectorUDT =>
+ Array(lda.setFeaturesCol(features))
+ case _ =>
+ throw new SparkException(
+ s"Unsupported input features type of ${featureSchema.dataType.typeName}," +
+ s" only String type and Vector type are supported now.")
+ }
+
+ if (topicConcentration != -1) {
+ lda.setTopicConcentration(topicConcentration)
+ } else {
+ // Auto-set topicConcentration
+ }
+
+ if (docConcentration.length == 1) {
+ if (docConcentration.head != -1) {
+ lda.setDocConcentration(docConcentration.head)
+ } else {
+ // Auto-set docConcentration
+ }
+ } else {
+ lda.setDocConcentration(docConcentration)
+ }
+
+ val pipeline = new Pipeline().setStages(stages)
+ val model = pipeline.fit(data)
+
+ val vocabulary: Array[String] = featureSchema.dataType match {
+ case d: StringType =>
+ val countVectorModel = model.stages(2).asInstanceOf[CountVectorizerModel]
+ countVectorModel.vocabulary
+ case _ => Array.empty[String]
+ }
+
+ val ldaModel: LDAModel = model.stages.last.asInstanceOf[LDAModel]
+ val preprocessor: PipelineModel =
+ new PipelineModel(s"${Identifiable.randomUID(pipeline.uid)}", model.stages.dropRight(1))
+
+ val preprocessedData = preprocessor.transform(data)
+
+ new LDAWrapper(
+ model,
+ ldaModel.logLikelihood(preprocessedData),
+ ldaModel.logPerplexity(preprocessedData),
+ vocabulary)
+ }
+
+ override def read: MLReader[LDAWrapper] = new LDAWrapperReader
+
+ override def load(path: String): LDAWrapper = super.load(path)
+
+ class LDAWrapperWriter(instance: LDAWrapper) extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ val rMetadataPath = new Path(path, "rMetadata").toString
+ val pipelinePath = new Path(path, "pipeline").toString
+
+ val rMetadata = ("class" -> instance.getClass.getName) ~
+ ("logLikelihood" -> instance.logLikelihood) ~
+ ("logPerplexity" -> instance.logPerplexity) ~
+ ("vocabulary" -> instance.vocabulary.toList)
+ val rMetadataJson: String = compact(render(rMetadata))
+ sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
+
+ instance.pipeline.save(pipelinePath)
+ }
+ }
+
+ class LDAWrapperReader extends MLReader[LDAWrapper] {
+
+ override def load(path: String): LDAWrapper = {
+ implicit val format = DefaultFormats
+ val rMetadataPath = new Path(path, "rMetadata").toString
+ val pipelinePath = new Path(path, "pipeline").toString
+
+ val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
+ val rMetadata = parse(rMetadataStr)
+ val logLikelihood = (rMetadata \ "logLikelihood").extract[Double]
+ val logPerplexity = (rMetadata \ "logPerplexity").extract[Double]
+ val vocabulary = (rMetadata \ "vocabulary").extract[List[String]].toArray
+
+ val pipeline = PipelineModel.load(pipelinePath)
+ new LDAWrapper(pipeline, logLikelihood, logPerplexity, vocabulary)
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
index 88ac26bc5e351..e23af51df5718 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
@@ -44,6 +44,8 @@ private[r] object RWrappers extends MLReader[Object] {
GeneralizedLinearRegressionWrapper.load(path)
case "org.apache.spark.ml.r.KMeansWrapper" =>
KMeansWrapper.load(path)
+ case "org.apache.spark.ml.r.LDAWrapper" =>
+ LDAWrapper.load(path)
case "org.apache.spark.ml.r.IsotonicRegressionWrapper" =>
IsotonicRegressionWrapper.load(path)
case "org.apache.spark.ml.r.GaussianMixtureWrapper" =>
From f5472dda51b980a726346587257c22873ff708e3 Mon Sep 17 00:00:00 2001
From: petermaxlee
Date: Fri, 19 Aug 2016 09:19:47 +0800
Subject: [PATCH 022/270] [SPARK-16947][SQL] Support type coercion and foldable
expression for inline tables
## What changes were proposed in this pull request?
This patch improves inline table support with the following:
1. Support type coercion.
2. Support using foldable expressions. Previously only literals were supported.
3. Improve error message handling.
4. Improve test coverage.
## How was this patch tested?
Added a new unit test suite ResolveInlineTablesSuite and a new file-based end-to-end test inline-table.sql.
Author: petermaxlee
Closes #14676 from petermaxlee/SPARK-16947.
---
.../sql/catalyst/analysis/Analyzer.scala | 1 +
.../analysis/ResolveInlineTables.scala | 112 ++++++++++++++
.../sql/catalyst/analysis/TypeCoercion.scala | 2 +-
.../sql/catalyst/analysis/unresolved.scala | 26 +++-
.../sql/catalyst/parser/AstBuilder.scala | 41 ++---
.../analysis/ResolveInlineTablesSuite.scala | 101 ++++++++++++
.../sql/catalyst/parser/PlanParserSuite.scala | 22 +--
.../sql-tests/inputs/inline-table.sql | 48 ++++++
.../sql-tests/results/inline-table.sql.out | 145 ++++++++++++++++++
9 files changed, 452 insertions(+), 46 deletions(-)
create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala
create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala
create mode 100644 sql/core/src/test/resources/sql-tests/inputs/inline-table.sql
create mode 100644 sql/core/src/test/resources/sql-tests/results/inline-table.sql.out
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 333dd4d9a4f2a..41e0e6d65e9ad 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -108,6 +108,7 @@ class Analyzer(
GlobalAggregates ::
ResolveAggregateFunctions ::
TimeWindowing ::
+ ResolveInlineTables ::
TypeCoercion.typeCoercionRules ++
extendedResolutionRules : _*),
Batch("Nondeterministic", Once,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala
new file mode 100644
index 0000000000000..7323197b10f6e
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import scala.util.control.NonFatal
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Cast
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.types.{StructField, StructType}
+
+/**
+ * An analyzer rule that replaces [[UnresolvedInlineTable]] with [[LocalRelation]].
+ */
+object ResolveInlineTables extends Rule[LogicalPlan] {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+ case table: UnresolvedInlineTable if table.expressionsResolved =>
+ validateInputDimension(table)
+ validateInputEvaluable(table)
+ convert(table)
+ }
+
+ /**
+ * Validates the input data dimension:
+ * 1. All rows have the same cardinality.
+ * 2. The number of column aliases defined is consistent with the number of columns in data.
+ *
+ * This is package visible for unit testing.
+ */
+ private[analysis] def validateInputDimension(table: UnresolvedInlineTable): Unit = {
+ if (table.rows.nonEmpty) {
+ val numCols = table.names.size
+ table.rows.zipWithIndex.foreach { case (row, ri) =>
+ if (row.size != numCols) {
+ table.failAnalysis(s"expected $numCols columns but found ${row.size} columns in row $ri")
+ }
+ }
+ }
+ }
+
+ /**
+ * Validates that all inline table data are valid expressions that can be evaluated
+ * (in this they must be foldable).
+ *
+ * This is package visible for unit testing.
+ */
+ private[analysis] def validateInputEvaluable(table: UnresolvedInlineTable): Unit = {
+ table.rows.foreach { row =>
+ row.foreach { e =>
+ // Note that nondeterministic expressions are not supported since they are not foldable.
+ if (!e.resolved || !e.foldable) {
+ e.failAnalysis(s"cannot evaluate expression ${e.sql} in inline table definition")
+ }
+ }
+ }
+ }
+
+ /**
+ * Convert a valid (with right shape and foldable inputs) [[UnresolvedInlineTable]]
+ * into a [[LocalRelation]].
+ *
+ * This function attempts to coerce inputs into consistent types.
+ *
+ * This is package visible for unit testing.
+ */
+ private[analysis] def convert(table: UnresolvedInlineTable): LocalRelation = {
+ // For each column, traverse all the values and find a common data type and nullability.
+ val fields = table.rows.transpose.zip(table.names).map { case (column, name) =>
+ val inputTypes = column.map(_.dataType)
+ val tpe = TypeCoercion.findWiderTypeWithoutStringPromotion(inputTypes).getOrElse {
+ table.failAnalysis(s"incompatible types found in column $name for inline table")
+ }
+ StructField(name, tpe, nullable = column.exists(_.nullable))
+ }
+ val attributes = StructType(fields).toAttributes
+ assert(fields.size == table.names.size)
+
+ val newRows: Seq[InternalRow] = table.rows.map { row =>
+ InternalRow.fromSeq(row.zipWithIndex.map { case (e, ci) =>
+ val targetType = fields(ci).dataType
+ try {
+ if (e.dataType.sameType(targetType)) {
+ e.eval()
+ } else {
+ Cast(e, targetType).eval()
+ }
+ } catch {
+ case NonFatal(ex) =>
+ table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}")
+ }
+ })
+ }
+
+ LocalRelation(attributes, newRows)
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 21e96aaf53844..193c3ec4e585a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -150,7 +150,7 @@ object TypeCoercion {
* [[findTightestCommonType]], but can handle decimal types. If the wider decimal type exceeds
* system limitation, this rule will truncate the decimal type before return it.
*/
- private def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = {
+ def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = {
types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
case Some(d) => findTightestCommonTypeOfTwo(d, c).orElse((d, c) match {
case (t1: DecimalType, t2: DecimalType) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 3735a1501cbfa..235ae04782455 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -50,10 +50,30 @@ case class UnresolvedRelation(
}
/**
- * Holds a table-valued function call that has yet to be resolved.
+ * An inline table that has not been resolved yet. Once resolved, it is turned by the analyzer into
+ * a [[org.apache.spark.sql.catalyst.plans.logical.LocalRelation]].
+ *
+ * @param names list of column names
+ * @param rows expressions for the data
+ */
+case class UnresolvedInlineTable(
+ names: Seq[String],
+ rows: Seq[Seq[Expression]])
+ extends LeafNode {
+
+ lazy val expressionsResolved: Boolean = rows.forall(_.forall(_.resolved))
+ override lazy val resolved = false
+ override def output: Seq[Attribute] = Nil
+}
+
+/**
+ * A table-valued function, e.g.
+ * {{{
+ * select * from range(10);
+ * }}}
*/
-case class UnresolvedTableValuedFunction(
- functionName: String, functionArgs: Seq[Expression]) extends LeafNode {
+case class UnresolvedTableValuedFunction(functionName: String, functionArgs: Seq[Expression])
+ extends LeafNode {
override def output: Seq[Attribute] = Nil
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 01322ae327e4a..283e4d43ba2b9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -670,39 +670,24 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
*/
override def visitInlineTable(ctx: InlineTableContext): LogicalPlan = withOrigin(ctx) {
// Get the backing expressions.
- val expressions = ctx.expression.asScala.map { eCtx =>
- val e = expression(eCtx)
- validate(e.foldable, "All expressions in an inline table must be constants.", eCtx)
- e
- }
-
- // Validate and evaluate the rows.
- val (structType, structConstructor) = expressions.head.dataType match {
- case st: StructType =>
- (st, (e: Expression) => e)
- case dt =>
- val st = CreateStruct(Seq(expressions.head)).dataType
- (st, (e: Expression) => CreateStruct(Seq(e)))
- }
- val rows = expressions.map {
- case expression =>
- val safe = Cast(structConstructor(expression), structType)
- safe.eval().asInstanceOf[InternalRow]
+ val rows = ctx.expression.asScala.map { e =>
+ expression(e) match {
+ // inline table comes in two styles:
+ // style 1: values (1), (2), (3) -- multiple columns are supported
+ // style 2: values 1, 2, 3 -- only a single column is supported here
+ case CreateStruct(children) => children // style 1
+ case child => Seq(child) // style 2
+ }
}
- // Construct attributes.
- val baseAttributes = structType.toAttributes.map(_.withNullability(true))
- val attributes = if (ctx.identifierList != null) {
- val aliases = visitIdentifierList(ctx.identifierList)
- validate(aliases.size == baseAttributes.size,
- "Number of aliases must match the number of fields in an inline table.", ctx)
- baseAttributes.zip(aliases).map(p => p._1.withName(p._2))
+ val aliases = if (ctx.identifierList != null) {
+ visitIdentifierList(ctx.identifierList)
} else {
- baseAttributes
+ Seq.tabulate(rows.head.size)(i => s"col${i + 1}")
}
- // Create plan and add an alias if a name has been defined.
- LocalRelation(attributes, rows).optionalMap(ctx.identifier)(aliasPlan)
+ val table = UnresolvedInlineTable(aliases, rows)
+ table.optionalMap(ctx.identifier)(aliasPlan)
}
/**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala
new file mode 100644
index 0000000000000..920c6ea50f4ba
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.expressions.{Literal, Rand}
+import org.apache.spark.sql.catalyst.expressions.aggregate.Count
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.types.{LongType, NullType}
+
+/**
+ * Unit tests for [[ResolveInlineTables]]. Note that there are also test cases defined in
+ * end-to-end tests (in sql/core module) for verifying the correct error messages are shown
+ * in negative cases.
+ */
+class ResolveInlineTablesSuite extends PlanTest with BeforeAndAfter {
+
+ private def lit(v: Any): Literal = Literal(v)
+
+ test("validate inputs are foldable") {
+ ResolveInlineTables.validateInputEvaluable(
+ UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)))))
+
+ // nondeterministic (rand) should not work
+ intercept[AnalysisException] {
+ ResolveInlineTables.validateInputEvaluable(
+ UnresolvedInlineTable(Seq("c1"), Seq(Seq(Rand(1)))))
+ }
+
+ // aggregate should not work
+ intercept[AnalysisException] {
+ ResolveInlineTables.validateInputEvaluable(
+ UnresolvedInlineTable(Seq("c1"), Seq(Seq(Count(lit(1))))))
+ }
+
+ // unresolved attribute should not work
+ intercept[AnalysisException] {
+ ResolveInlineTables.validateInputEvaluable(
+ UnresolvedInlineTable(Seq("c1"), Seq(Seq(UnresolvedAttribute("A")))))
+ }
+ }
+
+ test("validate input dimensions") {
+ ResolveInlineTables.validateInputDimension(
+ UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2)))))
+
+ // num alias != data dimension
+ intercept[AnalysisException] {
+ ResolveInlineTables.validateInputDimension(
+ UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)), Seq(lit(2)))))
+ }
+
+ // num alias == data dimension, but data themselves are inconsistent
+ intercept[AnalysisException] {
+ ResolveInlineTables.validateInputDimension(
+ UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(21), lit(22)))))
+ }
+ }
+
+ test("do not fire the rule if not all expressions are resolved") {
+ val table = UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(UnresolvedAttribute("A"))))
+ assert(ResolveInlineTables(table) == table)
+ }
+
+ test("convert") {
+ val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L))))
+ val converted = ResolveInlineTables.convert(table)
+
+ assert(converted.output.map(_.dataType) == Seq(LongType))
+ assert(converted.data.size == 2)
+ assert(converted.data(0).getLong(0) == 1L)
+ assert(converted.data(1).getLong(0) == 2L)
+ }
+
+ test("nullability inference in convert") {
+ val table1 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L))))
+ val converted1 = ResolveInlineTables.convert(table1)
+ assert(!converted1.schema.fields(0).nullable)
+
+ val table2 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(Literal(null, NullType))))
+ val converted2 = ResolveInlineTables.convert(table2)
+ assert(converted2.schema.fields(0).nullable)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index cbe4a022e730d..2fcbfc7067a13 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -17,9 +17,8 @@
package org.apache.spark.sql.catalyst.parser
-import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.FunctionIdentifier
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedGenerator, UnresolvedTableValuedFunction}
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedGenerator, UnresolvedInlineTable, UnresolvedTableValuedFunction}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
@@ -433,19 +432,14 @@ class PlanParserSuite extends PlanTest {
}
test("inline table") {
- assertEqual("values 1, 2, 3, 4", LocalRelation.fromExternalRows(
- Seq('col1.int),
- Seq(1, 2, 3, 4).map(x => Row(x))))
+ assertEqual("values 1, 2, 3, 4",
+ UnresolvedInlineTable(Seq("col1"), Seq(1, 2, 3, 4).map(x => Seq(Literal(x)))))
+
assertEqual(
- "values (1, 'a'), (2, 'b'), (3, 'c') as tbl(a, b)",
- LocalRelation.fromExternalRows(
- Seq('a.int, 'b.string),
- Seq((1, "a"), (2, "b"), (3, "c")).map(x => Row(x._1, x._2))).as("tbl"))
- intercept("values (a, 'a'), (b, 'b')",
- "All expressions in an inline table must be constants.")
- intercept("values (1, 'a'), (2, 'b') as tbl(a, b, c)",
- "Number of aliases must match the number of fields in an inline table.")
- intercept[ArrayIndexOutOfBoundsException](parsePlan("values (1, 'a'), (2, 'b', 5Y)"))
+ "values (1, 'a'), (2, 'b') as tbl(a, b)",
+ UnresolvedInlineTable(
+ Seq("a", "b"),
+ Seq(Literal(1), Literal("a")) :: Seq(Literal(2), Literal("b")) :: Nil).as("tbl"))
}
test("simple select query with !> and !<") {
diff --git a/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql b/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql
new file mode 100644
index 0000000000000..5107fa4d55537
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql
@@ -0,0 +1,48 @@
+
+-- single row, without table and column alias
+select * from values ("one", 1);
+
+-- single row, without column alias
+select * from values ("one", 1) as data;
+
+-- single row
+select * from values ("one", 1) as data(a, b);
+
+-- single column multiple rows
+select * from values 1, 2, 3 as data(a);
+
+-- three rows
+select * from values ("one", 1), ("two", 2), ("three", null) as data(a, b);
+
+-- null type
+select * from values ("one", null), ("two", null) as data(a, b);
+
+-- int and long coercion
+select * from values ("one", 1), ("two", 2L) as data(a, b);
+
+-- foldable expressions
+select * from values ("one", 1 + 0), ("two", 1 + 3L) as data(a, b);
+
+-- complex types
+select * from values ("one", array(0, 1)), ("two", array(2, 3)) as data(a, b);
+
+-- decimal and double coercion
+select * from values ("one", 2.0), ("two", 3.0D) as data(a, b);
+
+-- error reporting: nondeterministic function rand
+select * from values ("one", rand(5)), ("two", 3.0D) as data(a, b);
+
+-- error reporting: different number of columns
+select * from values ("one", 2.0), ("two") as data(a, b);
+
+-- error reporting: types that are incompatible
+select * from values ("one", array(0, 1)), ("two", struct(1, 2)) as data(a, b);
+
+-- error reporting: number aliases different from number data values
+select * from values ("one"), ("two") as data(a, b);
+
+-- error reporting: unresolved expression
+select * from values ("one", random_not_exist_func(1)), ("two", 2) as data(a, b);
+
+-- error reporting: aggregate expression
+select * from values ("one", count(1)), ("two", 2) as data(a, b);
diff --git a/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out b/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out
new file mode 100644
index 0000000000000..de6f01b8de772
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out
@@ -0,0 +1,145 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 16
+
+
+-- !query 0
+select * from values ("one", 1)
+-- !query 0 schema
+struct
+-- !query 0 output
+one 1
+
+
+-- !query 1
+select * from values ("one", 1) as data
+-- !query 1 schema
+struct
+-- !query 1 output
+one 1
+
+
+-- !query 2
+select * from values ("one", 1) as data(a, b)
+-- !query 2 schema
+struct
+-- !query 2 output
+one 1
+
+
+-- !query 3
+select * from values 1, 2, 3 as data(a)
+-- !query 3 schema
+struct
+-- !query 3 output
+1
+2
+3
+
+
+-- !query 4
+select * from values ("one", 1), ("two", 2), ("three", null) as data(a, b)
+-- !query 4 schema
+struct
+-- !query 4 output
+one 1
+three NULL
+two 2
+
+
+-- !query 5
+select * from values ("one", null), ("two", null) as data(a, b)
+-- !query 5 schema
+struct
+-- !query 5 output
+one NULL
+two NULL
+
+
+-- !query 6
+select * from values ("one", 1), ("two", 2L) as data(a, b)
+-- !query 6 schema
+struct
+-- !query 6 output
+one 1
+two 2
+
+
+-- !query 7
+select * from values ("one", 1 + 0), ("two", 1 + 3L) as data(a, b)
+-- !query 7 schema
+struct
+-- !query 7 output
+one 1
+two 4
+
+
+-- !query 8
+select * from values ("one", array(0, 1)), ("two", array(2, 3)) as data(a, b)
+-- !query 8 schema
+struct>
+-- !query 8 output
+one [0,1]
+two [2,3]
+
+
+-- !query 9
+select * from values ("one", 2.0), ("two", 3.0D) as data(a, b)
+-- !query 9 schema
+struct
+-- !query 9 output
+one 2.0
+two 3.0
+
+
+-- !query 10
+select * from values ("one", rand(5)), ("two", 3.0D) as data(a, b)
+-- !query 10 schema
+struct<>
+-- !query 10 output
+org.apache.spark.sql.AnalysisException
+cannot evaluate expression rand(5) in inline table definition; line 1 pos 29
+
+
+-- !query 11
+select * from values ("one", 2.0), ("two") as data(a, b)
+-- !query 11 schema
+struct<>
+-- !query 11 output
+org.apache.spark.sql.AnalysisException
+expected 2 columns but found 1 columns in row 1; line 1 pos 14
+
+
+-- !query 12
+select * from values ("one", array(0, 1)), ("two", struct(1, 2)) as data(a, b)
+-- !query 12 schema
+struct<>
+-- !query 12 output
+org.apache.spark.sql.AnalysisException
+incompatible types found in column b for inline table; line 1 pos 14
+
+
+-- !query 13
+select * from values ("one"), ("two") as data(a, b)
+-- !query 13 schema
+struct<>
+-- !query 13 output
+org.apache.spark.sql.AnalysisException
+expected 2 columns but found 1 columns in row 0; line 1 pos 14
+
+
+-- !query 14
+select * from values ("one", random_not_exist_func(1)), ("two", 2) as data(a, b)
+-- !query 14 schema
+struct<>
+-- !query 14 output
+org.apache.spark.sql.AnalysisException
+Undefined function: 'random_not_exist_func'. This function is neither a registered temporary function nor a permanent function registered in the database 'default'.; line 1 pos 29
+
+
+-- !query 15
+select * from values ("one", count(1)), ("two", 2) as data(a, b)
+-- !query 15 schema
+struct<>
+-- !query 15 output
+org.apache.spark.sql.AnalysisException
+cannot evaluate expression count(1) in inline table definition; line 1 pos 29
From b482c09fa22c5762a355f95820e4ba3e2517fb77 Mon Sep 17 00:00:00 2001
From: Reynold Xin
Date: Thu, 18 Aug 2016 19:02:32 -0700
Subject: [PATCH 023/270] HOTFIX: compilation broken due to protected ctor.
---
.../org/apache/spark/sql/catalyst/expressions/literals.scala | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 95ed68fbb0528..7040008769a32 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -163,8 +163,7 @@ object DecimalLiteral {
/**
* In order to do type checking, use Literal.create() instead of constructor
*/
-case class Literal protected (value: Any, dataType: DataType)
- extends LeafExpression with CodegenFallback {
+case class Literal (value: Any, dataType: DataType) extends LeafExpression with CodegenFallback {
override def foldable: Boolean = true
override def nullable: Boolean = value == null
From 287bea13050b8eedc3b8b6b3491f1b5e5bc24d7a Mon Sep 17 00:00:00 2001
From: sethah
Date: Thu, 18 Aug 2016 22:16:48 -0700
Subject: [PATCH 024/270] [SPARK-7159][ML] Add multiclass logistic regression
to Spark ML
## What changes were proposed in this pull request?
This patch adds a new estimator/transformer `MultinomialLogisticRegression` to spark ML.
JIRA: [SPARK-7159](https://issues.apache.org/jira/browse/SPARK-7159)
## How was this patch tested?
Added new test suite `MultinomialLogisticRegressionSuite`.
## Approach
### Do not use a "pivot" class in the algorithm formulation
Many implementations of multinomial logistic regression treat the problem as K - 1 independent binary logistic regression models where K is the number of possible outcomes in the output variable. In this case, one outcome is chosen as a "pivot" and the other K - 1 outcomes are regressed against the pivot. This is somewhat undesirable since the coefficients returned will be different for different choices of pivot variables. An alternative approach to the problem models class conditional probabilites using the softmax function and will return uniquely identifiable coefficients (assuming regularization is applied). This second approach is used in R's glmnet and was also recommended by dbtsai.
### Separate multinomial logistic regression and binary logistic regression
The initial design makes multinomial logistic regression a separate estimator/transformer than the existing LogisticRegression estimator/transformer. An alternative design would be to merge them into one.
**Arguments for:**
* The multinomial case without pivot is distinctly different than the current binary case since the binary case uses a pivot class.
* The current logistic regression model in ML uses a vector of coefficients and a scalar intercept. In the multinomial case, we require a matrix of coefficients and a vector of intercepts. There are potential workarounds for this issue if we were to merge the two estimators, but none are particularly elegant.
**Arguments against:**
* It may be inconvenient for users to have to switch the estimator class when transitioning between binary and multiclass (although the new multinomial estimator can be used for two class outcomes).
* Some portions of the code are repeated.
This is a major design point and warrants more discussion.
### Mean centering
When no regularization is applied, the coefficients will not be uniquely identifiable. This is not hard to show and is discussed in further detail [here](https://core.ac.uk/download/files/153/6287975.pdf). R's glmnet deals with this by choosing the minimum l2 regularized solution (i.e. mean centering). Additionally, the intercepts are never regularized so they are always mean centered. This is the approach taken in this PR as well.
### Feature scaling
In current ML logistic regression, the features are always standardized when running the optimization algorithm. They are always returned to the user in the original feature space, however. This same approach is maintained in this patch as well, but the implementation details are different. In ML logistic regression, the unregularized feature values are divided by the column standard deviation in every gradient update iteration. In contrast, MLlib transforms the entire input dataset to the scaled space _before_ optimizaton. In ML, this means that `numFeatures * numClasses` extra scalar divisions are required in every iteration. Performance testing shows that this has significant (4x in some cases) slow downs in each iteration. This can be avoided by transforming the input to the scaled space ala MLlib once, before iteration begins. This does add some overhead initially, but can make significant time savings in some cases.
One issue with this approach is that if the input data is already cached, there may not be enough memory to cache the transformed data, which would make the algorithm _much_ slower. The tradeoffs here merit more discussion.
### Specifying and inferring the number of outcome classes
The estimator checks the dataframe label column for metadata which specifies the number of values. If they are not specified, the length of the `histogram` variable is used, which is essentially the maximum value found in the column. The assumption then, is that the labels are zero-indexed when they are provided to the algorithm.
## Performance
Below are some performance tests I have run so far. I am happy to add more cases or trials if we deem them necessary.
Test cluster: 4 bare metal nodes, 128 GB RAM each, 48 cores each
Notes:
* Time in units of seconds
* Metric is classification accuracy
| algo | elasticNetParam | fitIntercept | metric | maxIter | numPoints | numClasses | numFeatures | time | standardization | regParam |
|--------|-------------------|----------------|----------|-----------|-------------|--------------|---------------|---------|-------------------|------------|
| ml | 0 | true | 0.746415 | 30 | 100000 | 3 | 100000 | 327.923 | true | 0 |
| mllib | 0 | true | 0.743785 | 30 | 100000 | 3 | 100000 | 390.217 | true | 0 |
| algo | elasticNetParam | fitIntercept | metric | maxIter | numPoints | numClasses | numFeatures | time | standardization | regParam |
|--------|-------------------|----------------|----------|-----------|-------------|--------------|---------------|---------|-------------------|------------|
| ml | 0 | true | 0.973238 | 30 | 2000000 | 3 | 10000 | 385.476 | true | 0 |
| mllib | 0 | true | 0.949828 | 30 | 2000000 | 3 | 10000 | 550.403 | true | 0 |
| algo | elasticNetParam | fitIntercept | metric | maxIter | numPoints | numClasses | numFeatures | time | standardization | regParam |
|--------|-------------------|----------------|----------|-----------|-------------|--------------|---------------|---------|-------------------|------------|
| mllib | 0 | true | 0.864358 | 30 | 2000000 | 3 | 10000 | 543.359 | true | 0.1 |
| ml | 0 | true | 0.867418 | 30 | 2000000 | 3 | 10000 | 401.955 | true | 0.1 |
| algo | elasticNetParam | fitIntercept | metric | maxIter | numPoints | numClasses | numFeatures | time | standardization | regParam |
|--------|-------------------|----------------|----------|-----------|-------------|--------------|---------------|---------|-------------------|------------|
| ml | 1 | true | 0.807449 | 30 | 2000000 | 3 | 10000 | 334.892 | true | 0.05 |
| algo | elasticNetParam | fitIntercept | metric | maxIter | numPoints | numClasses | numFeatures | time | standardization | regParam |
|--------|-------------------|----------------|----------|-----------|-------------|--------------|---------------|---------|-------------------|------------|
| ml | 0 | true | 0.602006 | 30 | 2000000 | 500 | 100 | 112.319 | true | 0 |
| mllib | 0 | true | 0.567226 | 30 | 2000000 | 500 | 100 | 263.768 | true | 0 |e | 0.567226 | 30 | 2000000 | 500 | 100 | 263.768 | true | 0 |
## References
Friedman, et al. ["Regularization Paths for Generalized Linear Models via Coordinate Descent"](https://core.ac.uk/download/files/153/6287975.pdf)
[http://web.stanford.edu/~hastie/glmnet/glmnet_alpha.html](http://web.stanford.edu/~hastie/glmnet/glmnet_alpha.html)
## Follow up items
* Consider using level 2 BLAS routines in the gradient computations - [SPARK-17134](https://issues.apache.org/jira/browse/SPARK-17134)
* Add model summary for MLOR - [SPARK-17139](https://issues.apache.org/jira/browse/SPARK-17139)
* Add initial model to MLOR and add test for intercept priors - [SPARK-17140](https://issues.apache.org/jira/browse/SPARK-17140)
* Python API - [SPARK-17138](https://issues.apache.org/jira/browse/SPARK-17138)
* Consider changing the tree aggregation level for MLOR/BLOR or making it user configurable to avoid memory problems with high dimensional data - [SPARK-17090](https://issues.apache.org/jira/browse/SPARK-17090)
* Refactor helper classes out of `LogisticRegression.scala` - [SPARK-17135](https://issues.apache.org/jira/browse/SPARK-17135)
* Design optimizer interface for added flexibility in ML algos - [SPARK-17136](https://issues.apache.org/jira/browse/SPARK-17136)
* Support compressing the coefficients and intercepts for MLOR models - [SPARK-17137](https://issues.apache.org/jira/browse/SPARK-17137)
Author: sethah
Closes #13796 from sethah/SPARK-7159_M.
---
.../classification/LogisticRegression.scala | 425 +++++--
.../MultinomialLogisticRegression.scala | 620 ++++++++++
.../MultinomialLogisticRegressionSuite.scala | 1056 +++++++++++++++++
.../apache/spark/ml/util/MLTestingUtils.scala | 49 +-
4 files changed, 2062 insertions(+), 88 deletions(-)
create mode 100644 mllib/src/main/scala/org/apache/spark/ml/classification/MultinomialLogisticRegression.scala
create mode 100644 mllib/src/test/scala/org/apache/spark/ml/classification/MultinomialLogisticRegressionSuite.scala
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index fce3935d396fe..ea31c68e4c943 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -63,6 +63,7 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
* equivalent.
*
* Default is 0.5.
+ *
* @group setParam
*/
def setThreshold(value: Double): this.type = {
@@ -131,6 +132,7 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
/**
* If [[threshold]] and [[thresholds]] are both set, ensures they are consistent.
+ *
* @throws IllegalArgumentException if [[threshold]] and [[thresholds]] are not equivalent
*/
protected def checkThresholdConsistency(): Unit = {
@@ -153,8 +155,8 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
/**
* Logistic regression.
- * Currently, this class only supports binary classification. It will support multiclass
- * in the future.
+ * Currently, this class only supports binary classification. For multiclass classification,
+ * use [[MultinomialLogisticRegression]]
*/
@Since("1.2.0")
class LogisticRegression @Since("1.2.0") (
@@ -168,6 +170,7 @@ class LogisticRegression @Since("1.2.0") (
/**
* Set the regularization parameter.
* Default is 0.0.
+ *
* @group setParam
*/
@Since("1.2.0")
@@ -179,6 +182,7 @@ class LogisticRegression @Since("1.2.0") (
* For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.
* For 0 < alpha < 1, the penalty is a combination of L1 and L2.
* Default is 0.0 which is an L2 penalty.
+ *
* @group setParam
*/
@Since("1.4.0")
@@ -188,6 +192,7 @@ class LogisticRegression @Since("1.2.0") (
/**
* Set the maximum number of iterations.
* Default is 100.
+ *
* @group setParam
*/
@Since("1.2.0")
@@ -198,6 +203,7 @@ class LogisticRegression @Since("1.2.0") (
* Set the convergence tolerance of iterations.
* Smaller value will lead to higher accuracy with the cost of more iterations.
* Default is 1E-6.
+ *
* @group setParam
*/
@Since("1.4.0")
@@ -207,6 +213,7 @@ class LogisticRegression @Since("1.2.0") (
/**
* Whether to fit an intercept term.
* Default is true.
+ *
* @group setParam
*/
@Since("1.4.0")
@@ -220,6 +227,7 @@ class LogisticRegression @Since("1.2.0") (
* the models should be always converged to the same solution when no regularization
* is applied. In R's GLMNET package, the default behavior is true as well.
* Default is true.
+ *
* @group setParam
*/
@Since("1.5.0")
@@ -233,9 +241,10 @@ class LogisticRegression @Since("1.2.0") (
override def getThreshold: Double = super.getThreshold
/**
- * Whether to over-/under-sample training instances according to the given weights in weightCol.
- * If not set or empty String, all instances are treated equally (weight 1.0).
+ * Sets the value of param [[weightCol]].
+ * If this is not set or empty, we treat all instance weights as 1.0.
* Default is not set, so all instances have weight one.
+ *
* @group setParam
*/
@Since("1.6.0")
@@ -310,12 +319,15 @@ class LogisticRegression @Since("1.2.0") (
throw new SparkException(msg)
}
+ val isConstantLabel = histogram.count(_ != 0) == 1
+
if (numClasses > 2) {
- val msg = s"Currently, LogisticRegression with ElasticNet in ML package only supports " +
- s"binary classification. Found $numClasses in the input dataset."
+ val msg = s"LogisticRegression with ElasticNet in ML package only supports " +
+ s"binary classification. Found $numClasses in the input dataset. Consider using " +
+ s"MultinomialLogisticRegression instead."
logError(msg)
throw new SparkException(msg)
- } else if ($(fitIntercept) && numClasses == 2 && histogram(0) == 0.0) {
+ } else if ($(fitIntercept) && numClasses == 2 && isConstantLabel) {
logWarning(s"All labels are one and fitIntercept=true, so the coefficients will be " +
s"zeros and the intercept will be positive infinity; as a result, " +
s"training is not needed.")
@@ -326,12 +338,9 @@ class LogisticRegression @Since("1.2.0") (
s"training is not needed.")
(Vectors.sparse(numFeatures, Seq()), Double.NegativeInfinity, Array.empty[Double])
} else {
- if (!$(fitIntercept) && numClasses == 2 && histogram(0) == 0.0) {
- logWarning(s"All labels are one and fitIntercept=false. It's a dangerous ground, " +
- s"so the algorithm may not converge.")
- } else if (!$(fitIntercept) && numClasses == 1) {
- logWarning(s"All labels are zero and fitIntercept=false. It's a dangerous ground, " +
- s"so the algorithm may not converge.")
+ if (!$(fitIntercept) && isConstantLabel) {
+ logWarning(s"All labels belong to a single class and fitIntercept=false. It's a " +
+ s"dangerous ground, so the algorithm may not converge.")
}
val featuresMean = summarizer.mean.toArray
@@ -349,7 +358,7 @@ class LogisticRegression @Since("1.2.0") (
val bcFeaturesStd = instances.context.broadcast(featuresStd)
val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept),
- $(standardization), bcFeaturesStd, regParamL2)
+ $(standardization), bcFeaturesStd, regParamL2, multinomial = false)
val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) {
new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
@@ -416,7 +425,7 @@ class LogisticRegression @Since("1.2.0") (
/*
Note that in Logistic Regression, the objective history (loss + regularization)
- is log-likelihood which is invariance under feature standardization. As a result,
+ is log-likelihood which is invariant under feature standardization. As a result,
the objective history from optimizer is the same as the one in the original space.
*/
val arrayBuilder = mutable.ArrayBuilder.make[Double]
@@ -559,6 +568,7 @@ class LogisticRegressionModel private[spark] (
/**
* Evaluates the model on a test dataset.
+ *
* @param dataset Test dataset to evaluate model on.
*/
@Since("2.0.0")
@@ -681,6 +691,7 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] {
val data = sparkSession.read.format("parquet").load(dataPath)
// We will need numClasses, numFeatures in the future for multinomial logreg support.
+ // TODO: remove numClasses and numFeatures fields?
val Row(numClasses: Int, numFeatures: Int, intercept: Double, coefficients: Vector) =
MLUtils.convertVectorColumnsToML(data, "coefficients")
.select("numClasses", "numFeatures", "intercept", "coefficients")
@@ -710,6 +721,7 @@ private[classification] class MultiClassSummarizer extends Serializable {
/**
* Add a new label into this MultilabelSummarizer, and update the distinct map.
+ *
* @param label The label for this data point.
* @param weight The weight of this instances.
* @return This MultilabelSummarizer
@@ -933,32 +945,310 @@ class BinaryLogisticRegressionSummary private[classification] (
}
/**
- * LogisticAggregator computes the gradient and loss for binary logistic loss function, as used
- * in binary classification for instances in sparse or dense vector in an online fashion.
- *
- * Note that multinomial logistic loss is not supported yet!
+ * LogisticAggregator computes the gradient and loss for binary or multinomial logistic (softmax)
+ * loss function, as used in classification for instances in sparse or dense vector in an online
+ * fashion.
*
- * Two LogisticAggregator can be merged together to have a summary of loss and gradient of
+ * Two LogisticAggregators can be merged together to have a summary of loss and gradient of
* the corresponding joint dataset.
*
+ * For improving the convergence rate during the optimization process and also to prevent against
+ * features with very large variances exerting an overly large influence during model training,
+ * packages like R's GLMNET perform the scaling to unit variance and remove the mean in order to
+ * reduce the condition number. The model is then trained in this scaled space, but returns the
+ * coefficients in the original scale. See page 9 in
+ * http://cran.r-project.org/web/packages/glmnet/glmnet.pdf
+ *
+ * However, we don't want to apply the [[org.apache.spark.ml.feature.StandardScaler]] on the
+ * training dataset, and then cache the standardized dataset since it will create a lot of overhead.
+ * As a result, we perform the scaling implicitly when we compute the objective function (though
+ * we do not subtract the mean).
+ *
+ * Note that there is a difference between multinomial (softmax) and binary loss. The binary case
+ * uses one outcome class as a "pivot" and regresses the other class against the pivot. In the
+ * multinomial case, the softmax loss function is used to model each class probability
+ * independently. Using softmax loss produces `K` sets of coefficients, while using a pivot class
+ * produces `K - 1` sets of coefficients (a single coefficient vector in the binary case). In the
+ * binary case, we can say that the coefficients are shared between the positive and negative
+ * classes. When regularization is applied, multinomial (softmax) loss will produce a result
+ * different from binary loss since the positive and negative don't share the coefficients while the
+ * binary regression shares the coefficients between positive and negative.
+ *
+ * The following is a mathematical derivation for the multinomial (softmax) loss.
+ *
+ * The probability of the multinomial outcome $y$ taking on any of the K possible outcomes is:
+ *
+ *
+ * $$
+ * P(y_i=0|\vec{x}_i, \beta) = \frac{e^{\vec{x}_i^T \vec{\beta}_0}}{\sum_{k=0}^{K-1}
+ * e^{\vec{x}_i^T \vec{\beta}_k}} \\
+ * P(y_i=1|\vec{x}_i, \beta) = \frac{e^{\vec{x}_i^T \vec{\beta}_1}}{\sum_{k=0}^{K-1}
+ * e^{\vec{x}_i^T \vec{\beta}_k}}\\
+ * P(y_i=K-1|\vec{x}_i, \beta) = \frac{e^{\vec{x}_i^T \vec{\beta}_{K-1}}\,}{\sum_{k=0}^{K-1}
+ * e^{\vec{x}_i^T \vec{\beta}_k}}
+ * $$
+ *
+ *
+ * The model coefficients $\beta = (\beta_0, \beta_1, \beta_2, ..., \beta_{K-1})$ become a matrix
+ * which has dimension of $K \times (N+1)$ if the intercepts are added. If the intercepts are not
+ * added, the dimension will be $K \times N$.
+ *
+ * Note that the coefficients in the model above lack identifiability. That is, any constant scalar
+ * can be added to all of the coefficients and the probabilities remain the same.
+ *
+ *
+ * $$
+ * \begin{align}
+ * \frac{e^{\vec{x}_i^T \left(\vec{\beta}_0 + \vec{c}\right)}}{\sum_{k=0}^{K-1}
+ * e^{\vec{x}_i^T \left(\vec{\beta}_k + \vec{c}\right)}}
+ * = \frac{e^{\vec{x}_i^T \vec{\beta}_0}e^{\vec{x}_i^T \vec{c}}\,}{e^{\vec{x}_i^T \vec{c}}
+ * \sum_{k=0}^{K-1} e^{\vec{x}_i^T \vec{\beta}_k}}
+ * = \frac{e^{\vec{x}_i^T \vec{\beta}_0}}{\sum_{k=0}^{K-1} e^{\vec{x}_i^T \vec{\beta}_k}}
+ * \end{align}
+ * $$
+ *
+ *
+ * However, when regularization is added to the loss function, the coefficients are indeed
+ * identifiable because there is only one set of coefficients which minimizes the regularization
+ * term. When no regularization is applied, we choose the coefficients with the minimum L2
+ * penalty for consistency and reproducibility. For further discussion see:
+ *
+ * Friedman, et al. "Regularization Paths for Generalized Linear Models via Coordinate Descent"
+ *
+ * The loss of objective function for a single instance of data (we do not include the
+ * regularization term here for simplicity) can be written as
+ *
+ *
+ * $$
+ * \begin{align}
+ * \ell\left(\beta, x_i\right) &= -log{P\left(y_i \middle| \vec{x}_i, \beta\right)} \\
+ * &= log\left(\sum_{k=0}^{K-1}e^{\vec{x}_i^T \vec{\beta}_k}\right) - \vec{x}_i^T \vec{\beta}_y\\
+ * &= log\left(\sum_{k=0}^{K-1} e^{margins_k}\right) - margins_y
+ * \end{align}
+ * $$
+ *
+ *
+ * where ${margins}_k = \vec{x}_i^T \vec{\beta}_k$.
+ *
+ * For optimization, we have to calculate the first derivative of the loss function, and a simple
+ * calculation shows that
+ *
+ *
+ * $$
+ * \begin{align}
+ * \frac{\partial \ell(\beta, \vec{x}_i, w_i)}{\partial \beta_{j, k}}
+ * &= x_{i,j} \cdot w_i \cdot \left(\frac{e^{\vec{x}_i \cdot \vec{\beta}_k}}{\sum_{k'=0}^{K-1}
+ * e^{\vec{x}_i \cdot \vec{\beta}_{k'}}\,} - I_{y=k}\right) \\
+ * &= x_{i, j} \cdot w_i \cdot multiplier_k
+ * \end{align}
+ * $$
+ *
+ *
+ * where $w_i$ is the sample weight, $I_{y=k}$ is an indicator function
+ *
+ *
+ * $$
+ * I_{y=k} = \begin{cases}
+ * 1 & y = k \\
+ * 0 & else
+ * \end{cases}
+ * $$
+ *
+ *
+ * and
+ *
+ *
+ * $$
+ * multiplier_k = \left(\frac{e^{\vec{x}_i \cdot \vec{\beta}_k}}{\sum_{k=0}^{K-1}
+ * e^{\vec{x}_i \cdot \vec{\beta}_k}} - I_{y=k}\right)
+ * $$
+ *
+ *
+ * If any of margins is larger than 709.78, the numerical computation of multiplier and loss
+ * function will suffer from arithmetic overflow. This issue occurs when there are outliers in
+ * data which are far away from the hyperplane, and this will cause the failing of training once
+ * infinity is introduced. Note that this is only a concern when max(margins) > 0.
+ *
+ * Fortunately, when max(margins) = maxMargin > 0, the loss function and the multiplier can easily
+ * be rewritten into the following equivalent numerically stable formula.
+ *
+ *
+ * $$
+ * \ell\left(\beta, x\right) = log\left(\sum_{k=0}^{K-1} e^{margins_k - maxMargin}\right) -
+ * margins_{y} + maxMargin
+ * $$
+ *
+ *
+ * Note that each term, $(margins_k - maxMargin)$ in the exponential is no greater than zero; as a
+ * result, overflow will not happen with this formula.
+ *
+ * For $multiplier$, a similar trick can be applied as the following,
+ *
+ *
+ * $$
+ * multiplier_k = \left(\frac{e^{\vec{x}_i \cdot \vec{\beta}_k - maxMargin}}{\sum_{k'=0}^{K-1}
+ * e^{\vec{x}_i \cdot \vec{\beta}_{k'} - maxMargin}} - I_{y=k}\right)
+ * $$
+ *
+ *
* @param bcCoefficients The broadcast coefficients corresponding to the features.
* @param bcFeaturesStd The broadcast standard deviation values of the features.
* @param numClasses the number of possible outcomes for k classes classification problem in
* Multinomial Logistic Regression.
* @param fitIntercept Whether to fit an intercept term.
+ * @param multinomial Whether to use multinomial (softmax) or binary loss
*/
private class LogisticAggregator(
- val bcCoefficients: Broadcast[Vector],
- val bcFeaturesStd: Broadcast[Array[Double]],
- private val numFeatures: Int,
+ bcCoefficients: Broadcast[Vector],
+ bcFeaturesStd: Broadcast[Array[Double]],
numClasses: Int,
- fitIntercept: Boolean) extends Serializable {
+ fitIntercept: Boolean,
+ multinomial: Boolean) extends Serializable with Logging {
+
+ private val numFeatures = bcFeaturesStd.value.length
+ private val numFeaturesPlusIntercept = if (fitIntercept) numFeatures + 1 else numFeatures
+ private val coefficientSize = bcCoefficients.value.size
+ if (multinomial) {
+ require(numClasses == coefficientSize / numFeaturesPlusIntercept, s"The number of " +
+ s"coefficients should be ${numClasses * numFeaturesPlusIntercept} but was $coefficientSize")
+ } else {
+ require(coefficientSize == numFeaturesPlusIntercept, s"Expected $numFeaturesPlusIntercept " +
+ s"coefficients but got $coefficientSize")
+ require(numClasses == 1 || numClasses == 2, s"Binary logistic aggregator requires numClasses " +
+ s"in {1, 2} but found $numClasses.")
+ }
private var weightSum = 0.0
private var lossSum = 0.0
- private val gradientSumArray =
- Array.ofDim[Double](if (fitIntercept) numFeatures + 1 else numFeatures)
+ private val gradientSumArray = Array.ofDim[Double](coefficientSize)
+
+ if (multinomial && numClasses <= 2) {
+ logInfo(s"Multinomial logistic regression for binary classification yields separate " +
+ s"coefficients for positive and negative classes. When no regularization is applied, the" +
+ s"result will be effectively the same as binary logistic regression. When regularization" +
+ s"is applied, multinomial loss will produce a result different from binary loss.")
+ }
+
+ /** Update gradient and loss using binary loss function. */
+ private def binaryUpdateInPlace(
+ features: Vector,
+ weight: Double,
+ label: Double): Unit = {
+
+ val localFeaturesStd = bcFeaturesStd.value
+ val localCoefficients = bcCoefficients.value
+ val localGradientArray = gradientSumArray
+ val margin = - {
+ var sum = 0.0
+ features.foreachActive { (index, value) =>
+ if (localFeaturesStd(index) != 0.0 && value != 0.0) {
+ sum += localCoefficients(index) * value / localFeaturesStd(index)
+ }
+ }
+ if (fitIntercept) sum += localCoefficients(numFeaturesPlusIntercept - 1)
+ sum
+ }
+
+ val multiplier = weight * (1.0 / (1.0 + math.exp(margin)) - label)
+
+ features.foreachActive { (index, value) =>
+ if (localFeaturesStd(index) != 0.0 && value != 0.0) {
+ localGradientArray(index) += multiplier * value / localFeaturesStd(index)
+ }
+ }
+
+ if (fitIntercept) {
+ localGradientArray(numFeaturesPlusIntercept - 1) += multiplier
+ }
+
+ if (label > 0) {
+ // The following is equivalent to log(1 + exp(margin)) but more numerically stable.
+ lossSum += weight * MLUtils.log1pExp(margin)
+ } else {
+ lossSum += weight * (MLUtils.log1pExp(margin) - margin)
+ }
+ }
+
+ /** Update gradient and loss using multinomial (softmax) loss function. */
+ private def multinomialUpdateInPlace(
+ features: Vector,
+ weight: Double,
+ label: Double): Unit = {
+ // TODO: use level 2 BLAS operations
+ /*
+ Note: this can still be used when numClasses = 2 for binary
+ logistic regression without pivoting.
+ */
+ val localFeaturesStd = bcFeaturesStd.value
+ val localCoefficients = bcCoefficients.value
+ val localGradientArray = gradientSumArray
+
+ // marginOfLabel is margins(label) in the formula
+ var marginOfLabel = 0.0
+ var maxMargin = Double.NegativeInfinity
+
+ val margins = Array.tabulate(numClasses) { i =>
+ var margin = 0.0
+ features.foreachActive { (index, value) =>
+ if (localFeaturesStd(index) != 0.0 && value != 0.0) {
+ margin += localCoefficients(i * numFeaturesPlusIntercept + index) *
+ value / localFeaturesStd(index)
+ }
+ }
+
+ if (fitIntercept) {
+ margin += localCoefficients(i * numFeaturesPlusIntercept + numFeatures)
+ }
+ if (i == label.toInt) marginOfLabel = margin
+ if (margin > maxMargin) {
+ maxMargin = margin
+ }
+ margin
+ }
+
+ /**
+ * When maxMargin > 0, the original formula could cause overflow.
+ * We address this by subtracting maxMargin from all the margins, so it's guaranteed
+ * that all of the new margins will be smaller than zero to prevent arithmetic overflow.
+ */
+ val sum = {
+ var temp = 0.0
+ if (maxMargin > 0) {
+ for (i <- 0 until numClasses) {
+ margins(i) -= maxMargin
+ temp += math.exp(margins(i))
+ }
+ } else {
+ for (i <- 0 until numClasses) {
+ temp += math.exp(margins(i))
+ }
+ }
+ temp
+ }
+
+ for (i <- 0 until numClasses) {
+ val multiplier = math.exp(margins(i)) / sum - {
+ if (label == i) 1.0 else 0.0
+ }
+ features.foreachActive { (index, value) =>
+ if (localFeaturesStd(index) != 0.0 && value != 0.0) {
+ localGradientArray(i * numFeaturesPlusIntercept + index) +=
+ weight * multiplier * value / localFeaturesStd(index)
+ }
+ }
+ if (fitIntercept) {
+ localGradientArray(i * numFeaturesPlusIntercept + numFeatures) += weight * multiplier
+ }
+ }
+
+ val loss = if (maxMargin > 0) {
+ math.log(sum) - marginOfLabel + maxMargin
+ } else {
+ math.log(sum) - marginOfLabel
+ }
+ lossSum += weight * loss
+ }
/**
* Add a new training instance to this LogisticAggregator, and update the loss and gradient
@@ -975,52 +1265,10 @@ private class LogisticAggregator(
if (weight == 0.0) return this
- val coefficientsArray = bcCoefficients.value match {
- case dv: DenseVector => dv.values
- case _ =>
- throw new IllegalArgumentException(
- "coefficients only supports dense vector" +
- s"but got type ${bcCoefficients.value.getClass}.")
- }
- val localGradientSumArray = gradientSumArray
-
- val featuresStd = bcFeaturesStd.value
- numClasses match {
- case 2 =>
- // For Binary Logistic Regression.
- val margin = - {
- var sum = 0.0
- features.foreachActive { (index, value) =>
- if (featuresStd(index) != 0.0 && value != 0.0) {
- sum += coefficientsArray(index) * (value / featuresStd(index))
- }
- }
- sum + {
- if (fitIntercept) coefficientsArray(numFeatures) else 0.0
- }
- }
-
- val multiplier = weight * (1.0 / (1.0 + math.exp(margin)) - label)
-
- features.foreachActive { (index, value) =>
- if (featuresStd(index) != 0.0 && value != 0.0) {
- localGradientSumArray(index) += multiplier * (value / featuresStd(index))
- }
- }
-
- if (fitIntercept) {
- localGradientSumArray(numFeatures) += multiplier
- }
-
- if (label > 0) {
- // The following is equivalent to log(1 + exp(margin)) but more numerically stable.
- lossSum += weight * MLUtils.log1pExp(margin)
- } else {
- lossSum += weight * (MLUtils.log1pExp(margin) - margin)
- }
- case _ =>
- new NotImplementedError("LogisticRegression with ElasticNet in ML package " +
- "only supports binary classification for now.")
+ if (multinomial) {
+ multinomialUpdateInPlace(features, weight, label)
+ } else {
+ binaryUpdateInPlace(features, weight, label)
}
weightSum += weight
this
@@ -1071,8 +1319,8 @@ private class LogisticAggregator(
}
/**
- * LogisticCostFun implements Breeze's DiffFunction[T] for a multinomial logistic loss function,
- * as used in multi-class classification (it is also used in binary logistic regression).
+ * LogisticCostFun implements Breeze's DiffFunction[T] for a multinomial (softmax) logistic loss
+ * function, as used in multi-class classification (it is also used in binary logistic regression).
* It returns the loss and gradient with L2 regularization at a particular point (coefficients).
* It's used in Breeze's convex optimization routines.
*/
@@ -1082,36 +1330,36 @@ private class LogisticCostFun(
fitIntercept: Boolean,
standardization: Boolean,
bcFeaturesStd: Broadcast[Array[Double]],
- regParamL2: Double) extends DiffFunction[BDV[Double]] {
+ regParamL2: Double,
+ multinomial: Boolean) extends DiffFunction[BDV[Double]] {
- val featuresStd = bcFeaturesStd.value
override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = {
- val numFeatures = featuresStd.length
val coeffs = Vectors.fromBreeze(coefficients)
val bcCoeffs = instances.context.broadcast(coeffs)
- val n = coeffs.size
+ val featuresStd = bcFeaturesStd.value
+ val numFeatures = featuresStd.length
val logisticAggregator = {
val seqOp = (c: LogisticAggregator, instance: Instance) => c.add(instance)
val combOp = (c1: LogisticAggregator, c2: LogisticAggregator) => c1.merge(c2)
instances.treeAggregate(
- new LogisticAggregator(bcCoeffs, bcFeaturesStd, numFeatures, numClasses, fitIntercept)
+ new LogisticAggregator(bcCoeffs, bcFeaturesStd, numClasses, fitIntercept,
+ multinomial)
)(seqOp, combOp)
}
val totalGradientArray = logisticAggregator.gradient.toArray
-
// regVal is the sum of coefficients squares excluding intercept for L2 regularization.
val regVal = if (regParamL2 == 0.0) {
0.0
} else {
var sum = 0.0
- coeffs.foreachActive { (index, value) =>
- // If `fitIntercept` is true, the last term which is intercept doesn't
- // contribute to the regularization.
- if (index != numFeatures) {
+ coeffs.foreachActive { case (index, value) =>
+ // We do not apply regularization to the intercepts
+ val isIntercept = fitIntercept && ((index + 1) % (numFeatures + 1) == 0)
+ if (!isIntercept) {
// The following code will compute the loss of the regularization; also
// the gradient of the regularization, and add back to totalGradientArray.
sum += {
@@ -1119,13 +1367,18 @@ private class LogisticCostFun(
totalGradientArray(index) += regParamL2 * value
value * value
} else {
- if (featuresStd(index) != 0.0) {
+ val featureIndex = if (fitIntercept) {
+ index % (numFeatures + 1)
+ } else {
+ index % numFeatures
+ }
+ if (featuresStd(featureIndex) != 0.0) {
// If `standardization` is false, we still standardize the data
// to improve the rate of convergence; as a result, we have to
// perform this reverse standardization by penalizing each component
// differently to get effectively the same objective function when
// the training dataset is not standardized.
- val temp = value / (featuresStd(index) * featuresStd(index))
+ val temp = value / (featuresStd(featureIndex) * featuresStd(featureIndex))
totalGradientArray(index) += regParamL2 * temp
value * temp
} else {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultinomialLogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultinomialLogisticRegression.scala
new file mode 100644
index 0000000000000..dfadd68c5f476
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultinomialLogisticRegression.scala
@@ -0,0 +1,620 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import scala.collection.mutable
+
+import breeze.linalg.{DenseVector => BDV}
+import breeze.optimize.{CachedDiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.SparkException
+import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.internal.Logging
+import org.apache.spark.ml.feature.Instance
+import org.apache.spark.ml.linalg._
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.util._
+import org.apache.spark.mllib.linalg.VectorImplicits._
+import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Dataset, Row}
+import org.apache.spark.sql.functions.{col, lit}
+import org.apache.spark.sql.types.DoubleType
+import org.apache.spark.storage.StorageLevel
+
+/**
+ * Params for multinomial logistic (softmax) regression.
+ */
+private[classification] trait MultinomialLogisticRegressionParams
+ extends ProbabilisticClassifierParams with HasRegParam with HasElasticNetParam with HasMaxIter
+ with HasFitIntercept with HasTol with HasStandardization with HasWeightCol {
+
+ /**
+ * Set thresholds in multiclass (or binary) classification to adjust the probability of
+ * predicting each class. Array must have length equal to the number of classes, with values >= 0.
+ * The class with largest value p/t is predicted, where p is the original probability of that
+ * class and t is the class' threshold.
+ *
+ * @group setParam
+ */
+ def setThresholds(value: Array[Double]): this.type = {
+ set(thresholds, value)
+ }
+
+ /**
+ * Get thresholds for binary or multiclass classification.
+ *
+ * @group getParam
+ */
+ override def getThresholds: Array[Double] = {
+ $(thresholds)
+ }
+}
+
+/**
+ * :: Experimental ::
+ * Multinomial Logistic (softmax) regression.
+ */
+@Since("2.1.0")
+@Experimental
+class MultinomialLogisticRegression @Since("2.1.0") (
+ @Since("2.1.0") override val uid: String)
+ extends ProbabilisticClassifier[Vector,
+ MultinomialLogisticRegression, MultinomialLogisticRegressionModel]
+ with MultinomialLogisticRegressionParams with DefaultParamsWritable with Logging {
+
+ @Since("2.1.0")
+ def this() = this(Identifiable.randomUID("mlogreg"))
+
+ /**
+ * Set the regularization parameter.
+ * Default is 0.0.
+ *
+ * @group setParam
+ */
+ @Since("2.1.0")
+ def setRegParam(value: Double): this.type = set(regParam, value)
+ setDefault(regParam -> 0.0)
+
+ /**
+ * Set the ElasticNet mixing parameter.
+ * For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.
+ * For 0 < alpha < 1, the penalty is a combination of L1 and L2.
+ * Default is 0.0 which is an L2 penalty.
+ *
+ * @group setParam
+ */
+ @Since("2.1.0")
+ def setElasticNetParam(value: Double): this.type = set(elasticNetParam, value)
+ setDefault(elasticNetParam -> 0.0)
+
+ /**
+ * Set the maximum number of iterations.
+ * Default is 100.
+ *
+ * @group setParam
+ */
+ @Since("2.1.0")
+ def setMaxIter(value: Int): this.type = set(maxIter, value)
+ setDefault(maxIter -> 100)
+
+ /**
+ * Set the convergence tolerance of iterations.
+ * Smaller value will lead to higher accuracy with the cost of more iterations.
+ * Default is 1E-6.
+ *
+ * @group setParam
+ */
+ @Since("2.1.0")
+ def setTol(value: Double): this.type = set(tol, value)
+ setDefault(tol -> 1E-6)
+
+ /**
+ * Whether to fit an intercept term.
+ * Default is true.
+ *
+ * @group setParam
+ */
+ @Since("2.1.0")
+ def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
+ setDefault(fitIntercept -> true)
+
+ /**
+ * Whether to standardize the training features before fitting the model.
+ * The coefficients of models will be always returned on the original scale,
+ * so it will be transparent for users. Note that with/without standardization,
+ * the models should always converge to the same solution when no regularization
+ * is applied. In R's GLMNET package, the default behavior is true as well.
+ * Default is true.
+ *
+ * @group setParam
+ */
+ @Since("2.1.0")
+ def setStandardization(value: Boolean): this.type = set(standardization, value)
+ setDefault(standardization -> true)
+
+ /**
+ * Sets the value of param [[weightCol]].
+ * If this is not set or empty, we treat all instance weights as 1.0.
+ * Default is not set, so all instances have weight one.
+ *
+ * @group setParam
+ */
+ @Since("2.1.0")
+ def setWeightCol(value: String): this.type = set(weightCol, value)
+
+ @Since("2.1.0")
+ override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value)
+
+ override protected[spark] def train(dataset: Dataset[_]): MultinomialLogisticRegressionModel = {
+ val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
+ val instances: RDD[Instance] =
+ dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
+ case Row(label: Double, weight: Double, features: Vector) =>
+ Instance(label, weight, features)
+ }
+
+ val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
+ if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
+
+ val instr = Instrumentation.create(this, instances)
+ instr.logParams(regParam, elasticNetParam, standardization, thresholds,
+ maxIter, tol, fitIntercept)
+
+ val (summarizer, labelSummarizer) = {
+ val seqOp = (c: (MultivariateOnlineSummarizer, MultiClassSummarizer),
+ instance: Instance) =>
+ (c._1.add(instance.features, instance.weight), c._2.add(instance.label, instance.weight))
+
+ val combOp = (c1: (MultivariateOnlineSummarizer, MultiClassSummarizer),
+ c2: (MultivariateOnlineSummarizer, MultiClassSummarizer)) =>
+ (c1._1.merge(c2._1), c1._2.merge(c2._2))
+
+ instances.treeAggregate(
+ new MultivariateOnlineSummarizer, new MultiClassSummarizer)(seqOp, combOp)
+ }
+
+ val histogram = labelSummarizer.histogram
+ val numInvalid = labelSummarizer.countInvalid
+ val numFeatures = summarizer.mean.size
+ val numFeaturesPlusIntercept = if (getFitIntercept) numFeatures + 1 else numFeatures
+
+ val numClasses = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
+ case Some(n: Int) =>
+ require(n >= histogram.length, s"Specified number of classes $n was " +
+ s"less than the number of unique labels ${histogram.length}")
+ n
+ case None => histogram.length
+ }
+
+ instr.logNumClasses(numClasses)
+ instr.logNumFeatures(numFeatures)
+
+ val (coefficients, intercepts, objectiveHistory) = {
+ if (numInvalid != 0) {
+ val msg = s"Classification labels should be in {0 to ${numClasses - 1} " +
+ s"Found $numInvalid invalid labels."
+ logError(msg)
+ throw new SparkException(msg)
+ }
+
+ val isConstantLabel = histogram.count(_ != 0) == 1
+
+ if ($(fitIntercept) && isConstantLabel) {
+ // we want to produce a model that will always predict the constant label so all the
+ // coefficients will be zero, and the constant label class intercept will be +inf
+ val constantLabelIndex = Vectors.dense(histogram).argmax
+ (Matrices.sparse(numClasses, numFeatures, Array.fill(numFeatures + 1)(0),
+ Array.empty[Int], Array.empty[Double]),
+ Vectors.sparse(numClasses, Seq((constantLabelIndex, Double.PositiveInfinity))),
+ Array.empty[Double])
+ } else {
+ if (!$(fitIntercept) && isConstantLabel) {
+ logWarning(s"All labels belong to a single class and fitIntercept=false. It's" +
+ s"a dangerous ground, so the algorithm may not converge.")
+ }
+
+ val featuresStd = summarizer.variance.toArray.map(math.sqrt)
+ val featuresMean = summarizer.mean.toArray
+ if (!$(fitIntercept) && (0 until numFeatures).exists { i =>
+ featuresStd(i) == 0.0 && featuresMean(i) != 0.0 }) {
+ logWarning("Fitting MultinomialLogisticRegressionModel without intercept on dataset " +
+ "with constant nonzero column, Spark MLlib outputs zero coefficients for constant " +
+ "nonzero columns. This behavior is the same as R glmnet but different from LIBSVM.")
+ }
+
+ val regParamL1 = $(elasticNetParam) * $(regParam)
+ val regParamL2 = (1.0 - $(elasticNetParam)) * $(regParam)
+
+ val bcFeaturesStd = instances.context.broadcast(featuresStd)
+ val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept),
+ $(standardization), bcFeaturesStd, regParamL2, multinomial = true)
+
+ val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) {
+ new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
+ } else {
+ val standardizationParam = $(standardization)
+ def regParamL1Fun = (index: Int) => {
+ // Remove the L1 penalization on the intercept
+ val isIntercept = $(fitIntercept) && ((index + 1) % numFeaturesPlusIntercept == 0)
+ if (isIntercept) {
+ 0.0
+ } else {
+ if (standardizationParam) {
+ regParamL1
+ } else {
+ val featureIndex = if ($(fitIntercept)) {
+ index % numFeaturesPlusIntercept
+ } else {
+ index % numFeatures
+ }
+ // If `standardization` is false, we still standardize the data
+ // to improve the rate of convergence; as a result, we have to
+ // perform this reverse standardization by penalizing each component
+ // differently to get effectively the same objective function when
+ // the training dataset is not standardized.
+ if (featuresStd(featureIndex) != 0.0) {
+ regParamL1 / featuresStd(featureIndex)
+ } else {
+ 0.0
+ }
+ }
+ }
+ }
+ new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol))
+ }
+
+ val initialCoefficientsWithIntercept = Vectors.zeros(numClasses * numFeaturesPlusIntercept)
+
+ if ($(fitIntercept)) {
+ /*
+ For multinomial logistic regression, when we initialize the coefficients as zeros,
+ it will converge faster if we initialize the intercepts such that
+ it follows the distribution of the labels.
+ {{{
+ P(1) = \exp(b_1) / Z
+ ...
+ P(K) = \exp(b_K) / Z
+ where Z = \sum_{k=1}^{K} \exp(b_k)
+ }}}
+ Since this doesn't have a unique solution, one of the solutions that satisfies the
+ above equations is
+ {{{
+ \exp(b_k) = count_k * \exp(\lambda)
+ b_k = \log(count_k) * \lambda
+ }}}
+ \lambda is a free parameter, so choose the phase \lambda such that the
+ mean is centered. This yields
+ {{{
+ b_k = \log(count_k)
+ b_k' = b_k - \mean(b_k)
+ }}}
+ */
+ val rawIntercepts = histogram.map(c => math.log(c + 1)) // add 1 for smoothing
+ val rawMean = rawIntercepts.sum / rawIntercepts.length
+ rawIntercepts.indices.foreach { i =>
+ initialCoefficientsWithIntercept.toArray(i * numFeaturesPlusIntercept + numFeatures) =
+ rawIntercepts(i) - rawMean
+ }
+ }
+
+ val states = optimizer.iterations(new CachedDiffFunction(costFun),
+ initialCoefficientsWithIntercept.asBreeze.toDenseVector)
+
+ /*
+ Note that in Multinomial Logistic Regression, the objective history
+ (loss + regularization) is log-likelihood which is invariant under feature
+ standardization. As a result, the objective history from optimizer is the same as the
+ one in the original space.
+ */
+ val arrayBuilder = mutable.ArrayBuilder.make[Double]
+ var state: optimizer.State = null
+ while (states.hasNext) {
+ state = states.next()
+ arrayBuilder += state.adjustedValue
+ }
+
+ if (state == null) {
+ val msg = s"${optimizer.getClass.getName} failed."
+ logError(msg)
+ throw new SparkException(msg)
+ }
+ bcFeaturesStd.destroy(blocking = false)
+
+ /*
+ The coefficients are trained in the scaled space; we're converting them back to
+ the original space.
+ Note that the intercept in scaled space and original space is the same;
+ as a result, no scaling is needed.
+ */
+ val rawCoefficients = state.x.toArray
+ val interceptsArray: Array[Double] = if ($(fitIntercept)) {
+ Array.tabulate(numClasses) { i =>
+ val coefIndex = (i + 1) * numFeaturesPlusIntercept - 1
+ rawCoefficients(coefIndex)
+ }
+ } else {
+ Array[Double]()
+ }
+
+ val coefficientArray: Array[Double] = Array.tabulate(numClasses * numFeatures) { i =>
+ // flatIndex will loop though rawCoefficients, and skip the intercept terms.
+ val flatIndex = if ($(fitIntercept)) i + i / numFeatures else i
+ val featureIndex = i % numFeatures
+ if (featuresStd(featureIndex) != 0.0) {
+ rawCoefficients(flatIndex) / featuresStd(featureIndex)
+ } else {
+ 0.0
+ }
+ }
+ val coefficientMatrix =
+ new DenseMatrix(numClasses, numFeatures, coefficientArray, isTransposed = true)
+
+ /*
+ When no regularization is applied, the coefficients lack identifiability because
+ we do not use a pivot class. We can add any constant value to the coefficients and
+ get the same likelihood. So here, we choose the mean centered coefficients for
+ reproducibility. This method follows the approach in glmnet, described here:
+
+ Friedman, et al. "Regularization Paths for Generalized Linear Models via
+ Coordinate Descent," https://core.ac.uk/download/files/153/6287975.pdf
+ */
+ if ($(regParam) == 0.0) {
+ val coefficientMean = coefficientMatrix.values.sum / (numClasses * numFeatures)
+ coefficientMatrix.update(_ - coefficientMean)
+ }
+ /*
+ The intercepts are never regularized, so we always center the mean.
+ */
+ val interceptVector = if (interceptsArray.nonEmpty) {
+ val interceptMean = interceptsArray.sum / numClasses
+ interceptsArray.indices.foreach { i => interceptsArray(i) -= interceptMean }
+ Vectors.dense(interceptsArray)
+ } else {
+ Vectors.sparse(numClasses, Seq())
+ }
+
+ (coefficientMatrix, interceptVector, arrayBuilder.result())
+ }
+ }
+
+ if (handlePersistence) instances.unpersist()
+
+ val model = copyValues(
+ new MultinomialLogisticRegressionModel(uid, coefficients, intercepts, numClasses))
+ instr.logSuccess(model)
+ model
+ }
+
+ @Since("2.1.0")
+ override def copy(extra: ParamMap): MultinomialLogisticRegression = defaultCopy(extra)
+}
+
+@Since("2.1.0")
+object MultinomialLogisticRegression extends DefaultParamsReadable[MultinomialLogisticRegression] {
+
+ @Since("2.1.0")
+ override def load(path: String): MultinomialLogisticRegression = super.load(path)
+}
+
+/**
+ * :: Experimental ::
+ * Model produced by [[MultinomialLogisticRegression]].
+ */
+@Since("2.1.0")
+@Experimental
+class MultinomialLogisticRegressionModel private[spark] (
+ @Since("2.1.0") override val uid: String,
+ @Since("2.1.0") val coefficients: Matrix,
+ @Since("2.1.0") val intercepts: Vector,
+ @Since("2.1.0") val numClasses: Int)
+ extends ProbabilisticClassificationModel[Vector, MultinomialLogisticRegressionModel]
+ with MultinomialLogisticRegressionParams with MLWritable {
+
+ @Since("2.1.0")
+ override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value)
+
+ @Since("2.1.0")
+ override def getThresholds: Array[Double] = super.getThresholds
+
+ @Since("2.1.0")
+ override val numFeatures: Int = coefficients.numCols
+
+ /** Margin (rawPrediction) for each class label. */
+ private val margins: Vector => Vector = (features) => {
+ val m = intercepts.toDense.copy
+ BLAS.gemv(1.0, coefficients, features, 1.0, m)
+ m
+ }
+
+ /** Score (probability) for each class label. */
+ private val scores: Vector => Vector = (features) => {
+ val m = margins(features)
+ val maxMarginIndex = m.argmax
+ val marginArray = m.toArray
+ val maxMargin = marginArray(maxMarginIndex)
+
+ // adjust margins for overflow
+ val sum = {
+ var temp = 0.0
+ var k = 0
+ while (k < numClasses) {
+ marginArray(k) = if (maxMargin > 0) {
+ math.exp(marginArray(k) - maxMargin)
+ } else {
+ math.exp(marginArray(k))
+ }
+ temp += marginArray(k)
+ k += 1
+ }
+ temp
+ }
+
+ val scores = Vectors.dense(marginArray)
+ BLAS.scal(1 / sum, scores)
+ scores
+ }
+
+ /**
+ * Predict label for the given feature vector.
+ * The behavior of this can be adjusted using [[thresholds]].
+ */
+ override protected def predict(features: Vector): Double = {
+ if (isDefined(thresholds)) {
+ val thresholds: Array[Double] = getThresholds
+ val probabilities = scores(features).toArray
+ var argMax = 0
+ var max = Double.NegativeInfinity
+ var i = 0
+ while (i < numClasses) {
+ if (thresholds(i) == 0.0) {
+ max = Double.PositiveInfinity
+ argMax = i
+ } else {
+ val scaled = probabilities(i) / thresholds(i)
+ if (scaled > max) {
+ max = scaled
+ argMax = i
+ }
+ }
+ i += 1
+ }
+ argMax
+ } else {
+ scores(features).argmax
+ }
+ }
+
+ override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
+ rawPrediction match {
+ case dv: DenseVector =>
+ val size = dv.size
+ val values = dv.values
+
+ // get the maximum margin
+ val maxMarginIndex = rawPrediction.argmax
+ val maxMargin = rawPrediction(maxMarginIndex)
+
+ if (maxMargin == Double.PositiveInfinity) {
+ var k = 0
+ while (k < size) {
+ values(k) = if (k == maxMarginIndex) 1.0 else 0.0
+ k += 1
+ }
+ } else {
+ val sum = {
+ var temp = 0.0
+ var k = 0
+ while (k < numClasses) {
+ values(k) = if (maxMargin > 0) {
+ math.exp(values(k) - maxMargin)
+ } else {
+ math.exp(values(k))
+ }
+ temp += values(k)
+ k += 1
+ }
+ temp
+ }
+ BLAS.scal(1 / sum, dv)
+ }
+ dv
+ case sv: SparseVector =>
+ throw new RuntimeException("Unexpected error in MultinomialLogisticRegressionModel:" +
+ " raw2probabilitiesInPlace encountered SparseVector")
+ }
+ }
+
+ override protected def predictRaw(features: Vector): Vector = margins(features)
+
+ @Since("2.1.0")
+ override def copy(extra: ParamMap): MultinomialLogisticRegressionModel = {
+ val newModel =
+ copyValues(
+ new MultinomialLogisticRegressionModel(uid, coefficients, intercepts, numClasses), extra)
+ newModel.setParent(parent)
+ }
+
+ /**
+ * Returns a [[org.apache.spark.ml.util.MLWriter]] instance for this ML instance.
+ *
+ * This does not save the [[parent]] currently.
+ */
+ @Since("2.1.0")
+ override def write: MLWriter =
+ new MultinomialLogisticRegressionModel.MultinomialLogisticRegressionModelWriter(this)
+}
+
+
+@Since("2.1.0")
+object MultinomialLogisticRegressionModel extends MLReadable[MultinomialLogisticRegressionModel] {
+
+ @Since("2.1.0")
+ override def read: MLReader[MultinomialLogisticRegressionModel] =
+ new MultinomialLogisticRegressionModelReader
+
+ @Since("2.1.0")
+ override def load(path: String): MultinomialLogisticRegressionModel = super.load(path)
+
+ /** [[MLWriter]] instance for [[MultinomialLogisticRegressionModel]] */
+ private[MultinomialLogisticRegressionModel]
+ class MultinomialLogisticRegressionModelWriter(instance: MultinomialLogisticRegressionModel)
+ extends MLWriter with Logging {
+
+ private case class Data(
+ numClasses: Int,
+ numFeatures: Int,
+ intercepts: Vector,
+ coefficients: Matrix)
+
+ override protected def saveImpl(path: String): Unit = {
+ // Save metadata and Params
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ // Save model data: numClasses, numFeatures, intercept, coefficients
+ val data = Data(instance.numClasses, instance.numFeatures, instance.intercepts,
+ instance.coefficients)
+ val dataPath = new Path(path, "data").toString
+ sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ }
+ }
+
+ private class MultinomialLogisticRegressionModelReader
+ extends MLReader[MultinomialLogisticRegressionModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[MultinomialLogisticRegressionModel].getName
+
+ override def load(path: String): MultinomialLogisticRegressionModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+
+ val dataPath = new Path(path, "data").toString
+ val data = sqlContext.read.format("parquet").load(dataPath)
+ .select("numClasses", "numFeatures", "intercepts", "coefficients").head()
+ val numClasses = data.getAs[Int](data.fieldIndex("numClasses"))
+ val intercepts = data.getAs[Vector](data.fieldIndex("intercepts"))
+ val coefficients = data.getAs[Matrix](data.fieldIndex("coefficients"))
+ val model =
+ new MultinomialLogisticRegressionModel(metadata.uid, coefficients, intercepts, numClasses)
+
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultinomialLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultinomialLogisticRegressionSuite.scala
new file mode 100644
index 0000000000000..0913fe559c562
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultinomialLogisticRegressionSuite.scala
@@ -0,0 +1,1056 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import scala.language.existentials
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.attribute.NominalAttribute
+import org.apache.spark.ml.classification.LogisticRegressionSuite._
+import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.linalg._
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.TestingUtils._
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
+
+class MultinomialLogisticRegressionSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+
+ @transient var dataset: Dataset[_] = _
+ @transient var multinomialDataset: DataFrame = _
+ private val eps: Double = 1e-5
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+
+ dataset = {
+ val nPoints = 100
+ val coefficients = Array(
+ -0.57997, 0.912083, -0.371077,
+ -0.16624, -0.84355, -0.048509)
+
+ val xMean = Array(5.843, 3.057)
+ val xVariance = Array(0.6856, 0.1899)
+
+ val testData = generateMultinomialLogisticInput(
+ coefficients, xMean, xVariance, addIntercept = true, nPoints, 42)
+
+ val df = spark.createDataFrame(sc.parallelize(testData, 4))
+ df.cache()
+ df
+ }
+
+ multinomialDataset = {
+ val nPoints = 10000
+ val coefficients = Array(
+ -0.57997, 0.912083, -0.371077, -0.819866, 2.688191,
+ -0.16624, -0.84355, -0.048509, -0.301789, 4.170682)
+
+ val xMean = Array(5.843, 3.057, 3.758, 1.199)
+ val xVariance = Array(0.6856, 0.1899, 3.116, 0.581)
+
+ val testData = generateMultinomialLogisticInput(
+ coefficients, xMean, xVariance, addIntercept = true, nPoints, 42)
+
+ val df = spark.createDataFrame(sc.parallelize(testData, 4))
+ df.cache()
+ df
+ }
+ }
+
+ /**
+ * Enable the ignored test to export the dataset into CSV format,
+ * so we can validate the training accuracy compared with R's glmnet package.
+ */
+ ignore("export test data into CSV format") {
+ val rdd = multinomialDataset.rdd.map { case Row(label: Double, features: Vector) =>
+ label + "," + features.toArray.mkString(",")
+ }.repartition(1)
+ rdd.saveAsTextFile("target/tmp/MultinomialLogisticRegressionSuite/multinomialDataset")
+ }
+
+ test("params") {
+ ParamsSuite.checkParams(new MultinomialLogisticRegression)
+ val model = new MultinomialLogisticRegressionModel("mLogReg",
+ Matrices.dense(2, 1, Array(0.0, 0.0)), Vectors.dense(0.0, 0.0), 2)
+ ParamsSuite.checkParams(model)
+ }
+
+ test("multinomial logistic regression: default params") {
+ val mlr = new MultinomialLogisticRegression
+ assert(mlr.getLabelCol === "label")
+ assert(mlr.getFeaturesCol === "features")
+ assert(mlr.getPredictionCol === "prediction")
+ assert(mlr.getRawPredictionCol === "rawPrediction")
+ assert(mlr.getProbabilityCol === "probability")
+ assert(!mlr.isDefined(mlr.weightCol))
+ assert(!mlr.isDefined(mlr.thresholds))
+ assert(mlr.getFitIntercept)
+ assert(mlr.getStandardization)
+ val model = mlr.fit(dataset)
+ model.transform(dataset)
+ .select("label", "probability", "prediction", "rawPrediction")
+ .collect()
+ assert(model.getFeaturesCol === "features")
+ assert(model.getPredictionCol === "prediction")
+ assert(model.getRawPredictionCol === "rawPrediction")
+ assert(model.getProbabilityCol === "probability")
+ assert(model.intercepts !== Vectors.dense(0.0, 0.0))
+ assert(model.hasParent)
+ }
+
+ test("multinomial logistic regression with intercept without regularization") {
+
+ val trainer1 = (new MultinomialLogisticRegression).setFitIntercept(true)
+ .setElasticNetParam(0.0).setRegParam(0.0).setStandardization(true).setMaxIter(100)
+ val trainer2 = (new MultinomialLogisticRegression).setFitIntercept(true)
+ .setElasticNetParam(0.0).setRegParam(0.0).setStandardization(false)
+
+ val model1 = trainer1.fit(multinomialDataset)
+ val model2 = trainer2.fit(multinomialDataset)
+
+ /*
+ Using the following R code to load the data and train the model using glmnet package.
+ > library("glmnet")
+ > data <- read.csv("path", header=FALSE)
+ > label = as.factor(data$V1)
+ > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+ > coefficients = coef(glmnet(features, label, family="multinomial", alpha = 0, lambda = 0))
+ > coefficients
+ $`0`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ -2.24493379
+ V2 0.25096771
+ V3 -0.03915938
+ V4 0.14766639
+ V5 0.36810817
+ $`1`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ 0.3778931
+ V2 -0.3327489
+ V3 0.8893666
+ V4 -0.2306948
+ V5 -0.4442330
+ $`2`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ 1.86704066
+ V2 0.08178121
+ V3 -0.85020722
+ V4 0.08302840
+ V5 0.07612480
+ */
+
+ val coefficientsR = new DenseMatrix(3, 4, Array(
+ 0.2509677, -0.0391594, 0.1476664, 0.3681082,
+ -0.3327489, 0.8893666, -0.2306948, -0.4442330,
+ 0.0817812, -0.8502072, 0.0830284, 0.0761248), isTransposed = true)
+ val interceptsR = Vectors.dense(-2.2449338, 0.3778931, 1.8670407)
+
+ assert(model1.coefficients ~== coefficientsR relTol 0.05)
+ assert(model1.coefficients.toArray.sum ~== 0.0 absTol eps)
+ assert(model1.intercepts ~== interceptsR relTol 0.05)
+ assert(model1.intercepts.toArray.sum ~== 0.0 absTol eps)
+ assert(model2.coefficients ~== coefficientsR relTol 0.05)
+ assert(model2.coefficients.toArray.sum ~== 0.0 absTol eps)
+ assert(model2.intercepts ~== interceptsR relTol 0.05)
+ assert(model2.intercepts.toArray.sum ~== 0.0 absTol eps)
+ }
+
+ test("multinomial logistic regression without intercept without regularization") {
+
+ val trainer1 = (new MultinomialLogisticRegression).setFitIntercept(false)
+ .setElasticNetParam(0.0).setRegParam(0.0).setStandardization(true)
+ val trainer2 = (new MultinomialLogisticRegression).setFitIntercept(false)
+ .setElasticNetParam(0.0).setRegParam(0.0).setStandardization(false)
+
+ val model1 = trainer1.fit(multinomialDataset)
+ val model2 = trainer2.fit(multinomialDataset)
+
+ /*
+ Using the following R code to load the data and train the model using glmnet package.
+ library("glmnet")
+ data <- read.csv("path", header=FALSE)
+ label = as.factor(data$V1)
+ features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+ coefficients = coef(glmnet(features, label, family="multinomial", alpha = 0, lambda = 0,
+ intercept=F))
+ > coefficients
+ $`0`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ .
+ V2 0.06992464
+ V3 -0.36562784
+ V4 0.12142680
+ V5 0.32052211
+ $`1`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ .
+ V2 -0.3036269
+ V3 0.9449630
+ V4 -0.2271038
+ V5 -0.4364839
+ $`2`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ .
+ V2 0.2337022
+ V3 -0.5793351
+ V4 0.1056770
+ V5 0.1159618
+ */
+
+ val coefficientsR = new DenseMatrix(3, 4, Array(
+ 0.0699246, -0.3656278, 0.1214268, 0.3205221,
+ -0.3036269, 0.9449630, -0.2271038, -0.4364839,
+ 0.2337022, -0.5793351, 0.1056770, 0.1159618), isTransposed = true)
+
+ assert(model1.coefficients ~== coefficientsR relTol 0.05)
+ assert(model1.coefficients.toArray.sum ~== 0.0 absTol eps)
+ assert(model1.intercepts.toArray === Array.fill(3)(0.0))
+ assert(model1.intercepts.toArray.sum ~== 0.0 absTol eps)
+ assert(model2.coefficients ~== coefficientsR relTol 0.05)
+ assert(model2.coefficients.toArray.sum ~== 0.0 absTol eps)
+ assert(model2.intercepts.toArray === Array.fill(3)(0.0))
+ assert(model2.intercepts.toArray.sum ~== 0.0 absTol eps)
+ }
+
+ test("multinomial logistic regression with intercept with L1 regularization") {
+
+ // use tighter constraints because OWL-QN solver takes longer to converge
+ val trainer1 = (new MultinomialLogisticRegression).setFitIntercept(true)
+ .setElasticNetParam(1.0).setRegParam(0.05).setStandardization(true)
+ .setMaxIter(300).setTol(1e-10)
+ val trainer2 = (new MultinomialLogisticRegression).setFitIntercept(true)
+ .setElasticNetParam(1.0).setRegParam(0.05).setStandardization(false)
+ .setMaxIter(300).setTol(1e-10)
+
+ val model1 = trainer1.fit(multinomialDataset)
+ val model2 = trainer2.fit(multinomialDataset)
+
+ /*
+ Use the following R code to load the data and train the model using glmnet package.
+ library("glmnet")
+ data <- read.csv("path", header=FALSE)
+ label = as.factor(data$V1)
+ features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+ coefficientsStd = coef(glmnet(features, label, family="multinomial", alpha = 1,
+ lambda = 0.05, standardization=T))
+ coefficients = coef(glmnet(features, label, family="multinomial", alpha = 1, lambda = 0.05,
+ standardization=F))
+ > coefficientsStd
+ $`0`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ -0.68988825
+ V2 .
+ V3 .
+ V4 .
+ V5 0.09404023
+
+ $`1`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ -0.2303499
+ V2 -0.1232443
+ V3 0.3258380
+ V4 -0.1564688
+ V5 -0.2053965
+
+ $`2`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ 0.9202381
+ V2 .
+ V3 -0.4803856
+ V4 .
+ V5 .
+
+ > coefficients
+ $`0`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ -0.44893320
+ V2 .
+ V3 .
+ V4 0.01933812
+ V5 0.03666044
+
+ $`1`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ 0.7376760
+ V2 -0.0577182
+ V3 .
+ V4 -0.2081718
+ V5 -0.1304592
+
+ $`2`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ -0.2887428
+ V2 .
+ V3 .
+ V4 .
+ V5 .
+ */
+
+ val coefficientsRStd = new DenseMatrix(3, 4, Array(
+ 0.0, 0.0, 0.0, 0.09404023,
+ -0.1232443, 0.3258380, -0.1564688, -0.2053965,
+ 0.0, -0.4803856, 0.0, 0.0), isTransposed = true)
+ val interceptsRStd = Vectors.dense(-0.68988825, -0.2303499, 0.9202381)
+
+ val coefficientsR = new DenseMatrix(3, 4, Array(
+ 0.0, 0.0, 0.01933812, 0.03666044,
+ -0.0577182, 0.0, -0.2081718, -0.1304592,
+ 0.0, 0.0, 0.0, 0.0), isTransposed = true)
+ val interceptsR = Vectors.dense(-0.44893320, 0.7376760, -0.2887428)
+
+ assert(model1.coefficients ~== coefficientsRStd absTol 0.02)
+ assert(model1.intercepts ~== interceptsRStd relTol 0.1)
+ assert(model1.intercepts.toArray.sum ~== 0.0 absTol eps)
+ assert(model2.coefficients ~== coefficientsR absTol 0.02)
+ assert(model2.intercepts ~== interceptsR relTol 0.1)
+ assert(model2.intercepts.toArray.sum ~== 0.0 absTol eps)
+ }
+
+ test("multinomial logistic regression without intercept with L1 regularization") {
+ val trainer1 = (new MultinomialLogisticRegression).setFitIntercept(false)
+ .setElasticNetParam(1.0).setRegParam(0.05).setStandardization(true)
+ val trainer2 = (new MultinomialLogisticRegression).setFitIntercept(false)
+ .setElasticNetParam(1.0).setRegParam(0.05).setStandardization(false)
+
+ val model1 = trainer1.fit(multinomialDataset)
+ val model2 = trainer2.fit(multinomialDataset)
+ /*
+ Use the following R code to load the data and train the model using glmnet package.
+ library("glmnet")
+ data <- read.csv("path", header=FALSE)
+ label = as.factor(data$V1)
+ features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+ coefficientsStd = coef(glmnet(features, label, family="multinomial", alpha = 1,
+ lambda = 0.05, intercept=F, standardization=T))
+ coefficients = coef(glmnet(features, label, family="multinomial", alpha = 1, lambda = 0.05,
+ intercept=F, standardization=F))
+ > coefficientsStd
+ $`0`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ .
+ V2 .
+ V3 .
+ V4 .
+ V5 0.01525105
+
+ $`1`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ .
+ V2 -0.1502410
+ V3 0.5134658
+ V4 -0.1601146
+ V5 -0.2500232
+
+ $`2`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ .
+ V2 0.003301875
+ V3 .
+ V4 .
+ V5 .
+
+ > coefficients
+ $`0`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ .
+ V2 .
+ V3 .
+ V4 .
+ V5 .
+
+ $`1`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ .
+ V2 .
+ V3 0.1943624
+ V4 -0.1902577
+ V5 -0.1028789
+
+ $`2`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ .
+ V2 .
+ V3 .
+ V4 .
+ V5 .
+ */
+
+ val coefficientsRStd = new DenseMatrix(3, 4, Array(
+ 0.0, 0.0, 0.0, 0.01525105,
+ -0.1502410, 0.5134658, -0.1601146, -0.2500232,
+ 0.003301875, 0.0, 0.0, 0.0), isTransposed = true)
+
+ val coefficientsR = new DenseMatrix(3, 4, Array(
+ 0.0, 0.0, 0.0, 0.0,
+ 0.0, 0.1943624, -0.1902577, -0.1028789,
+ 0.0, 0.0, 0.0, 0.0), isTransposed = true)
+
+ assert(model1.coefficients ~== coefficientsRStd absTol 0.01)
+ assert(model1.intercepts.toArray === Array.fill(3)(0.0))
+ assert(model1.intercepts.toArray.sum ~== 0.0 absTol eps)
+ assert(model2.coefficients ~== coefficientsR absTol 0.01)
+ assert(model2.intercepts.toArray === Array.fill(3)(0.0))
+ assert(model2.intercepts.toArray.sum ~== 0.0 absTol eps)
+ }
+
+ test("multinomial logistic regression with intercept with L2 regularization") {
+ val trainer1 = (new MultinomialLogisticRegression).setFitIntercept(true)
+ .setElasticNetParam(0.0).setRegParam(0.1).setStandardization(true)
+ val trainer2 = (new MultinomialLogisticRegression).setFitIntercept(true)
+ .setElasticNetParam(0.0).setRegParam(0.1).setStandardization(false)
+
+ val model1 = trainer1.fit(multinomialDataset)
+ val model2 = trainer2.fit(multinomialDataset)
+ /*
+ Use the following R code to load the data and train the model using glmnet package.
+ library("glmnet")
+ data <- read.csv("path", header=FALSE)
+ label = as.factor(data$V1)
+ features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+ coefficientsStd = coef(glmnet(features, label, family="multinomial", alpha = 0,
+ lambda = 0.1, intercept=T, standardization=T))
+ coefficients = coef(glmnet(features, label, family="multinomial", alpha = 0,
+ lambda = 0.1, intercept=T, standardization=F))
+ > coefficientsStd
+ $`0`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ -1.70040424
+ V2 0.17576070
+ V3 0.01527894
+ V4 0.10216108
+ V5 0.26099531
+
+ $`1`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ 0.2438590
+ V2 -0.2238875
+ V3 0.5967610
+ V4 -0.1555496
+ V5 -0.3010479
+
+ $`2`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ 1.45654525
+ V2 0.04812679
+ V3 -0.61203992
+ V4 0.05338850
+ V5 0.04005258
+
+ > coefficients
+ $`0`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ -1.65488543
+ V2 0.15715048
+ V3 0.01992903
+ V4 0.12428858
+ V5 0.22130317
+
+ $`1`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ 1.1297533
+ V2 -0.1974768
+ V3 0.2776373
+ V4 -0.1869445
+ V5 -0.2510320
+
+ $`2`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ 0.52513212
+ V2 0.04032627
+ V3 -0.29756637
+ V4 0.06265594
+ V5 0.02972883
+ */
+
+ val coefficientsRStd = new DenseMatrix(3, 4, Array(
+ 0.17576070, 0.01527894, 0.10216108, 0.26099531,
+ -0.2238875, 0.5967610, -0.1555496, -0.3010479,
+ 0.04812679, -0.61203992, 0.05338850, 0.04005258), isTransposed = true)
+ val interceptsRStd = Vectors.dense(-1.70040424, 0.2438590, 1.45654525)
+
+ val coefficientsR = new DenseMatrix(3, 4, Array(
+ 0.15715048, 0.01992903, 0.12428858, 0.22130317,
+ -0.1974768, 0.2776373, -0.1869445, -0.2510320,
+ 0.04032627, -0.29756637, 0.06265594, 0.02972883), isTransposed = true)
+ val interceptsR = Vectors.dense(-1.65488543, 1.1297533, 0.52513212)
+
+ assert(model1.coefficients ~== coefficientsRStd relTol 0.05)
+ assert(model1.intercepts ~== interceptsRStd relTol 0.05)
+ assert(model1.intercepts.toArray.sum ~== 0.0 absTol eps)
+ assert(model2.coefficients ~== coefficientsR relTol 0.05)
+ assert(model2.intercepts ~== interceptsR relTol 0.05)
+ assert(model2.intercepts.toArray.sum ~== 0.0 absTol eps)
+ }
+
+ test("multinomial logistic regression without intercept with L2 regularization") {
+ val trainer1 = (new MultinomialLogisticRegression).setFitIntercept(false)
+ .setElasticNetParam(0.0).setRegParam(0.1).setStandardization(true)
+ val trainer2 = (new MultinomialLogisticRegression).setFitIntercept(false)
+ .setElasticNetParam(0.0).setRegParam(0.1).setStandardization(false)
+
+ val model1 = trainer1.fit(multinomialDataset)
+ val model2 = trainer2.fit(multinomialDataset)
+ /*
+ Use the following R code to load the data and train the model using glmnet package.
+ library("glmnet")
+ data <- read.csv("path", header=FALSE)
+ label = as.factor(data$V1)
+ features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+ coefficientsStd = coef(glmnet(features, label, family="multinomial", alpha = 0,
+ lambda = 0.1, intercept=F, standardization=T))
+ coefficients = coef(glmnet(features, label, family="multinomial", alpha = 0,
+ lambda = 0.1, intercept=F, standardization=F))
+ > coefficientsStd
+ $`0`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ .
+ V2 0.03904171
+ V3 -0.23354322
+ V4 0.08288096
+ V5 0.22706393
+
+ $`1`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ .
+ V2 -0.2061848
+ V3 0.6341398
+ V4 -0.1530059
+ V5 -0.2958455
+
+ $`2`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ .
+ V2 0.16714312
+ V3 -0.40059658
+ V4 0.07012496
+ V5 0.06878158
+ > coefficients
+ $`0`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ .
+ V2 -0.005704542
+ V3 -0.144466409
+ V4 0.092080736
+ V5 0.182927657
+
+ $`1`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ .
+ V2 -0.08469036
+ V3 0.38996748
+ V4 -0.16468436
+ V5 -0.22522976
+
+ $`2`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ .
+ V2 0.09039490
+ V3 -0.24550107
+ V4 0.07260362
+ V5 0.04230210
+ */
+ val coefficientsRStd = new DenseMatrix(3, 4, Array(
+ 0.03904171, -0.23354322, 0.08288096, 0.2270639,
+ -0.2061848, 0.6341398, -0.1530059, -0.2958455,
+ 0.16714312, -0.40059658, 0.07012496, 0.06878158), isTransposed = true)
+
+ val coefficientsR = new DenseMatrix(3, 4, Array(
+ -0.005704542, -0.144466409, 0.092080736, 0.182927657,
+ -0.08469036, 0.38996748, -0.16468436, -0.22522976,
+ 0.0903949, -0.24550107, 0.07260362, 0.0423021), isTransposed = true)
+
+ assert(model1.coefficients ~== coefficientsRStd absTol 0.01)
+ assert(model1.intercepts.toArray === Array.fill(3)(0.0))
+ assert(model1.intercepts.toArray.sum ~== 0.0 absTol eps)
+ assert(model2.coefficients ~== coefficientsR absTol 0.01)
+ assert(model2.intercepts.toArray === Array.fill(3)(0.0))
+ assert(model2.intercepts.toArray.sum ~== 0.0 absTol eps)
+ }
+
+ test("multinomial logistic regression with intercept with elasticnet regularization") {
+ val trainer1 = (new MultinomialLogisticRegression).setFitIntercept(true)
+ .setElasticNetParam(0.5).setRegParam(0.1).setStandardization(true)
+ .setMaxIter(300).setTol(1e-10)
+ val trainer2 = (new MultinomialLogisticRegression).setFitIntercept(true)
+ .setElasticNetParam(0.5).setRegParam(0.1).setStandardization(false)
+ .setMaxIter(300).setTol(1e-10)
+
+ val model1 = trainer1.fit(multinomialDataset)
+ val model2 = trainer2.fit(multinomialDataset)
+ /*
+ Use the following R code to load the data and train the model using glmnet package.
+ library("glmnet")
+ data <- read.csv("path", header=FALSE)
+ label = as.factor(data$V1)
+ features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+ coefficientsStd = coef(glmnet(features, label, family="multinomial", alpha = 0.5,
+ lambda = 0.1, intercept=T, standardization=T))
+ coefficients = coef(glmnet(features, label, family="multinomial", alpha = 0.5,
+ lambda = 0.1, intercept=T, standardization=F))
+ > coefficientsStd
+ $`0`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ -0.5521819483
+ V2 0.0003092611
+ V3 .
+ V4 .
+ V5 0.0913818490
+
+ $`1`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ -0.27531989
+ V2 -0.09790029
+ V3 0.28502034
+ V4 -0.12416487
+ V5 -0.16513373
+
+ $`2`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ 0.8275018
+ V2 .
+ V3 -0.4044859
+ V4 .
+ V5 .
+
+ > coefficients
+ $`0`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ -0.39876213
+ V2 .
+ V3 .
+ V4 0.02547520
+ V5 0.03893991
+
+ $`1`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ 0.61089869
+ V2 -0.04224269
+ V3 .
+ V4 -0.18923970
+ V5 -0.09104249
+
+ $`2`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ -0.2121366
+ V2 .
+ V3 .
+ V4 .
+ V5 .
+ */
+
+ val coefficientsRStd = new DenseMatrix(3, 4, Array(
+ 0.0003092611, 0.0, 0.0, 0.091381849,
+ -0.09790029, 0.28502034, -0.12416487, -0.16513373,
+ 0.0, -0.4044859, 0.0, 0.0), isTransposed = true)
+ val interceptsRStd = Vectors.dense(-0.5521819483, -0.27531989, 0.8275018)
+
+ val coefficientsR = new DenseMatrix(3, 4, Array(
+ 0.0, 0.0, 0.0254752, 0.03893991,
+ -0.04224269, 0.0, -0.1892397, -0.09104249,
+ 0.0, 0.0, 0.0, 0.0), isTransposed = true)
+ val interceptsR = Vectors.dense(-0.39876213, 0.61089869, -0.2121366)
+
+ assert(model1.coefficients ~== coefficientsRStd absTol 0.01)
+ assert(model1.intercepts ~== interceptsRStd absTol 0.01)
+ assert(model1.intercepts.toArray.sum ~== 0.0 absTol eps)
+ assert(model2.coefficients ~== coefficientsR absTol 0.01)
+ assert(model2.intercepts ~== interceptsR absTol 0.01)
+ assert(model2.intercepts.toArray.sum ~== 0.0 absTol eps)
+ }
+
+ test("multinomial logistic regression without intercept with elasticnet regularization") {
+ val trainer1 = (new MultinomialLogisticRegression).setFitIntercept(false)
+ .setElasticNetParam(0.5).setRegParam(0.1).setStandardization(true)
+ .setMaxIter(300).setTol(1e-10)
+ val trainer2 = (new MultinomialLogisticRegression).setFitIntercept(false)
+ .setElasticNetParam(0.5).setRegParam(0.1).setStandardization(false)
+ .setMaxIter(300).setTol(1e-10)
+
+ val model1 = trainer1.fit(multinomialDataset)
+ val model2 = trainer2.fit(multinomialDataset)
+ /*
+ Use the following R code to load the data and train the model using glmnet package.
+ library("glmnet")
+ data <- read.csv("path", header=FALSE)
+ label = as.factor(data$V1)
+ features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+ coefficientsStd = coef(glmnet(features, label, family="multinomial", alpha = 0.5,
+ lambda = 0.1, intercept=F, standardization=T))
+ coefficients = coef(glmnet(features, label, family="multinomial", alpha = 0.5,
+ lambda = 0.1, intercept=F, standardization=F))
+ > coefficientsStd
+ $`0`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ .
+ V2 .
+ V3 .
+ V4 .
+ V5 0.03543706
+
+ $`1`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ .
+ V2 -0.1187387
+ V3 0.4025482
+ V4 -0.1270969
+ V5 -0.1918386
+
+ $`2`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ .
+ V2 0.00774365
+ V3 .
+ V4 .
+ V5 .
+
+ > coefficients
+ $`0`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ .
+ V2 .
+ V3 .
+ V4 .
+ V5 .
+
+ $`1`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ .
+ V2 .
+ V3 0.14666497
+ V4 -0.16570638
+ V5 -0.05982875
+
+ $`2`
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ .
+ V2 .
+ V3 .
+ V4 .
+ V5 .
+ */
+ val coefficientsRStd = new DenseMatrix(3, 4, Array(
+ 0.0, 0.0, 0.0, 0.03543706,
+ -0.1187387, 0.4025482, -0.1270969, -0.1918386,
+ 0.0, 0.0, 0.0, 0.00774365), isTransposed = true)
+
+ val coefficientsR = new DenseMatrix(3, 4, Array(
+ 0.0, 0.0, 0.0, 0.0,
+ 0.0, 0.14666497, -0.16570638, -0.05982875,
+ 0.0, 0.0, 0.0, 0.0), isTransposed = true)
+
+ assert(model1.coefficients ~== coefficientsRStd absTol 0.01)
+ assert(model1.intercepts.toArray === Array.fill(3)(0.0))
+ assert(model1.intercepts.toArray.sum ~== 0.0 absTol eps)
+ assert(model2.coefficients ~== coefficientsR absTol 0.01)
+ assert(model2.intercepts.toArray === Array.fill(3)(0.0))
+ assert(model2.intercepts.toArray.sum ~== 0.0 absTol eps)
+ }
+
+ /*
+ test("multinomial logistic regression with intercept with strong L1 regularization") {
+ // TODO: implement this test to check that the priors on the intercepts are correct
+ // TODO: when initial model becomes available
+ }
+ */
+
+ test("prediction") {
+ val model = new MultinomialLogisticRegressionModel("mLogReg",
+ Matrices.dense(3, 2, Array(0.0, 0.0, 0.0, 1.0, 2.0, 3.0)),
+ Vectors.dense(0.0, 0.0, 0.0), 3)
+ val overFlowData = spark.createDataFrame(Seq(
+ LabeledPoint(1.0, Vectors.dense(0.0, 1000.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0, -1.0))
+ ))
+ val results = model.transform(overFlowData).select("rawPrediction", "probability").collect()
+
+ // probabilities are correct when margins have to be adjusted
+ val raw1 = results(0).getAs[Vector](0)
+ val prob1 = results(0).getAs[Vector](1)
+ assert(raw1 === Vectors.dense(1000.0, 2000.0, 3000.0))
+ assert(prob1 ~== Vectors.dense(0.0, 0.0, 1.0) absTol eps)
+
+ // probabilities are correct when margins don't have to be adjusted
+ val raw2 = results(1).getAs[Vector](0)
+ val prob2 = results(1).getAs[Vector](1)
+ assert(raw2 === Vectors.dense(-1.0, -2.0, -3.0))
+ assert(prob2 ~== Vectors.dense(0.66524096, 0.24472847, 0.09003057) relTol eps)
+ }
+
+ test("multinomial logistic regression: Predictor, Classifier methods") {
+ val mlr = new MultinomialLogisticRegression
+
+ val model = mlr.fit(dataset)
+ assert(model.numClasses === 3)
+ val numFeatures = dataset.select("features").first().getAs[Vector](0).size
+ assert(model.numFeatures === numFeatures)
+
+ val results = model.transform(dataset)
+ // check that raw prediction is coefficients dot features + intercept
+ results.select("rawPrediction", "features").collect().foreach {
+ case Row(raw: Vector, features: Vector) =>
+ assert(raw.size === 3)
+ val margins = Array.tabulate(3) { k =>
+ var margin = 0.0
+ features.foreachActive { (index, value) =>
+ margin += value * model.coefficients(k, index)
+ }
+ margin += model.intercepts(k)
+ margin
+ }
+ assert(raw ~== Vectors.dense(margins) relTol eps)
+ }
+
+ // Compare rawPrediction with probability
+ results.select("rawPrediction", "probability").collect().foreach {
+ case Row(raw: Vector, prob: Vector) =>
+ assert(raw.size === 3)
+ assert(prob.size === 3)
+ val max = raw.toArray.max
+ val subtract = if (max > 0) max else 0.0
+ val sum = raw.toArray.map(x => math.exp(x - subtract)).sum
+ val probFromRaw0 = math.exp(raw(0) - subtract) / sum
+ val probFromRaw1 = math.exp(raw(1) - subtract) / sum
+ assert(prob(0) ~== probFromRaw0 relTol eps)
+ assert(prob(1) ~== probFromRaw1 relTol eps)
+ assert(prob(2) ~== 1.0 - probFromRaw1 - probFromRaw0 relTol eps)
+ }
+
+ // Compare prediction with probability
+ results.select("prediction", "probability").collect().foreach {
+ case Row(pred: Double, prob: Vector) =>
+ val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2
+ assert(pred == predFromProb)
+ }
+ }
+
+ test("multinomial logistic regression coefficients should be centered") {
+ val mlr = new MultinomialLogisticRegression().setMaxIter(1)
+ val model = mlr.fit(dataset)
+ assert(model.intercepts.toArray.sum ~== 0.0 absTol 1e-6)
+ assert(model.coefficients.toArray.sum ~== 0.0 absTol 1e-6)
+ }
+
+ test("numClasses specified in metadata/inferred") {
+ val mlr = new MultinomialLogisticRegression().setMaxIter(1)
+
+ // specify more classes than unique label values
+ val labelMeta = NominalAttribute.defaultAttr.withName("label").withNumValues(4).toMetadata()
+ val df = dataset.select(dataset("label").as("label", labelMeta), dataset("features"))
+ val model1 = mlr.fit(df)
+ assert(model1.numClasses === 4)
+ assert(model1.intercepts.size === 4)
+
+ // specify two classes when there are really three
+ val labelMeta1 = NominalAttribute.defaultAttr.withName("label").withNumValues(2).toMetadata()
+ val df1 = dataset.select(dataset("label").as("label", labelMeta1), dataset("features"))
+ val thrown = intercept[IllegalArgumentException] {
+ mlr.fit(df1)
+ }
+ assert(thrown.getMessage.contains("less than the number of unique labels"))
+
+ // mlr should infer the number of classes if not specified
+ val model3 = mlr.fit(dataset)
+ assert(model3.numClasses === 3)
+ }
+
+ test("all labels the same") {
+ val constantData = spark.createDataFrame(Seq(
+ LabeledPoint(4.0, Vectors.dense(0.0)),
+ LabeledPoint(4.0, Vectors.dense(1.0)),
+ LabeledPoint(4.0, Vectors.dense(2.0)))
+ )
+ val mlr = new MultinomialLogisticRegression
+ val model = mlr.fit(constantData)
+ val results = model.transform(constantData)
+ results.select("rawPrediction", "probability", "prediction").collect().foreach {
+ case Row(raw: Vector, prob: Vector, pred: Double) =>
+ assert(raw === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, Double.PositiveInfinity)))
+ assert(prob === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, 1.0)))
+ assert(pred === 4.0)
+ }
+
+ // force the model to be trained with only one class
+ val constantZeroData = spark.createDataFrame(Seq(
+ LabeledPoint(0.0, Vectors.dense(0.0)),
+ LabeledPoint(0.0, Vectors.dense(1.0)),
+ LabeledPoint(0.0, Vectors.dense(2.0)))
+ )
+ val modelZeroLabel = mlr.setFitIntercept(false).fit(constantZeroData)
+ val resultsZero = modelZeroLabel.transform(constantZeroData)
+ resultsZero.select("rawPrediction", "probability", "prediction").collect().foreach {
+ case Row(raw: Vector, prob: Vector, pred: Double) =>
+ assert(prob === Vectors.dense(Array(1.0)))
+ assert(pred === 0.0)
+ }
+
+ // ensure that the correct value is predicted when numClasses passed through metadata
+ val labelMeta = NominalAttribute.defaultAttr.withName("label").withNumValues(6).toMetadata()
+ val constantDataWithMetadata = constantData
+ .select(constantData("label").as("label", labelMeta), constantData("features"))
+ val modelWithMetadata = mlr.setFitIntercept(true).fit(constantDataWithMetadata)
+ val resultsWithMetadata = modelWithMetadata.transform(constantDataWithMetadata)
+ resultsWithMetadata.select("rawPrediction", "probability", "prediction").collect().foreach {
+ case Row(raw: Vector, prob: Vector, pred: Double) =>
+ assert(raw === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, Double.PositiveInfinity, 0.0)))
+ assert(prob === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, 1.0, 0.0)))
+ assert(pred === 4.0)
+ }
+ // TODO: check num iters is zero when it become available in the model
+ }
+
+ test("weighted data") {
+ val numClasses = 5
+ val numPoints = 40
+ val outlierData = MLTestingUtils.genClassificationInstancesWithWeightedOutliers(spark,
+ numClasses, numPoints)
+ val testData = spark.createDataFrame(Array.tabulate[LabeledPoint](numClasses) { i =>
+ LabeledPoint(i.toDouble, Vectors.dense(i.toDouble))
+ })
+ val mlr = new MultinomialLogisticRegression().setWeightCol("weight")
+ val model = mlr.fit(outlierData)
+ val results = model.transform(testData).select("label", "prediction").collect()
+
+ // check that the predictions are the one to one mapping
+ results.foreach { case Row(label: Double, pred: Double) =>
+ assert(label === pred)
+ }
+ val (overSampledData, weightedData) =
+ MLTestingUtils.genEquivalentOversampledAndWeightedInstances(outlierData, "label", "features",
+ 42L)
+ val weightedModel = mlr.fit(weightedData)
+ val overSampledModel = mlr.setWeightCol("").fit(overSampledData)
+ assert(weightedModel.coefficients ~== overSampledModel.coefficients relTol 0.01)
+ }
+
+ test("thresholds prediction") {
+ val mlr = new MultinomialLogisticRegression
+ val model = mlr.fit(dataset)
+ val basePredictions = model.transform(dataset).select("prediction").collect()
+
+ // should predict all zeros
+ model.setThresholds(Array(1, 1000, 1000))
+ val zeroPredictions = model.transform(dataset).select("prediction").collect()
+ assert(zeroPredictions.forall(_.getDouble(0) === 0.0))
+
+ // should predict all ones
+ model.setThresholds(Array(1000, 1, 1000))
+ val onePredictions = model.transform(dataset).select("prediction").collect()
+ assert(onePredictions.forall(_.getDouble(0) === 1.0))
+
+ // should predict all twos
+ model.setThresholds(Array(1000, 1000, 1))
+ val twoPredictions = model.transform(dataset).select("prediction").collect()
+ assert(twoPredictions.forall(_.getDouble(0) === 2.0))
+
+ // constant threshold scaling is the same as no thresholds
+ model.setThresholds(Array(1000, 1000, 1000))
+ val scaledPredictions = model.transform(dataset).select("prediction").collect()
+ assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) =>
+ scaled.getDouble(0) === base.getDouble(0)
+ })
+ }
+
+ test("read/write") {
+ def checkModelData(
+ model: MultinomialLogisticRegressionModel,
+ model2: MultinomialLogisticRegressionModel): Unit = {
+ assert(model.intercepts === model2.intercepts)
+ assert(model.coefficients.toArray === model2.coefficients.toArray)
+ assert(model.numClasses === model2.numClasses)
+ assert(model.numFeatures === model2.numFeatures)
+ }
+ val mlr = new MultinomialLogisticRegression()
+ testEstimatorAndModelReadWrite(mlr, dataset,
+ MultinomialLogisticRegressionSuite.allParamSettings,
+ checkModelData)
+ }
+
+ test("should support all NumericType labels and not support other types") {
+ val mlr = new MultinomialLogisticRegression().setMaxIter(1)
+ MLTestingUtils
+ .checkNumericTypes[MultinomialLogisticRegressionModel, MultinomialLogisticRegression](
+ mlr, spark) { (expected, actual) =>
+ assert(expected.intercepts === actual.intercepts)
+ assert(expected.coefficients.toArray === actual.coefficients.toArray)
+ }
+ }
+}
+
+object MultinomialLogisticRegressionSuite {
+
+ /**
+ * Mapping from all Params to valid settings which differ from the defaults.
+ * This is useful for tests which need to exercise all Params, such as save/load.
+ * This excludes input columns to simplify some tests.
+ */
+ val allParamSettings: Map[String, Any] = ProbabilisticClassifierSuite.allParamSettings ++ Map(
+ "probabilityCol" -> "myProbability",
+ "thresholds" -> Array(0.4, 0.6),
+ "regParam" -> 0.01,
+ "elasticNetParam" -> 0.1,
+ "maxIter" -> 2, // intentionally small
+ "fitIntercept" -> true,
+ "tol" -> 0.8,
+ "standardization" -> false
+ )
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
index 80b976914cbdf..472a5af06e7a2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
@@ -19,12 +19,14 @@ package org.apache.spark.ml.util
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.evaluation.Evaluator
-import org.apache.spark.ml.linalg.Vectors
+import org.apache.spark.ml.feature.Instance
+import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.recommendation.{ALS, ALSModel}
import org.apache.spark.ml.tree.impl.TreeTests
-import org.apache.spark.sql.{DataFrame, SparkSession}
+import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
@@ -179,4 +181,47 @@ object MLTestingUtils extends SparkFunSuite {
.map(t => t -> df.select(col(labelColName).cast(t), col(predictionColName)))
.toMap
}
+
+ def genClassificationInstancesWithWeightedOutliers(
+ spark: SparkSession,
+ numClasses: Int,
+ numInstances: Int): DataFrame = {
+ val data = Array.tabulate[Instance](numInstances) { i =>
+ val feature = i % numClasses
+ if (i < numInstances / 3) {
+ // give large weights to minority of data with 1 to 1 mapping feature to label
+ Instance(feature, 1.0, Vectors.dense(feature))
+ } else {
+ // give small weights to majority of data points with reverse mapping
+ Instance(numClasses - feature - 1, 0.01, Vectors.dense(feature))
+ }
+ }
+ val labelMeta =
+ NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses).toMetadata()
+ spark.createDataFrame(data).select(col("label").as("label", labelMeta), col("weight"),
+ col("features"))
+ }
+
+ def genEquivalentOversampledAndWeightedInstances(
+ data: DataFrame,
+ labelCol: String,
+ featuresCol: String,
+ seed: Long): (DataFrame, DataFrame) = {
+ import data.sparkSession.implicits._
+ val rng = scala.util.Random
+ rng.setSeed(seed)
+ val sample: () => Int = () => rng.nextInt(10) + 1
+ val sampleUDF = udf(sample)
+ val rawData = data.select(labelCol, featuresCol).withColumn("samples", sampleUDF())
+ val overSampledData = rawData.rdd.flatMap {
+ case Row(label: Double, features: Vector, n: Int) =>
+ Iterator.fill(n)(Instance(label, 1.0, features))
+ }.toDF()
+ rng.setSeed(seed)
+ val weightedData = rawData.rdd.map {
+ case Row(label: Double, features: Vector, n: Int) =>
+ Instance(label, n.toDouble, features)
+ }.toDF()
+ (overSampledData, weightedData)
+ }
}
From 5377fc62360d5e9b5c94078e41d10a96e0e8a535 Mon Sep 17 00:00:00 2001
From: Nick Lavers
Date: Fri, 19 Aug 2016 10:11:59 +0100
Subject: [PATCH 025/270] [SPARK-16961][CORE] Fixed off-by-one error that
biased randomizeInPlace
JIRA issue link:
https://issues.apache.org/jira/browse/SPARK-16961
Changed one line of Utils.randomizeInPlace to allow elements to stay in place.
Created a unit test that runs a Pearson's chi squared test to determine whether the output diverges significantly from a uniform distribution.
Author: Nick Lavers
Closes #14551 from nicklavers/SPARK-16961-randomizeInPlace.
---
R/pkg/inst/tests/testthat/test_mllib.R | 12 +++----
.../scala/org/apache/spark/util/Utils.scala | 2 +-
.../org/apache/spark/util/UtilsSuite.scala | 35 +++++++++++++++++++
python/pyspark/ml/clustering.py | 12 +++----
python/pyspark/mllib/clustering.py | 2 +-
python/pyspark/mllib/tests.py | 2 +-
6 files changed, 50 insertions(+), 15 deletions(-)
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index 8c380fbf150f4..dfb7a185cd5a3 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -546,15 +546,15 @@ test_that("spark.gaussianMixture", {
df <- createDataFrame(data, c("x1", "x2"))
model <- spark.gaussianMixture(df, ~ x1 + x2, k = 2)
stats <- summary(model)
- rLambda <- c(0.4, 0.6)
- rMu <- c(-0.2614822, 0.5128697, 2.647284, 4.544682)
- rSigma <- c(0.08427399, 0.00548772, 0.00548772, 0.09090715,
- 0.1641373, -0.1673806, -0.1673806, 0.7508951)
- expect_equal(stats$lambda, rLambda)
+ rLambda <- c(0.50861, 0.49139)
+ rMu <- c(0.267, 1.195, 2.743, 4.730)
+ rSigma <- c(1.099, 1.339, 1.339, 1.798,
+ 0.145, -0.309, -0.309, 0.716)
+ expect_equal(stats$lambda, rLambda, tolerance = 1e-3)
expect_equal(unlist(stats$mu), rMu, tolerance = 1e-3)
expect_equal(unlist(stats$sigma), rSigma, tolerance = 1e-3)
p <- collect(select(predict(model, df), "prediction"))
- expect_equal(p$prediction, c(0, 0, 0, 0, 1, 1, 1, 1, 1, 1))
+ expect_equal(p$prediction, c(0, 0, 0, 0, 0, 1, 1, 1, 1, 1))
# Test model save/load
modelPath <- tempfile(pattern = "spark-gaussianMixture", fileext = ".tmp")
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 0ae44a2ed7865..9b4274a27b3be 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -824,7 +824,7 @@ private[spark] object Utils extends Logging {
*/
def randomizeInPlace[T](arr: Array[T], rand: Random = new Random): Array[T] = {
for (i <- (arr.length - 1) to 1 by -1) {
- val j = rand.nextInt(i)
+ val j = rand.nextInt(i + 1)
val tmp = arr(j)
arr(j) = arr(i)
arr(i) = tmp
diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
index 30952a9458345..4715fd29375d6 100644
--- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
@@ -31,6 +31,7 @@ import scala.util.Random
import com.google.common.io.Files
import org.apache.commons.lang3.SystemUtils
+import org.apache.commons.math3.stat.inference.ChiSquareTest
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
@@ -874,4 +875,38 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging {
}
}
}
+
+ test("chi square test of randomizeInPlace") {
+ // Parameters
+ val arraySize = 10
+ val numTrials = 1000
+ val threshold = 0.05
+ val seed = 1L
+
+ // results(i)(j): how many times Utils.randomize moves an element from position j to position i
+ val results = Array.ofDim[Long](arraySize, arraySize)
+
+ // This must be seeded because even a fair random process will fail this test with
+ // probability equal to the value of `threshold`, which is inconvenient for a unit test.
+ val rand = new java.util.Random(seed)
+ val range = 0 until arraySize
+
+ for {
+ _ <- 0 until numTrials
+ trial = Utils.randomizeInPlace(range.toArray, rand)
+ i <- range
+ } results(i)(trial(i)) += 1L
+
+ val chi = new ChiSquareTest()
+
+ // We expect an even distribution; this array will be rescaled by `chiSquareTest`
+ val expected = Array.fill(arraySize * arraySize)(1.0)
+ val observed = results.flatten
+
+ // Performs Pearson's chi-squared test. Using the sum-of-squares as the test statistic, gives
+ // the probability of a uniform distribution producing results as extreme as `observed`
+ val pValue = chi.chiSquareTest(expected, observed)
+
+ assert(pValue > threshold)
+ }
}
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index 75d9a0e8cac18..4dab83362a0a4 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -99,9 +99,9 @@ class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte
+--------------------+--------------------+
| mean| cov|
+--------------------+--------------------+
- |[-0.0550000000000...|0.002025000000000...|
- |[0.82499999999999...|0.005625000000000...|
- |[-0.87,-0.7200000...|0.001600000000000...|
+ |[0.82500000140229...|0.005625000000006...|
+ |[-0.4777098016092...|0.167969502720916...|
+ |[-0.4472625243352...|0.167304119758233...|
+--------------------+--------------------+
...
>>> transformed = model.transform(df).select("features", "prediction")
@@ -124,9 +124,9 @@ class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte
+--------------------+--------------------+
| mean| cov|
+--------------------+--------------------+
- |[-0.0550000000000...|0.002025000000000...|
- |[0.82499999999999...|0.005625000000000...|
- |[-0.87,-0.7200000...|0.001600000000000...|
+ |[0.82500000140229...|0.005625000000006...|
+ |[-0.4777098016092...|0.167969502720916...|
+ |[-0.4472625243352...|0.167304119758233...|
+--------------------+--------------------+
...
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index c8c3c42774f21..29aa615125770 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -416,7 +416,7 @@ class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader):
... 4.5605, 5.2043, 6.2734])
>>> clusterdata_2 = sc.parallelize(data.reshape(5,3))
>>> model = GaussianMixture.train(clusterdata_2, 2, convergenceTol=0.0001,
- ... maxIterations=150, seed=10)
+ ... maxIterations=150, seed=4)
>>> labels = model.predict(clusterdata_2).collect()
>>> labels[0]==labels[1]
True
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 99bf50b5a1640..3f3dfd186c10d 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -550,7 +550,7 @@ def test_gmm(self):
[-6, -7],
])
clusters = GaussianMixture.train(data, 2, convergenceTol=0.001,
- maxIterations=10, seed=56)
+ maxIterations=10, seed=1)
labels = clusters.predict(data).collect()
self.assertEqual(labels[0], labels[1])
self.assertEqual(labels[2], labels[3])
From 864be9359ae2f8409e6dbc38a7a18593f9cc5692 Mon Sep 17 00:00:00 2001
From: Yanbo Liang
Date: Fri, 19 Aug 2016 03:23:16 -0700
Subject: [PATCH 026/270] [SPARK-17141][ML] MinMaxScaler should remain NaN
value.
## What changes were proposed in this pull request?
In the existing code, ```MinMaxScaler``` handle ```NaN``` value indeterminately.
* If a column has identity value, that is ```max == min```, ```MinMaxScalerModel``` transformation will output ```0.5``` for all rows even the original value is ```NaN```.
* Otherwise, it will remain ```NaN``` after transformation.
I think we should unify the behavior by remaining ```NaN``` value at any condition, since we don't know how to transform a ```NaN``` value. In Python sklearn, it will throw exception when there is ```NaN``` in the dataset.
## How was this patch tested?
Unit tests.
Author: Yanbo Liang
Closes #14716 from yanboliang/spark-17141.
---
.../spark/ml/feature/MinMaxScaler.scala | 6 +++--
.../spark/ml/feature/MinMaxScalerSuite.scala | 27 +++++++++++++++++++
2 files changed, 31 insertions(+), 2 deletions(-)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
index 9f3d2ca6db0c1..28cbe1cb01e9a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
@@ -186,8 +186,10 @@ class MinMaxScalerModel private[ml] (
val size = values.length
var i = 0
while (i < size) {
- val raw = if (originalRange(i) != 0) (values(i) - minArray(i)) / originalRange(i) else 0.5
- values(i) = raw * scale + $(min)
+ if (!values(i).isNaN) {
+ val raw = if (originalRange(i) != 0) (values(i) - minArray(i)) / originalRange(i) else 0.5
+ values(i) = raw * scale + $(min)
+ }
i += 1
}
Vectors.dense(values)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
index 5da84711758c6..9f376b70035c5 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
@@ -90,4 +90,31 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De
assert(newInstance.originalMin === instance.originalMin)
assert(newInstance.originalMax === instance.originalMax)
}
+
+ test("MinMaxScaler should remain NaN value") {
+ val data = Array(
+ Vectors.dense(1, Double.NaN, 2.0, 2.0),
+ Vectors.dense(2, 2.0, 0.0, 3.0),
+ Vectors.dense(3, Double.NaN, 0.0, 1.0),
+ Vectors.dense(6, 2.0, 2.0, Double.NaN))
+
+ val expected: Array[Vector] = Array(
+ Vectors.dense(-5.0, Double.NaN, 5.0, 0.0),
+ Vectors.dense(-3.0, 0.0, -5.0, 5.0),
+ Vectors.dense(-1.0, Double.NaN, -5.0, -5.0),
+ Vectors.dense(5.0, 0.0, 5.0, Double.NaN))
+
+ val df = spark.createDataFrame(data.zip(expected)).toDF("features", "expected")
+ val scaler = new MinMaxScaler()
+ .setInputCol("features")
+ .setOutputCol("scaled")
+ .setMin(-5)
+ .setMax(5)
+
+ val model = scaler.fit(df)
+ model.transform(df).select("expected", "scaled").collect()
+ .foreach { case Row(vector1: Vector, vector2: Vector) =>
+ assert(vector1.equals(vector2), "Transformed vector is different with expected.")
+ }
+ }
}
From 072acf5e1460d66d4b60b536d5b2ccddeee80794 Mon Sep 17 00:00:00 2001
From: Jeff Zhang
Date: Fri, 19 Aug 2016 12:38:15 +0100
Subject: [PATCH 027/270] [SPARK-16965][MLLIB][PYSPARK] Fix bound checking for
SparseVector.
## What changes were proposed in this pull request?
1. In scala, add negative low bound checking and put all the low/upper bound checking in one place
2. In python, add low/upper bound checking of indices.
## How was this patch tested?
unit test added
Author: Jeff Zhang
Closes #14555 from zjffdu/SPARK-16965.
---
.../org/apache/spark/ml/linalg/Vectors.scala | 34 +++++++++++--------
.../apache/spark/ml/linalg/VectorsSuite.scala | 6 ++++
python/pyspark/ml/linalg/__init__.py | 15 ++++++++
3 files changed, 40 insertions(+), 15 deletions(-)
diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala
index 0659324aad1fa..2e4a58dc6291c 100644
--- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala
+++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala
@@ -208,17 +208,7 @@ object Vectors {
*/
@Since("2.0.0")
def sparse(size: Int, elements: Seq[(Int, Double)]): Vector = {
- require(size > 0, "The size of the requested sparse vector must be greater than 0.")
-
val (indices, values) = elements.sortBy(_._1).unzip
- var prev = -1
- indices.foreach { i =>
- require(prev < i, s"Found duplicate indices: $i.")
- prev = i
- }
- require(prev < size, s"You may not write an element to index $prev because the declared " +
- s"size of your vector is $size")
-
new SparseVector(size, indices.toArray, values.toArray)
}
@@ -560,11 +550,25 @@ class SparseVector @Since("2.0.0") (
@Since("2.0.0") val indices: Array[Int],
@Since("2.0.0") val values: Array[Double]) extends Vector {
- require(indices.length == values.length, "Sparse vectors require that the dimension of the" +
- s" indices match the dimension of the values. You provided ${indices.length} indices and " +
- s" ${values.length} values.")
- require(indices.length <= size, s"You provided ${indices.length} indices and values, " +
- s"which exceeds the specified vector size ${size}.")
+ // validate the data
+ {
+ require(size >= 0, "The size of the requested sparse vector must be greater than 0.")
+ require(indices.length == values.length, "Sparse vectors require that the dimension of the" +
+ s" indices match the dimension of the values. You provided ${indices.length} indices and " +
+ s" ${values.length} values.")
+ require(indices.length <= size, s"You provided ${indices.length} indices and values, " +
+ s"which exceeds the specified vector size ${size}.")
+
+ if (indices.nonEmpty) {
+ require(indices(0) >= 0, s"Found negative index: ${indices(0)}.")
+ }
+ var prev = -1
+ indices.foreach { i =>
+ require(prev < i, s"Index $i follows $prev and is not strictly increasing")
+ prev = i
+ }
+ require(prev < size, s"Index $prev out of bounds for vector of size $size")
+ }
override def toString: String =
s"($size,${indices.mkString("[", ",", "]")},${values.mkString("[", ",", "]")})"
diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala
index 614be460a414a..ea22c2787fb3c 100644
--- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala
+++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala
@@ -72,6 +72,12 @@ class VectorsSuite extends SparkMLFunSuite {
}
}
+ test("sparse vector construction with negative indices") {
+ intercept[IllegalArgumentException] {
+ Vectors.sparse(3, Array(-1, 1), Array(3.0, 5.0))
+ }
+ }
+
test("dense to array") {
val vec = Vectors.dense(arr).asInstanceOf[DenseVector]
assert(vec.toArray.eq(arr))
diff --git a/python/pyspark/ml/linalg/__init__.py b/python/pyspark/ml/linalg/__init__.py
index f42c589b92255..05c0ac862fb7f 100644
--- a/python/pyspark/ml/linalg/__init__.py
+++ b/python/pyspark/ml/linalg/__init__.py
@@ -478,6 +478,14 @@ def __init__(self, size, *args):
SparseVector(4, {1: 1.0, 3: 5.5})
>>> SparseVector(4, [1, 3], [1.0, 5.5])
SparseVector(4, {1: 1.0, 3: 5.5})
+ >>> SparseVector(4, {1:1.0, 6:2.0})
+ Traceback (most recent call last):
+ ...
+ AssertionError: Index 6 is out of the the size of vector with size=4
+ >>> SparseVector(4, {-1:1.0})
+ Traceback (most recent call last):
+ ...
+ AssertionError: Contains negative index -1
"""
self.size = int(size)
""" Size of the vector. """
@@ -511,6 +519,13 @@ def __init__(self, size, *args):
"Indices %s and %s are not strictly increasing"
% (self.indices[i], self.indices[i + 1]))
+ if self.indices.size > 0:
+ assert np.max(self.indices) < self.size, \
+ "Index %d is out of the the size of vector with size=%d" \
+ % (np.max(self.indices), self.size)
+ assert np.min(self.indices) >= 0, \
+ "Contains negative index %d" % (np.min(self.indices))
+
def numNonzeros(self):
"""
Number of nonzero elements. This scans all active values and count non zeros.
From 67e59d464f782ff5f509234212aa072a7653d7bf Mon Sep 17 00:00:00 2001
From: Reynold Xin
Date: Fri, 19 Aug 2016 21:11:35 +0800
Subject: [PATCH 028/270] [SPARK-16994][SQL] Whitelist operators for predicate
pushdown
## What changes were proposed in this pull request?
This patch changes predicate pushdown optimization rule (PushDownPredicate) from using a blacklist to a whitelist. That is to say, operators must be explicitly allowed. This approach is more future-proof: previously it was possible for us to introduce a new operator and then render the optimization rule incorrect.
This also fixes the bug that previously we allowed pushing filter beneath limit, which was incorrect. That is to say, before this patch, the optimizer would rewrite
```
select * from (select * from range(10) limit 5) where id > 3
to
select * from range(10) where id > 3 limit 5
```
## How was this patch tested?
- a unit test case in FilterPushdownSuite
- an end-to-end test in limit.sql
Author: Reynold Xin
Closes #14713 from rxin/SPARK-16994.
---
.../sql/catalyst/optimizer/Optimizer.scala | 23 ++++++++++++++-----
.../optimizer/FilterPushdownSuite.scala | 6 +++++
.../test/resources/sql-tests/inputs/limit.sql | 3 +++
.../resources/sql-tests/results/limit.sql.out | 10 +++++++-
4 files changed, 35 insertions(+), 7 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index f7aa6da0a5bdc..ce57f05868fe1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -1208,17 +1208,28 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
filter
}
- // two filters should be combine together by other rules
- case filter @ Filter(_, _: Filter) => filter
- // should not push predicates through sample, or will generate different results.
- case filter @ Filter(_, _: Sample) => filter
-
- case filter @ Filter(condition, u: UnaryNode) if u.expressions.forall(_.deterministic) =>
+ case filter @ Filter(condition, u: UnaryNode)
+ if canPushThrough(u) && u.expressions.forall(_.deterministic) =>
pushDownPredicate(filter, u.child) { predicate =>
u.withNewChildren(Seq(Filter(predicate, u.child)))
}
}
+ private def canPushThrough(p: UnaryNode): Boolean = p match {
+ // Note that some operators (e.g. project, aggregate, union) are being handled separately
+ // (earlier in this rule).
+ case _: AppendColumns => true
+ case _: BroadcastHint => true
+ case _: Distinct => true
+ case _: Generate => true
+ case _: Pivot => true
+ case _: RedistributeData => true
+ case _: Repartition => true
+ case _: ScriptTransformation => true
+ case _: Sort => true
+ case _ => false
+ }
+
private def pushDownPredicate(
filter: Filter,
grandchild: LogicalPlan)(insertFilter: Expression => LogicalPlan): LogicalPlan = {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index 596b8fcea194b..9f25e9d8e9ac8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -111,6 +111,12 @@ class FilterPushdownSuite extends PlanTest {
assert(optimized == correctAnswer)
}
+ test("SPARK-16994: filter should not be pushed through limit") {
+ val originalQuery = testRelation.limit(10).where('a === 1).analyze
+ val optimized = Optimize.execute(originalQuery)
+ comparePlans(optimized, originalQuery)
+ }
+
test("can't push without rewrite") {
val originalQuery =
testRelation
diff --git a/sql/core/src/test/resources/sql-tests/inputs/limit.sql b/sql/core/src/test/resources/sql-tests/inputs/limit.sql
index 892a1bb4b559f..2ea35f7f3a5c8 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/limit.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/limit.sql
@@ -18,3 +18,6 @@ select * from testdata limit key > 3;
-- limit must be integer
select * from testdata limit true;
select * from testdata limit 'a';
+
+-- limit within a subquery
+select * from (select * from range(10) limit 5) where id > 3;
diff --git a/sql/core/src/test/resources/sql-tests/results/limit.sql.out b/sql/core/src/test/resources/sql-tests/results/limit.sql.out
index b71b05886986c..cb4e4d04810d0 100644
--- a/sql/core/src/test/resources/sql-tests/results/limit.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/limit.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 9
+-- Number of queries: 10
-- !query 0
@@ -81,3 +81,11 @@ struct<>
-- !query 8 output
org.apache.spark.sql.AnalysisException
The limit expression must be integer type, but got string;
+
+
+-- !query 9
+select * from (select * from range(10) limit 5) where id > 3
+-- !query 9 schema
+struct
+-- !query 9 output
+4
From e98eb2146f1363956bfc3e5adcc11c246182d617 Mon Sep 17 00:00:00 2001
From: Alex Bozarth
Date: Fri, 19 Aug 2016 10:04:20 -0500
Subject: [PATCH 029/270] [SPARK-16673][WEB UI] New Executor Page removed
conditional for Logs and Thread Dump columns
## What changes were proposed in this pull request?
When #13670 switched `ExecutorsPage` to use JQuery DataTables it incidentally removed the conditional for the Logs and Thread Dump columns. I reimplemented the conditional display of the Logs and Thread dump columns as it was before the switch.
## How was this patch tested?
Manually tested and dev/run-tests



Author: Alex Bozarth
Closes #14382 from ajbozarth/spark16673.
---
.../apache/spark/ui/static/executorspage.js | 38 +++++++++++++++----
.../apache/spark/ui/exec/ExecutorsPage.scala | 7 ++--
2 files changed, 34 insertions(+), 11 deletions(-)
diff --git a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js
index b2b2363d3ac69..1df67337ea031 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js
@@ -15,6 +15,16 @@
* limitations under the License.
*/
+var threadDumpEnabled = false;
+
+function setThreadDumpEnabled(val) {
+ threadDumpEnabled = val;
+}
+
+function getThreadDumpEnabled() {
+ return threadDumpEnabled;
+}
+
function formatStatus(status, type) {
if (type !== 'display') return status;
if (status) {
@@ -116,6 +126,12 @@ function formatLogsCells(execLogs, type) {
return result;
}
+function logsExist(execs) {
+ return execs.some(function(exec) {
+ return !($.isEmptyObject(exec["executorLogs"]));
+ });
+}
+
// Determine Color Opacity from 0.5-1
// activeTasks range from 0 to maxTasks
function activeTasksAlpha(activeTasks, maxTasks) {
@@ -143,18 +159,16 @@ function totalDurationAlpha(totalGCTime, totalDuration) {
(Math.min(totalGCTime / totalDuration + 0.5, 1)) : 1;
}
+// When GCTimePercent is edited change ToolTips.TASK_TIME to match
+var GCTimePercent = 0.1;
+
function totalDurationStyle(totalGCTime, totalDuration) {
// Red if GC time over GCTimePercent of total time
- // When GCTimePercent is edited change ToolTips.TASK_TIME to match
- var GCTimePercent = 0.1;
return (totalGCTime > GCTimePercent * totalDuration) ?
("hsla(0, 100%, 50%, " + totalDurationAlpha(totalGCTime, totalDuration) + ")") : "";
}
function totalDurationColor(totalGCTime, totalDuration) {
- // Red if GC time over GCTimePercent of total time
- // When GCTimePercent is edited change ToolTips.TASK_TIME to match
- var GCTimePercent = 0.1;
return (totalGCTime > GCTimePercent * totalDuration) ? "white" : "black";
}
@@ -392,8 +406,18 @@ $(document).ready(function () {
{data: 'executorLogs', render: formatLogsCells},
{
data: 'id', render: function (data, type) {
- return type === 'display' ? ("Thread Dump" ) : data;
+ return type === 'display' ? ("Thread Dump" ) : data;
+ }
}
+ ],
+ "columnDefs": [
+ {
+ "targets": [ 15 ],
+ "visible": logsExist(response)
+ },
+ {
+ "targets": [ 16 ],
+ "visible": getThreadDumpEnabled()
}
],
"order": [[0, "asc"]]
@@ -458,7 +482,7 @@ $(document).ready(function () {
"paging": false,
"searching": false,
"info": false
-
+
};
$(sumSelector).DataTable(sumConf);
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala
index 287390b87bd73..982e8915a8ded 100644
--- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala
@@ -50,16 +50,15 @@ private[ui] class ExecutorsPage(
threadDumpEnabled: Boolean)
extends WebUIPage("") {
private val listener = parent.listener
- // When GCTimePercent is edited change ToolTips.TASK_TIME to match
- private val GCTimePercent = 0.1
def render(request: HttpServletRequest): Seq[Node] = {
val content =
{
-
++
+
++
++
-
+ ++
+
}
;
From 071eaaf9d2b63589f2e66e5279a16a5a484de6f5 Mon Sep 17 00:00:00 2001
From: Kousuke Saruta
Date: Fri, 19 Aug 2016 10:11:25 -0500
Subject: [PATCH 030/270] [SPARK-11227][CORE] UnknownHostException can be
thrown when NameNode HA is enabled.
## What changes were proposed in this pull request?
If the following conditions are satisfied, executors don't load properties in `hdfs-site.xml` and UnknownHostException can be thrown.
(1) NameNode HA is enabled
(2) spark.eventLogging is disabled or logging path is NOT on HDFS
(3) Using Standalone or Mesos for the cluster manager
(4) There are no code to load `HdfsCondition` class in the driver regardless of directly or indirectly.
(5) The tasks access to HDFS
(There might be some more conditions...)
For example, following code causes UnknownHostException when the conditions above are satisfied.
```
sc.textFile("").collect
```
```
java.lang.IllegalArgumentException: java.net.UnknownHostException: hacluster
at org.apache.hadoop.security.SecurityUtil.buildTokenService(SecurityUtil.java:378)
at org.apache.hadoop.hdfs.NameNodeProxies.createNonHAProxy(NameNodeProxies.java:310)
at org.apache.hadoop.hdfs.NameNodeProxies.createProxy(NameNodeProxies.java:176)
at org.apache.hadoop.hdfs.DFSClient.(DFSClient.java:678)
at org.apache.hadoop.hdfs.DFSClient.(DFSClient.java:619)
at org.apache.hadoop.hdfs.DistributedFileSystem.initialize(DistributedFileSystem.java:149)
at org.apache.hadoop.fs.FileSystem.createFileSystem(FileSystem.java:2653)
at org.apache.hadoop.fs.FileSystem.access$200(FileSystem.java:92)
at org.apache.hadoop.fs.FileSystem$Cache.getInternal(FileSystem.java:2687)
at org.apache.hadoop.fs.FileSystem$Cache.get(FileSystem.java:2669)
at org.apache.hadoop.fs.FileSystem.get(FileSystem.java:371)
at org.apache.hadoop.fs.FileSystem.get(FileSystem.java:170)
at org.apache.hadoop.mapred.JobConf.getWorkingDirectory(JobConf.java:656)
at org.apache.hadoop.mapred.FileInputFormat.setInputPaths(FileInputFormat.java:438)
at org.apache.hadoop.mapred.FileInputFormat.setInputPaths(FileInputFormat.java:411)
at org.apache.spark.SparkContext$$anonfun$hadoopFile$1$$anonfun$32.apply(SparkContext.scala:986)
at org.apache.spark.SparkContext$$anonfun$hadoopFile$1$$anonfun$32.apply(SparkContext.scala:986)
at org.apache.spark.rdd.HadoopRDD$$anonfun$getJobConf$6.apply(HadoopRDD.scala:177)
at org.apache.spark.rdd.HadoopRDD$$anonfun$getJobConf$6.apply(HadoopRDD.scala:177)
at scala.Option.map(Option.scala:146)
at org.apache.spark.rdd.HadoopRDD.getJobConf(HadoopRDD.scala:177)
at org.apache.spark.rdd.HadoopRDD$$anon$1.(HadoopRDD.scala:213)
at org.apache.spark.rdd.HadoopRDD.compute(HadoopRDD.scala:209)
at org.apache.spark.rdd.HadoopRDD.compute(HadoopRDD.scala:102)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:318)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:282)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:318)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:282)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:70)
at org.apache.spark.scheduler.Task.run(Task.scala:85)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:274)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
at java.lang.Thread.run(Thread.java:745)
Caused by: java.net.UnknownHostException: hacluster
```
But following code doesn't cause the Exception because `textFile` method loads `HdfsConfiguration` indirectly.
```
sc.textFile("").collect
```
When a job includes some operations which access to HDFS, the object of `org.apache.hadoop.Configuration` is wrapped by `SerializableConfiguration`, serialized and broadcasted from driver to executors and each executor deserialize the object with `loadDefaults` false so HDFS related properties should be set before broadcasted.
## How was this patch tested?
Tested manually on my standalone cluster.
Author: Kousuke Saruta
Closes #13738 from sarutak/SPARK-11227.
---
.../scala/org/apache/spark/SparkContext.scala | 22 ++++++++++++++++++-
1 file changed, 21 insertions(+), 1 deletion(-)
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 60f042f1e07c5..2eaeab1d807b4 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -35,7 +35,7 @@ import scala.util.control.NonFatal
import com.google.common.collect.MapMaker
import org.apache.commons.lang3.SerializationUtils
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.Path
+import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable,
FloatWritable, IntWritable, LongWritable, NullWritable, Text, Writable}
import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf, SequenceFileInputFormat,
@@ -961,6 +961,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
valueClass: Class[V],
minPartitions: Int = defaultMinPartitions): RDD[(K, V)] = withScope {
assertNotStopped()
+
+ // This is a hack to enforce loading hdfs-site.xml.
+ // See SPARK-11227 for details.
+ FileSystem.getLocal(conf)
+
// Add necessary security credentials to the JobConf before broadcasting it.
SparkHadoopUtil.get.addCredentials(conf)
new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minPartitions)
@@ -981,6 +986,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
valueClass: Class[V],
minPartitions: Int = defaultMinPartitions): RDD[(K, V)] = withScope {
assertNotStopped()
+
+ // This is a hack to enforce loading hdfs-site.xml.
+ // See SPARK-11227 for details.
+ FileSystem.get(new URI(path), hadoopConfiguration)
+
// A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it.
val confBroadcast = broadcast(new SerializableConfiguration(hadoopConfiguration))
val setInputPathsFunc = (jobConf: JobConf) => FileInputFormat.setInputPaths(jobConf, path)
@@ -1065,6 +1075,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
vClass: Class[V],
conf: Configuration = hadoopConfiguration): RDD[(K, V)] = withScope {
assertNotStopped()
+
+ // This is a hack to enforce loading hdfs-site.xml.
+ // See SPARK-11227 for details.
+ FileSystem.get(new URI(path), hadoopConfiguration)
+
// The call to NewHadoopJob automatically adds security credentials to conf,
// so we don't need to explicitly add them ourselves
val job = NewHadoopJob.getInstance(conf)
@@ -1099,6 +1114,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
kClass: Class[K],
vClass: Class[V]): RDD[(K, V)] = withScope {
assertNotStopped()
+
+ // This is a hack to enforce loading hdfs-site.xml.
+ // See SPARK-11227 for details.
+ FileSystem.getLocal(conf)
+
// Add necessary security credentials to the JobConf. Required to access secure HDFS.
val jconf = new JobConf(conf)
SparkHadoopUtil.get.addCredentials(jconf)
From cf0cce90364d17afe780ff9a5426dfcefa298535 Mon Sep 17 00:00:00 2001
From: Sital Kedia
Date: Fri, 19 Aug 2016 11:27:30 -0700
Subject: [PATCH 031/270] [SPARK-17113] [SHUFFLE] Job failure due to Executor
OOM in offheap mode
## What changes were proposed in this pull request?
This PR fixes executor OOM in offheap mode due to bug in Cooperative Memory Management for UnsafeExternSorter. UnsafeExternalSorter was checking if memory page is being used by upstream by comparing the base object address of the current page with the base object address of upstream. However, in case of offheap memory allocation, the base object addresses are always null, so there was no spilling happening and eventually the operator would OOM.
Following is the stack trace this issue addresses -
java.lang.OutOfMemoryError: Unable to acquire 1220 bytes of memory, got 0
at org.apache.spark.memory.MemoryConsumer.allocatePage(MemoryConsumer.java:120)
at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.acquireNewPageIfNecessary(UnsafeExternalSorter.java:341)
at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.insertRecord(UnsafeExternalSorter.java:362)
at org.apache.spark.sql.execution.UnsafeExternalRowSorter.insertRow(UnsafeExternalRowSorter.java:93)
at org.apache.spark.sql.execution.UnsafeExternalRowSorter.sort(UnsafeExternalRowSorter.java:170)
## How was this patch tested?
Tested by running the failing job.
Author: Sital Kedia
Closes #14693 from sitalkedia/fix_offheap_oom.
---
.../util/collection/unsafe/sort/UnsafeExternalSorter.java | 2 +-
.../util/collection/unsafe/sort/UnsafeInMemorySorter.java | 7 +++++++
2 files changed, 8 insertions(+), 1 deletion(-)
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index 8d596f87d213b..ccf76643db2b4 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -522,7 +522,7 @@ public long spill() throws IOException {
// is accessing the current record. We free this page in that caller's next loadNext()
// call.
for (MemoryBlock page : allocatedPages) {
- if (!loaded || page.getBaseObject() != upstream.getBaseObject()) {
+ if (!loaded || page.pageNumber != ((UnsafeInMemorySorter.SortedIterator)upstream).getCurrentPageNumber()) {
released += page.size();
freePage(page);
} else {
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index 78da38927878b..30d0f3006a04e 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -248,6 +248,7 @@ public final class SortedIterator extends UnsafeSorterIterator implements Clonea
private long baseOffset;
private long keyPrefix;
private int recordLength;
+ private long currentPageNumber;
private SortedIterator(int numRecords, int offset) {
this.numRecords = numRecords;
@@ -262,6 +263,7 @@ public SortedIterator clone() {
iter.baseOffset = baseOffset;
iter.keyPrefix = keyPrefix;
iter.recordLength = recordLength;
+ iter.currentPageNumber = currentPageNumber;
return iter;
}
@@ -279,6 +281,7 @@ public boolean hasNext() {
public void loadNext() {
// This pointer points to a 4-byte record length, followed by the record's bytes
final long recordPointer = array.get(offset + position);
+ currentPageNumber = memoryManager.decodePageNumber(recordPointer);
baseObject = memoryManager.getPage(recordPointer);
baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4; // Skip over record length
recordLength = Platform.getInt(baseObject, baseOffset - 4);
@@ -292,6 +295,10 @@ public void loadNext() {
@Override
public long getBaseOffset() { return baseOffset; }
+ public long getCurrentPageNumber() {
+ return currentPageNumber;
+ }
+
@Override
public int getRecordLength() { return recordLength; }
From acac7a508a29d0f75d86ee2e4ca83ebf01a36cf8 Mon Sep 17 00:00:00 2001
From: Junyang Qian
Date: Fri, 19 Aug 2016 14:24:09 -0700
Subject: [PATCH 032/270] [SPARK-16443][SPARKR] Alternating Least Squares (ALS)
wrapper
## What changes were proposed in this pull request?
Add Alternating Least Squares wrapper in SparkR. Unit tests have been updated.
## How was this patch tested?
SparkR unit tests.
(If this patch involves UI changes, please attach a screenshot; otherwise, remove this)


Author: Junyang Qian
Closes #14384 from junyangq/SPARK-16443.
---
R/pkg/NAMESPACE | 3 +-
R/pkg/R/generics.R | 4 +
R/pkg/R/mllib.R | 159 +++++++++++++++++-
R/pkg/inst/tests/testthat/test_mllib.R | 40 +++++
.../org/apache/spark/ml/r/ALSWrapper.scala | 119 +++++++++++++
.../org/apache/spark/ml/r/RWrappers.scala | 2 +
6 files changed, 322 insertions(+), 5 deletions(-)
create mode 100644 mllib/src/main/scala/org/apache/spark/ml/r/ALSWrapper.scala
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 4404cffc292aa..e1b87b28d35ae 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -29,7 +29,8 @@ exportMethods("glm",
"spark.posterior",
"spark.perplexity",
"spark.isoreg",
- "spark.gaussianMixture")
+ "spark.gaussianMixture",
+ "spark.als")
# Job group lifecycle management methods
export("setJobGroup",
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index fe04bcfc7d14d..693aa31d3ecab 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1332,3 +1332,7 @@ setGeneric("spark.gaussianMixture",
#' @rdname write.ml
#' @export
setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") })
+
+#' @rdname spark.als
+#' @export
+setGeneric("spark.als", function(data, ...) { standardGeneric("spark.als") })
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index b9527410a9853..36f38fc73a510 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -74,6 +74,13 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj"))
#' @note GaussianMixtureModel since 2.1.0
setClass("GaussianMixtureModel", representation(jobj = "jobj"))
+#' S4 class that represents an ALSModel
+#'
+#' @param jobj a Java object reference to the backing Scala ALSWrapper
+#' @export
+#' @note ALSModel since 2.1.0
+setClass("ALSModel", representation(jobj = "jobj"))
+
#' Saves the MLlib model to the input path
#'
#' Saves the MLlib model to the input path. For more information, see the specific
@@ -82,8 +89,8 @@ setClass("GaussianMixtureModel", representation(jobj = "jobj"))
#' @name write.ml
#' @export
#' @seealso \link{spark.glm}, \link{glm}, \link{spark.gaussianMixture}
-#' @seealso \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg}, \link{spark.lda}
-#' @seealso \link{spark.isoreg}
+#' @seealso \link{spark.als}, \link{spark.kmeans}, \link{spark.lda}, \link{spark.naiveBayes}
+#' @seealso \link{spark.survreg}, \link{spark.isoreg}
#' @seealso \link{read.ml}
NULL
@@ -95,10 +102,11 @@ NULL
#' @name predict
#' @export
#' @seealso \link{spark.glm}, \link{glm}, \link{spark.gaussianMixture}
-#' @seealso \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg}
+#' @seealso \link{spark.als}, \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg}
#' @seealso \link{spark.isoreg}
NULL
+
#' Generalized Linear Models
#'
#' Fits generalized linear model against a Spark DataFrame.
@@ -801,6 +809,8 @@ read.ml <- function(path) {
return(new("IsotonicRegressionModel", jobj = jobj))
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GaussianMixtureWrapper")) {
return(new("GaussianMixtureModel", jobj = jobj))
+ } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.ALSWrapper")) {
+ return(new("ALSModel", jobj = jobj))
} else {
stop(paste("Unsupported model: ", jobj))
}
@@ -1053,4 +1063,145 @@ setMethod("summary", signature(object = "GaussianMixtureModel"),
setMethod("predict", signature(object = "GaussianMixtureModel"),
function(object, newData) {
return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
- })
\ No newline at end of file
+ })
+
+#' Alternating Least Squares (ALS) for Collaborative Filtering
+#'
+#' \code{spark.als} learns latent factors in collaborative filtering via alternating least
+#' squares. Users can call \code{summary} to obtain fitted latent factors, \code{predict}
+#' to make predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models.
+#'
+#' For more details, see
+#' \href{http://spark.apache.org/docs/latest/ml-collaborative-filtering.html}{MLlib:
+#' Collaborative Filtering}.
+#'
+#' @param data a SparkDataFrame for training.
+#' @param ratingCol column name for ratings.
+#' @param userCol column name for user ids. Ids must be (or can be coerced into) integers.
+#' @param itemCol column name for item ids. Ids must be (or can be coerced into) integers.
+#' @param rank rank of the matrix factorization (> 0).
+#' @param reg regularization parameter (>= 0).
+#' @param maxIter maximum number of iterations (>= 0).
+#' @param nonnegative logical value indicating whether to apply nonnegativity constraints.
+#' @param implicitPrefs logical value indicating whether to use implicit preference.
+#' @param alpha alpha parameter in the implicit preference formulation (>= 0).
+#' @param seed integer seed for random number generation.
+#' @param numUserBlocks number of user blocks used to parallelize computation (> 0).
+#' @param numItemBlocks number of item blocks used to parallelize computation (> 0).
+#' @param checkpointInterval number of checkpoint intervals (>= 1) or disable checkpoint (-1).
+#'
+#' @return \code{spark.als} returns a fitted ALS model
+#' @rdname spark.als
+#' @aliases spark.als,SparkDataFrame-method
+#' @name spark.als
+#' @export
+#' @examples
+#' \dontrun{
+#' ratings <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0),
+#' list(2, 1, 1.0), list(2, 2, 5.0))
+#' df <- createDataFrame(ratings, c("user", "item", "rating"))
+#' model <- spark.als(df, "rating", "user", "item")
+#'
+#' # extract latent factors
+#' stats <- summary(model)
+#' userFactors <- stats$userFactors
+#' itemFactors <- stats$itemFactors
+#'
+#' # make predictions
+#' predicted <- predict(model, df)
+#' showDF(predicted)
+#'
+#' # save and load the model
+#' path <- "path/to/model"
+#' write.ml(model, path)
+#' savedModel <- read.ml(path)
+#' summary(savedModel)
+#'
+#' # set other arguments
+#' modelS <- spark.als(df, "rating", "user", "item", rank = 20,
+#' reg = 0.1, nonnegative = TRUE)
+#' statsS <- summary(modelS)
+#' }
+#' @note spark.als since 2.1.0
+setMethod("spark.als", signature(data = "SparkDataFrame"),
+ function(data, ratingCol = "rating", userCol = "user", itemCol = "item",
+ rank = 10, reg = 1.0, maxIter = 10, nonnegative = FALSE,
+ implicitPrefs = FALSE, alpha = 1.0, numUserBlocks = 10, numItemBlocks = 10,
+ checkpointInterval = 10, seed = 0) {
+
+ if (!is.numeric(rank) || rank <= 0) {
+ stop("rank should be a positive number.")
+ }
+ if (!is.numeric(reg) || reg < 0) {
+ stop("reg should be a nonnegative number.")
+ }
+ if (!is.numeric(maxIter) || maxIter <= 0) {
+ stop("maxIter should be a positive number.")
+ }
+
+ jobj <- callJStatic("org.apache.spark.ml.r.ALSWrapper",
+ "fit", data@sdf, ratingCol, userCol, itemCol, as.integer(rank),
+ reg, as.integer(maxIter), implicitPrefs, alpha, nonnegative,
+ as.integer(numUserBlocks), as.integer(numItemBlocks),
+ as.integer(checkpointInterval), as.integer(seed))
+ return(new("ALSModel", jobj = jobj))
+ })
+
+# Returns a summary of the ALS model produced by spark.als.
+
+#' @param object a fitted ALS model.
+#' @return \code{summary} returns a list containing the names of the user column,
+#' the item column and the rating column, the estimated user and item factors,
+#' rank, regularization parameter and maximum number of iterations used in training.
+#' @rdname spark.als
+#' @aliases summary,ALSModel-method
+#' @export
+#' @note summary(ALSModel) since 2.1.0
+setMethod("summary", signature(object = "ALSModel"),
+function(object, ...) {
+ jobj <- object@jobj
+ user <- callJMethod(jobj, "userCol")
+ item <- callJMethod(jobj, "itemCol")
+ rating <- callJMethod(jobj, "ratingCol")
+ userFactors <- dataFrame(callJMethod(jobj, "userFactors"))
+ itemFactors <- dataFrame(callJMethod(jobj, "itemFactors"))
+ rank <- callJMethod(jobj, "rank")
+ return(list(user = user, item = item, rating = rating, userFactors = userFactors,
+ itemFactors = itemFactors, rank = rank))
+})
+
+
+# Makes predictions from an ALS model or a model produced by spark.als.
+
+#' @param newData a SparkDataFrame for testing.
+#' @return \code{predict} returns a SparkDataFrame containing predicted values.
+#' @rdname spark.als
+#' @aliases predict,ALSModel-method
+#' @export
+#' @note predict(ALSModel) since 2.1.0
+setMethod("predict", signature(object = "ALSModel"),
+function(object, newData) {
+ return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
+})
+
+
+# Saves the ALS model to the input path.
+
+#' @param path the directory where the model is saved.
+#' @param overwrite logical value indicating whether to overwrite if the output path
+#' already exists. Default is FALSE which means throw exception
+#' if the output path exists.
+#'
+#' @rdname spark.als
+#' @aliases write.ml,ALSModel,character-method
+#' @export
+#' @seealso \link{read.ml}
+#' @note write.ml(ALSModel, character) since 2.1.0
+setMethod("write.ml", signature(object = "ALSModel", path = "character"),
+function(object, path, overwrite = FALSE) {
+ writer <- callJMethod(object@jobj, "write")
+ if (overwrite) {
+ writer <- callJMethod(writer, "overwrite")
+ }
+ invisible(callJMethod(writer, "save", path))
+})
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index dfb7a185cd5a3..67a3099101cf1 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -657,4 +657,44 @@ test_that("spark.posterior and spark.perplexity", {
expect_equal(length(local.posterior), sum(unlist(local.posterior)))
})
+test_that("spark.als", {
+ data <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0),
+ list(2, 1, 1.0), list(2, 2, 5.0))
+ df <- createDataFrame(data, c("user", "item", "score"))
+ model <- spark.als(df, ratingCol = "score", userCol = "user", itemCol = "item",
+ rank = 10, maxIter = 5, seed = 0, reg = 0.1)
+ stats <- summary(model)
+ expect_equal(stats$rank, 10)
+ test <- createDataFrame(list(list(0, 2), list(1, 0), list(2, 0)), c("user", "item"))
+ predictions <- collect(predict(model, test))
+
+ expect_equal(predictions$prediction, c(-0.1380762, 2.6258414, -1.5018409),
+ tolerance = 1e-4)
+
+ # Test model save/load
+ modelPath <- tempfile(pattern = "spark-als", fileext = ".tmp")
+ write.ml(model, modelPath)
+ expect_error(write.ml(model, modelPath))
+ write.ml(model, modelPath, overwrite = TRUE)
+ model2 <- read.ml(modelPath)
+ stats2 <- summary(model2)
+ expect_equal(stats2$rating, "score")
+ userFactors <- collect(stats$userFactors)
+ itemFactors <- collect(stats$itemFactors)
+ userFactors2 <- collect(stats2$userFactors)
+ itemFactors2 <- collect(stats2$itemFactors)
+
+ orderUser <- order(userFactors$id)
+ orderUser2 <- order(userFactors2$id)
+ expect_equal(userFactors$id[orderUser], userFactors2$id[orderUser2])
+ expect_equal(userFactors$features[orderUser], userFactors2$features[orderUser2])
+
+ orderItem <- order(itemFactors$id)
+ orderItem2 <- order(itemFactors2$id)
+ expect_equal(itemFactors$id[orderItem], itemFactors2$id[orderItem2])
+ expect_equal(itemFactors$features[orderItem], itemFactors2$features[orderItem2])
+
+ unlink(modelPath)
+})
+
sparkR.session.stop()
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/ALSWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/ALSWrapper.scala
new file mode 100644
index 0000000000000..ad13cced4667b
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/ALSWrapper.scala
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import org.apache.hadoop.fs.Path
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.ml.recommendation.{ALS, ALSModel}
+import org.apache.spark.ml.util._
+import org.apache.spark.sql.{DataFrame, Dataset}
+
+private[r] class ALSWrapper private (
+ val alsModel: ALSModel,
+ val ratingCol: String) extends MLWritable {
+
+ lazy val userCol: String = alsModel.getUserCol
+ lazy val itemCol: String = alsModel.getItemCol
+ lazy val userFactors: DataFrame = alsModel.userFactors
+ lazy val itemFactors: DataFrame = alsModel.itemFactors
+ lazy val rank: Int = alsModel.rank
+
+ def transform(dataset: Dataset[_]): DataFrame = {
+ alsModel.transform(dataset)
+ }
+
+ override def write: MLWriter = new ALSWrapper.ALSWrapperWriter(this)
+}
+
+private[r] object ALSWrapper extends MLReadable[ALSWrapper] {
+
+ def fit( // scalastyle:ignore
+ data: DataFrame,
+ ratingCol: String,
+ userCol: String,
+ itemCol: String,
+ rank: Int,
+ regParam: Double,
+ maxIter: Int,
+ implicitPrefs: Boolean,
+ alpha: Double,
+ nonnegative: Boolean,
+ numUserBlocks: Int,
+ numItemBlocks: Int,
+ checkpointInterval: Int,
+ seed: Int): ALSWrapper = {
+
+ val als = new ALS()
+ .setRatingCol(ratingCol)
+ .setUserCol(userCol)
+ .setItemCol(itemCol)
+ .setRank(rank)
+ .setRegParam(regParam)
+ .setMaxIter(maxIter)
+ .setImplicitPrefs(implicitPrefs)
+ .setAlpha(alpha)
+ .setNonnegative(nonnegative)
+ .setNumBlocks(numUserBlocks)
+ .setNumItemBlocks(numItemBlocks)
+ .setCheckpointInterval(checkpointInterval)
+ .setSeed(seed.toLong)
+
+ val alsModel: ALSModel = als.fit(data)
+
+ new ALSWrapper(alsModel, ratingCol)
+ }
+
+ override def read: MLReader[ALSWrapper] = new ALSWrapperReader
+
+ override def load(path: String): ALSWrapper = super.load(path)
+
+ class ALSWrapperWriter(instance: ALSWrapper) extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ val rMetadataPath = new Path(path, "rMetadata").toString
+ val modelPath = new Path(path, "model").toString
+
+ val rMetadata = ("class" -> instance.getClass.getName) ~
+ ("ratingCol" -> instance.ratingCol)
+ val rMetadataJson: String = compact(render(rMetadata))
+ sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
+
+ instance.alsModel.save(modelPath)
+ }
+ }
+
+ class ALSWrapperReader extends MLReader[ALSWrapper] {
+
+ override def load(path: String): ALSWrapper = {
+ implicit val format = DefaultFormats
+ val rMetadataPath = new Path(path, "rMetadata").toString
+ val modelPath = new Path(path, "model").toString
+
+ val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
+ val rMetadata = parse(rMetadataStr)
+ val ratingCol = (rMetadata \ "ratingCol").extract[String]
+ val alsModel = ALSModel.load(modelPath)
+
+ new ALSWrapper(alsModel, ratingCol)
+ }
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
index e23af51df5718..51a65f7fc4fe8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
@@ -50,6 +50,8 @@ private[r] object RWrappers extends MLReader[Object] {
IsotonicRegressionWrapper.load(path)
case "org.apache.spark.ml.r.GaussianMixtureWrapper" =>
GaussianMixtureWrapper.load(path)
+ case "org.apache.spark.ml.r.ALSWrapper" =>
+ ALSWrapper.load(path)
case _ =>
throw new SparkException(s"SparkR read.ml does not support load $className")
}
From a117afa7c2d94f943106542ec53d74ba2b5f1058 Mon Sep 17 00:00:00 2001
From: petermaxlee
Date: Fri, 19 Aug 2016 18:14:45 -0700
Subject: [PATCH 033/270] [SPARK-17149][SQL] array.sql for testing array
related functions
## What changes were proposed in this pull request?
This patch creates array.sql in SQLQueryTestSuite for testing array related functions, including:
- indexing
- array creation
- size
- array_contains
- sort_array
## How was this patch tested?
The patch itself is about adding tests.
Author: petermaxlee
Closes #14708 from petermaxlee/SPARK-17149.
---
.../catalyst/analysis/FunctionRegistry.scala | 12 +-
.../test/resources/sql-tests/inputs/array.sql | 86 +++++++++++
.../resources/sql-tests/results/array.sql.out | 144 ++++++++++++++++++
.../org/apache/spark/sql/SQLQuerySuite.scala | 16 --
.../apache/spark/sql/SQLQueryTestSuite.scala | 10 ++
.../execution/HiveCompatibilitySuite.scala | 4 +-
.../sql/hive/execution/HiveQuerySuite.scala | 9 --
7 files changed, 248 insertions(+), 33 deletions(-)
create mode 100644 sql/core/src/test/resources/sql-tests/inputs/array.sql
create mode 100644 sql/core/src/test/resources/sql-tests/results/array.sql.out
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index c5f91c1590542..35fd800df4a4f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -161,7 +161,6 @@ object FunctionRegistry {
val expressions: Map[String, (ExpressionInfo, FunctionBuilder)] = Map(
// misc non-aggregate functions
expression[Abs]("abs"),
- expression[CreateArray]("array"),
expression[Coalesce]("coalesce"),
expression[Explode]("explode"),
expression[Greatest]("greatest"),
@@ -172,10 +171,6 @@ object FunctionRegistry {
expression[IsNull]("isnull"),
expression[IsNotNull]("isnotnull"),
expression[Least]("least"),
- expression[CreateMap]("map"),
- expression[MapKeys]("map_keys"),
- expression[MapValues]("map_values"),
- expression[CreateNamedStruct]("named_struct"),
expression[NaNvl]("nanvl"),
expression[NullIf]("nullif"),
expression[Nvl]("nvl"),
@@ -184,7 +179,6 @@ object FunctionRegistry {
expression[Rand]("rand"),
expression[Randn]("randn"),
expression[Stack]("stack"),
- expression[CreateStruct]("struct"),
expression[CaseWhen]("when"),
// math functions
@@ -354,9 +348,15 @@ object FunctionRegistry {
expression[TimeWindow]("window"),
// collection functions
+ expression[CreateArray]("array"),
expression[ArrayContains]("array_contains"),
+ expression[CreateMap]("map"),
+ expression[CreateNamedStruct]("named_struct"),
+ expression[MapKeys]("map_keys"),
+ expression[MapValues]("map_values"),
expression[Size]("size"),
expression[SortArray]("sort_array"),
+ expression[CreateStruct]("struct"),
// misc functions
expression[AssertTrue]("assert_true"),
diff --git a/sql/core/src/test/resources/sql-tests/inputs/array.sql b/sql/core/src/test/resources/sql-tests/inputs/array.sql
new file mode 100644
index 0000000000000..4038a0da41d2b
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/array.sql
@@ -0,0 +1,86 @@
+-- test cases for array functions
+
+create temporary view data as select * from values
+ ("one", array(11, 12, 13), array(array(111, 112, 113), array(121, 122, 123))),
+ ("two", array(21, 22, 23), array(array(211, 212, 213), array(221, 222, 223)))
+ as data(a, b, c);
+
+select * from data;
+
+-- index into array
+select a, b[0], b[0] + b[1] from data;
+
+-- index into array of arrays
+select a, c[0][0] + c[0][0 + 1] from data;
+
+
+create temporary view primitive_arrays as select * from values (
+ array(true),
+ array(2Y, 1Y),
+ array(2S, 1S),
+ array(2, 1),
+ array(2L, 1L),
+ array(9223372036854775809, 9223372036854775808),
+ array(2.0D, 1.0D),
+ array(float(2.0), float(1.0)),
+ array(date '2016-03-14', date '2016-03-13'),
+ array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000')
+) as primitive_arrays(
+ boolean_array,
+ tinyint_array,
+ smallint_array,
+ int_array,
+ bigint_array,
+ decimal_array,
+ double_array,
+ float_array,
+ date_array,
+ timestamp_array
+);
+
+select * from primitive_arrays;
+
+-- array_contains on all primitive types: result should alternate between true and false
+select
+ array_contains(boolean_array, true), array_contains(boolean_array, false),
+ array_contains(tinyint_array, 2Y), array_contains(tinyint_array, 0Y),
+ array_contains(smallint_array, 2S), array_contains(smallint_array, 0S),
+ array_contains(int_array, 2), array_contains(int_array, 0),
+ array_contains(bigint_array, 2L), array_contains(bigint_array, 0L),
+ array_contains(decimal_array, 9223372036854775809), array_contains(decimal_array, 1),
+ array_contains(double_array, 2.0D), array_contains(double_array, 0.0D),
+ array_contains(float_array, float(2.0)), array_contains(float_array, float(0.0)),
+ array_contains(date_array, date '2016-03-14'), array_contains(date_array, date '2016-01-01'),
+ array_contains(timestamp_array, timestamp '2016-11-15 20:54:00.000'), array_contains(timestamp_array, timestamp '2016-01-01 20:54:00.000')
+from primitive_arrays;
+
+-- array_contains on nested arrays
+select array_contains(b, 11), array_contains(c, array(111, 112, 113)) from data;
+
+-- sort_array
+select
+ sort_array(boolean_array),
+ sort_array(tinyint_array),
+ sort_array(smallint_array),
+ sort_array(int_array),
+ sort_array(bigint_array),
+ sort_array(decimal_array),
+ sort_array(double_array),
+ sort_array(float_array),
+ sort_array(date_array),
+ sort_array(timestamp_array)
+from primitive_arrays;
+
+-- size
+select
+ size(boolean_array),
+ size(tinyint_array),
+ size(smallint_array),
+ size(int_array),
+ size(bigint_array),
+ size(decimal_array),
+ size(double_array),
+ size(float_array),
+ size(date_array),
+ size(timestamp_array)
+from primitive_arrays;
diff --git a/sql/core/src/test/resources/sql-tests/results/array.sql.out b/sql/core/src/test/resources/sql-tests/results/array.sql.out
new file mode 100644
index 0000000000000..4a1d149c1f362
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/array.sql.out
@@ -0,0 +1,144 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 10
+
+
+-- !query 0
+create temporary view data as select * from values
+ ("one", array(11, 12, 13), array(array(111, 112, 113), array(121, 122, 123))),
+ ("two", array(21, 22, 23), array(array(211, 212, 213), array(221, 222, 223)))
+ as data(a, b, c)
+-- !query 0 schema
+struct<>
+-- !query 0 output
+
+
+
+-- !query 1
+select * from data
+-- !query 1 schema
+struct,c:array>>
+-- !query 1 output
+one [11,12,13] [[111,112,113],[121,122,123]]
+two [21,22,23] [[211,212,213],[221,222,223]]
+
+
+-- !query 2
+select a, b[0], b[0] + b[1] from data
+-- !query 2 schema
+struct
+-- !query 2 output
+one 11 23
+two 21 43
+
+
+-- !query 3
+select a, c[0][0] + c[0][0 + 1] from data
+-- !query 3 schema
+struct
+-- !query 3 output
+one 223
+two 423
+
+
+-- !query 4
+create temporary view primitive_arrays as select * from values (
+ array(true),
+ array(2Y, 1Y),
+ array(2S, 1S),
+ array(2, 1),
+ array(2L, 1L),
+ array(9223372036854775809, 9223372036854775808),
+ array(2.0D, 1.0D),
+ array(float(2.0), float(1.0)),
+ array(date '2016-03-14', date '2016-03-13'),
+ array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000')
+) as primitive_arrays(
+ boolean_array,
+ tinyint_array,
+ smallint_array,
+ int_array,
+ bigint_array,
+ decimal_array,
+ double_array,
+ float_array,
+ date_array,
+ timestamp_array
+)
+-- !query 4 schema
+struct<>
+-- !query 4 output
+
+
+
+-- !query 5
+select * from primitive_arrays
+-- !query 5 schema
+struct,tinyint_array:array,smallint_array:array,int_array:array,bigint_array:array,decimal_array:array,double_array:array,float_array:array,date_array:array,timestamp_array:array>
+-- !query 5 output
+[true] [2,1] [2,1] [2,1] [2,1] [9223372036854775809,9223372036854775808] [2.0,1.0] [2.0,1.0] [2016-03-14,2016-03-13] [2016-11-15 20:54:00.0,2016-11-12 20:54:00.0]
+
+
+-- !query 6
+select
+ array_contains(boolean_array, true), array_contains(boolean_array, false),
+ array_contains(tinyint_array, 2Y), array_contains(tinyint_array, 0Y),
+ array_contains(smallint_array, 2S), array_contains(smallint_array, 0S),
+ array_contains(int_array, 2), array_contains(int_array, 0),
+ array_contains(bigint_array, 2L), array_contains(bigint_array, 0L),
+ array_contains(decimal_array, 9223372036854775809), array_contains(decimal_array, 1),
+ array_contains(double_array, 2.0D), array_contains(double_array, 0.0D),
+ array_contains(float_array, float(2.0)), array_contains(float_array, float(0.0)),
+ array_contains(date_array, date '2016-03-14'), array_contains(date_array, date '2016-01-01'),
+ array_contains(timestamp_array, timestamp '2016-11-15 20:54:00.000'), array_contains(timestamp_array, timestamp '2016-01-01 20:54:00.000')
+from primitive_arrays
+-- !query 6 schema
+struct
+-- !query 6 output
+true false true false true false true false true false true false true false true false true false true false
+
+
+-- !query 7
+select array_contains(b, 11), array_contains(c, array(111, 112, 113)) from data
+-- !query 7 schema
+struct
+-- !query 7 output
+false false
+true true
+
+
+-- !query 8
+select
+ sort_array(boolean_array),
+ sort_array(tinyint_array),
+ sort_array(smallint_array),
+ sort_array(int_array),
+ sort_array(bigint_array),
+ sort_array(decimal_array),
+ sort_array(double_array),
+ sort_array(float_array),
+ sort_array(date_array),
+ sort_array(timestamp_array)
+from primitive_arrays
+-- !query 8 schema
+struct,sort_array(tinyint_array, true):array,sort_array(smallint_array, true):array,sort_array(int_array, true):array,sort_array(bigint_array, true):array,sort_array(decimal_array, true):array,sort_array(double_array, true):array,sort_array(float_array, true):array,sort_array(date_array, true):array,sort_array(timestamp_array, true):array>
+-- !query 8 output
+[true] [1,2] [1,2] [1,2] [1,2] [9223372036854775808,9223372036854775809] [1.0,2.0] [1.0,2.0] [2016-03-13,2016-03-14] [2016-11-12 20:54:00.0,2016-11-15 20:54:00.0]
+
+
+-- !query 9
+select
+ size(boolean_array),
+ size(tinyint_array),
+ size(smallint_array),
+ size(int_array),
+ size(bigint_array),
+ size(decimal_array),
+ size(double_array),
+ size(float_array),
+ size(date_array),
+ size(timestamp_array)
+from primitive_arrays
+-- !query 9 schema
+struct
+-- !query 9 output
+1 2 2 2 2 2 2 2 2 2
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 4fcde58833d76..eac266cba55b8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -445,12 +445,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
Nil)
}
- test("index into array") {
- checkAnswer(
- sql("SELECT data, data[0], data[0] + data[1], data[0 + 1] FROM arrayData"),
- arrayData.map(d => Row(d.data, d.data(0), d.data(0) + d.data(1), d.data(1))).collect())
- }
-
test("left semi greater than predicate") {
withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
checkAnswer(
@@ -472,16 +466,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
)
}
- test("index into array of arrays") {
- checkAnswer(
- sql(
- "SELECT nestedData, nestedData[0][0], nestedData[0][0] + nestedData[0][1] FROM arrayData"),
- arrayData.map(d =>
- Row(d.nestedData,
- d.nestedData(0)(0),
- d.nestedData(0)(0) + d.nestedData(0)(1))).collect().toSeq)
- }
-
test("agg") {
checkAnswer(
sql("SELECT a, SUM(b) FROM testData2 GROUP BY a"),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
index 069a9b665eb36..55d5a56f1040a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
@@ -35,6 +35,16 @@ import org.apache.spark.sql.types.StructType
* Each case is loaded from a file in "spark/sql/core/src/test/resources/sql-tests/inputs".
* Each case has a golden result file in "spark/sql/core/src/test/resources/sql-tests/results".
*
+ * To run the entire test suite:
+ * {{{
+ * build/sbt "sql/test-only *SQLQueryTestSuite"
+ * }}}
+ *
+ * To run a single test file upon change:
+ * {{{
+ * build/sbt "~sql/test-only *SQLQueryTestSuite -- -z inline-table.sql"
+ * }}}
+ *
* To re-generate golden files, run:
* {{{
* SPARK_GENERATE_GOLDEN_FILES=1 build/sbt "sql/test-only *SQLQueryTestSuite"
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index 13d18fdec0e9d..a54d234876256 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -979,8 +979,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"udf_PI",
"udf_acos",
"udf_add",
- "udf_array",
- "udf_array_contains",
+ // "udf_array", -- done in array.sql
+ // "udf_array_contains", -- done in array.sql
"udf_ascii",
"udf_asin",
"udf_atan",
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 6785167d3dfba..3c7dbb449c521 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -216,15 +216,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
assert(new Timestamp(1000) == r1.getTimestamp(0))
}
- createQueryTest("constant array",
- """
- |SELECT sort_array(
- | sort_array(
- | array("hadoop distributed file system",
- | "enterprise databases", "hadoop map-reduce")))
- |FROM src LIMIT 1;
- """.stripMargin)
-
createQueryTest("null case",
"SELECT case when(true) then 1 else null end FROM src LIMIT 1")
From ba1737c21aab91ff3f1a1737aa2d6b07575e36a3 Mon Sep 17 00:00:00 2001
From: Srinath Shankar
Date: Fri, 19 Aug 2016 19:54:26 -0700
Subject: [PATCH 034/270] [SPARK-17158][SQL] Change error message for out of
range numeric literals
## What changes were proposed in this pull request?
Modifies error message for numeric literals to
Numeric literal does not fit in range [min, max] for type
## How was this patch tested?
Fixed up the error messages for literals.sql in SqlQueryTestSuite and re-ran via sbt. Also fixed up error messages in ExpressionParserSuite
Author: Srinath Shankar
Closes #14721 from srinathshankar/sc4296.
---
.../sql/catalyst/parser/AstBuilder.scala | 29 ++++++++++++-------
.../parser/ExpressionParserSuite.scala | 9 ++++--
.../sql-tests/results/literals.sql.out | 6 ++--
3 files changed, 27 insertions(+), 17 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 283e4d43ba2b9..8b98efcbf33c8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -1278,10 +1278,17 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
}
/** Create a numeric literal expression. */
- private def numericLiteral(ctx: NumberContext)(f: String => Any): Literal = withOrigin(ctx) {
- val raw = ctx.getText
+ private def numericLiteral
+ (ctx: NumberContext, minValue: BigDecimal, maxValue: BigDecimal, typeName: String)
+ (converter: String => Any): Literal = withOrigin(ctx) {
+ val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1)
try {
- Literal(f(raw.substring(0, raw.length - 1)))
+ val rawBigDecimal = BigDecimal(rawStrippedQualifier)
+ if (rawBigDecimal < minValue || rawBigDecimal > maxValue) {
+ throw new ParseException(s"Numeric literal ${rawStrippedQualifier} does not " +
+ s"fit in range [${minValue}, ${maxValue}] for type ${typeName}", ctx)
+ }
+ Literal(converter(rawStrippedQualifier))
} catch {
case e: NumberFormatException =>
throw new ParseException(e.getMessage, ctx)
@@ -1291,29 +1298,29 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
/**
* Create a Byte Literal expression.
*/
- override def visitTinyIntLiteral(ctx: TinyIntLiteralContext): Literal = numericLiteral(ctx) {
- _.toByte
+ override def visitTinyIntLiteral(ctx: TinyIntLiteralContext): Literal = {
+ numericLiteral(ctx, Byte.MinValue, Byte.MaxValue, ByteType.simpleString)(_.toByte)
}
/**
* Create a Short Literal expression.
*/
- override def visitSmallIntLiteral(ctx: SmallIntLiteralContext): Literal = numericLiteral(ctx) {
- _.toShort
+ override def visitSmallIntLiteral(ctx: SmallIntLiteralContext): Literal = {
+ numericLiteral(ctx, Short.MinValue, Short.MaxValue, ShortType.simpleString)(_.toShort)
}
/**
* Create a Long Literal expression.
*/
- override def visitBigIntLiteral(ctx: BigIntLiteralContext): Literal = numericLiteral(ctx) {
- _.toLong
+ override def visitBigIntLiteral(ctx: BigIntLiteralContext): Literal = {
+ numericLiteral(ctx, Long.MinValue, Long.MaxValue, LongType.simpleString)(_.toLong)
}
/**
* Create a Double Literal expression.
*/
- override def visitDoubleLiteral(ctx: DoubleLiteralContext): Literal = numericLiteral(ctx) {
- _.toDouble
+ override def visitDoubleLiteral(ctx: DoubleLiteralContext): Literal = {
+ numericLiteral(ctx, Double.MinValue, Double.MaxValue, DoubleType.simpleString)(_.toDouble)
}
/**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
index 849d96212822c..401d9cd9d288c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
@@ -375,18 +375,21 @@ class ExpressionParserSuite extends PlanTest {
// Tiny Int Literal
assertEqual("10Y", Literal(10.toByte))
- intercept("-1000Y")
+ intercept("-1000Y", s"does not fit in range [${Byte.MinValue}, ${Byte.MaxValue}]")
// Small Int Literal
assertEqual("10S", Literal(10.toShort))
- intercept("40000S")
+ intercept("40000S", s"does not fit in range [${Short.MinValue}, ${Short.MaxValue}]")
// Long Int Literal
assertEqual("10L", Literal(10L))
- intercept("78732472347982492793712334L")
+ intercept("78732472347982492793712334L",
+ s"does not fit in range [${Long.MinValue}, ${Long.MaxValue}]")
// Double Literal
assertEqual("10.0D", Literal(10.0D))
+ intercept("-1.8E308D", s"does not fit in range")
+ intercept("1.8E308D", s"does not fit in range")
// TODO we need to figure out if we should throw an exception here!
assertEqual("1E309", Literal(Double.PositiveInfinity))
}
diff --git a/sql/core/src/test/resources/sql-tests/results/literals.sql.out b/sql/core/src/test/resources/sql-tests/results/literals.sql.out
index b964a6fc0921f..67e6d78dfbf24 100644
--- a/sql/core/src/test/resources/sql-tests/results/literals.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/literals.sql.out
@@ -41,7 +41,7 @@ struct<>
-- !query 4 output
org.apache.spark.sql.catalyst.parser.ParseException
-Value out of range. Value:"128" Radix:10(line 1, pos 7)
+Numeric literal 128 does not fit in range [-128, 127] for type tinyint(line 1, pos 7)
== SQL ==
select 128Y
@@ -71,7 +71,7 @@ struct<>
-- !query 7 output
org.apache.spark.sql.catalyst.parser.ParseException
-Value out of range. Value:"32768" Radix:10(line 1, pos 7)
+Numeric literal 32768 does not fit in range [-32768, 32767] for type smallint(line 1, pos 7)
== SQL ==
select 32768S
@@ -101,7 +101,7 @@ struct<>
-- !query 10 output
org.apache.spark.sql.catalyst.parser.ParseException
-For input string: "9223372036854775808"(line 1, pos 7)
+Numeric literal 9223372036854775808 does not fit in range [-9223372036854775808, 9223372036854775807] for type bigint(line 1, pos 7)
== SQL ==
select 9223372036854775808L
From 45d40d9f66c666eec6df926db23937589d67225d Mon Sep 17 00:00:00 2001
From: petermaxlee
Date: Sat, 20 Aug 2016 13:19:38 +0800
Subject: [PATCH 035/270] [SPARK-17150][SQL] Support SQL generation for inline
tables
## What changes were proposed in this pull request?
This patch adds support for SQL generation for inline tables. With this, it would be possible to create a view that depends on inline tables.
## How was this patch tested?
Added a test case in LogicalPlanToSQLSuite.
Author: petermaxlee
Closes #14709 from petermaxlee/SPARK-17150.
---
.../catalyst/plans/logical/LocalRelation.scala | 17 +++++++++++++++--
.../apache/spark/sql/catalyst/SQLBuilder.scala | 3 +++
.../src/test/resources/sqlgen/inline_tables.sql | 4 ++++
.../sql/catalyst/LogicalPlanToSQLSuite.scala | 8 ++++++++
4 files changed, 30 insertions(+), 2 deletions(-)
create mode 100644 sql/hive/src/test/resources/sqlgen/inline_tables.sql
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
index 9d64f35efcc6a..890865d177845 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
@@ -18,8 +18,9 @@
package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.{analysis, CatalystTypeConverters, InternalRow}
-import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
+import org.apache.spark.sql.catalyst.analysis
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal}
import org.apache.spark.sql.types.{StructField, StructType}
object LocalRelation {
@@ -75,4 +76,16 @@ case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil)
override lazy val statistics =
Statistics(sizeInBytes = output.map(_.dataType.defaultSize).sum * data.length)
+
+ def toSQL(inlineTableName: String): String = {
+ require(data.nonEmpty)
+ val types = output.map(_.dataType)
+ val rows = data.map { row =>
+ val cells = row.toSeq(types).zip(types).map { case (v, tpe) => Literal(v, tpe).sql }
+ cells.mkString("(", ", ", ")")
+ }
+ "VALUES " + rows.mkString(", ") +
+ " AS " + inlineTableName +
+ output.map(_.name).mkString("(", ", ", ")")
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala
index 0f51aa58d63ba..af1de511da060 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala
@@ -205,6 +205,9 @@ class SQLBuilder private (
case p: ScriptTransformation =>
scriptTransformationToSQL(p)
+ case p: LocalRelation =>
+ p.toSQL(newSubqueryName())
+
case OneRowRelation =>
""
diff --git a/sql/hive/src/test/resources/sqlgen/inline_tables.sql b/sql/hive/src/test/resources/sqlgen/inline_tables.sql
new file mode 100644
index 0000000000000..602551e69da6e
--- /dev/null
+++ b/sql/hive/src/test/resources/sqlgen/inline_tables.sql
@@ -0,0 +1,4 @@
+-- This file is automatically generated by LogicalPlanToSQLSuite.
+select * from values ("one", 1), ("two", 2), ("three", null) as data(a, b) where b > 1
+--------------------------------------------------------------------------------
+SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `b` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (VALUES ("one", 1), ("two", 2), ("three", CAST(NULL AS INT)) AS gen_subquery_0(gen_attr_0, gen_attr_1)) AS data WHERE (`gen_attr_1` > 1)) AS data
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala
index 4e5a51155defd..742b065891a8e 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala
@@ -1102,4 +1102,12 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
checkSQL("select * from orc_t", "select_orc_table")
}
}
+
+ test("inline tables") {
+ checkSQL(
+ """
+ |select * from values ("one", 1), ("two", 2), ("three", null) as data(a, b) where b > 1
+ """.stripMargin,
+ "inline_tables")
+ }
}
From 39f328ba3519b01940a7d1cdee851ba4e75ef31f Mon Sep 17 00:00:00 2001
From: Bryan Cutler
Date: Fri, 19 Aug 2016 23:46:36 -0700
Subject: [PATCH 036/270] [SPARK-15018][PYSPARK][ML] Improve handling of
PySpark Pipeline when used without stages
## What changes were proposed in this pull request?
When fitting a PySpark Pipeline without the `stages` param set, a confusing NoneType error is raised as attempts to iterate over the pipeline stages. A pipeline with no stages should act as an identity transform, however the `stages` param still needs to be set to an empty list. This change improves the error output when the `stages` param is not set and adds a better description of what the API expects as input. Also minor cleanup of related code.
## How was this patch tested?
Added new unit tests to verify an empty Pipeline acts as an identity transformer
Author: Bryan Cutler
Closes #12790 from BryanCutler/pipeline-identity-SPARK-15018.
---
python/pyspark/ml/pipeline.py | 11 +++--------
python/pyspark/ml/tests.py | 11 +++++++++++
2 files changed, 14 insertions(+), 8 deletions(-)
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index a48f4bb2ad1ba..4307ad02a0ebd 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -44,21 +44,19 @@ class Pipeline(Estimator, MLReadable, MLWritable):
the dataset for the next stage. The fitted model from a
:py:class:`Pipeline` is a :py:class:`PipelineModel`, which
consists of fitted models and transformers, corresponding to the
- pipeline stages. If there are no stages, the pipeline acts as an
+ pipeline stages. If stages is an empty list, the pipeline acts as an
identity transformer.
.. versionadded:: 1.3.0
"""
- stages = Param(Params._dummy(), "stages", "pipeline stages")
+ stages = Param(Params._dummy(), "stages", "a list of pipeline stages")
@keyword_only
def __init__(self, stages=None):
"""
__init__(self, stages=None)
"""
- if stages is None:
- stages = []
super(Pipeline, self).__init__()
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@@ -78,8 +76,7 @@ def getStages(self):
"""
Get pipeline stages.
"""
- if self.stages in self._paramMap:
- return self._paramMap[self.stages]
+ return self.getOrDefault(self.stages)
@keyword_only
@since("1.3.0")
@@ -88,8 +85,6 @@ def setParams(self, stages=None):
setParams(self, stages=None)
Sets params for Pipeline.
"""
- if stages is None:
- stages = []
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 4bcb2c400c4aa..6886ed321ee82 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -230,6 +230,17 @@ def test_pipeline(self):
self.assertEqual(5, transformer3.dataset_index)
self.assertEqual(6, dataset.index)
+ def test_identity_pipeline(self):
+ dataset = MockDataset()
+
+ def doTransform(pipeline):
+ pipeline_model = pipeline.fit(dataset)
+ return pipeline_model.transform(dataset)
+ # check that empty pipeline did not perform any transformation
+ self.assertEqual(dataset.index, doTransform(Pipeline(stages=[])).index)
+ # check that failure to set stages param will raise KeyError for missing param
+ self.assertRaises(KeyError, lambda: doTransform(Pipeline()))
+
class TestParams(HasMaxIter, HasInputCol, HasSeed):
"""
From 01401e965b58f7e8ab615764a452d7d18f1d4bf0 Mon Sep 17 00:00:00 2001
From: Junyang Qian
Date: Sat, 20 Aug 2016 06:59:23 -0700
Subject: [PATCH 037/270] [SPARK-16508][SPARKR] Fix CRAN
undocumented/duplicated arguments warnings.
## What changes were proposed in this pull request?
This PR tries to fix all the remaining "undocumented/duplicated arguments" warnings given by CRAN-check.
One left is doc for R `stats::glm` exported in SparkR. To mute that warning, we have to also provide document for all arguments of that non-SparkR function.
Some previous conversation is in #14558.
## How was this patch tested?
R unit test and `check-cran.sh` script (with no-test).
Author: Junyang Qian
Closes #14705 from junyangq/SPARK-16508-master.
---
R/pkg/R/DataFrame.R | 221 +++++++++++++++++++++++++------------------
R/pkg/R/SQLContext.R | 30 +++---
R/pkg/R/WindowSpec.R | 11 ++-
R/pkg/R/column.R | 18 +++-
R/pkg/R/functions.R | 173 +++++++++++++++++++++------------
R/pkg/R/generics.R | 62 +++++++++---
R/pkg/R/group.R | 7 +-
R/pkg/R/mllib.R | 113 +++++++++++-----------
R/pkg/R/schema.R | 5 +-
R/pkg/R/sparkR.R | 21 ++--
R/pkg/R/stats.R | 25 +++--
11 files changed, 419 insertions(+), 267 deletions(-)
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 09be06de06b52..540dc3122dd6d 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -120,8 +120,9 @@ setMethod("schema",
#'
#' Print the logical and physical Catalyst plans to the console for debugging.
#'
-#' @param x A SparkDataFrame
+#' @param x a SparkDataFrame.
#' @param extended Logical. If extended is FALSE, explain() only prints the physical plan.
+#' @param ... further arguments to be passed to or from other methods.
#' @family SparkDataFrame functions
#' @aliases explain,SparkDataFrame-method
#' @rdname explain
@@ -177,11 +178,13 @@ setMethod("isLocal",
#'
#' Print the first numRows rows of a SparkDataFrame
#'
-#' @param x A SparkDataFrame
-#' @param numRows The number of rows to print. Defaults to 20.
-#' @param truncate Whether truncate long strings. If true, strings more than 20 characters will be
-#' truncated. However, if set greater than zero, truncates strings longer than `truncate`
-#' characters and all cells will be aligned right.
+#' @param x a SparkDataFrame.
+#' @param numRows the number of rows to print. Defaults to 20.
+#' @param truncate whether truncate long strings. If \code{TRUE}, strings more than
+#' 20 characters will be truncated. However, if set greater than zero,
+#' truncates strings longer than `truncate` characters and all cells
+#' will be aligned right.
+#' @param ... further arguments to be passed to or from other methods.
#' @family SparkDataFrame functions
#' @aliases showDF,SparkDataFrame-method
#' @rdname showDF
@@ -211,7 +214,7 @@ setMethod("showDF",
#'
#' Print the SparkDataFrame column names and types
#'
-#' @param x A SparkDataFrame
+#' @param object a SparkDataFrame.
#'
#' @family SparkDataFrame functions
#' @rdname show
@@ -262,11 +265,11 @@ setMethod("dtypes",
})
})
-#' Column names
+#' Column Names of SparkDataFrame
#'
-#' Return all column names as a list
+#' Return all column names as a list.
#'
-#' @param x A SparkDataFrame
+#' @param x a SparkDataFrame.
#'
#' @family SparkDataFrame functions
#' @rdname columns
@@ -323,6 +326,8 @@ setMethod("colnames",
columns(x)
})
+#' @param value a character vector. Must have the same length as the number
+#' of columns in the SparkDataFrame.
#' @rdname columns
#' @aliases colnames<-,SparkDataFrame-method
#' @name colnames<-
@@ -514,9 +519,10 @@ setMethod("registerTempTable",
#'
#' Insert the contents of a SparkDataFrame into a table registered in the current SparkSession.
#'
-#' @param x A SparkDataFrame
-#' @param tableName A character vector containing the name of the table
-#' @param overwrite A logical argument indicating whether or not to overwrite
+#' @param x a SparkDataFrame.
+#' @param tableName a character vector containing the name of the table.
+#' @param overwrite a logical argument indicating whether or not to overwrite.
+#' @param ... further arguments to be passed to or from other methods.
#' the existing rows in the table.
#'
#' @family SparkDataFrame functions
@@ -575,7 +581,9 @@ setMethod("cache",
#' supported storage levels, refer to
#' \url{http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence}.
#'
-#' @param x The SparkDataFrame to persist
+#' @param x the SparkDataFrame to persist.
+#' @param newLevel storage level chosen for the persistance. See available options in
+#' the description.
#'
#' @family SparkDataFrame functions
#' @rdname persist
@@ -603,8 +611,9 @@ setMethod("persist",
#' Mark this SparkDataFrame as non-persistent, and remove all blocks for it from memory and
#' disk.
#'
-#' @param x The SparkDataFrame to unpersist
-#' @param blocking Whether to block until all blocks are deleted
+#' @param x the SparkDataFrame to unpersist.
+#' @param blocking whether to block until all blocks are deleted.
+#' @param ... further arguments to be passed to or from other methods.
#'
#' @family SparkDataFrame functions
#' @rdname unpersist-methods
@@ -638,9 +647,10 @@ setMethod("unpersist",
#' \item{3.} {Return a new SparkDataFrame partitioned by the given column(s),
#' using `spark.sql.shuffle.partitions` as number of partitions.}
#'}
-#' @param x A SparkDataFrame
-#' @param numPartitions The number of partitions to use.
-#' @param col The column by which the partitioning will be performed.
+#' @param x a SparkDataFrame.
+#' @param numPartitions the number of partitions to use.
+#' @param col the column by which the partitioning will be performed.
+#' @param ... additional column(s) to be used in the partitioning.
#'
#' @family SparkDataFrame functions
#' @rdname repartition
@@ -919,11 +929,10 @@ setMethod("sample_frac",
#' Returns the number of rows in a SparkDataFrame
#'
-#' @param x A SparkDataFrame
-#'
+#' @param x a SparkDataFrame.
#' @family SparkDataFrame functions
#' @rdname nrow
-#' @name count
+#' @name nrow
#' @aliases count,SparkDataFrame-method
#' @export
#' @examples
@@ -999,9 +1008,10 @@ setMethod("dim",
#' Collects all the elements of a SparkDataFrame and coerces them into an R data.frame.
#'
-#' @param x A SparkDataFrame
-#' @param stringsAsFactors (Optional) A logical indicating whether or not string columns
+#' @param x a SparkDataFrame.
+#' @param stringsAsFactors (Optional) a logical indicating whether or not string columns
#' should be converted to factors. FALSE by default.
+#' @param ... further arguments to be passed to or from other methods.
#'
#' @family SparkDataFrame functions
#' @rdname collect
@@ -1096,8 +1106,10 @@ setMethod("limit",
dataFrame(res)
})
-#' Take the first NUM rows of a SparkDataFrame and return a the results as a R data.frame
+#' Take the first NUM rows of a SparkDataFrame and return the results as a R data.frame
#'
+#' @param x a SparkDataFrame.
+#' @param num number of rows to take.
#' @family SparkDataFrame functions
#' @rdname take
#' @name take
@@ -1124,9 +1136,9 @@ setMethod("take",
#' then head() returns the first 6 rows in keeping with the current data.frame
#' convention in R.
#'
-#' @param x A SparkDataFrame
-#' @param num The number of rows to return. Default is 6.
-#' @return A data.frame
+#' @param x a SparkDataFrame.
+#' @param num the number of rows to return. Default is 6.
+#' @return A data.frame.
#'
#' @family SparkDataFrame functions
#' @aliases head,SparkDataFrame-method
@@ -1150,7 +1162,8 @@ setMethod("head",
#' Return the first row of a SparkDataFrame
#'
-#' @param x A SparkDataFrame
+#' @param x a SparkDataFrame or a column used in aggregation function.
+#' @param ... further arguments to be passed to or from other methods.
#'
#' @family SparkDataFrame functions
#' @aliases first,SparkDataFrame-method
@@ -1201,8 +1214,9 @@ setMethod("toRDD",
#'
#' Groups the SparkDataFrame using the specified columns, so we can run aggregation on them.
#'
-#' @param x a SparkDataFrame
-#' @return a GroupedData
+#' @param x a SparkDataFrame.
+#' @param ... variable(s) (character names(s) or Column(s)) to group on.
+#' @return A GroupedData.
#' @family SparkDataFrame functions
#' @aliases groupBy,SparkDataFrame-method
#' @rdname groupBy
@@ -1244,7 +1258,6 @@ setMethod("group_by",
#'
#' Compute aggregates by specifying a list of columns
#'
-#' @param x a SparkDataFrame
#' @family SparkDataFrame functions
#' @aliases agg,SparkDataFrame-method
#' @rdname summarize
@@ -1391,16 +1404,15 @@ setMethod("dapplyCollect",
#' Groups the SparkDataFrame using the specified columns and applies the R function to each
#' group.
#'
-#' @param x A SparkDataFrame
-#' @param cols Grouping columns
-#' @param func A function to be applied to each group partition specified by grouping
+#' @param cols grouping columns.
+#' @param func a function to be applied to each group partition specified by grouping
#' column of the SparkDataFrame. The function `func` takes as argument
#' a key - grouping columns and a data frame - a local R data.frame.
#' The output of `func` is a local R data.frame.
-#' @param schema The schema of the resulting SparkDataFrame after the function is applied.
+#' @param schema the schema of the resulting SparkDataFrame after the function is applied.
#' The schema must match to output of `func`. It has to be defined for each
#' output column with preferred output column name and corresponding data type.
-#' @return a SparkDataFrame
+#' @return A SparkDataFrame.
#' @family SparkDataFrame functions
#' @aliases gapply,SparkDataFrame-method
#' @rdname gapply
@@ -1483,13 +1495,12 @@ setMethod("gapply",
#' Groups the SparkDataFrame using the specified columns, applies the R function to each
#' group and collects the result back to R as data.frame.
#'
-#' @param x A SparkDataFrame
-#' @param cols Grouping columns
-#' @param func A function to be applied to each group partition specified by grouping
+#' @param cols grouping columns.
+#' @param func a function to be applied to each group partition specified by grouping
#' column of the SparkDataFrame. The function `func` takes as argument
#' a key - grouping columns and a data frame - a local R data.frame.
#' The output of `func` is a local R data.frame.
-#' @return a data.frame
+#' @return A data.frame.
#' @family SparkDataFrame functions
#' @aliases gapplyCollect,SparkDataFrame-method
#' @rdname gapplyCollect
@@ -1636,6 +1647,7 @@ getColumn <- function(x, c) {
column(callJMethod(x@sdf, "col", c))
}
+#' @param name name of a Column (without being wrapped by \code{""}).
#' @rdname select
#' @name $
#' @aliases $,SparkDataFrame-method
@@ -1645,6 +1657,7 @@ setMethod("$", signature(x = "SparkDataFrame"),
getColumn(x, name)
})
+#' @param value a Column or NULL. If NULL, the specified Column is dropped.
#' @rdname select
#' @name $<-
#' @aliases $<-,SparkDataFrame-method
@@ -1719,12 +1732,13 @@ setMethod("[", signature(x = "SparkDataFrame"),
#' Subset
#'
#' Return subsets of SparkDataFrame according to given conditions
-#' @param x A SparkDataFrame
-#' @param subset (Optional) A logical expression to filter on rows
-#' @param select expression for the single Column or a list of columns to select from the SparkDataFrame
+#' @param x a SparkDataFrame.
+#' @param i,subset (Optional) a logical expression to filter on rows.
+#' @param j,select expression for the single Column or a list of columns to select from the SparkDataFrame.
#' @param drop if TRUE, a Column will be returned if the resulting dataset has only one column.
-#' Otherwise, a SparkDataFrame will always be returned.
-#' @return A new SparkDataFrame containing only the rows that meet the condition with selected columns
+#' Otherwise, a SparkDataFrame will always be returned.
+#' @param ... currently not used.
+#' @return A new SparkDataFrame containing only the rows that meet the condition with selected columns.
#' @export
#' @family SparkDataFrame functions
#' @aliases subset,SparkDataFrame-method
@@ -1759,9 +1773,12 @@ setMethod("subset", signature(x = "SparkDataFrame"),
#' Select
#'
#' Selects a set of columns with names or Column expressions.
-#' @param x A SparkDataFrame
-#' @param col A list of columns or single Column or name
-#' @return A new SparkDataFrame with selected columns
+#' @param x a SparkDataFrame.
+#' @param col a list of columns or single Column or name.
+#' @param ... additional column(s) if only one column is specified in \code{col}.
+#' If more than one column is assigned in \code{col}, \code{...}
+#' should be left empty.
+#' @return A new SparkDataFrame with selected columns.
#' @export
#' @family SparkDataFrame functions
#' @rdname select
@@ -1858,9 +1875,9 @@ setMethod("selectExpr",
#' Return a new SparkDataFrame by adding a column or replacing the existing column
#' that has the same name.
#'
-#' @param x A SparkDataFrame
-#' @param colName A column name.
-#' @param col A Column expression.
+#' @param x a SparkDataFrame.
+#' @param colName a column name.
+#' @param col a Column expression.
#' @return A SparkDataFrame with the new column added or the existing column replaced.
#' @family SparkDataFrame functions
#' @aliases withColumn,SparkDataFrame,character,Column-method
@@ -1889,8 +1906,8 @@ setMethod("withColumn",
#'
#' Return a new SparkDataFrame with the specified columns added or replaced.
#'
-#' @param .data A SparkDataFrame
-#' @param col a named argument of the form name = col
+#' @param .data a SparkDataFrame.
+#' @param ... additional column argument(s) each in the form name = col.
#' @return A new SparkDataFrame with the new columns added or replaced.
#' @family SparkDataFrame functions
#' @aliases mutate,SparkDataFrame-method
@@ -1967,6 +1984,7 @@ setMethod("mutate",
do.call(select, c(x, colList, deDupCols))
})
+#' @param _data a SparkDataFrame.
#' @export
#' @rdname mutate
#' @aliases transform,SparkDataFrame-method
@@ -2278,11 +2296,18 @@ setMethod("join",
#' specified, the common column names in \code{x} and \code{y} will be used.
#' @param by.x a character vector specifying the joining columns for x.
#' @param by.y a character vector specifying the joining columns for y.
+#' @param all a boolean value setting \code{all.x} and \code{all.y}
+#' if any of them are unset.
#' @param all.x a boolean value indicating whether all the rows in x should
#' be including in the join
#' @param all.y a boolean value indicating whether all the rows in y should
#' be including in the join
#' @param sort a logical argument indicating whether the resulting columns should be sorted
+#' @param suffixes a string vector of length 2 used to make colnames of
+#' \code{x} and \code{y} unique.
+#' The first element is appended to each colname of \code{x}.
+#' The second element is appended to each colname of \code{y}.
+#' @param ... additional argument(s) passed to the method.
#' @details If all.x and all.y are set to FALSE, a natural join will be returned. If
#' all.x is set to TRUE and all.y is set to FALSE, a left outer join will
#' be returned. If all.x is set to FALSE and all.y is set to TRUE, a right
@@ -2311,7 +2336,7 @@ setMethod("merge",
signature(x = "SparkDataFrame", y = "SparkDataFrame"),
function(x, y, by = intersect(names(x), names(y)), by.x = by, by.y = by,
all = FALSE, all.x = all, all.y = all,
- sort = TRUE, suffixes = c("_x", "_y"), ... ) {
+ sort = TRUE, suffixes = c("_x", "_y"), ...) {
if (length(suffixes) != 2) {
stop("suffixes must have length 2")
@@ -2464,8 +2489,10 @@ setMethod("unionAll",
#' Union two or more SparkDataFrames. This is equivalent to `UNION ALL` in SQL.
#' Note that this does not remove duplicate rows across the two SparkDataFrames.
#'
-#' @param x A SparkDataFrame
-#' @param ... Additional SparkDataFrame
+#' @param x a SparkDataFrame.
+#' @param ... additional SparkDataFrame(s).
+#' @param deparse.level currently not used (put here to match the signature of
+#' the base implementation).
#' @return A SparkDataFrame containing the result of the union.
#' @family SparkDataFrame functions
#' @aliases rbind,SparkDataFrame-method
@@ -2522,8 +2549,8 @@ setMethod("intersect",
#' Return a new SparkDataFrame containing rows in this SparkDataFrame
#' but not in another SparkDataFrame. This is equivalent to `EXCEPT` in SQL.
#'
-#' @param x A SparkDataFrame
-#' @param y A SparkDataFrame
+#' @param x a SparkDataFrame.
+#' @param y a SparkDataFrame.
#' @return A SparkDataFrame containing the result of the except operation.
#' @family SparkDataFrame functions
#' @aliases except,SparkDataFrame,SparkDataFrame-method
@@ -2564,10 +2591,11 @@ setMethod("except",
#' and to not change the existing data.
#' }
#'
-#' @param df A SparkDataFrame
-#' @param path A name for the table
-#' @param source A name for external data source
-#' @param mode One of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default)
+#' @param df a SparkDataFrame.
+#' @param path a name for the table.
+#' @param source a name for external data source.
+#' @param mode one of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default)
+#' @param ... additional argument(s) passed to the method.
#'
#' @family SparkDataFrame functions
#' @aliases write.df,SparkDataFrame,character-method
@@ -2626,10 +2654,11 @@ setMethod("saveDF",
#' ignore: The save operation is expected to not save the contents of the SparkDataFrame
#' and to not change the existing data. \cr
#'
-#' @param df A SparkDataFrame
-#' @param tableName A name for the table
-#' @param source A name for external data source
-#' @param mode One of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default)
+#' @param df a SparkDataFrame.
+#' @param tableName a name for the table.
+#' @param source a name for external data source.
+#' @param mode one of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default).
+#' @param ... additional option(s) passed to the method.
#'
#' @family SparkDataFrame functions
#' @aliases saveAsTable,SparkDataFrame,character-method
@@ -2665,10 +2694,10 @@ setMethod("saveAsTable",
#' Computes statistics for numeric and string columns.
#' If no columns are given, this function computes statistics for all numerical or string columns.
#'
-#' @param x A SparkDataFrame to be computed.
-#' @param col A string of name
-#' @param ... Additional expressions
-#' @return A SparkDataFrame
+#' @param x a SparkDataFrame to be computed.
+#' @param col a string of name.
+#' @param ... additional expressions.
+#' @return A SparkDataFrame.
#' @family SparkDataFrame functions
#' @aliases describe,SparkDataFrame,character-method describe,SparkDataFrame,ANY-method
#' @rdname summary
@@ -2703,6 +2732,7 @@ setMethod("describe",
dataFrame(sdf)
})
+#' @param object a SparkDataFrame to be summarized.
#' @rdname summary
#' @name summary
#' @aliases summary,SparkDataFrame-method
@@ -2718,16 +2748,20 @@ setMethod("summary",
#'
#' dropna, na.omit - Returns a new SparkDataFrame omitting rows with null values.
#'
-#' @param x A SparkDataFrame.
+#' @param x a SparkDataFrame.
#' @param how "any" or "all".
#' if "any", drop a row if it contains any nulls.
#' if "all", drop a row only if all its values are null.
#' if minNonNulls is specified, how is ignored.
-#' @param minNonNulls If specified, drop rows that have less than
+#' @param minNonNulls if specified, drop rows that have less than
#' minNonNulls non-null values.
#' This overwrites the how parameter.
-#' @param cols Optional list of column names to consider.
-#' @return A SparkDataFrame
+#' @param cols optional list of column names to consider. In `fillna`,
+#' columns specified in cols that do not have matching data
+#' type are ignored. For example, if value is a character, and
+#' subset contains a non-character column, then the non-character
+#' column is simply ignored.
+#' @return A SparkDataFrame.
#'
#' @family SparkDataFrame functions
#' @rdname nafunctions
@@ -2759,6 +2793,8 @@ setMethod("dropna",
dataFrame(sdf)
})
+#' @param object a SparkDataFrame.
+#' @param ... further arguments to be passed to or from other methods.
#' @rdname nafunctions
#' @name na.omit
#' @aliases na.omit,SparkDataFrame-method
@@ -2772,18 +2808,12 @@ setMethod("na.omit",
#' fillna - Replace null values.
#'
-#' @param x A SparkDataFrame.
-#' @param value Value to replace null values with.
+#' @param value value to replace null values with.
#' Should be an integer, numeric, character or named list.
#' If the value is a named list, then cols is ignored and
#' value must be a mapping from column name (character) to
#' replacement value. The replacement value must be an
#' integer, numeric or character.
-#' @param cols optional list of column names to consider.
-#' Columns specified in cols that do not have matching data
-#' type are ignored. For example, if value is a character, and
-#' subset contains a non-character column, then the non-character
-#' column is simply ignored.
#'
#' @rdname nafunctions
#' @name fillna
@@ -2848,8 +2878,11 @@ setMethod("fillna",
#' Since data.frames are held in memory, ensure that you have enough memory
#' in your system to accommodate the contents.
#'
-#' @param x a SparkDataFrame
-#' @return a data.frame
+#' @param x a SparkDataFrame.
+#' @param row.names NULL or a character vector giving the row names for the data frame.
+#' @param optional If `TRUE`, converting column names is optional.
+#' @param ... additional arguments to pass to base::as.data.frame.
+#' @return A data.frame.
#' @family SparkDataFrame functions
#' @aliases as.data.frame,SparkDataFrame-method
#' @rdname as.data.frame
@@ -3003,9 +3036,10 @@ setMethod("str",
#' Returns a new SparkDataFrame with columns dropped.
#' This is a no-op if schema doesn't contain column name(s).
#'
-#' @param x A SparkDataFrame.
-#' @param cols A character vector of column names or a Column.
-#' @return A SparkDataFrame
+#' @param x a SparkDataFrame.
+#' @param col a character vector of column names or a Column.
+#' @param ... further arguments to be passed to or from other methods.
+#' @return A SparkDataFrame.
#'
#' @family SparkDataFrame functions
#' @rdname drop
@@ -3024,7 +3058,7 @@ setMethod("str",
#' @note drop since 2.0.0
setMethod("drop",
signature(x = "SparkDataFrame"),
- function(x, col) {
+ function(x, col, ...) {
stopifnot(class(col) == "character" || class(col) == "Column")
if (class(col) == "Column") {
@@ -3052,8 +3086,8 @@ setMethod("drop",
#'
#' @name histogram
#' @param nbins the number of bins (optional). Default value is 10.
+#' @param col the column as Character string or a Column to build the histogram from.
#' @param df the SparkDataFrame containing the Column to build the histogram from.
-#' @param colname the name of the column to build the histogram from.
#' @return a data.frame with the histogram statistics, i.e., counts and centroids.
#' @rdname histogram
#' @aliases histogram,SparkDataFrame,characterOrColumn-method
@@ -3184,10 +3218,11 @@ setMethod("histogram",
#' and to not change the existing data.
#' }
#'
-#' @param x A SparkDataFrame
-#' @param url JDBC database url of the form `jdbc:subprotocol:subname`
-#' @param tableName The name of the table in the external database
-#' @param mode One of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default)
+#' @param x s SparkDataFrame.
+#' @param url JDBC database url of the form `jdbc:subprotocol:subname`.
+#' @param tableName yhe name of the table in the external database.
+#' @param mode one of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default).
+#' @param ... additional JDBC database connection properties.
#' @family SparkDataFrame functions
#' @rdname write.jdbc
#' @name write.jdbc
diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R
index 0c06bba639d9b..a9cd2d85f898c 100644
--- a/R/pkg/R/SQLContext.R
+++ b/R/pkg/R/SQLContext.R
@@ -165,9 +165,9 @@ getDefaultSqlSource <- function() {
#'
#' Converts R data.frame or list into SparkDataFrame.
#'
-#' @param data An RDD or list or data.frame
-#' @param schema a list of column names or named list (StructType), optional
-#' @return a SparkDataFrame
+#' @param data an RDD or list or data.frame.
+#' @param schema a list of column names or named list (StructType), optional.
+#' @return A SparkDataFrame.
#' @rdname createDataFrame
#' @export
#' @examples
@@ -257,23 +257,25 @@ createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) {
}
createDataFrame <- function(x, ...) {
- dispatchFunc("createDataFrame(data, schema = NULL, samplingRatio = 1.0)", x, ...)
+ dispatchFunc("createDataFrame(data, schema = NULL)", x, ...)
}
+#' @param samplingRatio Currently not used.
#' @rdname createDataFrame
#' @aliases createDataFrame
#' @export
#' @method as.DataFrame default
#' @note as.DataFrame since 1.6.0
as.DataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) {
- createDataFrame(data, schema, samplingRatio)
+ createDataFrame(data, schema)
}
+#' @param ... additional argument(s).
#' @rdname createDataFrame
#' @aliases as.DataFrame
#' @export
-as.DataFrame <- function(x, ...) {
- dispatchFunc("as.DataFrame(data, schema = NULL, samplingRatio = 1.0)", x, ...)
+as.DataFrame <- function(data, ...) {
+ dispatchFunc("as.DataFrame(data, schema = NULL)", data, ...)
}
#' toDF
@@ -398,7 +400,7 @@ read.orc <- function(path) {
#'
#' Loads a Parquet file, returning the result as a SparkDataFrame.
#'
-#' @param path Path of file to read. A vector of multiple paths is allowed.
+#' @param path path of file to read. A vector of multiple paths is allowed.
#' @return SparkDataFrame
#' @rdname read.parquet
#' @export
@@ -418,6 +420,7 @@ read.parquet <- function(x, ...) {
dispatchFunc("read.parquet(...)", x, ...)
}
+#' @param ... argument(s) passed to the method.
#' @rdname read.parquet
#' @name parquetFile
#' @export
@@ -727,6 +730,7 @@ dropTempView <- function(viewName) {
#' @param source The name of external data source
#' @param schema The data schema defined in structType
#' @param na.strings Default string value for NA when source is "csv"
+#' @param ... additional external data source specific named properties.
#' @return SparkDataFrame
#' @rdname read.df
#' @name read.df
@@ -791,10 +795,11 @@ loadDF <- function(x, ...) {
#' If `source` is not specified, the default data source configured by
#' "spark.sql.sources.default" will be used.
#'
-#' @param tableName A name of the table
-#' @param path The path of files to load
-#' @param source the name of external data source
-#' @return SparkDataFrame
+#' @param tableName a name of the table.
+#' @param path the path of files to load.
+#' @param source the name of external data source.
+#' @param ... additional argument(s) passed to the method.
+#' @return A SparkDataFrame.
#' @rdname createExternalTable
#' @export
#' @examples
@@ -840,6 +845,7 @@ createExternalTable <- function(x, ...) {
#' clause expressions used to split the column `partitionColumn` evenly.
#' This defaults to SparkContext.defaultParallelism when unset.
#' @param predicates a list of conditions in the where clause; each one defines one partition
+#' @param ... additional JDBC database connection named propertie(s).
#' @return SparkDataFrame
#' @rdname read.jdbc
#' @name read.jdbc
diff --git a/R/pkg/R/WindowSpec.R b/R/pkg/R/WindowSpec.R
index 751ba3fde954d..b55356b07d5e3 100644
--- a/R/pkg/R/WindowSpec.R
+++ b/R/pkg/R/WindowSpec.R
@@ -54,8 +54,10 @@ setMethod("show", "WindowSpec",
#'
#' Defines the partitioning columns in a WindowSpec.
#'
-#' @param x a WindowSpec
-#' @return a WindowSpec
+#' @param x a WindowSpec.
+#' @param col a column to partition on (desribed by the name or Column).
+#' @param ... additional column(s) to partition on.
+#' @return A WindowSpec.
#' @rdname partitionBy
#' @name partitionBy
#' @aliases partitionBy,WindowSpec-method
@@ -86,7 +88,7 @@ setMethod("partitionBy",
#'
#' Defines the ordering columns in a WindowSpec.
#' @param x a WindowSpec
-#' @param col a character or Column object indicating an ordering column
+#' @param col a character or Column indicating an ordering column
#' @param ... additional sorting fields
#' @return A WindowSpec.
#' @name orderBy
@@ -192,6 +194,9 @@ setMethod("rangeBetween",
#'
#' Define a windowing column.
#'
+#' @param x a Column, usually one returned by window function(s).
+#' @param window a WindowSpec object. Can be created by `windowPartitionBy` or
+#' `windowOrderBy` and configured by other WindowSpec methods.
#' @rdname over
#' @name over
#' @aliases over,Column,WindowSpec-method
diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R
index 0edb9d2ae5c45..af486e1ce212d 100644
--- a/R/pkg/R/column.R
+++ b/R/pkg/R/column.R
@@ -163,8 +163,9 @@ setMethod("alias",
#' @family colum_func
#' @aliases substr,Column-method
#'
-#' @param start starting position
-#' @param stop ending position
+#' @param x a Column.
+#' @param start starting position.
+#' @param stop ending position.
#' @note substr since 1.4.0
setMethod("substr", signature(x = "Column"),
function(x, start, stop) {
@@ -219,6 +220,7 @@ setMethod("endsWith", signature(x = "Column"),
#' @family colum_func
#' @aliases between,Column-method
#'
+#' @param x a Column
#' @param bounds lower and upper bounds
#' @note between since 1.5.0
setMethod("between", signature(x = "Column"),
@@ -233,6 +235,11 @@ setMethod("between", signature(x = "Column"),
#' Casts the column to a different data type.
#'
+#' @param x a Column.
+#' @param dataType a character object describing the target data type.
+#' See
+#' \href{https://spark.apache.org/docs/latest/sparkr.html#data-type-mapping-between-r-and-spark}{
+#' Spark Data Types} for available data types.
#' @rdname cast
#' @name cast
#' @family colum_func
@@ -254,10 +261,12 @@ setMethod("cast",
#' Match a column with given values.
#'
+#' @param x a Column.
+#' @param table a collection of values (coercible to list) to compare with.
#' @rdname match
#' @name %in%
#' @aliases %in%,Column-method
-#' @return a matched values as a result of comparing with given values.
+#' @return A matched values as a result of comparing with given values.
#' @export
#' @examples
#' \dontrun{
@@ -277,6 +286,9 @@ setMethod("%in%",
#' If values in the specified column are null, returns the value.
#' Can be used in conjunction with `when` to specify a default value for expressions.
#'
+#' @param x a Column.
+#' @param value value to replace when the corresponding entry in \code{x} is NA.
+#' Can be a single value or a Column.
#' @rdname otherwise
#' @name otherwise
#' @family colum_func
diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R
index 573c915a5c67a..b3c10de71f3fe 100644
--- a/R/pkg/R/functions.R
+++ b/R/pkg/R/functions.R
@@ -23,6 +23,7 @@ NULL
#' A new \linkS4class{Column} is created to represent the literal value.
#' If the parameter is a \linkS4class{Column}, it is returned unchanged.
#'
+#' @param x a literal value or a Column.
#' @family normal_funcs
#' @rdname lit
#' @name lit
@@ -89,8 +90,6 @@ setMethod("acos",
#' Returns the approximate number of distinct items in a group. This is a column
#' aggregate function.
#'
-#' @param x Column to compute on.
-#'
#' @rdname approxCountDistinct
#' @name approxCountDistinct
#' @return the approximate number of distinct items in a group.
@@ -171,8 +170,6 @@ setMethod("atan",
#'
#' Aggregate function: returns the average of the values in a group.
#'
-#' @param x Column to compute on.
-#'
#' @rdname avg
#' @name avg
#' @family agg_funcs
@@ -319,7 +316,7 @@ setMethod("column",
#'
#' Computes the Pearson Correlation Coefficient for two Columns.
#'
-#' @param x Column to compute on.
+#' @param col2 a (second) Column.
#'
#' @rdname corr
#' @name corr
@@ -339,8 +336,6 @@ setMethod("corr", signature(x = "Column"),
#'
#' Compute the sample covariance between two expressions.
#'
-#' @param x Column to compute on.
-#'
#' @rdname cov
#' @name cov
#' @family math_funcs
@@ -362,8 +357,8 @@ setMethod("cov", signature(x = "characterOrColumn"),
#' @rdname cov
#'
-#' @param col1 First column to compute cov_samp.
-#' @param col2 Second column to compute cov_samp.
+#' @param col1 the first Column.
+#' @param col2 the second Column.
#' @name covar_samp
#' @aliases covar_samp,characterOrColumn,characterOrColumn-method
#' @note covar_samp since 2.0.0
@@ -451,9 +446,7 @@ setMethod("cosh",
#'
#' Returns the number of items in a group. This is a column aggregate function.
#'
-#' @param x Column to compute on.
-#'
-#' @rdname nrow
+#' @rdname count
#' @name count
#' @family agg_funcs
#' @aliases count,Column-method
@@ -493,6 +486,7 @@ setMethod("crc32",
#' Calculates the hash code of given columns, and returns the result as a int column.
#'
#' @param x Column to compute on.
+#' @param ... additional Column(s) to be included.
#'
#' @rdname hash
#' @name hash
@@ -663,7 +657,8 @@ setMethod("factorial",
#' The function by default returns the first values it sees. It will return the first non-missing
#' value it sees when na.rm is set to true. If all values are missing, then NA is returned.
#'
-#' @param x Column to compute on.
+#' @param na.rm a logical value indicating whether NA values should be stripped
+#' before the computation proceeds.
#'
#' @rdname first
#' @name first
@@ -832,7 +827,10 @@ setMethod("kurtosis",
#' The function by default returns the last values it sees. It will return the last non-missing
#' value it sees when na.rm is set to true. If all values are missing, then NA is returned.
#'
-#' @param x Column to compute on.
+#' @param x column to compute on.
+#' @param na.rm a logical value indicating whether NA values should be stripped
+#' before the computation proceeds.
+#' @param ... further arguments to be passed to or from other methods.
#'
#' @rdname last
#' @name last
@@ -1143,7 +1141,7 @@ setMethod("minute",
#' @export
#' @examples \dontrun{select(df, monotonically_increasing_id())}
setMethod("monotonically_increasing_id",
- signature(x = "missing"),
+ signature("missing"),
function() {
jc <- callJStatic("org.apache.spark.sql.functions", "monotonically_increasing_id")
column(jc)
@@ -1272,13 +1270,16 @@ setMethod("round",
#' bround
#'
-#' Returns the value of the column `e` rounded to `scale` decimal places using HALF_EVEN rounding
-#' mode if `scale` >= 0 or at integral part when `scale` < 0.
+#' Returns the value of the column \code{e} rounded to \code{scale} decimal places using HALF_EVEN rounding
+#' mode if \code{scale} >= 0 or at integer part when \code{scale} < 0.
#' Also known as Gaussian rounding or bankers' rounding that rounds to the nearest even number.
#' bround(2.5, 0) = 2, bround(3.5, 0) = 4.
#'
#' @param x Column to compute on.
-#'
+#' @param scale round to \code{scale} digits to the right of the decimal point when \code{scale} > 0,
+#' the nearest even number when \code{scale} = 0, and \code{scale} digits to the left
+#' of the decimal point when \code{scale} < 0.
+#' @param ... further arguments to be passed to or from other methods.
#' @rdname bround
#' @name bround
#' @family math_funcs
@@ -1319,7 +1320,7 @@ setMethod("rtrim",
#' Aggregate function: alias for \link{stddev_samp}
#'
#' @param x Column to compute on.
-#'
+#' @param na.rm currently not used.
#' @rdname sd
#' @name sd
#' @family agg_funcs
@@ -1497,7 +1498,7 @@ setMethod("soundex",
#' \dontrun{select(df, spark_partition_id())}
#' @note spark_partition_id since 2.0.0
setMethod("spark_partition_id",
- signature(x = "missing"),
+ signature("missing"),
function() {
jc <- callJStatic("org.apache.spark.sql.functions", "spark_partition_id")
column(jc)
@@ -1560,7 +1561,8 @@ setMethod("stddev_samp",
#'
#' Creates a new struct column that composes multiple input columns.
#'
-#' @param x Column to compute on.
+#' @param x a column to compute on.
+#' @param ... optional column(s) to be included.
#'
#' @rdname struct
#' @name struct
@@ -1831,8 +1833,8 @@ setMethod("upper",
#'
#' Aggregate function: alias for \link{var_samp}.
#'
-#' @param x Column to compute on.
-#'
+#' @param x a Column to compute on.
+#' @param y,na.rm,use currently not used.
#' @rdname var
#' @name var
#' @family agg_funcs
@@ -2114,7 +2116,9 @@ setMethod("pmod", signature(y = "Column"),
#' @rdname approxCountDistinct
#' @name approxCountDistinct
#'
+#' @param x Column to compute on.
#' @param rsd maximum estimation error allowed (default = 0.05)
+#' @param ... further arguments to be passed to or from other methods.
#'
#' @aliases approxCountDistinct,Column-method
#' @export
@@ -2127,7 +2131,7 @@ setMethod("approxCountDistinct",
column(jc)
})
-#' Count Distinct
+#' Count Distinct Values
#'
#' @param x Column to compute on
#' @param ... other columns
@@ -2156,7 +2160,7 @@ setMethod("countDistinct",
#' concat
#'
#' Concatenates multiple input string columns together into a single string column.
-#'
+#'
#' @param x Column to compute on
#' @param ... other columns
#'
@@ -2246,7 +2250,6 @@ setMethod("ceiling",
})
#' @rdname sign
-#' @param x Column to compute on
#'
#' @name sign
#' @aliases sign,Column-method
@@ -2262,9 +2265,6 @@ setMethod("sign", signature(x = "Column"),
#'
#' Aggregate function: returns the number of distinct items in a group.
#'
-#' @param x Column to compute on
-#' @param ... other columns
-#'
#' @rdname countDistinct
#' @name n_distinct
#' @aliases n_distinct,Column-method
@@ -2276,9 +2276,7 @@ setMethod("n_distinct", signature(x = "Column"),
countDistinct(x, ...)
})
-#' @rdname nrow
-#' @param x Column to compute on
-#'
+#' @rdname count
#' @name n
#' @aliases n,Column-method
#' @export
@@ -2300,8 +2298,8 @@ setMethod("n", signature(x = "Column"),
#' NOTE: Use when ever possible specialized functions like \code{year}. These benefit from a
#' specialized implementation.
#'
-#' @param y Column to compute on
-#' @param x date format specification
+#' @param y Column to compute on.
+#' @param x date format specification.
#'
#' @family datetime_funcs
#' @rdname date_format
@@ -2320,8 +2318,8 @@ setMethod("date_format", signature(y = "Column", x = "character"),
#'
#' Assumes given timestamp is UTC and converts to given timezone.
#'
-#' @param y Column to compute on
-#' @param x time zone to use
+#' @param y Column to compute on.
+#' @param x time zone to use.
#'
#' @family datetime_funcs
#' @rdname from_utc_timestamp
@@ -2370,8 +2368,8 @@ setMethod("instr", signature(y = "Column", x = "character"),
#' Day of the week parameter is case insensitive, and accepts first three or two characters:
#' "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun".
#'
-#' @param y Column to compute on
-#' @param x Day of the week string
+#' @param y Column to compute on.
+#' @param x Day of the week string.
#'
#' @family datetime_funcs
#' @rdname next_day
@@ -2637,6 +2635,7 @@ setMethod("conv", signature(x = "Column", fromBase = "numeric", toBase = "numeri
#' Parses the expression string into the column that it represents, similar to
#' SparkDataFrame.selectExpr
#'
+#' @param x an expression character object to be parsed.
#' @family normal_funcs
#' @rdname expr
#' @aliases expr,character-method
@@ -2654,6 +2653,9 @@ setMethod("expr", signature(x = "character"),
#'
#' Formats the arguments in printf-style and returns the result as a string column.
#'
+#' @param format a character object of format strings.
+#' @param x a Column.
+#' @param ... additional Column(s).
#' @family string_funcs
#' @rdname format_string
#' @name format_string
@@ -2676,6 +2678,11 @@ setMethod("format_string", signature(format = "character", x = "Column"),
#' representing the timestamp of that moment in the current system time zone in the given
#' format.
#'
+#' @param x a Column of unix timestamp.
+#' @param format the target format. See
+#' \href{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}{
+#' Customizing Formats} for available options.
+#' @param ... further arguments to be passed to or from other methods.
#' @family datetime_funcs
#' @rdname from_unixtime
#' @name from_unixtime
@@ -2702,19 +2709,21 @@ setMethod("from_unixtime", signature(x = "Column"),
#' [12:05,12:10) but not in [12:00,12:05). Windows can support microsecond precision. Windows in
#' the order of months are not supported.
#'
-#' The time column must be of TimestampType.
-#'
-#' Durations are provided as strings, e.g. '1 second', '1 day 12 hours', '2 minutes'. Valid
-#' interval strings are 'week', 'day', 'hour', 'minute', 'second', 'millisecond', 'microsecond'.
-#' If the `slideDuration` is not provided, the windows will be tumbling windows.
-#'
-#' The startTime is the offset with respect to 1970-01-01 00:00:00 UTC with which to start
-#' window intervals. For example, in order to have hourly tumbling windows that start 15 minutes
-#' past the hour, e.g. 12:15-13:15, 13:15-14:15... provide `startTime` as `15 minutes`.
-#'
-#' The output column will be a struct called 'window' by default with the nested columns 'start'
-#' and 'end'.
-#'
+#' @param x a time Column. Must be of TimestampType.
+#' @param windowDuration a string specifying the width of the window, e.g. '1 second',
+#' '1 day 12 hours', '2 minutes'. Valid interval strings are 'week',
+#' 'day', 'hour', 'minute', 'second', 'millisecond', 'microsecond'.
+#' @param slideDuration a string specifying the sliding interval of the window. Same format as
+#' \code{windowDuration}. A new window will be generated every
+#' \code{slideDuration}. Must be less than or equal to
+#' the \code{windowDuration}.
+#' @param startTime the offset with respect to 1970-01-01 00:00:00 UTC with which to start
+#' window intervals. For example, in order to have hourly tumbling windows
+#' that start 15 minutes past the hour, e.g. 12:15-13:15, 13:15-14:15... provide
+#' \code{startTime} as \code{"15 minutes"}.
+#' @param ... further arguments to be passed to or from other methods.
+#' @return An output column of struct called 'window' by default with the nested columns 'start'
+#' and 'end'.
#' @family datetime_funcs
#' @rdname window
#' @name window
@@ -2766,6 +2775,10 @@ setMethod("window", signature(x = "Column"),
#' NOTE: The position is not zero based, but 1 based index, returns 0 if substr
#' could not be found in str.
#'
+#' @param substr a character string to be matched.
+#' @param str a Column where matches are sought for each entry.
+#' @param pos start position of search.
+#' @param ... further arguments to be passed to or from other methods.
#' @family string_funcs
#' @rdname locate
#' @aliases locate,character,Column-method
@@ -2785,6 +2798,9 @@ setMethod("locate", signature(substr = "character", str = "Column"),
#'
#' Left-pad the string column with
#'
+#' @param x the string Column to be left-padded.
+#' @param len maximum length of each output result.
+#' @param pad a character string to be padded with.
#' @family string_funcs
#' @rdname lpad
#' @aliases lpad,Column,numeric,character-method
@@ -2804,6 +2820,7 @@ setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"),
#'
#' Generate a random column with i.i.d. samples from U[0.0, 1.0].
#'
+#' @param seed a random seed. Can be missing.
#' @family normal_funcs
#' @rdname rand
#' @name rand
@@ -2832,6 +2849,7 @@ setMethod("rand", signature(seed = "numeric"),
#'
#' Generate a column with i.i.d. samples from the standard normal distribution.
#'
+#' @param seed a random seed. Can be missing.
#' @family normal_funcs
#' @rdname randn
#' @name randn
@@ -2860,6 +2878,9 @@ setMethod("randn", signature(seed = "numeric"),
#'
#' Extract a specific(idx) group identified by a java regex, from the specified string column.
#'
+#' @param x a string Column.
+#' @param pattern a regular expression.
+#' @param idx a group index.
#' @family string_funcs
#' @rdname regexp_extract
#' @name regexp_extract
@@ -2880,6 +2901,9 @@ setMethod("regexp_extract",
#'
#' Replace all substrings of the specified string value that match regexp with rep.
#'
+#' @param x a string Column.
+#' @param pattern a regular expression.
+#' @param replacement a character string that a matched \code{pattern} is replaced with.
#' @family string_funcs
#' @rdname regexp_replace
#' @name regexp_replace
@@ -2900,6 +2924,9 @@ setMethod("regexp_replace",
#'
#' Right-padded with pad to a length of len.
#'
+#' @param x the string Column to be right-padded.
+#' @param len maximum length of each output result.
+#' @param pad a character string to be padded with.
#' @family string_funcs
#' @rdname rpad
#' @name rpad
@@ -2922,6 +2949,11 @@ setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"),
#' returned. If count is negative, every to the right of the final delimiter (counting from the
#' right) is returned. substring_index performs a case-sensitive match when searching for delim.
#'
+#' @param x a Column.
+#' @param delim a delimiter string.
+#' @param count number of occurrences of \code{delim} before the substring is returned.
+#' A positive number means counting from the left, while negative means
+#' counting from the right.
#' @family string_funcs
#' @rdname substring_index
#' @aliases substring_index,Column,character,numeric-method
@@ -2949,6 +2981,11 @@ setMethod("substring_index",
#' The translate will happen when any character in the string matching with the character
#' in the matchingString.
#'
+#' @param x a string Column.
+#' @param matchingString a source string where each character will be translated.
+#' @param replaceString a target string where each \code{matchingString} character will
+#' be replaced by the character in \code{replaceString}
+#' at the same location, if any.
#' @family string_funcs
#' @rdname translate
#' @name translate
@@ -2997,6 +3034,10 @@ setMethod("unix_timestamp", signature(x = "Column", format = "missing"),
column(jc)
})
+#' @param x a Column of date, in string, date or timestamp type.
+#' @param format the target format. See
+#' \href{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}{
+#' Customizing Formats} for available options.
#' @rdname unix_timestamp
#' @name unix_timestamp
#' @aliases unix_timestamp,Column,character-method
@@ -3012,6 +3053,8 @@ setMethod("unix_timestamp", signature(x = "Column", format = "character"),
#' Evaluates a list of conditions and returns one of multiple possible result expressions.
#' For unmatched expressions null is returned.
#'
+#' @param condition the condition to test on. Must be a Column expression.
+#' @param value result expression.
#' @family normal_funcs
#' @rdname when
#' @name when
@@ -3033,6 +3076,9 @@ setMethod("when", signature(condition = "Column", value = "ANY"),
#' Evaluates a list of conditions and returns \code{yes} if the conditions are satisfied.
#' Otherwise \code{no} is returned for unmatched conditions.
#'
+#' @param test a Column expression that describes the condition.
+#' @param yes return values for \code{TRUE} elements of test.
+#' @param no return values for \code{FALSE} elements of test.
#' @family normal_funcs
#' @rdname ifelse
#' @name ifelse
@@ -3074,10 +3120,14 @@ setMethod("ifelse",
#' @family window_funcs
#' @aliases cume_dist,missing-method
#' @export
-#' @examples \dontrun{cume_dist()}
+#' @examples \dontrun{
+#' df <- createDataFrame(iris)
+#' ws <- orderBy(windowPartitionBy("Species"), "Sepal_Length")
+#' out <- select(df, over(cume_dist(), ws), df$Sepal_Length, df$Species)
+#' }
#' @note cume_dist since 1.6.0
setMethod("cume_dist",
- signature(x = "missing"),
+ signature("missing"),
function() {
jc <- callJStatic("org.apache.spark.sql.functions", "cume_dist")
column(jc)
@@ -3101,7 +3151,7 @@ setMethod("cume_dist",
#' @examples \dontrun{dense_rank()}
#' @note dense_rank since 1.6.0
setMethod("dense_rank",
- signature(x = "missing"),
+ signature("missing"),
function() {
jc <- callJStatic("org.apache.spark.sql.functions", "dense_rank")
column(jc)
@@ -3115,6 +3165,11 @@ setMethod("dense_rank",
#'
#' This is equivalent to the LAG function in SQL.
#'
+#' @param x the column as a character string or a Column to compute on.
+#' @param offset the number of rows back from the current row from which to obtain a value.
+#' If not specified, the default is 1.
+#' @param defaultValue default to use when the offset row does not exist.
+#' @param ... further arguments to be passed to or from other methods.
#' @rdname lag
#' @name lag
#' @aliases lag,characterOrColumn-method
@@ -3143,7 +3198,7 @@ setMethod("lag",
#' an `offset` of one will return the next row at any given point in the window partition.
#'
#' This is equivalent to the LEAD function in SQL.
-#'
+#'
#' @param x Column to compute on
#' @param offset Number of rows to offset
#' @param defaultValue (Optional) default value to use
@@ -3211,7 +3266,7 @@ setMethod("ntile",
#' @examples \dontrun{percent_rank()}
#' @note percent_rank since 1.6.0
setMethod("percent_rank",
- signature(x = "missing"),
+ signature("missing"),
function() {
jc <- callJStatic("org.apache.spark.sql.functions", "percent_rank")
column(jc)
@@ -3243,6 +3298,8 @@ setMethod("rank",
})
# Expose rank() in the R base package
+#' @param x a numeric, complex, character or logical vector.
+#' @param ... additional argument(s) passed to the method.
#' @name rank
#' @rdname rank
#' @aliases rank,ANY-method
@@ -3267,7 +3324,7 @@ setMethod("rank",
#' @examples \dontrun{row_number()}
#' @note row_number since 1.6.0
setMethod("row_number",
- signature(x = "missing"),
+ signature("missing"),
function() {
jc <- callJStatic("org.apache.spark.sql.functions", "row_number")
column(jc)
@@ -3318,7 +3375,7 @@ setMethod("explode",
#' size
#'
#' Returns length of array or map.
-#'
+#'
#' @param x Column to compute on
#'
#' @rdname size
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 693aa31d3ecab..6610a25c8c05a 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -380,6 +380,9 @@ setGeneric("value", function(bcast) { standardGeneric("value") })
#################### SparkDataFrame Methods ########################
+#' @param x a SparkDataFrame or GroupedData.
+#' @param ... further arguments to be passed to or from other methods.
+#' @return A SparkDataFrame.
#' @rdname summarize
#' @export
setGeneric("agg", function (x, ...) { standardGeneric("agg") })
@@ -407,6 +410,8 @@ setGeneric("cache", function(x) { standardGeneric("cache") })
#' @export
setGeneric("collect", function(x, ...) { standardGeneric("collect") })
+#' @param do.NULL currently not used.
+#' @param prefix currently not used.
#' @rdname columns
#' @export
setGeneric("colnames", function(x, do.NULL = TRUE, prefix = "col") { standardGeneric("colnames") })
@@ -427,15 +432,24 @@ setGeneric("coltypes<-", function(x, value) { standardGeneric("coltypes<-") })
#' @export
setGeneric("columns", function(x) {standardGeneric("columns") })
-#' @rdname nrow
+#' @param x a GroupedData or Column.
+#' @rdname count
#' @export
setGeneric("count", function(x) { standardGeneric("count") })
#' @rdname cov
+#' @param x a Column object or a SparkDataFrame.
+#' @param ... additional argument(s). If `x` is a Column object, a Column object
+#' should be provided. If `x` is a SparkDataFrame, two column names should
+#' be provided.
#' @export
setGeneric("cov", function(x, ...) {standardGeneric("cov") })
#' @rdname corr
+#' @param x a Column object or a SparkDataFrame.
+#' @param ... additional argument(s). If `x` is a Column object, a Column object
+#' should be provided. If `x` is a SparkDataFrame, two column names should
+#' be provided.
#' @export
setGeneric("corr", function(x, ...) {standardGeneric("corr") })
@@ -462,10 +476,14 @@ setGeneric("dapply", function(x, func, schema) { standardGeneric("dapply") })
#' @export
setGeneric("dapplyCollect", function(x, func) { standardGeneric("dapplyCollect") })
+#' @param x a SparkDataFrame or GroupedData.
+#' @param ... additional argument(s) passed to the method.
#' @rdname gapply
#' @export
setGeneric("gapply", function(x, ...) { standardGeneric("gapply") })
+#' @param x a SparkDataFrame or GroupedData.
+#' @param ... additional argument(s) passed to the method.
#' @rdname gapplyCollect
#' @export
setGeneric("gapplyCollect", function(x, ...) { standardGeneric("gapplyCollect") })
@@ -667,8 +685,8 @@ setGeneric("selectExpr", function(x, expr, ...) { standardGeneric("selectExpr")
#' @export
setGeneric("showDF", function(x, ...) { standardGeneric("showDF") })
-# @rdname subset
-# @export
+#' @rdname subset
+#' @export
setGeneric("subset", function(x, ...) { standardGeneric("subset") })
#' @rdname summarize
@@ -735,6 +753,8 @@ setGeneric("between", function(x, bounds) { standardGeneric("between") })
setGeneric("cast", function(x, dataType) { standardGeneric("cast") })
#' @rdname columnfunctions
+#' @param x a Column object.
+#' @param ... additional argument(s).
#' @export
setGeneric("contains", function(x, ...) { standardGeneric("contains") })
@@ -830,6 +850,8 @@ setGeneric("array_contains", function(x, value) { standardGeneric("array_contain
#' @export
setGeneric("ascii", function(x) { standardGeneric("ascii") })
+#' @param x Column to compute on or a GroupedData object.
+#' @param ... additional argument(s) when `x` is a GroupedData object.
#' @rdname avg
#' @export
setGeneric("avg", function(x, ...) { standardGeneric("avg") })
@@ -886,9 +908,10 @@ setGeneric("crc32", function(x) { standardGeneric("crc32") })
#' @export
setGeneric("hash", function(x, ...) { standardGeneric("hash") })
+#' @param x empty. Should be used with no argument.
#' @rdname cume_dist
#' @export
-setGeneric("cume_dist", function(x) { standardGeneric("cume_dist") })
+setGeneric("cume_dist", function(x = "missing") { standardGeneric("cume_dist") })
#' @rdname datediff
#' @export
@@ -918,9 +941,10 @@ setGeneric("dayofyear", function(x) { standardGeneric("dayofyear") })
#' @export
setGeneric("decode", function(x, charset) { standardGeneric("decode") })
+#' @param x empty. Should be used with no argument.
#' @rdname dense_rank
#' @export
-setGeneric("dense_rank", function(x) { standardGeneric("dense_rank") })
+setGeneric("dense_rank", function(x = "missing") { standardGeneric("dense_rank") })
#' @rdname encode
#' @export
@@ -1034,10 +1058,11 @@ setGeneric("md5", function(x) { standardGeneric("md5") })
#' @export
setGeneric("minute", function(x) { standardGeneric("minute") })
+#' @param x empty. Should be used with no argument.
#' @rdname monotonically_increasing_id
#' @export
setGeneric("monotonically_increasing_id",
- function(x) { standardGeneric("monotonically_increasing_id") })
+ function(x = "missing") { standardGeneric("monotonically_increasing_id") })
#' @rdname month
#' @export
@@ -1047,7 +1072,7 @@ setGeneric("month", function(x) { standardGeneric("month") })
#' @export
setGeneric("months_between", function(y, x) { standardGeneric("months_between") })
-#' @rdname nrow
+#' @rdname count
#' @export
setGeneric("n", function(x) { standardGeneric("n") })
@@ -1071,9 +1096,10 @@ setGeneric("ntile", function(x) { standardGeneric("ntile") })
#' @export
setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") })
+#' @param x empty. Should be used with no argument.
#' @rdname percent_rank
#' @export
-setGeneric("percent_rank", function(x) { standardGeneric("percent_rank") })
+setGeneric("percent_rank", function(x = "missing") { standardGeneric("percent_rank") })
#' @rdname pmod
#' @export
@@ -1114,11 +1140,12 @@ setGeneric("reverse", function(x) { standardGeneric("reverse") })
#' @rdname rint
#' @export
-setGeneric("rint", function(x, ...) { standardGeneric("rint") })
+setGeneric("rint", function(x) { standardGeneric("rint") })
+#' @param x empty. Should be used with no argument.
#' @rdname row_number
#' @export
-setGeneric("row_number", function(x) { standardGeneric("row_number") })
+setGeneric("row_number", function(x = "missing") { standardGeneric("row_number") })
#' @rdname rpad
#' @export
@@ -1176,9 +1203,10 @@ setGeneric("sort_array", function(x, asc = TRUE) { standardGeneric("sort_array")
#' @export
setGeneric("soundex", function(x) { standardGeneric("soundex") })
+#' @param x empty. Should be used with no argument.
#' @rdname spark_partition_id
#' @export
-setGeneric("spark_partition_id", function(x) { standardGeneric("spark_partition_id") })
+setGeneric("spark_partition_id", function(x = "missing") { standardGeneric("spark_partition_id") })
#' @rdname sd
#' @export
@@ -1276,10 +1304,16 @@ setGeneric("year", function(x) { standardGeneric("year") })
#' @export
setGeneric("spark.glm", function(data, formula, ...) { standardGeneric("spark.glm") })
+#' @param x,y For \code{glm}: logical values indicating whether the response vector
+#' and model matrix used in the fitting process should be returned as
+#' components of the returned value.
+#' @inheritParams stats::glm
#' @rdname glm
#' @export
setGeneric("glm")
+#' @param object a fitted ML model object.
+#' @param ... additional argument(s) passed to the method.
#' @rdname predict
#' @export
setGeneric("predict", function(object, ...) { standardGeneric("predict") })
@@ -1302,7 +1336,7 @@ setGeneric("spark.naiveBayes", function(data, formula, ...) { standardGeneric("s
#' @rdname spark.survreg
#' @export
-setGeneric("spark.survreg", function(data, formula, ...) { standardGeneric("spark.survreg") })
+setGeneric("spark.survreg", function(data, formula) { standardGeneric("spark.survreg") })
#' @rdname spark.lda
#' @param ... Additional parameters to tune LDA.
@@ -1328,7 +1362,9 @@ setGeneric("spark.gaussianMixture",
standardGeneric("spark.gaussianMixture")
})
-#' write.ml
+#' @param object a fitted ML model object.
+#' @param path the directory where the model is saved.
+#' @param ... additional argument(s) passed to the method.
#' @rdname write.ml
#' @export
setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") })
diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R
index 85348ae76baa7..3c85ada91a444 100644
--- a/R/pkg/R/group.R
+++ b/R/pkg/R/group.R
@@ -59,8 +59,7 @@ setMethod("show", "GroupedData",
#' Count the number of rows for each group.
#' The resulting SparkDataFrame will also contain the grouping columns.
#'
-#' @param x a GroupedData
-#' @return a SparkDataFrame
+#' @return A SparkDataFrame.
#' @rdname count
#' @aliases count,GroupedData-method
#' @export
@@ -83,8 +82,6 @@ setMethod("count",
#' df2 <- agg(df, = )
#' df2 <- agg(df, newColName = aggFunction(column))
#'
-#' @param x a GroupedData
-#' @return a SparkDataFrame
#' @rdname summarize
#' @aliases agg,GroupedData-method
#' @name agg
@@ -201,7 +198,6 @@ createMethods()
#' gapply
#'
-#' @param x A GroupedData
#' @rdname gapply
#' @aliases gapply,GroupedData-method
#' @name gapply
@@ -216,7 +212,6 @@ setMethod("gapply",
#' gapplyCollect
#'
-#' @param x A GroupedData
#' @rdname gapplyCollect
#' @aliases gapplyCollect,GroupedData-method
#' @name gapplyCollect
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index 36f38fc73a510..9a53c80aecded 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -113,17 +113,18 @@ NULL
#' Users can call \code{summary} to print a summary of the fitted model, \code{predict} to make
#' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models.
#'
-#' @param data SparkDataFrame for training.
-#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
+#' @param data a SparkDataFrame for training.
+#' @param formula a symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~', '.', ':', '+', and '-'.
-#' @param family A description of the error distribution and link function to be used in the model.
+#' @param family a description of the error distribution and link function to be used in the model.
#' This can be a character string naming a family function, a family function or
#' the result of a call to a family function. Refer R family at
#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
-#' @param tol Positive convergence tolerance of iterations.
-#' @param maxIter Integer giving the maximal number of IRLS iterations.
-#' @param weightCol The weight column name. If this is not set or NULL, we treat all instance
+#' @param weightCol the weight column name. If this is not set or NULL, we treat all instance
#' weights as 1.0.
+#' @param tol positive convergence tolerance of iterations.
+#' @param maxIter integer giving the maximal number of IRLS iterations.
+#' @param ... additional arguments passed to the method.
#' @aliases spark.glm,SparkDataFrame,formula-method
#' @return \code{spark.glm} returns a fitted generalized linear model
#' @rdname spark.glm
@@ -178,17 +179,17 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
#' Generalized Linear Models (R-compliant)
#'
#' Fits a generalized linear model, similarly to R's glm().
-#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
+#' @param formula a symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~', '.', ':', '+', and '-'.
-#' @param data SparkDataFrame for training.
-#' @param family A description of the error distribution and link function to be used in the model.
+#' @param data a SparkDataFrame or R's glm data for training.
+#' @param family a description of the error distribution and link function to be used in the model.
#' This can be a character string naming a family function, a family function or
#' the result of a call to a family function. Refer R family at
#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
-#' @param epsilon Positive convergence tolerance of iterations.
-#' @param maxit Integer giving the maximal number of IRLS iterations.
-#' @param weightCol The weight column name. If this is not set or NULL, we treat all instance
+#' @param weightCol the weight column name. If this is not set or NULL, we treat all instance
#' weights as 1.0.
+#' @param epsilon positive convergence tolerance of iterations.
+#' @param maxit integer giving the maximal number of IRLS iterations.
#' @return \code{glm} returns a fitted generalized linear model.
#' @rdname glm
#' @export
@@ -209,7 +210,7 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDat
# Returns the summary of a model produced by glm() or spark.glm(), similarly to R's summary().
-#' @param object A fitted generalized linear model
+#' @param object a fitted generalized linear model.
#' @return \code{summary} returns a summary object of the fitted model, a list of components
#' including at least the coefficients, null/residual deviance, null/residual degrees
#' of freedom, AIC and number of iterations IRLS takes.
@@ -250,7 +251,7 @@ setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"),
# Prints the summary of GeneralizedLinearRegressionModel
#' @rdname spark.glm
-#' @param x Summary object of fitted generalized linear model returned by \code{summary} function
+#' @param x summary object of fitted generalized linear model returned by \code{summary} function
#' @export
#' @note print.summary.GeneralizedLinearRegressionModel since 2.0.0
print.summary.GeneralizedLinearRegressionModel <- function(x, ...) {
@@ -282,7 +283,7 @@ print.summary.GeneralizedLinearRegressionModel <- function(x, ...) {
# Makes predictions from a generalized linear model produced by glm() or spark.glm(),
# similarly to R's predict().
-#' @param newData SparkDataFrame for testing
+#' @param newData a SparkDataFrame for testing.
#' @return \code{predict} returns a SparkDataFrame containing predicted labels in a column named
#' "prediction"
#' @rdname spark.glm
@@ -296,7 +297,7 @@ setMethod("predict", signature(object = "GeneralizedLinearRegressionModel"),
# Makes predictions from a naive Bayes model or a model produced by spark.naiveBayes(),
# similarly to R package e1071's predict.
-#' @param newData A SparkDataFrame for testing
+#' @param newData a SparkDataFrame for testing.
#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named
#' "prediction"
#' @rdname spark.naiveBayes
@@ -309,9 +310,9 @@ setMethod("predict", signature(object = "NaiveBayesModel"),
# Returns the summary of a naive Bayes model produced by \code{spark.naiveBayes}
-#' @param object A naive Bayes model fitted by \code{spark.naiveBayes}
+#' @param object a naive Bayes model fitted by \code{spark.naiveBayes}.
#' @return \code{summary} returns a list containing \code{apriori}, the label distribution, and
-#' \code{tables}, conditional probabilities given the target label
+#' \code{tables}, conditional probabilities given the target label.
#' @rdname spark.naiveBayes
#' @export
#' @note summary(NaiveBayesModel) since 2.0.0
@@ -491,7 +492,6 @@ setMethod("predict", signature(object = "IsotonicRegressionModel"),
# Get the summary of an IsotonicRegressionModel model
-#' @param object a fitted IsotonicRegressionModel
#' @param ... Other optional arguments to summary of an IsotonicRegressionModel
#' @return \code{summary} returns the model's boundaries and prediction as lists
#' @rdname spark.isoreg
@@ -512,14 +512,15 @@ setMethod("summary", signature(object = "IsotonicRegressionModel"),
#' Users can call \code{summary} to print a summary of the fitted model, \code{predict} to make
#' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models.
#'
-#' @param data SparkDataFrame for training
-#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
+#' @param data a SparkDataFrame for training.
+#' @param formula a symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~', '.', ':', '+', and '-'.
#' Note that the response variable of formula is empty in spark.kmeans.
-#' @param k Number of centers
-#' @param maxIter Maximum iteration number
-#' @param initMode The initialization algorithm choosen to fit the model
-#' @return \code{spark.kmeans} returns a fitted k-means model
+#' @param k number of centers.
+#' @param maxIter maximum iteration number.
+#' @param initMode the initialization algorithm choosen to fit the model.
+#' @param ... additional argument(s) passed to the method.
+#' @return \code{spark.kmeans} returns a fitted k-means model.
#' @rdname spark.kmeans
#' @aliases spark.kmeans,SparkDataFrame,formula-method
#' @name spark.kmeans
@@ -560,8 +561,11 @@ setMethod("spark.kmeans", signature(data = "SparkDataFrame", formula = "formula"
#' Get fitted result from a k-means model, similarly to R's fitted().
#' Note: A saved-loaded model does not support this method.
#'
-#' @param object A fitted k-means model
-#' @return \code{fitted} returns a SparkDataFrame containing fitted values
+#' @param object a fitted k-means model.
+#' @param method type of fitted results, \code{"centers"} for cluster centers
+#' or \code{"classes"} for assigned classes.
+#' @param ... additional argument(s) passed to the method.
+#' @return \code{fitted} returns a SparkDataFrame containing fitted values.
#' @rdname fitted
#' @export
#' @examples
@@ -585,8 +589,8 @@ setMethod("fitted", signature(object = "KMeansModel"),
# Get the summary of a k-means model
-#' @param object A fitted k-means model
-#' @return \code{summary} returns the model's coefficients, size and cluster
+#' @param object a fitted k-means model.
+#' @return \code{summary} returns the model's coefficients, size and cluster.
#' @rdname spark.kmeans
#' @export
#' @note summary(KMeansModel) since 2.0.0
@@ -612,7 +616,8 @@ setMethod("summary", signature(object = "KMeansModel"),
# Predicted values based on a k-means model
-#' @return \code{predict} returns the predicted values based on a k-means model
+#' @param newData a SparkDataFrame for testing.
+#' @return \code{predict} returns the predicted values based on a k-means model.
#' @rdname spark.kmeans
#' @export
#' @note predict(KMeansModel) since 2.0.0
@@ -628,11 +633,12 @@ setMethod("predict", signature(object = "KMeansModel"),
#' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models.
#' Only categorical data is supported.
#'
-#' @param data A \code{SparkDataFrame} of observations and labels for model fitting
-#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
+#' @param data a \code{SparkDataFrame} of observations and labels for model fitting.
+#' @param formula a symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~', '.', ':', '+', and '-'.
-#' @param smoothing Smoothing parameter
-#' @return \code{spark.naiveBayes} returns a fitted naive Bayes model
+#' @param smoothing smoothing parameter.
+#' @param ... additional argument(s) passed to the method. Currently only \code{smoothing}.
+#' @return \code{spark.naiveBayes} returns a fitted naive Bayes model.
#' @rdname spark.naiveBayes
#' @aliases spark.naiveBayes,SparkDataFrame,formula-method
#' @name spark.naiveBayes
@@ -668,8 +674,8 @@ setMethod("spark.naiveBayes", signature(data = "SparkDataFrame", formula = "form
# Saves the Bernoulli naive Bayes model to the input path.
-#' @param path The directory where the model is saved
-#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
+#' @param path the directory where the model is saved
+#' @param overwrite overwrites or not if the output path already exists. Default is FALSE
#' which means throw exception if the output path exists.
#'
#' @rdname spark.naiveBayes
@@ -687,10 +693,9 @@ setMethod("write.ml", signature(object = "NaiveBayesModel", path = "character"),
# Saves the AFT survival regression model to the input path.
-#' @param path The directory where the model is saved
-#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
+#' @param path the directory where the model is saved.
+#' @param overwrite overwrites or not if the output path already exists. Default is FALSE
#' which means throw exception if the output path exists.
-#'
#' @rdname spark.survreg
#' @export
#' @note write.ml(AFTSurvivalRegressionModel, character) since 2.0.0
@@ -706,8 +711,8 @@ setMethod("write.ml", signature(object = "AFTSurvivalRegressionModel", path = "c
# Saves the generalized linear model to the input path.
-#' @param path The directory where the model is saved
-#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
+#' @param path the directory where the model is saved.
+#' @param overwrite overwrites or not if the output path already exists. Default is FALSE
#' which means throw exception if the output path exists.
#'
#' @rdname spark.glm
@@ -724,8 +729,8 @@ setMethod("write.ml", signature(object = "GeneralizedLinearRegressionModel", pat
# Save fitted MLlib model to the input path
-#' @param path The directory where the model is saved
-#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
+#' @param path the directory where the model is saved.
+#' @param overwrite overwrites or not if the output path already exists. Default is FALSE
#' which means throw exception if the output path exists.
#'
#' @rdname spark.kmeans
@@ -780,8 +785,8 @@ setMethod("write.ml", signature(object = "GaussianMixtureModel", path = "charact
#' Load a fitted MLlib model from the input path.
#'
-#' @param path Path of the model to read.
-#' @return a fitted MLlib model
+#' @param path path of the model to read.
+#' @return A fitted MLlib model.
#' @rdname read.ml
#' @name read.ml
#' @export
@@ -823,11 +828,11 @@ read.ml <- function(path) {
#' \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to
#' save/load fitted models.
#'
-#' @param data A SparkDataFrame for training
-#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
+#' @param data a SparkDataFrame for training.
+#' @param formula a symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~', ':', '+', and '-'.
-#' Note that operator '.' is not supported currently
-#' @return \code{spark.survreg} returns a fitted AFT survival regression model
+#' Note that operator '.' is not supported currently.
+#' @return \code{spark.survreg} returns a fitted AFT survival regression model.
#' @rdname spark.survreg
#' @seealso survival: \url{https://cran.r-project.org/web/packages/survival/}
#' @export
@@ -851,7 +856,7 @@ read.ml <- function(path) {
#' }
#' @note spark.survreg since 2.0.0
setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula"),
- function(data, formula, ...) {
+ function(data, formula) {
formula <- paste(deparse(formula), collapse = "")
jobj <- callJStatic("org.apache.spark.ml.r.AFTSurvivalRegressionWrapper",
"fit", formula, data@sdf)
@@ -927,14 +932,14 @@ setMethod("spark.lda", signature(data = "SparkDataFrame"),
# Returns a summary of the AFT survival regression model produced by spark.survreg,
# similarly to R's summary().
-#' @param object A fitted AFT survival regression model
+#' @param object a fitted AFT survival regression model.
#' @return \code{summary} returns a list containing the model's coefficients,
#' intercept and log(scale)
#' @rdname spark.survreg
#' @export
#' @note summary(AFTSurvivalRegressionModel) since 2.0.0
setMethod("summary", signature(object = "AFTSurvivalRegressionModel"),
- function(object, ...) {
+ function(object) {
jobj <- object@jobj
features <- callJMethod(jobj, "rFeatures")
coefficients <- callJMethod(jobj, "rCoefficients")
@@ -947,9 +952,9 @@ setMethod("summary", signature(object = "AFTSurvivalRegressionModel"),
# Makes predictions from an AFT survival regression model or a model produced by
# spark.survreg, similarly to R package survival's predict.
-#' @param newData A SparkDataFrame for testing
+#' @param newData a SparkDataFrame for testing.
#' @return \code{predict} returns a SparkDataFrame containing predicted values
-#' on the original scale of the data (mean predicted value at scale = 1.0)
+#' on the original scale of the data (mean predicted value at scale = 1.0).
#' @rdname spark.survreg
#' @export
#' @note predict(AFTSurvivalRegressionModel) since 2.0.0
diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R
index b429f5de13b87..cb5bdb90175bf 100644
--- a/R/pkg/R/schema.R
+++ b/R/pkg/R/schema.R
@@ -92,8 +92,9 @@ print.structType <- function(x, ...) {
#'
#' Create a structField object that contains the metadata for a single field in a schema.
#'
-#' @param x The name of the field
-#' @return a structField object
+#' @param x the name of the field.
+#' @param ... additional argument(s) passed to the method.
+#' @return A structField object.
#' @rdname structField
#' @export
#' @examples
diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R
index f8bdee739ef02..85815af1f3639 100644
--- a/R/pkg/R/sparkR.R
+++ b/R/pkg/R/sparkR.R
@@ -320,14 +320,15 @@ sparkRHive.init <- function(jsc = NULL) {
#' For details on how to initialize and use SparkR, refer to SparkR programming guide at
#' \url{http://spark.apache.org/docs/latest/sparkr.html#starting-up-sparksession}.
#'
-#' @param master The Spark master URL
-#' @param appName Application name to register with cluster manager
-#' @param sparkHome Spark Home directory
-#' @param sparkConfig Named list of Spark configuration to set on worker nodes
-#' @param sparkJars Character vector of jar files to pass to the worker nodes
-#' @param sparkPackages Character vector of packages from spark-packages.org
-#' @param enableHiveSupport Enable support for Hive, fallback if not built with Hive support; once
+#' @param master the Spark master URL.
+#' @param appName application name to register with cluster manager.
+#' @param sparkHome Spark Home directory.
+#' @param sparkConfig named list of Spark configuration to set on worker nodes.
+#' @param sparkJars character vector of jar files to pass to the worker nodes.
+#' @param sparkPackages character vector of packages from spark-packages.org
+#' @param enableHiveSupport enable support for Hive, fallback if not built with Hive support; once
#' set, this cannot be turned off on an existing session
+#' @param ... named Spark properties passed to the method.
#' @export
#' @examples
#'\dontrun{
@@ -413,9 +414,9 @@ sparkR.session <- function(
#' Assigns a group ID to all the jobs started by this thread until the group ID is set to a
#' different value or cleared.
#'
-#' @param groupid the ID to be assigned to job groups
-#' @param description description for the job group ID
-#' @param interruptOnCancel flag to indicate if the job is interrupted on job cancellation
+#' @param groupId the ID to be assigned to job groups.
+#' @param description description for the job group ID.
+#' @param interruptOnCancel flag to indicate if the job is interrupted on job cancellation.
#' @rdname setJobGroup
#' @name setJobGroup
#' @examples
diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R
index 2b4ce195cbddb..8ea24d81729ec 100644
--- a/R/pkg/R/stats.R
+++ b/R/pkg/R/stats.R
@@ -25,6 +25,7 @@ setOldClass("jobj")
#' table. The number of distinct values for each column should be less than 1e4. At most 1e6
#' non-zero pair frequencies will be returned.
#'
+#' @param x a SparkDataFrame
#' @param col1 name of the first column. Distinct items will make the first item of each row.
#' @param col2 name of the second column. Distinct items will make the column names of the output.
#' @return a local R data.frame representing the contingency table. The first column of each row
@@ -53,10 +54,9 @@ setMethod("crosstab",
#' Calculate the sample covariance of two numerical columns of a SparkDataFrame.
#'
-#' @param x A SparkDataFrame
-#' @param col1 the name of the first column
-#' @param col2 the name of the second column
-#' @return the covariance of the two columns.
+#' @param colName1 the name of the first column
+#' @param colName2 the name of the second column
+#' @return The covariance of the two columns.
#'
#' @rdname cov
#' @name cov
@@ -71,19 +71,18 @@ setMethod("crosstab",
#' @note cov since 1.6.0
setMethod("cov",
signature(x = "SparkDataFrame"),
- function(x, col1, col2) {
- stopifnot(class(col1) == "character" && class(col2) == "character")
+ function(x, colName1, colName2) {
+ stopifnot(class(colName1) == "character" && class(colName2) == "character")
statFunctions <- callJMethod(x@sdf, "stat")
- callJMethod(statFunctions, "cov", col1, col2)
+ callJMethod(statFunctions, "cov", colName1, colName2)
})
#' Calculates the correlation of two columns of a SparkDataFrame.
#' Currently only supports the Pearson Correlation Coefficient.
#' For Spearman Correlation, consider using RDD methods found in MLlib's Statistics.
#'
-#' @param x A SparkDataFrame
-#' @param col1 the name of the first column
-#' @param col2 the name of the second column
+#' @param colName1 the name of the first column
+#' @param colName2 the name of the second column
#' @param method Optional. A character specifying the method for calculating the correlation.
#' only "pearson" is allowed now.
#' @return The Pearson Correlation Coefficient as a Double.
@@ -102,10 +101,10 @@ setMethod("cov",
#' @note corr since 1.6.0
setMethod("corr",
signature(x = "SparkDataFrame"),
- function(x, col1, col2, method = "pearson") {
- stopifnot(class(col1) == "character" && class(col2) == "character")
+ function(x, colName1, colName2, method = "pearson") {
+ stopifnot(class(colName1) == "character" && class(colName2) == "character")
statFunctions <- callJMethod(x@sdf, "stat")
- callJMethod(statFunctions, "corr", col1, col2, method)
+ callJMethod(statFunctions, "corr", colName1, colName2, method)
})
From 3e5fdeb3fb084cc9d25ce2f3f8cbf07a0aa2c573 Mon Sep 17 00:00:00 2001
From: "wm624@hotmail.com"
Date: Sat, 20 Aug 2016 07:00:51 -0700
Subject: [PATCH 038/270] [SPARKR][EXAMPLE] change example APP name
## What changes were proposed in this pull request?
(Please fill in changes proposed in this fix)
For R SQL example, appname is "MyApp". While examples in scala, Java and python, the appName is "x Spark SQL basic example".
I made the R example consistent with other examples.
## How was this patch tested?
(Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests)
Manual test
(If this patch involves UI changes, please attach a screenshot; otherwise, remove this)
Author: wm624@hotmail.com
Closes #14703 from wangmiao1981/example.
---
examples/src/main/r/RSparkSQLExample.R | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/examples/src/main/r/RSparkSQLExample.R b/examples/src/main/r/RSparkSQLExample.R
index de489e1bda2c3..4e0267a03851b 100644
--- a/examples/src/main/r/RSparkSQLExample.R
+++ b/examples/src/main/r/RSparkSQLExample.R
@@ -18,7 +18,7 @@
library(SparkR)
# $example on:init_session$
-sparkR.session(appName = "MyApp", sparkConfig = list(spark.some.config.option = "some-value"))
+sparkR.session(appName = "R Spark SQL basic example", sparkConfig = list(spark.some.config.option = "some-value"))
# $example off:init_session$
From 31a015572024046f4deaa6cec66bb6fab110f31d Mon Sep 17 00:00:00 2001
From: Liang-Chi Hsieh
Date: Sat, 20 Aug 2016 23:29:48 +0800
Subject: [PATCH 039/270] [SPARK-17104][SQL] LogicalRelation.newInstance should
follow the semantics of MultiInstanceRelation
## What changes were proposed in this pull request?
Currently `LogicalRelation.newInstance()` simply creates another `LogicalRelation` object with the same parameters. However, the `newInstance()` method inherited from `MultiInstanceRelation` should return a copy of object with unique expression ids. Current `LogicalRelation.newInstance()` can cause failure when doing self-join.
## How was this patch tested?
Jenkins tests.
Author: Liang-Chi Hsieh
Closes #14682 from viirya/fix-localrelation.
---
.../sql/execution/datasources/LogicalRelation.scala | 11 +++++++++--
.../org/apache/spark/sql/hive/parquetSuites.scala | 7 +++++++
2 files changed, 16 insertions(+), 2 deletions(-)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala
index 90711f2b1dde4..2a8e147011f55 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala
@@ -79,11 +79,18 @@ case class LogicalRelation(
/** Used to lookup original attribute capitalization */
val attributeMap: AttributeMap[AttributeReference] = AttributeMap(output.map(o => (o, o)))
- def newInstance(): this.type =
+ /**
+ * Returns a new instance of this LogicalRelation. According to the semantics of
+ * MultiInstanceRelation, this method returns a copy of this object with
+ * unique expression ids. We respect the `expectedOutputAttributes` and create
+ * new instances of attributes in it.
+ */
+ override def newInstance(): this.type = {
LogicalRelation(
relation,
- expectedOutputAttributes,
+ expectedOutputAttributes.map(_.map(_.newInstance())),
metastoreTableIdentifier).asInstanceOf[this.type]
+ }
override def refresh(): Unit = relation match {
case fs: HadoopFsRelation => fs.refresh()
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
index 31b6197d56fc7..e92bbdea75a7b 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
@@ -589,6 +589,13 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest {
}
}
}
+
+ test("self-join") {
+ val table = spark.table("normal_parquet")
+ val selfJoin = table.as("t1").join(table.as("t2"))
+ checkAnswer(selfJoin,
+ sql("SELECT * FROM normal_parquet x JOIN normal_parquet y"))
+ }
}
/**
From 9560c8d29542a5dcaaa07b7af9ef5ddcdbb5d14d Mon Sep 17 00:00:00 2001
From: petermaxlee
Date: Sun, 21 Aug 2016 00:25:55 +0800
Subject: [PATCH 040/270] [SPARK-17124][SQL] RelationalGroupedDataset.agg
should preserve order and allow multiple aggregates per column
## What changes were proposed in this pull request?
This patch fixes a longstanding issue with one of the RelationalGroupedDataset.agg function. Even though the signature accepts vararg of pairs, the underlying implementation turns the seq into a map, and thus not order preserving nor allowing multiple aggregates per column.
This change also allows users to use this function to run multiple different aggregations for a single column, e.g.
```
agg("age" -> "max", "age" -> "count")
```
## How was this patch tested?
Added a test case in DataFrameAggregateSuite.
Author: petermaxlee
Closes #14697 from petermaxlee/SPARK-17124.
---
.../apache/spark/sql/RelationalGroupedDataset.scala | 6 ++++--
.../org/apache/spark/sql/DataFrameAggregateSuite.scala | 10 ++++++++++
2 files changed, 14 insertions(+), 2 deletions(-)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 7cfd1cdc7d5d1..53d732403f979 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -128,7 +128,7 @@ class RelationalGroupedDataset protected[sql](
}
/**
- * (Scala-specific) Compute aggregates by specifying a map from column name to
+ * (Scala-specific) Compute aggregates by specifying the column names and
* aggregate methods. The resulting [[DataFrame]] will also contain the grouping columns.
*
* The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`.
@@ -143,7 +143,9 @@ class RelationalGroupedDataset protected[sql](
* @since 1.3.0
*/
def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = {
- agg((aggExpr +: aggExprs).toMap)
+ toDF((aggExpr +: aggExprs).map { case (colName, expr) =>
+ strToExpr(expr)(df(colName).expr)
+ })
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 92aa7b95434dc..69a3b5f278fd8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -87,6 +87,16 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
)
}
+ test("SPARK-17124 agg should be ordering preserving") {
+ val df = spark.range(2)
+ val ret = df.groupBy("id").agg("id" -> "sum", "id" -> "count", "id" -> "min")
+ assert(ret.schema.map(_.name) == Seq("id", "sum(id)", "count(id)", "min(id)"))
+ checkAnswer(
+ ret,
+ Row(0, 0, 1, 0) :: Row(1, 1, 1, 1) :: Nil
+ )
+ }
+
test("rollup") {
checkAnswer(
courseSales.rollup("course", "year").sum("earnings"),
From 9f37d4eac28dd179dd523fa7d645be97bb52af9c Mon Sep 17 00:00:00 2001
From: Bryan Cutler
Date: Sat, 20 Aug 2016 13:45:26 -0700
Subject: [PATCH 041/270] [SPARK-12666][CORE] SparkSubmit packages fix for when
'default' conf doesn't exist in dependent module
## What changes were proposed in this pull request?
Adding a "(runtime)" to the dependency configuration will set a fallback configuration to be used if the requested one is not found. E.g. with the setting "default(runtime)", Ivy will look for the conf "default" in the module ivy file and if not found will look for the conf "runtime". This can help with the case when using "sbt publishLocal" which does not write a "default" conf in the published ivy.xml file.
## How was this patch tested?
used spark-submit with --packages option for a package published locally with no default conf, and a package resolved from Maven central.
Author: Bryan Cutler
Closes #13428 from BryanCutler/fallback-package-conf-SPARK-12666.
---
.../scala/org/apache/spark/deploy/SparkSubmit.scala | 11 +++++++----
1 file changed, 7 insertions(+), 4 deletions(-)
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index 9feafc99ac07f..7b6d5a394bc35 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -897,9 +897,12 @@ private[spark] object SparkSubmitUtils {
val localIvyRoot = new File(ivySettings.getDefaultIvyUserDir, "local")
localIvy.setLocal(true)
localIvy.setRepository(new FileRepository(localIvyRoot))
- val ivyPattern = Seq("[organisation]", "[module]", "[revision]", "[type]s",
- "[artifact](-[classifier]).[ext]").mkString(File.separator)
- localIvy.addIvyPattern(localIvyRoot.getAbsolutePath + File.separator + ivyPattern)
+ val ivyPattern = Seq(localIvyRoot.getAbsolutePath, "[organisation]", "[module]", "[revision]",
+ "ivys", "ivy.xml").mkString(File.separator)
+ localIvy.addIvyPattern(ivyPattern)
+ val artifactPattern = Seq(localIvyRoot.getAbsolutePath, "[organisation]", "[module]",
+ "[revision]", "[type]s", "[artifact](-[classifier]).[ext]").mkString(File.separator)
+ localIvy.addArtifactPattern(artifactPattern)
localIvy.setName("local-ivy-cache")
cr.add(localIvy)
@@ -944,7 +947,7 @@ private[spark] object SparkSubmitUtils {
artifacts.foreach { mvn =>
val ri = ModuleRevisionId.newInstance(mvn.groupId, mvn.artifactId, mvn.version)
val dd = new DefaultDependencyDescriptor(ri, false, false)
- dd.addDependencyConfiguration(ivyConfName, ivyConfName)
+ dd.addDependencyConfiguration(ivyConfName, ivyConfName + "(runtime)")
// scalastyle:off println
printStream.println(s"${dd.getDependencyId} added as a dependency")
// scalastyle:on println
From 61ef74f2272faa7ce8f2badc7e00039908e3551f Mon Sep 17 00:00:00 2001
From: hqzizania
Date: Sat, 20 Aug 2016 18:52:44 -0700
Subject: [PATCH 042/270] [SPARK-17090][ML] Make tree aggregation level in
linear/logistic regression configurable
## What changes were proposed in this pull request?
Linear/logistic regression use treeAggregate with default depth (always = 2) for collecting coefficient gradient updates to the driver. For high dimensional problems, this can cause OOM error on the driver. This patch makes it configurable to avoid this problem if users' input data has many features. It adds a HasTreeDepth API in `sharedParams.scala`, and extends it to both Linear regression and logistic regression in .ml
Author: hqzizania
Closes #14717 from hqzizania/SPARK-17090.
---
.../classification/LogisticRegression.scala | 24 +++++++++++++-----
.../MultinomialLogisticRegression.scala | 16 ++++++++++--
.../ml/param/shared/SharedParamsCodeGen.scala | 4 ++-
.../spark/ml/param/shared/sharedParams.scala | 25 ++++++++++++++++---
.../ml/regression/LinearRegression.scala | 22 +++++++++++++---
5 files changed, 74 insertions(+), 17 deletions(-)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index ea31c68e4c943..757d52052d87f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -48,7 +48,7 @@ import org.apache.spark.storage.StorageLevel
*/
private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams
with HasRegParam with HasElasticNetParam with HasMaxIter with HasFitIntercept with HasTol
- with HasStandardization with HasWeightCol with HasThreshold {
+ with HasStandardization with HasWeightCol with HasThreshold with HasAggregationDepth {
/**
* Set threshold in binary classification, in range [0, 1].
@@ -256,6 +256,17 @@ class LogisticRegression @Since("1.2.0") (
@Since("1.5.0")
override def getThresholds: Array[Double] = super.getThresholds
+ /**
+ * Suggested depth for treeAggregate (>= 2).
+ * If the dimensions of features or the number of partitions are large,
+ * this param could be adjusted to a larger size.
+ * Default is 2.
+ * @group expertSetParam
+ */
+ @Since("2.1.0")
+ def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value)
+ setDefault(aggregationDepth -> 2)
+
private var optInitialModel: Option[LogisticRegressionModel] = None
/** @group setParam */
@@ -294,7 +305,8 @@ class LogisticRegression @Since("1.2.0") (
(c1._1.merge(c2._1), c1._2.merge(c2._2))
instances.treeAggregate(
- new MultivariateOnlineSummarizer, new MultiClassSummarizer)(seqOp, combOp)
+ new MultivariateOnlineSummarizer, new MultiClassSummarizer
+ )(seqOp, combOp, $(aggregationDepth))
}
val histogram = labelSummarizer.histogram
@@ -358,7 +370,7 @@ class LogisticRegression @Since("1.2.0") (
val bcFeaturesStd = instances.context.broadcast(featuresStd)
val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept),
- $(standardization), bcFeaturesStd, regParamL2, multinomial = false)
+ $(standardization), bcFeaturesStd, regParamL2, multinomial = false, $(aggregationDepth))
val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) {
new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
@@ -1331,8 +1343,8 @@ private class LogisticCostFun(
standardization: Boolean,
bcFeaturesStd: Broadcast[Array[Double]],
regParamL2: Double,
- multinomial: Boolean) extends DiffFunction[BDV[Double]] {
-
+ multinomial: Boolean,
+ aggregationDepth: Int) extends DiffFunction[BDV[Double]] {
override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = {
val coeffs = Vectors.fromBreeze(coefficients)
@@ -1347,7 +1359,7 @@ private class LogisticCostFun(
instances.treeAggregate(
new LogisticAggregator(bcCoeffs, bcFeaturesStd, numClasses, fitIntercept,
multinomial)
- )(seqOp, combOp)
+ )(seqOp, combOp, aggregationDepth)
}
val totalGradientArray = logisticAggregator.gradient.toArray
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultinomialLogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultinomialLogisticRegression.scala
index dfadd68c5f476..f85ac76a8d129 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultinomialLogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultinomialLogisticRegression.scala
@@ -44,7 +44,8 @@ import org.apache.spark.storage.StorageLevel
*/
private[classification] trait MultinomialLogisticRegressionParams
extends ProbabilisticClassifierParams with HasRegParam with HasElasticNetParam with HasMaxIter
- with HasFitIntercept with HasTol with HasStandardization with HasWeightCol {
+ with HasFitIntercept with HasTol with HasStandardization with HasWeightCol
+ with HasAggregationDepth {
/**
* Set thresholds in multiclass (or binary) classification to adjust the probability of
@@ -163,6 +164,17 @@ class MultinomialLogisticRegression @Since("2.1.0") (
@Since("2.1.0")
override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value)
+ /**
+ * Suggested depth for treeAggregate (>= 2).
+ * If the dimensions of features or the number of partitions are large,
+ * this param could be adjusted to a larger size.
+ * Default is 2.
+ * @group expertSetParam
+ */
+ @Since("2.1.0")
+ def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value)
+ setDefault(aggregationDepth -> 2)
+
override protected[spark] def train(dataset: Dataset[_]): MultinomialLogisticRegressionModel = {
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances: RDD[Instance] =
@@ -245,7 +257,7 @@ class MultinomialLogisticRegression @Since("2.1.0") (
val bcFeaturesStd = instances.context.broadcast(featuresStd)
val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept),
- $(standardization), bcFeaturesStd, regParamL2, multinomial = true)
+ $(standardization), bcFeaturesStd, regParamL2, multinomial = true, $(aggregationDepth))
val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) {
new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index 4ab0c16a1b4d0..0f48a16a429ff 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -78,7 +78,9 @@ private[shared] object SharedParamsCodeGen {
ParamDesc[String]("weightCol", "weight column name. If this is not set or empty, we treat " +
"all instance weights as 1.0"),
ParamDesc[String]("solver", "the solver algorithm for optimization. If this is not set or " +
- "empty, default value is 'auto'", Some("\"auto\"")))
+ "empty, default value is 'auto'", Some("\"auto\"")),
+ ParamDesc[Int]("aggregationDepth", "suggested depth for treeAggregate (>= 2)", Some("2"),
+ isValid = "ParamValidators.gtEq(2)"))
val code = genSharedParams(params)
val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala"
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index 64d6af2766ca9..6803772c63d62 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -334,10 +334,10 @@ private[ml] trait HasElasticNetParam extends Params {
private[ml] trait HasTol extends Params {
/**
- * Param for the convergence tolerance for iterative algorithms.
+ * Param for the convergence tolerance for iterative algorithms (>= 0).
* @group param
*/
- final val tol: DoubleParam = new DoubleParam(this, "tol", "the convergence tolerance for iterative algorithms")
+ final val tol: DoubleParam = new DoubleParam(this, "tol", "the convergence tolerance for iterative algorithms (>= 0)", ParamValidators.gtEq(0))
/** @group getParam */
final def getTol: Double = $(tol)
@@ -349,10 +349,10 @@ private[ml] trait HasTol extends Params {
private[ml] trait HasStepSize extends Params {
/**
- * Param for Step size to be used for each iteration of optimization.
+ * Param for Step size to be used for each iteration of optimization (> 0).
* @group param
*/
- final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size to be used for each iteration of optimization")
+ final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size to be used for each iteration of optimization (> 0)", ParamValidators.gt(0))
/** @group getParam */
final def getStepSize: Double = $(stepSize)
@@ -389,4 +389,21 @@ private[ml] trait HasSolver extends Params {
/** @group getParam */
final def getSolver: String = $(solver)
}
+
+/**
+ * Trait for shared param aggregationDepth (default: 2).
+ */
+private[ml] trait HasAggregationDepth extends Params {
+
+ /**
+ * Param for suggested depth for treeAggregate (>= 2).
+ * @group param
+ */
+ final val aggregationDepth: IntParam = new IntParam(this, "aggregationDepth", "suggested depth for treeAggregate (>= 2)", ParamValidators.gtEq(2))
+
+ setDefault(aggregationDepth, 2)
+
+ /** @group getParam */
+ final def getAggregationDepth: Int = $(aggregationDepth)
+}
// scalastyle:on
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 76be4204e9050..b1bb9b9fe0058 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -53,6 +53,7 @@ import org.apache.spark.storage.StorageLevel
private[regression] trait LinearRegressionParams extends PredictorParams
with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
with HasFitIntercept with HasStandardization with HasWeightCol with HasSolver
+ with HasAggregationDepth
/**
* Linear regression.
@@ -172,6 +173,17 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
def setSolver(value: String): this.type = set(solver, value)
setDefault(solver -> "auto")
+ /**
+ * Suggested depth for treeAggregate (>= 2).
+ * If the dimensions of features or the number of partitions are large,
+ * this param could be adjusted to a larger size.
+ * Default is 2.
+ * @group expertSetParam
+ */
+ @Since("2.1.0")
+ def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value)
+ setDefault(aggregationDepth -> 2)
+
override protected def train(dataset: Dataset[_]): LinearRegressionModel = {
// Extract the number of features before deciding optimization solver.
val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size
@@ -230,7 +242,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
(c1._1.merge(c2._1), c1._2.merge(c2._2))
instances.treeAggregate(
- new MultivariateOnlineSummarizer, new MultivariateOnlineSummarizer)(seqOp, combOp)
+ new MultivariateOnlineSummarizer, new MultivariateOnlineSummarizer
+ )(seqOp, combOp, $(aggregationDepth))
}
val yMean = ySummarizer.mean(0)
@@ -296,7 +309,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
val effectiveL2RegParam = (1.0 - $(elasticNetParam)) * effectiveRegParam
val costFun = new LeastSquaresCostFun(instances, yStd, yMean, $(fitIntercept),
- $(standardization), bcFeaturesStd, bcFeaturesMean, effectiveL2RegParam)
+ $(standardization), bcFeaturesStd, bcFeaturesMean, effectiveL2RegParam, $(aggregationDepth))
val optimizer = if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) {
new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
@@ -1016,7 +1029,8 @@ private class LeastSquaresCostFun(
standardization: Boolean,
bcFeaturesStd: Broadcast[Array[Double]],
bcFeaturesMean: Broadcast[Array[Double]],
- effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] {
+ effectiveL2regParam: Double,
+ aggregationDepth: Int) extends DiffFunction[BDV[Double]] {
override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = {
val coeffs = Vectors.fromBreeze(coefficients)
@@ -1029,7 +1043,7 @@ private class LeastSquaresCostFun(
instances.treeAggregate(
new LeastSquaresAggregator(bcCoeffs, labelStd, labelMean, fitIntercept, bcFeaturesStd,
- bcFeaturesMean))(seqOp, combOp)
+ bcFeaturesMean))(seqOp, combOp, aggregationDepth)
}
val totalGradientArray = leastSquaresAggregator.gradient.toArray
From 7f08a60b6e9acb89482fa0e268b192250d9ba6e4 Mon Sep 17 00:00:00 2001
From: Yanbo Liang
Date: Sun, 21 Aug 2016 02:23:31 -0700
Subject: [PATCH 043/270] [SPARK-16961][FOLLOW-UP][SPARKR] More robust test
case for spark.gaussianMixture.
## What changes were proposed in this pull request?
#14551 fixed off-by-one bug in ```randomizeInPlace``` and some test failure caused by this fix.
But for SparkR ```spark.gaussianMixture``` test case, the fix is inappropriate. It only changed the output result of native R which should be compared by SparkR, however, it did not change the R code in annotation which is used for reproducing the result in native R. It will confuse users who can not reproduce the same result in native R. This PR sends a more robust test case which can produce same result between SparkR and native R.
## How was this patch tested?
Unit test update.
Author: Yanbo Liang
Closes #14730 from yanboliang/spark-16961-followup.
---
R/pkg/inst/tests/testthat/test_mllib.R | 47 ++++++++++++++------------
1 file changed, 25 insertions(+), 22 deletions(-)
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index 67a3099101cf1..d15c2393b94ac 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -512,49 +512,52 @@ test_that("spark.gaussianMixture", {
# R code to reproduce the result.
# nolint start
#' library(mvtnorm)
- #' set.seed(100)
- #' a <- rmvnorm(4, c(0, 0))
- #' b <- rmvnorm(6, c(3, 4))
+ #' set.seed(1)
+ #' a <- rmvnorm(7, c(0, 0))
+ #' b <- rmvnorm(8, c(10, 10))
#' data <- rbind(a, b)
#' model <- mvnormalmixEM(data, k = 2)
#' model$lambda
#
- # [1] 0.4 0.6
+ # [1] 0.4666667 0.5333333
#
#' model$mu
#
- # [1] -0.2614822 0.5128697
- # [1] 2.647284 4.544682
+ # [1] 0.11731091 -0.06192351
+ # [1] 10.363673 9.897081
#
#' model$sigma
#
# [[1]]
- # [,1] [,2]
- # [1,] 0.08427399 0.00548772
- # [2,] 0.00548772 0.09090715
+ # [,1] [,2]
+ # [1,] 0.62049934 0.06880802
+ # [2,] 0.06880802 1.27431874
#
# [[2]]
- # [,1] [,2]
- # [1,] 0.1641373 -0.1673806
- # [2,] -0.1673806 0.7508951
+ # [,1] [,2]
+ # [1,] 0.2961543 0.160783
+ # [2,] 0.1607830 1.008878
# nolint end
- data <- list(list(-0.50219235, 0.1315312), list(-0.07891709, 0.8867848),
- list(0.11697127, 0.3186301), list(-0.58179068, 0.7145327),
- list(2.17474057, 3.6401379), list(3.08988614, 4.0962745),
- list(2.79836605, 4.7398405), list(3.12337950, 3.9706833),
- list(2.61114575, 4.5108563), list(2.08618581, 6.3102968))
+ data <- list(list(-0.6264538, 0.1836433), list(-0.8356286, 1.5952808),
+ list(0.3295078, -0.8204684), list(0.4874291, 0.7383247),
+ list(0.5757814, -0.3053884), list(1.5117812, 0.3898432),
+ list(-0.6212406, -2.2146999), list(11.1249309, 9.9550664),
+ list(9.9838097, 10.9438362), list(10.8212212, 10.5939013),
+ list(10.9189774, 10.7821363), list(10.0745650, 8.0106483),
+ list(10.6198257, 9.9438713), list(9.8442045, 8.5292476),
+ list(9.5218499, 10.4179416))
df <- createDataFrame(data, c("x1", "x2"))
model <- spark.gaussianMixture(df, ~ x1 + x2, k = 2)
stats <- summary(model)
- rLambda <- c(0.50861, 0.49139)
- rMu <- c(0.267, 1.195, 2.743, 4.730)
- rSigma <- c(1.099, 1.339, 1.339, 1.798,
- 0.145, -0.309, -0.309, 0.716)
+ rLambda <- c(0.4666667, 0.5333333)
+ rMu <- c(0.11731091, -0.06192351, 10.363673, 9.897081)
+ rSigma <- c(0.62049934, 0.06880802, 0.06880802, 1.27431874,
+ 0.2961543, 0.160783, 0.1607830, 1.008878)
expect_equal(stats$lambda, rLambda, tolerance = 1e-3)
expect_equal(unlist(stats$mu), rMu, tolerance = 1e-3)
expect_equal(unlist(stats$sigma), rSigma, tolerance = 1e-3)
p <- collect(select(predict(model, df), "prediction"))
- expect_equal(p$prediction, c(0, 0, 0, 0, 0, 1, 1, 1, 1, 1))
+ expect_equal(p$prediction, c(0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1))
# Test model save/load
modelPath <- tempfile(pattern = "spark-gaussianMixture", fileext = ".tmp")
From e328f577e81363f6b3f892931f20dbf68f7d29cf Mon Sep 17 00:00:00 2001
From: "wm624@hotmail.com"
Date: Sun, 21 Aug 2016 11:51:46 +0100
Subject: [PATCH 044/270] [SPARK-17002][CORE] Document that spark.ssl.protocol.
is required for SSL
## What changes were proposed in this pull request?
`spark.ssl.enabled`=true, but failing to set `spark.ssl.protocol` will fail and throw meaningless exception. `spark.ssl.protocol` is required when `spark.ssl.enabled`.
Improvement: require `spark.ssl.protocol` when initializing SSLContext, otherwise throws an exception to indicate that.
Remove the OrElse("default").
Document this requirement in configure.md
## How was this patch tested?
(Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests)
Manual tests:
Build document and check document
Configure `spark.ssl.enabled` only, it throws exception below:
6/08/16 16:04:37 INFO SecurityManager: SecurityManager: authentication disabled; ui acls disabled; users with view permissions: Set(mwang); groups with view permissions: Set(); users with modify permissions: Set(mwang); groups with modify permissions: Set()
Exception in thread "main" java.lang.IllegalArgumentException: requirement failed: spark.ssl.protocol is required when enabling SSL connections.
at scala.Predef$.require(Predef.scala:224)
at org.apache.spark.SecurityManager.(SecurityManager.scala:285)
at org.apache.spark.deploy.master.Master$.startRpcEnvAndEndpoint(Master.scala:1026)
at org.apache.spark.deploy.master.Master$.main(Master.scala:1011)
at org.apache.spark.deploy.master.Master.main(Master.scala)
Configure `spark.ssl.protocol` and `spark.ssl.protocol`
It works fine.
Author: wm624@hotmail.com
Closes #14674 from wangmiao1981/ssl.
---
core/src/main/scala/org/apache/spark/SecurityManager.scala | 5 ++++-
docs/configuration.md | 3 +++
2 files changed, 7 insertions(+), 1 deletion(-)
diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala
index f72c7ded5ea52..a6550b6ca8c94 100644
--- a/core/src/main/scala/org/apache/spark/SecurityManager.scala
+++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala
@@ -282,7 +282,10 @@ private[spark] class SecurityManager(sparkConf: SparkConf)
}: TrustManager
})
- val sslContext = SSLContext.getInstance(fileServerSSLOptions.protocol.getOrElse("Default"))
+ require(fileServerSSLOptions.protocol.isDefined,
+ "spark.ssl.protocol is required when enabling SSL connections.")
+
+ val sslContext = SSLContext.getInstance(fileServerSSLOptions.protocol.get)
sslContext.init(null, trustStoreManagers.getOrElse(credulousTrustStoreManagers), null)
val hostVerifier = new HostnameVerifier {
diff --git a/docs/configuration.md b/docs/configuration.md
index 96e8c6d08a1e3..4bda464b98bf6 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -1472,6 +1472,9 @@ Apart from these, the following properties are also available, and may be useful
Whether to enable SSL connections on all supported protocols.
+ When spark.ssl.enabled is configured, spark.ssl.protocol
+ is required.
+
All the SSL settings like spark.ssl.xxx where xxx is a
particular configuration property, denote the global configuration for all the supported
protocols. In order to override the global configuration for the particular protocol,
From ab7143463daf2056736c85e3a943c826b5992623 Mon Sep 17 00:00:00 2001
From: Xiangrui Meng
Date: Sun, 21 Aug 2016 10:31:25 -0700
Subject: [PATCH 045/270] [MINOR][R] add SparkR.Rcheck/ and SparkR_*.tar.gz to
R/.gitignore
## What changes were proposed in this pull request?
Ignore temp files generated by `check-cran.sh`.
Author: Xiangrui Meng
Closes #14740 from mengxr/R-gitignore.
---
R/.gitignore | 2 ++
1 file changed, 2 insertions(+)
diff --git a/R/.gitignore b/R/.gitignore
index 9a5889ba28b2a..c98504ab07781 100644
--- a/R/.gitignore
+++ b/R/.gitignore
@@ -4,3 +4,5 @@
lib
pkg/man
pkg/html
+SparkR.Rcheck/
+SparkR_*.tar.gz
From 91c2397684ab791572ac57ffb2a924ff058bb64f Mon Sep 17 00:00:00 2001
From: Dongjoon Hyun
Date: Sun, 21 Aug 2016 22:07:47 +0200
Subject: [PATCH 046/270] [SPARK-17098][SQL] Fix `NullPropagation` optimizer to
handle `COUNT(NULL) OVER` correctly
## What changes were proposed in this pull request?
Currently, `NullPropagation` optimizer replaces `COUNT` on null literals in a bottom-up fashion. During that, `WindowExpression` is not covered properly. This PR adds the missing propagation logic.
**Before**
```scala
scala> sql("SELECT COUNT(1 + NULL) OVER ()").show
java.lang.UnsupportedOperationException: Cannot evaluate expression: cast(0 as bigint) windowspecdefinition(ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)
```
**After**
```scala
scala> sql("SELECT COUNT(1 + NULL) OVER ()").show
+----------------------------------------------------------------------------------------------+
|count((1 + CAST(NULL AS INT))) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)|
+----------------------------------------------------------------------------------------------+
| 0|
+----------------------------------------------------------------------------------------------+
```
## How was this patch tested?
Pass the Jenkins test with a new test case.
Author: Dongjoon Hyun
Closes #14689 from dongjoon-hyun/SPARK-17098.
---
.../sql/catalyst/optimizer/Optimizer.scala | 2 +
.../sql-tests/inputs/null-propagation.sql | 9 +++++
.../results/null-propagation.sql.out | 38 +++++++++++++++++++
3 files changed, 49 insertions(+)
create mode 100644 sql/core/src/test/resources/sql-tests/inputs/null-propagation.sql
create mode 100644 sql/core/src/test/resources/sql-tests/results/null-propagation.sql.out
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index ce57f05868fe1..9a0ff8a9b3211 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -619,6 +619,8 @@ object NullPropagation extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
+ case e @ WindowExpression(Cast(Literal(0L, _), _), _) =>
+ Cast(Literal(0L), e.dataType)
case e @ AggregateExpression(Count(exprs), _, _, _) if !exprs.exists(nonNullLiteral) =>
Cast(Literal(0L), e.dataType)
case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType)
diff --git a/sql/core/src/test/resources/sql-tests/inputs/null-propagation.sql b/sql/core/src/test/resources/sql-tests/inputs/null-propagation.sql
new file mode 100644
index 0000000000000..66549da7971d3
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/null-propagation.sql
@@ -0,0 +1,9 @@
+
+-- count(null) should be 0
+SELECT COUNT(NULL) FROM VALUES 1, 2, 3;
+SELECT COUNT(1 + NULL) FROM VALUES 1, 2, 3;
+
+-- count(null) on window should be 0
+SELECT COUNT(NULL) OVER () FROM VALUES 1, 2, 3;
+SELECT COUNT(1 + NULL) OVER () FROM VALUES 1, 2, 3;
+
diff --git a/sql/core/src/test/resources/sql-tests/results/null-propagation.sql.out b/sql/core/src/test/resources/sql-tests/results/null-propagation.sql.out
new file mode 100644
index 0000000000000..ed3a651aa6614
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/null-propagation.sql.out
@@ -0,0 +1,38 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 4
+
+
+-- !query 0
+SELECT COUNT(NULL) FROM VALUES 1, 2, 3
+-- !query 0 schema
+struct
+-- !query 0 output
+0
+
+
+-- !query 1
+SELECT COUNT(1 + NULL) FROM VALUES 1, 2, 3
+-- !query 1 schema
+struct
+-- !query 1 output
+0
+
+
+-- !query 2
+SELECT COUNT(NULL) OVER () FROM VALUES 1, 2, 3
+-- !query 2 schema
+struct
+-- !query 2 output
+0
+0
+0
+
+
+-- !query 3
+SELECT COUNT(1 + NULL) OVER () FROM VALUES 1, 2, 3
+-- !query 3 schema
+struct
+-- !query 3 output
+0
+0
+0
From b2074b664a9c269c4103760d40c4a14e7aeb1e83 Mon Sep 17 00:00:00 2001
From: Wenchen Fan
Date: Sun, 21 Aug 2016 22:23:14 -0700
Subject: [PATCH 047/270] [SPARK-16498][SQL] move hive hack for data source
table into HiveExternalCatalog
## What changes were proposed in this pull request?
Spark SQL doesn't have its own meta store yet, and use hive's currently. However, hive's meta store has some limitations(e.g. columns can't be too many, not case-preserving, bad decimal type support, etc.), so we have some hacks to successfully store data source table metadata into hive meta store, i.e. put all the information in table properties.
This PR moves these hacks to `HiveExternalCatalog`, tries to isolate hive specific logic in one place.
changes overview:
1. **before this PR**: we need to put metadata(schema, partition columns, etc.) of data source tables to table properties before saving it to external catalog, even the external catalog doesn't use hive metastore(e.g. `InMemoryCatalog`)
**after this PR**: the table properties tricks are only in `HiveExternalCatalog`, the caller side doesn't need to take care of it anymore.
2. **before this PR**: because the table properties tricks are done outside of external catalog, so we also need to revert these tricks when we read the table metadata from external catalog and use it. e.g. in `DescribeTableCommand` we will read schema and partition columns from table properties.
**after this PR**: The table metadata read from external catalog is exactly the same with what we saved to it.
bonus: now we can create data source table using `SessionCatalog`, if schema is specified.
breaks: `schemaStringLengthThreshold` is not configurable anymore. `hive.default.rcfile.serde` is not configurable anymore.
## How was this patch tested?
existing tests.
Author: Wenchen Fan
Closes #14155 from cloud-fan/catalog-table.
---
.../ml/source/libsvm/LibSVMRelation.scala | 3 +-
.../spark/sql/execution/SparkSqlParser.scala | 4 +-
.../command/createDataSourceTables.scala | 255 ++------------
.../spark/sql/execution/command/ddl.scala | 94 +----
.../spark/sql/execution/command/tables.scala | 59 +---
.../datasources/DataSourceStrategy.scala | 22 +-
.../datasources/WriterContainer.scala | 16 +-
.../datasources/csv/CSVRelation.scala | 5 +-
.../datasources/json/JsonFileFormat.scala | 3 +-
.../parquet/ParquetFileFormat.scala | 4 +-
.../datasources/text/TextFileFormat.scala | 3 +-
.../apache/spark/sql/internal/HiveSerDe.scala | 6 +-
.../execution/command/DDLCommandSuite.scala | 6 +-
.../sql/execution/command/DDLSuite.scala | 110 +-----
.../sources/CreateTableAsSelectSuite.scala | 5 +-
.../spark/sql/hive/HiveExternalCatalog.scala | 328 +++++++++++++++++-
.../spark/sql/hive/HiveMetastoreCatalog.scala | 67 +---
.../sql/hive/client/HiveClientImpl.scala | 16 +-
.../spark/sql/hive/orc/OrcFileFormat.scala | 3 +-
.../sql/hive/MetastoreDataSourcesSuite.scala | 110 +++---
.../sql/hive/execution/HiveCommandSuite.scala | 40 ++-
.../sql/hive/execution/HiveDDLSuite.scala | 23 ++
.../sql/hive/execution/SQLQuerySuite.scala | 4 +-
.../sql/sources/SimpleTextRelation.scala | 3 +-
24 files changed, 536 insertions(+), 653 deletions(-)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
index 034223e115389..5c79c6905801c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
@@ -33,7 +33,6 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
-import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
@@ -51,7 +50,7 @@ private[libsvm] class LibSVMOutputWriter(
new TextOutputFormat[NullWritable, Text]() {
override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
val configuration = context.getConfiguration
- val uniqueWriteJobId = configuration.get(CreateDataSourceTableUtils.DATASOURCE_WRITEJOBUUID)
+ val uniqueWriteJobId = configuration.get(WriterContainer.DATASOURCE_WRITEJOBUUID)
val taskAttemptId = context.getTaskAttemptID
val split = taskAttemptId.getTaskID.getId
new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
index 71c3bd31e02e4..e32d30178eeb1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
@@ -971,7 +971,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder {
// Storage format
val defaultStorage: CatalogStorageFormat = {
val defaultStorageType = conf.getConfString("hive.default.fileformat", "textfile")
- val defaultHiveSerde = HiveSerDe.sourceToSerDe(defaultStorageType, conf)
+ val defaultHiveSerde = HiveSerDe.sourceToSerDe(defaultStorageType)
CatalogStorageFormat(
locationUri = None,
inputFormat = defaultHiveSerde.flatMap(_.inputFormat)
@@ -1115,7 +1115,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder {
override def visitGenericFileFormat(
ctx: GenericFileFormatContext): CatalogStorageFormat = withOrigin(ctx) {
val source = ctx.identifier.getText
- HiveSerDe.sourceToSerDe(source, conf) match {
+ HiveSerDe.sourceToSerDe(source) match {
case Some(s) =>
CatalogStorageFormat.empty.copy(
inputFormat = s.inputFormat,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala
index 7b028e72ed0a8..7400a0e7bb1f1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala
@@ -17,10 +17,6 @@
package org.apache.spark.sql.execution.command
-import scala.collection.mutable
-import scala.util.control.NonFatal
-
-import org.apache.spark.internal.Logging
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
@@ -28,7 +24,6 @@ import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.datasources._
-import org.apache.spark.sql.internal.HiveSerDe
import org.apache.spark.sql.sources.{BaseRelation, InsertableRelation}
import org.apache.spark.sql.types._
@@ -97,16 +92,19 @@ case class CreateDataSourceTableCommand(
}
}
- CreateDataSourceTableUtils.createDataSourceTable(
- sparkSession = sparkSession,
- tableIdent = tableIdent,
+ val table = CatalogTable(
+ identifier = tableIdent,
+ tableType = if (isExternal) CatalogTableType.EXTERNAL else CatalogTableType.MANAGED,
+ storage = CatalogStorageFormat.empty.copy(properties = optionsWithPath),
schema = dataSource.schema,
- partitionColumns = partitionColumns,
- bucketSpec = bucketSpec,
- provider = provider,
- options = optionsWithPath,
- isExternal = isExternal)
-
+ provider = Some(provider),
+ partitionColumnNames = partitionColumns,
+ bucketSpec = bucketSpec
+ )
+
+ // We will return Nil or throw exception at the beginning if the table already exists, so when
+ // we reach here, the table should not exist and we should set `ignoreIfExists` to false.
+ sessionState.catalog.createTable(table, ignoreIfExists = false)
Seq.empty[Row]
}
}
@@ -193,7 +191,7 @@ case class CreateDataSourceTableAsSelectCommand(
}
existingSchema = Some(l.schema)
case s: SimpleCatalogRelation if DDLUtils.isDatasourceTable(s.metadata) =>
- existingSchema = Some(DDLUtils.getSchemaFromTableProperties(s.metadata))
+ existingSchema = Some(s.metadata.schema)
case o =>
throw new AnalysisException(s"Saving data in ${o.toString} is not supported.")
}
@@ -233,15 +231,17 @@ case class CreateDataSourceTableAsSelectCommand(
// We will use the schema of resolved.relation as the schema of the table (instead of
// the schema of df). It is important since the nullability may be changed by the relation
// provider (for example, see org.apache.spark.sql.parquet.DefaultSource).
- CreateDataSourceTableUtils.createDataSourceTable(
- sparkSession = sparkSession,
- tableIdent = tableIdent,
- schema = result.schema,
- partitionColumns = partitionColumns,
- bucketSpec = bucketSpec,
- provider = provider,
- options = optionsWithPath,
- isExternal = isExternal)
+ val schema = result.schema
+ val table = CatalogTable(
+ identifier = tableIdent,
+ tableType = if (isExternal) CatalogTableType.EXTERNAL else CatalogTableType.MANAGED,
+ storage = CatalogStorageFormat.empty.copy(properties = optionsWithPath),
+ schema = schema,
+ provider = Some(provider),
+ partitionColumnNames = partitionColumns,
+ bucketSpec = bucketSpec
+ )
+ sessionState.catalog.createTable(table, ignoreIfExists = false)
}
// Refresh the cache of the table in the catalog.
@@ -249,210 +249,3 @@ case class CreateDataSourceTableAsSelectCommand(
Seq.empty[Row]
}
}
-
-
-object CreateDataSourceTableUtils extends Logging {
-
- val DATASOURCE_PREFIX = "spark.sql.sources."
- val DATASOURCE_PROVIDER = DATASOURCE_PREFIX + "provider"
- val DATASOURCE_WRITEJOBUUID = DATASOURCE_PREFIX + "writeJobUUID"
- val DATASOURCE_OUTPUTPATH = DATASOURCE_PREFIX + "output.path"
- val DATASOURCE_SCHEMA = DATASOURCE_PREFIX + "schema"
- val DATASOURCE_SCHEMA_PREFIX = DATASOURCE_SCHEMA + "."
- val DATASOURCE_SCHEMA_NUMPARTS = DATASOURCE_SCHEMA_PREFIX + "numParts"
- val DATASOURCE_SCHEMA_NUMPARTCOLS = DATASOURCE_SCHEMA_PREFIX + "numPartCols"
- val DATASOURCE_SCHEMA_NUMSORTCOLS = DATASOURCE_SCHEMA_PREFIX + "numSortCols"
- val DATASOURCE_SCHEMA_NUMBUCKETS = DATASOURCE_SCHEMA_PREFIX + "numBuckets"
- val DATASOURCE_SCHEMA_NUMBUCKETCOLS = DATASOURCE_SCHEMA_PREFIX + "numBucketCols"
- val DATASOURCE_SCHEMA_PART_PREFIX = DATASOURCE_SCHEMA_PREFIX + "part."
- val DATASOURCE_SCHEMA_PARTCOL_PREFIX = DATASOURCE_SCHEMA_PREFIX + "partCol."
- val DATASOURCE_SCHEMA_BUCKETCOL_PREFIX = DATASOURCE_SCHEMA_PREFIX + "bucketCol."
- val DATASOURCE_SCHEMA_SORTCOL_PREFIX = DATASOURCE_SCHEMA_PREFIX + "sortCol."
-
- def createDataSourceTable(
- sparkSession: SparkSession,
- tableIdent: TableIdentifier,
- schema: StructType,
- partitionColumns: Array[String],
- bucketSpec: Option[BucketSpec],
- provider: String,
- options: Map[String, String],
- isExternal: Boolean): Unit = {
- val tableProperties = new mutable.HashMap[String, String]
- tableProperties.put(DATASOURCE_PROVIDER, provider)
-
- // Serialized JSON schema string may be too long to be stored into a single metastore table
- // property. In this case, we split the JSON string and store each part as a separate table
- // property.
- val threshold = sparkSession.sessionState.conf.schemaStringLengthThreshold
- val schemaJsonString = schema.json
- // Split the JSON string.
- val parts = schemaJsonString.grouped(threshold).toSeq
- tableProperties.put(DATASOURCE_SCHEMA_NUMPARTS, parts.size.toString)
- parts.zipWithIndex.foreach { case (part, index) =>
- tableProperties.put(s"$DATASOURCE_SCHEMA_PART_PREFIX$index", part)
- }
-
- if (partitionColumns.length > 0) {
- tableProperties.put(DATASOURCE_SCHEMA_NUMPARTCOLS, partitionColumns.length.toString)
- partitionColumns.zipWithIndex.foreach { case (partCol, index) =>
- tableProperties.put(s"$DATASOURCE_SCHEMA_PARTCOL_PREFIX$index", partCol)
- }
- }
-
- if (bucketSpec.isDefined) {
- val BucketSpec(numBuckets, bucketColumnNames, sortColumnNames) = bucketSpec.get
-
- tableProperties.put(DATASOURCE_SCHEMA_NUMBUCKETS, numBuckets.toString)
- tableProperties.put(DATASOURCE_SCHEMA_NUMBUCKETCOLS, bucketColumnNames.length.toString)
- bucketColumnNames.zipWithIndex.foreach { case (bucketCol, index) =>
- tableProperties.put(s"$DATASOURCE_SCHEMA_BUCKETCOL_PREFIX$index", bucketCol)
- }
-
- if (sortColumnNames.nonEmpty) {
- tableProperties.put(DATASOURCE_SCHEMA_NUMSORTCOLS, sortColumnNames.length.toString)
- sortColumnNames.zipWithIndex.foreach { case (sortCol, index) =>
- tableProperties.put(s"$DATASOURCE_SCHEMA_SORTCOL_PREFIX$index", sortCol)
- }
- }
- }
-
- val tableType = if (isExternal) {
- tableProperties.put("EXTERNAL", "TRUE")
- CatalogTableType.EXTERNAL
- } else {
- tableProperties.put("EXTERNAL", "FALSE")
- CatalogTableType.MANAGED
- }
-
- val maybeSerDe = HiveSerDe.sourceToSerDe(provider, sparkSession.sessionState.conf)
- val dataSource =
- DataSource(
- sparkSession,
- userSpecifiedSchema = Some(schema),
- partitionColumns = partitionColumns,
- bucketSpec = bucketSpec,
- className = provider,
- options = options)
-
- def newSparkSQLSpecificMetastoreTable(): CatalogTable = {
- CatalogTable(
- identifier = tableIdent,
- tableType = tableType,
- schema = new StructType,
- provider = Some(provider),
- storage = CatalogStorageFormat(
- locationUri = None,
- inputFormat = None,
- outputFormat = None,
- serde = None,
- compressed = false,
- properties = options
- ),
- properties = tableProperties.toMap)
- }
-
- def newHiveCompatibleMetastoreTable(
- relation: HadoopFsRelation,
- serde: HiveSerDe): CatalogTable = {
- assert(partitionColumns.isEmpty)
- assert(relation.partitionSchema.isEmpty)
-
- CatalogTable(
- identifier = tableIdent,
- tableType = tableType,
- storage = CatalogStorageFormat(
- locationUri = Some(relation.location.paths.map(_.toUri.toString).head),
- inputFormat = serde.inputFormat,
- outputFormat = serde.outputFormat,
- serde = serde.serde,
- compressed = false,
- properties = options
- ),
- schema = relation.schema,
- provider = Some(provider),
- properties = tableProperties.toMap,
- viewText = None)
- }
-
- // TODO: Support persisting partitioned data source relations in Hive compatible format
- val qualifiedTableName = tableIdent.quotedString
- val skipHiveMetadata = options.getOrElse("skipHiveMetadata", "false").toBoolean
- val resolvedRelation = dataSource.resolveRelation(checkPathExist = false)
- val (hiveCompatibleTable, logMessage) = (maybeSerDe, resolvedRelation) match {
- case _ if skipHiveMetadata =>
- val message =
- s"Persisting partitioned data source relation $qualifiedTableName into " +
- "Hive metastore in Spark SQL specific format, which is NOT compatible with Hive."
- (None, message)
-
- case (Some(serde), relation: HadoopFsRelation) if relation.location.paths.length == 1 &&
- relation.partitionSchema.isEmpty && relation.bucketSpec.isEmpty =>
- val hiveTable = newHiveCompatibleMetastoreTable(relation, serde)
- val message =
- s"Persisting data source relation $qualifiedTableName with a single input path " +
- s"into Hive metastore in Hive compatible format. Input path: " +
- s"${relation.location.paths.head}."
- (Some(hiveTable), message)
-
- case (Some(serde), relation: HadoopFsRelation) if relation.partitionSchema.nonEmpty =>
- val message =
- s"Persisting partitioned data source relation $qualifiedTableName into " +
- "Hive metastore in Spark SQL specific format, which is NOT compatible with Hive. " +
- "Input path(s): " + relation.location.paths.mkString("\n", "\n", "")
- (None, message)
-
- case (Some(serde), relation: HadoopFsRelation) if relation.bucketSpec.nonEmpty =>
- val message =
- s"Persisting bucketed data source relation $qualifiedTableName into " +
- "Hive metastore in Spark SQL specific format, which is NOT compatible with Hive. " +
- "Input path(s): " + relation.location.paths.mkString("\n", "\n", "")
- (None, message)
-
- case (Some(serde), relation: HadoopFsRelation) =>
- val message =
- s"Persisting data source relation $qualifiedTableName with multiple input paths into " +
- "Hive metastore in Spark SQL specific format, which is NOT compatible with Hive. " +
- s"Input paths: " + relation.location.paths.mkString("\n", "\n", "")
- (None, message)
-
- case (Some(serde), _) =>
- val message =
- s"Data source relation $qualifiedTableName is not a " +
- s"${classOf[HadoopFsRelation].getSimpleName}. Persisting it into Hive metastore " +
- "in Spark SQL specific format, which is NOT compatible with Hive."
- (None, message)
-
- case _ =>
- val message =
- s"Couldn't find corresponding Hive SerDe for data source provider $provider. " +
- s"Persisting data source relation $qualifiedTableName into Hive metastore in " +
- s"Spark SQL specific format, which is NOT compatible with Hive."
- (None, message)
- }
-
- (hiveCompatibleTable, logMessage) match {
- case (Some(table), message) =>
- // We first try to save the metadata of the table in a Hive compatible way.
- // If Hive throws an error, we fall back to save its metadata in the Spark SQL
- // specific way.
- try {
- logInfo(message)
- sparkSession.sessionState.catalog.createTable(table, ignoreIfExists = false)
- } catch {
- case NonFatal(e) =>
- val warningMessage =
- s"Could not persist $qualifiedTableName in a Hive compatible way. Persisting " +
- s"it into Hive metastore in Spark SQL specific format."
- logWarning(warningMessage, e)
- val table = newSparkSQLSpecificMetastoreTable()
- sparkSession.sessionState.catalog.createTable(table, ignoreIfExists = false)
- }
-
- case (None, message) =>
- logWarning(message)
- val table = newSparkSQLSpecificMetastoreTable()
- sparkSession.sessionState.catalog.createTable(table, ignoreIfExists = false)
- }
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
index 2eff9337bc14a..3817f919f3a5a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
@@ -27,10 +27,9 @@ import org.apache.hadoop.mapred.{FileInputFormat, JobConf}
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogDatabase, CatalogTable, CatalogTablePartition, CatalogTableType, SessionCatalog}
-import org.apache.spark.sql.catalyst.catalog.CatalogTypes._
+import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable, CatalogTablePartition, CatalogTableType, SessionCatalog}
+import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
-import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils._
import org.apache.spark.sql.execution.datasources.PartitioningUtils
import org.apache.spark.sql.types._
@@ -234,10 +233,8 @@ case class AlterTableSetPropertiesCommand(
extends RunnableCommand {
override def run(sparkSession: SparkSession): Seq[Row] = {
- val ident = if (isView) "VIEW" else "TABLE"
val catalog = sparkSession.sessionState.catalog
DDLUtils.verifyAlterTableType(catalog, tableName, isView)
- DDLUtils.verifyTableProperties(properties.keys.toSeq, s"ALTER $ident")
val table = catalog.getTableMetadata(tableName)
// This overrides old properties
val newTable = table.copy(properties = table.properties ++ properties)
@@ -264,10 +261,8 @@ case class AlterTableUnsetPropertiesCommand(
extends RunnableCommand {
override def run(sparkSession: SparkSession): Seq[Row] = {
- val ident = if (isView) "VIEW" else "TABLE"
val catalog = sparkSession.sessionState.catalog
DDLUtils.verifyAlterTableType(catalog, tableName, isView)
- DDLUtils.verifyTableProperties(propKeys, s"ALTER $ident")
val table = catalog.getTableMetadata(tableName)
if (!ifExists) {
propKeys.foreach { k =>
@@ -445,11 +440,11 @@ case class AlterTableRecoverPartitionsCommand(
if (!catalog.tableExists(tableName)) {
throw new AnalysisException(s"Table $tableName in $cmd does not exist.")
}
- val table = catalog.getTableMetadata(tableName)
if (catalog.isTemporaryTable(tableName)) {
throw new AnalysisException(
s"Operation not allowed: $cmd on temporary tables: $tableName")
}
+ val table = catalog.getTableMetadata(tableName)
if (DDLUtils.isDatasourceTable(table)) {
throw new AnalysisException(
s"Operation not allowed: $cmd on datasource tables: $tableName")
@@ -458,7 +453,7 @@ case class AlterTableRecoverPartitionsCommand(
throw new AnalysisException(
s"Operation not allowed: $cmd only works on external tables: $tableName")
}
- if (!DDLUtils.isTablePartitioned(table)) {
+ if (table.partitionColumnNames.isEmpty) {
throw new AnalysisException(
s"Operation not allowed: $cmd only works on partitioned tables: $tableName")
}
@@ -584,13 +579,8 @@ case class AlterTableSetLocationCommand(
object DDLUtils {
-
- def isDatasourceTable(props: Map[String, String]): Boolean = {
- props.contains(DATASOURCE_PROVIDER)
- }
-
def isDatasourceTable(table: CatalogTable): Boolean = {
- isDatasourceTable(table.properties)
+ table.provider.isDefined && table.provider.get != "hive"
}
/**
@@ -611,78 +601,4 @@ object DDLUtils {
case _ =>
})
}
-
- /**
- * If the given table properties (or SerDe properties) contains datasource properties,
- * throw an exception.
- */
- def verifyTableProperties(propKeys: Seq[String], operation: String): Unit = {
- val datasourceKeys = propKeys.filter(_.startsWith(DATASOURCE_PREFIX))
- if (datasourceKeys.nonEmpty) {
- throw new AnalysisException(s"Operation not allowed: $operation property keys may not " +
- s"start with '$DATASOURCE_PREFIX': ${datasourceKeys.mkString("[", ", ", "]")}")
- }
- }
-
- def isTablePartitioned(table: CatalogTable): Boolean = {
- table.partitionColumnNames.nonEmpty || table.properties.contains(DATASOURCE_SCHEMA_NUMPARTCOLS)
- }
-
- // A persisted data source table always store its schema in the catalog.
- def getSchemaFromTableProperties(metadata: CatalogTable): StructType = {
- require(isDatasourceTable(metadata))
- val msgSchemaCorrupted = "Could not read schema from the metastore because it is corrupted."
- val props = metadata.properties
- props.get(DATASOURCE_SCHEMA).map { schema =>
- // Originally, we used spark.sql.sources.schema to store the schema of a data source table.
- // After SPARK-6024, we removed this flag.
- // Although we are not using spark.sql.sources.schema any more, we need to still support.
- DataType.fromJson(schema).asInstanceOf[StructType]
- } getOrElse {
- props.get(DATASOURCE_SCHEMA_NUMPARTS).map { numParts =>
- val parts = (0 until numParts.toInt).map { index =>
- val part = metadata.properties.get(s"$DATASOURCE_SCHEMA_PART_PREFIX$index").orNull
- if (part == null) {
- throw new AnalysisException(msgSchemaCorrupted +
- s" (missing part $index of the schema, $numParts parts are expected).")
- }
- part
- }
- // Stick all parts back to a single schema string.
- DataType.fromJson(parts.mkString).asInstanceOf[StructType]
- } getOrElse(throw new AnalysisException(msgSchemaCorrupted))
- }
- }
-
- private def getColumnNamesByType(
- props: Map[String, String], colType: String, typeName: String): Seq[String] = {
- require(isDatasourceTable(props))
-
- for {
- numCols <- props.get(s"spark.sql.sources.schema.num${colType.capitalize}Cols").toSeq
- index <- 0 until numCols.toInt
- } yield props.getOrElse(
- s"$DATASOURCE_SCHEMA_PREFIX${colType}Col.$index",
- throw new AnalysisException(
- s"Corrupted $typeName in catalog: $numCols parts expected, but part $index is missing."
- )
- )
- }
-
- def getPartitionColumnsFromTableProperties(metadata: CatalogTable): Seq[String] = {
- getColumnNamesByType(metadata.properties, "part", "partitioning columns")
- }
-
- def getBucketSpecFromTableProperties(metadata: CatalogTable): Option[BucketSpec] = {
- if (isDatasourceTable(metadata)) {
- metadata.properties.get(DATASOURCE_SCHEMA_NUMBUCKETS).map { numBuckets =>
- BucketSpec(
- numBuckets.toInt,
- getColumnNamesByType(metadata.properties, "bucket", "bucketing columns"),
- getColumnNamesByType(metadata.properties, "sort", "sorting columns"))
- }
- } else {
- None
- }
- }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
index 720399ecc596a..af2b5ffd1c427 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
@@ -119,11 +119,9 @@ case class CreateTableLikeCommand(
case class CreateTableCommand(table: CatalogTable, ifNotExists: Boolean) extends RunnableCommand {
override def run(sparkSession: SparkSession): Seq[Row] = {
- DDLUtils.verifyTableProperties(table.properties.keys.toSeq, "CREATE TABLE")
sparkSession.sessionState.catalog.createTable(table, ifNotExists)
Seq.empty[Row]
}
-
}
@@ -414,8 +412,8 @@ case class DescribeTableCommand(table: TableIdentifier, isExtended: Boolean, isF
describeSchema(catalog.lookupRelation(table).schema, result)
} else {
val metadata = catalog.getTableMetadata(table)
+ describeSchema(metadata.schema, result)
- describeSchema(metadata, result)
if (isExtended) {
describeExtended(metadata, result)
} else if (isFormatted) {
@@ -429,20 +427,10 @@ case class DescribeTableCommand(table: TableIdentifier, isExtended: Boolean, isF
}
private def describePartitionInfo(table: CatalogTable, buffer: ArrayBuffer[Row]): Unit = {
- if (DDLUtils.isDatasourceTable(table)) {
- val partColNames = DDLUtils.getPartitionColumnsFromTableProperties(table)
- if (partColNames.nonEmpty) {
- val userSpecifiedSchema = DDLUtils.getSchemaFromTableProperties(table)
- append(buffer, "# Partition Information", "", "")
- append(buffer, s"# ${output.head.name}", output(1).name, output(2).name)
- describeSchema(StructType(partColNames.map(userSpecifiedSchema(_))), buffer)
- }
- } else {
- if (table.partitionColumnNames.nonEmpty) {
- append(buffer, "# Partition Information", "", "")
- append(buffer, s"# ${output.head.name}", output(1).name, output(2).name)
- describeSchema(table.partitionSchema, buffer)
- }
+ if (table.partitionColumnNames.nonEmpty) {
+ append(buffer, "# Partition Information", "", "")
+ append(buffer, s"# ${output.head.name}", output(1).name, output(2).name)
+ describeSchema(table.partitionSchema, buffer)
}
}
@@ -466,11 +454,7 @@ case class DescribeTableCommand(table: TableIdentifier, isExtended: Boolean, isF
append(buffer, "Table Type:", table.tableType.name, "")
append(buffer, "Table Parameters:", "", "")
- table.properties.filterNot {
- // Hides schema properties that hold user-defined schema, partition columns, and bucketing
- // information since they are already extracted and shown in other parts.
- case (key, _) => key.startsWith(CreateDataSourceTableUtils.DATASOURCE_SCHEMA)
- }.foreach { case (key, value) =>
+ table.properties.foreach { case (key, value) =>
append(buffer, s" $key", value, "")
}
@@ -493,7 +477,7 @@ case class DescribeTableCommand(table: TableIdentifier, isExtended: Boolean, isF
}
private def describeBucketingInfo(metadata: CatalogTable, buffer: ArrayBuffer[Row]): Unit = {
- def appendBucketInfo(bucketSpec: Option[BucketSpec]) = bucketSpec match {
+ metadata.bucketSpec match {
case Some(BucketSpec(numBuckets, bucketColumnNames, sortColumnNames)) =>
append(buffer, "Num Buckets:", numBuckets.toString, "")
append(buffer, "Bucket Columns:", bucketColumnNames.mkString("[", ", ", "]"), "")
@@ -501,23 +485,6 @@ case class DescribeTableCommand(table: TableIdentifier, isExtended: Boolean, isF
case _ =>
}
-
- if (DDLUtils.isDatasourceTable(metadata)) {
- appendBucketInfo(DDLUtils.getBucketSpecFromTableProperties(metadata))
- } else {
- appendBucketInfo(metadata.bucketSpec)
- }
- }
-
- private def describeSchema(
- tableDesc: CatalogTable,
- buffer: ArrayBuffer[Row]): Unit = {
- if (DDLUtils.isDatasourceTable(tableDesc)) {
- val schema = DDLUtils.getSchemaFromTableProperties(tableDesc)
- describeSchema(schema, buffer)
- } else {
- describeSchema(tableDesc.schema, buffer)
- }
}
private def describeSchema(schema: StructType, buffer: ArrayBuffer[Row]): Unit = {
@@ -678,7 +645,7 @@ case class ShowPartitionsCommand(
s"SHOW PARTITIONS is not allowed on a view or index table: ${tab.qualifiedName}")
}
- if (!DDLUtils.isTablePartitioned(tab)) {
+ if (tab.partitionColumnNames.isEmpty) {
throw new AnalysisException(
s"SHOW PARTITIONS is not allowed on a table that is not partitioned: ${tab.qualifiedName}")
}
@@ -729,6 +696,7 @@ case class ShowCreateTableCommand(table: TableIdentifier) extends RunnableComman
val tableMetadata = catalog.getTableMetadata(table)
+ // TODO: unify this after we unify the CREATE TABLE syntax for hive serde and data source table.
val stmt = if (DDLUtils.isDatasourceTable(tableMetadata)) {
showCreateDataSourceTable(tableMetadata)
} else {
@@ -872,15 +840,14 @@ case class ShowCreateTableCommand(table: TableIdentifier) extends RunnableComman
private def showDataSourceTableDataColumns(
metadata: CatalogTable, builder: StringBuilder): Unit = {
- val schema = DDLUtils.getSchemaFromTableProperties(metadata)
- val columns = schema.fields.map(f => s"${quoteIdentifier(f.name)} ${f.dataType.sql}")
+ val columns = metadata.schema.fields.map(f => s"${quoteIdentifier(f.name)} ${f.dataType.sql}")
builder ++= columns.mkString("(", ", ", ")\n")
}
private def showDataSourceTableOptions(metadata: CatalogTable, builder: StringBuilder): Unit = {
val props = metadata.properties
- builder ++= s"USING ${props(CreateDataSourceTableUtils.DATASOURCE_PROVIDER)}\n"
+ builder ++= s"USING ${metadata.provider.get}\n"
val dataSourceOptions = metadata.storage.properties.filterNot {
case (key, value) =>
@@ -900,12 +867,12 @@ case class ShowCreateTableCommand(table: TableIdentifier) extends RunnableComman
private def showDataSourceTableNonDataColumns(
metadata: CatalogTable, builder: StringBuilder): Unit = {
- val partCols = DDLUtils.getPartitionColumnsFromTableProperties(metadata)
+ val partCols = metadata.partitionColumnNames
if (partCols.nonEmpty) {
builder ++= s"PARTITIONED BY ${partCols.mkString("(", ", ", ")")}\n"
}
- DDLUtils.getBucketSpecFromTableProperties(metadata).foreach { spec =>
+ metadata.bucketSpec.foreach { spec =>
if (spec.bucketColumnNames.nonEmpty) {
builder ++= s"CLUSTERED BY ${spec.bucketColumnNames.mkString("(", ", ", ")")}\n"
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 733ba185287e1..5eba7df060c4e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, UnknownPartitioning}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan}
-import org.apache.spark.sql.execution.command.{CreateDataSourceTableUtils, DDLUtils, ExecutedCommandExec}
+import org.apache.spark.sql.execution.command.{DDLUtils, ExecutedCommandExec}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -204,24 +204,14 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] {
*/
class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] {
private def readDataSourceTable(sparkSession: SparkSession, table: CatalogTable): LogicalPlan = {
- val schema = DDLUtils.getSchemaFromTableProperties(table)
-
- // We only need names at here since userSpecifiedSchema we loaded from the metastore
- // contains partition columns. We can always get datatypes of partitioning columns
- // from userSpecifiedSchema.
- val partitionColumns = DDLUtils.getPartitionColumnsFromTableProperties(table)
-
- val bucketSpec = DDLUtils.getBucketSpecFromTableProperties(table)
-
- val options = table.storage.properties
val dataSource =
DataSource(
sparkSession,
- userSpecifiedSchema = Some(schema),
- partitionColumns = partitionColumns,
- bucketSpec = bucketSpec,
- className = table.properties(CreateDataSourceTableUtils.DATASOURCE_PROVIDER),
- options = options)
+ userSpecifiedSchema = Some(table.schema),
+ partitionColumns = table.partitionColumnNames,
+ bucketSpec = table.bucketSpec,
+ className = table.provider.get,
+ options = table.storage.properties)
LogicalRelation(
dataSource.resolveRelation(),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
index 447c237e3a1b0..7880c7cfa16f8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
@@ -33,7 +33,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.UnsafeKVExternalSorter
-import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.util.{SerializableConfiguration, Utils}
@@ -48,6 +47,11 @@ private[datasources] case class WriteRelation(
prepareJobForWrite: Job => OutputWriterFactory,
bucketSpec: Option[BucketSpec])
+object WriterContainer {
+ val DATASOURCE_WRITEJOBUUID = "spark.sql.sources.writeJobUUID"
+ val DATASOURCE_OUTPUTPATH = "spark.sql.sources.output.path"
+}
+
private[datasources] abstract class BaseWriterContainer(
@transient val relation: WriteRelation,
@transient private val job: Job,
@@ -94,7 +98,7 @@ private[datasources] abstract class BaseWriterContainer(
// This UUID is sent to executor side together with the serialized `Configuration` object within
// the `Job` instance. `OutputWriters` on the executor side should use this UUID to generate
// unique task output files.
- job.getConfiguration.set(DATASOURCE_WRITEJOBUUID, uniqueWriteJobId.toString)
+ job.getConfiguration.set(WriterContainer.DATASOURCE_WRITEJOBUUID, uniqueWriteJobId.toString)
// Order of the following two lines is important. For Hadoop 1, TaskAttemptContext constructor
// clones the Configuration object passed in. If we initialize the TaskAttemptContext first,
@@ -244,7 +248,7 @@ private[datasources] class DefaultWriterContainer(
def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = {
executorSideSetup(taskContext)
val configuration = taskAttemptContext.getConfiguration
- configuration.set(DATASOURCE_OUTPUTPATH, outputPath)
+ configuration.set(WriterContainer.DATASOURCE_OUTPUTPATH, outputPath)
var writer = newOutputWriter(getWorkPath)
writer.initConverter(dataSchema)
@@ -352,10 +356,12 @@ private[datasources] class DynamicPartitionWriterContainer(
val configuration = taskAttemptContext.getConfiguration
val path = if (partitionColumns.nonEmpty) {
val partitionPath = getPartitionString(key).getString(0)
- configuration.set(DATASOURCE_OUTPUTPATH, new Path(outputPath, partitionPath).toString)
+ configuration.set(
+ WriterContainer.DATASOURCE_OUTPUTPATH,
+ new Path(outputPath, partitionPath).toString)
new Path(getWorkPath, partitionPath).toString
} else {
- configuration.set(DATASOURCE_OUTPUTPATH, outputPath)
+ configuration.set(WriterContainer.DATASOURCE_OUTPUTPATH, outputPath)
getWorkPath
}
val bucketId = getBucketIdFromKey(key)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
index 6b2f9fc61e677..de2d633c0bcf4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
@@ -30,8 +30,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
-import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils
-import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory, PartitionedFile}
+import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory, PartitionedFile, WriterContainer}
import org.apache.spark.sql.types._
object CSVRelation extends Logging {
@@ -192,7 +191,7 @@ private[csv] class CsvOutputWriter(
new TextOutputFormat[NullWritable, Text]() {
override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
val configuration = context.getConfiguration
- val uniqueWriteJobId = configuration.get(CreateDataSourceTableUtils.DATASOURCE_WRITEJOBUUID)
+ val uniqueWriteJobId = configuration.get(WriterContainer.DATASOURCE_WRITEJOBUUID)
val taskAttemptId = context.getTaskAttemptID
val split = taskAttemptId.getTaskID.getId
new Path(path, f"part-r-$split%05d-$uniqueWriteJobId.csv$extension")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
index 27910e2cddad8..16150b91d6452 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
@@ -31,7 +31,6 @@ import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
@@ -164,7 +163,7 @@ private[json] class JsonOutputWriter(
new TextOutputFormat[NullWritable, Text]() {
override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
val configuration = context.getConfiguration
- val uniqueWriteJobId = configuration.get(CreateDataSourceTableUtils.DATASOURCE_WRITEJOBUUID)
+ val uniqueWriteJobId = configuration.get(WriterContainer.DATASOURCE_WRITEJOBUUID)
val taskAttemptId = context.getTaskAttemptID
val split = taskAttemptId.getTaskID.getId
val bucketString = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
index 9c4778acf53d7..9208c82179d8d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
@@ -44,7 +44,6 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.catalyst.parser.LegacyTypeStringParser
-import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.internal.SQLConf
@@ -547,8 +546,7 @@ private[parquet] class ParquetOutputWriter(
// partitions in the case of dynamic partitioning.
override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
val configuration = context.getConfiguration
- val uniqueWriteJobId = configuration.get(
- CreateDataSourceTableUtils.DATASOURCE_WRITEJOBUUID)
+ val uniqueWriteJobId = configuration.get(WriterContainer.DATASOURCE_WRITEJOBUUID)
val taskAttemptId = context.getTaskAttemptID
val split = taskAttemptId.getTaskID.getId
val bucketString = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
index abb6059f75ba8..a0c3fd53fb53b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
@@ -27,7 +27,6 @@ import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter}
-import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{StringType, StructType}
@@ -131,7 +130,7 @@ class TextOutputWriter(path: String, dataSchema: StructType, context: TaskAttemp
new TextOutputFormat[NullWritable, Text]() {
override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
val configuration = context.getConfiguration
- val uniqueWriteJobId = configuration.get(CreateDataSourceTableUtils.DATASOURCE_WRITEJOBUUID)
+ val uniqueWriteJobId = configuration.get(WriterContainer.DATASOURCE_WRITEJOBUUID)
val taskAttemptId = context.getTaskAttemptID
val split = taskAttemptId.getTaskID.getId
new Path(path, f"part-r-$split%05d-$uniqueWriteJobId.txt$extension")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala
index ad69137f7401b..52e648a917d8b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala
@@ -28,10 +28,9 @@ object HiveSerDe {
*
* @param source Currently the source abbreviation can be one of the following:
* SequenceFile, RCFile, ORC, PARQUET, and case insensitive.
- * @param conf SQLConf
* @return HiveSerDe associated with the specified source
*/
- def sourceToSerDe(source: String, conf: SQLConf): Option[HiveSerDe] = {
+ def sourceToSerDe(source: String): Option[HiveSerDe] = {
val serdeMap = Map(
"sequencefile" ->
HiveSerDe(
@@ -42,8 +41,7 @@ object HiveSerDe {
HiveSerDe(
inputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileInputFormat"),
outputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileOutputFormat"),
- serde = Option(conf.getConfString("hive.default.rcfile.serde",
- "org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe"))),
+ serde = Option("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")),
"orc" ->
HiveSerDe(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala
index be1bccbd990a0..8dd883b37bde0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala
@@ -243,7 +243,7 @@ class DDLCommandSuite extends PlanTest {
allSources.foreach { s =>
val query = s"CREATE TABLE my_tab STORED AS $s"
val ct = parseAs[CreateTable](query)
- val hiveSerde = HiveSerDe.sourceToSerDe(s, new SQLConf)
+ val hiveSerde = HiveSerDe.sourceToSerDe(s)
assert(hiveSerde.isDefined)
assert(ct.tableDesc.storage.serde == hiveSerde.get.serde)
assert(ct.tableDesc.storage.inputFormat == hiveSerde.get.inputFormat)
@@ -276,7 +276,7 @@ class DDLCommandSuite extends PlanTest {
val query = s"CREATE TABLE my_tab ROW FORMAT SERDE 'anything' STORED AS $s"
if (supportedSources.contains(s)) {
val ct = parseAs[CreateTable](query)
- val hiveSerde = HiveSerDe.sourceToSerDe(s, new SQLConf)
+ val hiveSerde = HiveSerDe.sourceToSerDe(s)
assert(hiveSerde.isDefined)
assert(ct.tableDesc.storage.serde == Some("anything"))
assert(ct.tableDesc.storage.inputFormat == hiveSerde.get.inputFormat)
@@ -295,7 +295,7 @@ class DDLCommandSuite extends PlanTest {
val query = s"CREATE TABLE my_tab ROW FORMAT DELIMITED FIELDS TERMINATED BY ' ' STORED AS $s"
if (supportedSources.contains(s)) {
val ct = parseAs[CreateTable](query)
- val hiveSerde = HiveSerDe.sourceToSerDe(s, new SQLConf)
+ val hiveSerde = HiveSerDe.sourceToSerDe(s)
assert(hiveSerde.isDefined)
assert(ct.tableDesc.storage.serde == hiveSerde.get.serde)
assert(ct.tableDesc.storage.inputFormat == hiveSerde.get.inputFormat)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
index 0f7fda7666a3b..e6ae42258d4c5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
@@ -30,7 +30,6 @@ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogDatabase, Catal
import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType}
import org.apache.spark.sql.catalyst.catalog.{CatalogTablePartition, SessionCatalog}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
-import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
@@ -93,7 +92,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
.add("col2", "string")
.add("a", "int")
.add("b", "int"),
- provider = Some("parquet"),
+ provider = Some("hive"),
partitionColumnNames = Seq("a", "b"),
createTime = 0L)
}
@@ -277,10 +276,8 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
""".stripMargin)
val tableMetadata = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tabName))
- assert(expectedSchema ==
- DDLUtils.getSchemaFromTableProperties(tableMetadata))
- assert(expectedPartitionCols ==
- DDLUtils.getPartitionColumnsFromTableProperties(tableMetadata))
+ assert(expectedSchema == tableMetadata.schema)
+ assert(expectedPartitionCols == tableMetadata.partitionColumnNames)
}
}
@@ -399,41 +396,6 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
assert(e.message == "Found duplicate column(s) in bucket: a")
}
- test("Describe Table with Corrupted Schema") {
- import testImplicits._
-
- val tabName = "tab1"
- withTempPath { dir =>
- val path = dir.getCanonicalPath
- val df = sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("col1", "col2")
- df.write.format("json").save(path)
- val uri = dir.toURI
-
- withTable(tabName) {
- sql(
- s"""
- |CREATE TABLE $tabName
- |USING json
- |OPTIONS (
- | path '$uri'
- |)
- """.stripMargin)
-
- val catalog = spark.sessionState.catalog
- val table = catalog.getTableMetadata(TableIdentifier(tabName))
- val newProperties = table.properties.filterKeys(key =>
- key != CreateDataSourceTableUtils.DATASOURCE_SCHEMA_NUMPARTS)
- val newTable = table.copy(properties = newProperties)
- catalog.alterTable(newTable)
-
- val e = intercept[AnalysisException] {
- sql(s"DESC $tabName")
- }.getMessage
- assert(e.contains(s"Could not read schema from the metastore because it is corrupted"))
- }
- }
- }
-
test("Refresh table after changing the data source table partitioning") {
import testImplicits._
@@ -460,10 +422,8 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
|)
""".stripMargin)
val tableMetadata = catalog.getTableMetadata(TableIdentifier(tabName))
- val tableSchema = DDLUtils.getSchemaFromTableProperties(tableMetadata)
- assert(tableSchema == schema)
- val partCols = DDLUtils.getPartitionColumnsFromTableProperties(tableMetadata)
- assert(partCols == partitionCols)
+ assert(tableMetadata.schema == schema)
+ assert(tableMetadata.partitionColumnNames == partitionCols)
// Change the schema
val newDF = sparkContext.parallelize(1 to 10).map(i => (i, i.toString))
@@ -472,23 +432,15 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
// No change on the schema
val tableMetadataBeforeRefresh = catalog.getTableMetadata(TableIdentifier(tabName))
- val tableSchemaBeforeRefresh =
- DDLUtils.getSchemaFromTableProperties(tableMetadataBeforeRefresh)
- assert(tableSchemaBeforeRefresh == schema)
- val partColsBeforeRefresh =
- DDLUtils.getPartitionColumnsFromTableProperties(tableMetadataBeforeRefresh)
- assert(partColsBeforeRefresh == partitionCols)
+ assert(tableMetadataBeforeRefresh.schema == schema)
+ assert(tableMetadataBeforeRefresh.partitionColumnNames == partitionCols)
// Refresh does not affect the schema
spark.catalog.refreshTable(tabName)
val tableMetadataAfterRefresh = catalog.getTableMetadata(TableIdentifier(tabName))
- val tableSchemaAfterRefresh =
- DDLUtils.getSchemaFromTableProperties(tableMetadataAfterRefresh)
- assert(tableSchemaAfterRefresh == schema)
- val partColsAfterRefresh =
- DDLUtils.getPartitionColumnsFromTableProperties(tableMetadataAfterRefresh)
- assert(partColsAfterRefresh == partitionCols)
+ assert(tableMetadataAfterRefresh.schema == schema)
+ assert(tableMetadataAfterRefresh.partitionColumnNames == partitionCols)
}
}
}
@@ -641,7 +593,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
val table = catalog.getTableMetadata(TableIdentifier("tbl"))
assert(table.tableType == CatalogTableType.MANAGED)
assert(table.schema == new StructType().add("a", "int").add("b", "int"))
- assert(table.properties(DATASOURCE_PROVIDER) == "parquet")
+ assert(table.provider == Some("parquet"))
}
}
@@ -651,12 +603,9 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
sql("CREATE TABLE tbl(a INT, b INT) USING parquet PARTITIONED BY (a)")
val table = catalog.getTableMetadata(TableIdentifier("tbl"))
assert(table.tableType == CatalogTableType.MANAGED)
- assert(table.schema.isEmpty) // partitioned datasource table is not hive-compatible
- assert(table.properties(DATASOURCE_PROVIDER) == "parquet")
- assert(DDLUtils.getSchemaFromTableProperties(table) ==
- new StructType().add("a", IntegerType).add("b", IntegerType))
- assert(DDLUtils.getPartitionColumnsFromTableProperties(table) ==
- Seq("a"))
+ assert(table.provider == Some("parquet"))
+ assert(table.schema == new StructType().add("a", IntegerType).add("b", IntegerType))
+ assert(table.partitionColumnNames == Seq("a"))
}
}
@@ -667,12 +616,9 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
"CLUSTERED BY (a) SORTED BY (b) INTO 5 BUCKETS")
val table = catalog.getTableMetadata(TableIdentifier("tbl"))
assert(table.tableType == CatalogTableType.MANAGED)
- assert(table.schema.isEmpty) // partitioned datasource table is not hive-compatible
- assert(table.properties(DATASOURCE_PROVIDER) == "parquet")
- assert(DDLUtils.getSchemaFromTableProperties(table) ==
- new StructType().add("a", IntegerType).add("b", IntegerType))
- assert(DDLUtils.getBucketSpecFromTableProperties(table) ==
- Some(BucketSpec(5, Seq("a"), Seq("b"))))
+ assert(table.provider == Some("parquet"))
+ assert(table.schema == new StructType().add("a", IntegerType).add("b", IntegerType))
+ assert(table.bucketSpec == Some(BucketSpec(5, Seq("a"), Seq("b"))))
}
}
@@ -1096,7 +1042,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
catalog: SessionCatalog,
tableIdent: TableIdentifier): Unit = {
catalog.alterTable(catalog.getTableMetadata(tableIdent).copy(
- properties = Map(DATASOURCE_PROVIDER -> "csv")))
+ provider = Some("csv")))
}
private def testSetProperties(isDatasourceTable: Boolean): Unit = {
@@ -1108,9 +1054,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
convertToDatasourceTable(catalog, tableIdent)
}
def getProps: Map[String, String] = {
- catalog.getTableMetadata(tableIdent).properties.filterKeys { k =>
- !isDatasourceTable || !k.startsWith(DATASOURCE_PREFIX)
- }
+ catalog.getTableMetadata(tableIdent).properties
}
assert(getProps.isEmpty)
// set table properties
@@ -1124,11 +1068,6 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
intercept[AnalysisException] {
sql("ALTER TABLE does_not_exist SET TBLPROPERTIES ('winner' = 'loser')")
}
- // datasource table property keys are not allowed
- val e = intercept[AnalysisException] {
- sql(s"ALTER TABLE tab1 SET TBLPROPERTIES ('${DATASOURCE_PREFIX}foo' = 'loser')")
- }
- assert(e.getMessage.contains(DATASOURCE_PREFIX + "foo"))
}
private def testUnsetProperties(isDatasourceTable: Boolean): Unit = {
@@ -1140,9 +1079,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
convertToDatasourceTable(catalog, tableIdent)
}
def getProps: Map[String, String] = {
- catalog.getTableMetadata(tableIdent).properties.filterKeys { k =>
- !isDatasourceTable || !k.startsWith(DATASOURCE_PREFIX)
- }
+ catalog.getTableMetadata(tableIdent).properties
}
// unset table properties
sql("ALTER TABLE dbx.tab1 SET TBLPROPERTIES ('j' = 'am', 'p' = 'an', 'c' = 'lan', 'x' = 'y')")
@@ -1164,11 +1101,6 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
// property to unset does not exist, but "IF EXISTS" is specified
sql("ALTER TABLE tab1 UNSET TBLPROPERTIES IF EXISTS ('c', 'xyz')")
assert(getProps == Map("x" -> "y"))
- // datasource table property keys are not allowed
- val e2 = intercept[AnalysisException] {
- sql(s"ALTER TABLE tab1 UNSET TBLPROPERTIES ('${DATASOURCE_PREFIX}foo')")
- }
- assert(e2.getMessage.contains(DATASOURCE_PREFIX + "foo"))
}
private def testSetLocation(isDatasourceTable: Boolean): Unit = {
@@ -1573,10 +1505,6 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
}
}
- test("create table with datasource properties (not allowed)") {
- assertUnsupported("CREATE TABLE my_tab TBLPROPERTIES ('spark.sql.sources.me'='anything')")
- }
-
test("Create Hive Table As Select") {
import testImplicits._
withTable("t", "t1") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
index 49153f77362b7..729c9fdda543e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
@@ -201,7 +201,7 @@ class CreateTableAsSelectSuite
""".stripMargin
)
val table = catalog.getTableMetadata(TableIdentifier("t"))
- assert(DDLUtils.getPartitionColumnsFromTableProperties(table) == Seq("a"))
+ assert(table.partitionColumnNames == Seq("a"))
}
}
@@ -217,8 +217,7 @@ class CreateTableAsSelectSuite
""".stripMargin
)
val table = catalog.getTableMetadata(TableIdentifier("t"))
- assert(DDLUtils.getBucketSpecFromTableProperties(table) ==
- Option(BucketSpec(5, Seq("a"), Seq("b"))))
+ assert(table.bucketSpec == Option(BucketSpec(5, Seq("a"), Seq("b"))))
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
index 8302e3e98ad34..de3e60a44d920 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
@@ -30,7 +30,11 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog._
+import org.apache.spark.sql.execution.command.DDLUtils
+import org.apache.spark.sql.execution.datasources.CaseInsensitiveMap
import org.apache.spark.sql.hive.client.HiveClient
+import org.apache.spark.sql.internal.HiveSerDe
+import org.apache.spark.sql.types.{DataType, StructType}
/**
@@ -41,6 +45,8 @@ private[spark] class HiveExternalCatalog(client: HiveClient, hadoopConf: Configu
extends ExternalCatalog with Logging {
import CatalogTypes.TablePartitionSpec
+ import HiveExternalCatalog._
+ import CatalogTableType._
// Exceptions thrown by the hive client that we would like to wrap
private val clientExceptions = Set(
@@ -81,6 +87,20 @@ private[spark] class HiveExternalCatalog(client: HiveClient, hadoopConf: Configu
withClient { getTable(db, table) }
}
+ /**
+ * If the given table properties contains datasource properties, throw an exception. We will do
+ * this check when create or alter a table, i.e. when we try to write table metadata to Hive
+ * metastore.
+ */
+ private def verifyTableProperties(table: CatalogTable): Unit = {
+ val datasourceKeys = table.properties.keys.filter(_.startsWith(DATASOURCE_PREFIX))
+ if (datasourceKeys.nonEmpty) {
+ throw new AnalysisException(s"Cannot persistent ${table.qualifiedName} into hive metastore " +
+ s"as table property keys may not start with '$DATASOURCE_PREFIX': " +
+ datasourceKeys.mkString("[", ", ", "]"))
+ }
+ }
+
// --------------------------------------------------------------------------
// Databases
// --------------------------------------------------------------------------
@@ -144,16 +164,162 @@ private[spark] class HiveExternalCatalog(client: HiveClient, hadoopConf: Configu
assert(tableDefinition.identifier.database.isDefined)
val db = tableDefinition.identifier.database.get
requireDbExists(db)
+ verifyTableProperties(tableDefinition)
+
+ // Before saving data source table metadata into Hive metastore, we should:
+ // 1. Put table schema, partition column names and bucket specification in table properties.
+ // 2. Check if this table is hive compatible
+ // 2.1 If it's not hive compatible, set schema, partition columns and bucket spec to empty
+ // and save table metadata to Hive.
+ // 2.1 If it's hive compatible, set serde information in table metadata and try to save
+ // it to Hive. If it fails, treat it as not hive compatible and go back to 2.1
+ if (DDLUtils.isDatasourceTable(tableDefinition)) {
+ // data source table always have a provider, it's guaranteed by `DDLUtils.isDatasourceTable`.
+ val provider = tableDefinition.provider.get
+ val partitionColumns = tableDefinition.partitionColumnNames
+ val bucketSpec = tableDefinition.bucketSpec
+
+ val tableProperties = new scala.collection.mutable.HashMap[String, String]
+ tableProperties.put(DATASOURCE_PROVIDER, provider)
+
+ // Serialized JSON schema string may be too long to be stored into a single metastore table
+ // property. In this case, we split the JSON string and store each part as a separate table
+ // property.
+ // TODO: the threshold should be set by `spark.sql.sources.schemaStringLengthThreshold`,
+ // however the current SQLConf is session isolated, which is not applicable to external
+ // catalog. We should re-enable this conf instead of hard code the value here, after we have
+ // global SQLConf.
+ val threshold = 4000
+ val schemaJsonString = tableDefinition.schema.json
+ // Split the JSON string.
+ val parts = schemaJsonString.grouped(threshold).toSeq
+ tableProperties.put(DATASOURCE_SCHEMA_NUMPARTS, parts.size.toString)
+ parts.zipWithIndex.foreach { case (part, index) =>
+ tableProperties.put(s"$DATASOURCE_SCHEMA_PART_PREFIX$index", part)
+ }
+
+ if (partitionColumns.nonEmpty) {
+ tableProperties.put(DATASOURCE_SCHEMA_NUMPARTCOLS, partitionColumns.length.toString)
+ partitionColumns.zipWithIndex.foreach { case (partCol, index) =>
+ tableProperties.put(s"$DATASOURCE_SCHEMA_PARTCOL_PREFIX$index", partCol)
+ }
+ }
+
+ if (bucketSpec.isDefined) {
+ val BucketSpec(numBuckets, bucketColumnNames, sortColumnNames) = bucketSpec.get
+
+ tableProperties.put(DATASOURCE_SCHEMA_NUMBUCKETS, numBuckets.toString)
+ tableProperties.put(DATASOURCE_SCHEMA_NUMBUCKETCOLS, bucketColumnNames.length.toString)
+ bucketColumnNames.zipWithIndex.foreach { case (bucketCol, index) =>
+ tableProperties.put(s"$DATASOURCE_SCHEMA_BUCKETCOL_PREFIX$index", bucketCol)
+ }
+
+ if (sortColumnNames.nonEmpty) {
+ tableProperties.put(DATASOURCE_SCHEMA_NUMSORTCOLS, sortColumnNames.length.toString)
+ sortColumnNames.zipWithIndex.foreach { case (sortCol, index) =>
+ tableProperties.put(s"$DATASOURCE_SCHEMA_SORTCOL_PREFIX$index", sortCol)
+ }
+ }
+ }
+
+ // converts the table metadata to Spark SQL specific format, i.e. set schema, partition column
+ // names and bucket specification to empty.
+ def newSparkSQLSpecificMetastoreTable(): CatalogTable = {
+ tableDefinition.copy(
+ schema = new StructType,
+ partitionColumnNames = Nil,
+ bucketSpec = None,
+ properties = tableDefinition.properties ++ tableProperties)
+ }
+
+ // converts the table metadata to Hive compatible format, i.e. set the serde information.
+ def newHiveCompatibleMetastoreTable(serde: HiveSerDe, path: String): CatalogTable = {
+ tableDefinition.copy(
+ storage = tableDefinition.storage.copy(
+ locationUri = Some(new Path(path).toUri.toString),
+ inputFormat = serde.inputFormat,
+ outputFormat = serde.outputFormat,
+ serde = serde.serde
+ ),
+ properties = tableDefinition.properties ++ tableProperties)
+ }
+
+ val qualifiedTableName = tableDefinition.identifier.quotedString
+ val maybeSerde = HiveSerDe.sourceToSerDe(tableDefinition.provider.get)
+ val maybePath = new CaseInsensitiveMap(tableDefinition.storage.properties).get("path")
+ val skipHiveMetadata = tableDefinition.storage.properties
+ .getOrElse("skipHiveMetadata", "false").toBoolean
+
+ val (hiveCompatibleTable, logMessage) = (maybeSerde, maybePath) match {
+ case _ if skipHiveMetadata =>
+ val message =
+ s"Persisting data source table $qualifiedTableName into Hive metastore in" +
+ "Spark SQL specific format, which is NOT compatible with Hive."
+ (None, message)
+
+ // our bucketing is un-compatible with hive(different hash function)
+ case _ if tableDefinition.bucketSpec.nonEmpty =>
+ val message =
+ s"Persisting bucketed data source table $qualifiedTableName into " +
+ "Hive metastore in Spark SQL specific format, which is NOT compatible with Hive. "
+ (None, message)
+
+ case (Some(serde), Some(path)) =>
+ val message =
+ s"Persisting file based data source table $qualifiedTableName with an input path " +
+ s"into Hive metastore in Hive compatible format."
+ (Some(newHiveCompatibleMetastoreTable(serde, path)), message)
+
+ case (Some(_), None) =>
+ val message =
+ s"Data source table $qualifiedTableName is not file based. Persisting it into " +
+ s"Hive metastore in Spark SQL specific format, which is NOT compatible with Hive."
+ (None, message)
+
+ case _ =>
+ val provider = tableDefinition.provider.get
+ val message =
+ s"Couldn't find corresponding Hive SerDe for data source provider $provider. " +
+ s"Persisting data source table $qualifiedTableName into Hive metastore in " +
+ s"Spark SQL specific format, which is NOT compatible with Hive."
+ (None, message)
+ }
+
+ (hiveCompatibleTable, logMessage) match {
+ case (Some(table), message) =>
+ // We first try to save the metadata of the table in a Hive compatible way.
+ // If Hive throws an error, we fall back to save its metadata in the Spark SQL
+ // specific way.
+ try {
+ logInfo(message)
+ saveTableIntoHive(table, ignoreIfExists)
+ } catch {
+ case NonFatal(e) =>
+ val warningMessage =
+ s"Could not persist ${tableDefinition.identifier.quotedString} in a Hive " +
+ "compatible way. Persisting it into Hive metastore in Spark SQL specific format."
+ logWarning(warningMessage, e)
+ saveTableIntoHive(newSparkSQLSpecificMetastoreTable(), ignoreIfExists)
+ }
+
+ case (None, message) =>
+ logWarning(message)
+ saveTableIntoHive(newSparkSQLSpecificMetastoreTable(), ignoreIfExists)
+ }
+ } else {
+ client.createTable(tableDefinition, ignoreIfExists)
+ }
+ }
- if (
+ private def saveTableIntoHive(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = {
+ assert(DDLUtils.isDatasourceTable(tableDefinition),
+ "saveTableIntoHive only takes data source table.")
// If this is an external data source table...
- tableDefinition.properties.contains("spark.sql.sources.provider") &&
- tableDefinition.tableType == CatalogTableType.EXTERNAL &&
- // ... that is not persisted as Hive compatible format (external tables in Hive compatible
- // format always set `locationUri` to the actual data location and should NOT be hacked as
- // following.)
- tableDefinition.storage.locationUri.isEmpty
- ) {
+ if (tableDefinition.tableType == EXTERNAL &&
+ // ... that is not persisted as Hive compatible format (external tables in Hive compatible
+ // format always set `locationUri` to the actual data location and should NOT be hacked as
+ // following.)
+ tableDefinition.storage.locationUri.isEmpty) {
// !! HACK ALERT !!
//
// Due to a restriction of Hive metastore, here we have to set `locationUri` to a temporary
@@ -200,22 +366,79 @@ private[spark] class HiveExternalCatalog(client: HiveClient, hadoopConf: Configu
* Alter a table whose name that matches the one specified in `tableDefinition`,
* assuming the table exists.
*
- * Note: As of now, this only supports altering table properties, serde properties,
- * and num buckets!
+ * Note: As of now, this doesn't support altering table schema, partition column names and bucket
+ * specification. We will ignore them even if users do specify different values for these fields.
*/
override def alterTable(tableDefinition: CatalogTable): Unit = withClient {
assert(tableDefinition.identifier.database.isDefined)
val db = tableDefinition.identifier.database.get
requireTableExists(db, tableDefinition.identifier.table)
- client.alterTable(tableDefinition)
+ verifyTableProperties(tableDefinition)
+
+ if (DDLUtils.isDatasourceTable(tableDefinition)) {
+ val oldDef = client.getTable(db, tableDefinition.identifier.table)
+ // Sets the `schema`, `partitionColumnNames` and `bucketSpec` from the old table definition,
+ // to retain the spark specific format if it is. Also add old data source properties to table
+ // properties, to retain the data source table format.
+ val oldDataSourceProps = oldDef.properties.filter(_._1.startsWith(DATASOURCE_PREFIX))
+ val newDef = tableDefinition.copy(
+ schema = oldDef.schema,
+ partitionColumnNames = oldDef.partitionColumnNames,
+ bucketSpec = oldDef.bucketSpec,
+ properties = oldDataSourceProps ++ tableDefinition.properties)
+
+ client.alterTable(newDef)
+ } else {
+ client.alterTable(tableDefinition)
+ }
}
override def getTable(db: String, table: String): CatalogTable = withClient {
- client.getTable(db, table)
+ restoreTableMetadata(client.getTable(db, table))
}
override def getTableOption(db: String, table: String): Option[CatalogTable] = withClient {
- client.getTableOption(db, table)
+ client.getTableOption(db, table).map(restoreTableMetadata)
+ }
+
+ /**
+ * Restores table metadata from the table properties if it's a datasouce table. This method is
+ * kind of a opposite version of [[createTable]].
+ *
+ * It reads table schema, provider, partition column names and bucket specification from table
+ * properties, and filter out these special entries from table properties.
+ */
+ private def restoreTableMetadata(table: CatalogTable): CatalogTable = {
+ if (table.tableType == VIEW) {
+ table
+ } else {
+ getProviderFromTableProperties(table).map { provider =>
+ assert(provider != "hive", "Hive serde table should not save provider in table properties.")
+ // SPARK-15269: Persisted data source tables always store the location URI as a storage
+ // property named "path" instead of standard Hive `dataLocation`, because Hive only
+ // allows directory paths as location URIs while Spark SQL data source tables also
+ // allows file paths. So the standard Hive `dataLocation` is meaningless for Spark SQL
+ // data source tables.
+ // Spark SQL may also save external data source in Hive compatible format when
+ // possible, so that these tables can be directly accessed by Hive. For these tables,
+ // `dataLocation` is still necessary. Here we also check for input format because only
+ // these Hive compatible tables set this field.
+ val storage = if (table.tableType == EXTERNAL && table.storage.inputFormat.isEmpty) {
+ table.storage.copy(locationUri = None)
+ } else {
+ table.storage
+ }
+ table.copy(
+ storage = storage,
+ schema = getSchemaFromTableProperties(table),
+ provider = Some(provider),
+ partitionColumnNames = getPartitionColumnsFromTableProperties(table),
+ bucketSpec = getBucketSpecFromTableProperties(table),
+ properties = getOriginalTableProperties(table))
+ } getOrElse {
+ table.copy(provider = Some("hive"))
+ }
+ }
}
override def tableExists(db: String, table: String): Boolean = withClient {
@@ -363,3 +586,82 @@ private[spark] class HiveExternalCatalog(client: HiveClient, hadoopConf: Configu
}
}
+
+object HiveExternalCatalog {
+ val DATASOURCE_PREFIX = "spark.sql.sources."
+ val DATASOURCE_PROVIDER = DATASOURCE_PREFIX + "provider"
+ val DATASOURCE_SCHEMA = DATASOURCE_PREFIX + "schema"
+ val DATASOURCE_SCHEMA_PREFIX = DATASOURCE_SCHEMA + "."
+ val DATASOURCE_SCHEMA_NUMPARTS = DATASOURCE_SCHEMA_PREFIX + "numParts"
+ val DATASOURCE_SCHEMA_NUMPARTCOLS = DATASOURCE_SCHEMA_PREFIX + "numPartCols"
+ val DATASOURCE_SCHEMA_NUMSORTCOLS = DATASOURCE_SCHEMA_PREFIX + "numSortCols"
+ val DATASOURCE_SCHEMA_NUMBUCKETS = DATASOURCE_SCHEMA_PREFIX + "numBuckets"
+ val DATASOURCE_SCHEMA_NUMBUCKETCOLS = DATASOURCE_SCHEMA_PREFIX + "numBucketCols"
+ val DATASOURCE_SCHEMA_PART_PREFIX = DATASOURCE_SCHEMA_PREFIX + "part."
+ val DATASOURCE_SCHEMA_PARTCOL_PREFIX = DATASOURCE_SCHEMA_PREFIX + "partCol."
+ val DATASOURCE_SCHEMA_BUCKETCOL_PREFIX = DATASOURCE_SCHEMA_PREFIX + "bucketCol."
+ val DATASOURCE_SCHEMA_SORTCOL_PREFIX = DATASOURCE_SCHEMA_PREFIX + "sortCol."
+
+ def getProviderFromTableProperties(metadata: CatalogTable): Option[String] = {
+ metadata.properties.get(DATASOURCE_PROVIDER)
+ }
+
+ def getOriginalTableProperties(metadata: CatalogTable): Map[String, String] = {
+ metadata.properties.filterNot { case (key, _) => key.startsWith(DATASOURCE_PREFIX) }
+ }
+
+ // A persisted data source table always store its schema in the catalog.
+ def getSchemaFromTableProperties(metadata: CatalogTable): StructType = {
+ val errorMessage = "Could not read schema from the hive metastore because it is corrupted."
+ val props = metadata.properties
+ props.get(DATASOURCE_SCHEMA).map { schema =>
+ // Originally, we used `spark.sql.sources.schema` to store the schema of a data source table.
+ // After SPARK-6024, we removed this flag.
+ // Although we are not using `spark.sql.sources.schema` any more, we need to still support.
+ DataType.fromJson(schema).asInstanceOf[StructType]
+ } getOrElse {
+ props.get(DATASOURCE_SCHEMA_NUMPARTS).map { numParts =>
+ val parts = (0 until numParts.toInt).map { index =>
+ val part = metadata.properties.get(s"$DATASOURCE_SCHEMA_PART_PREFIX$index").orNull
+ if (part == null) {
+ throw new AnalysisException(errorMessage +
+ s" (missing part $index of the schema, $numParts parts are expected).")
+ }
+ part
+ }
+ // Stick all parts back to a single schema string.
+ DataType.fromJson(parts.mkString).asInstanceOf[StructType]
+ } getOrElse {
+ throw new AnalysisException(errorMessage)
+ }
+ }
+ }
+
+ private def getColumnNamesByType(
+ props: Map[String, String],
+ colType: String,
+ typeName: String): Seq[String] = {
+ for {
+ numCols <- props.get(s"spark.sql.sources.schema.num${colType.capitalize}Cols").toSeq
+ index <- 0 until numCols.toInt
+ } yield props.getOrElse(
+ s"$DATASOURCE_SCHEMA_PREFIX${colType}Col.$index",
+ throw new AnalysisException(
+ s"Corrupted $typeName in catalog: $numCols parts expected, but part $index is missing."
+ )
+ )
+ }
+
+ def getPartitionColumnsFromTableProperties(metadata: CatalogTable): Seq[String] = {
+ getColumnNamesByType(metadata.properties, "part", "partitioning columns")
+ }
+
+ def getBucketSpecFromTableProperties(metadata: CatalogTable): Option[BucketSpec] = {
+ metadata.properties.get(DATASOURCE_SCHEMA_NUMBUCKETS).map { numBuckets =>
+ BucketSpec(
+ numBuckets.toInt,
+ getColumnNamesByType(metadata.properties, "bucket", "bucketing columns"),
+ getColumnNamesByType(metadata.properties, "sort", "sorting columns"))
+ }
+ }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 7118edabb83cf..181f470b2a100 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
-import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils._
+import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.{Partition => _, _}
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetOptions}
import org.apache.spark.sql.hive.orc.OrcFileFormat
@@ -68,64 +68,16 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
val cacheLoader = new CacheLoader[QualifiedTableName, LogicalPlan]() {
override def load(in: QualifiedTableName): LogicalPlan = {
logDebug(s"Creating new cached data source for $in")
- val table = client.getTable(in.database, in.name)
+ val table = sparkSession.sharedState.externalCatalog.getTable(in.database, in.name)
- // TODO: the following code is duplicated with FindDataSourceTable.readDataSourceTable
-
- def schemaStringFromParts: Option[String] = {
- table.properties.get(DATASOURCE_SCHEMA_NUMPARTS).map { numParts =>
- val parts = (0 until numParts.toInt).map { index =>
- val part = table.properties.get(s"$DATASOURCE_SCHEMA_PART_PREFIX$index").orNull
- if (part == null) {
- throw new AnalysisException(
- "Could not read schema from the metastore because it is corrupted " +
- s"(missing part $index of the schema, $numParts parts are expected).")
- }
-
- part
- }
- // Stick all parts back to a single schema string.
- parts.mkString
- }
- }
-
- def getColumnNames(colType: String): Seq[String] = {
- table.properties.get(s"$DATASOURCE_SCHEMA.num${colType.capitalize}Cols").map {
- numCols => (0 until numCols.toInt).map { index =>
- table.properties.getOrElse(s"$DATASOURCE_SCHEMA_PREFIX${colType}Col.$index",
- throw new AnalysisException(
- s"Could not read $colType columns from the metastore because it is corrupted " +
- s"(missing part $index of it, $numCols parts are expected)."))
- }
- }.getOrElse(Nil)
- }
-
- // Originally, we used spark.sql.sources.schema to store the schema of a data source table.
- // After SPARK-6024, we removed this flag.
- // Although we are not using spark.sql.sources.schema any more, we need to still support.
- val schemaString = table.properties.get(DATASOURCE_SCHEMA).orElse(schemaStringFromParts)
-
- val userSpecifiedSchema =
- schemaString.map(s => DataType.fromJson(s).asInstanceOf[StructType])
-
- // We only need names at here since userSpecifiedSchema we loaded from the metastore
- // contains partition columns. We can always get data types of partitioning columns
- // from userSpecifiedSchema.
- val partitionColumns = getColumnNames("part")
-
- val bucketSpec = table.properties.get(DATASOURCE_SCHEMA_NUMBUCKETS).map { n =>
- BucketSpec(n.toInt, getColumnNames("bucket"), getColumnNames("sort"))
- }
-
- val options = table.storage.properties
val dataSource =
DataSource(
sparkSession,
- userSpecifiedSchema = userSpecifiedSchema,
- partitionColumns = partitionColumns,
- bucketSpec = bucketSpec,
- className = table.properties(DATASOURCE_PROVIDER),
- options = options)
+ userSpecifiedSchema = Some(table.schema),
+ partitionColumns = table.partitionColumnNames,
+ bucketSpec = table.bucketSpec,
+ className = table.provider.get,
+ options = table.storage.properties)
LogicalRelation(
dataSource.resolveRelation(checkPathExist = true),
@@ -158,9 +110,10 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
tableIdent: TableIdentifier,
alias: Option[String]): LogicalPlan = {
val qualifiedTableName = getQualifiedTableName(tableIdent)
- val table = client.getTable(qualifiedTableName.database, qualifiedTableName.name)
+ val table = sparkSession.sharedState.externalCatalog.getTable(
+ qualifiedTableName.database, qualifiedTableName.name)
- if (table.properties.get(DATASOURCE_PROVIDER).isDefined) {
+ if (DDLUtils.isDatasourceTable(table)) {
val dataSourceTable = cachedDataSourceTables(qualifiedTableName)
val qualifiedTable = SubqueryAlias(qualifiedTableName.name, dataSourceTable, None)
// Then, if alias is specified, wrap the table with a Subquery using the alias.
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
index f8204e183f03a..9b7afd462841c 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
@@ -45,7 +45,6 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException}
import org.apache.spark.sql.execution.QueryExecutionException
-import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.util.{CircularBuffer, Utils}
@@ -392,20 +391,7 @@ private[hive] class HiveClientImpl(
createTime = h.getTTable.getCreateTime.toLong * 1000,
lastAccessTime = h.getLastAccessTime.toLong * 1000,
storage = CatalogStorageFormat(
- locationUri = shim.getDataLocation(h).filterNot { _ =>
- // SPARK-15269: Persisted data source tables always store the location URI as a SerDe
- // property named "path" instead of standard Hive `dataLocation`, because Hive only
- // allows directory paths as location URIs while Spark SQL data source tables also
- // allows file paths. So the standard Hive `dataLocation` is meaningless for Spark SQL
- // data source tables.
- DDLUtils.isDatasourceTable(properties) &&
- h.getTableType == HiveTableType.EXTERNAL_TABLE &&
- // Spark SQL may also save external data source in Hive compatible format when
- // possible, so that these tables can be directly accessed by Hive. For these tables,
- // `dataLocation` is still necessary. Here we also check for input format class
- // because only these Hive compatible tables set this field.
- h.getInputFormatClass == null
- },
+ locationUri = shim.getDataLocation(h),
inputFormat = Option(h.getInputFormatClass).map(_.getName),
outputFormat = Option(h.getOutputFormatClass).map(_.getName),
serde = Option(h.getSerializationLib),
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala
index c74d948a6fa52..286197b50e229 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala
@@ -34,7 +34,6 @@ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.hive.{HiveInspectors, HiveShim}
import org.apache.spark.sql.sources.{Filter, _}
@@ -222,7 +221,7 @@ private[orc] class OrcOutputWriter(
private lazy val recordWriter: RecordWriter[NullWritable, Writable] = {
recordWriterInstantiated = true
- val uniqueWriteJobId = conf.get(CreateDataSourceTableUtils.DATASOURCE_WRITEJOBUUID)
+ val uniqueWriteJobId = conf.get(WriterContainer.DATASOURCE_WRITEJOBUUID)
val taskAttemptId = context.getTaskAttemptID
val partition = taskAttemptId.getTaskID.getId
val bucketString = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
index 3892fe87e2a80..571ba49d115f8 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
@@ -26,9 +26,9 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType}
-import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
-import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils._
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
+import org.apache.spark.sql.hive.HiveExternalCatalog._
+import org.apache.spark.sql.hive.client.HiveClient
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
@@ -49,6 +49,10 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
jsonFilePath = Utils.getSparkClassLoader.getResource("sample.json").getFile
}
+ // To test `HiveExternalCatalog`, we need to read the raw table metadata(schema, partition
+ // columns and bucket specification are still in table properties) from hive client.
+ private def hiveClient: HiveClient = sharedState.asInstanceOf[HiveSharedState].metadataHive
+
test("persistent JSON table") {
withTable("jsonTable") {
sql(
@@ -697,18 +701,18 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
withTable("wide_schema") {
withTempDir { tempDir =>
// We will need 80 splits for this schema if the threshold is 4000.
- val schema = StructType((1 to 5000).map(i => StructField(s"c_$i", StringType, true)))
-
- // Manually create a metastore data source table.
- createDataSourceTable(
- sparkSession = spark,
- tableIdent = TableIdentifier("wide_schema"),
+ val schema = StructType((1 to 5000).map(i => StructField(s"c_$i", StringType)))
+
+ val tableDesc = CatalogTable(
+ identifier = TableIdentifier("wide_schema"),
+ tableType = CatalogTableType.EXTERNAL,
+ storage = CatalogStorageFormat.empty.copy(
+ properties = Map("path" -> tempDir.getCanonicalPath)
+ ),
schema = schema,
- partitionColumns = Array.empty[String],
- bucketSpec = None,
- provider = "json",
- options = Map("path" -> tempDir.getCanonicalPath),
- isExternal = false)
+ provider = Some("json")
+ )
+ spark.sessionState.catalog.createTable(tableDesc, ignoreIfExists = false)
sessionState.refreshTable("wide_schema")
@@ -741,7 +745,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
DATASOURCE_SCHEMA -> schema.json,
"EXTERNAL" -> "FALSE"))
- sharedState.externalCatalog.createTable(hiveTable, ignoreIfExists = false)
+ hiveClient.createTable(hiveTable, ignoreIfExists = false)
sessionState.refreshTable(tableName)
val actualSchema = table(tableName).schema
@@ -759,7 +763,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
withTable(tableName) {
df.write.format("parquet").partitionBy("d", "b").saveAsTable(tableName)
sessionState.refreshTable(tableName)
- val metastoreTable = sharedState.externalCatalog.getTable("default", tableName)
+ val metastoreTable = hiveClient.getTable("default", tableName)
val expectedPartitionColumns = StructType(df.schema("d") :: df.schema("b") :: Nil)
val numPartCols = metastoreTable.properties(DATASOURCE_SCHEMA_NUMPARTCOLS).toInt
@@ -794,7 +798,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
.sortBy("c")
.saveAsTable(tableName)
sessionState.refreshTable(tableName)
- val metastoreTable = sharedState.externalCatalog.getTable("default", tableName)
+ val metastoreTable = hiveClient.getTable("default", tableName)
val expectedBucketByColumns = StructType(df.schema("d") :: df.schema("b") :: Nil)
val expectedSortByColumns = StructType(df.schema("c") :: Nil)
@@ -985,35 +989,37 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
withTempDir { tempPath =>
val schema = StructType((1 to 5).map(i => StructField(s"c_$i", StringType)))
- createDataSourceTable(
- sparkSession = spark,
- tableIdent = TableIdentifier("not_skip_hive_metadata"),
+ val tableDesc1 = CatalogTable(
+ identifier = TableIdentifier("not_skip_hive_metadata"),
+ tableType = CatalogTableType.EXTERNAL,
+ storage = CatalogStorageFormat.empty.copy(
+ properties = Map("path" -> tempPath.getCanonicalPath, "skipHiveMetadata" -> "false")
+ ),
schema = schema,
- partitionColumns = Array.empty[String],
- bucketSpec = None,
- provider = "parquet",
- options = Map("path" -> tempPath.getCanonicalPath, "skipHiveMetadata" -> "false"),
- isExternal = false)
+ provider = Some("parquet")
+ )
+ spark.sessionState.catalog.createTable(tableDesc1, ignoreIfExists = false)
// As a proxy for verifying that the table was stored in Hive compatible format,
// we verify that each column of the table is of native type StringType.
- assert(sharedState.externalCatalog.getTable("default", "not_skip_hive_metadata").schema
+ assert(hiveClient.getTable("default", "not_skip_hive_metadata").schema
.forall(_.dataType == StringType))
- createDataSourceTable(
- sparkSession = spark,
- tableIdent = TableIdentifier("skip_hive_metadata"),
+ val tableDesc2 = CatalogTable(
+ identifier = TableIdentifier("skip_hive_metadata", Some("default")),
+ tableType = CatalogTableType.EXTERNAL,
+ storage = CatalogStorageFormat.empty.copy(
+ properties = Map("path" -> tempPath.getCanonicalPath, "skipHiveMetadata" -> "true")
+ ),
schema = schema,
- partitionColumns = Array.empty[String],
- bucketSpec = None,
- provider = "parquet",
- options = Map("path" -> tempPath.getCanonicalPath, "skipHiveMetadata" -> "true"),
- isExternal = false)
+ provider = Some("parquet")
+ )
+ spark.sessionState.catalog.createTable(tableDesc2, ignoreIfExists = false)
// As a proxy for verifying that the table was stored in SparkSQL format,
// we verify that the table has a column type as array of StringType.
- assert(sharedState.externalCatalog.getTable("default", "skip_hive_metadata")
- .schema.forall(_.dataType == ArrayType(StringType)))
+ assert(hiveClient.getTable("default", "skip_hive_metadata").schema
+ .forall(_.dataType == ArrayType(StringType)))
}
}
@@ -1030,7 +1036,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
""".stripMargin
)
- val metastoreTable = sharedState.externalCatalog.getTable("default", "t")
+ val metastoreTable = hiveClient.getTable("default", "t")
assert(metastoreTable.properties(DATASOURCE_SCHEMA_NUMPARTCOLS).toInt === 1)
assert(!metastoreTable.properties.contains(DATASOURCE_SCHEMA_NUMBUCKETS))
assert(!metastoreTable.properties.contains(DATASOURCE_SCHEMA_NUMBUCKETCOLS))
@@ -1054,7 +1060,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
""".stripMargin
)
- val metastoreTable = sharedState.externalCatalog.getTable("default", "t")
+ val metastoreTable = hiveClient.getTable("default", "t")
assert(!metastoreTable.properties.contains(DATASOURCE_SCHEMA_NUMPARTCOLS))
assert(metastoreTable.properties(DATASOURCE_SCHEMA_NUMBUCKETS).toInt === 2)
assert(metastoreTable.properties(DATASOURCE_SCHEMA_NUMBUCKETCOLS).toInt === 1)
@@ -1076,7 +1082,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
""".stripMargin
)
- val metastoreTable = sharedState.externalCatalog.getTable("default", "t")
+ val metastoreTable = hiveClient.getTable("default", "t")
assert(!metastoreTable.properties.contains(DATASOURCE_SCHEMA_NUMPARTCOLS))
assert(metastoreTable.properties(DATASOURCE_SCHEMA_NUMBUCKETS).toInt === 2)
assert(metastoreTable.properties(DATASOURCE_SCHEMA_NUMBUCKETCOLS).toInt === 1)
@@ -1101,7 +1107,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
""".stripMargin
)
- val metastoreTable = sharedState.externalCatalog.getTable("default", "t")
+ val metastoreTable = hiveClient.getTable("default", "t")
assert(metastoreTable.properties(DATASOURCE_SCHEMA_NUMPARTCOLS).toInt === 1)
assert(metastoreTable.properties(DATASOURCE_SCHEMA_NUMBUCKETS).toInt === 2)
assert(metastoreTable.properties(DATASOURCE_SCHEMA_NUMBUCKETCOLS).toInt === 1)
@@ -1168,7 +1174,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
)
sql("insert into t values (2, 3, 4)")
checkAnswer(table("t"), Seq(Row(1, 2, 3), Row(2, 3, 4)))
- val catalogTable = sharedState.externalCatalog.getTable("default", "t")
+ val catalogTable = hiveClient.getTable("default", "t")
// there should not be a lowercase key 'path' now
assert(catalogTable.storage.properties.get("path").isEmpty)
assert(catalogTable.storage.properties.get("PATH").isDefined)
@@ -1188,4 +1194,28 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
}
}
}
+
+ test("read table with corrupted schema") {
+ try {
+ val schema = StructType(StructField("int", IntegerType, true) :: Nil)
+ val hiveTable = CatalogTable(
+ identifier = TableIdentifier("t", Some("default")),
+ tableType = CatalogTableType.MANAGED,
+ schema = new StructType,
+ storage = CatalogStorageFormat.empty,
+ properties = Map(
+ DATASOURCE_PROVIDER -> "json",
+ // no DATASOURCE_SCHEMA_NUMPARTS
+ DATASOURCE_SCHEMA_PART_PREFIX + 0 -> schema.json))
+
+ hiveClient.createTable(hiveTable, ignoreIfExists = false)
+
+ val e = intercept[AnalysisException] {
+ sharedState.externalCatalog.getTable("default", "t")
+ }.getMessage
+ assert(e.contains(s"Could not read schema from the hive metastore because it is corrupted"))
+ } finally {
+ hiveClient.dropTable("default", "t", ignoreIfNotExists = true, purge = true)
+ }
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala
index 5d510197c4d95..76aa84b19410d 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala
@@ -18,21 +18,32 @@
package org.apache.spark.sql.hive.execution
import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SaveMode}
+import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
-import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils._
+import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType}
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.types.StructType
class HiveCommandSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
import testImplicits._
protected override def beforeAll(): Unit = {
super.beforeAll()
- sql(
- """
- |CREATE TABLE parquet_tab1 (c1 INT, c2 STRING)
- |USING org.apache.spark.sql.parquet.DefaultSource
- """.stripMargin)
+
+ // Use catalog to create table instead of SQL string here, because we don't support specifying
+ // table properties for data source table with SQL API now.
+ hiveContext.sessionState.catalog.createTable(
+ CatalogTable(
+ identifier = TableIdentifier("parquet_tab1"),
+ tableType = CatalogTableType.MANAGED,
+ storage = CatalogStorageFormat.empty,
+ schema = new StructType().add("c1", "int").add("c2", "string"),
+ provider = Some("parquet"),
+ properties = Map("my_key1" -> "v1")
+ ),
+ ignoreIfExists = false
+ )
sql(
"""
@@ -101,23 +112,14 @@ class HiveCommandSuite extends QueryTest with SQLTestUtils with TestHiveSingleto
test("show tblproperties of data source tables - basic") {
checkAnswer(
- sql("SHOW TBLPROPERTIES parquet_tab1").filter(s"key = '$DATASOURCE_PROVIDER'"),
- Row(DATASOURCE_PROVIDER, "org.apache.spark.sql.parquet.DefaultSource") :: Nil
+ sql("SHOW TBLPROPERTIES parquet_tab1").filter(s"key = 'my_key1'"),
+ Row("my_key1", "v1") :: Nil
)
checkAnswer(
- sql(s"SHOW TBLPROPERTIES parquet_tab1($DATASOURCE_PROVIDER)"),
- Row("org.apache.spark.sql.parquet.DefaultSource") :: Nil
+ sql(s"SHOW TBLPROPERTIES parquet_tab1('my_key1')"),
+ Row("v1") :: Nil
)
-
- checkAnswer(
- sql("SHOW TBLPROPERTIES parquet_tab1").filter(s"key = '$DATASOURCE_SCHEMA_NUMPARTS'"),
- Row(DATASOURCE_SCHEMA_NUMPARTS, "1") :: Nil
- )
-
- checkAnswer(
- sql(s"SHOW TBLPROPERTIES parquet_tab1('$DATASOURCE_SCHEMA_NUMPARTS')"),
- Row("1"))
}
test("show tblproperties for datasource table - errors") {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala
index 970b6885f6254..f00a99b6d0b3d 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala
@@ -692,4 +692,27 @@ class HiveDDLSuite
))
}
}
+
+ test("datasource table property keys are not allowed") {
+ import org.apache.spark.sql.hive.HiveExternalCatalog.DATASOURCE_PREFIX
+
+ withTable("tbl") {
+ sql("CREATE TABLE tbl(a INT) STORED AS parquet")
+
+ val e = intercept[AnalysisException] {
+ sql(s"ALTER TABLE tbl SET TBLPROPERTIES ('${DATASOURCE_PREFIX}foo' = 'loser')")
+ }
+ assert(e.getMessage.contains(DATASOURCE_PREFIX + "foo"))
+
+ val e2 = intercept[AnalysisException] {
+ sql(s"ALTER TABLE tbl UNSET TBLPROPERTIES ('${DATASOURCE_PREFIX}foo')")
+ }
+ assert(e2.getMessage.contains(DATASOURCE_PREFIX + "foo"))
+
+ val e3 = intercept[AnalysisException] {
+ sql(s"CREATE TABLE tbl TBLPROPERTIES ('${DATASOURCE_PREFIX}foo'='anything')")
+ }
+ assert(e3.getMessage.contains(DATASOURCE_PREFIX + "foo"))
+ }
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index e6fe47aa65f34..4ca882f840a58 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, FunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.CatalogTableType
import org.apache.spark.sql.catalyst.parser.ParseException
-import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive.{HiveUtils, MetastoreRelation}
@@ -436,8 +435,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
assert(r.options("path") === location)
case None => // OK.
}
- assert(
- catalogTable.properties(CreateDataSourceTableUtils.DATASOURCE_PROVIDER) === format)
+ assert(catalogTable.provider.get === format)
case r: MetastoreRelation =>
if (isDataSourceParquet) {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
index 67a58a3859b84..906de6bbcbee5 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
@@ -29,7 +29,6 @@ import org.apache.spark.sql.{sources, Row, SparkSession}
import org.apache.spark.sql.catalyst.{expressions, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, GenericInternalRow, InterpretedPredicate, InterpretedProjection, JoinedRow, Literal}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
-import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils._
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.util.SerializableConfiguration
@@ -145,7 +144,7 @@ class AppendingTextOutputFormat(outputFile: Path) extends TextOutputFormat[NullW
override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
val configuration = context.getConfiguration
- val uniqueWriteJobId = configuration.get(DATASOURCE_WRITEJOBUUID)
+ val uniqueWriteJobId = configuration.get(WriterContainer.DATASOURCE_WRITEJOBUUID)
val taskAttemptId = context.getTaskAttemptID
val split = taskAttemptId.getTaskID.getId
val name = FileOutputFormat.getOutputName(context)
From 083de00cb608a7414aae99a639825482bebfea8a Mon Sep 17 00:00:00 2001
From: Richael
Date: Mon, 22 Aug 2016 09:01:50 +0100
Subject: [PATCH 048/270] [SPARK-17127] Make unaligned access in unsafe
available for AArch64
## # What changes were proposed in this pull request?
From the spark of version 2.0.0 , when MemoryMode.OFF_HEAP is set , whether the architecture supports unaligned access or not is checked. If the check doesn't pass, exception is raised.
We know that AArch64 also supports unaligned access , but now only i386, x86, amd64, and X86_64 are included.
I think we should include aarch64 when performing the check.
## How was this patch tested?
Unit test suite
Author: Richael
Closes #14700 from yimuxi/zym_change_unsafe.
---
.../unsafe/src/main/java/org/apache/spark/unsafe/Platform.java | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
index a2ee45c37e2b3..c892b9cdaf49c 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
@@ -55,7 +55,7 @@ public final class Platform {
// We at least know x86 and x64 support unaligned access.
String arch = System.getProperty("os.arch", "");
//noinspection DynamicRegexReplaceableByCompiledPattern
- _unaligned = arch.matches("^(i[3-6]86|x86(_64)?|x64|amd64)$");
+ _unaligned = arch.matches("^(i[3-6]86|x86(_64)?|x64|amd64|aarch64)$");
}
unaligned = _unaligned;
}
From 4b6c2cbcb109c7cef6087bae32d87cc3ddb69cf9 Mon Sep 17 00:00:00 2001
From: GraceH
Date: Mon, 22 Aug 2016 09:03:46 +0100
Subject: [PATCH 049/270] [SPARK-16968] Document additional options in jdbc
Writer
## What changes were proposed in this pull request?
(Please fill in changes proposed in this fix)
This is the document for previous JDBC Writer options.
## How was this patch tested?
(Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests)
Unit test has been added in previous PR.
(If this patch involves UI changes, please attach a screenshot; otherwise, remove this)
Author: GraceH
Closes #14683 from GraceH/jdbc_options.
---
docs/sql-programming-guide.md | 14 ++++++++++++++
1 file changed, 14 insertions(+)
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index c89286d0e49d0..28cc88c322b7e 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -1058,6 +1058,20 @@ the Data Sources API. The following options are supported:
The JDBC fetch size, which determines how many rows to fetch per round trip. This can help performance on JDBC drivers which default to low fetch size (eg. Oracle with 10 rows).
|
+
+
+ truncate |
+
+ This is a JDBC writer related option. When SaveMode.Overwrite is enabled, this option causes Spark to truncate an existing table instead of dropping and recreating it. This can be more efficient, and prevents the table metadata (e.g. indices) from being removed. However, it will not work in some cases, such as when the new data has a different schema. It defaults to false .
+ |
+
+
+
+ createTableOptions |
+
+ This is a JDBC writer related option. If specified, this option allows setting of database-specific table and partition options when creating a table. For example: CREATE TABLE t (name string) ENGINE=InnoDB.
+ |
+
From 8d35a6f68d6d733212674491cbf31bed73fada0f Mon Sep 17 00:00:00 2001
From: Davies Liu
Date: Mon, 22 Aug 2016 16:16:03 +0800
Subject: [PATCH 050/270] [SPARK-17115][SQL] decrease the threshold when split
expressions
## What changes were proposed in this pull request?
In 2.0, we change the threshold of splitting expressions from 16K to 64K, which cause very bad performance on wide table, because the generated method can't be JIT compiled by default (above the limit of 8K bytecode).
This PR will decrease it to 1K, based on the benchmark results for a wide table with 400 columns of LongType.
It also fix a bug around splitting expression in whole-stage codegen (it should not split them).
## How was this patch tested?
Added benchmark suite.
Author: Davies Liu
Closes #14692 from davies/split_exprs.
---
.../expressions/codegen/CodeGenerator.scala | 9 ++--
.../aggregate/HashAggregateExec.scala | 2 -
.../benchmark/BenchmarkWideTable.scala | 53 +++++++++++++++++++
3 files changed, 59 insertions(+), 5 deletions(-)
create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkWideTable.scala
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 16fb1f683710f..4bd9ee03f96dd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -584,15 +584,18 @@ class CodegenContext {
* @param expressions the codes to evaluate expressions.
*/
def splitExpressions(row: String, expressions: Seq[String]): String = {
- if (row == null) {
+ if (row == null || currentVars != null) {
// Cannot split these expressions because they are not created from a row object.
return expressions.mkString("\n")
}
val blocks = new ArrayBuffer[String]()
val blockBuilder = new StringBuilder()
for (code <- expressions) {
- // We can't know how many byte code will be generated, so use the number of bytes as limit
- if (blockBuilder.length > 64 * 1000) {
+ // We can't know how many bytecode will be generated, so use the length of source code
+ // as metric. A method should not go beyond 8K, otherwise it will not be JITted, should
+ // also not be too small, or it will have many function calls (for wide table), see the
+ // results in BenchmarkWideTable.
+ if (blockBuilder.length > 1024) {
blocks.append(blockBuilder.toString())
blockBuilder.clear()
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index cfc47aba889aa..bd7efa606e0ce 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -603,8 +603,6 @@ case class HashAggregateExec(
// create grouping key
ctx.currentVars = input
- // make sure that the generated code will not be splitted as multiple functions
- ctx.INPUT_ROW = null
val unsafeRowKeyCode = GenerateUnsafeProjection.createCode(
ctx, groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output)))
val vectorizedRowKeys = ctx.generateExpressions(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkWideTable.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkWideTable.scala
new file mode 100644
index 0000000000000..9dcaca0ca93ee
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkWideTable.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.benchmark
+
+import org.apache.spark.util.Benchmark
+
+
+/**
+ * Benchmark to measure performance for wide table.
+ * To run this:
+ * build/sbt "sql/test-only *benchmark.BenchmarkWideTable"
+ *
+ * Benchmarks in this file are skipped in normal builds.
+ */
+class BenchmarkWideTable extends BenchmarkBase {
+
+ ignore("project on wide table") {
+ val N = 1 << 20
+ val df = sparkSession.range(N)
+ val columns = (0 until 400).map{ i => s"id as id$i"}
+ val benchmark = new Benchmark("projection on wide table", N)
+ benchmark.addCase("wide table", numIters = 5) { iter =>
+ df.selectExpr(columns : _*).queryExecution.toRdd.count()
+ }
+ benchmark.run()
+
+ /**
+ * Here are some numbers with different split threshold:
+ *
+ * Split threshold methods Rate(M/s) Per Row(ns)
+ * 10 400 0.4 2279
+ * 100 200 0.6 1554
+ * 1k 37 0.9 1116
+ * 8k 5 0.5 2025
+ * 64k 1 0.0 21649
+ */
+ }
+}
From bd9655063bdba8836b4ec96ed115e5653e246b65 Mon Sep 17 00:00:00 2001
From: Jagadeesan
Date: Mon, 22 Aug 2016 09:30:31 +0100
Subject: [PATCH 051/270] [SPARK-17085][STREAMING][DOCUMENTATION AND ACTUAL
CODE DIFFERS - UNSUPPORTED OPERATIONS]
Changes in Spark Stuctured Streaming doc in this link
https://spark.apache.org/docs/2.0.0/structured-streaming-programming-guide.html#unsupported-operations
Author: Jagadeesan
Closes #14715 from jagadeesanas2/SPARK-17085.
---
docs/structured-streaming-programming-guide.md | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md
index e2c881bf4a604..226ff740a5d67 100644
--- a/docs/structured-streaming-programming-guide.md
+++ b/docs/structured-streaming-programming-guide.md
@@ -726,9 +726,9 @@ However, note that all of the operations applicable on static DataFrames/Dataset
+ Full outer join with a streaming Dataset is not supported
- + Left outer join with a streaming Dataset on the left is not supported
+ + Left outer join with a streaming Dataset on the right is not supported
- + Right outer join with a streaming Dataset on the right is not supported
+ + Right outer join with a streaming Dataset on the left is not supported
- Any kind of joins between two streaming Datasets are not yet supported.
From b264cbb16fb97116e630fb593adf5898a5a0e8fa Mon Sep 17 00:00:00 2001
From: Holden Karau
Date: Mon, 22 Aug 2016 12:21:22 +0200
Subject: [PATCH 052/270] [SPARK-15113][PYSPARK][ML] Add missing num features
num classes
## What changes were proposed in this pull request?
Add missing `numFeatures` and `numClasses` to the wrapped Java models in PySpark ML pipelines. Also tag `DecisionTreeClassificationModel` as Expiremental to match Scala doc.
## How was this patch tested?
Extended doctests
Author: Holden Karau
Closes #12889 from holdenk/SPARK-15113-add-missing-numFeatures-numClasses.
---
.../GeneralizedLinearRegression.scala | 2 +
python/pyspark/ml/classification.py | 37 ++++++++++++++++---
python/pyspark/ml/regression.py | 22 ++++++++---
python/pyspark/ml/util.py | 16 ++++++++
4 files changed, 66 insertions(+), 11 deletions(-)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 2bdc09e1db246..1d4dfd1147589 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -788,6 +788,8 @@ class GeneralizedLinearRegressionModel private[ml] (
@Since("2.0.0")
override def write: MLWriter =
new GeneralizedLinearRegressionModel.GeneralizedLinearRegressionModelWriter(this)
+
+ override val numFeatures: Int = coefficients.size
}
@Since("2.0.0")
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 6468007045691..33ada27454b72 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -43,6 +43,23 @@
'OneVsRest', 'OneVsRestModel']
+@inherit_doc
+class JavaClassificationModel(JavaPredictionModel):
+ """
+ (Private) Java Model produced by a ``Classifier``.
+ Classes are indexed {0, 1, ..., numClasses - 1}.
+ To be mixed in with class:`pyspark.ml.JavaModel`
+ """
+
+ @property
+ @since("2.1.0")
+ def numClasses(self):
+ """
+ Number of classes (values which the label can take).
+ """
+ return self._call_java("numClasses")
+
+
@inherit_doc
class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol,
@@ -212,7 +229,7 @@ def _checkThresholdConsistency(self):
" threshold (%g) and thresholds (equivalent to %g)" % (t2, t))
-class LogisticRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
+class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by LogisticRegression.
@@ -522,6 +539,10 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
1
>>> model.featureImportances
SparseVector(1, {0: 1.0})
+ >>> model.numFeatures
+ 1
+ >>> model.numClasses
+ 2
>>> print(model.toDebugString)
DecisionTreeClassificationModel (uid=...) of depth 1 with 3 nodes...
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
@@ -595,7 +616,8 @@ def _create_model(self, java_model):
@inherit_doc
-class DecisionTreeClassificationModel(DecisionTreeModel, JavaMLWritable, JavaMLReadable):
+class DecisionTreeClassificationModel(DecisionTreeModel, JavaClassificationModel, JavaMLWritable,
+ JavaMLReadable):
"""
Model fitted by DecisionTreeClassifier.
@@ -722,7 +744,8 @@ def _create_model(self, java_model):
return RandomForestClassificationModel(java_model)
-class RandomForestClassificationModel(TreeEnsembleModel, JavaMLWritable, JavaMLReadable):
+class RandomForestClassificationModel(TreeEnsembleModel, JavaClassificationModel, JavaMLWritable,
+ JavaMLReadable):
"""
Model fitted by RandomForestClassifier.
@@ -873,7 +896,8 @@ def getLossType(self):
return self.getOrDefault(self.lossType)
-class GBTClassificationModel(TreeEnsembleModel, JavaMLWritable, JavaMLReadable):
+class GBTClassificationModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable,
+ JavaMLReadable):
"""
Model fitted by GBTClassifier.
@@ -1027,7 +1051,7 @@ def getModelType(self):
return self.getOrDefault(self.modelType)
-class NaiveBayesModel(JavaModel, JavaMLWritable, JavaMLReadable):
+class NaiveBayesModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by NaiveBayes.
@@ -1226,7 +1250,8 @@ def getInitialWeights(self):
return self.getOrDefault(self.initialWeights)
-class MultilayerPerceptronClassificationModel(JavaModel, JavaMLWritable, JavaMLReadable):
+class MultilayerPerceptronClassificationModel(JavaModel, JavaPredictionModel, JavaMLWritable,
+ JavaMLReadable):
"""
.. note:: Experimental
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 1ae2bd4e400e8..56312f672f71d 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -88,6 +88,8 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
True
>>> model.intercept == model2.intercept
True
+ >>> model.numFeatures
+ 1
.. versionadded:: 1.4.0
"""
@@ -126,7 +128,7 @@ def _create_model(self, java_model):
return LinearRegressionModel(java_model)
-class LinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
+class LinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by :class:`LinearRegression`.
@@ -654,6 +656,8 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
3
>>> model.featureImportances
SparseVector(1, {0: 1.0})
+ >>> model.numFeatures
+ 1
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
0.0
@@ -719,7 +723,7 @@ def _create_model(self, java_model):
@inherit_doc
-class DecisionTreeModel(JavaModel):
+class DecisionTreeModel(JavaModel, JavaPredictionModel):
"""
Abstraction for Decision Tree models.
@@ -843,6 +847,8 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
0.0
+ >>> model.numFeatures
+ 1
>>> model.trees
[DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...]
>>> model.getNumTrees
@@ -909,7 +915,8 @@ def _create_model(self, java_model):
return RandomForestRegressionModel(java_model)
-class RandomForestRegressionModel(TreeEnsembleModel, JavaMLWritable, JavaMLReadable):
+class RandomForestRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable,
+ JavaMLReadable):
"""
Model fitted by :class:`RandomForestRegressor`.
@@ -958,6 +965,8 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
>>> model = gbt.fit(df)
>>> model.featureImportances
SparseVector(1, {0: 1.0})
+ >>> model.numFeatures
+ 1
>>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
True
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
@@ -1047,7 +1056,7 @@ def getLossType(self):
return self.getOrDefault(self.lossType)
-class GBTRegressionModel(TreeEnsembleModel, JavaMLWritable, JavaMLReadable):
+class GBTRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by :class:`GBTRegressor`.
@@ -1307,6 +1316,8 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
True
>>> model.coefficients
DenseVector([1.5..., -1.0...])
+ >>> model.numFeatures
+ 2
>>> abs(model.intercept - 1.5) < 0.001
True
>>> glr_path = temp_path + "/glr"
@@ -1412,7 +1423,8 @@ def getLink(self):
return self.getOrDefault(self.link)
-class GeneralizedLinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
+class GeneralizedLinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable,
+ JavaMLReadable):
"""
.. note:: Experimental
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 4a31a298096fc..7d39c30122350 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -238,3 +238,19 @@ class JavaMLReadable(MLReadable):
def read(cls):
"""Returns an MLReader instance for this class."""
return JavaMLReader(cls)
+
+
+@inherit_doc
+class JavaPredictionModel():
+ """
+ (Private) Java Model for prediction tasks (regression and classification).
+ To be mixed in with class:`pyspark.ml.JavaModel`
+ """
+
+ @property
+ @since("2.1.0")
+ def numFeatures(self):
+ """
+ Returns the number of features the model was trained on. If unknown, returns -1
+ """
+ return self._call_java("numFeatures")
From 209e1b3c0683a9106428e269e5041980b6cc327f Mon Sep 17 00:00:00 2001
From: Junyang Qian
Date: Mon, 22 Aug 2016 10:03:48 -0700
Subject: [PATCH 053/270] [SPARKR][MINOR] Fix Cache Folder Path in Windows
## What changes were proposed in this pull request?
This PR tries to fix the scheme of local cache folder in Windows. The name of the environment variable should be `LOCALAPPDATA` rather than `%LOCALAPPDATA%`.
## How was this patch tested?
Manual test in Windows 7.
Author: Junyang Qian
Closes #14743 from junyangq/SPARKR-FixWindowsInstall.
---
R/pkg/R/install.R | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/R/pkg/R/install.R b/R/pkg/R/install.R
index 987bac7bebc0e..ff81e86835ff8 100644
--- a/R/pkg/R/install.R
+++ b/R/pkg/R/install.R
@@ -212,7 +212,7 @@ hadoop_version_name <- function(hadoopVersion) {
# adapt to Spark context
spark_cache_path <- function() {
if (.Platform$OS.type == "windows") {
- winAppPath <- Sys.getenv("%LOCALAPPDATA%", unset = NA)
+ winAppPath <- Sys.getenv("LOCALAPPDATA", unset = NA)
if (is.na(winAppPath)) {
msg <- paste("%LOCALAPPDATA% not found.",
"Please define the environment variable",
From 342278c09cf6e79ed4f63422988a6bbd1e7d8a91 Mon Sep 17 00:00:00 2001
From: Sean Owen
Date: Mon, 22 Aug 2016 11:15:53 -0700
Subject: [PATCH 054/270] [SPARK-16320][DOC] Document G1 heap region's effect
on spark 2.0 vs 1.6
## What changes were proposed in this pull request?
Collect GC discussion in one section, and documenting findings about G1 GC heap region size.
## How was this patch tested?
Jekyll doc build
Author: Sean Owen
Closes #14732 from srowen/SPARK-16320.
---
docs/tuning.md | 36 +++++++++++++++++-------------------
1 file changed, 17 insertions(+), 19 deletions(-)
diff --git a/docs/tuning.md b/docs/tuning.md
index 976f2eb8a7b23..cbf37213aa724 100644
--- a/docs/tuning.md
+++ b/docs/tuning.md
@@ -122,21 +122,8 @@ large records.
`R` is the storage space within `M` where cached blocks immune to being evicted by execution.
The value of `spark.memory.fraction` should be set in order to fit this amount of heap space
-comfortably within the JVM's old or "tenured" generation. Otherwise, when much of this space is
-used for caching and execution, the tenured generation will be full, which causes the JVM to
-significantly increase time spent in garbage collection. See
-Java GC sizing documentation
-for more information.
-
-The tenured generation size is controlled by the JVM's `NewRatio` parameter, which defaults to 2,
-meaning that the tenured generation is 2 times the size of the new generation (the rest of the heap).
-So, by default, the tenured generation occupies 2/3 or about 0.66 of the heap. A value of
-0.6 for `spark.memory.fraction` keeps storage and execution memory within the old generation with
-room to spare. If `spark.memory.fraction` is increased to, say, 0.8, then `NewRatio` may have to
-increase to 6 or more.
-
-`NewRatio` is set as a JVM flag for executors, which means adding
-`spark.executor.extraJavaOptions=-XX:NewRatio=x` to a Spark job's configuration.
+comfortably within the JVM's old or "tenured" generation. See the discussion of advanced GC
+tuning below for details.
## Determining Memory Consumption
@@ -217,14 +204,22 @@ temporary objects created during task execution. Some steps which may be useful
* Check if there are too many garbage collections by collecting GC stats. If a full GC is invoked multiple times for
before a task completes, it means that there isn't enough memory available for executing tasks.
-* In the GC stats that are printed, if the OldGen is close to being full, reduce the amount of
- memory used for caching by lowering `spark.memory.storageFraction`; it is better to cache fewer
- objects than to slow down task execution!
-
* If there are too many minor collections but not many major GCs, allocating more memory for Eden would help. You
can set the size of the Eden to be an over-estimate of how much memory each task will need. If the size of Eden
is determined to be `E`, then you can set the size of the Young generation using the option `-Xmn=4/3*E`. (The scaling
up by 4/3 is to account for space used by survivor regions as well.)
+
+* In the GC stats that are printed, if the OldGen is close to being full, reduce the amount of
+ memory used for caching by lowering `spark.memory.fraction`; it is better to cache fewer
+ objects than to slow down task execution. Alternatively, consider decreasing the size of
+ the Young generation. This means lowering `-Xmn` if you've set it as above. If not, try changing the
+ value of the JVM's `NewRatio` parameter. Many JVMs default this to 2, meaning that the Old generation
+ occupies 2/3 of the heap. It should be large enough such that this fraction exceeds `spark.memory.fraction`.
+
+* Try the G1GC garbage collector with `-XX:+UseG1GC`. It can improve performance in some situations where
+ garbage collection is a bottleneck. Note that with large executor heap sizes, it may be important to
+ increase the [G1 region size](https://blogs.oracle.com/g1gc/entry/g1_gc_tuning_a_case)
+ with `-XX:G1HeapRegionSize`
* As an example, if your task is reading data from HDFS, the amount of memory used by the task can be estimated using
the size of the data block read from HDFS. Note that the size of a decompressed block is often 2 or 3 times the
@@ -237,6 +232,9 @@ Our experience suggests that the effect of GC tuning depends on your application
There are [many more tuning options](http://www.oracle.com/technetwork/java/javase/gc-tuning-6-140523.html) described online,
but at a high level, managing how frequently full GC takes place can help in reducing the overhead.
+GC tuning flags for executors can be specified by setting `spark.executor.extraJavaOptions` in
+a job's configuration.
+
# Other Considerations
## Level of Parallelism
From 0583ecda1b63a7e3f126c3276059e4f99548a741 Mon Sep 17 00:00:00 2001
From: Felix Cheung
Date: Mon, 22 Aug 2016 12:27:33 -0700
Subject: [PATCH 055/270] [SPARK-17173][SPARKR] R MLlib refactor, cleanup,
reformat, fix deprecation in test
## What changes were proposed in this pull request?
refactor, cleanup, reformat, fix deprecation in test
## How was this patch tested?
unit tests, manual tests
Author: Felix Cheung
Closes #14735 from felixcheung/rmllibutil.
---
R/pkg/R/mllib.R | 205 +++++++++++--------------
R/pkg/inst/tests/testthat/test_mllib.R | 10 +-
2 files changed, 98 insertions(+), 117 deletions(-)
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index 9a53c80aecded..b36fbcee17671 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -88,9 +88,9 @@ setClass("ALSModel", representation(jobj = "jobj"))
#' @rdname write.ml
#' @name write.ml
#' @export
-#' @seealso \link{spark.glm}, \link{glm}, \link{spark.gaussianMixture}
-#' @seealso \link{spark.als}, \link{spark.kmeans}, \link{spark.lda}, \link{spark.naiveBayes}
-#' @seealso \link{spark.survreg}, \link{spark.isoreg}
+#' @seealso \link{spark.glm}, \link{glm},
+#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans},
+#' @seealso \link{spark.lda}, \link{spark.naiveBayes}, \link{spark.survreg},
#' @seealso \link{read.ml}
NULL
@@ -101,11 +101,22 @@ NULL
#' @rdname predict
#' @name predict
#' @export
-#' @seealso \link{spark.glm}, \link{glm}, \link{spark.gaussianMixture}
-#' @seealso \link{spark.als}, \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg}
-#' @seealso \link{spark.isoreg}
+#' @seealso \link{spark.glm}, \link{glm},
+#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans},
+#' @seealso \link{spark.naiveBayes}, \link{spark.survreg},
NULL
+write_internal <- function(object, path, overwrite = FALSE) {
+ writer <- callJMethod(object@jobj, "write")
+ if (overwrite) {
+ writer <- callJMethod(writer, "overwrite")
+ }
+ invisible(callJMethod(writer, "save", path))
+}
+
+predict_internal <- function(object, newData) {
+ dataFrame(callJMethod(object@jobj, "transform", newData@sdf))
+}
#' Generalized Linear Models
#'
@@ -173,7 +184,7 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper",
"fit", formula, data@sdf, family$family, family$link,
tol, as.integer(maxIter), as.character(weightCol))
- return(new("GeneralizedLinearRegressionModel", jobj = jobj))
+ new("GeneralizedLinearRegressionModel", jobj = jobj)
})
#' Generalized Linear Models (R-compliant)
@@ -219,7 +230,7 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDat
#' @export
#' @note summary(GeneralizedLinearRegressionModel) since 2.0.0
setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"),
- function(object, ...) {
+ function(object) {
jobj <- object@jobj
is.loaded <- callJMethod(jobj, "isLoaded")
features <- callJMethod(jobj, "rFeatures")
@@ -245,7 +256,7 @@ setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"),
deviance = deviance, df.null = df.null, df.residual = df.residual,
aic = aic, iter = iter, family = family, is.loaded = is.loaded)
class(ans) <- "summary.GeneralizedLinearRegressionModel"
- return(ans)
+ ans
})
# Prints the summary of GeneralizedLinearRegressionModel
@@ -275,8 +286,7 @@ print.summary.GeneralizedLinearRegressionModel <- function(x, ...) {
" on", format(unlist(x[c("df.null", "df.residual")])), " degrees of freedom\n"),
1L, paste, collapse = " "), sep = "")
cat("AIC: ", format(x$aic, digits = 4L), "\n\n",
- "Number of Fisher Scoring iterations: ", x$iter, "\n", sep = "")
- cat("\n")
+ "Number of Fisher Scoring iterations: ", x$iter, "\n\n", sep = "")
invisible(x)
}
@@ -291,7 +301,7 @@ print.summary.GeneralizedLinearRegressionModel <- function(x, ...) {
#' @note predict(GeneralizedLinearRegressionModel) since 1.5.0
setMethod("predict", signature(object = "GeneralizedLinearRegressionModel"),
function(object, newData) {
- return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
+ predict_internal(object, newData)
})
# Makes predictions from a naive Bayes model or a model produced by spark.naiveBayes(),
@@ -305,7 +315,7 @@ setMethod("predict", signature(object = "GeneralizedLinearRegressionModel"),
#' @note predict(NaiveBayesModel) since 2.0.0
setMethod("predict", signature(object = "NaiveBayesModel"),
function(object, newData) {
- return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
+ predict_internal(object, newData)
})
# Returns the summary of a naive Bayes model produced by \code{spark.naiveBayes}
@@ -317,7 +327,7 @@ setMethod("predict", signature(object = "NaiveBayesModel"),
#' @export
#' @note summary(NaiveBayesModel) since 2.0.0
setMethod("summary", signature(object = "NaiveBayesModel"),
- function(object, ...) {
+ function(object) {
jobj <- object@jobj
features <- callJMethod(jobj, "features")
labels <- callJMethod(jobj, "labels")
@@ -328,7 +338,7 @@ setMethod("summary", signature(object = "NaiveBayesModel"),
tables <- matrix(tables, nrow = length(labels))
rownames(tables) <- unlist(labels)
colnames(tables) <- unlist(features)
- return(list(apriori = apriori, tables = tables))
+ list(apriori = apriori, tables = tables)
})
# Returns posterior probabilities from a Latent Dirichlet Allocation model produced by spark.lda()
@@ -342,7 +352,7 @@ setMethod("summary", signature(object = "NaiveBayesModel"),
#' @note spark.posterior(LDAModel) since 2.1.0
setMethod("spark.posterior", signature(object = "LDAModel", newData = "SparkDataFrame"),
function(object, newData) {
- return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
+ predict_internal(object, newData)
})
# Returns the summary of a Latent Dirichlet Allocation model produced by \code{spark.lda}
@@ -377,12 +387,11 @@ setMethod("summary", signature(object = "LDAModel"),
vocabSize <- callJMethod(jobj, "vocabSize")
topics <- dataFrame(callJMethod(jobj, "topics", maxTermsPerTopic))
vocabulary <- callJMethod(jobj, "vocabulary")
- return(list(docConcentration = unlist(docConcentration),
- topicConcentration = topicConcentration,
- logLikelihood = logLikelihood, logPerplexity = logPerplexity,
- isDistributed = isDistributed, vocabSize = vocabSize,
- topics = topics,
- vocabulary = unlist(vocabulary)))
+ list(docConcentration = unlist(docConcentration),
+ topicConcentration = topicConcentration,
+ logLikelihood = logLikelihood, logPerplexity = logPerplexity,
+ isDistributed = isDistributed, vocabSize = vocabSize,
+ topics = topics, vocabulary = unlist(vocabulary))
})
# Returns the log perplexity of a Latent Dirichlet Allocation model produced by \code{spark.lda}
@@ -395,8 +404,8 @@ setMethod("summary", signature(object = "LDAModel"),
#' @note spark.perplexity(LDAModel) since 2.1.0
setMethod("spark.perplexity", signature(object = "LDAModel", data = "SparkDataFrame"),
function(object, data) {
- return(ifelse(missing(data), callJMethod(object@jobj, "logPerplexity"),
- callJMethod(object@jobj, "computeLogPerplexity", data@sdf)))
+ ifelse(missing(data), callJMethod(object@jobj, "logPerplexity"),
+ callJMethod(object@jobj, "computeLogPerplexity", data@sdf))
})
# Saves the Latent Dirichlet Allocation model to the input path.
@@ -412,11 +421,7 @@ setMethod("spark.perplexity", signature(object = "LDAModel", data = "SparkDataFr
#' @note write.ml(LDAModel, character) since 2.1.0
setMethod("write.ml", signature(object = "LDAModel", path = "character"),
function(object, path, overwrite = FALSE) {
- writer <- callJMethod(object@jobj, "write")
- if (overwrite) {
- writer <- callJMethod(writer, "overwrite")
- }
- invisible(callJMethod(writer, "save", path))
+ write_internal(object, path, overwrite)
})
#' Isotonic Regression Model
@@ -471,9 +476,9 @@ setMethod("spark.isoreg", signature(data = "SparkDataFrame", formula = "formula"
}
jobj <- callJStatic("org.apache.spark.ml.r.IsotonicRegressionWrapper", "fit",
- data@sdf, formula, as.logical(isotonic), as.integer(featureIndex),
- as.character(weightCol))
- return(new("IsotonicRegressionModel", jobj = jobj))
+ data@sdf, formula, as.logical(isotonic), as.integer(featureIndex),
+ as.character(weightCol))
+ new("IsotonicRegressionModel", jobj = jobj)
})
# Predicted values based on an isotonicRegression model
@@ -487,7 +492,7 @@ setMethod("spark.isoreg", signature(data = "SparkDataFrame", formula = "formula"
#' @note predict(IsotonicRegressionModel) since 2.1.0
setMethod("predict", signature(object = "IsotonicRegressionModel"),
function(object, newData) {
- return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
+ predict_internal(object, newData)
})
# Get the summary of an IsotonicRegressionModel model
@@ -499,11 +504,11 @@ setMethod("predict", signature(object = "IsotonicRegressionModel"),
#' @export
#' @note summary(IsotonicRegressionModel) since 2.1.0
setMethod("summary", signature(object = "IsotonicRegressionModel"),
- function(object, ...) {
+ function(object) {
jobj <- object@jobj
boundaries <- callJMethod(jobj, "boundaries")
predictions <- callJMethod(jobj, "predictions")
- return(list(boundaries = boundaries, predictions = predictions))
+ list(boundaries = boundaries, predictions = predictions)
})
#' K-Means Clustering Model
@@ -553,7 +558,7 @@ setMethod("spark.kmeans", signature(data = "SparkDataFrame", formula = "formula"
initMode <- match.arg(initMode)
jobj <- callJStatic("org.apache.spark.ml.r.KMeansWrapper", "fit", data@sdf, formula,
as.integer(k), as.integer(maxIter), initMode)
- return(new("KMeansModel", jobj = jobj))
+ new("KMeansModel", jobj = jobj)
})
#' Get fitted result from a k-means model
@@ -576,14 +581,14 @@ setMethod("spark.kmeans", signature(data = "SparkDataFrame", formula = "formula"
#'}
#' @note fitted since 2.0.0
setMethod("fitted", signature(object = "KMeansModel"),
- function(object, method = c("centers", "classes"), ...) {
+ function(object, method = c("centers", "classes")) {
method <- match.arg(method)
jobj <- object@jobj
is.loaded <- callJMethod(jobj, "isLoaded")
if (is.loaded) {
- stop(paste("Saved-loaded k-means model does not support 'fitted' method"))
+ stop("Saved-loaded k-means model does not support 'fitted' method")
} else {
- return(dataFrame(callJMethod(jobj, "fitted", method)))
+ dataFrame(callJMethod(jobj, "fitted", method))
}
})
@@ -595,7 +600,7 @@ setMethod("fitted", signature(object = "KMeansModel"),
#' @export
#' @note summary(KMeansModel) since 2.0.0
setMethod("summary", signature(object = "KMeansModel"),
- function(object, ...) {
+ function(object) {
jobj <- object@jobj
is.loaded <- callJMethod(jobj, "isLoaded")
features <- callJMethod(jobj, "features")
@@ -610,8 +615,8 @@ setMethod("summary", signature(object = "KMeansModel"),
} else {
dataFrame(callJMethod(jobj, "cluster"))
}
- return(list(coefficients = coefficients, size = size,
- cluster = cluster, is.loaded = is.loaded))
+ list(coefficients = coefficients, size = size,
+ cluster = cluster, is.loaded = is.loaded)
})
# Predicted values based on a k-means model
@@ -623,7 +628,7 @@ setMethod("summary", signature(object = "KMeansModel"),
#' @note predict(KMeansModel) since 2.0.0
setMethod("predict", signature(object = "KMeansModel"),
function(object, newData) {
- return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
+ predict_internal(object, newData)
})
#' Naive Bayes Models
@@ -665,11 +670,11 @@ setMethod("predict", signature(object = "KMeansModel"),
#' }
#' @note spark.naiveBayes since 2.0.0
setMethod("spark.naiveBayes", signature(data = "SparkDataFrame", formula = "formula"),
- function(data, formula, smoothing = 1.0, ...) {
+ function(data, formula, smoothing = 1.0) {
formula <- paste(deparse(formula), collapse = "")
jobj <- callJStatic("org.apache.spark.ml.r.NaiveBayesWrapper", "fit",
formula, data@sdf, smoothing)
- return(new("NaiveBayesModel", jobj = jobj))
+ new("NaiveBayesModel", jobj = jobj)
})
# Saves the Bernoulli naive Bayes model to the input path.
@@ -684,11 +689,7 @@ setMethod("spark.naiveBayes", signature(data = "SparkDataFrame", formula = "form
#' @note write.ml(NaiveBayesModel, character) since 2.0.0
setMethod("write.ml", signature(object = "NaiveBayesModel", path = "character"),
function(object, path, overwrite = FALSE) {
- writer <- callJMethod(object@jobj, "write")
- if (overwrite) {
- writer <- callJMethod(writer, "overwrite")
- }
- invisible(callJMethod(writer, "save", path))
+ write_internal(object, path, overwrite)
})
# Saves the AFT survival regression model to the input path.
@@ -702,11 +703,7 @@ setMethod("write.ml", signature(object = "NaiveBayesModel", path = "character"),
#' @seealso \link{read.ml}
setMethod("write.ml", signature(object = "AFTSurvivalRegressionModel", path = "character"),
function(object, path, overwrite = FALSE) {
- writer <- callJMethod(object@jobj, "write")
- if (overwrite) {
- writer <- callJMethod(writer, "overwrite")
- }
- invisible(callJMethod(writer, "save", path))
+ write_internal(object, path, overwrite)
})
# Saves the generalized linear model to the input path.
@@ -720,11 +717,7 @@ setMethod("write.ml", signature(object = "AFTSurvivalRegressionModel", path = "c
#' @note write.ml(GeneralizedLinearRegressionModel, character) since 2.0.0
setMethod("write.ml", signature(object = "GeneralizedLinearRegressionModel", path = "character"),
function(object, path, overwrite = FALSE) {
- writer <- callJMethod(object@jobj, "write")
- if (overwrite) {
- writer <- callJMethod(writer, "overwrite")
- }
- invisible(callJMethod(writer, "save", path))
+ write_internal(object, path, overwrite)
})
# Save fitted MLlib model to the input path
@@ -738,11 +731,7 @@ setMethod("write.ml", signature(object = "GeneralizedLinearRegressionModel", pat
#' @note write.ml(KMeansModel, character) since 2.0.0
setMethod("write.ml", signature(object = "KMeansModel", path = "character"),
function(object, path, overwrite = FALSE) {
- writer <- callJMethod(object@jobj, "write")
- if (overwrite) {
- writer <- callJMethod(writer, "overwrite")
- }
- invisible(callJMethod(writer, "save", path))
+ write_internal(object, path, overwrite)
})
# Save fitted IsotonicRegressionModel to the input path
@@ -757,11 +746,7 @@ setMethod("write.ml", signature(object = "KMeansModel", path = "character"),
#' @note write.ml(IsotonicRegression, character) since 2.1.0
setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "character"),
function(object, path, overwrite = FALSE) {
- writer <- callJMethod(object@jobj, "write")
- if (overwrite) {
- writer <- callJMethod(writer, "overwrite")
- }
- invisible(callJMethod(writer, "save", path))
+ write_internal(object, path, overwrite)
})
# Save fitted MLlib model to the input path
@@ -776,11 +761,7 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char
#' @note write.ml(GaussianMixtureModel, character) since 2.1.0
setMethod("write.ml", signature(object = "GaussianMixtureModel", path = "character"),
function(object, path, overwrite = FALSE) {
- writer <- callJMethod(object@jobj, "write")
- if (overwrite) {
- writer <- callJMethod(writer, "overwrite")
- }
- invisible(callJMethod(writer, "save", path))
+ write_internal(object, path, overwrite)
})
#' Load a fitted MLlib model from the input path.
@@ -801,21 +782,21 @@ read.ml <- function(path) {
path <- suppressWarnings(normalizePath(path))
jobj <- callJStatic("org.apache.spark.ml.r.RWrappers", "load", path)
if (isInstanceOf(jobj, "org.apache.spark.ml.r.NaiveBayesWrapper")) {
- return(new("NaiveBayesModel", jobj = jobj))
+ new("NaiveBayesModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.AFTSurvivalRegressionWrapper")) {
- return(new("AFTSurvivalRegressionModel", jobj = jobj))
+ new("AFTSurvivalRegressionModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper")) {
- return(new("GeneralizedLinearRegressionModel", jobj = jobj))
+ new("GeneralizedLinearRegressionModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.KMeansWrapper")) {
- return(new("KMeansModel", jobj = jobj))
+ new("KMeansModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LDAWrapper")) {
- return(new("LDAModel", jobj = jobj))
+ new("LDAModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.IsotonicRegressionWrapper")) {
- return(new("IsotonicRegressionModel", jobj = jobj))
+ new("IsotonicRegressionModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GaussianMixtureWrapper")) {
- return(new("GaussianMixtureModel", jobj = jobj))
+ new("GaussianMixtureModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.ALSWrapper")) {
- return(new("ALSModel", jobj = jobj))
+ new("ALSModel", jobj = jobj)
} else {
stop(paste("Unsupported model: ", jobj))
}
@@ -860,7 +841,7 @@ setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula
formula <- paste(deparse(formula), collapse = "")
jobj <- callJStatic("org.apache.spark.ml.r.AFTSurvivalRegressionWrapper",
"fit", formula, data@sdf)
- return(new("AFTSurvivalRegressionModel", jobj = jobj))
+ new("AFTSurvivalRegressionModel", jobj = jobj)
})
#' Latent Dirichlet Allocation
@@ -926,7 +907,7 @@ setMethod("spark.lda", signature(data = "SparkDataFrame"),
as.numeric(subsamplingRate), topicConcentration,
as.array(docConcentration), as.array(customizedStopWords),
maxVocabSize)
- return(new("LDAModel", jobj = jobj))
+ new("LDAModel", jobj = jobj)
})
# Returns a summary of the AFT survival regression model produced by spark.survreg,
@@ -946,7 +927,7 @@ setMethod("summary", signature(object = "AFTSurvivalRegressionModel"),
coefficients <- as.matrix(unlist(coefficients))
colnames(coefficients) <- c("Value")
rownames(coefficients) <- unlist(features)
- return(list(coefficients = coefficients))
+ list(coefficients = coefficients)
})
# Makes predictions from an AFT survival regression model or a model produced by
@@ -960,7 +941,7 @@ setMethod("summary", signature(object = "AFTSurvivalRegressionModel"),
#' @note predict(AFTSurvivalRegressionModel) since 2.0.0
setMethod("predict", signature(object = "AFTSurvivalRegressionModel"),
function(object, newData) {
- return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
+ predict_internal(object, newData)
})
#' Multivariate Gaussian Mixture Model (GMM)
@@ -1014,7 +995,7 @@ setMethod("spark.gaussianMixture", signature(data = "SparkDataFrame", formula =
formula <- paste(deparse(formula), collapse = "")
jobj <- callJStatic("org.apache.spark.ml.r.GaussianMixtureWrapper", "fit", data@sdf,
formula, as.integer(k), as.integer(maxIter), as.numeric(tol))
- return(new("GaussianMixtureModel", jobj = jobj))
+ new("GaussianMixtureModel", jobj = jobj)
})
# Get the summary of a multivariate gaussian mixture model
@@ -1027,7 +1008,7 @@ setMethod("spark.gaussianMixture", signature(data = "SparkDataFrame", formula =
#' @export
#' @note summary(GaussianMixtureModel) since 2.1.0
setMethod("summary", signature(object = "GaussianMixtureModel"),
- function(object, ...) {
+ function(object) {
jobj <- object@jobj
is.loaded <- callJMethod(jobj, "isLoaded")
lambda <- unlist(callJMethod(jobj, "lambda"))
@@ -1052,8 +1033,8 @@ setMethod("summary", signature(object = "GaussianMixtureModel"),
} else {
dataFrame(callJMethod(jobj, "posterior"))
}
- return(list(lambda = lambda, mu = mu, sigma = sigma,
- posterior = posterior, is.loaded = is.loaded))
+ list(lambda = lambda, mu = mu, sigma = sigma,
+ posterior = posterior, is.loaded = is.loaded)
})
# Predicted values based on a gaussian mixture model
@@ -1067,7 +1048,7 @@ setMethod("summary", signature(object = "GaussianMixtureModel"),
#' @note predict(GaussianMixtureModel) since 2.1.0
setMethod("predict", signature(object = "GaussianMixtureModel"),
function(object, newData) {
- return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
+ predict_internal(object, newData)
})
#' Alternating Least Squares (ALS) for Collaborative Filtering
@@ -1149,7 +1130,7 @@ setMethod("spark.als", signature(data = "SparkDataFrame"),
reg, as.integer(maxIter), implicitPrefs, alpha, nonnegative,
as.integer(numUserBlocks), as.integer(numItemBlocks),
as.integer(checkpointInterval), as.integer(seed))
- return(new("ALSModel", jobj = jobj))
+ new("ALSModel", jobj = jobj)
})
# Returns a summary of the ALS model produced by spark.als.
@@ -1163,17 +1144,17 @@ setMethod("spark.als", signature(data = "SparkDataFrame"),
#' @export
#' @note summary(ALSModel) since 2.1.0
setMethod("summary", signature(object = "ALSModel"),
-function(object, ...) {
- jobj <- object@jobj
- user <- callJMethod(jobj, "userCol")
- item <- callJMethod(jobj, "itemCol")
- rating <- callJMethod(jobj, "ratingCol")
- userFactors <- dataFrame(callJMethod(jobj, "userFactors"))
- itemFactors <- dataFrame(callJMethod(jobj, "itemFactors"))
- rank <- callJMethod(jobj, "rank")
- return(list(user = user, item = item, rating = rating, userFactors = userFactors,
- itemFactors = itemFactors, rank = rank))
-})
+ function(object) {
+ jobj <- object@jobj
+ user <- callJMethod(jobj, "userCol")
+ item <- callJMethod(jobj, "itemCol")
+ rating <- callJMethod(jobj, "ratingCol")
+ userFactors <- dataFrame(callJMethod(jobj, "userFactors"))
+ itemFactors <- dataFrame(callJMethod(jobj, "itemFactors"))
+ rank <- callJMethod(jobj, "rank")
+ list(user = user, item = item, rating = rating, userFactors = userFactors,
+ itemFactors = itemFactors, rank = rank)
+ })
# Makes predictions from an ALS model or a model produced by spark.als.
@@ -1185,9 +1166,9 @@ function(object, ...) {
#' @export
#' @note predict(ALSModel) since 2.1.0
setMethod("predict", signature(object = "ALSModel"),
-function(object, newData) {
- return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
-})
+ function(object, newData) {
+ predict_internal(object, newData)
+ })
# Saves the ALS model to the input path.
@@ -1203,10 +1184,6 @@ function(object, newData) {
#' @seealso \link{read.ml}
#' @note write.ml(ALSModel, character) since 2.1.0
setMethod("write.ml", signature(object = "ALSModel", path = "character"),
-function(object, path, overwrite = FALSE) {
- writer <- callJMethod(object@jobj, "write")
- if (overwrite) {
- writer <- callJMethod(writer, "overwrite")
- }
- invisible(callJMethod(writer, "save", path))
-})
+ function(object, path, overwrite = FALSE) {
+ write_internal(object, path, overwrite)
+ })
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index d15c2393b94ac..de9bd48662c3a 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -95,6 +95,10 @@ test_that("spark.glm summary", {
expect_equal(stats$df.residual, rStats$df.residual)
expect_equal(stats$aic, rStats$aic)
+ out <- capture.output(print(stats))
+ expect_match(out[2], "Deviance Residuals:")
+ expect_true(any(grepl("AIC: 59.22", out)))
+
# binomial family
df <- suppressWarnings(createDataFrame(iris))
training <- df[df$Species %in% c("versicolor", "virginica"), ]
@@ -409,7 +413,7 @@ test_that("spark.naiveBayes", {
# Test e1071::naiveBayes
if (requireNamespace("e1071", quietly = TRUE)) {
- expect_that(m <- e1071::naiveBayes(Survived ~ ., data = t1), not(throws_error()))
+ expect_error(m <- e1071::naiveBayes(Survived ~ ., data = t1), NA)
expect_equal(as.character(predict(m, t1[1, ])), "Yes")
}
})
@@ -487,7 +491,7 @@ test_that("spark.isotonicRegression", {
weightCol = "weight")
# only allow one variable on the right hand side of the formula
expect_error(model2 <- spark.isoreg(df, ~., isotonic = FALSE))
- result <- summary(model, df)
+ result <- summary(model)
expect_equal(result$predictions, list(7, 5, 4, 4, 1))
# Test model prediction
@@ -503,7 +507,7 @@ test_that("spark.isotonicRegression", {
expect_error(write.ml(model, modelPath))
write.ml(model, modelPath, overwrite = TRUE)
model2 <- read.ml(modelPath)
- expect_equal(result, summary(model2, df))
+ expect_equal(result, summary(model2))
unlink(modelPath)
})
From 6f3cd36f93c11265449fdce3323e139fec8ab22d Mon Sep 17 00:00:00 2001
From: Shivaram Venkataraman
Date: Mon, 22 Aug 2016 12:53:52 -0700
Subject: [PATCH 056/270] [SPARKR][MINOR] Add Xiangrui and Felix to maintainers
## What changes were proposed in this pull request?
This change adds Xiangrui Meng and Felix Cheung to the maintainers field in the package description.
## How was this patch tested?
(Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests)
(If this patch involves UI changes, please attach a screenshot; otherwise, remove this)
Author: Shivaram Venkataraman
Closes #14758 from shivaram/sparkr-maintainers.
---
R/pkg/DESCRIPTION | 2 ++
1 file changed, 2 insertions(+)
diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION
index 357ab007931f5..d81f1a3d4de68 100644
--- a/R/pkg/DESCRIPTION
+++ b/R/pkg/DESCRIPTION
@@ -5,6 +5,8 @@ Version: 2.0.0
Date: 2016-07-07
Author: The Apache Software Foundation
Maintainer: Shivaram Venkataraman
+ Xiangrui Meng
+ Felix Cheung
Depends:
R (>= 3.0),
methods
From 929cb8beed9b7014231580cc002853236a5337d6 Mon Sep 17 00:00:00 2001
From: Sean Zhong
Date: Mon, 22 Aug 2016 13:31:38 -0700
Subject: [PATCH 057/270] [MINOR][SQL] Fix some typos in comments and test
hints
## What changes were proposed in this pull request?
Fix some typos in comments and test hints
## How was this patch tested?
N/A.
Author: Sean Zhong
Closes #14755 from clockfly/fix_minor_typo.
---
.../apache/spark/sql/execution/UnsafeKVExternalSorter.java | 2 +-
.../execution/aggregate/TungstenAggregationIterator.scala | 6 +++---
.../src/test/scala/org/apache/spark/sql/QueryTest.scala | 6 +++---
3 files changed, 7 insertions(+), 7 deletions(-)
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
index eb105bd09a3ea..0d51dc9ff8a85 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -99,7 +99,7 @@ public UnsafeKVExternalSorter(
// The array will be used to do in-place sort, which require half of the space to be empty.
assert(map.numKeys() <= map.getArray().size() / 2);
// During spilling, the array in map will not be used, so we can borrow that and use it
- // as the underline array for in-memory sorter (it's always large enough).
+ // as the underlying array for in-memory sorter (it's always large enough).
// Since we will not grow the array, it's fine to pass `null` as consumer.
final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter(
null, taskMemoryManager, recordComparator, prefixComparator, map.getArray(),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index 4b8adf5230717..4e072a92cc772 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -32,9 +32,9 @@ import org.apache.spark.unsafe.KVIterator
* An iterator used to evaluate aggregate functions. It operates on [[UnsafeRow]]s.
*
* This iterator first uses hash-based aggregation to process input rows. It uses
- * a hash map to store groups and their corresponding aggregation buffers. If we
- * this map cannot allocate memory from memory manager, it spill the map into disk
- * and create a new one. After processed all the input, then merge all the spills
+ * a hash map to store groups and their corresponding aggregation buffers. If
+ * this map cannot allocate memory from memory manager, it spills the map into disk
+ * and creates a new one. After processed all the input, then merge all the spills
* together using external sorter, and do sort-based aggregation.
*
* The process has the following step:
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 484e4380331f8..c7af40227d45f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -358,11 +358,11 @@ abstract class QueryTest extends PlanTest {
*/
def assertEmptyMissingInput(query: Dataset[_]): Unit = {
assert(query.queryExecution.analyzed.missingInput.isEmpty,
- s"The analyzed logical plan has missing inputs: ${query.queryExecution.analyzed}")
+ s"The analyzed logical plan has missing inputs:\n${query.queryExecution.analyzed}")
assert(query.queryExecution.optimizedPlan.missingInput.isEmpty,
- s"The optimized logical plan has missing inputs: ${query.queryExecution.optimizedPlan}")
+ s"The optimized logical plan has missing inputs:\n${query.queryExecution.optimizedPlan}")
assert(query.queryExecution.executedPlan.missingInput.isEmpty,
- s"The physical plan has missing inputs: ${query.queryExecution.executedPlan}")
+ s"The physical plan has missing inputs:\n${query.queryExecution.executedPlan}")
}
}
From 84770b59f773f132073cd2af4204957fc2d7bf35 Mon Sep 17 00:00:00 2001
From: Eric Liang
Date: Mon, 22 Aug 2016 15:48:35 -0700
Subject: [PATCH 058/270] [SPARK-17162] Range does not support SQL generation
## What changes were proposed in this pull request?
The range operator previously didn't support SQL generation, which made it not possible to use in views.
## How was this patch tested?
Unit tests.
cc hvanhovell
Author: Eric Liang
Closes #14724 from ericl/spark-17162.
---
.../ResolveTableValuedFunctions.scala | 11 ++++------
.../plans/logical/basicLogicalOperators.scala | 21 ++++++++++++-------
.../spark/sql/catalyst/SQLBuilder.scala | 3 +++
.../execution/basicPhysicalOperators.scala | 2 +-
.../spark/sql/execution/command/views.scala | 3 +--
sql/hive/src/test/resources/sqlgen/range.sql | 4 ++++
.../resources/sqlgen/range_with_splits.sql | 4 ++++
.../sql/catalyst/LogicalPlanToSQLSuite.scala | 14 ++++++++++++-
8 files changed, 44 insertions(+), 18 deletions(-)
create mode 100644 sql/hive/src/test/resources/sqlgen/range.sql
create mode 100644 sql/hive/src/test/resources/sqlgen/range_with_splits.sql
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala
index 7fdf7fa0c06a3..6b3bb68538dd1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala
@@ -28,9 +28,6 @@ import org.apache.spark.sql.types.{DataType, IntegerType, LongType}
* Rule that resolves table-valued function references.
*/
object ResolveTableValuedFunctions extends Rule[LogicalPlan] {
- private lazy val defaultParallelism =
- SparkContext.getOrCreate(new SparkConf(false)).defaultParallelism
-
/**
* List of argument names and their types, used to declare a function.
*/
@@ -84,25 +81,25 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] {
"range" -> Map(
/* range(end) */
tvf("end" -> LongType) { case Seq(end: Long) =>
- Range(0, end, 1, defaultParallelism)
+ Range(0, end, 1, None)
},
/* range(start, end) */
tvf("start" -> LongType, "end" -> LongType) { case Seq(start: Long, end: Long) =>
- Range(start, end, 1, defaultParallelism)
+ Range(start, end, 1, None)
},
/* range(start, end, step) */
tvf("start" -> LongType, "end" -> LongType, "step" -> LongType) {
case Seq(start: Long, end: Long, step: Long) =>
- Range(start, end, step, defaultParallelism)
+ Range(start, end, step, None)
},
/* range(start, end, step, numPartitions) */
tvf("start" -> LongType, "end" -> LongType, "step" -> LongType,
"numPartitions" -> IntegerType) {
case Seq(start: Long, end: Long, step: Long, numPartitions: Int) =>
- Range(start, end, step, numPartitions)
+ Range(start, end, step, Some(numPartitions))
})
)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index af1736e60799b..010aec7ba1a42 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -422,17 +422,20 @@ case class Sort(
/** Factory for constructing new `Range` nodes. */
object Range {
- def apply(start: Long, end: Long, step: Long, numSlices: Int): Range = {
+ def apply(start: Long, end: Long, step: Long, numSlices: Option[Int]): Range = {
val output = StructType(StructField("id", LongType, nullable = false) :: Nil).toAttributes
new Range(start, end, step, numSlices, output)
}
+ def apply(start: Long, end: Long, step: Long, numSlices: Int): Range = {
+ Range(start, end, step, Some(numSlices))
+ }
}
case class Range(
start: Long,
end: Long,
step: Long,
- numSlices: Int,
+ numSlices: Option[Int],
output: Seq[Attribute])
extends LeafNode with MultiInstanceRelation {
@@ -449,6 +452,14 @@ case class Range(
}
}
+ def toSQL(): String = {
+ if (numSlices.isDefined) {
+ s"SELECT id AS `${output.head.name}` FROM range($start, $end, $step, ${numSlices.get})"
+ } else {
+ s"SELECT id AS `${output.head.name}` FROM range($start, $end, $step)"
+ }
+ }
+
override def newInstance(): Range = copy(output = output.map(_.newInstance()))
override lazy val statistics: Statistics = {
@@ -457,11 +468,7 @@ case class Range(
}
override def simpleString: String = {
- if (step == 1) {
- s"Range ($start, $end, splits=$numSlices)"
- } else {
- s"Range ($start, $end, step=$step, splits=$numSlices)"
- }
+ s"Range ($start, $end, step=$step, splits=$numSlices)"
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala
index af1de511da060..dde91b0a8606e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala
@@ -208,6 +208,9 @@ class SQLBuilder private (
case p: LocalRelation =>
p.toSQL(newSubqueryName())
+ case p: Range =>
+ p.toSQL()
+
case OneRowRelation =>
""
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index ad8a71689895b..3562083b06740 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -318,7 +318,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
def start: Long = range.start
def step: Long = range.step
- def numSlices: Int = range.numSlices
+ def numSlices: Int = range.numSlices.getOrElse(sparkContext.defaultParallelism)
def numElements: BigInt = range.numElements
override val output: Seq[Attribute] = range.output
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala
index e397cfa058e24..f0d7b64c3c160 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala
@@ -179,8 +179,7 @@ case class CreateViewCommand(
sparkSession.sql(viewSQL).queryExecution.assertAnalyzed()
} catch {
case NonFatal(e) =>
- throw new RuntimeException(
- "Failed to analyze the canonicalized SQL. It is possible there is a bug in Spark.", e)
+ throw new RuntimeException(s"Failed to analyze the canonicalized SQL: ${viewSQL}", e)
}
val viewSchema = if (userSpecifiedColumns.isEmpty) {
diff --git a/sql/hive/src/test/resources/sqlgen/range.sql b/sql/hive/src/test/resources/sqlgen/range.sql
new file mode 100644
index 0000000000000..53c72ea71e6ac
--- /dev/null
+++ b/sql/hive/src/test/resources/sqlgen/range.sql
@@ -0,0 +1,4 @@
+-- This file is automatically generated by LogicalPlanToSQLSuite.
+select * from range(100)
+--------------------------------------------------------------------------------
+SELECT `gen_attr_0` AS `id` FROM (SELECT `gen_attr_0` FROM (SELECT id AS `gen_attr_0` FROM range(0, 100, 1)) AS gen_subquery_0) AS gen_subquery_1
diff --git a/sql/hive/src/test/resources/sqlgen/range_with_splits.sql b/sql/hive/src/test/resources/sqlgen/range_with_splits.sql
new file mode 100644
index 0000000000000..83d637d54a302
--- /dev/null
+++ b/sql/hive/src/test/resources/sqlgen/range_with_splits.sql
@@ -0,0 +1,4 @@
+-- This file is automatically generated by LogicalPlanToSQLSuite.
+select * from range(1, 100, 20, 10)
+--------------------------------------------------------------------------------
+SELECT `gen_attr_0` AS `id` FROM (SELECT `gen_attr_0` FROM (SELECT id AS `gen_attr_0` FROM range(1, 100, 20, 10)) AS gen_subquery_0) AS gen_subquery_1
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala
index 742b065891a8e..9c6da6a628dcf 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala
@@ -23,7 +23,10 @@ import java.nio.file.{Files, NoSuchFileException, Paths}
import scala.util.control.NonFatal
import org.apache.spark.sql.Column
+import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
+import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.parser.ParseException
+import org.apache.spark.sql.catalyst.plans.logical.LeafNode
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
@@ -180,7 +183,11 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
}
test("Test should fail if the SQL query cannot be regenerated") {
- spark.range(10).createOrReplaceTempView("not_sql_gen_supported_table_so_far")
+ case class Unsupported() extends LeafNode with MultiInstanceRelation {
+ override def newInstance(): Unsupported = copy()
+ override def output: Seq[Attribute] = Nil
+ }
+ Unsupported().createOrReplaceTempView("not_sql_gen_supported_table_so_far")
sql("select * from not_sql_gen_supported_table_so_far")
val m3 = intercept[org.scalatest.exceptions.TestFailedException] {
checkSQL("select * from not_sql_gen_supported_table_so_far", "in")
@@ -196,6 +203,11 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
}
}
+ test("range") {
+ checkSQL("select * from range(100)", "range")
+ checkSQL("select * from range(1, 100, 20, 10)", "range_with_splits")
+ }
+
test("in") {
checkSQL("SELECT id FROM parquet_t0 WHERE id IN (1, 2, 3)", "in")
}
From 71afeeea4ec8e67edc95b5d504c557c88a2598b9 Mon Sep 17 00:00:00 2001
From: Felix Cheung
Date: Mon, 22 Aug 2016 15:53:10 -0700
Subject: [PATCH 059/270] [SPARK-16508][SPARKR] doc updates and more CRAN check
fixes
## What changes were proposed in this pull request?
replace ``` ` ``` in code doc with `\code{thing}`
remove added `...` for drop(DataFrame)
fix remaining CRAN check warnings
## How was this patch tested?
create doc with knitr
junyangq
Author: Felix Cheung
Closes #14734 from felixcheung/rdoccleanup.
---
R/pkg/NAMESPACE | 6 +++-
R/pkg/R/DataFrame.R | 71 ++++++++++++++++++++++----------------------
R/pkg/R/RDD.R | 10 +++----
R/pkg/R/SQLContext.R | 30 +++++++++----------
R/pkg/R/WindowSpec.R | 23 +++++++-------
R/pkg/R/column.R | 2 +-
R/pkg/R/functions.R | 36 +++++++++++-----------
R/pkg/R/generics.R | 15 +++++-----
R/pkg/R/group.R | 1 +
R/pkg/R/mllib.R | 19 ++++++------
R/pkg/R/pairRDD.R | 6 ++--
R/pkg/R/stats.R | 14 ++++-----
12 files changed, 119 insertions(+), 114 deletions(-)
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index e1b87b28d35ae..709057675e578 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -1,5 +1,9 @@
# Imports from base R
-importFrom(methods, setGeneric, setMethod, setOldClass)
+# Do not include stats:: "rpois", "runif" - causes error at runtime
+importFrom("methods", "setGeneric", "setMethod", "setOldClass")
+importFrom("methods", "is", "new", "signature", "show")
+importFrom("stats", "gaussian", "setNames")
+importFrom("utils", "download.file", "packageVersion", "untar")
# Disable native libraries till we figure out how to package it
# See SPARKR-7839
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 540dc3122dd6d..52a6628ad7b32 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -150,7 +150,7 @@ setMethod("explain",
#' isLocal
#'
-#' Returns True if the `collect` and `take` methods can be run locally
+#' Returns True if the \code{collect} and \code{take} methods can be run locally
#' (without any Spark executors).
#'
#' @param x A SparkDataFrame
@@ -182,7 +182,7 @@ setMethod("isLocal",
#' @param numRows the number of rows to print. Defaults to 20.
#' @param truncate whether truncate long strings. If \code{TRUE}, strings more than
#' 20 characters will be truncated. However, if set greater than zero,
-#' truncates strings longer than `truncate` characters and all cells
+#' truncates strings longer than \code{truncate} characters and all cells
#' will be aligned right.
#' @param ... further arguments to be passed to or from other methods.
#' @family SparkDataFrame functions
@@ -642,10 +642,10 @@ setMethod("unpersist",
#' The following options for repartition are possible:
#' \itemize{
#' \item{1.} {Return a new SparkDataFrame partitioned by
-#' the given columns into `numPartitions`.}
-#' \item{2.} {Return a new SparkDataFrame that has exactly `numPartitions`.}
+#' the given columns into \code{numPartitions}.}
+#' \item{2.} {Return a new SparkDataFrame that has exactly \code{numPartitions}.}
#' \item{3.} {Return a new SparkDataFrame partitioned by the given column(s),
-#' using `spark.sql.shuffle.partitions` as number of partitions.}
+#' using \code{spark.sql.shuffle.partitions} as number of partitions.}
#'}
#' @param x a SparkDataFrame.
#' @param numPartitions the number of partitions to use.
@@ -1132,9 +1132,8 @@ setMethod("take",
#' Head
#'
-#' Return the first NUM rows of a SparkDataFrame as a R data.frame. If NUM is NULL,
-#' then head() returns the first 6 rows in keeping with the current data.frame
-#' convention in R.
+#' Return the first \code{num} rows of a SparkDataFrame as a R data.frame. If \code{num} is not
+#' specified, then head() returns the first 6 rows as with R data.frame.
#'
#' @param x a SparkDataFrame.
#' @param num the number of rows to return. Default is 6.
@@ -1406,11 +1405,11 @@ setMethod("dapplyCollect",
#'
#' @param cols grouping columns.
#' @param func a function to be applied to each group partition specified by grouping
-#' column of the SparkDataFrame. The function `func` takes as argument
+#' column of the SparkDataFrame. The function \code{func} takes as argument
#' a key - grouping columns and a data frame - a local R data.frame.
-#' The output of `func` is a local R data.frame.
+#' The output of \code{func} is a local R data.frame.
#' @param schema the schema of the resulting SparkDataFrame after the function is applied.
-#' The schema must match to output of `func`. It has to be defined for each
+#' The schema must match to output of \code{func}. It has to be defined for each
#' output column with preferred output column name and corresponding data type.
#' @return A SparkDataFrame.
#' @family SparkDataFrame functions
@@ -1497,9 +1496,9 @@ setMethod("gapply",
#'
#' @param cols grouping columns.
#' @param func a function to be applied to each group partition specified by grouping
-#' column of the SparkDataFrame. The function `func` takes as argument
+#' column of the SparkDataFrame. The function \code{func} takes as argument
#' a key - grouping columns and a data frame - a local R data.frame.
-#' The output of `func` is a local R data.frame.
+#' The output of \code{func} is a local R data.frame.
#' @return A data.frame.
#' @family SparkDataFrame functions
#' @aliases gapplyCollect,SparkDataFrame-method
@@ -1657,7 +1656,7 @@ setMethod("$", signature(x = "SparkDataFrame"),
getColumn(x, name)
})
-#' @param value a Column or NULL. If NULL, the specified Column is dropped.
+#' @param value a Column or \code{NULL}. If \code{NULL}, the specified Column is dropped.
#' @rdname select
#' @name $<-
#' @aliases $<-,SparkDataFrame-method
@@ -1747,7 +1746,7 @@ setMethod("[", signature(x = "SparkDataFrame"),
#' @family subsetting functions
#' @examples
#' \dontrun{
-#' # Columns can be selected using `[[` and `[`
+#' # Columns can be selected using [[ and [
#' df[[2]] == df[["age"]]
#' df[,2] == df[,"age"]
#' df[,c("name", "age")]
@@ -1792,7 +1791,7 @@ setMethod("subset", signature(x = "SparkDataFrame"),
#' select(df, df$name, df$age + 1)
#' select(df, c("col1", "col2"))
#' select(df, list(df$name, df$age + 1))
-#' # Similar to R data frames columns can also be selected using `$`
+#' # Similar to R data frames columns can also be selected using $
#' df[,df$age]
#' }
#' @note select(SparkDataFrame, character) since 1.4.0
@@ -2443,7 +2442,7 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) {
#' Return a new SparkDataFrame containing the union of rows
#'
#' Return a new SparkDataFrame containing the union of rows in this SparkDataFrame
-#' and another SparkDataFrame. This is equivalent to `UNION ALL` in SQL.
+#' and another SparkDataFrame. This is equivalent to \code{UNION ALL} in SQL.
#' Note that this does not remove duplicate rows across the two SparkDataFrames.
#'
#' @param x A SparkDataFrame
@@ -2486,7 +2485,7 @@ setMethod("unionAll",
#' Union two or more SparkDataFrames
#'
-#' Union two or more SparkDataFrames. This is equivalent to `UNION ALL` in SQL.
+#' Union two or more SparkDataFrames. This is equivalent to \code{UNION ALL} in SQL.
#' Note that this does not remove duplicate rows across the two SparkDataFrames.
#'
#' @param x a SparkDataFrame.
@@ -2519,7 +2518,7 @@ setMethod("rbind",
#' Intersect
#'
#' Return a new SparkDataFrame containing rows only in both this SparkDataFrame
-#' and another SparkDataFrame. This is equivalent to `INTERSECT` in SQL.
+#' and another SparkDataFrame. This is equivalent to \code{INTERSECT} in SQL.
#'
#' @param x A SparkDataFrame
#' @param y A SparkDataFrame
@@ -2547,7 +2546,7 @@ setMethod("intersect",
#' except
#'
#' Return a new SparkDataFrame containing rows in this SparkDataFrame
-#' but not in another SparkDataFrame. This is equivalent to `EXCEPT` in SQL.
+#' but not in another SparkDataFrame. This is equivalent to \code{EXCEPT} in SQL.
#'
#' @param x a SparkDataFrame.
#' @param y a SparkDataFrame.
@@ -2576,8 +2575,8 @@ setMethod("except",
#' Save the contents of SparkDataFrame to a data source.
#'
-#' The data source is specified by the `source` and a set of options (...).
-#' If `source` is not specified, the default data source configured by
+#' The data source is specified by the \code{source} and a set of options (...).
+#' If \code{source} is not specified, the default data source configured by
#' spark.sql.sources.default will be used.
#'
#' Additionally, mode is used to specify the behavior of the save operation when data already
@@ -2613,7 +2612,7 @@ setMethod("except",
#' @note write.df since 1.4.0
setMethod("write.df",
signature(df = "SparkDataFrame", path = "character"),
- function(df, path, source = NULL, mode = "error", ...){
+ function(df, path, source = NULL, mode = "error", ...) {
if (is.null(source)) {
source <- getDefaultSqlSource()
}
@@ -2635,14 +2634,14 @@ setMethod("write.df",
#' @note saveDF since 1.4.0
setMethod("saveDF",
signature(df = "SparkDataFrame", path = "character"),
- function(df, path, source = NULL, mode = "error", ...){
+ function(df, path, source = NULL, mode = "error", ...) {
write.df(df, path, source, mode, ...)
})
#' Save the contents of the SparkDataFrame to a data source as a table
#'
-#' The data source is specified by the `source` and a set of options (...).
-#' If `source` is not specified, the default data source configured by
+#' The data source is specified by the \code{source} and a set of options (...).
+#' If \code{source} is not specified, the default data source configured by
#' spark.sql.sources.default will be used.
#'
#' Additionally, mode is used to specify the behavior of the save operation when
@@ -2675,7 +2674,7 @@ setMethod("saveDF",
#' @note saveAsTable since 1.4.0
setMethod("saveAsTable",
signature(df = "SparkDataFrame", tableName = "character"),
- function(df, tableName, source = NULL, mode="error", ...){
+ function(df, tableName, source = NULL, mode="error", ...) {
if (is.null(source)) {
source <- getDefaultSqlSource()
}
@@ -2752,11 +2751,11 @@ setMethod("summary",
#' @param how "any" or "all".
#' if "any", drop a row if it contains any nulls.
#' if "all", drop a row only if all its values are null.
-#' if minNonNulls is specified, how is ignored.
+#' if \code{minNonNulls} is specified, how is ignored.
#' @param minNonNulls if specified, drop rows that have less than
-#' minNonNulls non-null values.
+#' \code{minNonNulls} non-null values.
#' This overwrites the how parameter.
-#' @param cols optional list of column names to consider. In `fillna`,
+#' @param cols optional list of column names to consider. In \code{fillna},
#' columns specified in cols that do not have matching data
#' type are ignored. For example, if value is a character, and
#' subset contains a non-character column, then the non-character
@@ -2879,8 +2878,8 @@ setMethod("fillna",
#' in your system to accommodate the contents.
#'
#' @param x a SparkDataFrame.
-#' @param row.names NULL or a character vector giving the row names for the data frame.
-#' @param optional If `TRUE`, converting column names is optional.
+#' @param row.names \code{NULL} or a character vector giving the row names for the data frame.
+#' @param optional If \code{TRUE}, converting column names is optional.
#' @param ... additional arguments to pass to base::as.data.frame.
#' @return A data.frame.
#' @family SparkDataFrame functions
@@ -3058,7 +3057,7 @@ setMethod("str",
#' @note drop since 2.0.0
setMethod("drop",
signature(x = "SparkDataFrame"),
- function(x, col, ...) {
+ function(x, col) {
stopifnot(class(col) == "character" || class(col) == "Column")
if (class(col) == "Column") {
@@ -3218,8 +3217,8 @@ setMethod("histogram",
#' and to not change the existing data.
#' }
#'
-#' @param x s SparkDataFrame.
-#' @param url JDBC database url of the form `jdbc:subprotocol:subname`.
+#' @param x a SparkDataFrame.
+#' @param url JDBC database url of the form \code{jdbc:subprotocol:subname}.
#' @param tableName yhe name of the table in the external database.
#' @param mode one of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default).
#' @param ... additional JDBC database connection properties.
@@ -3237,7 +3236,7 @@ setMethod("histogram",
#' @note write.jdbc since 2.0.0
setMethod("write.jdbc",
signature(x = "SparkDataFrame", url = "character", tableName = "character"),
- function(x, url, tableName, mode = "error", ...){
+ function(x, url, tableName, mode = "error", ...) {
jmode <- convertToJSaveMode(mode)
jprops <- varargsToJProperties(...)
write <- callJMethod(x@sdf, "write")
diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R
index 6b254bb0d302c..6cd0704003f1a 100644
--- a/R/pkg/R/RDD.R
+++ b/R/pkg/R/RDD.R
@@ -887,17 +887,17 @@ setMethod("sampleRDD",
# Discards some random values to ensure each partition has a
# different random seed.
- runif(partIndex)
+ stats::runif(partIndex)
for (elem in part) {
if (withReplacement) {
- count <- rpois(1, fraction)
+ count <- stats::rpois(1, fraction)
if (count > 0) {
res[ (len + 1) : (len + count) ] <- rep(list(elem), count)
len <- len + count
}
} else {
- if (runif(1) < fraction) {
+ if (stats::runif(1) < fraction) {
len <- len + 1
res[[len]] <- elem
}
@@ -965,7 +965,7 @@ setMethod("takeSample", signature(x = "RDD", withReplacement = "logical",
set.seed(seed)
samples <- collectRDD(sampleRDD(x, withReplacement, fraction,
- as.integer(ceiling(runif(1,
+ as.integer(ceiling(stats::runif(1,
-MAXINT,
MAXINT)))))
# If the first sample didn't turn out large enough, keep trying to
@@ -973,7 +973,7 @@ setMethod("takeSample", signature(x = "RDD", withReplacement = "logical",
# multiplier for thei initial size
while (length(samples) < total)
samples <- collectRDD(sampleRDD(x, withReplacement, fraction,
- as.integer(ceiling(runif(1,
+ as.integer(ceiling(stats::runif(1,
-MAXINT,
MAXINT)))))
diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R
index a9cd2d85f898c..572e71e25b80b 100644
--- a/R/pkg/R/SQLContext.R
+++ b/R/pkg/R/SQLContext.R
@@ -115,7 +115,7 @@ infer_type <- function(x) {
#' Get Runtime Config from the current active SparkSession
#'
#' Get Runtime Config from the current active SparkSession.
-#' To change SparkSession Runtime Config, please see `sparkR.session()`.
+#' To change SparkSession Runtime Config, please see \code{sparkR.session()}.
#'
#' @param key (optional) The key of the config to get, if omitted, all config is returned
#' @param defaultValue (optional) The default value of the config to return if they config is not
@@ -720,11 +720,11 @@ dropTempView <- function(viewName) {
#'
#' Returns the dataset in a data source as a SparkDataFrame
#'
-#' The data source is specified by the `source` and a set of options(...).
-#' If `source` is not specified, the default data source configured by
+#' The data source is specified by the \code{source} and a set of options(...).
+#' If \code{source} is not specified, the default data source configured by
#' "spark.sql.sources.default" will be used. \cr
-#' Similar to R read.csv, when `source` is "csv", by default, a value of "NA" will be interpreted
-#' as NA.
+#' Similar to R read.csv, when \code{source} is "csv", by default, a value of "NA" will be
+#' interpreted as NA.
#'
#' @param path The path of files to load
#' @param source The name of external data source
@@ -791,8 +791,8 @@ loadDF <- function(x, ...) {
#' Creates an external table based on the dataset in a data source,
#' Returns a SparkDataFrame associated with the external table.
#'
-#' The data source is specified by the `source` and a set of options(...).
-#' If `source` is not specified, the default data source configured by
+#' The data source is specified by the \code{source} and a set of options(...).
+#' If \code{source} is not specified, the default data source configured by
#' "spark.sql.sources.default" will be used.
#'
#' @param tableName a name of the table.
@@ -830,22 +830,22 @@ createExternalTable <- function(x, ...) {
#' Additional JDBC database connection properties can be set (...)
#'
#' Only one of partitionColumn or predicates should be set. Partitions of the table will be
-#' retrieved in parallel based on the `numPartitions` or by the predicates.
+#' retrieved in parallel based on the \code{numPartitions} or by the predicates.
#'
#' Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash
#' your external database systems.
#'
-#' @param url JDBC database url of the form `jdbc:subprotocol:subname`
+#' @param url JDBC database url of the form \code{jdbc:subprotocol:subname}
#' @param tableName the name of the table in the external database
#' @param partitionColumn the name of a column of integral type that will be used for partitioning
-#' @param lowerBound the minimum value of `partitionColumn` used to decide partition stride
-#' @param upperBound the maximum value of `partitionColumn` used to decide partition stride
-#' @param numPartitions the number of partitions, This, along with `lowerBound` (inclusive),
-#' `upperBound` (exclusive), form partition strides for generated WHERE
-#' clause expressions used to split the column `partitionColumn` evenly.
+#' @param lowerBound the minimum value of \code{partitionColumn} used to decide partition stride
+#' @param upperBound the maximum value of \code{partitionColumn} used to decide partition stride
+#' @param numPartitions the number of partitions, This, along with \code{lowerBound} (inclusive),
+#' \code{upperBound} (exclusive), form partition strides for generated WHERE
+#' clause expressions used to split the column \code{partitionColumn} evenly.
#' This defaults to SparkContext.defaultParallelism when unset.
#' @param predicates a list of conditions in the where clause; each one defines one partition
-#' @param ... additional JDBC database connection named propertie(s).
+#' @param ... additional JDBC database connection named properties.
#' @return SparkDataFrame
#' @rdname read.jdbc
#' @name read.jdbc
diff --git a/R/pkg/R/WindowSpec.R b/R/pkg/R/WindowSpec.R
index b55356b07d5e3..ddd2ef2fcdee5 100644
--- a/R/pkg/R/WindowSpec.R
+++ b/R/pkg/R/WindowSpec.R
@@ -44,6 +44,7 @@ windowSpec <- function(sws) {
}
#' @rdname show
+#' @export
#' @note show(WindowSpec) since 2.0.0
setMethod("show", "WindowSpec",
function(object) {
@@ -125,11 +126,11 @@ setMethod("orderBy",
#' rowsBetween
#'
-#' Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive).
+#' Defines the frame boundaries, from \code{start} (inclusive) to \code{end} (inclusive).
#'
-#' Both `start` and `end` are relative positions from the current row. For example, "0" means
-#' "current row", while "-1" means the row before the current row, and "5" means the fifth row
-#' after the current row.
+#' Both \code{start} and \code{end} are relative positions from the current row. For example,
+#' "0" means "current row", while "-1" means the row before the current row, and "5" means the
+#' fifth row after the current row.
#'
#' @param x a WindowSpec
#' @param start boundary start, inclusive.
@@ -157,12 +158,12 @@ setMethod("rowsBetween",
#' rangeBetween
#'
-#' Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive).
+#' Defines the frame boundaries, from \code{start} (inclusive) to \code{end} (inclusive).
+#'
+#' Both \code{start} and \code{end} are relative from the current row. For example, "0" means
+#' "current row", while "-1" means one off before the current row, and "5" means the five off
+#' after the current row.
#'
-#' Both `start` and `end` are relative from the current row. For example, "0" means "current row",
-#' while "-1" means one off before the current row, and "5" means the five off after the
-#' current row.
-
#' @param x a WindowSpec
#' @param start boundary start, inclusive.
#' The frame is unbounded if this is the minimum long value.
@@ -195,8 +196,8 @@ setMethod("rangeBetween",
#' Define a windowing column.
#'
#' @param x a Column, usually one returned by window function(s).
-#' @param window a WindowSpec object. Can be created by `windowPartitionBy` or
-#' `windowOrderBy` and configured by other WindowSpec methods.
+#' @param window a WindowSpec object. Can be created by \code{windowPartitionBy} or
+#' \code{windowOrderBy} and configured by other WindowSpec methods.
#' @rdname over
#' @name over
#' @aliases over,Column,WindowSpec-method
diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R
index af486e1ce212d..539d91b0f8797 100644
--- a/R/pkg/R/column.R
+++ b/R/pkg/R/column.R
@@ -284,7 +284,7 @@ setMethod("%in%",
#' otherwise
#'
#' If values in the specified column are null, returns the value.
-#' Can be used in conjunction with `when` to specify a default value for expressions.
+#' Can be used in conjunction with \code{when} to specify a default value for expressions.
#'
#' @param x a Column.
#' @param value value to replace when the corresponding entry in \code{x} is NA.
diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R
index b3c10de71f3fe..f042adddef91f 100644
--- a/R/pkg/R/functions.R
+++ b/R/pkg/R/functions.R
@@ -1250,7 +1250,7 @@ setMethod("rint",
#' round
#'
-#' Returns the value of the column `e` rounded to 0 decimal places using HALF_UP rounding mode.
+#' Returns the value of the column \code{e} rounded to 0 decimal places using HALF_UP rounding mode.
#'
#' @param x Column to compute on.
#'
@@ -1974,7 +1974,7 @@ setMethod("atan2", signature(y = "Column"),
#' datediff
#'
-#' Returns the number of days from `start` to `end`.
+#' Returns the number of days from \code{start} to \code{end}.
#'
#' @param x start Column to use.
#' @param y end Column to use.
@@ -2043,7 +2043,7 @@ setMethod("levenshtein", signature(y = "Column"),
#' months_between
#'
-#' Returns number of months between dates `date1` and `date2`.
+#' Returns number of months between dates \code{date1} and \code{date2}.
#'
#' @param x start Column to use.
#' @param y end Column to use.
@@ -2430,7 +2430,7 @@ setMethod("add_months", signature(y = "Column", x = "numeric"),
#' date_add
#'
-#' Returns the date that is `days` days after `start`
+#' Returns the date that is \code{x} days after
#'
#' @param y Column to compute on
#' @param x Number of days to add
@@ -2450,7 +2450,7 @@ setMethod("date_add", signature(y = "Column", x = "numeric"),
#' date_sub
#'
-#' Returns the date that is `days` days before `start`
+#' Returns the date that is \code{x} days before
#'
#' @param y Column to compute on
#' @param x Number of days to substract
@@ -3113,7 +3113,7 @@ setMethod("ifelse",
#' N = total number of rows in the partition
#' cume_dist(x) = number of values before (and including) x / N
#'
-#' This is equivalent to the CUME_DIST function in SQL.
+#' This is equivalent to the \code{CUME_DIST} function in SQL.
#'
#' @rdname cume_dist
#' @name cume_dist
@@ -3141,7 +3141,7 @@ setMethod("cume_dist",
#' and had three people tie for second place, you would say that all three were in second
#' place and that the next person came in third.
#'
-#' This is equivalent to the DENSE_RANK function in SQL.
+#' This is equivalent to the \code{DENSE_RANK} function in SQL.
#'
#' @rdname dense_rank
#' @name dense_rank
@@ -3159,11 +3159,11 @@ setMethod("dense_rank",
#' lag
#'
-#' Window function: returns the value that is `offset` rows before the current row, and
-#' `defaultValue` if there is less than `offset` rows before the current row. For example,
-#' an `offset` of one will return the previous row at any given point in the window partition.
+#' Window function: returns the value that is \code{offset} rows before the current row, and
+#' \code{defaultValue} if there is less than \code{offset} rows before the current row. For example,
+#' an \code{offset} of one will return the previous row at any given point in the window partition.
#'
-#' This is equivalent to the LAG function in SQL.
+#' This is equivalent to the \code{LAG} function in SQL.
#'
#' @param x the column as a character string or a Column to compute on.
#' @param offset the number of rows back from the current row from which to obtain a value.
@@ -3193,11 +3193,11 @@ setMethod("lag",
#' lead
#'
-#' Window function: returns the value that is `offset` rows after the current row, and
-#' `null` if there is less than `offset` rows after the current row. For example,
-#' an `offset` of one will return the next row at any given point in the window partition.
+#' Window function: returns the value that is \code{offset} rows after the current row, and
+#' NULL if there is less than \code{offset} rows after the current row. For example,
+#' an \code{offset} of one will return the next row at any given point in the window partition.
#'
-#' This is equivalent to the LEAD function in SQL.
+#' This is equivalent to the \code{LEAD} function in SQL.
#'
#' @param x Column to compute on
#' @param offset Number of rows to offset
@@ -3226,11 +3226,11 @@ setMethod("lead",
#' ntile
#'
-#' Window function: returns the ntile group id (from 1 to `n` inclusive) in an ordered window
-#' partition. For example, if `n` is 4, the first quarter of the rows will get value 1, the second
+#' Window function: returns the ntile group id (from 1 to n inclusive) in an ordered window
+#' partition. For example, if n is 4, the first quarter of the rows will get value 1, the second
#' quarter will get 2, the third quarter will get 3, and the last quarter will get 4.
#'
-#' This is equivalent to the NTILE function in SQL.
+#' This is equivalent to the \code{NTILE} function in SQL.
#'
#' @param x Number of ntile groups
#'
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 6610a25c8c05a..88884e62575df 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -438,17 +438,17 @@ setGeneric("columns", function(x) {standardGeneric("columns") })
setGeneric("count", function(x) { standardGeneric("count") })
#' @rdname cov
-#' @param x a Column object or a SparkDataFrame.
-#' @param ... additional argument(s). If `x` is a Column object, a Column object
-#' should be provided. If `x` is a SparkDataFrame, two column names should
+#' @param x a Column or a SparkDataFrame.
+#' @param ... additional argument(s). If \code{x} is a Column, a Column
+#' should be provided. If \code{x} is a SparkDataFrame, two column names should
#' be provided.
#' @export
setGeneric("cov", function(x, ...) {standardGeneric("cov") })
#' @rdname corr
-#' @param x a Column object or a SparkDataFrame.
-#' @param ... additional argument(s). If `x` is a Column object, a Column object
-#' should be provided. If `x` is a SparkDataFrame, two column names should
+#' @param x a Column or a SparkDataFrame.
+#' @param ... additional argument(s). If \code{x} is a Column, a Column
+#' should be provided. If \code{x} is a SparkDataFrame, two column names should
#' be provided.
#' @export
setGeneric("corr", function(x, ...) {standardGeneric("corr") })
@@ -851,7 +851,7 @@ setGeneric("array_contains", function(x, value) { standardGeneric("array_contain
setGeneric("ascii", function(x) { standardGeneric("ascii") })
#' @param x Column to compute on or a GroupedData object.
-#' @param ... additional argument(s) when `x` is a GroupedData object.
+#' @param ... additional argument(s) when \code{x} is a GroupedData object.
#' @rdname avg
#' @export
setGeneric("avg", function(x, ...) { standardGeneric("avg") })
@@ -1339,7 +1339,6 @@ setGeneric("spark.naiveBayes", function(data, formula, ...) { standardGeneric("s
setGeneric("spark.survreg", function(data, formula) { standardGeneric("spark.survreg") })
#' @rdname spark.lda
-#' @param ... Additional parameters to tune LDA.
#' @export
setGeneric("spark.lda", function(data, ...) { standardGeneric("spark.lda") })
diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R
index 3c85ada91a444..e3479ef5fa583 100644
--- a/R/pkg/R/group.R
+++ b/R/pkg/R/group.R
@@ -48,6 +48,7 @@ groupedData <- function(sgd) {
#' @rdname show
#' @aliases show,GroupedData-method
+#' @export
#' @note show(GroupedData) since 1.4.0
setMethod("show", "GroupedData",
function(object) {
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index b36fbcee17671..a40310d194d27 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -131,7 +131,7 @@ predict_internal <- function(object, newData) {
#' This can be a character string naming a family function, a family function or
#' the result of a call to a family function. Refer R family at
#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
-#' @param weightCol the weight column name. If this is not set or NULL, we treat all instance
+#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance
#' weights as 1.0.
#' @param tol positive convergence tolerance of iterations.
#' @param maxIter integer giving the maximal number of IRLS iterations.
@@ -197,7 +197,7 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
#' This can be a character string naming a family function, a family function or
#' the result of a call to a family function. Refer R family at
#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
-#' @param weightCol the weight column name. If this is not set or NULL, we treat all instance
+#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance
#' weights as 1.0.
#' @param epsilon positive convergence tolerance of iterations.
#' @param maxit integer giving the maximal number of IRLS iterations.
@@ -434,8 +434,8 @@ setMethod("write.ml", signature(object = "LDAModel", path = "character"),
#' operators are supported, including '~', '.', ':', '+', and '-'.
#' @param isotonic Whether the output sequence should be isotonic/increasing (TRUE) or
#' antitonic/decreasing (FALSE)
-#' @param featureIndex The index of the feature if \code{featuresCol} is a vector column (default: `0`),
-#' no effect otherwise
+#' @param featureIndex The index of the feature if \code{featuresCol} is a vector column
+#' (default: 0), no effect otherwise
#' @param weightCol The weight column name.
#' @return \code{spark.isoreg} returns a fitted Isotonic Regression model
#' @rdname spark.isoreg
@@ -647,7 +647,7 @@ setMethod("predict", signature(object = "KMeansModel"),
#' @rdname spark.naiveBayes
#' @aliases spark.naiveBayes,SparkDataFrame,formula-method
#' @name spark.naiveBayes
-#' @seealso e1071: \url{https://cran.r-project.org/web/packages/e1071/}
+#' @seealso e1071: \url{https://cran.r-project.org/package=e1071}
#' @export
#' @examples
#' \dontrun{
@@ -815,7 +815,7 @@ read.ml <- function(path) {
#' Note that operator '.' is not supported currently.
#' @return \code{spark.survreg} returns a fitted AFT survival regression model.
#' @rdname spark.survreg
-#' @seealso survival: \url{https://cran.r-project.org/web/packages/survival/}
+#' @seealso survival: \url{https://cran.r-project.org/package=survival}
#' @export
#' @examples
#' \dontrun{
@@ -870,10 +870,11 @@ setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula
#' @param customizedStopWords stopwords that need to be removed from the given corpus. Ignore the
#' parameter if libSVM-format column is used as the features column.
#' @param maxVocabSize maximum vocabulary size, default 1 << 18
+#' @param ... additional argument(s) passed to the method.
#' @return \code{spark.lda} returns a fitted Latent Dirichlet Allocation model
#' @rdname spark.lda
#' @aliases spark.lda,SparkDataFrame-method
-#' @seealso topicmodels: \url{https://cran.r-project.org/web/packages/topicmodels/}
+#' @seealso topicmodels: \url{https://cran.r-project.org/package=topicmodels}
#' @export
#' @examples
#' \dontrun{
@@ -962,7 +963,7 @@ setMethod("predict", signature(object = "AFTSurvivalRegressionModel"),
#' @return \code{spark.gaussianMixture} returns a fitted multivariate gaussian mixture model.
#' @rdname spark.gaussianMixture
#' @name spark.gaussianMixture
-#' @seealso mixtools: \url{https://cran.r-project.org/web/packages/mixtools/}
+#' @seealso mixtools: \url{https://cran.r-project.org/package=mixtools}
#' @export
#' @examples
#' \dontrun{
@@ -1075,7 +1076,7 @@ setMethod("predict", signature(object = "GaussianMixtureModel"),
#' @param numUserBlocks number of user blocks used to parallelize computation (> 0).
#' @param numItemBlocks number of item blocks used to parallelize computation (> 0).
#' @param checkpointInterval number of checkpoint intervals (>= 1) or disable checkpoint (-1).
-#'
+#' @param ... additional argument(s) passed to the method.
#' @return \code{spark.als} returns a fitted ALS model
#' @rdname spark.als
#' @aliases spark.als,SparkDataFrame-method
diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R
index f0605db1e9e83..4dee3245f9b75 100644
--- a/R/pkg/R/pairRDD.R
+++ b/R/pkg/R/pairRDD.R
@@ -917,19 +917,19 @@ setMethod("sampleByKey",
len <- 0
# mixing because the initial seeds are close to each other
- runif(10)
+ stats::runif(10)
for (elem in part) {
if (elem[[1]] %in% names(fractions)) {
frac <- as.numeric(fractions[which(elem[[1]] == names(fractions))])
if (withReplacement) {
- count <- rpois(1, frac)
+ count <- stats::rpois(1, frac)
if (count > 0) {
res[ (len + 1) : (len + count) ] <- rep(list(elem), count)
len <- len + count
}
} else {
- if (runif(1) < frac) {
+ if (stats::runif(1) < frac) {
len <- len + 1
res[[len]] <- elem
}
diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R
index 8ea24d81729ec..dcd7198f41ea7 100644
--- a/R/pkg/R/stats.R
+++ b/R/pkg/R/stats.R
@@ -29,9 +29,9 @@ setOldClass("jobj")
#' @param col1 name of the first column. Distinct items will make the first item of each row.
#' @param col2 name of the second column. Distinct items will make the column names of the output.
#' @return a local R data.frame representing the contingency table. The first column of each row
-#' will be the distinct values of `col1` and the column names will be the distinct values
-#' of `col2`. The name of the first column will be `$col1_$col2`. Pairs that have no
-#' occurrences will have zero as their counts.
+#' will be the distinct values of \code{col1} and the column names will be the distinct values
+#' of \code{col2}. The name of the first column will be "\code{col1}_\code{col2}". Pairs
+#' that have no occurrences will have zero as their counts.
#'
#' @rdname crosstab
#' @name crosstab
@@ -116,7 +116,7 @@ setMethod("corr",
#'
#' @param x A SparkDataFrame.
#' @param cols A vector column names to search frequent items in.
-#' @param support (Optional) The minimum frequency for an item to be considered `frequent`.
+#' @param support (Optional) The minimum frequency for an item to be considered \code{frequent}.
#' Should be greater than 1e-4. Default support = 0.01.
#' @return a local R data.frame with the frequent items in each column
#'
@@ -142,9 +142,9 @@ setMethod("freqItems", signature(x = "SparkDataFrame", cols = "character"),
#'
#' Calculates the approximate quantiles of a numerical column of a SparkDataFrame.
#' The result of this algorithm has the following deterministic bound:
-#' If the SparkDataFrame has N elements and if we request the quantile at probability `p` up to
-#' error `err`, then the algorithm will return a sample `x` from the SparkDataFrame so that the
-#' *exact* rank of `x` is close to (p * N). More precisely,
+#' If the SparkDataFrame has N elements and if we request the quantile at probability p up to
+#' error err, then the algorithm will return a sample x from the SparkDataFrame so that the
+#' *exact* rank of x is close to (p * N). More precisely,
#' floor((p - err) * N) <= rank(x) <= ceil((p + err) * N).
#' This method implements a variation of the Greenwald-Khanna algorithm (with some speed
#' optimizations). The algorithm was first present in [[http://dx.doi.org/10.1145/375663.375670
From 8e223ea67acf5aa730ccf688802f17f6fc10907c Mon Sep 17 00:00:00 2001
From: Eric Liang
Date: Mon, 22 Aug 2016 16:32:14 -0700
Subject: [PATCH 060/270] [SPARK-16550][SPARK-17042][CORE] Certain classes fail
to deserialize in block manager replication
## What changes were proposed in this pull request?
This is a straightforward clone of JoshRosen 's original patch. I have follow-up changes to fix block replication for repl-defined classes as well, but those appear to be flaking tests so I'm going to leave that for SPARK-17042
## How was this patch tested?
End-to-end test in ReplSuite (also more tests in DistributedSuite from the original patch).
Author: Eric Liang
Closes #14311 from ericl/spark-16550.
---
.../spark/serializer/SerializerManager.scala | 14 +++-
.../apache/spark/storage/BlockManager.scala | 13 +++-
.../org/apache/spark/DistributedSuite.scala | 77 ++++++-------------
.../org/apache/spark/repl/ReplSuite.scala | 14 ++++
4 files changed, 60 insertions(+), 58 deletions(-)
diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
index 9dc274c9fe288..07caadbe40438 100644
--- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
@@ -68,7 +68,7 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
* loaded yet. */
private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf)
- private def canUseKryo(ct: ClassTag[_]): Boolean = {
+ def canUseKryo(ct: ClassTag[_]): Boolean = {
primitiveAndPrimitiveArrayClassTags.contains(ct) || ct == stringClassTag
}
@@ -128,8 +128,18 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
/** Serializes into a chunked byte buffer. */
def dataSerialize[T: ClassTag](blockId: BlockId, values: Iterator[T]): ChunkedByteBuffer = {
+ dataSerializeWithExplicitClassTag(blockId, values, implicitly[ClassTag[T]])
+ }
+
+ /** Serializes into a chunked byte buffer. */
+ def dataSerializeWithExplicitClassTag(
+ blockId: BlockId,
+ values: Iterator[_],
+ classTag: ClassTag[_]): ChunkedByteBuffer = {
val bbos = new ChunkedByteBufferOutputStream(1024 * 1024 * 4, ByteBuffer.allocate)
- dataSerializeStream(blockId, bbos, values)
+ val byteStream = new BufferedOutputStream(bbos)
+ val ser = getSerializer(classTag).newInstance()
+ ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
bbos.toChunkedByteBuffer
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 015e71d1260ea..fe8465279860d 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -498,7 +498,8 @@ private[spark] class BlockManager(
diskStore.getBytes(blockId)
} else if (level.useMemory && memoryStore.contains(blockId)) {
// The block was not found on disk, so serialize an in-memory copy:
- serializerManager.dataSerialize(blockId, memoryStore.getValues(blockId).get)
+ serializerManager.dataSerializeWithExplicitClassTag(
+ blockId, memoryStore.getValues(blockId).get, info.classTag)
} else {
handleLocalReadFailure(blockId)
}
@@ -973,8 +974,16 @@ private[spark] class BlockManager(
if (level.replication > 1) {
val remoteStartTime = System.currentTimeMillis
val bytesToReplicate = doGetLocalBytes(blockId, info)
+ // [SPARK-16550] Erase the typed classTag when using default serialization, since
+ // NettyBlockRpcServer crashes when deserializing repl-defined classes.
+ // TODO(ekl) remove this once the classloader issue on the remote end is fixed.
+ val remoteClassTag = if (!serializerManager.canUseKryo(classTag)) {
+ scala.reflect.classTag[Any]
+ } else {
+ classTag
+ }
try {
- replicate(blockId, bytesToReplicate, level, classTag)
+ replicate(blockId, bytesToReplicate, level, remoteClassTag)
} finally {
bytesToReplicate.dispose()
}
diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
index 6beae842b04d1..4ee0e00fde506 100644
--- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
@@ -149,61 +149,16 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex
sc.parallelize(1 to 10).count()
}
- test("caching") {
+ private def testCaching(storageLevel: StorageLevel): Unit = {
sc = new SparkContext(clusterUrl, "test")
- val data = sc.parallelize(1 to 1000, 10).cache()
- assert(data.count() === 1000)
- assert(data.count() === 1000)
- assert(data.count() === 1000)
- }
-
- test("caching on disk") {
- sc = new SparkContext(clusterUrl, "test")
- val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.DISK_ONLY)
- assert(data.count() === 1000)
- assert(data.count() === 1000)
- assert(data.count() === 1000)
- }
-
- test("caching in memory, replicated") {
- sc = new SparkContext(clusterUrl, "test")
- val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_ONLY_2)
- assert(data.count() === 1000)
- assert(data.count() === 1000)
- assert(data.count() === 1000)
- }
-
- test("caching in memory, serialized, replicated") {
- sc = new SparkContext(clusterUrl, "test")
- val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_ONLY_SER_2)
- assert(data.count() === 1000)
- assert(data.count() === 1000)
- assert(data.count() === 1000)
- }
-
- test("caching on disk, replicated") {
- sc = new SparkContext(clusterUrl, "test")
- val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.DISK_ONLY_2)
- assert(data.count() === 1000)
- assert(data.count() === 1000)
- assert(data.count() === 1000)
- }
-
- test("caching in memory and disk, replicated") {
- sc = new SparkContext(clusterUrl, "test")
- val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_AND_DISK_2)
- assert(data.count() === 1000)
- assert(data.count() === 1000)
- assert(data.count() === 1000)
- }
-
- test("caching in memory and disk, serialized, replicated") {
- sc = new SparkContext(clusterUrl, "test")
- val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_AND_DISK_SER_2)
-
- assert(data.count() === 1000)
- assert(data.count() === 1000)
- assert(data.count() === 1000)
+ sc.jobProgressListener.waitUntilExecutorsUp(2, 30000)
+ val data = sc.parallelize(1 to 1000, 10)
+ val cachedData = data.persist(storageLevel)
+ assert(cachedData.count === 1000)
+ assert(sc.getExecutorStorageStatus.map(_.rddBlocksById(cachedData.id).size).sum ===
+ storageLevel.replication * data.getNumPartitions)
+ assert(cachedData.count === 1000)
+ assert(cachedData.count === 1000)
// Get all the locations of the first partition and try to fetch the partitions
// from those locations.
@@ -221,6 +176,20 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex
}
}
+ Seq(
+ "caching" -> StorageLevel.MEMORY_ONLY,
+ "caching on disk" -> StorageLevel.DISK_ONLY,
+ "caching in memory, replicated" -> StorageLevel.MEMORY_ONLY_2,
+ "caching in memory, serialized, replicated" -> StorageLevel.MEMORY_ONLY_SER_2,
+ "caching on disk, replicated" -> StorageLevel.DISK_ONLY_2,
+ "caching in memory and disk, replicated" -> StorageLevel.MEMORY_AND_DISK_2,
+ "caching in memory and disk, serialized, replicated" -> StorageLevel.MEMORY_AND_DISK_SER_2
+ ).foreach { case (testName, storageLevel) =>
+ test(testName) {
+ testCaching(storageLevel)
+ }
+ }
+
test("compute without caching when no partitions fit in memory") {
val size = 10000
val conf = new SparkConf()
diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index c10db947bcb44..06b09f3158d77 100644
--- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -396,6 +396,20 @@ class ReplSuite extends SparkFunSuite {
assertContains("ret: Array[(Int, Iterable[Foo])] = Array((1,", output)
}
+ test("replicating blocks of object with class defined in repl") {
+ val output = runInterpreter("local-cluster[2,1,1024]",
+ """
+ |import org.apache.spark.storage.StorageLevel._
+ |case class Foo(i: Int)
+ |val ret = sc.parallelize((1 to 100).map(Foo), 10).persist(MEMORY_ONLY_2)
+ |ret.count()
+ |sc.getExecutorStorageStatus.map(s => s.rddBlocksById(ret.id).size).sum
+ """.stripMargin)
+ assertDoesNotContain("error:", output)
+ assertDoesNotContain("Exception", output)
+ assertContains(": Int = 20", output)
+ }
+
test("line wrapper only initialized once when used as encoder outer scope") {
val output = runInterpreter("local",
"""
From 6d93f9e0236aa61e39a1abfb0f7f7c558fb7d5d5 Mon Sep 17 00:00:00 2001
From: gatorsmile
Date: Tue, 23 Aug 2016 08:03:08 +0800
Subject: [PATCH 061/270] [SPARK-17144][SQL] Removal of useless
CreateHiveTableAsSelectLogicalPlan
## What changes were proposed in this pull request?
`CreateHiveTableAsSelectLogicalPlan` is a dead code after refactoring.
## How was this patch tested?
N/A
Author: gatorsmile
Closes #14707 from gatorsmile/removeCreateHiveTable.
---
.../spark/sql/execution/command/tables.scala | 19 +------------------
1 file changed, 1 insertion(+), 18 deletions(-)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
index af2b5ffd1c427..21544a37d9975 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
@@ -33,28 +33,11 @@ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogT
import org.apache.spark.sql.catalyst.catalog.CatalogTableType._
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
-import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan, UnaryNode}
import org.apache.spark.sql.catalyst.util.quoteIdentifier
-import org.apache.spark.sql.execution.datasources.{PartitioningUtils}
+import org.apache.spark.sql.execution.datasources.PartitioningUtils
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
-case class CreateHiveTableAsSelectLogicalPlan(
- tableDesc: CatalogTable,
- child: LogicalPlan,
- allowExisting: Boolean) extends UnaryNode with Command {
-
- override def output: Seq[Attribute] = Seq.empty[Attribute]
-
- override lazy val resolved: Boolean =
- tableDesc.identifier.database.isDefined &&
- tableDesc.schema.nonEmpty &&
- tableDesc.storage.serde.isDefined &&
- tableDesc.storage.inputFormat.isDefined &&
- tableDesc.storage.outputFormat.isDefined &&
- childrenResolved
-}
-
/**
* A command to create a table with the same definition of the given existing table.
*
From 37f0ab70d25802b609317bc93421d2fe3ee9db6e Mon Sep 17 00:00:00 2001
From: hqzizania
Date: Mon, 22 Aug 2016 17:09:08 -0700
Subject: [PATCH 062/270] [SPARK-17090][FOLLOW-UP][ML] Add expert param support
to SharedParamsCodeGen
## What changes were proposed in this pull request?
Add expert param support to SharedParamsCodeGen where aggregationDepth a expert param is added.
Author: hqzizania
Closes #14738 from hqzizania/SPARK-17090-minor.
---
.../ml/param/shared/SharedParamsCodeGen.scala | 14 ++++++++++----
.../spark/ml/param/shared/sharedParams.scala | 4 ++--
2 files changed, 12 insertions(+), 6 deletions(-)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index 0f48a16a429ff..480b03d0f35c4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -80,7 +80,7 @@ private[shared] object SharedParamsCodeGen {
ParamDesc[String]("solver", "the solver algorithm for optimization. If this is not set or " +
"empty, default value is 'auto'", Some("\"auto\"")),
ParamDesc[Int]("aggregationDepth", "suggested depth for treeAggregate (>= 2)", Some("2"),
- isValid = "ParamValidators.gtEq(2)"))
+ isValid = "ParamValidators.gtEq(2)", isExpertParam = true))
val code = genSharedParams(params)
val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala"
@@ -95,7 +95,8 @@ private[shared] object SharedParamsCodeGen {
doc: String,
defaultValueStr: Option[String] = None,
isValid: String = "",
- finalMethods: Boolean = true) {
+ finalMethods: Boolean = true,
+ isExpertParam: Boolean = false) {
require(name.matches("[a-z][a-zA-Z0-9]*"), s"Param name $name is invalid.")
require(doc.nonEmpty) // TODO: more rigorous on doc
@@ -153,6 +154,11 @@ private[shared] object SharedParamsCodeGen {
} else {
""
}
+ val groupStr = if (param.isExpertParam) {
+ Array("expertParam", "expertGetParam")
+ } else {
+ Array("param", "getParam")
+ }
val methodStr = if (param.finalMethods) {
"final def"
} else {
@@ -167,11 +173,11 @@ private[shared] object SharedParamsCodeGen {
|
| /**
| * Param for $doc.
- | * @group param
+ | * @group ${groupStr(0)}
| */
| final val $name: $Param = new $Param(this, "$name", "$doc"$isValid)
|$setDefault
- | /** @group getParam */
+ | /** @group ${groupStr(1)} */
| $methodStr get$Name: $T = $$($name)
|}
|""".stripMargin
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index 6803772c63d62..9125d9e19bf09 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -397,13 +397,13 @@ private[ml] trait HasAggregationDepth extends Params {
/**
* Param for suggested depth for treeAggregate (>= 2).
- * @group param
+ * @group expertParam
*/
final val aggregationDepth: IntParam = new IntParam(this, "aggregationDepth", "suggested depth for treeAggregate (>= 2)", ParamValidators.gtEq(2))
setDefault(aggregationDepth, 2)
- /** @group getParam */
+ /** @group expertGetParam */
final def getAggregationDepth: Int = $(aggregationDepth)
}
// scalastyle:on
From 920806ab272ba58a369072a5eeb89df5e9b470a6 Mon Sep 17 00:00:00 2001
From: Shivaram Venkataraman
Date: Mon, 22 Aug 2016 17:09:32 -0700
Subject: [PATCH 063/270] [SPARK-16577][SPARKR] Add CRAN documentation checks
to run-tests.sh
## What changes were proposed in this pull request?
(Please fill in changes proposed in this fix)
## How was this patch tested?
This change adds CRAN documentation checks to be run as a part of `R/run-tests.sh` . As this script is also used by Jenkins this means that we will get documentation checks on every PR going forward.
(If this patch involves UI changes, please attach a screenshot; otherwise, remove this)
Author: Shivaram Venkataraman
Closes #14759 from shivaram/sparkr-cran-jenkins.
---
R/check-cran.sh | 18 +++++++++++++++---
R/run-tests.sh | 27 ++++++++++++++++++++++++---
2 files changed, 39 insertions(+), 6 deletions(-)
diff --git a/R/check-cran.sh b/R/check-cran.sh
index 5c90fd07f28e4..bb331466ae931 100755
--- a/R/check-cran.sh
+++ b/R/check-cran.sh
@@ -43,10 +43,22 @@ $FWDIR/create-docs.sh
"$R_SCRIPT_PATH/"R CMD build $FWDIR/pkg
# Run check as-cran.
-# TODO(shivaram): Remove the skip tests once we figure out the install mechanism
-
VERSION=`grep Version $FWDIR/pkg/DESCRIPTION | awk '{print $NF}'`
-"$R_SCRIPT_PATH/"R CMD check --as-cran SparkR_"$VERSION".tar.gz
+CRAN_CHECK_OPTIONS="--as-cran"
+
+if [ -n "$NO_TESTS" ]
+then
+ CRAN_CHECK_OPTIONS=$CRAN_CHECK_OPTIONS" --no-tests"
+fi
+
+if [ -n "$NO_MANUAL" ]
+then
+ CRAN_CHECK_OPTIONS=$CRAN_CHECK_OPTIONS" --no-manual"
+fi
+
+echo "Running CRAN check with $CRAN_CHECK_OPTIONS options"
+
+"$R_SCRIPT_PATH/"R CMD check $CRAN_CHECK_OPTIONS SparkR_"$VERSION".tar.gz
popd > /dev/null
diff --git a/R/run-tests.sh b/R/run-tests.sh
index 9dcf0ace7d97e..1a1e8ab9ffe18 100755
--- a/R/run-tests.sh
+++ b/R/run-tests.sh
@@ -26,6 +26,17 @@ rm -f $LOGFILE
SPARK_TESTING=1 $FWDIR/../bin/spark-submit --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" --conf spark.hadoop.fs.default.name="file:///" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE
FAILED=$((PIPESTATUS[0]||$FAILED))
+# Also run the documentation tests for CRAN
+CRAN_CHECK_LOG_FILE=$FWDIR/cran-check.out
+rm -f $CRAN_CHECK_LOG_FILE
+
+NO_TESTS=1 NO_MANUAL=1 $FWDIR/check-cran.sh 2>&1 | tee -a $CRAN_CHECK_LOG_FILE
+FAILED=$((PIPESTATUS[0]||$FAILED))
+
+NUM_CRAN_WARNING="$(grep -c WARNING$ $CRAN_CHECK_LOG_FILE)"
+NUM_CRAN_ERROR="$(grep -c ERROR$ $CRAN_CHECK_LOG_FILE)"
+NUM_CRAN_NOTES="$(grep -c NOTE$ $CRAN_CHECK_LOG_FILE)"
+
if [[ $FAILED != 0 ]]; then
cat $LOGFILE
echo -en "\033[31m" # Red
@@ -33,7 +44,17 @@ if [[ $FAILED != 0 ]]; then
echo -en "\033[0m" # No color
exit -1
else
- echo -en "\033[32m" # Green
- echo "Tests passed."
- echo -en "\033[0m" # No color
+ # We have 2 existing NOTEs for new maintainer, attach()
+ # We have one more NOTE in Jenkins due to "No repository set"
+ if [[ $NUM_CRAN_WARNING != 0 || $NUM_CRAN_ERROR != 0 || $NUM_CRAN_NOTES -gt 3 ]]; then
+ cat $CRAN_CHECK_LOG_FILE
+ echo -en "\033[31m" # Red
+ echo "Had CRAN check errors; see logs."
+ echo -en "\033[0m" # No color
+ exit -1
+ else
+ echo -en "\033[32m" # Green
+ echo "Tests passed."
+ echo -en "\033[0m" # No color
+ fi
fi
From 2cdd92a7cd6f85186c846635b422b977bdafbcdd Mon Sep 17 00:00:00 2001
From: Cheng Lian
Date: Tue, 23 Aug 2016 09:11:47 +0800
Subject: [PATCH 064/270] [SPARK-17182][SQL] Mark Collect as non-deterministic
## What changes were proposed in this pull request?
This PR marks the abstract class `Collect` as non-deterministic since the results of `CollectList` and `CollectSet` depend on the actual order of input rows.
## How was this patch tested?
Existing test cases should be enough.
Author: Cheng Lian
Closes #14749 from liancheng/spark-17182-non-deterministic-collect.
---
.../spark/sql/catalyst/expressions/aggregate/collect.scala | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
index ac2cefaddcf59..896ff61b23093 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
@@ -54,6 +54,10 @@ abstract class Collect extends ImperativeAggregate {
override def inputAggBufferAttributes: Seq[AttributeReference] = Nil
+ // Both `CollectList` and `CollectSet` are non-deterministic since their results depend on the
+ // actual order of input rows.
+ override def deterministic: Boolean = false
+
protected[this] val buffer: Growable[Any] with Iterable[Any]
override def initialize(b: MutableRow): Unit = {
From d2b3d3e63e1a9217de6ef507c350308017664a62 Mon Sep 17 00:00:00 2001
From: Felix Cheung
Date: Mon, 22 Aug 2016 20:15:03 -0700
Subject: [PATCH 065/270] [SPARKR][MINOR] Update R DESCRIPTION file
## What changes were proposed in this pull request?
Update DESCRIPTION
## How was this patch tested?
Run install and CRAN tests
Author: Felix Cheung
Closes #14764 from felixcheung/rpackagedescription.
---
R/pkg/DESCRIPTION | 13 +++++++++----
1 file changed, 9 insertions(+), 4 deletions(-)
diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION
index d81f1a3d4de68..e5afed2d0a93e 100644
--- a/R/pkg/DESCRIPTION
+++ b/R/pkg/DESCRIPTION
@@ -3,10 +3,15 @@ Type: Package
Title: R Frontend for Apache Spark
Version: 2.0.0
Date: 2016-07-07
-Author: The Apache Software Foundation
-Maintainer: Shivaram Venkataraman
- Xiangrui Meng
- Felix Cheung
+Authors@R: c(person("Shivaram", "Venkataraman", role = c("aut", "cre"),
+ email = "shivaram@cs.berkeley.edu"),
+ person("Xiangrui", "Meng", role = "aut",
+ email = "meng@databricks.com"),
+ person("Felix", "Cheung", role = "aut",
+ email = "felixcheung@apache.org"),
+ person(family = "The Apache Software Foundation", role = c("aut", "cph")))
+URL: http://www.apache.org/ http://spark.apache.org/
+BugReports: https://issues.apache.org/jira/secure/CreateIssueDetails!init.jspa?pid=12315420&components=12325400&issuetype=4
Depends:
R (>= 3.0),
methods
From cc33460a51d2890fe8f50f5b6b87003d6d210f04 Mon Sep 17 00:00:00 2001
From: Sean Zhong
Date: Tue, 23 Aug 2016 14:57:00 +0800
Subject: [PATCH 066/270] [SPARK-17188][SQL] Moves class QuantileSummaries to
project catalyst for implementing percentile_approx
## What changes were proposed in this pull request?
This is a sub-task of [SPARK-16283](https://issues.apache.org/jira/browse/SPARK-16283) (Implement percentile_approx SQL function), which moves class QuantileSummaries to project catalyst so that it can be reused when implementing aggregation function `percentile_approx`.
## How was this patch tested?
This PR only does class relocation, class implementation is not changed.
Author: Sean Zhong
Closes #14754 from clockfly/move_QuantileSummaries_to_catalyst.
---
.../sql/catalyst/util/QuantileSummaries.scala | 264 ++++++++++++++++++
.../util/QuantileSummariesSuite.scala} | 7 +-
.../sql/execution/stat/StatFunctions.scala | 247 +---------------
3 files changed, 267 insertions(+), 251 deletions(-)
create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala
rename sql/{core/src/test/scala/org/apache/spark/sql/execution/stat/ApproxQuantileSuite.scala => catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala} (96%)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala
new file mode 100644
index 0000000000000..493b5faf9e50a
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala
@@ -0,0 +1,264 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.sql.catalyst.util.QuantileSummaries.Stats
+
+/**
+ * Helper class to compute approximate quantile summary.
+ * This implementation is based on the algorithm proposed in the paper:
+ * "Space-efficient Online Computation of Quantile Summaries" by Greenwald, Michael
+ * and Khanna, Sanjeev. (http://dx.doi.org/10.1145/375663.375670)
+ *
+ * In order to optimize for speed, it maintains an internal buffer of the last seen samples,
+ * and only inserts them after crossing a certain size threshold. This guarantees a near-constant
+ * runtime complexity compared to the original algorithm.
+ *
+ * @param compressThreshold the compression threshold.
+ * After the internal buffer of statistics crosses this size, it attempts to compress the
+ * statistics together.
+ * @param relativeError the target relative error.
+ * It is uniform across the complete range of values.
+ * @param sampled a buffer of quantile statistics.
+ * See the G-K article for more details.
+ * @param count the count of all the elements *inserted in the sampled buffer*
+ * (excluding the head buffer)
+ */
+class QuantileSummaries(
+ val compressThreshold: Int,
+ val relativeError: Double,
+ val sampled: Array[Stats] = Array.empty,
+ val count: Long = 0L) extends Serializable {
+
+ // a buffer of latest samples seen so far
+ private val headSampled: ArrayBuffer[Double] = ArrayBuffer.empty
+
+ import QuantileSummaries._
+
+ /**
+ * Returns a summary with the given observation inserted into the summary.
+ * This method may either modify in place the current summary (and return the same summary,
+ * modified in place), or it may create a new summary from scratch it necessary.
+ * @param x the new observation to insert into the summary
+ */
+ def insert(x: Double): QuantileSummaries = {
+ headSampled.append(x)
+ if (headSampled.size >= defaultHeadSize) {
+ this.withHeadBufferInserted
+ } else {
+ this
+ }
+ }
+
+ /**
+ * Inserts an array of (unsorted samples) in a batch, sorting the array first to traverse
+ * the summary statistics in a single batch.
+ *
+ * This method does not modify the current object and returns if necessary a new copy.
+ *
+ * @return a new quantile summary object.
+ */
+ private def withHeadBufferInserted: QuantileSummaries = {
+ if (headSampled.isEmpty) {
+ return this
+ }
+ var currentCount = count
+ val sorted = headSampled.toArray.sorted
+ val newSamples: ArrayBuffer[Stats] = new ArrayBuffer[Stats]()
+ // The index of the next element to insert
+ var sampleIdx = 0
+ // The index of the sample currently being inserted.
+ var opsIdx: Int = 0
+ while(opsIdx < sorted.length) {
+ val currentSample = sorted(opsIdx)
+ // Add all the samples before the next observation.
+ while(sampleIdx < sampled.size && sampled(sampleIdx).value <= currentSample) {
+ newSamples.append(sampled(sampleIdx))
+ sampleIdx += 1
+ }
+
+ // If it is the first one to insert, of if it is the last one
+ currentCount += 1
+ val delta =
+ if (newSamples.isEmpty || (sampleIdx == sampled.size && opsIdx == sorted.length - 1)) {
+ 0
+ } else {
+ math.floor(2 * relativeError * currentCount).toInt
+ }
+
+ val tuple = Stats(currentSample, 1, delta)
+ newSamples.append(tuple)
+ opsIdx += 1
+ }
+
+ // Add all the remaining existing samples
+ while(sampleIdx < sampled.size) {
+ newSamples.append(sampled(sampleIdx))
+ sampleIdx += 1
+ }
+ new QuantileSummaries(compressThreshold, relativeError, newSamples.toArray, currentCount)
+ }
+
+ /**
+ * Returns a new summary that compresses the summary statistics and the head buffer.
+ *
+ * This implements the COMPRESS function of the GK algorithm. It does not modify the object.
+ *
+ * @return a new summary object with compressed statistics
+ */
+ def compress(): QuantileSummaries = {
+ // Inserts all the elements first
+ val inserted = this.withHeadBufferInserted
+ assert(inserted.headSampled.isEmpty)
+ assert(inserted.count == count + headSampled.size)
+ val compressed =
+ compressImmut(inserted.sampled, mergeThreshold = 2 * relativeError * inserted.count)
+ new QuantileSummaries(compressThreshold, relativeError, compressed, inserted.count)
+ }
+
+ private def shallowCopy: QuantileSummaries = {
+ new QuantileSummaries(compressThreshold, relativeError, sampled, count)
+ }
+
+ /**
+ * Merges two (compressed) summaries together.
+ *
+ * Returns a new summary.
+ */
+ def merge(other: QuantileSummaries): QuantileSummaries = {
+ require(headSampled.isEmpty, "Current buffer needs to be compressed before merge")
+ require(other.headSampled.isEmpty, "Other buffer needs to be compressed before merge")
+ if (other.count == 0) {
+ this.shallowCopy
+ } else if (count == 0) {
+ other.shallowCopy
+ } else {
+ // Merge the two buffers.
+ // The GK algorithm is a bit unclear about it, but it seems there is no need to adjust the
+ // statistics during the merging: the invariants are still respected after the merge.
+ // TODO: could replace full sort by ordered merge, the two lists are known to be sorted
+ // already.
+ val res = (sampled ++ other.sampled).sortBy(_.value)
+ val comp = compressImmut(res, mergeThreshold = 2 * relativeError * count)
+ new QuantileSummaries(
+ other.compressThreshold, other.relativeError, comp, other.count + count)
+ }
+ }
+
+ /**
+ * Runs a query for a given quantile.
+ * The result follows the approximation guarantees detailed above.
+ * The query can only be run on a compressed summary: you need to call compress() before using
+ * it.
+ *
+ * @param quantile the target quantile
+ * @return
+ */
+ def query(quantile: Double): Double = {
+ require(quantile >= 0 && quantile <= 1.0, "quantile should be in the range [0.0, 1.0]")
+ require(headSampled.isEmpty,
+ "Cannot operate on an uncompressed summary, call compress() first")
+
+ if (quantile <= relativeError) {
+ return sampled.head.value
+ }
+
+ if (quantile >= 1 - relativeError) {
+ return sampled.last.value
+ }
+
+ // Target rank
+ val rank = math.ceil(quantile * count).toInt
+ val targetError = math.ceil(relativeError * count)
+ // Minimum rank at current sample
+ var minRank = 0
+ var i = 1
+ while (i < sampled.size - 1) {
+ val curSample = sampled(i)
+ minRank += curSample.g
+ val maxRank = minRank + curSample.delta
+ if (maxRank - targetError <= rank && rank <= minRank + targetError) {
+ return curSample.value
+ }
+ i += 1
+ }
+ sampled.last.value
+ }
+}
+
+object QuantileSummaries {
+ // TODO(tjhunter) more tuning could be done one the constants here, but for now
+ // the main cost of the algorithm is accessing the data in SQL.
+ /**
+ * The default value for the compression threshold.
+ */
+ val defaultCompressThreshold: Int = 10000
+
+ /**
+ * The size of the head buffer.
+ */
+ val defaultHeadSize: Int = 50000
+
+ /**
+ * The default value for the relative error (1%).
+ * With this value, the best extreme percentiles that can be approximated are 1% and 99%.
+ */
+ val defaultRelativeError: Double = 0.01
+
+ /**
+ * Statistics from the Greenwald-Khanna paper.
+ * @param value the sampled value
+ * @param g the minimum rank jump from the previous value's minimum rank
+ * @param delta the maximum span of the rank.
+ */
+ case class Stats(value: Double, g: Int, delta: Int)
+
+ private def compressImmut(
+ currentSamples: IndexedSeq[Stats],
+ mergeThreshold: Double): Array[Stats] = {
+ if (currentSamples.isEmpty) {
+ return Array.empty[Stats]
+ }
+ val res: ArrayBuffer[Stats] = ArrayBuffer.empty
+ // Start for the last element, which is always part of the set.
+ // The head contains the current new head, that may be merged with the current element.
+ var head = currentSamples.last
+ var i = currentSamples.size - 2
+ // Do not compress the last element
+ while (i >= 1) {
+ // The current sample:
+ val sample1 = currentSamples(i)
+ // Do we need to compress?
+ if (sample1.g + head.g + head.delta < mergeThreshold) {
+ // Do not insert yet, just merge the current element into the head.
+ head = head.copy(g = head.g + sample1.g)
+ } else {
+ // Prepend the current head, and keep the current sample as target for merging.
+ res.prepend(head)
+ head = sample1
+ }
+ i -= 1
+ }
+ res.prepend(head)
+ // If necessary, add the minimum element:
+ res.prepend(currentSamples.head)
+ res.toArray
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/stat/ApproxQuantileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala
similarity index 96%
rename from sql/core/src/test/scala/org/apache/spark/sql/execution/stat/ApproxQuantileSuite.scala
rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala
index 0a989d026ce1c..89b2a22a3de45 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/stat/ApproxQuantileSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala
@@ -15,15 +15,13 @@
* limitations under the License.
*/
-package org.apache.spark.sql.execution.stat
+package org.apache.spark.sql.catalyst.util
import scala.util.Random
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.execution.stat.StatFunctions.QuantileSummaries
-
-class ApproxQuantileSuite extends SparkFunSuite {
+class QuantileSummariesSuite extends SparkFunSuite {
private val r = new Random(1)
private val n = 100
@@ -125,5 +123,4 @@ class ApproxQuantileSuite extends SparkFunSuite {
checkQuantile(0.001, data, s)
}
}
-
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
index 7c58c4897fcd5..822f49ecab47b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -17,20 +17,17 @@
package org.apache.spark.sql.execution.stat
-import scala.collection.mutable.ArrayBuffer
-
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
import org.apache.spark.sql.catalyst.expressions.{Cast, GenericMutableRow}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
+import org.apache.spark.sql.catalyst.util.QuantileSummaries
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
object StatFunctions extends Logging {
- import QuantileSummaries.Stats
-
/**
* Calculates the approximate quantiles of multiple numerical columns of a DataFrame in one pass.
*
@@ -95,248 +92,6 @@ object StatFunctions extends Logging {
summaries.map { summary => probabilities.map(summary.query) }
}
- /**
- * Helper class to compute approximate quantile summary.
- * This implementation is based on the algorithm proposed in the paper:
- * "Space-efficient Online Computation of Quantile Summaries" by Greenwald, Michael
- * and Khanna, Sanjeev. (http://dx.doi.org/10.1145/375663.375670)
- *
- * In order to optimize for speed, it maintains an internal buffer of the last seen samples,
- * and only inserts them after crossing a certain size threshold. This guarantees a near-constant
- * runtime complexity compared to the original algorithm.
- *
- * @param compressThreshold the compression threshold.
- * After the internal buffer of statistics crosses this size, it attempts to compress the
- * statistics together.
- * @param relativeError the target relative error.
- * It is uniform across the complete range of values.
- * @param sampled a buffer of quantile statistics.
- * See the G-K article for more details.
- * @param count the count of all the elements *inserted in the sampled buffer*
- * (excluding the head buffer)
- */
- class QuantileSummaries(
- val compressThreshold: Int,
- val relativeError: Double,
- val sampled: Array[Stats] = Array.empty,
- val count: Long = 0L) extends Serializable {
-
- // a buffer of latest samples seen so far
- private val headSampled: ArrayBuffer[Double] = ArrayBuffer.empty
-
- import QuantileSummaries._
-
- /**
- * Returns a summary with the given observation inserted into the summary.
- * This method may either modify in place the current summary (and return the same summary,
- * modified in place), or it may create a new summary from scratch it necessary.
- * @param x the new observation to insert into the summary
- */
- def insert(x: Double): QuantileSummaries = {
- headSampled.append(x)
- if (headSampled.size >= defaultHeadSize) {
- this.withHeadBufferInserted
- } else {
- this
- }
- }
-
- /**
- * Inserts an array of (unsorted samples) in a batch, sorting the array first to traverse
- * the summary statistics in a single batch.
- *
- * This method does not modify the current object and returns if necessary a new copy.
- *
- * @return a new quantile summary object.
- */
- private def withHeadBufferInserted: QuantileSummaries = {
- if (headSampled.isEmpty) {
- return this
- }
- var currentCount = count
- val sorted = headSampled.toArray.sorted
- val newSamples: ArrayBuffer[Stats] = new ArrayBuffer[Stats]()
- // The index of the next element to insert
- var sampleIdx = 0
- // The index of the sample currently being inserted.
- var opsIdx: Int = 0
- while(opsIdx < sorted.length) {
- val currentSample = sorted(opsIdx)
- // Add all the samples before the next observation.
- while(sampleIdx < sampled.size && sampled(sampleIdx).value <= currentSample) {
- newSamples.append(sampled(sampleIdx))
- sampleIdx += 1
- }
-
- // If it is the first one to insert, of if it is the last one
- currentCount += 1
- val delta =
- if (newSamples.isEmpty || (sampleIdx == sampled.size && opsIdx == sorted.length - 1)) {
- 0
- } else {
- math.floor(2 * relativeError * currentCount).toInt
- }
-
- val tuple = Stats(currentSample, 1, delta)
- newSamples.append(tuple)
- opsIdx += 1
- }
-
- // Add all the remaining existing samples
- while(sampleIdx < sampled.size) {
- newSamples.append(sampled(sampleIdx))
- sampleIdx += 1
- }
- new QuantileSummaries(compressThreshold, relativeError, newSamples.toArray, currentCount)
- }
-
- /**
- * Returns a new summary that compresses the summary statistics and the head buffer.
- *
- * This implements the COMPRESS function of the GK algorithm. It does not modify the object.
- *
- * @return a new summary object with compressed statistics
- */
- def compress(): QuantileSummaries = {
- // Inserts all the elements first
- val inserted = this.withHeadBufferInserted
- assert(inserted.headSampled.isEmpty)
- assert(inserted.count == count + headSampled.size)
- val compressed =
- compressImmut(inserted.sampled, mergeThreshold = 2 * relativeError * inserted.count)
- new QuantileSummaries(compressThreshold, relativeError, compressed, inserted.count)
- }
-
- private def shallowCopy: QuantileSummaries = {
- new QuantileSummaries(compressThreshold, relativeError, sampled, count)
- }
-
- /**
- * Merges two (compressed) summaries together.
- *
- * Returns a new summary.
- */
- def merge(other: QuantileSummaries): QuantileSummaries = {
- require(headSampled.isEmpty, "Current buffer needs to be compressed before merge")
- require(other.headSampled.isEmpty, "Other buffer needs to be compressed before merge")
- if (other.count == 0) {
- this.shallowCopy
- } else if (count == 0) {
- other.shallowCopy
- } else {
- // Merge the two buffers.
- // The GK algorithm is a bit unclear about it, but it seems there is no need to adjust the
- // statistics during the merging: the invariants are still respected after the merge.
- // TODO: could replace full sort by ordered merge, the two lists are known to be sorted
- // already.
- val res = (sampled ++ other.sampled).sortBy(_.value)
- val comp = compressImmut(res, mergeThreshold = 2 * relativeError * count)
- new QuantileSummaries(
- other.compressThreshold, other.relativeError, comp, other.count + count)
- }
- }
-
- /**
- * Runs a query for a given quantile.
- * The result follows the approximation guarantees detailed above.
- * The query can only be run on a compressed summary: you need to call compress() before using
- * it.
- *
- * @param quantile the target quantile
- * @return
- */
- def query(quantile: Double): Double = {
- require(quantile >= 0 && quantile <= 1.0, "quantile should be in the range [0.0, 1.0]")
- require(headSampled.isEmpty,
- "Cannot operate on an uncompressed summary, call compress() first")
-
- if (quantile <= relativeError) {
- return sampled.head.value
- }
-
- if (quantile >= 1 - relativeError) {
- return sampled.last.value
- }
-
- // Target rank
- val rank = math.ceil(quantile * count).toInt
- val targetError = math.ceil(relativeError * count)
- // Minimum rank at current sample
- var minRank = 0
- var i = 1
- while (i < sampled.size - 1) {
- val curSample = sampled(i)
- minRank += curSample.g
- val maxRank = minRank + curSample.delta
- if (maxRank - targetError <= rank && rank <= minRank + targetError) {
- return curSample.value
- }
- i += 1
- }
- sampled.last.value
- }
- }
-
- object QuantileSummaries {
- // TODO(tjhunter) more tuning could be done one the constants here, but for now
- // the main cost of the algorithm is accessing the data in SQL.
- /**
- * The default value for the compression threshold.
- */
- val defaultCompressThreshold: Int = 10000
-
- /**
- * The size of the head buffer.
- */
- val defaultHeadSize: Int = 50000
-
- /**
- * The default value for the relative error (1%).
- * With this value, the best extreme percentiles that can be approximated are 1% and 99%.
- */
- val defaultRelativeError: Double = 0.01
-
- /**
- * Statistics from the Greenwald-Khanna paper.
- * @param value the sampled value
- * @param g the minimum rank jump from the previous value's minimum rank
- * @param delta the maximum span of the rank.
- */
- case class Stats(value: Double, g: Int, delta: Int)
-
- private def compressImmut(
- currentSamples: IndexedSeq[Stats],
- mergeThreshold: Double): Array[Stats] = {
- if (currentSamples.isEmpty) {
- return Array.empty[Stats]
- }
- val res: ArrayBuffer[Stats] = ArrayBuffer.empty
- // Start for the last element, which is always part of the set.
- // The head contains the current new head, that may be merged with the current element.
- var head = currentSamples.last
- var i = currentSamples.size - 2
- // Do not compress the last element
- while (i >= 1) {
- // The current sample:
- val sample1 = currentSamples(i)
- // Do we need to compress?
- if (sample1.g + head.g + head.delta < mergeThreshold) {
- // Do not insert yet, just merge the current element into the head.
- head = head.copy(g = head.g + sample1.g)
- } else {
- // Prepend the current head, and keep the current sample as target for merging.
- res.prepend(head)
- head = sample1
- }
- i -= 1
- }
- res.prepend(head)
- // If necessary, add the minimum element:
- res.prepend(currentSamples.head)
- res.toArray
- }
- }
-
/** Calculate the Pearson Correlation Coefficient for the given columns */
def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = {
val counts = collectStatisticalData(df, cols, "correlation")
From 9d376ad76ca702ae3fc6ffd0567e7590d9a8daf3 Mon Sep 17 00:00:00 2001
From: Jacek Laskowski
Date: Tue, 23 Aug 2016 12:59:25 +0200
Subject: [PATCH 067/270] [SPARK-17199] Use CatalystConf.resolver for
case-sensitivity comparison
## What changes were proposed in this pull request?
Use `CatalystConf.resolver` consistently for case-sensitivity comparison (removed dups).
## How was this patch tested?
Local build. Waiting for Jenkins to ensure clean build and test.
Author: Jacek Laskowski
Closes #14771 from jaceklaskowski/17199-catalystconf-resolver.
---
.../apache/spark/sql/catalyst/analysis/Analyzer.scala | 8 +-------
.../spark/sql/execution/datasources/DataSource.scala | 10 ++--------
.../sql/execution/datasources/DataSourceStrategy.scala | 8 +-------
.../spark/sql/execution/streaming/FileStreamSink.scala | 6 +-----
4 files changed, 5 insertions(+), 27 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 41e0e6d65e9ad..e559f235c5a38 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -64,13 +64,7 @@ class Analyzer(
this(catalog, conf, conf.optimizerMaxIterations)
}
- def resolver: Resolver = {
- if (conf.caseSensitiveAnalysis) {
- caseSensitiveResolution
- } else {
- caseInsensitiveResolution
- }
- }
+ def resolver: Resolver = conf.resolver
protected val fixedPoint = FixedPoint(maxIterations)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index 5ad6ae0956e1c..b783d699745b1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -394,13 +394,7 @@ case class DataSource(
sparkSession, globbedPaths, options, partitionSchema, !checkPathExist)
val dataSchema = userSpecifiedSchema.map { schema =>
- val equality =
- if (sparkSession.sessionState.conf.caseSensitiveAnalysis) {
- org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
- } else {
- org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
- }
-
+ val equality = sparkSession.sessionState.conf.resolver
StructType(schema.filterNot(f => partitionColumns.exists(equality(_, f.name))))
}.orElse {
format.inferSchema(
@@ -430,7 +424,7 @@ case class DataSource(
relation
}
- /** Writes the give [[DataFrame]] out to this [[DataSource]]. */
+ /** Writes the given [[DataFrame]] out to this [[DataSource]]. */
def write(
mode: SaveMode,
data: DataFrame): BaseRelation = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 5eba7df060c4e..a6621054fc74b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -45,13 +45,7 @@ import org.apache.spark.unsafe.types.UTF8String
*/
case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] {
- def resolver: Resolver = {
- if (conf.caseSensitiveAnalysis) {
- caseSensitiveResolution
- } else {
- caseInsensitiveResolution
- }
- }
+ def resolver: Resolver = conf.resolver
// Visible for testing.
def convertStaticPartitions(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
index 117d6672ee2f7..0f7d958136835 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
@@ -102,11 +102,7 @@ class FileStreamSinkWriter(
// Get the actual partition columns as attributes after matching them by name with
// the given columns names.
private val partitionColumns = partitionColumnNames.map { col =>
- val nameEquality = if (data.sparkSession.sessionState.conf.caseSensitiveAnalysis) {
- org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
- } else {
- org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
- }
+ val nameEquality = data.sparkSession.sessionState.conf.resolver
data.logicalPlan.output.find(f => nameEquality(f.name, col)).getOrElse {
throw new RuntimeException(s"Partition column $col not found in schema $dataSchema")
}
From 97d461b75badbfa323d7f1508b20600ea189bb95 Mon Sep 17 00:00:00 2001
From: Jagadeesan
Date: Tue, 23 Aug 2016 12:23:30 +0100
Subject: [PATCH 068/270] [SPARK-17095] [Documentation] [Latex and Scala doc do
not play nicely]
## What changes were proposed in this pull request?
In Latex, it is common to find "}}}" when closing several expressions at once. [SPARK-16822](https://issues.apache.org/jira/browse/SPARK-16822) added Mathjax to render Latex equations in scaladoc. However, when scala doc sees "}}}" or "{{{" it treats it as a special character for code block. This results in some very strange output.
Author: Jagadeesan
Closes #14688 from jagadeesanas2/SPARK-17095.
---
.../spark/ml/feature/PolynomialExpansion.scala | 8 +++++---
.../ml/regression/GeneralizedLinearRegression.scala | 8 +++++---
.../spark/ml/regression/LinearRegression.scala | 9 ++++++---
.../spark/mllib/clustering/StreamingKMeans.scala | 12 ++++++++----
4 files changed, 24 insertions(+), 13 deletions(-)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
index 6e872c1f2cada..25fb6be5afd81 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
@@ -76,9 +76,11 @@ class PolynomialExpansion @Since("1.4.0") (@Since("1.4.0") override val uid: Str
* (n + d choose d) (including 1 and first-order values). For example, let f([a, b, c], 3) be the
* function that expands [a, b, c] to their monomials of degree 3. We have the following recursion:
*
- * {{{
- * f([a, b, c], 3) = f([a, b], 3) ++ f([a, b], 2) * c ++ f([a, b], 1) * c^2 ++ [c^3]
- * }}}
+ *
+ * $$
+ * f([a, b, c], 3) &= f([a, b], 3) ++ f([a, b], 2) * c ++ f([a, b], 1) * c^2 ++ [c^3]
+ * $$
+ *
*
* To handle sparsity, if c is zero, we can skip all monomials that contain it. We remember the
* current index and increment it properly for sparse input.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 1d4dfd1147589..02b27fb650979 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -196,9 +196,11 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
/**
* Sets the regularization parameter for L2 regularization.
* The regularization term is
- * {{{
- * 0.5 * regParam * L2norm(coefficients)^2
- * }}}
+ *
+ * $$
+ * 0.5 * regParam * L2norm(coefficients)^2
+ * $$
+ *
* Default is 0.0.
*
* @group setParam
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index b1bb9b9fe0058..7fddfd9b10f84 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -338,9 +338,12 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
/*
Note that in Linear Regression, the objective history (loss + regularization) returned
from optimizer is computed in the scaled space given by the following formula.
- {{{
- L = 1/2n||\sum_i w_i(x_i - \bar{x_i}) / \hat{x_i} - (y - \bar{y}) / \hat{y}||^2 + regTerms
- }}}
+
+ $$
+ L &= 1/2n||\sum_i w_i(x_i - \bar{x_i}) / \hat{x_i} - (y - \bar{y}) / \hat{y}||^2
+ + regTerms \\
+ $$
+
*/
val arrayBuilder = mutable.ArrayBuilder.make[Double]
var state: optimizer.State = null
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
index 52bdccb919a61..f20ab09bf0b42 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
@@ -39,10 +39,14 @@ import org.apache.spark.util.random.XORShiftRandom
* generalized to incorporate forgetfullness (i.e. decay).
* The update rule (for each cluster) is:
*
- * {{{
- * c_t+1 = [(c_t * n_t * a) + (x_t * m_t)] / [n_t + m_t]
- * n_t+t = n_t * a + m_t
- * }}}
+ *
+ * $$
+ * \begin{align}
+ * c_t+1 &= [(c_t * n_t * a) + (x_t * m_t)] / [n_t + m_t] \\
+ * n_t+t &= n_t * a + m_t
+ * \end{align}
+ * $$
+ *
*
* Where c_t is the previously estimated centroid for that cluster,
* n_t is the number of points assigned to it thus far, x_t is the centroid
From 9afdfc94f49395e69a7959e881c19d787ce00c3e Mon Sep 17 00:00:00 2001
From: Davies Liu
Date: Tue, 23 Aug 2016 09:45:13 -0700
Subject: [PATCH 069/270] [SPARK-13286] [SQL] add the next expression of
SQLException as cause
## What changes were proposed in this pull request?
Some JDBC driver (for example PostgreSQL) does not use the underlying exception as cause, but have another APIs (getNextException) to access that, so it it's included in the error logging, making us hard to find the root cause, especially in batch mode.
This PR will pull out the next exception and add it as cause (if it's different) or suppressed (if there is another different cause).
## How was this patch tested?
Can't reproduce this on the default JDBC driver, so did not add a regression test.
Author: Davies Liu
Closes #14722 from davies/keep_cause.
---
.../execution/datasources/jdbc/JdbcUtils.scala | 15 +++++++++++++--
1 file changed, 13 insertions(+), 2 deletions(-)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index a33c26d81354f..cbd504603bbf4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution.datasources.jdbc
-import java.sql.{Connection, Driver, DriverManager, PreparedStatement}
+import java.sql.{Connection, Driver, DriverManager, PreparedStatement, SQLException}
import java.util.Properties
import scala.collection.JavaConverters._
@@ -289,7 +289,7 @@ object JdbcUtils extends Logging {
}
val stmt = insertStatement(conn, table, rddSchema, dialect)
val setters: Array[JDBCValueSetter] = rddSchema.fields.map(_.dataType)
- .map(makeSetter(conn, dialect, _)).toArray
+ .map(makeSetter(conn, dialect, _)).toArray
try {
var rowCount = 0
@@ -322,6 +322,17 @@ object JdbcUtils extends Logging {
conn.commit()
}
committed = true
+ } catch {
+ case e: SQLException =>
+ val cause = e.getNextException
+ if (e.getCause != cause) {
+ if (e.getCause == null) {
+ e.initCause(cause)
+ } else {
+ e.addSuppressed(cause)
+ }
+ }
+ throw e
} finally {
if (!committed) {
// The stage must fail. We got here through an exception path, so
From 8fd63e808e15c8a7e78fef847183c86f332daa91 Mon Sep 17 00:00:00 2001
From: Junyang Qian
Date: Tue, 23 Aug 2016 11:22:32 -0700
Subject: [PATCH 070/270] [SPARKR][MINOR] Remove reference link for common
Windows environment variables
## What changes were proposed in this pull request?
The PR removes reference link in the doc for environment variables for common Windows folders. The cran check gave code 503: service unavailable on the original link.
## How was this patch tested?
Manual check.
Author: Junyang Qian
Closes #14767 from junyangq/SPARKR-RemoveLink.
---
R/pkg/R/install.R | 4 +---
1 file changed, 1 insertion(+), 3 deletions(-)
diff --git a/R/pkg/R/install.R b/R/pkg/R/install.R
index ff81e86835ff8..c6ed88e032a71 100644
--- a/R/pkg/R/install.R
+++ b/R/pkg/R/install.R
@@ -50,9 +50,7 @@
#' \itemize{
#' \item Mac OS X: \file{~/Library/Caches/spark}
#' \item Unix: \env{$XDG_CACHE_HOME} if defined, otherwise \file{~/.cache/spark}
-#' \item Windows: \file{\%LOCALAPPDATA\%\\spark\\spark\\Cache}. See
-#' \href{https://www.microsoft.com/security/portal/mmpc/shared/variables.aspx}{
-#' Windows Common Folder Variables} about \%LOCALAPPDATA\%
+#' \item Windows: \file{\%LOCALAPPDATA\%\\spark\\spark\\Cache}.
#' }
#' @param overwrite If \code{TRUE}, download and overwrite the existing tar file in localDir
#' and force re-install Spark (in case the local directory or file is corrupted)
From 588559911de94bbe0932526ee1e1dd36a581a423 Mon Sep 17 00:00:00 2001
From: hyukjinkwon
Date: Tue, 23 Aug 2016 21:21:43 +0100
Subject: [PATCH 071/270] [MINOR][DOC] Use standard quotes instead of "curly
quote" marks from Mac in structured streaming programming guides
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
## What changes were proposed in this pull request?
This PR fixes curly quotes (`“` and `”` ) to standard quotes (`"`).
This will be a actual problem when users copy and paste the examples. This would not work.
This seems only happening in `structured-streaming-programming-guide.md`.
## How was this patch tested?
Manually built.
This will change some examples to be correctly marked down as below:

to

Author: hyukjinkwon
Closes #14770 from HyukjinKwon/minor-quotes.
---
.../structured-streaming-programming-guide.md | 38 +++++++++----------
1 file changed, 19 insertions(+), 19 deletions(-)
diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md
index 226ff740a5d67..090b14f4ce2bc 100644
--- a/docs/structured-streaming-programming-guide.md
+++ b/docs/structured-streaming-programming-guide.md
@@ -88,7 +88,7 @@ val words = lines.as[String].flatMap(_.split(" "))
val wordCounts = words.groupBy("value").count()
{% endhighlight %}
-This `lines` DataFrame represents an unbounded table containing the streaming text data. This table contains one column of strings named “value”, and each line in the streaming text data becomes a row in the table. Note, that this is not currently receiving any data as we are just setting up the transformation, and have not yet started it. Next, we have converted the DataFrame to a Dataset of String using `.as[String]`, so that we can apply the `flatMap` operation to split each line into multiple words. The resultant `words` Dataset contains all the words. Finally, we have defined the `wordCounts` DataFrame by grouping by the unique values in the Dataset and counting them. Note that this is a streaming DataFrame which represents the running word counts of the stream.
+This `lines` DataFrame represents an unbounded table containing the streaming text data. This table contains one column of strings named "value", and each line in the streaming text data becomes a row in the table. Note, that this is not currently receiving any data as we are just setting up the transformation, and have not yet started it. Next, we have converted the DataFrame to a Dataset of String using `.as[String]`, so that we can apply the `flatMap` operation to split each line into multiple words. The resultant `words` Dataset contains all the words. Finally, we have defined the `wordCounts` DataFrame by grouping by the unique values in the Dataset and counting them. Note that this is a streaming DataFrame which represents the running word counts of the stream.
@@ -117,7 +117,7 @@ Dataset words = lines
Dataset wordCounts = words.groupBy("value").count();
{% endhighlight %}
-This `lines` DataFrame represents an unbounded table containing the streaming text data. This table contains one column of strings named “value”, and each line in the streaming text data becomes a row in the table. Note, that this is not currently receiving any data as we are just setting up the transformation, and have not yet started it. Next, we have converted the DataFrame to a Dataset of String using `.as(Encoders.STRING())`, so that we can apply the `flatMap` operation to split each line into multiple words. The resultant `words` Dataset contains all the words. Finally, we have defined the `wordCounts` DataFrame by grouping by the unique values in the Dataset and counting them. Note that this is a streaming DataFrame which represents the running word counts of the stream.
+This `lines` DataFrame represents an unbounded table containing the streaming text data. This table contains one column of strings named "value", and each line in the streaming text data becomes a row in the table. Note, that this is not currently receiving any data as we are just setting up the transformation, and have not yet started it. Next, we have converted the DataFrame to a Dataset of String using `.as(Encoders.STRING())`, so that we can apply the `flatMap` operation to split each line into multiple words. The resultant `words` Dataset contains all the words. Finally, we have defined the `wordCounts` DataFrame by grouping by the unique values in the Dataset and counting them. Note that this is a streaming DataFrame which represents the running word counts of the stream.
@@ -142,12 +142,12 @@ words = lines.select(
wordCounts = words.groupBy('word').count()
{% endhighlight %}
-This `lines` DataFrame represents an unbounded table containing the streaming text data. This table contains one column of strings named “value”, and each line in the streaming text data becomes a row in the table. Note, that this is not currently receiving any data as we are just setting up the transformation, and have not yet started it. Next, we have used two built-in SQL functions - split and explode, to split each line into multiple rows with a word each. In addition, we use the function `alias` to name the new column as “word”. Finally, we have defined the `wordCounts` DataFrame by grouping by the unique values in the Dataset and counting them. Note that this is a streaming DataFrame which represents the running word counts of the stream.
+This `lines` DataFrame represents an unbounded table containing the streaming text data. This table contains one column of strings named "value", and each line in the streaming text data becomes a row in the table. Note, that this is not currently receiving any data as we are just setting up the transformation, and have not yet started it. Next, we have used two built-in SQL functions - split and explode, to split each line into multiple rows with a word each. In addition, we use the function `alias` to name the new column as "word". Finally, we have defined the `wordCounts` DataFrame by grouping by the unique values in the Dataset and counting them. Note that this is a streaming DataFrame which represents the running word counts of the stream.
-We have now set up the query on the streaming data. All that is left is to actually start receiving data and computing the counts. To do this, we set it up to print the complete set of counts (specified by `outputMode(“complete”)`) to the console every time they are updated. And then start the streaming computation using `start()`.
+We have now set up the query on the streaming data. All that is left is to actually start receiving data and computing the counts. To do this, we set it up to print the complete set of counts (specified by `outputMode("complete")`) to the console every time they are updated. And then start the streaming computation using `start()`.
@@ -361,16 +361,16 @@ table, and Spark runs it as an *incremental* query on the *unbounded* input
table. Let’s understand this model in more detail.
## Basic Concepts
-Consider the input data stream as the “Input Table”. Every data item that is
+Consider the input data stream as the "Input Table". Every data item that is
arriving on the stream is like a new row being appended to the Input Table.

-A query on the input will generate the “Result Table”. Every trigger interval (say, every 1 second), new rows get appended to the Input Table, which eventually updates the Result Table. Whenever the result table gets updated, we would want to write the changed result rows to an external sink.
+A query on the input will generate the "Result Table". Every trigger interval (say, every 1 second), new rows get appended to the Input Table, which eventually updates the Result Table. Whenever the result table gets updated, we would want to write the changed result rows to an external sink.

-The “Output” is defined as what gets written out to the external storage. The output can be defined in different modes
+The "Output" is defined as what gets written out to the external storage. The output can be defined in different modes
- *Complete Mode* - The entire updated Result Table will be written to the external storage. It is up to the storage connector to decide how to handle writing of the entire table.
@@ -386,7 +386,7 @@ the final `wordCounts` DataFrame is the result table. Note that the query on
streaming `lines` DataFrame to generate `wordCounts` is *exactly the same* as
it would be a static DataFrame. However, when this query is started, Spark
will continuously check for new data from the socket connection. If there is
-new data, Spark will run an “incremental” query that combines the previous
+new data, Spark will run an "incremental" query that combines the previous
running counts with the new data to compute updated counts, as shown below.

@@ -682,8 +682,8 @@ Streaming DataFrames can be joined with static DataFrames to create new streamin
val staticDf = spark.read. ...
val streamingDf = spark.readStream. ...
-streamingDf.join(staticDf, “type”) // inner equi-join with a static DF
-streamingDf.join(staticDf, “type”, “right_join”) // right outer join with a static DF
+streamingDf.join(staticDf, "type") // inner equi-join with a static DF
+streamingDf.join(staticDf, "type", "right_join") // right outer join with a static DF
{% endhighlight %}
@@ -789,7 +789,7 @@ Here is a table of all the sinks, and the corresponding settings.
File Sink (only parquet in Spark 2.0) |
Append |
- writeStream .format(“parquet”) .start() |
+ writeStream .format("parquet") .start() |
Yes |
Supports writes to partitioned tables. Partitioning by time may be useful. |
@@ -803,14 +803,14 @@ Here is a table of all the sinks, and the corresponding settings.
Console Sink |
Append, Complete |
- writeStream .format(“console”) .start() |
+ writeStream .format("console") .start() |
No |
|
Memory Sink |
Append, Complete |
- writeStream .format(“memory”) .queryName(“table”) .start() |
+ writeStream .format("memory") .queryName("table") .start() |
No |
Saves the output data as a table, for interactive querying. Table name is the query name. |
@@ -839,7 +839,7 @@ noAggDF
.start()
// ========== DF with aggregation ==========
-val aggDF = df.groupBy(“device”).count()
+val aggDF = df.groupBy("device").count()
// Print updated aggregations to console
aggDF
@@ -879,7 +879,7 @@ noAggDF
.start();
// ========== DF with aggregation ==========
-Dataset
aggDF = df.groupBy(“device”).count();
+Dataset aggDF = df.groupBy("device").count();
// Print updated aggregations to console
aggDF
@@ -919,7 +919,7 @@ noAggDF\
.start()
# ========== DF with aggregation ==========
-aggDF = df.groupBy(“device”).count()
+aggDF = df.groupBy("device").count()
# Print updated aggregations to console
aggDF\
@@ -1095,7 +1095,7 @@ In case of a failure or intentional shutdown, you can recover the previous progr
aggDF
.writeStream
.outputMode("complete")
- .option(“checkpointLocation”, “path/to/HDFS/dir”)
+ .option("checkpointLocation", "path/to/HDFS/dir")
.format("memory")
.start()
{% endhighlight %}
@@ -1107,7 +1107,7 @@ aggDF
aggDF
.writeStream()
.outputMode("complete")
- .option(“checkpointLocation”, “path/to/HDFS/dir”)
+ .option("checkpointLocation", "path/to/HDFS/dir")
.format("memory")
.start();
{% endhighlight %}
@@ -1119,7 +1119,7 @@ aggDF
aggDF\
.writeStream()\
.outputMode("complete")\
- .option(“checkpointLocation”, “path/to/HDFS/dir”)\
+ .option("checkpointLocation", "path/to/HDFS/dir")\
.format("memory")\
.start()
{% endhighlight %}
From 6555ef0ccbecd09c3071670e10f0c1e2d7713bfe Mon Sep 17 00:00:00 2001
From: Zheng RuiFeng
Date: Tue, 23 Aug 2016 21:25:04 +0100
Subject: [PATCH 072/270] [TRIVIAL] Typo Fix
## What changes were proposed in this pull request?
Fix a typo
## How was this patch tested?
no tests
Author: Zheng RuiFeng
Closes #14772 from zhengruifeng/minor_numClasses.
---
.../scala/org/apache/spark/ml/classification/Classifier.scala | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index 6decea72719fd..d1b21b16f2342 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -83,7 +83,7 @@ abstract class Classifier[
case Row(label: Double, features: Vector) =>
require(label % 1 == 0 && label >= 0 && label < numClasses, s"Classifier was given" +
s" dataset with invalid label $label. Labels must be integers in range" +
- s" [0, 1, ..., $numClasses), where numClasses=$numClasses.")
+ s" [0, $numClasses).")
LabeledPoint(label, features)
}
}
From bf8ff833e30b39e5e5e35ba8dcac31b79323838c Mon Sep 17 00:00:00 2001
From: Josh Rosen
Date: Tue, 23 Aug 2016 22:31:58 +0200
Subject: [PATCH 073/270] [SPARK-17194] Use single quotes when generating SQL
for string literals
When Spark emits SQL for a string literal, it should wrap the string in single quotes, not double quotes. Databases which adhere more strictly to the ANSI SQL standards, such as Postgres, allow only single-quotes to be used for denoting string literals (see http://stackoverflow.com/a/1992331/590203).
Author: Josh Rosen
Closes #14763 from JoshRosen/SPARK-17194.
---
.../org/apache/spark/sql/catalyst/expressions/literals.scala | 4 ++--
.../src/test/resources/sqlgen/broadcast_join_subquery.sql | 2 +-
sql/hive/src/test/resources/sqlgen/case_with_key.sql | 2 +-
.../src/test/resources/sqlgen/case_with_key_and_else.sql | 2 +-
sql/hive/src/test/resources/sqlgen/inline_tables.sql | 2 +-
.../src/test/resources/sqlgen/json_tuple_generator_1.sql | 2 +-
.../src/test/resources/sqlgen/json_tuple_generator_2.sql | 2 +-
sql/hive/src/test/resources/sqlgen/not_like.sql | 2 +-
sql/hive/src/test/resources/sqlgen/subquery_exists_1.sql | 2 +-
sql/hive/src/test/resources/sqlgen/subquery_exists_2.sql | 2 +-
.../src/test/resources/sqlgen/subquery_exists_having_1.sql | 2 +-
.../src/test/resources/sqlgen/subquery_exists_having_2.sql | 2 +-
.../src/test/resources/sqlgen/subquery_exists_having_3.sql | 2 +-
sql/hive/src/test/resources/sqlgen/subquery_in_having_1.sql | 2 +-
sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql | 2 +-
sql/hive/src/test/resources/sqlgen/subquery_not_exists_1.sql | 2 +-
sql/hive/src/test/resources/sqlgen/subquery_not_exists_2.sql | 2 +-
.../test/resources/sqlgen/subquery_not_exists_having_1.sql | 2 +-
.../test/resources/sqlgen/subquery_not_exists_having_2.sql | 2 +-
.../spark/sql/catalyst/ExpressionSQLBuilderSuite.scala | 5 +++--
20 files changed, 23 insertions(+), 22 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 7040008769a32..55fd9c0834fcc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -245,8 +245,8 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression with
case (_, NullType | _: ArrayType | _: MapType | _: StructType) if value == null => "NULL"
case _ if value == null => s"CAST(NULL AS ${dataType.sql})"
case (v: UTF8String, StringType) =>
- // Escapes all backslashes and double quotes.
- "\"" + v.toString.replace("\\", "\\\\").replace("\"", "\\\"") + "\""
+ // Escapes all backslashes and single quotes.
+ "'" + v.toString.replace("\\", "\\\\").replace("'", "\\'") + "'"
case (v: Byte, ByteType) => v + "Y"
case (v: Short, ShortType) => v + "S"
case (v: Long, LongType) => v + "L"
diff --git a/sql/hive/src/test/resources/sqlgen/broadcast_join_subquery.sql b/sql/hive/src/test/resources/sqlgen/broadcast_join_subquery.sql
index 3e2111d58a3c6..ec881a216e0b0 100644
--- a/sql/hive/src/test/resources/sqlgen/broadcast_join_subquery.sql
+++ b/sql/hive/src/test/resources/sqlgen/broadcast_join_subquery.sql
@@ -5,4 +5,4 @@ FROM (SELECT x.key as key1, x.value as value1, y.key as key2, y.value as value2
JOIN srcpart z ON (subq.key1 = z.key and z.ds='2008-04-08' and z.hr=11)
ORDER BY subq.key1, z.value
--------------------------------------------------------------------------------
-SELECT `gen_attr_0` AS `key1`, `gen_attr_1` AS `value` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_5` AS `gen_attr_0`, `gen_attr_7` AS `gen_attr_6`, `gen_attr_9` AS `gen_attr_8`, `gen_attr_11` AS `gen_attr_10` FROM (SELECT `key` AS `gen_attr_5`, `value` AS `gen_attr_7` FROM `default`.`src1`) AS gen_subquery_0 INNER JOIN (SELECT `key` AS `gen_attr_9`, `value` AS `gen_attr_11` FROM `default`.`src`) AS gen_subquery_1 ON (`gen_attr_5` = `gen_attr_9`)) AS subq INNER JOIN (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_1`, `ds` AS `gen_attr_3`, `hr` AS `gen_attr_4` FROM `default`.`srcpart`) AS gen_subquery_2 ON (((`gen_attr_0` = `gen_attr_2`) AND (`gen_attr_3` = "2008-04-08")) AND (CAST(`gen_attr_4` AS DOUBLE) = CAST(11 AS DOUBLE))) ORDER BY `gen_attr_0` ASC, `gen_attr_1` ASC) AS gen_subquery_3
+SELECT `gen_attr_0` AS `key1`, `gen_attr_1` AS `value` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_5` AS `gen_attr_0`, `gen_attr_7` AS `gen_attr_6`, `gen_attr_9` AS `gen_attr_8`, `gen_attr_11` AS `gen_attr_10` FROM (SELECT `key` AS `gen_attr_5`, `value` AS `gen_attr_7` FROM `default`.`src1`) AS gen_subquery_0 INNER JOIN (SELECT `key` AS `gen_attr_9`, `value` AS `gen_attr_11` FROM `default`.`src`) AS gen_subquery_1 ON (`gen_attr_5` = `gen_attr_9`)) AS subq INNER JOIN (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_1`, `ds` AS `gen_attr_3`, `hr` AS `gen_attr_4` FROM `default`.`srcpart`) AS gen_subquery_2 ON (((`gen_attr_0` = `gen_attr_2`) AND (`gen_attr_3` = '2008-04-08')) AND (CAST(`gen_attr_4` AS DOUBLE) = CAST(11 AS DOUBLE))) ORDER BY `gen_attr_0` ASC, `gen_attr_1` ASC) AS gen_subquery_3
diff --git a/sql/hive/src/test/resources/sqlgen/case_with_key.sql b/sql/hive/src/test/resources/sqlgen/case_with_key.sql
index dff65f10835f3..e991ebafdc90e 100644
--- a/sql/hive/src/test/resources/sqlgen/case_with_key.sql
+++ b/sql/hive/src/test/resources/sqlgen/case_with_key.sql
@@ -1,4 +1,4 @@
-- This file is automatically generated by LogicalPlanToSQLSuite.
SELECT CASE id WHEN 0 THEN 'foo' WHEN 1 THEN 'bar' END FROM parquet_t0
--------------------------------------------------------------------------------
-SELECT `gen_attr_0` AS `CASE WHEN (id = CAST(0 AS BIGINT)) THEN foo WHEN (id = CAST(1 AS BIGINT)) THEN bar END` FROM (SELECT CASE WHEN (`gen_attr_1` = CAST(0 AS BIGINT)) THEN "foo" WHEN (`gen_attr_1` = CAST(1 AS BIGINT)) THEN "bar" END AS `gen_attr_0` FROM (SELECT `id` AS `gen_attr_1` FROM `default`.`parquet_t0`) AS gen_subquery_0) AS gen_subquery_1
+SELECT `gen_attr_0` AS `CASE WHEN (id = CAST(0 AS BIGINT)) THEN foo WHEN (id = CAST(1 AS BIGINT)) THEN bar END` FROM (SELECT CASE WHEN (`gen_attr_1` = CAST(0 AS BIGINT)) THEN 'foo' WHEN (`gen_attr_1` = CAST(1 AS BIGINT)) THEN 'bar' END AS `gen_attr_0` FROM (SELECT `id` AS `gen_attr_1` FROM `default`.`parquet_t0`) AS gen_subquery_0) AS gen_subquery_1
diff --git a/sql/hive/src/test/resources/sqlgen/case_with_key_and_else.sql b/sql/hive/src/test/resources/sqlgen/case_with_key_and_else.sql
index af3e169b54315..492777e376ecc 100644
--- a/sql/hive/src/test/resources/sqlgen/case_with_key_and_else.sql
+++ b/sql/hive/src/test/resources/sqlgen/case_with_key_and_else.sql
@@ -1,4 +1,4 @@
-- This file is automatically generated by LogicalPlanToSQLSuite.
SELECT CASE id WHEN 0 THEN 'foo' WHEN 1 THEN 'bar' ELSE 'baz' END FROM parquet_t0
--------------------------------------------------------------------------------
-SELECT `gen_attr_0` AS `CASE WHEN (id = CAST(0 AS BIGINT)) THEN foo WHEN (id = CAST(1 AS BIGINT)) THEN bar ELSE baz END` FROM (SELECT CASE WHEN (`gen_attr_1` = CAST(0 AS BIGINT)) THEN "foo" WHEN (`gen_attr_1` = CAST(1 AS BIGINT)) THEN "bar" ELSE "baz" END AS `gen_attr_0` FROM (SELECT `id` AS `gen_attr_1` FROM `default`.`parquet_t0`) AS gen_subquery_0) AS gen_subquery_1
+SELECT `gen_attr_0` AS `CASE WHEN (id = CAST(0 AS BIGINT)) THEN foo WHEN (id = CAST(1 AS BIGINT)) THEN bar ELSE baz END` FROM (SELECT CASE WHEN (`gen_attr_1` = CAST(0 AS BIGINT)) THEN 'foo' WHEN (`gen_attr_1` = CAST(1 AS BIGINT)) THEN 'bar' ELSE 'baz' END AS `gen_attr_0` FROM (SELECT `id` AS `gen_attr_1` FROM `default`.`parquet_t0`) AS gen_subquery_0) AS gen_subquery_1
diff --git a/sql/hive/src/test/resources/sqlgen/inline_tables.sql b/sql/hive/src/test/resources/sqlgen/inline_tables.sql
index 602551e69da6e..18803a3ee59b9 100644
--- a/sql/hive/src/test/resources/sqlgen/inline_tables.sql
+++ b/sql/hive/src/test/resources/sqlgen/inline_tables.sql
@@ -1,4 +1,4 @@
-- This file is automatically generated by LogicalPlanToSQLSuite.
select * from values ("one", 1), ("two", 2), ("three", null) as data(a, b) where b > 1
--------------------------------------------------------------------------------
-SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `b` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (VALUES ("one", 1), ("two", 2), ("three", CAST(NULL AS INT)) AS gen_subquery_0(gen_attr_0, gen_attr_1)) AS data WHERE (`gen_attr_1` > 1)) AS data
+SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `b` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (VALUES ('one', 1), ('two', 2), ('three', CAST(NULL AS INT)) AS gen_subquery_0(gen_attr_0, gen_attr_1)) AS data WHERE (`gen_attr_1` > 1)) AS data
diff --git a/sql/hive/src/test/resources/sqlgen/json_tuple_generator_1.sql b/sql/hive/src/test/resources/sqlgen/json_tuple_generator_1.sql
index 6f5562a20cccd..11e45a48f1b89 100644
--- a/sql/hive/src/test/resources/sqlgen/json_tuple_generator_1.sql
+++ b/sql/hive/src/test/resources/sqlgen/json_tuple_generator_1.sql
@@ -3,4 +3,4 @@ SELECT c0, c1, c2
FROM parquet_t3
LATERAL VIEW JSON_TUPLE(json, 'f1', 'f2', 'f3') jt
--------------------------------------------------------------------------------
-SELECT `gen_attr_0` AS `c0`, `gen_attr_1` AS `c1`, `gen_attr_2` AS `c2` FROM (SELECT `gen_attr_0`, `gen_attr_1`, `gen_attr_2` FROM (SELECT `arr` AS `gen_attr_4`, `arr2` AS `gen_attr_5`, `json` AS `gen_attr_3`, `id` AS `gen_attr_6` FROM `default`.`parquet_t3`) AS gen_subquery_0 LATERAL VIEW json_tuple(`gen_attr_3`, "f1", "f2", "f3") gen_subquery_1 AS `gen_attr_0`, `gen_attr_1`, `gen_attr_2`) AS jt
+SELECT `gen_attr_0` AS `c0`, `gen_attr_1` AS `c1`, `gen_attr_2` AS `c2` FROM (SELECT `gen_attr_0`, `gen_attr_1`, `gen_attr_2` FROM (SELECT `arr` AS `gen_attr_4`, `arr2` AS `gen_attr_5`, `json` AS `gen_attr_3`, `id` AS `gen_attr_6` FROM `default`.`parquet_t3`) AS gen_subquery_0 LATERAL VIEW json_tuple(`gen_attr_3`, 'f1', 'f2', 'f3') gen_subquery_1 AS `gen_attr_0`, `gen_attr_1`, `gen_attr_2`) AS jt
diff --git a/sql/hive/src/test/resources/sqlgen/json_tuple_generator_2.sql b/sql/hive/src/test/resources/sqlgen/json_tuple_generator_2.sql
index 0d4f67f18426b..d86b39df57442 100644
--- a/sql/hive/src/test/resources/sqlgen/json_tuple_generator_2.sql
+++ b/sql/hive/src/test/resources/sqlgen/json_tuple_generator_2.sql
@@ -3,4 +3,4 @@ SELECT a, b, c
FROM parquet_t3
LATERAL VIEW JSON_TUPLE(json, 'f1', 'f2', 'f3') jt AS a, b, c
--------------------------------------------------------------------------------
-SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `b`, `gen_attr_2` AS `c` FROM (SELECT `gen_attr_0`, `gen_attr_1`, `gen_attr_2` FROM (SELECT `arr` AS `gen_attr_4`, `arr2` AS `gen_attr_5`, `json` AS `gen_attr_3`, `id` AS `gen_attr_6` FROM `default`.`parquet_t3`) AS gen_subquery_0 LATERAL VIEW json_tuple(`gen_attr_3`, "f1", "f2", "f3") gen_subquery_1 AS `gen_attr_0`, `gen_attr_1`, `gen_attr_2`) AS jt
+SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `b`, `gen_attr_2` AS `c` FROM (SELECT `gen_attr_0`, `gen_attr_1`, `gen_attr_2` FROM (SELECT `arr` AS `gen_attr_4`, `arr2` AS `gen_attr_5`, `json` AS `gen_attr_3`, `id` AS `gen_attr_6` FROM `default`.`parquet_t3`) AS gen_subquery_0 LATERAL VIEW json_tuple(`gen_attr_3`, 'f1', 'f2', 'f3') gen_subquery_1 AS `gen_attr_0`, `gen_attr_1`, `gen_attr_2`) AS jt
diff --git a/sql/hive/src/test/resources/sqlgen/not_like.sql b/sql/hive/src/test/resources/sqlgen/not_like.sql
index da39a62225a53..22485045e212e 100644
--- a/sql/hive/src/test/resources/sqlgen/not_like.sql
+++ b/sql/hive/src/test/resources/sqlgen/not_like.sql
@@ -1,4 +1,4 @@
-- This file is automatically generated by LogicalPlanToSQLSuite.
SELECT id FROM t0 WHERE id + 5 NOT LIKE '1%'
--------------------------------------------------------------------------------
-SELECT `gen_attr_0` AS `id` FROM (SELECT `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`t0`) AS gen_subquery_0 WHERE (NOT CAST((`gen_attr_0` + CAST(5 AS BIGINT)) AS STRING) LIKE "1%")) AS t0
+SELECT `gen_attr_0` AS `id` FROM (SELECT `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`t0`) AS gen_subquery_0 WHERE (NOT CAST((`gen_attr_0` + CAST(5 AS BIGINT)) AS STRING) LIKE '1%')) AS t0
diff --git a/sql/hive/src/test/resources/sqlgen/subquery_exists_1.sql b/sql/hive/src/test/resources/sqlgen/subquery_exists_1.sql
index d598e4c036a29..bd28d8dca94c2 100644
--- a/sql/hive/src/test/resources/sqlgen/subquery_exists_1.sql
+++ b/sql/hive/src/test/resources/sqlgen/subquery_exists_1.sql
@@ -5,4 +5,4 @@ where exists (select a.key
from src a
where b.value = a.value and a.key = b.key and a.value > 'val_9')
--------------------------------------------------------------------------------
-SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `value` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_1` FROM `default`.`src`) AS gen_subquery_0 WHERE EXISTS(SELECT `gen_attr_4` AS `1` FROM (SELECT 1 AS `gen_attr_4` FROM (SELECT `gen_attr_3`, `gen_attr_2` FROM (SELECT `key` AS `gen_attr_3`, `value` AS `gen_attr_2` FROM `default`.`src`) AS gen_subquery_2 WHERE (`gen_attr_2` > "val_9")) AS gen_subquery_1 WHERE ((`gen_attr_1` = `gen_attr_2`) AND (`gen_attr_3` = `gen_attr_0`))) AS gen_subquery_3)) AS b
+SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `value` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_1` FROM `default`.`src`) AS gen_subquery_0 WHERE EXISTS(SELECT `gen_attr_4` AS `1` FROM (SELECT 1 AS `gen_attr_4` FROM (SELECT `gen_attr_3`, `gen_attr_2` FROM (SELECT `key` AS `gen_attr_3`, `value` AS `gen_attr_2` FROM `default`.`src`) AS gen_subquery_2 WHERE (`gen_attr_2` > 'val_9')) AS gen_subquery_1 WHERE ((`gen_attr_1` = `gen_attr_2`) AND (`gen_attr_3` = `gen_attr_0`))) AS gen_subquery_3)) AS b
diff --git a/sql/hive/src/test/resources/sqlgen/subquery_exists_2.sql b/sql/hive/src/test/resources/sqlgen/subquery_exists_2.sql
index a353c33af21a6..d2965fc0b9b77 100644
--- a/sql/hive/src/test/resources/sqlgen/subquery_exists_2.sql
+++ b/sql/hive/src/test/resources/sqlgen/subquery_exists_2.sql
@@ -6,4 +6,4 @@ from (select *
from src a
where b.value = a.value and a.key = b.key and a.value > 'val_9')) a
--------------------------------------------------------------------------------
-SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `value` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_1` FROM `default`.`src`) AS gen_subquery_0 WHERE EXISTS(SELECT `gen_attr_4` AS `1` FROM (SELECT 1 AS `gen_attr_4` FROM (SELECT `gen_attr_3`, `gen_attr_2` FROM (SELECT `key` AS `gen_attr_3`, `value` AS `gen_attr_2` FROM `default`.`src`) AS gen_subquery_2 WHERE (`gen_attr_2` > "val_9")) AS gen_subquery_1 WHERE ((`gen_attr_1` = `gen_attr_2`) AND (`gen_attr_3` = `gen_attr_0`))) AS gen_subquery_3)) AS a) AS a
+SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `value` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_1` FROM `default`.`src`) AS gen_subquery_0 WHERE EXISTS(SELECT `gen_attr_4` AS `1` FROM (SELECT 1 AS `gen_attr_4` FROM (SELECT `gen_attr_3`, `gen_attr_2` FROM (SELECT `key` AS `gen_attr_3`, `value` AS `gen_attr_2` FROM `default`.`src`) AS gen_subquery_2 WHERE (`gen_attr_2` > 'val_9')) AS gen_subquery_1 WHERE ((`gen_attr_1` = `gen_attr_2`) AND (`gen_attr_3` = `gen_attr_0`))) AS gen_subquery_3)) AS a) AS a
diff --git a/sql/hive/src/test/resources/sqlgen/subquery_exists_having_1.sql b/sql/hive/src/test/resources/sqlgen/subquery_exists_having_1.sql
index f6873d24e16ec..93ce902b75994 100644
--- a/sql/hive/src/test/resources/sqlgen/subquery_exists_having_1.sql
+++ b/sql/hive/src/test/resources/sqlgen/subquery_exists_having_1.sql
@@ -6,4 +6,4 @@ having exists (select a.key
from src a
where a.key = b.key and a.value > 'val_9')
--------------------------------------------------------------------------------
-SELECT `gen_attr_1` AS `key`, `gen_attr_2` AS `count(1)` FROM (SELECT `gen_attr_1`, count(1) AS `gen_attr_2` FROM (SELECT `key` AS `gen_attr_1`, `value` AS `gen_attr_3` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_1` HAVING EXISTS(SELECT `gen_attr_4` AS `1` FROM (SELECT 1 AS `gen_attr_4` FROM (SELECT `gen_attr_0` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_5` FROM `default`.`src`) AS gen_subquery_2 WHERE (`gen_attr_5` > "val_9")) AS gen_subquery_1 WHERE (`gen_attr_0` = `gen_attr_1`)) AS gen_subquery_3)) AS b
+SELECT `gen_attr_1` AS `key`, `gen_attr_2` AS `count(1)` FROM (SELECT `gen_attr_1`, count(1) AS `gen_attr_2` FROM (SELECT `key` AS `gen_attr_1`, `value` AS `gen_attr_3` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_1` HAVING EXISTS(SELECT `gen_attr_4` AS `1` FROM (SELECT 1 AS `gen_attr_4` FROM (SELECT `gen_attr_0` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_5` FROM `default`.`src`) AS gen_subquery_2 WHERE (`gen_attr_5` > 'val_9')) AS gen_subquery_1 WHERE (`gen_attr_0` = `gen_attr_1`)) AS gen_subquery_3)) AS b
diff --git a/sql/hive/src/test/resources/sqlgen/subquery_exists_having_2.sql b/sql/hive/src/test/resources/sqlgen/subquery_exists_having_2.sql
index 8452ef946f61d..411e073f0d280 100644
--- a/sql/hive/src/test/resources/sqlgen/subquery_exists_having_2.sql
+++ b/sql/hive/src/test/resources/sqlgen/subquery_exists_having_2.sql
@@ -7,4 +7,4 @@ from (select b.key, count(*)
from src a
where a.key = b.key and a.value > 'val_9')) a
--------------------------------------------------------------------------------
-SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `count(1)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, count(1) AS `gen_attr_1` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_3` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_0` HAVING EXISTS(SELECT `gen_attr_4` AS `1` FROM (SELECT 1 AS `gen_attr_4` FROM (SELECT `gen_attr_2` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_5` FROM `default`.`src`) AS gen_subquery_2 WHERE (`gen_attr_5` > "val_9")) AS gen_subquery_1 WHERE (`gen_attr_2` = `gen_attr_0`)) AS gen_subquery_3)) AS a) AS a
+SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `count(1)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, count(1) AS `gen_attr_1` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_3` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_0` HAVING EXISTS(SELECT `gen_attr_4` AS `1` FROM (SELECT 1 AS `gen_attr_4` FROM (SELECT `gen_attr_2` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_5` FROM `default`.`src`) AS gen_subquery_2 WHERE (`gen_attr_5` > 'val_9')) AS gen_subquery_1 WHERE (`gen_attr_2` = `gen_attr_0`)) AS gen_subquery_3)) AS a) AS a
diff --git a/sql/hive/src/test/resources/sqlgen/subquery_exists_having_3.sql b/sql/hive/src/test/resources/sqlgen/subquery_exists_having_3.sql
index 2ef38ce42944f..b2ed0b0557aff 100644
--- a/sql/hive/src/test/resources/sqlgen/subquery_exists_having_3.sql
+++ b/sql/hive/src/test/resources/sqlgen/subquery_exists_having_3.sql
@@ -6,4 +6,4 @@ having exists (select a.key
from src a
where a.value > 'val_9' and a.value = min(b.value))
--------------------------------------------------------------------------------
-SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `min(value)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, min(`gen_attr_4`) AS `gen_attr_1`, min(`gen_attr_4`) AS `gen_attr_3` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_4` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_0` HAVING EXISTS(SELECT `gen_attr_5` AS `1` FROM (SELECT 1 AS `gen_attr_5` FROM (SELECT `gen_attr_6`, `gen_attr_2` FROM (SELECT `key` AS `gen_attr_6`, `value` AS `gen_attr_2` FROM `default`.`src`) AS gen_subquery_3 WHERE (`gen_attr_2` > "val_9")) AS gen_subquery_2 WHERE (`gen_attr_2` = `gen_attr_3`)) AS gen_subquery_4)) AS gen_subquery_1) AS b
+SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `min(value)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, min(`gen_attr_4`) AS `gen_attr_1`, min(`gen_attr_4`) AS `gen_attr_3` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_4` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_0` HAVING EXISTS(SELECT `gen_attr_5` AS `1` FROM (SELECT 1 AS `gen_attr_5` FROM (SELECT `gen_attr_6`, `gen_attr_2` FROM (SELECT `key` AS `gen_attr_6`, `value` AS `gen_attr_2` FROM `default`.`src`) AS gen_subquery_3 WHERE (`gen_attr_2` > 'val_9')) AS gen_subquery_2 WHERE (`gen_attr_2` = `gen_attr_3`)) AS gen_subquery_4)) AS gen_subquery_1) AS b
diff --git a/sql/hive/src/test/resources/sqlgen/subquery_in_having_1.sql b/sql/hive/src/test/resources/sqlgen/subquery_in_having_1.sql
index bfa58211b12f1..9894f5ab39c76 100644
--- a/sql/hive/src/test/resources/sqlgen/subquery_in_having_1.sql
+++ b/sql/hive/src/test/resources/sqlgen/subquery_in_having_1.sql
@@ -5,4 +5,4 @@ group by key
having count(*) in (select count(*) from src s1 where s1.key = '90' group by s1.key)
order by key
--------------------------------------------------------------------------------
-SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `count(1)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, count(1) AS `gen_attr_1`, count(1) AS `gen_attr_2` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_4` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_0` HAVING (`gen_attr_2` IN (SELECT `gen_attr_5` AS `_c0` FROM (SELECT `gen_attr_3` AS `gen_attr_5` FROM (SELECT count(1) AS `gen_attr_3` FROM (SELECT `key` AS `gen_attr_6`, `value` AS `gen_attr_7` FROM `default`.`src`) AS gen_subquery_3 WHERE (CAST(`gen_attr_6` AS DOUBLE) = CAST("90" AS DOUBLE)) GROUP BY `gen_attr_6`) AS gen_subquery_2) AS gen_subquery_4))) AS gen_subquery_1 ORDER BY `gen_attr_0` ASC) AS src
+SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `count(1)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, count(1) AS `gen_attr_1`, count(1) AS `gen_attr_2` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_4` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_0` HAVING (`gen_attr_2` IN (SELECT `gen_attr_5` AS `_c0` FROM (SELECT `gen_attr_3` AS `gen_attr_5` FROM (SELECT count(1) AS `gen_attr_3` FROM (SELECT `key` AS `gen_attr_6`, `value` AS `gen_attr_7` FROM `default`.`src`) AS gen_subquery_3 WHERE (CAST(`gen_attr_6` AS DOUBLE) = CAST('90' AS DOUBLE)) GROUP BY `gen_attr_6`) AS gen_subquery_2) AS gen_subquery_4))) AS gen_subquery_1 ORDER BY `gen_attr_0` ASC) AS src
diff --git a/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql b/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql
index f7503bce068f8..c3a122aa889b9 100644
--- a/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql
+++ b/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql
@@ -7,4 +7,4 @@ having b.key in (select a.key
where a.value > 'val_9' and a.value = min(b.value))
order by b.key
--------------------------------------------------------------------------------
-SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `min(value)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, min(`gen_attr_5`) AS `gen_attr_1`, min(`gen_attr_5`) AS `gen_attr_4` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_5` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_0` HAVING (struct(`gen_attr_0`, `gen_attr_4`) IN (SELECT `gen_attr_6` AS `_c0`, `gen_attr_7` AS `_c1` FROM (SELECT `gen_attr_2` AS `gen_attr_6`, `gen_attr_3` AS `gen_attr_7` FROM (SELECT `gen_attr_2`, `gen_attr_3` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_3` FROM `default`.`src`) AS gen_subquery_3 WHERE (`gen_attr_3` > "val_9")) AS gen_subquery_2) AS gen_subquery_4))) AS gen_subquery_1 ORDER BY `gen_attr_0` ASC) AS b
+SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `min(value)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, min(`gen_attr_5`) AS `gen_attr_1`, min(`gen_attr_5`) AS `gen_attr_4` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_5` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_0` HAVING (struct(`gen_attr_0`, `gen_attr_4`) IN (SELECT `gen_attr_6` AS `_c0`, `gen_attr_7` AS `_c1` FROM (SELECT `gen_attr_2` AS `gen_attr_6`, `gen_attr_3` AS `gen_attr_7` FROM (SELECT `gen_attr_2`, `gen_attr_3` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_3` FROM `default`.`src`) AS gen_subquery_3 WHERE (`gen_attr_3` > 'val_9')) AS gen_subquery_2) AS gen_subquery_4))) AS gen_subquery_1 ORDER BY `gen_attr_0` ASC) AS b
diff --git a/sql/hive/src/test/resources/sqlgen/subquery_not_exists_1.sql b/sql/hive/src/test/resources/sqlgen/subquery_not_exists_1.sql
index 54a38ec0edb4c..eed20a5d311f3 100644
--- a/sql/hive/src/test/resources/sqlgen/subquery_not_exists_1.sql
+++ b/sql/hive/src/test/resources/sqlgen/subquery_not_exists_1.sql
@@ -5,4 +5,4 @@ where not exists (select a.key
from src a
where b.value = a.value and a.key = b.key and a.value > 'val_2')
--------------------------------------------------------------------------------
-SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `value` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_1` FROM `default`.`src`) AS gen_subquery_0 WHERE (NOT EXISTS(SELECT `gen_attr_4` AS `1` FROM (SELECT 1 AS `gen_attr_4` FROM (SELECT `gen_attr_3`, `gen_attr_2` FROM (SELECT `key` AS `gen_attr_3`, `value` AS `gen_attr_2` FROM `default`.`src`) AS gen_subquery_2 WHERE (`gen_attr_2` > "val_2")) AS gen_subquery_1 WHERE ((`gen_attr_1` = `gen_attr_2`) AND (`gen_attr_3` = `gen_attr_0`))) AS gen_subquery_3))) AS b
+SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `value` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_1` FROM `default`.`src`) AS gen_subquery_0 WHERE (NOT EXISTS(SELECT `gen_attr_4` AS `1` FROM (SELECT 1 AS `gen_attr_4` FROM (SELECT `gen_attr_3`, `gen_attr_2` FROM (SELECT `key` AS `gen_attr_3`, `value` AS `gen_attr_2` FROM `default`.`src`) AS gen_subquery_2 WHERE (`gen_attr_2` > 'val_2')) AS gen_subquery_1 WHERE ((`gen_attr_1` = `gen_attr_2`) AND (`gen_attr_3` = `gen_attr_0`))) AS gen_subquery_3))) AS b
diff --git a/sql/hive/src/test/resources/sqlgen/subquery_not_exists_2.sql b/sql/hive/src/test/resources/sqlgen/subquery_not_exists_2.sql
index c05bb5d991b4b..7040e106e7ba2 100644
--- a/sql/hive/src/test/resources/sqlgen/subquery_not_exists_2.sql
+++ b/sql/hive/src/test/resources/sqlgen/subquery_not_exists_2.sql
@@ -5,4 +5,4 @@ where not exists (select a.key
from src a
where b.value = a.value and a.value > 'val_2')
--------------------------------------------------------------------------------
-SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `value` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_1` FROM `default`.`src`) AS gen_subquery_0 WHERE (NOT EXISTS(SELECT `gen_attr_3` AS `1` FROM (SELECT 1 AS `gen_attr_3` FROM (SELECT `gen_attr_4`, `gen_attr_2` FROM (SELECT `key` AS `gen_attr_4`, `value` AS `gen_attr_2` FROM `default`.`src`) AS gen_subquery_2 WHERE (`gen_attr_2` > "val_2")) AS gen_subquery_1 WHERE (`gen_attr_1` = `gen_attr_2`)) AS gen_subquery_3))) AS b
+SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `value` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_1` FROM `default`.`src`) AS gen_subquery_0 WHERE (NOT EXISTS(SELECT `gen_attr_3` AS `1` FROM (SELECT 1 AS `gen_attr_3` FROM (SELECT `gen_attr_4`, `gen_attr_2` FROM (SELECT `key` AS `gen_attr_4`, `value` AS `gen_attr_2` FROM `default`.`src`) AS gen_subquery_2 WHERE (`gen_attr_2` > 'val_2')) AS gen_subquery_1 WHERE (`gen_attr_1` = `gen_attr_2`)) AS gen_subquery_3))) AS b
diff --git a/sql/hive/src/test/resources/sqlgen/subquery_not_exists_having_1.sql b/sql/hive/src/test/resources/sqlgen/subquery_not_exists_having_1.sql
index d6047c52f20fc..3c0e90ed42223 100644
--- a/sql/hive/src/test/resources/sqlgen/subquery_not_exists_having_1.sql
+++ b/sql/hive/src/test/resources/sqlgen/subquery_not_exists_having_1.sql
@@ -6,4 +6,4 @@ having not exists (select a.key
from src a
where b.value = a.value and a.key = b.key and a.value > 'val_12')
--------------------------------------------------------------------------------
-SELECT `gen_attr_3` AS `key`, `gen_attr_0` AS `value` FROM (SELECT `gen_attr_3`, `gen_attr_0` FROM (SELECT `key` AS `gen_attr_3`, `value` AS `gen_attr_0` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_3`, `gen_attr_0` HAVING (NOT EXISTS(SELECT `gen_attr_4` AS `1` FROM (SELECT 1 AS `gen_attr_4` FROM (SELECT `gen_attr_2`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_1` FROM `default`.`src`) AS gen_subquery_2 WHERE (`gen_attr_1` > "val_12")) AS gen_subquery_1 WHERE ((`gen_attr_0` = `gen_attr_1`) AND (`gen_attr_2` = `gen_attr_3`))) AS gen_subquery_3))) AS b
+SELECT `gen_attr_3` AS `key`, `gen_attr_0` AS `value` FROM (SELECT `gen_attr_3`, `gen_attr_0` FROM (SELECT `key` AS `gen_attr_3`, `value` AS `gen_attr_0` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_3`, `gen_attr_0` HAVING (NOT EXISTS(SELECT `gen_attr_4` AS `1` FROM (SELECT 1 AS `gen_attr_4` FROM (SELECT `gen_attr_2`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_1` FROM `default`.`src`) AS gen_subquery_2 WHERE (`gen_attr_1` > 'val_12')) AS gen_subquery_1 WHERE ((`gen_attr_0` = `gen_attr_1`) AND (`gen_attr_2` = `gen_attr_3`))) AS gen_subquery_3))) AS b
diff --git a/sql/hive/src/test/resources/sqlgen/subquery_not_exists_having_2.sql b/sql/hive/src/test/resources/sqlgen/subquery_not_exists_having_2.sql
index 8b5402d8aa77f..0c16f9e58b9b9 100644
--- a/sql/hive/src/test/resources/sqlgen/subquery_not_exists_having_2.sql
+++ b/sql/hive/src/test/resources/sqlgen/subquery_not_exists_having_2.sql
@@ -6,4 +6,4 @@ having not exists (select distinct a.key
from src a
where b.value = a.value and a.value > 'val_12')
--------------------------------------------------------------------------------
-SELECT `gen_attr_2` AS `key`, `gen_attr_0` AS `value` FROM (SELECT `gen_attr_2`, `gen_attr_0` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_0` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_2`, `gen_attr_0` HAVING (NOT EXISTS(SELECT `gen_attr_3` AS `1` FROM (SELECT 1 AS `gen_attr_3` FROM (SELECT DISTINCT `gen_attr_4`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_4`, `value` AS `gen_attr_1` FROM `default`.`src`) AS gen_subquery_2 WHERE (`gen_attr_1` > "val_12")) AS gen_subquery_1 WHERE (`gen_attr_0` = `gen_attr_1`)) AS gen_subquery_3))) AS b
+SELECT `gen_attr_2` AS `key`, `gen_attr_0` AS `value` FROM (SELECT `gen_attr_2`, `gen_attr_0` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_0` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_2`, `gen_attr_0` HAVING (NOT EXISTS(SELECT `gen_attr_3` AS `1` FROM (SELECT 1 AS `gen_attr_3` FROM (SELECT DISTINCT `gen_attr_4`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_4`, `value` AS `gen_attr_1` FROM `default`.`src`) AS gen_subquery_2 WHERE (`gen_attr_1` > 'val_12')) AS gen_subquery_1 WHERE (`gen_attr_0` = `gen_attr_1`)) AS gen_subquery_3))) AS b
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala
index 7249df813b17f..93dc0f493eb7b 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala
@@ -24,8 +24,9 @@ import org.apache.spark.sql.catalyst.expressions.{If, Literal, SpecifiedWindowFr
class ExpressionSQLBuilderSuite extends SQLBuilderTest {
test("literal") {
- checkSQL(Literal("foo"), "\"foo\"")
- checkSQL(Literal("\"foo\""), "\"\\\"foo\\\"\"")
+ checkSQL(Literal("foo"), "'foo'")
+ checkSQL(Literal("\"foo\""), "'\"foo\"'")
+ checkSQL(Literal("'foo'"), "'\\'foo\\''")
checkSQL(Literal(1: Byte), "1Y")
checkSQL(Literal(2: Short), "2S")
checkSQL(Literal(4: Int), "4")
From c1937dd19a23bd096a4707656c7ba19fb5c16966 Mon Sep 17 00:00:00 2001
From: Tejas Patil
Date: Tue, 23 Aug 2016 18:48:08 -0700
Subject: [PATCH 074/270] [SPARK-16862] Configurable buffer size in
`UnsafeSorterSpillReader`
## What changes were proposed in this pull request?
Jira: https://issues.apache.org/jira/browse/SPARK-16862
`BufferedInputStream` used in `UnsafeSorterSpillReader` uses the default 8k buffer to read data off disk. This PR makes it configurable to improve on disk reads. I have made the default value to be 1 MB as with that value I observed improved performance.
## How was this patch tested?
I am relying on the existing unit tests.
## Performance
After deploying this change to prod and setting the config to 1 mb, there was a 12% reduction in the CPU time and 19.5% reduction in CPU reservation time.
Author: Tejas Patil
Closes #14726 from tejasapatil/spill_buffer_2.
---
.../unsafe/sort/UnsafeSorterSpillReader.java | 22 ++++++++++++++++++-
1 file changed, 21 insertions(+), 1 deletion(-)
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
index 1d588c37c5db0..d048cf7aeb5f1 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
@@ -22,15 +22,21 @@
import com.google.common.io.ByteStreams;
import com.google.common.io.Closeables;
+import org.apache.spark.SparkEnv;
import org.apache.spark.serializer.SerializerManager;
import org.apache.spark.storage.BlockId;
import org.apache.spark.unsafe.Platform;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
/**
* Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description
* of the file format).
*/
public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implements Closeable {
+ private static final Logger logger = LoggerFactory.getLogger(UnsafeSorterSpillReader.class);
+ private static final int DEFAULT_BUFFER_SIZE_BYTES = 1024 * 1024; // 1 MB
+ private static final int MAX_BUFFER_SIZE_BYTES = 16777216; // 16 mb
private InputStream in;
private DataInputStream din;
@@ -50,7 +56,21 @@ public UnsafeSorterSpillReader(
File file,
BlockId blockId) throws IOException {
assert (file.length() > 0);
- final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file));
+ long bufferSizeBytes =
+ SparkEnv.get() == null ?
+ DEFAULT_BUFFER_SIZE_BYTES:
+ SparkEnv.get().conf().getSizeAsBytes("spark.unsafe.sorter.spill.reader.buffer.size",
+ DEFAULT_BUFFER_SIZE_BYTES);
+ if (bufferSizeBytes > MAX_BUFFER_SIZE_BYTES || bufferSizeBytes < DEFAULT_BUFFER_SIZE_BYTES) {
+ // fall back to a sane default value
+ logger.warn("Value of config \"spark.unsafe.sorter.spill.reader.buffer.size\" = {} not in " +
+ "allowed range [{}, {}). Falling back to default value : {} bytes", bufferSizeBytes,
+ DEFAULT_BUFFER_SIZE_BYTES, MAX_BUFFER_SIZE_BYTES, DEFAULT_BUFFER_SIZE_BYTES);
+ bufferSizeBytes = DEFAULT_BUFFER_SIZE_BYTES;
+ }
+
+ final BufferedInputStream bs =
+ new BufferedInputStream(new FileInputStream(file), (int) bufferSizeBytes);
try {
this.in = serializerManager.wrapForCompression(blockId, bs);
this.din = new DataInputStream(this.in);
From b9994ad05628077016331e6b411fbc09017b1e63 Mon Sep 17 00:00:00 2001
From: Weiqing Yang
Date: Tue, 23 Aug 2016 23:44:45 -0700
Subject: [PATCH 075/270] [MINOR][SQL] Remove implemented functions from
comments of 'HiveSessionCatalog.scala'
## What changes were proposed in this pull request?
This PR removes implemented functions from comments of `HiveSessionCatalog.scala`: `java_method`, `posexplode`, `str_to_map`.
## How was this patch tested?
Manual.
Author: Weiqing Yang
Closes #14769 from Sherry302/cleanComment.
---
.../org/apache/spark/sql/hive/HiveSessionCatalog.scala | 6 ++----
1 file changed, 2 insertions(+), 4 deletions(-)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
index ebed9eb6e7dca..ca8c7347f23e9 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
@@ -230,10 +230,8 @@ private[sql] class HiveSessionCatalog(
// List of functions we are explicitly not supporting are:
// compute_stats, context_ngrams, create_union,
// current_user, ewah_bitmap, ewah_bitmap_and, ewah_bitmap_empty, ewah_bitmap_or, field,
- // in_file, index, java_method,
- // matchpath, ngrams, noop, noopstreaming, noopwithmap, noopwithmapstreaming,
- // parse_url_tuple, posexplode, reflect2,
- // str_to_map, windowingtablefunction.
+ // in_file, index, matchpath, ngrams, noop, noopstreaming, noopwithmap,
+ // noopwithmapstreaming, parse_url_tuple, reflect2, windowingtablefunction.
private val hiveFunctions = Seq(
"hash",
"histogram_numeric",
From 52fa45d62a5a0bc832442f38f9e634c5d8e29e08 Mon Sep 17 00:00:00 2001
From: Wenchen Fan
Date: Tue, 23 Aug 2016 23:46:09 -0700
Subject: [PATCH 076/270] [SPARK-17186][SQL] remove catalog table type INDEX
## What changes were proposed in this pull request?
Actually Spark SQL doesn't support index, the catalog table type `INDEX` is from Hive. However, most operations in Spark SQL can't handle index table, e.g. create table, alter table, etc.
Logically index table should be invisible to end users, and Hive also generates special table name for index table to avoid users accessing it directly. Hive has special SQL syntax to create/show/drop index tables.
At Spark SQL side, although we can describe index table directly, but the result is unreadable, we should use the dedicated SQL syntax to do it(e.g. `SHOW INDEX ON tbl`). Spark SQL can also read index table directly, but the result is always empty.(Can hive read index table directly?)
This PR remove the table type `INDEX`, to make it clear that Spark SQL doesn't support index currently.
## How was this patch tested?
existing tests.
Author: Wenchen Fan
Closes #14752 from cloud-fan/minor2.
---
.../org/apache/spark/sql/catalyst/catalog/interface.scala | 1 -
.../org/apache/spark/sql/execution/command/tables.scala | 8 +++-----
.../org/apache/spark/sql/hive/MetastoreRelation.scala | 1 -
.../org/apache/spark/sql/hive/client/HiveClientImpl.scala | 4 ++--
.../spark/sql/hive/execution/HiveCommandSuite.scala | 2 +-
5 files changed, 6 insertions(+), 10 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
index f7762e0f8acd3..83e01f95c06af 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
@@ -200,7 +200,6 @@ case class CatalogTableType private(name: String)
object CatalogTableType {
val EXTERNAL = new CatalogTableType("EXTERNAL")
val MANAGED = new CatalogTableType("MANAGED")
- val INDEX = new CatalogTableType("INDEX")
val VIEW = new CatalogTableType("VIEW")
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
index 21544a37d9975..b4a15b8b2882e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
@@ -620,12 +620,11 @@ case class ShowPartitionsCommand(
* Validate and throws an [[AnalysisException]] exception under the following conditions:
* 1. If the table is not partitioned.
* 2. If it is a datasource table.
- * 3. If it is a view or index table.
+ * 3. If it is a view.
*/
- if (tab.tableType == VIEW ||
- tab.tableType == INDEX) {
+ if (tab.tableType == VIEW) {
throw new AnalysisException(
- s"SHOW PARTITIONS is not allowed on a view or index table: ${tab.qualifiedName}")
+ s"SHOW PARTITIONS is not allowed on a view: ${tab.qualifiedName}")
}
if (tab.partitionColumnNames.isEmpty) {
@@ -708,7 +707,6 @@ case class ShowCreateTableCommand(table: TableIdentifier) extends RunnableComman
case EXTERNAL => " EXTERNAL TABLE"
case VIEW => " VIEW"
case MANAGED => " TABLE"
- case INDEX => reportUnsupportedError(Seq("index table"))
}
builder ++= s"CREATE$tableTypeString ${table.quotedString}"
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/MetastoreRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/MetastoreRelation.scala
index 195fce8354134..d62bc983d0279 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/MetastoreRelation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/MetastoreRelation.scala
@@ -80,7 +80,6 @@ private[hive] case class MetastoreRelation(
tTable.setTableType(catalogTable.tableType match {
case CatalogTableType.EXTERNAL => HiveTableType.EXTERNAL_TABLE.toString
case CatalogTableType.MANAGED => HiveTableType.MANAGED_TABLE.toString
- case CatalogTableType.INDEX => HiveTableType.INDEX_TABLE.toString
case CatalogTableType.VIEW => HiveTableType.VIRTUAL_VIEW.toString
})
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
index 9b7afd462841c..81d5a124e9d4a 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
@@ -379,8 +379,9 @@ private[hive] class HiveClientImpl(
tableType = h.getTableType match {
case HiveTableType.EXTERNAL_TABLE => CatalogTableType.EXTERNAL
case HiveTableType.MANAGED_TABLE => CatalogTableType.MANAGED
- case HiveTableType.INDEX_TABLE => CatalogTableType.INDEX
case HiveTableType.VIRTUAL_VIEW => CatalogTableType.VIEW
+ case HiveTableType.INDEX_TABLE =>
+ throw new AnalysisException("Hive index table is not supported.")
},
schema = schema,
partitionColumnNames = partCols.map(_.name),
@@ -757,7 +758,6 @@ private[hive] class HiveClientImpl(
HiveTableType.EXTERNAL_TABLE
case CatalogTableType.MANAGED =>
HiveTableType.MANAGED_TABLE
- case CatalogTableType.INDEX => HiveTableType.INDEX_TABLE
case CatalogTableType.VIEW => HiveTableType.VIRTUAL_VIEW
})
// Note: In Hive the schema and partition columns must be disjoint sets
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala
index 76aa84b19410d..df33731df2d00 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala
@@ -424,7 +424,7 @@ class HiveCommandSuite extends QueryTest with SQLTestUtils with TestHiveSingleto
val message4 = intercept[AnalysisException] {
sql("SHOW PARTITIONS parquet_view1")
}.getMessage
- assert(message4.contains("is not allowed on a view or index table"))
+ assert(message4.contains("is not allowed on a view"))
}
}
From 673a80d2230602c9e6573a23e35fb0f6b832bfca Mon Sep 17 00:00:00 2001
From: Weiqing Yang
Date: Wed, 24 Aug 2016 10:12:44 +0100
Subject: [PATCH 077/270] [MINOR][BUILD] Fix Java CheckStyle Error
## What changes were proposed in this pull request?
As Spark 2.0.1 will be released soon (mentioned in the spark dev mailing list), besides the critical bugs, it's better to fix the code style errors before the release.
Before:
```
./dev/lint-java
Checkstyle checks failed at following occurrences:
[ERROR] src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java:[525] (sizes) LineLength: Line is longer than 100 characters (found 119).
[ERROR] src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCount.java:[64] (sizes) LineLength: Line is longer than 100 characters (found 103).
```
After:
```
./dev/lint-java
Using `mvn` from path: /usr/local/bin/mvn
Checkstyle checks passed.
```
## How was this patch tested?
Manual.
Author: Weiqing Yang
Closes #14768 from Sherry302/fixjavastyle.
---
.../collection/unsafe/sort/UnsafeExternalSorter.java | 3 ++-
.../sql/streaming/JavaStructuredNetworkWordCount.java | 11 ++++++-----
2 files changed, 8 insertions(+), 6 deletions(-)
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index ccf76643db2b4..196e67d8b29b6 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -522,7 +522,8 @@ public long spill() throws IOException {
// is accessing the current record. We free this page in that caller's next loadNext()
// call.
for (MemoryBlock page : allocatedPages) {
- if (!loaded || page.pageNumber != ((UnsafeInMemorySorter.SortedIterator)upstream).getCurrentPageNumber()) {
+ if (!loaded || page.pageNumber !=
+ ((UnsafeInMemorySorter.SortedIterator)upstream).getCurrentPageNumber()) {
released += page.size();
freePage(page);
} else {
diff --git a/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCount.java
index c913ee0658504..5f342e1ead6ca 100644
--- a/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCount.java
+++ b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCount.java
@@ -61,11 +61,12 @@ public static void main(String[] args) throws Exception {
.load();
// Split the lines into words
- Dataset words = lines.as(Encoders.STRING()).flatMap(new FlatMapFunction() {
- @Override
- public Iterator call(String x) {
- return Arrays.asList(x.split(" ")).iterator();
- }
+ Dataset words = lines.as(Encoders.STRING())
+ .flatMap(new FlatMapFunction() {
+ @Override
+ public Iterator call(String x) {
+ return Arrays.asList(x.split(" ")).iterator();
+ }
}, Encoders.STRING());
// Generate running word count
From 92c0eaf348b42b3479610da0be761013f9d81c54 Mon Sep 17 00:00:00 2001
From: VinceShieh
Date: Wed, 24 Aug 2016 10:16:58 +0100
Subject: [PATCH 078/270] [SPARK-17086][ML] Fix InvalidArgumentException issue
in QuantileDiscretizer when some quantiles are duplicated
## What changes were proposed in this pull request?
In cases when QuantileDiscretizerSuite is called upon a numeric array with duplicated elements, we will take the unique elements generated from approxQuantiles as input for Bucketizer.
## How was this patch tested?
An unit test is added in QuantileDiscretizerSuite
QuantileDiscretizer.fit will throw an illegal exception when calling setSplits on a list of splits
with duplicated elements. Bucketizer.setSplits should only accept either a numeric vector of two
or more unique cut points, although that may produce less number of buckets than requested.
Signed-off-by: VinceShieh
Author: VinceShieh
Closes #14747 from VinceShieh/SPARK-17086.
---
.../ml/feature/QuantileDiscretizer.scala | 7 ++++++-
.../ml/feature/QuantileDiscretizerSuite.scala | 19 +++++++++++++++++++
2 files changed, 25 insertions(+), 1 deletion(-)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index 558a7bbf0a2df..e09800877c694 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -114,7 +114,12 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui
splits(0) = Double.NegativeInfinity
splits(splits.length - 1) = Double.PositiveInfinity
- val bucketizer = new Bucketizer(uid).setSplits(splits)
+ val distinctSplits = splits.distinct
+ if (splits.length != distinctSplits.length) {
+ log.warn(s"Some quantiles were identical. Bucketing to ${distinctSplits.length - 1}" +
+ s" buckets as a result.")
+ }
+ val bucketizer = new Bucketizer(uid).setSplits(distinctSplits.sorted)
copyValues(bucketizer.setParent(this))
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
index b73dbd62328cf..18f1e89ee8148 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
@@ -52,6 +52,25 @@ class QuantileDiscretizerSuite
"Bucket sizes are not within expected relative error tolerance.")
}
+ test("Test Bucketizer on duplicated splits") {
+ val spark = this.spark
+ import spark.implicits._
+
+ val datasetSize = 12
+ val numBuckets = 5
+ val df = sc.parallelize(Array(1.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 2.0, 2.0, 2.0, 1.0, 3.0))
+ .map(Tuple1.apply).toDF("input")
+ val discretizer = new QuantileDiscretizer()
+ .setInputCol("input")
+ .setOutputCol("result")
+ .setNumBuckets(numBuckets)
+ val result = discretizer.fit(df).transform(df)
+
+ val observedNumBuckets = result.select("result").distinct.count
+ assert(2 <= observedNumBuckets && observedNumBuckets <= numBuckets,
+ "Observed number of buckets are not within expected range.")
+ }
+
test("Test transform method on unseen data") {
val spark = this.spark
import spark.implicits._
From 45b786aca2b5818dc233643e6b3a53b869560563 Mon Sep 17 00:00:00 2001
From: Yanbo Liang
Date: Wed, 24 Aug 2016 08:24:16 -0700
Subject: [PATCH 079/270] [MINOR][DOC] Fix wrong ml.feature.Normalizer
document.
## What changes were proposed in this pull request?
The ```ml.feature.Normalizer``` examples illustrate L1 norm rather than L2, we should correct corresponding document.

## How was this patch tested?
Doc change, no test.
Author: Yanbo Liang
Closes #14787 from yanboliang/normalizer.
---
docs/ml-features.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/docs/ml-features.md b/docs/ml-features.md
index 6020114845486..e41bf78521b6e 100644
--- a/docs/ml-features.md
+++ b/docs/ml-features.md
@@ -734,7 +734,7 @@ for more details on the API.
`Normalizer` is a `Transformer` which transforms a dataset of `Vector` rows, normalizing each `Vector` to have unit norm. It takes parameter `p`, which specifies the [p-norm](http://en.wikipedia.org/wiki/Norm_%28mathematics%29#p-norm) used for normalization. ($p = 2$ by default.) This normalization can help standardize your input data and improve the behavior of learning algorithms.
-The following example demonstrates how to load a dataset in libsvm format and then normalize each row to have unit $L^2$ norm and unit $L^\infty$ norm.
+The following example demonstrates how to load a dataset in libsvm format and then normalize each row to have unit $L^1$ norm and unit $L^\infty$ norm.
From d2932a0e987132c694ed59515b7c77adaad052e6 Mon Sep 17 00:00:00 2001
From: Junyang Qian
Date: Wed, 24 Aug 2016 10:40:09 -0700
Subject: [PATCH 080/270] [SPARKR][MINOR] Fix doc for show method
## What changes were proposed in this pull request?
The original doc of `show` put methods for multiple classes together but the text only talks about `SparkDataFrame`. This PR tries to fix this problem.
## How was this patch tested?
Manual test.
Author: Junyang Qian
Closes #14776 from junyangq/SPARK-FixShowDoc.
---
R/pkg/R/DataFrame.R | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 52a6628ad7b32..e12b58e2eefc5 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -212,9 +212,9 @@ setMethod("showDF",
#' show
#'
-#' Print the SparkDataFrame column names and types
+#' Print class and type information of a Spark object.
#'
-#' @param object a SparkDataFrame.
+#' @param object a Spark object. Can be a SparkDataFrame, Column, GroupedData, WindowSpec.
#'
#' @family SparkDataFrame functions
#' @rdname show
From 2fbdb606392631b1dff88ec86f388cc2559c28f5 Mon Sep 17 00:00:00 2001
From: Xin Ren
Date: Wed, 24 Aug 2016 11:18:10 -0700
Subject: [PATCH 081/270] [SPARK-16445][MLLIB][SPARKR] Multilayer Perceptron
Classifier wrapper in SparkR
https://issues.apache.org/jira/browse/SPARK-16445
## What changes were proposed in this pull request?
Create Multilayer Perceptron Classifier wrapper in SparkR
## How was this patch tested?
Tested manually on local machine
Author: Xin Ren
Closes #14447 from keypointt/SPARK-16445.
---
R/pkg/NAMESPACE | 1 +
R/pkg/R/generics.R | 4 +
R/pkg/R/mllib.R | 125 +++++++++++++++-
R/pkg/inst/tests/testthat/test_mllib.R | 32 +++++
...ultilayerPerceptronClassifierWrapper.scala | 134 ++++++++++++++++++
.../org/apache/spark/ml/r/RWrappers.scala | 2 +
6 files changed, 293 insertions(+), 5 deletions(-)
create mode 100644 mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 709057675e578..ad587a6b7d03a 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -27,6 +27,7 @@ exportMethods("glm",
"summary",
"spark.kmeans",
"fitted",
+ "spark.mlp",
"spark.naiveBayes",
"spark.survreg",
"spark.lda",
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 88884e62575df..7e626be50808d 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1330,6 +1330,10 @@ setGeneric("spark.kmeans", function(data, formula, ...) { standardGeneric("spark
#' @export
setGeneric("fitted")
+#' @rdname spark.mlp
+#' @export
+setGeneric("spark.mlp", function(data, ...) { standardGeneric("spark.mlp") })
+
#' @rdname spark.naiveBayes
#' @export
setGeneric("spark.naiveBayes", function(data, formula, ...) { standardGeneric("spark.naiveBayes") })
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index a40310d194d27..a670600ca6938 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -60,6 +60,13 @@ setClass("AFTSurvivalRegressionModel", representation(jobj = "jobj"))
#' @note KMeansModel since 2.0.0
setClass("KMeansModel", representation(jobj = "jobj"))
+#' S4 class that represents a MultilayerPerceptronClassificationModel
+#'
+#' @param jobj a Java object reference to the backing Scala MultilayerPerceptronClassifierWrapper
+#' @export
+#' @note MultilayerPerceptronClassificationModel since 2.1.0
+setClass("MultilayerPerceptronClassificationModel", representation(jobj = "jobj"))
+
#' S4 class that represents an IsotonicRegressionModel
#'
#' @param jobj a Java object reference to the backing Scala IsotonicRegressionModel
@@ -90,7 +97,7 @@ setClass("ALSModel", representation(jobj = "jobj"))
#' @export
#' @seealso \link{spark.glm}, \link{glm},
#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans},
-#' @seealso \link{spark.lda}, \link{spark.naiveBayes}, \link{spark.survreg},
+#' @seealso \link{spark.lda}, \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg}
#' @seealso \link{read.ml}
NULL
@@ -103,7 +110,7 @@ NULL
#' @export
#' @seealso \link{spark.glm}, \link{glm},
#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans},
-#' @seealso \link{spark.naiveBayes}, \link{spark.survreg},
+#' @seealso \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg}
NULL
write_internal <- function(object, path, overwrite = FALSE) {
@@ -631,6 +638,95 @@ setMethod("predict", signature(object = "KMeansModel"),
predict_internal(object, newData)
})
+#' Multilayer Perceptron Classification Model
+#'
+#' \code{spark.mlp} fits a multi-layer perceptron neural network model against a SparkDataFrame.
+#' Users can call \code{summary} to print a summary of the fitted model, \code{predict} to make
+#' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models.
+#' Only categorical data is supported.
+#' For more details, see
+#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html}{
+#' Multilayer Perceptron}
+#'
+#' @param data a \code{SparkDataFrame} of observations and labels for model fitting.
+#' @param blockSize blockSize parameter.
+#' @param layers integer vector containing the number of nodes for each layer
+#' @param solver solver parameter, supported options: "gd" (minibatch gradient descent) or "l-bfgs".
+#' @param maxIter maximum iteration number.
+#' @param tol convergence tolerance of iterations.
+#' @param stepSize stepSize parameter.
+#' @param seed seed parameter for weights initialization.
+#' @param ... additional arguments passed to the method.
+#' @return \code{spark.mlp} returns a fitted Multilayer Perceptron Classification Model.
+#' @rdname spark.mlp
+#' @aliases spark.mlp,SparkDataFrame-method
+#' @name spark.mlp
+#' @seealso \link{read.ml}
+#' @export
+#' @examples
+#' \dontrun{
+#' df <- read.df("data/mllib/sample_multiclass_classification_data.txt", source = "libsvm")
+#'
+#' # fit a Multilayer Perceptron Classification Model
+#' model <- spark.mlp(df, blockSize = 128, layers = c(4, 5, 4, 3), solver = "l-bfgs",
+#' maxIter = 100, tol = 0.5, stepSize = 1, seed = 1)
+#'
+#' # get the summary of the model
+#' summary(model)
+#'
+#' # make predictions
+#' predictions <- predict(model, df)
+#'
+#' # save and load the model
+#' path <- "path/to/model"
+#' write.ml(model, path)
+#' savedModel <- read.ml(path)
+#' summary(savedModel)
+#' }
+#' @note spark.mlp since 2.1.0
+setMethod("spark.mlp", signature(data = "SparkDataFrame"),
+ function(data, blockSize = 128, layers = c(3, 5, 2), solver = "l-bfgs", maxIter = 100,
+ tol = 0.5, stepSize = 1, seed = 1) {
+ jobj <- callJStatic("org.apache.spark.ml.r.MultilayerPerceptronClassifierWrapper",
+ "fit", data@sdf, as.integer(blockSize), as.array(layers),
+ as.character(solver), as.integer(maxIter), as.numeric(tol),
+ as.numeric(stepSize), as.integer(seed))
+ new("MultilayerPerceptronClassificationModel", jobj = jobj)
+ })
+
+# Makes predictions from a model produced by spark.mlp().
+
+#' @param newData a SparkDataFrame for testing.
+#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named
+#' "prediction".
+#' @rdname spark.mlp
+#' @aliases predict,MultilayerPerceptronClassificationModel-method
+#' @export
+#' @note predict(MultilayerPerceptronClassificationModel) since 2.1.0
+setMethod("predict", signature(object = "MultilayerPerceptronClassificationModel"),
+ function(object, newData) {
+ predict_internal(object, newData)
+ })
+
+# Returns the summary of a Multilayer Perceptron Classification Model produced by \code{spark.mlp}
+
+#' @param object a Multilayer Perceptron Classification Model fitted by \code{spark.mlp}
+#' @return \code{summary} returns a list containing \code{layers}, the label distribution, and
+#' \code{tables}, conditional probabilities given the target label.
+#' @rdname spark.mlp
+#' @export
+#' @aliases summary,MultilayerPerceptronClassificationModel-method
+#' @note summary(MultilayerPerceptronClassificationModel) since 2.1.0
+setMethod("summary", signature(object = "MultilayerPerceptronClassificationModel"),
+ function(object) {
+ jobj <- object@jobj
+ labelCount <- callJMethod(jobj, "labelCount")
+ layers <- unlist(callJMethod(jobj, "layers"))
+ weights <- callJMethod(jobj, "weights")
+ weights <- matrix(weights, nrow = length(weights))
+ list(labelCount = labelCount, layers = layers, weights = weights)
+ })
+
#' Naive Bayes Models
#'
#' \code{spark.naiveBayes} fits a Bernoulli naive Bayes model against a SparkDataFrame.
@@ -685,7 +781,7 @@ setMethod("spark.naiveBayes", signature(data = "SparkDataFrame", formula = "form
#'
#' @rdname spark.naiveBayes
#' @export
-#' @seealso \link{read.ml}
+#' @seealso \link{write.ml}
#' @note write.ml(NaiveBayesModel, character) since 2.0.0
setMethod("write.ml", signature(object = "NaiveBayesModel", path = "character"),
function(object, path, overwrite = FALSE) {
@@ -700,7 +796,7 @@ setMethod("write.ml", signature(object = "NaiveBayesModel", path = "character"),
#' @rdname spark.survreg
#' @export
#' @note write.ml(AFTSurvivalRegressionModel, character) since 2.0.0
-#' @seealso \link{read.ml}
+#' @seealso \link{write.ml}
setMethod("write.ml", signature(object = "AFTSurvivalRegressionModel", path = "character"),
function(object, path, overwrite = FALSE) {
write_internal(object, path, overwrite)
@@ -734,6 +830,23 @@ setMethod("write.ml", signature(object = "KMeansModel", path = "character"),
write_internal(object, path, overwrite)
})
+# Saves the Multilayer Perceptron Classification Model to the input path.
+
+#' @param path the directory where the model is saved.
+#' @param overwrite overwrites or not if the output path already exists. Default is FALSE
+#' which means throw exception if the output path exists.
+#'
+#' @rdname spark.mlp
+#' @aliases write.ml,MultilayerPerceptronClassificationModel,character-method
+#' @export
+#' @seealso \link{write.ml}
+#' @note write.ml(MultilayerPerceptronClassificationModel, character) since 2.1.0
+setMethod("write.ml", signature(object = "MultilayerPerceptronClassificationModel",
+ path = "character"),
+ function(object, path, overwrite = FALSE) {
+ write_internal(object, path, overwrite)
+ })
+
# Save fitted IsotonicRegressionModel to the input path
#' @param path The directory where the model is saved
@@ -791,6 +904,8 @@ read.ml <- function(path) {
new("KMeansModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LDAWrapper")) {
new("LDAModel", jobj = jobj)
+ } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.MultilayerPerceptronClassifierWrapper")) {
+ new("MultilayerPerceptronClassificationModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.IsotonicRegressionWrapper")) {
new("IsotonicRegressionModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GaussianMixtureWrapper")) {
@@ -798,7 +913,7 @@ read.ml <- function(path) {
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.ALSWrapper")) {
new("ALSModel", jobj = jobj)
} else {
- stop(paste("Unsupported model: ", jobj))
+ stop("Unsupported model: ", jobj)
}
}
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index de9bd48662c3a..1e6da650d1bb8 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -347,6 +347,38 @@ test_that("spark.kmeans", {
unlink(modelPath)
})
+test_that("spark.mlp", {
+ df <- read.df("data/mllib/sample_multiclass_classification_data.txt", source = "libsvm")
+ model <- spark.mlp(df, blockSize = 128, layers = c(4, 5, 4, 3), solver = "l-bfgs", maxIter = 100,
+ tol = 0.5, stepSize = 1, seed = 1)
+
+ # Test summary method
+ summary <- summary(model)
+ expect_equal(summary$labelCount, 3)
+ expect_equal(summary$layers, c(4, 5, 4, 3))
+ expect_equal(length(summary$weights), 64)
+
+ # Test predict method
+ mlpTestDF <- df
+ mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction"))
+ expect_equal(head(mlpPredictions$prediction, 6), c(0, 1, 1, 1, 1, 1))
+
+ # Test model save/load
+ modelPath <- tempfile(pattern = "spark-mlp", fileext = ".tmp")
+ write.ml(model, modelPath)
+ expect_error(write.ml(model, modelPath))
+ write.ml(model, modelPath, overwrite = TRUE)
+ model2 <- read.ml(modelPath)
+ summary2 <- summary(model2)
+
+ expect_equal(summary2$labelCount, 3)
+ expect_equal(summary2$layers, c(4, 5, 4, 3))
+ expect_equal(length(summary2$weights), 64)
+
+ unlink(modelPath)
+
+})
+
test_that("spark.naiveBayes", {
# R code to reproduce the result.
# We do not support instance weights yet. So we ignore the frequencies.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala
new file mode 100644
index 0000000000000..be51e74187faa
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import org.apache.hadoop.fs.Path
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.ml.{Pipeline, PipelineModel}
+import org.apache.spark.ml.classification.{MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier}
+import org.apache.spark.ml.util.{MLReadable, MLReader, MLWritable, MLWriter}
+import org.apache.spark.sql.{DataFrame, Dataset}
+
+private[r] class MultilayerPerceptronClassifierWrapper private (
+ val pipeline: PipelineModel,
+ val labelCount: Long,
+ val layers: Array[Int],
+ val weights: Array[Double]
+ ) extends MLWritable {
+
+ def transform(dataset: Dataset[_]): DataFrame = {
+ pipeline.transform(dataset)
+ }
+
+ /**
+ * Returns an [[MLWriter]] instance for this ML instance.
+ */
+ override def write: MLWriter =
+ new MultilayerPerceptronClassifierWrapper.MultilayerPerceptronClassifierWrapperWriter(this)
+}
+
+private[r] object MultilayerPerceptronClassifierWrapper
+ extends MLReadable[MultilayerPerceptronClassifierWrapper] {
+
+ val PREDICTED_LABEL_COL = "prediction"
+
+ def fit(
+ data: DataFrame,
+ blockSize: Int,
+ layers: Array[Double],
+ solver: String,
+ maxIter: Int,
+ tol: Double,
+ stepSize: Double,
+ seed: Int
+ ): MultilayerPerceptronClassifierWrapper = {
+ // get labels and feature names from output schema
+ val schema = data.schema
+
+ // assemble and fit the pipeline
+ val mlp = new MultilayerPerceptronClassifier()
+ .setLayers(layers.map(_.toInt))
+ .setBlockSize(blockSize)
+ .setSolver(solver)
+ .setMaxIter(maxIter)
+ .setTol(tol)
+ .setStepSize(stepSize)
+ .setSeed(seed)
+ .setPredictionCol(PREDICTED_LABEL_COL)
+ val pipeline = new Pipeline()
+ .setStages(Array(mlp))
+ .fit(data)
+
+ val multilayerPerceptronClassificationModel: MultilayerPerceptronClassificationModel =
+ pipeline.stages.head.asInstanceOf[MultilayerPerceptronClassificationModel]
+
+ val weights = multilayerPerceptronClassificationModel.weights.toArray
+ val layersFromPipeline = multilayerPerceptronClassificationModel.layers
+ val labelCount = data.select("label").distinct().count()
+
+ new MultilayerPerceptronClassifierWrapper(pipeline, labelCount, layersFromPipeline, weights)
+ }
+
+ /**
+ * Returns an [[MLReader]] instance for this class.
+ */
+ override def read: MLReader[MultilayerPerceptronClassifierWrapper] =
+ new MultilayerPerceptronClassifierWrapperReader
+
+ override def load(path: String): MultilayerPerceptronClassifierWrapper = super.load(path)
+
+ class MultilayerPerceptronClassifierWrapperReader
+ extends MLReader[MultilayerPerceptronClassifierWrapper]{
+
+ override def load(path: String): MultilayerPerceptronClassifierWrapper = {
+ implicit val format = DefaultFormats
+ val rMetadataPath = new Path(path, "rMetadata").toString
+ val pipelinePath = new Path(path, "pipeline").toString
+
+ val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
+ val rMetadata = parse(rMetadataStr)
+ val labelCount = (rMetadata \ "labelCount").extract[Long]
+ val layers = (rMetadata \ "layers").extract[Array[Int]]
+ val weights = (rMetadata \ "weights").extract[Array[Double]]
+
+ val pipeline = PipelineModel.load(pipelinePath)
+ new MultilayerPerceptronClassifierWrapper(pipeline, labelCount, layers, weights)
+ }
+ }
+
+ class MultilayerPerceptronClassifierWrapperWriter(instance: MultilayerPerceptronClassifierWrapper)
+ extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ val rMetadataPath = new Path(path, "rMetadata").toString
+ val pipelinePath = new Path(path, "pipeline").toString
+
+ val rMetadata = ("class" -> instance.getClass.getName) ~
+ ("labelCount" -> instance.labelCount) ~
+ ("layers" -> instance.layers.toSeq) ~
+ ("weights" -> instance.weights.toArray.toSeq)
+ val rMetadataJson: String = compact(render(rMetadata))
+ sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
+
+ instance.pipeline.save(pipelinePath)
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
index 51a65f7fc4fe8..d64de1b6abb63 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
@@ -44,6 +44,8 @@ private[r] object RWrappers extends MLReader[Object] {
GeneralizedLinearRegressionWrapper.load(path)
case "org.apache.spark.ml.r.KMeansWrapper" =>
KMeansWrapper.load(path)
+ case "org.apache.spark.ml.r.MultilayerPerceptronClassifierWrapper" =>
+ MultilayerPerceptronClassifierWrapper.load(path)
case "org.apache.spark.ml.r.LDAWrapper" =>
LDAWrapper.load(path)
case "org.apache.spark.ml.r.IsotonicRegressionWrapper" =>
From 0b3a4be92ca6b38eef32ea5ca240d9f91f68aa65 Mon Sep 17 00:00:00 2001
From: Sean Owen
Date: Wed, 24 Aug 2016 20:04:09 +0100
Subject: [PATCH 082/270] [SPARK-16781][PYSPARK] java launched by PySpark as
gateway may not be the same java used in the spark environment
## What changes were proposed in this pull request?
Update to py4j 0.10.3 to enable JAVA_HOME support
## How was this patch tested?
Pyspark tests
Author: Sean Owen
Closes #14748 from srowen/SPARK-16781.
---
LICENSE | 2 +-
bin/pyspark | 2 +-
bin/pyspark2.cmd | 2 +-
core/pom.xml | 2 +-
.../apache/spark/api/python/PythonUtils.scala | 2 +-
dev/deps/spark-deps-hadoop-2.2 | 2 +-
dev/deps/spark-deps-hadoop-2.3 | 2 +-
dev/deps/spark-deps-hadoop-2.4 | 2 +-
dev/deps/spark-deps-hadoop-2.6 | 2 +-
dev/deps/spark-deps-hadoop-2.7 | 2 +-
python/docs/Makefile | 2 +-
python/lib/py4j-0.10.1-src.zip | Bin 61356 -> 0 bytes
python/lib/py4j-0.10.3-src.zip | Bin 0 -> 91275 bytes
sbin/spark-config.sh | 2 +-
.../org/apache/spark/deploy/yarn/Client.scala | 6 +++---
.../spark/deploy/yarn/YarnClusterSuite.scala | 2 +-
16 files changed, 16 insertions(+), 16 deletions(-)
delete mode 100644 python/lib/py4j-0.10.1-src.zip
create mode 100644 python/lib/py4j-0.10.3-src.zip
diff --git a/LICENSE b/LICENSE
index 94fd46f568473..d68609cc28733 100644
--- a/LICENSE
+++ b/LICENSE
@@ -263,7 +263,7 @@ The text of each license is also included at licenses/LICENSE-[project].txt.
(New BSD license) Protocol Buffer Java API (org.spark-project.protobuf:protobuf-java:2.4.1-shaded - http://code.google.com/p/protobuf)
(The BSD License) Fortran to Java ARPACK (net.sourceforge.f2j:arpack_combined_all:0.1 - http://f2j.sourceforge.net)
(The BSD License) xmlenc Library (xmlenc:xmlenc:0.52 - http://xmlenc.sourceforge.net)
- (The New BSD License) Py4J (net.sf.py4j:py4j:0.10.1 - http://py4j.sourceforge.net/)
+ (The New BSD License) Py4J (net.sf.py4j:py4j:0.10.3 - http://py4j.sourceforge.net/)
(Two-clause BSD-style license) JUnit-Interface (com.novocode:junit-interface:0.10 - http://github.com/szeiger/junit-interface/)
(BSD licence) sbt and sbt-launch-lib.bash
(BSD 3 Clause) d3.min.js (https://github.com/mbostock/d3/blob/master/LICENSE)
diff --git a/bin/pyspark b/bin/pyspark
index a0d7e22e8ad82..7590309b442ed 100755
--- a/bin/pyspark
+++ b/bin/pyspark
@@ -57,7 +57,7 @@ export PYSPARK_PYTHON
# Add the PySpark classes to the Python path:
export PYTHONPATH="${SPARK_HOME}/python/:$PYTHONPATH"
-export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.1-src.zip:$PYTHONPATH"
+export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.3-src.zip:$PYTHONPATH"
# Load the PySpark shell.py script when ./pyspark is used interactively:
export OLD_PYTHONSTARTUP="$PYTHONSTARTUP"
diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd
index 3e2ff100fb8af..1217a4f2f97a2 100644
--- a/bin/pyspark2.cmd
+++ b/bin/pyspark2.cmd
@@ -30,7 +30,7 @@ if "x%PYSPARK_DRIVER_PYTHON%"=="x" (
)
set PYTHONPATH=%SPARK_HOME%\python;%PYTHONPATH%
-set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.1-src.zip;%PYTHONPATH%
+set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.3-src.zip;%PYTHONPATH%
set OLD_PYTHONSTARTUP=%PYTHONSTARTUP%
set PYTHONSTARTUP=%SPARK_HOME%\python\pyspark\shell.py
diff --git a/core/pom.xml b/core/pom.xml
index 04b94a258c71c..ab6c3ce805275 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -326,7 +326,7 @@
net.sf.py4j
py4j
- 0.10.1
+ 0.10.3
org.apache.spark
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
index 64cf4981714c0..701097ace8974 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
@@ -32,7 +32,7 @@ private[spark] object PythonUtils {
val pythonPath = new ArrayBuffer[String]
for (sparkHome <- sys.env.get("SPARK_HOME")) {
pythonPath += Seq(sparkHome, "python", "lib", "pyspark.zip").mkString(File.separator)
- pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.1-src.zip").mkString(File.separator)
+ pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.3-src.zip").mkString(File.separator)
}
pythonPath ++= SparkContext.jarOfObject(this)
pythonPath.mkString(File.pathSeparator)
diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2
index e2433bd71822e..326271a7e2b23 100644
--- a/dev/deps/spark-deps-hadoop-2.2
+++ b/dev/deps/spark-deps-hadoop-2.2
@@ -139,7 +139,7 @@ parquet-jackson-1.8.1.jar
pmml-model-1.2.15.jar
pmml-schema-1.2.15.jar
protobuf-java-2.5.0.jar
-py4j-0.10.1.jar
+py4j-0.10.3.jar
pyrolite-4.9.jar
scala-compiler-2.11.8.jar
scala-library-2.11.8.jar
diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3
index 51eaec5e6ae53..1ff6ecb7342bb 100644
--- a/dev/deps/spark-deps-hadoop-2.3
+++ b/dev/deps/spark-deps-hadoop-2.3
@@ -146,7 +146,7 @@ parquet-jackson-1.8.1.jar
pmml-model-1.2.15.jar
pmml-schema-1.2.15.jar
protobuf-java-2.5.0.jar
-py4j-0.10.1.jar
+py4j-0.10.3.jar
pyrolite-4.9.jar
scala-compiler-2.11.8.jar
scala-library-2.11.8.jar
diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4
index 43c85fabfd481..68333849cf4c9 100644
--- a/dev/deps/spark-deps-hadoop-2.4
+++ b/dev/deps/spark-deps-hadoop-2.4
@@ -146,7 +146,7 @@ parquet-jackson-1.8.1.jar
pmml-model-1.2.15.jar
pmml-schema-1.2.15.jar
protobuf-java-2.5.0.jar
-py4j-0.10.1.jar
+py4j-0.10.3.jar
pyrolite-4.9.jar
scala-compiler-2.11.8.jar
scala-library-2.11.8.jar
diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6
index 93f68f3f9e3fe..787d06c3512db 100644
--- a/dev/deps/spark-deps-hadoop-2.6
+++ b/dev/deps/spark-deps-hadoop-2.6
@@ -154,7 +154,7 @@ parquet-jackson-1.8.1.jar
pmml-model-1.2.15.jar
pmml-schema-1.2.15.jar
protobuf-java-2.5.0.jar
-py4j-0.10.1.jar
+py4j-0.10.3.jar
pyrolite-4.9.jar
scala-compiler-2.11.8.jar
scala-library-2.11.8.jar
diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7
index 9740fc8d59698..386495bf1bbb1 100644
--- a/dev/deps/spark-deps-hadoop-2.7
+++ b/dev/deps/spark-deps-hadoop-2.7
@@ -155,7 +155,7 @@ parquet-jackson-1.8.1.jar
pmml-model-1.2.15.jar
pmml-schema-1.2.15.jar
protobuf-java-2.5.0.jar
-py4j-0.10.1.jar
+py4j-0.10.3.jar
pyrolite-4.9.jar
scala-compiler-2.11.8.jar
scala-library-2.11.8.jar
diff --git a/python/docs/Makefile b/python/docs/Makefile
index 12e397e4507c5..de86e97d862f0 100644
--- a/python/docs/Makefile
+++ b/python/docs/Makefile
@@ -7,7 +7,7 @@ SPHINXBUILD ?= sphinx-build
PAPER ?=
BUILDDIR ?= _build
-export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.1-src.zip)
+export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.3-src.zip)
# User-friendly check for sphinx-build
ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1)
diff --git a/python/lib/py4j-0.10.1-src.zip b/python/lib/py4j-0.10.1-src.zip
deleted file mode 100644
index a54bcae03afb823da3b2b69814811d192db44630..0000000000000000000000000000000000000000
GIT binary patch
literal 0
HcmV?d00001
literal 61356
zcmb5VV~}XgvMt)SZQHhO+qUi1wr$&XueNR5-K%ZCwf8;e$M@pC*!#_hnlb-W&WNnc
zF)Bw^DM$l@Kmq*gK~}Dk`1i;Eyg&hP031A7tmsu$AOV0^teDkB%b3+&J)i*qLC$~y
z0RH___&Xit?*s?{amO5qn}#$mRagK3ATa;{jQ^c(W$0$8Z)|U4V`}VTX>aFD=ivEo
zi>=ed>i-)5FN;??w)UGGNZ;4`{m#7-+N%TPoAXuZKSQZ<
zZJSrwri-dX8gxfY(xyD2H%&C)2aj4upbYh^pP7Rt+?E$$;571sub?ytt`oS7->Y)1{vg?)^8gvqu
ze+{{59`UlBAocpk7BJ6wLSJRbCBCvMToYZ!^m@jCk@{sGNmVSdo`~^OjTA@;SyCTYacj=b2?i1S3Yf*8
z{>-78VE7oXg@}zwM)l63q`@TRKmst83`CgQ|4U{|4rHHjk}OGe257$b;-$L#!KOWP
z3LL=2ASR#y7O*!RohudxBgDwY>M$tZNX=O@6O{vfxmqQ;L{y6Wveq5-`(^I|i32FE
zfl$b6Y0jF^nMjVcd0020&KDOKy{!r>87R3xiltIQc=D|jRF5bkxf`N1MxQ4ekOERi
z;HfQ&ZS0Mp(;@1;|m)YSRo*1
zqDSPxy`TiUb3zdK+RZIk-e|kqPtwePo35+6=>_A8AQr_aOqa;POP4T0f?Y^lf8+-_
zV3Y}U9IPdR5T3n!Kv>TR
z7Q8IGs$0bYAh;LB{14BYI^bu|pE^7I`!P^@pd&cGkoK;@v`dD^Gsg8#5I7-~M$33C
zs_M$MuTRY4AJxcfA(tn9Ge;~0?K5&9i@8wH$PUfPgJ3#y)qy(iYN_yvXHiVipr)vK
z{CNk40bOK^AorU=I>JY-XTq2S0?d_d8t=B3+^;D=6CywbqKWz{Qvi1sAi*;6q5xE2
zv~{V(hV3>nNWL7lLXc+ns^)n9+wO!+FRMqZH&&){I`l9Df$jBh+;R(r;TTIAa-Q9>
zGQ#|cimWWme$S;NyIEOplxoY3-mMzg?9JwYwsJwYs#=rLwf>WbvR00U8lE)W8_m1ud2N$rz6
zBDw=>iuU~S2~hPDreb!@?+BP#<@ymXN`jQtmG}3#Gg|-Q#2K=Migs6KY`(Ad#9_*`
z;eBejs%_o&`cHr>o9NDfu2DZe5x{y-0!`yrp*CQWzT&f&v2Pz5$^Ze*8YGFJQ&+)~
zg&G@d0-1tsNG+j#U3Jx?wZ8oKQGrRu(#~nJV3T}Cd8uL9U}YW_>PdF466sL2JXHYN
zl>4uu$;l%A!b5g*bMez&(Nl4A;O9?ydF!-meSe6-wqCp9HcY!@#Vvtg5JHTm5A%(h
zio4Hxw?L$`^Cee5yJWn+IkAwaqRE7kql#|C4%T4&G+#;s2wIxCpVht!FCS!z4BO%C
zydQC<2y!!Skjp{%a}F!^LwI@lSnxPwKM*KXxAZ(J-Y<>qBZ7lZ?oM;z9N;MK)OtY_
z9DglU^`prNunOB@h@kg(b%O*d8xhbJipkr2R=iJx1L8-<(n_&E`;NM`;P5-0mcsdn
zHt<;fOJF8uEdUOKIS8RB>hVr57Hxqbo2g_iviQ>gTzk*BqPMWK6onMkW`vLx===SZ
zZDM06ysqCMU9HTD_jJd)1~#Ac^&);4#t=yuzbdI7RKRfrDQRppXf(X2ABJUFTduYh
z(thzm;Q3iY)FN6(vBvY6a;#mnn@Xz)Kx(M6kkOKqA+o_}^O6xmSrF6-gos0AcQ6+o
zR3zwFCq||{{*f5r4?lR%Z*51_`HanmqPIeMJ--Kn=}<9;Q5sRV%%Bn$6YfxH%RtTt
z?J$T4Y#LUB?lKVH@8|
z3qBZGaAz0b!p-LABH-oSbD&yMzEysJ)RE4)(3^Z~b9~_lC%(Rr#PJ5@&Tl_nE50UX
zXl(#fO`{)rvisQ^eIe|MFieKj6=&{#)2o#=lrCg;m4A7EVOHgUYstsBiBU@&fG5KU
zYB8|M4%3MlM=_jLsV7iJokw;}3QB6x5(MnYoVrUy9o=8FtiF!SuN=5qOg@*AuS_&&
zW}sIz_w+$6;~_LOX}AByD8N4d!j<_kK`swZ*R3lnv-+~&eY&vugsBIxPI3u-k~xub
zY>;)o>E%|-UNrK_lQm7U9euSFg`P2Fd`C4%C*$)u0a`
zzqFbPls}c^bGk~WRb%JA489T^!x|`Ea%0qUx$>T`g(f3ESqWcu^?ne4^-+<{WCvqy
zyvw=+kyr08B6=h`tjEv33GIj`t2WcIpRWrE<1mNk80>1Ai;pD>14cXxEF@VoAT0b~
zcjH?$5Za`xi?mc+hgDE{%U&amU&nVmNovOPk^@0ik-TQ3P5OIZZb0m=wohYQZZrPQ
za>bgX6jtYXvLI!r^HUq=WR-|6k4k((5N@)i@#nW|dg_yf`hhNRPjSGr_A&bvc2)?x
zs+D$H*Xw%W(Xt8lgx*#2bP2s*EA_9thg+odcKiaE5&5P7OS0X`{Sl|vZ}I0c^1*qu
z62^pe;E$Wrr`_30{L}7lYnsSu_+VeiLG~r2*LAwCQF+?BP|)IFyEFs-u>`l~va_L$
zAoB8a-Ps(3m^3Z9GBp9UDMDUf91|WtXFng74N$E4{&EH3UV0H^e$(Pf_Kfv7PSFaB
z?hWfChy&H}5JTJe>qzbDMBUHp=jIttqP?tCFkv1mp5iMiDRl!b+^-Q1NJ~W@$(aa8`|@y&gMDs?2~crWvI(}
zMC9_*aC*7!O*-(iC+qBK1^XGttkli|YcY}=lJ@3BE1-7Tm?0_BG0y)n4EaCi_(uermmx-I5o9Ru4uqib&&!9h>+L`%$EY`=xF*
zYc|@nA&!p~aOcoA7GRJH21*-h9TR&Wwt2#q7PjG^guSid9b0(8-vShyCQg
zL9fS@;wZsDe@K)tcvo4}3~a?tWSuXEYda)wLeXHd)8?je8fL|LMbCP(ua&F=wbt6}
z>iAi_&bGlGMYTUQ0=tMP6oqKMfCIxF4reHYsIbM%ecmXl8n0!?hKip}lB%xO4PW+0
zMq40}*P2}lKIjJNB?8H6HFBJEj2(T{mB0*x_63Ky&xX>HQ7ONyn)%}0xNLvf5!?zO
z^);o;J#3N*;gci8Y5!yfR5Gh1I=HzY#(cM9JZcP}A5z-?IuXkl>`vSIwvcPz?n}Pj
zayQYU=Z3GrOYOQ@yRa?K0WV|kxaMTBW95eTVZzeEi`miH@pkz*rDLnO{q@f>?ZSAo
zs|_9iz=#n50O5bjGzTYp7yG|u+J7j|TClADmS^`}exULn!#nLl3t(N)yKaD^irVhD
zo95iU`PyU6GMiRU4382QwC}t2p~VzRiAi>j=qR!xGf~Wgo|3J$+LOG*VJzdw{E(8#
zHj(iOu$y3vlM*v8ffQpw+#bbuq4W(?Wnlz62}H)&jeJ^7hsV~ULk<3B@rbYDk~6=L
zE8XhGC
zm_j`8kGzoKOpA0;d%STH*YHdPUn^9zg58O%=~haNiD+{SsQ
z!}EqT>kTQxafgIpStdIW)DTW-Vk{+AXr5mXW%-M!8hxE++3syP$&ngm_)QSB{V_~Q
z))Q%7vmjl;5S=9|>SO^I$g+1+-#hZfB-4-3h(zh3-M53M)gk{A@c4zGKlyGCRH)HK|Yk%lmoqxX#I%#gj2JI`zhI7H`<$EAZ1IswU#Y
zsUD}t5jLINHNPJujoifunBAQbXG+KZ)rw(sx%WG8I4>Og^z4nI{8%pa;bq4L>cC^M
zYW3RTJf;?rt*vI?g()Cp-sNEWWHEzV|2Mm#&n8g9R$I`^&OlDDtV
z)^*f8H^^((?6K$9$w;T0K8)R{H2in!+Bx~AMh=Xfj!syB{m)N!M-a
zjkyp0+#bWVbbP4}l#Mk{PPfXe@n&bH+~~DIVd_esjh8p8+wuh1I=nc0P=i+I2=kOU
zR>@TB9?SJs4h;3wQY+X4Cr-BxfFb&|67RKO_I#liob9CyC{0@;7d|<@pTwq-9ANu7el6
zN_)?)>b-qOqdAjG;H(2M4Q+kA_5A?E^y85&`dGpcuSj+)2hWuO
zZB;Y%9Ds?I1c!)Qdu;jjKv-Q#Y6ztj&@=TmzMB9Ywh5aYV@^Tv?nT2cGpbG-^Cxsee?m6FlBozvTGUD;!*XE2Yr4w=
zJNhwpfeLZ?$mS
z32#}XIo0dHh1SVkfRFVW?e`PR)KyP4(|Sm{)_BU;08;8AUN;D0K2$#%`1Q(nGRiNq
zGoOoIDZ(0}T^rEJt@^#41yC|4tjLuh*^dpKYAsDeFhoCNeLMg91aE%c70y&Jrd*T*{@-nD7q1&_^w8E4T
zJ+%#pkVez|mg`^)OjhgxJ&DD4@#@OnzNdz5FdV;(tn0p2eUi#^5yGlrPxd)aFwT=n
zOm`J^tX9`SUBiWi9ua)gmxH)c-Z)P@)nF?DZ6a;0+}%)UXuP={tp!=7ZexR4XG)jw
zT3-tVd1{PMK+Uz?D>S`RSyf1?uVmYdZ!;9K-fCWGf-MU7IuEh+?$@t-`o<0
z#}DlWJ-K9bN!L+6=0nD%ehi5)p&>W~@tSHbB))qG-$PDND(1V15d8?nZOkVST*T=N
z;EPuZ1~u@6`sOyBs~K4nF6$ttQNdl492VqWqkO8@58QpXqEb`V_>4+ICYij_uo-nz
zuWYqMrqqhElP-{5PDOcit;j-`gv#NMfg|MTBcsgksCy)TqA^Wc856+fiYEy!%46MJ
z>;j-TfC1Y$Vd1Q2;W;oCdC;LunkA3akO^h;aoZ0X8UE~N$wcYapN^zT!ay7RS^$Gz
zld(3-k)P{!O9oYVXd(UF2SdIDC+7|p+?mHdVE(hzxm_Eu2nUW<)2AQ|o((?E+)`*E
zzrBDBQ#M%)4Wd0~fpAz2>_ed}4B1ihL8EVA2rN89zyh;lhN^Hi^asvPb)vi%oHgj&
zCzCPRcVra(NHZtmDAwePXqWG8`dX?L(^{g~e-b%zJcc|F{bv_hP$EAP6#;c1t^fwN
zc_|Vs+BueBXCnoGtK@lKqJ$=Cj@SiHIL`7jA0sH@y3uNPNj?Q>)TdojA=*6KVv;O!4$tdHVQk+xL*eLY4r9
z>9_ePM7%Ub@ErFLzRZ#FgPh#XjVugqMnId)_z}
zF}yL}L=_$bWy`Etg|VEDltF2I&j7Tb5W5Qb{>*W7i635B%v_((KM4EOKq=zr{}8d~
zihF&woQ@|WAq#?Ge9*K9am?=G8VQKyqzp~!7?b3;qAcV_C)#GHi?ExSpb9E4BXFnk
z35Yq^6$CEGy_FV|bnE=yV^vD!aHPd9FcS6F;O~mHy3mTJP$gLzVh7;)5k-6kt*4GG
zNUXgtOaq!HT*Sbd!hnf4eWlVv<5*zm?vz|lEph1AeDl)#TZ~zR0RE1RJ1eM)s?rgZ
zh|KkuT0&xNd;a3Pec_Ui$DI(|QidVf%>&&Tp7rQeyYF9bVvyn5Rz;r~s3O7Q>ckmq
zrwl9y`{{tl-gJjjKtm+LX<(66@^tLPpW(L}-bzIec7m?#g=Q8r_yS_r{qw8+)xzff
z1ETAxVGn%hog%`|W8+7ox7utHmWkqh%{aqwI5}(-3QGV!N)@*=WpZ`};6MsPATD;3
zq(v5E5G
zM@-KGl0f9;C$g9C|tbun_jOr!7is-NaEI;_&xp8?~eIn_xj_@h`YD*H|a2T1Jd+(Qd92m
z2%4uvqOibz>L$^R1Za=)Tq&Xoz6g|b9$u(DsOnEVNMJg!M&0!4g7cc7l%{b(>=2aL
zSTl6|``Y;5@WV8wO=5?!KXqs`Vs_7Eyl;?2)m&`59eU;|sL>W$|
z=rnyV6;7F(+D@@=g-+pv&A)w`n735}&abK9f(fFi7oO&mf=I~eWdDWnMj1+>S49qE(lrQ)!zek&uan0
zS*M?T*|!$4-{r!q`7FQe_Oh;NT1(iibKm{>sq=p5>{oEqQcJScqLSL?el7JM&y70V
zw|R#Kf_n-K?%79QYs}VIB#@V2<{D>fJqdL4<^6ix$Pb*80+~!vsfS_z|vjU)2XR~j21Z6!e0-4JUQ;0Wj
ztQ>@=Z@S2QWua?XLR_uvI$*aqt$nuQG&KzdIA8Z=0`A%CIlZxR6T~qNOu1f}!{Ulq
zb|bc~16&r7bW&Vc-NR9Ulk%x0k4^1X5ibHK*_?Nvk)D#QY7~mWsPe8#6@d*j1@#~r
zCfvc&(p0n?)1*dn`s2hE3OM*MXU_zAtMLG!Ut}Cu79-yk
zR}}QFiV2b(R}Nel^`}pRea9E^ZSwg*L0irLUdJ1zGv0{%I~70lcQctuSr@i$8Ol_6v1O4iZ{yIvcdLE1DlQ?b$(M^QhoYTc$;v%>M!8jc!cSg3
zV?>2
z+q}^KD*(8dI=eXk|7iARY3W)5fdc^i!Ug~!_@@zL8%t9=7ky__CpS|k{eKz$530hI
z#^0td8^XVuzFhhweW;nu9}9HDN~@Ad$HU=0big%`UAON`+!dj@|mb#FXn^jyCLq
zVQA@!EJ&X2-K(dunG%Re3M5$9tmei82PzV})k7==%JDTmu2`*Pp?ov#W)&!e6vK&@
z{Y&oDpe++4mae4I=_~mzlM#~Wb#_78UgF})m4o^%1!hE?HO0E*W3e1T#YhFn+%Dd1C
z_M|V->Skbm$x8(ks>qb!i6TToOgv+JBsWnNXUAE-x|*y&F_z)`J3C^P`S2=VM?%Fh
zm;9_&HO;Tjj}|HZSiBYp!we_#=)6Oz9s*SCBG}`pb4fxlSFGVi@8uSw?q4ozGz|6fiu0m51zg)c4)ugM!ft;A!BWJ~7HD`q;ODH;FzkXDsOkHi;OO-sN=(W|1CT~w^iU9G>o
zO?LwA!1X>lISK9VB}!8lcJgucUBFJ)C#z7OkAzVgz++UfxrDf0m%{{h`#!XHGOVe_?WAl7E
z3!g0Un_AGEiICFkrwwZ_!-9tqZlrt^jAd1iPK{5CVGr@bDwmw5ABNVpw76!Cu!YrFm;ap~*0Q+O)LFCoDY|i0@AQbP5J3IynHS*)9aS
zLwI@?Gkz32Vgld~ZN*d_)JX1buyKKpM-0efjW`d@>VQk>*Vm9O#1Ku$Z)5cG@KDrz+51um+u*h7dWn9;e!;I27n@;
zA49wwC6YP9lb=d>oXawFeKy9!2wQBN`?7c*t2)$%pW(XFKJH2lkLSBWyvJ=H-G_&3
z%X(|7SJYaKf;p>KyPE@pdIhWl0PU=y9KxvmK*yP6^<%n@siW|)g`Gt7??E&cry_?9
z0pik8{*cN1eMX<;D8GYTE%ws5qUfxZ}5eRN4!Q7-x7;|KkzDfXf6eFJ~
zR>l2}&C!Rxv1#9CP7qk&szE*H-u*^6ymd`!Gi;A|2E<_INLl}?JuGfUTxXBAxjEM;
z+g?N0OvmgWmV(H+0Uy(v3WN=ZG#OYGy9PrR+f!HJn$9;F)Df@Hf`>jB9q0vdf)XWGxjzvN5q!l=;m0S%A
z72L!c-GB^#9^f$0gqx4JiY{yR-$wL5m{Rnjd8I&Kop4Q+MTM9h9}*X=r0JBFoT>ae
za+9^n*_m
zg=L|A(-}Xb25A_7Ldi$^lI(D0=2|h8%i&uA6Y^1=346JolEWKbe7hn1Y>p`DUUnDW
zCqH3U<2|32jj*qqYesMLjy*!%k90t}-8MvLA)eaCEZ6iRQW`wJC6Q=Ax>PDsn7sn
zWwhINv|P$AR>PYIY6u11v8L?`hTQv%dqT7E2MVflV>m(qWa{OV89e#p4fhBKdT1Ib
zGj&TcRWw$4So`tk*6AU>tLkU0?r4uuIlqzz
zn28kVWv$`Q(FIS0sWeBce6dXaO2>1(AuaN>iNQWorpU>Rf}CnJqx747YOgPbyh!^z(WK-iQQ;fMly9FJu
z?JOwJ3UC=Aitrn}TGQbSUeWS{SjQAq-LSwqwz#@ahf+kXoZq@U>dl5s(chM8@
zN$)qrqLfB(!45~;cPtVP40ZI%-18x~G85`rhAz$sA#uSysACw5+L4HE9XCOzTsY*{
zEuO4kN;w|>%$K=1s|i=}4j#$%bO(e<6LyLd7qcZMmg^0x5V$_7FIlf~Ojw)lbu_?F
z*{qfJ#v_0TzV|}GwLwtx^Jj=vM6Y>8n_}ffi4@N!R|)U>(KP6(*UD&i*VRZA+9U_e
z=9sAn;EmF`d@vg>b=4eU^VU^3t`CfFn&nN!jy;GLt>UeB=JU$!?N?ZcI^-l^zs8=((xGaFSvSW>bX<@oxd|4}$VS)?^zWvdjDvx-M(|
z-IFEntuV*=oBiRW=hPE#6B#NB(*m?BlVAe`fmU65#7@*0KC?9*q7j^?!7Aev&S?R4
z>Bo3J#UcU5n1(ilv3`=tvPA8-
z9F)wJ(PW=&D5uGLg%}dAW&IRzJRk}n&Y^gXgU%nZ69*8>{Gi@>I|y50C=wk(Mr4i5
zyw#N~icU?tl`7QB^L_*L(*t$YE-VF2V8(7wU&vQVuBG50ec2su;~A9GiIQWS0}P%S
zYC_c0oqSeTH2sz1S^I?sHoNy74_r*`X4SZL(n%(3
zbm*)Gw(&EwE~(vh#U?AUm&SGaakuS$>U_u%OUD8M%$Wq+6$t9hZL)5`6Iu9D$ZB0t
zW<%JlDGd7Li<0y$HJR!7+e)@7z*%ic%4%?@QQtH2pX#b$YXR3GmMS{^~Redkr*^=N&kLX
zxhkj*#c)chv;a_~5h%`T&G>O^!7P*jY$3&j;jmr*tyZ}SgR35k!V2sMZj{Hwh%g%c
z8N0N2v-SerH!6Spej`%&A7SgG5Swd(EPD|0_2Zn4I_QA8k7h%JDfbitl`6e6f<^Pn
zwLgx3#tBq*7usrJuRnt}v9m$;`>O$W*+>({P{H*|!J$U)Ky-2_8uqU4d518n}6+
zNNG*yU(;z9CHJsm#5Wy-yPHisB1(^Sub}c;OUoB#fmZcx8+f8PQ_e!C8U@I*ZG7T2
zGvg}X5emx^H>8j-CG|>gi%Q7kolG;%q2&sG>kvFSo9Dsg+28h=ii-p~
zU6eBQW1bEu+{~SS-X<`YYRk0GYB`OdK_OsOktN2fJ~E0E6g2hj)iC{z0E~0%%5*cY+oC)3N;yoACUC*h_d72po1T1a8B{6<62Q5v1HdA
z1n4`qDP4#&&45l}c^w|pLMUQ6Q>raE;<_$7^VZd*Funo)nLiSU!?#m^Swr+M;$i(W
zfBwxH|3eo40px7sWUc>lMd-~pst`GPsN^UeEr=L9#vc}o8!_GEA=B?ZSG9q&Uv&~(
zYxzRxJ#~1#1chFdGAZ!!MBVq~H0toTYb?|yxt8U=RMu2jrav50xNQSaVBV+tcT1M3>tD
zS*P1980-G&pG7Qr3fsT<@&x^V#}{W)m;VM|w$*l{vIG$PUg|LZL3`lnX_)HaV103(`0o=8L*b=e~;G03n`kS_<5zT2Sk(2&IgLl1q
z`wPtDyylb2xT?F}Mf2JSszh)B8xMnqtQD~|QtPHCqX97%mye8*YH`;c)k+Uc)iIkh
z`Mq1LH}ow|OcGJu#1MYGLTJLwWIO5R8%1usuW$6CFM7PTU$wiCn$f}C{z6&;8q%1+
zVW}OBkaZurM(hr4r=9ele94ueBb-MBv%1*X_(3$Y)r%>
z<^UbK1Xrow+|VW_r$6&RoUft8x%(P3Ex+!tX
zw&JvoEXyv)M}Epk>0zh>oJLKkhH`)_TA1Rflt`$G=|G8tmE-`;)Fh@5?@)*p%@d_x
zyxXO4G|rGw>k!Y*&x2XxrMk^*v{25WLDNPQt>|EVFPTA@3jvV)J_-7Jv#DRIB9SkA>)-ST_}enWY~}+L{^ioHkaq~
zys)0vy4Q@!>p@e2tazlJd;h{CF9eA>hdx;Xuwt%1tUsLPx~L;)Q(2$J%-rlL6`qx+
zhl*>3j1#==pSu_2hfKs(4O6eckyn(;4h!Z>OJ{c7Ep6ad2r*R92()J;ESL~CPUhsG
zCB52nuc5RH^*<|mz)dU_641DP8>IUg8|51GslaC|7CXK1Puzv6A}^Z*W5Y&YJ=}M(
zC_4e#k8@oO^5WJg|KM3GTfB1kfZj)T!InoC%VJU7`njX+2*iICGWa@#7A4KOwJTt4
z-U}jS>08Q-R!LV}xyJn=sKG@p$;uXd&;5!_>*w4^TE?OKls@n~^?7`HyZ15X9fkWS
z^XVB~T@CkWp-a?@Bioka)7_gNRR5C4LYcF8|EY3^$lHhgXB-t_Q^-91jUyEh006vy
z#*vYyi>bb|i<70D`F{h3Au5xxSqv!K52$V9V3m+ob2OC_>&>#%uJdYErCExmoDPs8
zBc;!z$b1#M^Z1vnAVLeH1ZdAsoTGfBahx=IbOfxL6))5tYrplopHm3hN`VOPL$RWS
z1T>dNkj8SDD+y72`5|VW#L`60EY6fQ4dk#qyxbH&&uTX|4t?E4y67Wlt%#7ZyzMoKrvq
zsz{#h$x6Sb(1bOZwm8(Ie(Y$UGibREV4m#7<6X=ZOw113iCvGdR{!w2V#Mrm+C(h?
z{+Ix?92N&ajh~~ANj-_rp04hA&yfP~IW&))b=`<-+s2K{ddo%|9iEo~Q;-Y3TkJb`l?
zY`K_Ww=XRnG9N!w1eWVFfUJMFuQ9yc!9jj)#5a-kn^$s+HiIKcDr3DQ)uUmY5}LcO
zF0Tn%UW;0Lc38Rl)VJ(2O9eL6&NE2zJ5bUN>Cibem5m>(LMUV-YAeA0u}NgMKZ%m(
z4S#DbdKq8e*IbPLdUJL!Y4ihex8*JI^b^HG^DWF^8W2ixTnU=-vdy)d^!J4257De-lY1YTAD_U=+XAI)bIJv2-#P!8RaJ78=*47Qq#dh$ce|
zsKC-#8!nWJ1Qn#858R6biPT1ybiJv`ea>C?TDU=mbZi4Srb5&=P_-ixGV;2(Uoe4r
zn#XJ#-NbpIy~Mxlf`bS|-nPyH6*ARpnFJ^iMJr;6Q;>Ab1xP0_$!ulTCag3$^J7Hj
zx@oA+MJ!pstwXfjXVIr6=j?1-sKUt0cRB0r=nYNVmgHyMT3MB^cWjbqyUbl;{@|_E
zGNXWk^g4IHwRD7!XG?Y%eO_Zt@uq_^R|Ak39%?NG7GWO2I4bW*DUEO@$8st9*Xqb<
z>B`%i^J3t&;`$eSTZCKOt!6oU`Mr-)>SojNj|?^9ly9rhEy6>!A?0U-_Os_l=#h_-
z^XJ?62M%$W;e?2*B_ow5rEX-Jv3Re}ZF+Ovy%*9yRs0CWaYReL^3va(0_-JeOgFF^
zBsG~i`uaL*r7Ag`1gq_@p6)lVaPyE#cA11LxWUr-#tu|_#xK1nHdTEs)xsJ?QYjT$
z(q*(9u}sFd>)8f$tz?fTpHCt=uK>r;jPZ~Wk!`Bwmr2CmjIou=WBc+&ECMQBj-`H9|rOe)dy?n-z>gK(FU%4cu)drm%_e=U;PKwru(nb_i@S
zYO2Qe5#Lr`BZgDIALVR}x;Xb*|8E%k~I%k(}BhrC_#S{)w+V|w(
zwe#O^4lf_yJNbDECOm)L{Q<@0>n7Tr2g)+4viEaqDONY{-Sz3@`nX+i{ta0DC#H1<
zSqxTMxnDZ2*bLJaR(g9WH-_Wkws+2U6zLpUxuAWqMJ?c6tr?YE3`&Ycgv0D6qM3wY!^Kfp%Y;684fS}{0PNE=4dCfZ
z08|gC2#sOHN*eICE(r?+mw`>+ObveCknXpWEf=2}(NI?WjrZ_xeG5
z*^)cdXH#}5NpW*usD$MjhZle87p2z${=d-a|Ma^4bte(k{}uFKHx&Oh(Edqv`udi3
zmM;4G|DhlvCQ$HXks?s}+r}bL_$#^@nwvuX69Gf)NV&)`000KO0RZs-`Hz2Z%gha3
zOx+DV{~Hv%;yriT6kGUtLNPuQ7YGxQP-w(mE~`^Zpw!mWrJj&$r9p=x2}7|#$amB*
z+p)+q*081u%0!?98ra->4P^^p$s)pBgBb$s=Dlgu}*iLi*;(A<@GaMR1kmDO2r&L4IYv?r8i*kw>%J
zo3R`!D0MlY)zJ^^KKf-B&VP#hX1PzciesGc9?=#(02yP<5IDqhDo^5r;e5uC?Nb;5
zvcc~XD;Jig%f!X~+Qr=2jd7jl0vbIOy?CcomNGZ&bn+%tcZ4;?qspP*2##6b7Vn?H
zU102fF9}3POfP}+&Z-OrlS9Ii`m9si%CD8=(BQ%xFbxaDpMX7OndWilk)psJxrO@O
zpd(lW2JC}Iu9#**e?@mUXs~12{v%;2#SImdf!_Z6g*(ROP5e3
z8yeNMn?P>MKq$wdG=bBd7`=QELg%sX$(Q9hTuk4h`1f13SD)EI7~|UQ@vXzp)m@3R
zpPS>jWdV#bmvetjb_BsedlIBB862BgdLV+*H3LNZXs^_^MP(+c4;dsOkqdTCUF$AP
z5d}!qp7@gV+e0!LyWyiGPk=&S`Q&E0V!q%$Wx?H$t`d1ZJ3;~ASu$9O_#F_@_C5+X
zHa2$A{&3em=wjv{L~nBj3>zId@w5wM@59f$P5bPJsQvwU`tAFzq
zf13h$^G@!ZySim8&BgQ{4RPRH{*piv+Hn}=^XV)-vI&MK${a~SCc|s|z
zWY~+iPIW(dB*9QNH|}~iL^GZ#6xu!7tU&nC6EMQmIkc8Zv!7-M;xUDEf&g)6V(U|G
zx(_jsxJ>qT)wVk}$?W9iCJ8r7rw2^?uWThK@8NS;_`l3mxhqP
zwn5`0pyS+6+h_ZHGfIp3gkmZ?P|*hVE$r)f;f0LuS(WqyDr4*Te*TZPz4A1{li6_|
zW2NcSzD(oJYIU;*?X6t|nnYPLeWGUb#&mV(I2F`EqGhK>H3YrU*eHt}Lps_uIaCdX
zqLc9W5NxSQUW$Lf1X^_?jp>9H0IR@Lg1-<9w0+q3VQ()h2fwu
zlkA!Cm^2;aGa`pJ6rnO+(bND4P~qv2K=Kq7$?2?6lWYazg
zSl$bYEGpk22T}G?RrrnP$M;04;MVDHrE-MWSPa-cv<-`gY$$Y|C_#`>TIBQmI{W2?
zsy+J)m=Bn6QZ)s;)v0o2Wl{P8LZb=65J4;(-7!m4_SeU`pbfa~4;M4X)imys-ZB-a
z#nxHfPWbb?#1%V-9O=NYqJ7G{!kM+!exejH-g;99Qgl3P*$iBrS&kHBd|@M6XYeMi
zG(`vloa!$^bpn-74-uG628+9~#`GK3$QI*A(#dKKhpZCX7;{SaGC~lc;e0}U`GrHH
zkqo5Ph_!vcK~P-{9*;?5jYpR9Tt!X#)9GUS(Y!~fpkPT^5XKT&NAMH_Rlp5}i+LB?
zz&GEabNKwv;9U$bup)f`cIw6Fq@DXc(*hkLeWdI2p^)fS1?CK(c+ln^l??&!(gF+H
z!bz`QiOx#W~lC#sTWDq?N{-VB#
zdX$+ex1-oPFc}MN7WLce)rf}YFRQi3s;jsOQy*7-w8nTd7_Fi%G7_mSYd@HzBFd3F
zg+Hjw5Nw?bhXVg|##@VY?LEAtppaktsEG9tNayP=WLhUkW|JEltsad&Ze1xgfd*&W
z_#Rb|*%=MR{Adn~8404;i3Yk#!`}gwSVaK-8)}cA;rEP06ix|ptHZe0q~riF;WAW>}*mp`wzje=<}iB1U8fn4TTSm4yc49q0nqC
zT9Uhb};ab~c^cn74@f_fnp?n|TsFhA!Aj9FB^*WB+y?e3i1t1sAm{K>+I4
zbrT-y7igk|I@GuKO#E;6w#S_kkwW>4lJ0mRpu`w?ZsYmkMG5k}<4uS(L1;Ck~
z&jW*z+2S*?I87|$oj_bqc#N9lt*xcYv361?9QdRx*3E4bC#@}LU*lwb8WD-HKkE^R
z1_#ojQKC^sP)7J!s>~d6NrGLzM6Tp}1v{#$GBp{}2<1?iQzLtO-_NLfnfE;vXk3OU>q@i4qa@`;927WCLY@CMdf7M4vuBo)*yj~mexVmenYR4TH7LsrX1hu7GL?$gN>`p0O
zhafWE>=zLD%Q+FQ8s>}3uOpoieF`};qtv=G?GcbvucJySpou&>Ux;4pRd=(fFtAd8
zc?8ap=##(joz^*t3OeE?S`NIR9S1f$blE^^s)=`8F55;ZN{=1{$~HBwYj+@#V7~-+
ztPUo7wIrPsUf~xnRo#DHWHgj|-Nu^Y!cjjLc_wK6Ka9O&bZFnU?V0R2*|BZgwr$(C
zZQHhO+qP}&7&}(wRJD8G`=7d1?|hiwS6g$nHRl+;_uqKeX-=9;_Lp>@8t4=?E?Qu#
z(%06sGOz`|gr%2#lt%OTy`sN1aeKsmYM**si-4G-4vFG5*-DnOVW`!FNAtax&~e<^
zzlA~EXZ9t@k-4wY6=34uy8pY<9$6bi6(YCM5I1@v4erwyy(4lcA&!wDvay{xal==j
zTv)5|#CNF9rwV*ClqY3ehg>Se(0+6QwOt;XZObZ>HR!oIBk>*
zsNPdqXOJszUXe{`oRb3X$OznYb(U1NBJ-G7Tf`}#Gud{YPuP(+q#4Tnt#>dK#p+?6c!_XWO8M1G?884_iRh;YN(YqN1gt8nepBur_&LE8Wpbo0?AY?>oAF3q(PY>7RLRD(0AecgNIcPlc&9EZf>
zIP+qbDW_|6s5Ci&&hHl%$AD*^V(S%$nxgUL$}c*w1|Z7*w_gfPl9_R-<-3Tm{oI3c7S%3H`kus
zZ@mbNnnt_SX|vkxm$#v15*wEBeO?*NhE9*7AtiCw;5_Xf#8gJU_^9)!S%mPMNXRA0
zL?P+Rq!*gM$ZI}eX9CR{+A5qUOIgO4qu97y-%BC*tjaUDevD*A1?#}AYbO#GOq=s)`PACPSiCu2@s4>|u@19-
z1Y3EUX(Yfu{$%mhUAlvNK0VS`RZ7COgfW>`KQ~T`@wVtcV=nnt6$N@*#z9!m_a17RaGi#_R7POTygRAt}8!XLePwJb(sCK?O2aY#@7^
z*wn-FdBoxHy^900YtaPnXt#Fm8b_7rY2DFdY&B=7%%mXOMA}s2rG9<7uQP}-P0RxM
zir8E4Ov6suU&cWvgDZB~yU?{Cu(4!&eeO+@Ge*_B!aq@vAjr+phH(`qIAI$S##&D~
zup_ZW-BWcj0yrpvaUM8anJ`_sKJ3RvJ-xg)ZSi`#@jaGZ_XhUn5d^;mqrRJY#78*r
z5e==U)VnO90i__pTQDw0JY85L#BZw(*p%MW2ku$$-F*i@%VP7<0kZAgo!MJIr5X^q
z1m1W?(Q`+|X`8o5l1ZZ*E?V|ja7vuEjzy;@b6Ka?05~x<=7&c(2YRKQs*b~9wH;*o
z<|e`E`Fq?YxfhIbeFR&EtBERCak&4gs{+Bj2Md)*Ke_a5AJ{u7Gv!eIB0lsxEIVJB
zsm>rDoM7K4D1w`5843W>7m>NjCbBt^n$Q`p}8LBp=qCS_C_9tzXfa
zKdLNFC5-biMTe_@8v%(S(s)UnB6W^U@pA>9ghdCD!T65cp95K#Ame84(QC}C^{EcB
zJjq*b!kcnkhm7b>63M5)->Et@FM-+CyoG2T)eDEO*{d%B+em}bTin6FT!Kmgo#W#<
zV=N8|>TE;dbB6n54P@vt0Ox_5hWM~z#PQ=)ICYJa^H%m7-P39_I$K&6#qks$xb!)g
z7=On&M*FgK1;*&Y?tQwA%o?br8}wb6@F@>IbG;PMvlAT?deg6aF)Vvs^KI1HT6@6K
z`V_&WF}f9rFjX!Sx=Wdsy0IFC>>f+sJI^YC-ep#ttOkb>+LdemX`?}%otR+mMxTgA
zAfNQ$m(?Brg~}_=F8Ref)T&OpVR0FfoeObRbv)I&!4~aqlkpE@M+9s0G0df0QE+lE
zXe*EyeG89NWzHzz4|2j37%|~vGicn{InM;V)5pZx$O-D$@^9CXnw0_<5Mfqa2)e7{
z$nrqIP7EP*C_Go#qz5Z5?T5}*ChZSNhf3-UH%xn}VL*zpbyp%)3ypWpCTb`&vprbL
z%c_FX05Kc&iPQZgLAhHl{zw@PRosKUg`f()?KITXwXjiK;e)
zrdyEu+GxWU#1TZE8xOzJK1=kA+qt#`Owo3CpwSN#Heh-^rtc(Ne|&`Jd>+nn<%o%z
z6xwN0r;e-yC$BNT7c)_w9Fb^1&DkaaipQBdw|-mK=#;`ejJm-iafK*xxE9juxDcio
zmA4#L;^8;~7A@Hg2+ZCu6q_(q$t^CgD(3c*@t^
z_!To40HHfc*-Y2lZ#%7m3#!N&Qq~I`_u*@v8Eq2M!%e4gDKo`@bZuIr6k}~qd*L!o
zJ7{U(hX?(@3H1Kjh^j3!&cV*%ymN)yi|v9l=H-Q0RHvIpg3|dZko2o~+x(K@fUreM
zU7E-x{UFM=8yONC_*Y?5
znxGkg3mfcRiuNbUM?u<+%#Qot5;845+gp2ij;<+~sY@?@B*XfESqT~h)*#3&=T$ou
zcJRWyW-ot}dEXl`ZAL)X9k6=8`)W{P)76$V-frRqRE^5B4l>^dMMeTBAN
zGis0{G7k{S*9^3aKTEd7$q=3eS_)w5nHN4S^
zikUkcGT6B|8Qox8ObTGf*HQdXAExEu=E7h_R_zxD(i>L@yRy*;6xB&7KP1PN68>E5
z-`%BaA8qEC)>J#+%bxCCIytrbdX#Tr-DbQ^-ZYyCZx6Y-S%WFAmh(0F^L%`%LIdj7
z1&Yu=y}|f!hPIWupASp!YKIHjH23?1=!syzuliKKxP{$r1)qd#RAuZ)BTo0rRTH1K
zno{Lp#8w#@Gu=<*Pha1sU^wVWBa?h|JCz4A58)GJea+Z6ly!>H(+syb5zj@max>D2
zdVR;uiFCiF?Wz{&BO4jCxy3dLRM1cilYTnClsh#u&UJKRHWULgNz&--I}{xx0e}54
zMW8PDyJa8-GVkz}L%G%ob9VNjAYF)(>fH=mUZEax4%xu~#OE6=Fm7f`Cs*HDi#8HE`Cj)$Mt0e4
zs(b%n<*Bc918zw;1Cs3lPxL-{%k(HYYISW8TjhIH)4Tot>TZe(^vwK$JxR;p$JAXE
z-%Ah==XWA0B|LLJ1tWu7znSyV$23PAk&1Q#S#L{2W$_ms+>hi#o93@6hA?vM=0z<8
zt+Ru{Z6DM6`&18*oHyj!8t
z*7^K9(oL?unasJ53^ZtoR|&Pm%58Ix*(KV#`C7bW7Q!<~`gZ+ie^IqpV6RZ^AX{&3
z$S-Eg4pAXcPP~EO3bYUp!07chh~36R&UKUNPo_0y=k=GT_59cuiwcX
z!_l1;kKrXY(o@zS=5K(F9C-teLs~llkZ^q*`dALhcDT_26pY|K@eI?AdjaD2OFsb;
z6lO6*hZl1Obu}d9fO-t`M5Og>V`>ZzD&vREdj#!0=%Mfh?fsQ9-;#~diY0!Il0W;w
zl*~lA_8w8iBEVlEOsQmev%)ov+!%;-BTooX`!~%Yo2~|L(m7+@+=2ai+7Vc{W6&G8
z?Y43_7Y5}rwemBpvDP1XAe}RWgz%2ODAfTyVvBbrn^f_dzL9yl4V4(Ne1mi51z{(#
zP@*SVYk`NnFiUp~Em;$DRW@_XTgxKGkJ3l#1y)W1@5r^a$aRLGGPW{P8iRnn>e*$A
z$4q=9-bbXyp>+&^-n>gskE(zRxTpzk
zYy97xhSL2$MxmIsB*i{B!h*OQyRFUCnn7a<=k4bN>>feE=Lz@l1iJDJAoi)d^HnPy9ih_bKe
zQEGvXbR1!I$iubIWd@YUsv_KEKj_!$eDPS~3F8=sm~erI82bf?^6$7%Y|Pw=wwf=V
zYP8b257Y{HnCxk!37R5Z+mw9X5woRwBiS(^VTwegGGY!=OqNPXIK52=LBb;722l{<
z0E#5=Xo3;}0zyNTy>O}0^p>Ojn89igCvbLIydI(=mY!Ov+7e2(^USHb5QF~c3Mc)-
z9VW?gAyfkO0EsJiywN6}WR^*U5CM%)JbY@BMh0vxSo|p*V2kL#Ur**F$_w&~-B&0(&DTh~{95AP1L71o~r!0&>=aT@IIk
z+dgEPkKENK8F}|7xlDxLhG5oO8!wS$jZd}9vPpXlbY?Dlda91+ne6Q?pm{*;G12l@
z>Kb5HM|W0GJy0OqS?~^&f8u7sDwY-Bofy$!usCTbYefgLL<6ZFdS{Njv*<`mX&zl2
z1iI7%?Z}9Edm*XV0Uo_pt{vMoA7%G3fu$Z;^lJ%D2=)r{Zfv^t?&$ooF73HrR9hO*
zNEy~{eV3TrWbkt9mMBJ3_@vI@KH4$CoUS^H*wnh|IB6ZKn=C?)B0fchNAt_9IJCYq
zx49y5*(P$`zp?VGeQL%u3H?4->so%FK@HY<=(@Sxwd&&J?&;d_r0sa3Z7}oL#JrPi
zbO0xpkgaCG4lgCb6;5-XrcacB%cD=SgG=it0`D3$$ivomN+f>GRNn>O5=5f4uT={s
zjXEDD3<8rBCy>dFYv)};C5W|9E)00vJvjZ`59Kvnf>mT4Wp2TD+ki=XVQ5y}56JNthd2L3zW^PgmAHU2w0+o5YDh4|4>
z!^aBbh}9~+$|Q1WjlJQQN4%j=5RM_t{221}j9reJm~gbR@bx`>~Gmm7<9*x
zn4X75h(fF-mNYzd3OJ0Ps1!%%46k0yHAOVsE<}opZk%rvzQ-f2h+;dj|Ik&e*j2J#
z5jQS6?&yb!b}zXt=LC7MhHLmGv+utcVZ-IS~-ysx%dj2M;2kpc*S(
zJ_M+ZxCihaOhnQnix^HENj7UV3y_ajih@*?Va4^J+w1PHE9H~3_iFUIB}}fan3ZG7
zu%}Aj6_P|kqi{$;%oA&PE;(|jQ@Z#HE+?~#`5Tv9A5N$ocoo1yJSm7SS*Lu_GoiMQ
zA*|IH1s*9xTk}x?
zKoc;}Ga>pm==jCiXJJ2G2XjRsT9zgBwQgSr41zdoXgl9zH
zX($JkP;j&a9|%7
zak|`fZ%|@|ql5c}Ux)L0a*qmrPe-8Y+u5|j+2Ejf!Wu$`>A&i2Yr+Y(8zaI@v%p%9
z`_+8?qw;I#hClr?Ov;bu`%$7ao>=OS6)p3BZ9?`K5&*AbfETZA7}Wp-CM*R6`zFn1
zK#5`HA=ax*{^t|H
zNu0aw&6n7(P+m~n4xBnn_%&C6Lf6r~6x#JLiwWn)dd)0#YaVTTfn!VUyV
zh$?C=6?4NzCPylgiUyKdV>^Xq@`DkI%?6{A5d)(BWp|-IaFcea}Fbd@el0VsGmYN;K^gjJnkOqpLy$9IXc-F
zZ(C-vHuuxC6p~B=2;60eG#b-_E?tv#fwe?yl*c|LBm4Hd8QHd4>*lplOvH7&BL5a=
zq6!ckQMZ&A#kbyJM~0;)bpauux#dnMDClSt6HeU(>#qzPnTO1b|oS3Rk$_#0BzL0@61
z)}(B8Q^bb7G#Ka*(9P;MavREV#UBr9N}+r#2jI+fiBMk+X^yBU7~ZRynQ!jq>2}q)
z#I91Ip5u5XwIp#QLtbQ7++B(v823sm^Ydn3+rC*#ol72klrZAX1YmxX`lW{~Pv`!+
zFRpB~w27g2F4u*ypy0z=U32v4_@Qf@8KBtKVYo)iYv6~W@YdU^O8M1{7y~cTzWKG-
z0G8=->3uO35zQpX)~|+qPoJBYc8nWXqv#Bc-@B9?v(Z)gNT^9DX2zbsIMlkXaoy6E
z(-HPwAriw~fRjCh7pVa!9eE%(6T|Z2szqLGAK&D}nb7UH)LLmN|AB&X&fN_2BGEdI
zZLa(Snr!`Alm8|CZ54aj+H$!X*ocx!`E)hZbJ70!Q}9FW47CZfo9W>317|IT)6lq)
z%w*TvPK?7wjaP~wJM#$t#`q*PkTb-kD)JkAP;qZNUirC&>J{{C%A%MYY}ft4mj-@Z
z{{DEGTA3bQo#L=5`p`{E)=AiHh4S|NqTO+h)<{Tz{te9zG4j@T_4)97>paRD^qnpC
z<>KuIea^d$Op#Z5M9(;C+16S+ig3UnsHNrx12^GM00yACJy0EijfC&CJ8>|bxvi7rU?WM3%N
z5=6v0-Mbr)Oe)hs1U0F48VM(P$}M_B_=iXyDIODOi(f`r_=bP+>jDQ`>#-KxebSOu
z^(14a&E7Fiu0|{3s10j9FUwpsM&fGI;wkn|+2!%jQu7e$A
zO89XOvUD0?y~G~I3bR_ZP&lJp!Etp2Q-_<+)TJMn-IEVivd6;k+U}C&w=;paPAB5*
zha6=D#PBFb)8t-ws>ZKF4MuAvP!wo&h&)6AF_QLXr-+Tt+C(2LETYwDp3q-^z?=Hj
ztoTv5J7`*9Ze`_)IPJ$xR(gx83+*XMGyj&<1Q=x2$%Q~tRPd4FxRiVpMwe-AO=r^J
zOo9qi$j9&0Rs}LlTNvg!D9R5AiIH1n5pjlN0N*i!1}6R%@PB{3LB_Sh&-OWUZL2=?
z%Y3$YjdCLk{#-~Z5LovgD$ib}d48Mm@{DLnF>KP||=KhWs^MW(dQ~QNl!i!corMvzyKe_2e@Bu3@EcvZ#$5jGvNQ{bX
zC?9A2mW3$7oKa
zg+3BqXdr3Vxp+eS_OGusY;P!2P%`*_k54jEL@IAHCs`HKi+{$szrEFoC^%-5xq$|m
z((@M!*crfr)`gMVd++`0(R-Lx4W6B6DYXu1uy*MI14c!@o=2y3LRAVELbg#9#E?Wf
zpGAciMF@&Cx(={^5jDgfSa*B&2$K`VKSF7w_~+W|YI>YxSjtkezV?n9#+GfKv#YX;
z=eEODco`rWH8t%3+72}RmkhbSk<0<4TRZ1BULR@_0!N&%dpRF)Z`Fz=Sw#iP7S2I#
zl)bUQ+JQ8NVAvwl5%)Cyq2cT;!hltBBS^cX8$GDld$*1)prjK}RAmyqtIU1SBWVLQ
zq}B|ucCw&rA&&+^1OXCp*ufE~#0|Qw5zZl%V6B75PTUa++513wdva9{S$98r5_KdE
zquzIs4d`7%RnlMgdCB=D9+*~L*BTPN3X7X-@1ar!%5B-9CVFO?SSo}WsOv>qrcj>M
zl7k6=#a69mP;KRffimb}%|tSJcBtiTjYV`tqeV}@YB)jkveD9W#4T;42{P(Z&K!*J
zEMb^+*|W1fe8JMDK^3OI@(*#?ci~76qr`U4;IKW%T>e_ci9OzMAC!hnxr_pjcV4)Y
z8-1&-nCPE6GX-m&!sI*RUlgUIIOC)PIyY!zGg0*$c&S1|h;gg04iHP7g>25Xt^|HU
zh{7AdhxWQo8Ya{opPw%5qCN%7rU@S@1rSbZ)+&W2!HJ3@zmR9=H~50B>l%<
z1TN06HHs13Q3;l|&=6Dy-hqf}h%0t~?X@?8g;`?D+2!{EF$IEVn|yxHmI0EHO>71(
z*~=wxE@45zVC?O=-KsZUFs?u2d?S5^pX$ytPub@H8dQ=Ae!vp8|B5e0vq{=AdV*nf
z>s%Gl+O3=>;V7f6q2#m70e&WhB#OGv#F7gYC9J<#T?k6Ae}R#Vy6gCQf*2-MV2oypQDgCO@u3F&&VJU}d)A}l(G_pbV{MT_D*
zCzQrZ`3Ma7x&0H$!Q$}xF9XG}@bkjLi+L3??ey8$Tz^`PWcjM9CLMOeU02XC_x@g!
zjRoSsOwRff>v^kli}2G$l(jH)>1=pR=@AwK5yjAm=uIX@3{JvR*_kVQ6FO(%OZ{?T
z*HA5$(U*auyPRzV(fSK7op>&tJmek_A`88IF6u!S7;?ey12!=CD0Xv}LadpwLl(~J
zM{)DdEyr@w*c;iGUS@Brlk0{}TCygGCJ!gRMtJY&L?QpRTmp+vZafUyBUPbcN^|SC2DM=ME5k+FDtp~^Tv0kc8sfR
zQATg>JH5%FHmCKK3i;J@q(hewhL%Zn8cC2U5TceYiYkg!K5s_?6sb_!?3f;Aa|{ae
zV*M>XwBH~HR%F!tjfjcGm75W$#J>*S0u9nU@YJT6|C-7DUJtOzusmzr&qM15v8h|V
ztzeoN+mVUX^=9hPg=>KwJF;&O?w1rlzm~l)-Z$0%asJ#L4Ce-`%Mq@4?0S0Xgfc?y
zIvM0EU$4RjR_bV8BS)X&WoHLqF;sN8_b?qLtgomgOU;F-2P&?`Y0|k`wqzUl{xEpl
zR4T1)0>1ZxNvE%F-i-h^Z#>Z|pL%^!F`U}onMa4`T7iQt+#+^p<=HK}aJKiOcE@eU
zR}Z*JTYkMpa!nL~{EpbgzbUdY2=o#?TP3wn6dwVg1vNX1q^2_PD*avN
z?P$(D%NVMh*xWxi&wqG9*6;MJdiWValVe2HtG#hhXJxg~;I8Ju_0XoWxk+(}RyYc%v*8rC%IK&)d$CZMJYgzHhLB70O`kC8Q7IG&fIWoh|
zom8WMPU=>q%WkaLWxy`5kmz^zuO{Dg(Itnzf8(mN1h;Cp|4hc3h5w)P+T~YMXR=#fSjDq;`F1F>K#I5wLLg#k3*+F&`&+_q`vkzC5F6YK
zupxGyuNhY_8Wy8~RPe(p`DlPBcxp=;Bnyg+TO_(dNML=3n6c!eDtW(bVK^fb({A<-
z9j!#-p8sCXRnq+j~F#{dc$*XU*Hyq`J#EC-6?IZxS#0(kvz
z(+%ZuIl@3%4b_Nfgj*cq=y+%VmXPd&$Sr1BlObO5>=60e`rG?C210Xlz52xyk`Rh%
zM<@q+4{Jv}hqyhe2C1snyw#y(X_RvWbH0U^gs!#R;eoS9Y_w`(qN%}rAQFSo8$_v+
zp3PvUaBcNxMukB|W5;8)^rl%OIDGhhr|@cehIP+e=p)FAHC5vn<9rprs{>}j34P!u
znE)xNSwVCOc_-U(Vmcys0oKSk>wMo1khyYB^J~e;5EK)RDknp*asMhpYSBXR2+LX%
zU{RWpM;KUMkGdn{=lR
zVJ7X;(yRgt>i*0kr0e5b3}i<@)WvE4x12yP)V?XIfC0-5`!vep$(x5V%_V0INOI(S
z^_*K!6RJL@L_Ti1&7S_1jj%p9Bd^AUp<1UIJ#EQuS!C?#w>Ib?YV|hjYAm#l
z%RCiTrm-dmr&YL&5a7W2NcB~BRm*5YX3mI>w`=e4{xKuLE@~b{nV7a)=q#Gf^Jaej
z;i_GxOI31_{=0X-lu|17g*>4%5ruWV>@U!zrvy+InYv>*ZlML77n&H0>E{{yS2*#)
zh+MSsdwCUmq~9bB=q>Ep>;75=NhavvWk?cjUxIlGU%(hE&kfYKQLF;UH3PAr&qH>y
zTUYt%6Wb|=c6^o{oGF7(XFa7Y<6c2XE%;hI0Wu$rXIQcC?Ih#SWKZDDW=O`deeZz(EYI@y;My*KGS+85eO~H+sq&i
zWERP>bf8H->}1|(`O#kjLFOw!p%^|ohPF%7piE^$qfkwXB}_N0<&kxG`hL*P@`Deq
zTyDk+!_v(G_W-CdB(PNSU)p;Y6VasCXB=&xSq!PSlVU!HmD@Y;We6(gZw-mk1Bn%G
zTaFhXnGLq=@S7mNf1bz=K^05Z-6>H4JjCV5$_?%jB`5bG6hr%qDnGc4-r;;y@8b%d
z#Y>JB443qGNsHxfBA%Ql5A^hV)j}_e66Q87Cpb%`GiE)nosE;ne#xYZP=Kn4gt({p
z8{tI+;9jDxQzu3!=jBHwgzKxuyIbYw19>dRXut~S=PYY9E1cA=lvm!=(^7mCdJ~LU
z=GDk^;^!WJD;VEe#Uj1b(7evPPCTao`<5~@<&Lkhp#$QoNI9`gSZe~z4MsO;E@}I6
zkrVUn+}$Kn6i3&qIk3`i0NRuf(GmftyV1B6Nv9e3n1>Ytxfn!pnP;?Du2`g6ei8~(w>+(zHZ
z+~YsPZ7!?s{uKN_1QT8SE!0d9#LR&_IJNPWP=406j#Bv`t6a^}hGdC*f#=m5r0=AQ
zX7Ns%YFq?y*N*pV&h2p*?NxONJq!f`35OXXcV5DIt}M^<
zjH>yuIGXjmUArI~x{SEKp&OEQwFqzja!&45Ay1F#rP=8vxNIgJvXN`Z(j)v{H56As
zMmE*Rv&nfQh3Pk?eo*ja+nxdC0E`*$1c8{B8c0_-YC`oO(BUt@#2kStW*-qLFZDB>
z?r@)Svb5ko=`JpQ^&AQVWM>Q;S-D4wKK_#g^>LL#dIqVM@dWIiUEw~L&OB}`V{HSoUgqX@3W8BItilZUjix2ORKAx
z0}X#FH-zVEP+$b2*B+tJ*aggPG89HH^k53rflYQbf(?G?5dc+t#qKY!#sC7Ohru}_
zT2DTz%xOcs6I=ywX#H#*)CG}gcA3=x*U$Wjy&@nGBSvncUW@q1F)udi0)KOUh`$-=
z%})}3k_IA%5fRhw0J2!pe!CS{LtWA%65
z%3TYISLQfzfI!I+>lmLvI^SO+iHkj|KG7p6qCBKTSE7(0Wyr8Gu?04!=%S~x0K797
zZDO^SEJI$jPV@vUnfNYfFsS;E`W@1cEBI~p?)}}Fkuv~$vhH%3&1djlR=Zm7to?U*
zE>z(*(ukTZZLjphfh*v(5pI-gbvn;crd{}Jr5;&!-wX0u+o)6jwd8?Q(=Cd={@^#r
znt8KSPbQBS0`I>Jo&||2fA&^4F03;`t&(jE_O^!HbRjN64se
zEdv!=O?oUtOpwvP|4pBBji%gS{GpKX{|qxl|L4Ek(b&qw(cIL=*oel?*!ur(l4*C!
zI!vWl!oMIZyNkpwbj9q{#F6XI(=Rc@ety=CEiiB?P=Wk0|UBQCkuc|~CtugkXzL9_^R
zdKPia^4yjJ`~n@O3Cd{lfR{dKs4O7b*Bf#b#cf_a83eoQ(m_`Bj&rKA2s{mnia}Wk
z8}SzK@(??1ecx2d8i~_UWei<+q16?^f2nnr%!(!F21mnMefGc-5JPT5aBo!seF_*y
zK4-ylO5T9zvFF@4d9Ko)WL9&!4{S@eN$?vrT^kf8`N}FI22H%K-2a<@Lzyz|x+-%^
z@AvUC5i0__wt^I8?uPi+g3LO8MTzL9duq6c0W5w1w7TS=c|Hak`HjH=Im}Obec4|^
zN*zIkXQD>o`CdlZSdHPfG^6yq