-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-10931][ML][PYSPARK] PySpark Models Copy Param Values from Estimator #17849
Changes from all commits
a4ede3f
3b921a4
dff7863
398ef27
1f3de13
d621c89
acdb4b9
9b7b886
765eb5f
ca52db4
a22a2cc
4a66e90
4affa01
f4a657e
07f6e85
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -417,6 +417,54 @@ def test_logistic_regression_check_thresholds(self): | |
LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5] | ||
) | ||
|
||
@staticmethod | ||
def check_params(test_self, py_stage, check_params_exist=True): | ||
""" | ||
Checks common requirements for Params.params: | ||
- set of params exist in Java and Python and are ordered by names | ||
- param parent has the same UID as the object's UID | ||
- default param value from Java matches value in Python | ||
- optionally check if all params from Java also exist in Python | ||
""" | ||
py_stage_str = "%s %s" % (type(py_stage), py_stage) | ||
if not hasattr(py_stage, "_to_java"): | ||
return | ||
java_stage = py_stage._to_java() | ||
if java_stage is None: | ||
return | ||
test_self.assertEqual(py_stage.uid, java_stage.uid(), msg=py_stage_str) | ||
if check_params_exist: | ||
param_names = [p.name for p in py_stage.params] | ||
java_params = list(java_stage.params()) | ||
java_param_names = [jp.name() for jp in java_params] | ||
test_self.assertEqual( | ||
param_names, sorted(java_param_names), | ||
"Param list in Python does not match Java for %s:\nJava = %s\nPython = %s" | ||
% (py_stage_str, java_param_names, param_names)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Line 436-443 is the only change to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also changed the return to continue on line 454, this loop is checking all params so it was meant to skip over random seed params - not break out of the loop entirely (this is why that default value for MLP was missed). I cleaned up the NaN checks, before it was just checking for Imputer params, but it should be the same for any params with NaN's as default values. This is lines 460-462 |
||
for p in py_stage.params: | ||
test_self.assertEqual(p.parent, py_stage.uid) | ||
java_param = java_stage.getParam(p.name) | ||
py_has_default = py_stage.hasDefault(p) | ||
java_has_default = java_stage.hasDefault(java_param) | ||
test_self.assertEqual(py_has_default, java_has_default, | ||
"Default value mismatch of param %s for Params %s" | ||
% (p.name, str(py_stage))) | ||
if py_has_default: | ||
if p.name == "seed": | ||
continue # Random seeds between Spark and PySpark are different | ||
java_default = _java2py(test_self.sc, | ||
java_stage.clear(java_param).getOrDefault(java_param)) | ||
py_stage._clear(p) | ||
py_default = py_stage.getOrDefault(p) | ||
# equality test for NaN is always False | ||
if isinstance(java_default, float) and np.isnan(java_default): | ||
java_default = "NaN" | ||
py_default = "NaN" if np.isnan(py_default) else "not NaN" | ||
test_self.assertEqual( | ||
java_default, py_default, | ||
"Java default %s != python default %s of param %s for Params %s" | ||
% (str(java_default), str(py_default), p.name, str(py_stage))) | ||
|
||
|
||
class EvaluatorTests(SparkSessionTestCase): | ||
|
||
|
@@ -473,6 +521,8 @@ def test_idf(self): | |
"Model should inherit the UID from its parent estimator.") | ||
output = idf0m.transform(dataset) | ||
self.assertIsNotNone(output.head().idf) | ||
# Test that parameters transferred to Python Model | ||
ParamTests.check_params(self, idf0m) | ||
|
||
def test_ngram(self): | ||
dataset = self.spark.createDataFrame([ | ||
|
@@ -1525,40 +1575,6 @@ class DefaultValuesTests(PySparkTestCase): | |
those in their Scala counterparts. | ||
""" | ||
|
||
def check_params(self, py_stage): | ||
import pyspark.ml.feature | ||
if not hasattr(py_stage, "_to_java"): | ||
return | ||
java_stage = py_stage._to_java() | ||
if java_stage is None: | ||
return | ||
for p in py_stage.params: | ||
java_param = java_stage.getParam(p.name) | ||
py_has_default = py_stage.hasDefault(p) | ||
java_has_default = java_stage.hasDefault(java_param) | ||
self.assertEqual(py_has_default, java_has_default, | ||
"Default value mismatch of param %s for Params %s" | ||
% (p.name, str(py_stage))) | ||
if py_has_default: | ||
if p.name == "seed": | ||
return # Random seeds between Spark and PySpark are different | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should not return, I changed it to |
||
java_default =\ | ||
_java2py(self.sc, java_stage.clear(java_param).getOrDefault(java_param)) | ||
py_stage._clear(p) | ||
py_default = py_stage.getOrDefault(p) | ||
if isinstance(py_stage, pyspark.ml.feature.Imputer) and p.name == "missingValue": | ||
# SPARK-15040 - default value for Imputer param 'missingValue' is NaN, | ||
# and NaN != NaN, so handle it specially here | ||
import math | ||
self.assertTrue(math.isnan(java_default) and math.isnan(py_default), | ||
"Java default %s and python default %s are not both NaN for " | ||
"param %s for Params %s" | ||
% (str(java_default), str(py_default), p.name, str(py_stage))) | ||
return | ||
self.assertEqual(java_default, py_default, | ||
"Java default %s != python default %s of param %s for Params %s" | ||
% (str(java_default), str(py_default), p.name, str(py_stage))) | ||
|
||
def test_java_params(self): | ||
import pyspark.ml.feature | ||
import pyspark.ml.classification | ||
|
@@ -1572,7 +1588,8 @@ def test_java_params(self): | |
for name, cls in inspect.getmembers(module, inspect.isclass): | ||
if not name.endswith('Model') and issubclass(cls, JavaParams)\ | ||
and not inspect.isabstract(cls): | ||
self.check_params(cls()) | ||
# NOTE: disable check_params_exist until there is parity with Scala API | ||
ParamTests.check_params(self, cls(), check_params_exist=False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This skips param test for Model. Should we do similar check to all models? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, ideally but most of the models need to be trained first so that is why they are skipped here. Some basic framework would need to be added to allow this, and I'm looking into that as a follow on. |
||
|
||
|
||
def _squared_distance(a, b): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -135,6 +135,20 @@ def _transfer_param_map_to_java(self, pyParamMap): | |
paramMap.put([pair]) | ||
return paramMap | ||
|
||
def _create_params_from_java(self): | ||
""" | ||
SPARK-10931: Temporary fix to create params that are defined in the Java obj but not here | ||
""" | ||
java_params = list(self._java_obj.params()) | ||
from pyspark.ml.param import Param | ||
for java_param in java_params: | ||
java_param_name = java_param.name() | ||
if not hasattr(self, java_param_name): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If self contains a same name attribute which is not a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point, it's possible that there could be an attribute with that name that is not a param. If that's the case, then it is probably best to just ignore silently since this is not critical to the model. |
||
param = Param(self, java_param_name, java_param.doc()) | ||
setattr(param, "created_from_java_param", True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BTW, would you mind if I ask where There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this is part of a temporary fix to add Params that are defined in Java but not in Python, then this just adds a tag to the Param in case something goes wrong we will know the param was created here. |
||
setattr(self, java_param_name, param) | ||
self._params = None # need to reset so self.params will discover new params | ||
|
||
def _transfer_params_from_java(self): | ||
""" | ||
Transforms the embedded params from the companion Java object. | ||
|
@@ -147,6 +161,10 @@ def _transfer_params_from_java(self): | |
if self._java_obj.isSet(java_param): | ||
value = _java2py(sc, self._java_obj.getOrDefault(java_param)) | ||
self._set(**{param.name: value}) | ||
# SPARK-10931: Temporary fix for params that have a default in Java | ||
if self._java_obj.hasDefault(java_param) and not self.isDefined(param): | ||
value = _java2py(sc, self._java_obj.getDefault(java_param)).get() | ||
self._setDefault(**{param.name: value}) | ||
|
||
def _transfer_param_map_from_java(self, javaParamMap): | ||
""" | ||
|
@@ -204,6 +222,11 @@ def __get_class(clazz): | |
# Load information from java_stage to the instance. | ||
py_stage = py_type() | ||
py_stage._java_obj = java_stage | ||
|
||
# SPARK-10931: Temporary fix so that persisted models would own params from Estimator | ||
if issubclass(py_type, JavaModel): | ||
py_stage._create_params_from_java() | ||
|
||
py_stage._resetUid(java_stage.uid()) | ||
py_stage._transfer_params_from_java() | ||
elif hasattr(py_type, "_from_java"): | ||
|
@@ -263,7 +286,8 @@ def _fit_java(self, dataset): | |
|
||
def _fit(self, dataset): | ||
java_model = self._fit_java(dataset) | ||
return self._create_model(java_model) | ||
model = self._create_model(java_model) | ||
return self._copyValues(model) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the crucial line being added in this PR. Without this, if a Python model defines a param (matching one from Scala), then when the model is fit in Scala that param value will never be sent back to Python. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here I think it is going to copy values from the estimator to the created model. So I think we assume that the params in estimator and model are the same? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, that is the assumption and it's the same on the Scala side too. The estimators and models should both have a shared mixin that defines the common params used. That's how it's done on the Scala side and Python should follow (once that's done, the temporary fix from here can be removed). |
||
|
||
|
||
@inherit_doc | ||
|
@@ -307,4 +331,10 @@ def __init__(self, java_model=None): | |
""" | ||
super(JavaModel, self).__init__(java_model) | ||
if java_model is not None: | ||
|
||
# SPARK-10931: This is a temporary fix to allow models to own params | ||
# from estimators. Eventually, these params should be in models through | ||
# using common base classes between estimators and models. | ||
self._create_params_from_java() | ||
|
||
self._resetUid(java_model.uid()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a difference in default values between Python and Java that wasn't being caught because of
check_params
prematurely returningThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like 1e-6 is correct default value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the
check_params
test was meant to catch that but was broken