Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-10931][ML][PYSPARK] PySpark Models Copy Param Values from Estimator #17849

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,7 +1344,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
super(MultilayerPerceptronClassifier, self).__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.classification.MultilayerPerceptronClassifier", self.uid)
self._setDefault(maxIter=100, tol=1E-4, blockSize=128, stepSize=0.03, solver="l-bfgs")
self._setDefault(maxIter=100, tol=1E-6, blockSize=128, stepSize=0.03, solver="l-bfgs")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a difference in default values between Python and Java that wasn't being caught because of check_params prematurely returning

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like 1e-6 is correct default value.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the check_params test was meant to catch that but was broken

kwargs = self._input_kwargs
self.setParams(**kwargs)

Expand Down
8 changes: 7 additions & 1 deletion python/pyspark/ml/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,13 @@ def toLocal(self):

WARNING: This involves collecting a large :py:func:`topicsMatrix` to the driver.
"""
return LocalLDAModel(self._call_java("toLocal"))
model = LocalLDAModel(self._call_java("toLocal"))

# SPARK-10931: Temporary fix to be removed once LDAModel defines Params
model._create_params_from_java()
model._transfer_params_from_java()

return model

@since("2.0.0")
def trainingLogLikelihood(self):
Expand Down
87 changes: 52 additions & 35 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,54 @@ def test_logistic_regression_check_thresholds(self):
LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5]
)

@staticmethod
def check_params(test_self, py_stage, check_params_exist=True):
"""
Checks common requirements for Params.params:
- set of params exist in Java and Python and are ordered by names
- param parent has the same UID as the object's UID
- default param value from Java matches value in Python
- optionally check if all params from Java also exist in Python
"""
py_stage_str = "%s %s" % (type(py_stage), py_stage)
if not hasattr(py_stage, "_to_java"):
return
java_stage = py_stage._to_java()
if java_stage is None:
return
test_self.assertEqual(py_stage.uid, java_stage.uid(), msg=py_stage_str)
if check_params_exist:
param_names = [p.name for p in py_stage.params]
java_params = list(java_stage.params())
java_param_names = [jp.name() for jp in java_params]
test_self.assertEqual(
param_names, sorted(java_param_names),
"Param list in Python does not match Java for %s:\nJava = %s\nPython = %s"
% (py_stage_str, java_param_names, param_names))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 436-443 is the only change to check_params?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also changed the return to continue on line 454, this loop is checking all params so it was meant to skip over random seed params - not break out of the loop entirely (this is why that default value for MLP was missed). I cleaned up the NaN checks, before it was just checking for Imputer params, but it should be the same for any params with NaN's as default values. This is lines 460-462

for p in py_stage.params:
test_self.assertEqual(p.parent, py_stage.uid)
java_param = java_stage.getParam(p.name)
py_has_default = py_stage.hasDefault(p)
java_has_default = java_stage.hasDefault(java_param)
test_self.assertEqual(py_has_default, java_has_default,
"Default value mismatch of param %s for Params %s"
% (p.name, str(py_stage)))
if py_has_default:
if p.name == "seed":
continue # Random seeds between Spark and PySpark are different
java_default = _java2py(test_self.sc,
java_stage.clear(java_param).getOrDefault(java_param))
py_stage._clear(p)
py_default = py_stage.getOrDefault(p)
# equality test for NaN is always False
if isinstance(java_default, float) and np.isnan(java_default):
java_default = "NaN"
py_default = "NaN" if np.isnan(py_default) else "not NaN"
test_self.assertEqual(
java_default, py_default,
"Java default %s != python default %s of param %s for Params %s"
% (str(java_default), str(py_default), p.name, str(py_stage)))


class EvaluatorTests(SparkSessionTestCase):

Expand Down Expand Up @@ -473,6 +521,8 @@ def test_idf(self):
"Model should inherit the UID from its parent estimator.")
output = idf0m.transform(dataset)
self.assertIsNotNone(output.head().idf)
# Test that parameters transferred to Python Model
ParamTests.check_params(self, idf0m)

def test_ngram(self):
dataset = self.spark.createDataFrame([
Expand Down Expand Up @@ -1525,40 +1575,6 @@ class DefaultValuesTests(PySparkTestCase):
those in their Scala counterparts.
"""

def check_params(self, py_stage):
import pyspark.ml.feature
if not hasattr(py_stage, "_to_java"):
return
java_stage = py_stage._to_java()
if java_stage is None:
return
for p in py_stage.params:
java_param = java_stage.getParam(p.name)
py_has_default = py_stage.hasDefault(p)
java_has_default = java_stage.hasDefault(java_param)
self.assertEqual(py_has_default, java_has_default,
"Default value mismatch of param %s for Params %s"
% (p.name, str(py_stage)))
if py_has_default:
if p.name == "seed":
return # Random seeds between Spark and PySpark are different
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should not return, I changed it to continue above

