Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

overload StringArrayParam.w #3

Merged
merged 1 commit into from
May 7, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.util.NoSuchElementException

import scala.annotation.varargs
import scala.collection.mutable
import scala.reflect.ClassTag
import scala.collection.JavaConverters._

import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.util.Identifiable
Expand Down Expand Up @@ -228,7 +228,8 @@ class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array

override def w(value: Array[String]): ParamPair[Array[String]] = super.w(value)

private[param] def wCast(value: Seq[String]): ParamPair[Array[String]] = w(value.toArray)
/** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
def w(value: java.util.List[String]): ParamPair[Array[String]] = w(value.asScala.toArray)
}

/**
Expand Down Expand Up @@ -323,13 +324,7 @@ trait Params extends Identifiable with Serializable {
* Sets a parameter in the embedded param map.
*/
protected final def set[T](param: Param[T], value: T): this.type = {
shouldOwn(param)
if (param.isInstanceOf[StringArrayParam] && value.isInstanceOf[Seq[_]]) {
paramMap.put(param.asInstanceOf[StringArrayParam].wCast(value.asInstanceOf[Seq[String]]))
} else {
paramMap.put(param.w(value))
}
this
set(param -> value)
}

/**
Expand All @@ -339,6 +334,15 @@ trait Params extends Identifiable with Serializable {
set(getParam(param), value)
}

/**
* Sets a parameter in the embedded param map.
*/
protected final def set(paramPair: ParamPair[_]): this.type = {
shouldOwn(paramPair.param)
paramMap.put(paramPair)
this
}

/**
* Optionally returns the user-supplied value of a param.
*/
Expand Down
8 changes: 4 additions & 4 deletions python/pyspark/ml/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,8 @@ def _transfer_params_to_java(self, params, java_obj):
for param in self.params:
if param in paramMap:
value = paramMap[param]
if isinstance(value, list):
value = _jvm().PythonUtils.toSeq(value)
java_obj.set(param.name, value)
java_param = java_obj.getParam(param.name)
java_obj.set(java_param.w(value))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was missing this part :(


def _empty_java_param_map(self):
"""
Expand All @@ -82,7 +81,8 @@ def _create_java_param_map(self, params, java_obj):
paramMap = self._empty_java_param_map()
for param, value in params.items():
if param.parent is self:
paramMap.put(java_obj.getParam(param.name), value)
java_param = java_obj.getParam(param.name)
paramMap.put(java_param.w(value))
return paramMap


Expand Down