-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-10931][ML][PYSPARK] PySpark Models Copy Param Values from Estimator #17849
[SPARK-10931][ML][PYSPARK] PySpark Models Copy Param Values from Estimator #17849
Conversation
… instead of continue
…n-params-SPARK-10931
@@ -1325,7 +1325,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred | |||
super(MultilayerPerceptronClassifier, self).__init__() | |||
self._java_obj = self._new_java_obj( | |||
"org.apache.spark.ml.classification.MultilayerPerceptronClassifier", self.uid) | |||
self._setDefault(maxIter=100, tol=1E-4, blockSize=128, stepSize=0.03, solver="l-bfgs") | |||
self._setDefault(maxIter=100, tol=1E-6, blockSize=128, stepSize=0.03, solver="l-bfgs") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a difference in default values between Python and Java that wasn't being caught because of check_params
prematurely returning
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like 1e-6 is correct default value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the check_params
test was meant to catch that but was broken
% (p.name, str(py_stage))) | ||
if py_has_default: | ||
if p.name == "seed": | ||
return # Random seeds between Spark and PySpark are different |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should not return, I changed it to continue
above
python/pyspark/ml/tests.py
Outdated
@@ -1355,7 +1370,7 @@ def test_java_params(self): | |||
for name, cls in inspect.getmembers(module, inspect.isclass): | |||
if not name.endswith('Model') and issubclass(cls, JavaParams)\ | |||
and not inspect.isabstract(cls): | |||
self.check_params(cls()) | |||
ParamTests.check_params(self, cls(), check_params_exist=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Setting check_params_exist
to True will uncover any params that exist in Java but not in Python
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might make sense to include as a comment in the code for whoever is coming to update this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, will do
Test build #76429 has finished for PR 17849 at commit
|
@jkbradley @holdenk the heart of this change is just adding the call to I know having this temporary 'fix' isn't ideal but it would allow us to incrementally add missing Params or restructure class hierarchy to match Scala versions and will continue to copy these values to the Models. Until that is done, there won't be explicit methods to get each param, such as |
python/pyspark/ml/wrapper.py
Outdated
# SPARK-10931: This is a temporary fix to allow models to own params | ||
# from estimators. Eventually, these params should be in models through | ||
# using common base classes between estimators and models. | ||
model._create_params_from_java() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be better to move to JavaModel.__init()
for the case of creating a model without fitting - e.g. CountVectorizerModel
from vocabulary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So right now this would apply to all of the models, would it make sense to make it so that we can selectively move the params forward one at a time?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think there is really any downside of just creating all the Params from Java, see my comment below.
…ake a model without fitting
python/pyspark/ml/tests.py
Outdated
@@ -404,6 +404,53 @@ def test_copy_param_extras(self): | |||
self.assertEqual(tp._paramMap, copied_no_extra) | |||
self.assertEqual(tp._defaultParamMap, tp_copy._defaultParamMap) | |||
|
|||
@staticmethod | |||
def check_params(test_self, py_stage, check_params_exist=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank so you much for putting in the time on this. :D :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no problem!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for working on this, I've done a first read through with some questions :)
python/pyspark/ml/tests.py
Outdated
@@ -1355,7 +1370,7 @@ def test_java_params(self): | |||
for name, cls in inspect.getmembers(module, inspect.isclass): | |||
if not name.endswith('Model') and issubclass(cls, JavaParams)\ | |||
and not inspect.isabstract(cls): | |||
self.check_params(cls()) | |||
ParamTests.check_params(self, cls(), check_params_exist=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might make sense to include as a comment in the code for whoever is coming to update this.
python/pyspark/ml/wrapper.py
Outdated
# SPARK-10931: This is a temporary fix to allow models to own params | ||
# from estimators. Eventually, these params should be in models through | ||
# using common base classes between estimators and models. | ||
model._create_params_from_java() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So right now this would apply to all of the models, would it make sense to make it so that we can selectively move the params forward one at a time?
Thanks @holdenk for the review! I think I wrote the description a little too rushed, so let me clarify a bit... The temporary "fix" will just create empty params in the model if they exist in the Java model but not the Python one. There should be no risk of having these added to the Python model since they are empty when created and not yet defined with a value. These params will be set in 2 ways: 1) after the model is fit in the call to I think the best way forward to get parity with the Scala API is to then organize a JIRA with subtasks to update the Python ML class hierarchies to match the Scala ones, so that the Params will be defined that way with proper "get" and "set" methods too. It might be good to also have a Python test that checks for matching params in Java for both the estimators and models. It could be ignored by default and then enabled during the QA period. The temporary fix here would continue to work and not interfere while the params are being added. It could be removed once we feel that most of the params have been properly added and close to matching the Scala API. |
Test build #76596 has finished for PR 17849 at commit
|
ping @jkbradley @holdenk , please have a look when you can, thanks! |
This looks pretty reasonable, sorry for the delay. If you have a chance to update this to master would be good to do. |
Thanks @holdenk! Sure, I'll update to master |
…n-params-SPARK-10931
Test build #80089 has finished for PR 17849 at commit
|
ping @holdenk - think this is good to go? |
Thanks your work on this but I am curious what is the benefit of doing this? In pyspark there is no param in Model itself currently, what is the problem or bugs it can resolve after adding params to pyspark model ? |
If params are defined in the PySpark model, when that model is fit a Scala version is created then the PySpark model is wrapped around it. The param values from the Scala version are never transferred to the PySpark model, so the defined params will only have default values. |
ping @holdenk , also @HyukjinKwon if you are able to take a look |
@@ -263,7 +284,8 @@ def _fit_java(self, dataset): | |||
|
|||
def _fit(self, dataset): | |||
java_model = self._fit_java(dataset) | |||
return self._create_model(java_model) | |||
model = self._create_model(java_model) | |||
return self._copyValues(model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the crucial line being added in this PR. Without this, if a Python model defines a param (matching one from Scala), then when the model is fit in Scala that param value will never be sent back to Python.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I think it is going to copy values from the estimator to the created model. So I think we assume that the params in estimator and model are the same?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that is the assumption and it's the same on the Scala side too. The estimators and models should both have a shared mixin that defines the common params used. That's how it's done on the Scala side and Python should follow (once that's done, the temporary fix from here can be removed).
I am rather a backend developer and work together with data scientists. So, my ML knowledge is limited (am studying hard :)). Will leave few comments together if there are some nits and someone starts to review so that they can be addressed together. cc @viirya who I believe knows ML bit and @zero323 who I believe should be able to review this (but now is inactive though), are you maybe able to make a pass for this one? |
Will try to give a pass anyway. |
java_param_name = java_param.name() | ||
if not hasattr(self, java_param_name): | ||
param = Param(self, java_param_name, java_param.doc()) | ||
setattr(param, "created_from_java_param", True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, would you mind if I ask where created_from_java_param
is used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is part of a temporary fix to add Params that are defined in Java but not in Python, then this just adds a tag to the Param in case something goes wrong we will know the param was created here.
from pyspark.ml.param import Param | ||
for java_param in java_params: | ||
java_param_name = java_param.name() | ||
if not hasattr(self, java_param_name): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If self contains a same name attribute which is not a Param
, should we process it like throw exception?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, it's possible that there could be an attribute with that name that is not a param. If that's the case, then it is probably best to just ignore silently since this is not critical to the model.
Sorry, let me try and take a look tomorrow. |
python/pyspark/ml/wrapper.py
Outdated
if self._java_obj.isSet(java_param): | ||
if self._java_obj.isSet(java_param) or ( | ||
# SPARK-10931: Temporary fix for params that have a default in Java | ||
self._java_obj.hasDefault(java_param) and not self.isDefined(param)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change will make a default value for a param in java side as an user-provided param value in python side. I think we should use _setDefault
for default value instead of _set
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True. I was thinking since this is part of the temporary fix, then it doesn't matter, but it won't be much extra to use _setDefault
and probably be clearer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, fixed to use _setDefault
test_self.assertEqual( | ||
param_names, sorted(java_param_names), | ||
"Param list in Python does not match Java for %s:\nJava = %s\nPython = %s" | ||
% (py_stage_str, java_param_names, param_names)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Line 436-443 is the only change to check_params
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also changed the return to continue on line 454, this loop is checking all params so it was meant to skip over random seed params - not break out of the loop entirely (this is why that default value for MLP was missed). I cleaned up the NaN checks, before it was just checking for Imputer params, but it should be the same for any params with NaN's as default values. This is lines 460-462
@@ -1572,7 +1588,8 @@ def test_java_params(self): | |||
for name, cls in inspect.getmembers(module, inspect.isclass): | |||
if not name.endswith('Model') and issubclass(cls, JavaParams)\ | |||
and not inspect.isabstract(cls): | |||
self.check_params(cls()) | |||
# NOTE: disable check_params_exist until there is parity with Scala API | |||
ParamTests.check_params(self, cls(), check_params_exist=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This skips param test for Model. Should we do similar check to all models?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, ideally but most of the models need to be trained first so that is why they are skipped here. Some basic framework would need to be added to allow this, and I'm looking into that as a follow on.
Thanks for reviewing @viirya and @HyukjinKwon ! |
… made NaN check better message when fail
Test build #80499 has finished for PR 17849 at commit
|
Test build #80506 has finished for PR 17849 at commit
|
LGTM, its certainly sort of an intermediary fix state but making the params accessible without users having to go through py4j manually is worth while. I'll leave this over the weekend in case anyone has issues. |
@holdenk , do you think this is good to go now? |
What do you think about this ? @jkbradley |
I think its good to go for master pending jenkins (it's been awhile since the last run). So let's just make sure everything is still ok: Jenkins retest this please. |
Test build #81004 has finished for PR 17849 at commit
|
Merged to master, thanks everyone :) (There is also a follow up JIRA https://issues.apache.org/jira/browse/SPARK-21812 for explicitly defining all of the params in Python). |
Thanks @holdenk! |
What changes were proposed in this pull request?
Added call to copy values of Params from Estimator to Model after fit in PySpark ML. This will copy values for any params that are also defined in the Model. Since currently most Models do not define the same params from the Estimator, also added method to create new Params from looking at the Java object if they do not exist in the Python object. This is a temporary fix that can be removed once the PySpark models properly define the params themselves.
How was this patch tested?
Refactored the
check_params
test to optionally check if the model params for Python and Java match and added this check to an existing fitted model that shares params between Estimator and Model.