diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index ea779f7d409cf..51ce19d29cd29 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -24,8 +24,7 @@ import scala.annotation.varargs import scala.collection.mutable import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.util.{SchemaUtils, Identifiable} -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.ml.util.Identifiable /** * :: AlphaComponent :: @@ -381,18 +380,6 @@ trait Params extends Identifiable with Serializable { this } - /** - * Check whether the given schema contains an input column. - * @param colName Input column name - * @param dataType Input column DataType - */ - protected def checkInputColumn(schema: StructType, colName: String, dataType: DataType): Unit = { - val actualDataType = schema(colName).dataType - SchemaUtils.checkColumnType(schema, colName, dataType) - require(actualDataType.equals(dataType), s"Input column Name: $colName Description: ${getParam(colName)}") - } - - /** * Gets the default value of a parameter. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index 0383bf0b382b7..9252618715625 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -34,10 +34,11 @@ object SchemaUtils { * @param colName column name * @param dataType required column data type */ - def checkColumnType(schema: StructType, colName: String, dataType: DataType): Unit = { + def checkColumnType(schema: StructType, colName: String, dataType: DataType, + msg: String = ""): Unit = { val actualDataType = schema(colName).dataType require(actualDataType.equals(dataType), - s"Column $colName must be of type $dataType but was actually $actualDataType.") + s"Column $colName must be of type $dataType but was actually $actualDataType.$msg") } /**