java_default =\
_java2py(self.sc, java_stage.clear(java_param).getOrDefault(java_param))
py_stage._clear(p)
py_default = py_stage.getOrDefault(p)
if isinstance(py_stage, pyspark.ml.feature.Imputer) and p.name == "missingValue":
# SPARK-15040 - default value for Imputer param 'missingValue' is NaN,
# and NaN != NaN, so handle it specially here
import math
self.assertTrue(math.isnan(java_default) and math.isnan(py_default),
"Java default %s and python default %s are not both NaN for "
"param %s for Params %s"
% (str(java_default), str(py_default), p.name, str(py_stage)))
return
self.assertEqual(java_default, py_default,
"Java default %s != python default %s of param %s for Params %s"
% (str(java_default), str(py_default), p.name, str(py_stage)))

def test_java_params(self):
import pyspark.ml.feature
import pyspark.ml.classification
Expand All @@ -1572,7 +1588,8 @@ def test_java_params(self):
for name, cls in inspect.getmembers(module, inspect.isclass):
if not name.endswith('Model') and issubclass(cls, JavaParams)\
and not inspect.isabstract(cls):
self.check_params(cls())
# NOTE: disable check_params_exist until there is parity with Scala API
ParamTests.check_params(self, cls(), check_params_exist=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This skips param test for Model. Should we do similar check to all models?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, ideally but most of the models need to be trained first so that is why they are skipped here. Some basic framework would need to be added to allow this, and I'm looking into that as a follow on.



def _squared_distance(a, b):
Expand Down
32 changes: 31 additions & 1 deletion python/pyspark/ml/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,20 @@ def _transfer_param_map_to_java(self, pyParamMap):
paramMap.put([pair])
return paramMap

def _create_params_from_java(self):
"""
SPARK-10931: Temporary fix to create params that are defined in the Java obj but not here
"""
java_params = list(self._java_obj.params())
from pyspark.ml.param import Param
for java_param in java_params:
java_param_name = java_param.name()
if not hasattr(self, java_param_name):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If self contains a same name attribute which is not a Param, should we process it like throw exception?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, it's possible that there could be an attribute with that name that is not a param. If that's the case, then it is probably best to just ignore silently since this is not critical to the model.

param = Param(self, java_param_name, java_param.doc())
setattr(param, "created_from_java_param", True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, would you mind if I ask where created_from_java_param is used?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is part of a temporary fix to add Params that are defined in Java but not in Python, then this just adds a tag to the Param in case something goes wrong we will know the param was created here.

setattr(self, java_param_name, param)
self._params = None # need to reset so self.params will discover new params

def _transfer_params_from_java(self):
"""
Transforms the embedded params from the companion Java object.
Expand All @@ -147,6 +161,10 @@ def _transfer_params_from_java(self):
if self._java_obj.isSet(java_param):
value = _java2py(sc, self._java_obj.getOrDefault(java_param))
self._set(**{param.name: value})
# SPARK-10931: Temporary fix for params that have a default in Java
if self._java_obj.hasDefault(java_param) and not self.isDefined(param):
value = _java2py(sc, self._java_obj.getDefault(java_param)).get()
self._setDefault(**{param.name: value})

def _transfer_param_map_from_java(self, javaParamMap):
"""
Expand Down Expand Up @@ -204,6 +222,11 @@ def __get_class(clazz):
# Load information from java_stage to the instance.
py_stage = py_type()
py_stage._java_obj = java_stage

# SPARK-10931: Temporary fix so that persisted models would own params from Estimator
if issubclass(py_type, JavaModel):
py_stage._create_params_from_java()

py_stage._resetUid(java_stage.uid())
py_stage._transfer_params_from_java()
elif hasattr(py_type, "_from_java"):
Expand Down Expand Up @@ -263,7 +286,8 @@ def _fit_java(self, dataset):

def _fit(self, dataset):
java_model = self._fit_java(dataset)
return self._create_model(java_model)
model = self._create_model(java_model)
return self._copyValues(model)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the crucial line being added in this PR. Without this, if a Python model defines a param (matching one from Scala), then when the model is fit in Scala that param value will never be sent back to Python.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I think it is going to copy values from the estimator to the created model. So I think we assume that the params in estimator and model are the same?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that is the assumption and it's the same on the Scala side too. The estimators and models should both have a shared mixin that defines the common params used. That's how it's done on the Scala side and Python should follow (once that's done, the temporary fix from here can be removed).



@inherit_doc
Expand Down Expand Up @@ -307,4 +331,10 @@ def __init__(self, java_model=None):
"""
super(JavaModel, self).__init__(java_model)
if java_model is not None:

# SPARK-10931: This is a temporary fix to allow models to own params
# from estimators. Eventually, these params should be in models through
# using common base classes between estimators and models.
self._create_params_from_java()

self._resetUid(java_model.uid())