Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-10931][ML][PYSPARK] PySpark Models Copy Param Values from Estimator #17849

Conversation

BryanCutler
Copy link
Member

What changes were proposed in this pull request?

Added call to copy values of Params from Estimator to Model after fit in PySpark ML. This will copy values for any params that are also defined in the Model. Since currently most Models do not define the same params from the Estimator, also added method to create new Params from looking at the Java object if they do not exist in the Python object. This is a temporary fix that can be removed once the PySpark models properly define the params themselves.

How was this patch tested?

Refactored the check_params test to optionally check if the model params for Python and Java match and added this check to an existing fitted model that shares params between Estimator and Model.

@@ -1325,7 +1325,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
super(MultilayerPerceptronClassifier, self).__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.classification.MultilayerPerceptronClassifier", self.uid)
self._setDefault(maxIter=100, tol=1E-4, blockSize=128, stepSize=0.03, solver="l-bfgs")
self._setDefault(maxIter=100, tol=1E-6, blockSize=128, stepSize=0.03, solver="l-bfgs")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a difference in default values between Python and Java that wasn't being caught because of check_params prematurely returning

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like 1e-6 is correct default value.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the check_params test was meant to catch that but was broken

% (p.name, str(py_stage)))
if py_has_default:
if p.name == "seed":
return # Random seeds between Spark and PySpark are different
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should not return, I changed it to continue above

@@ -1355,7 +1370,7 @@ def test_java_params(self):
for name, cls in inspect.getmembers(module, inspect.isclass):
if not name.endswith('Model') and issubclass(cls, JavaParams)\
and not inspect.isabstract(cls):
self.check_params(cls())
ParamTests.check_params(self, cls(), check_params_exist=False)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting check_params_exist to True will uncover any params that exist in Java but not in Python

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might make sense to include as a comment in the code for whoever is coming to update this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, will do

@SparkQA
Copy link

SparkQA commented May 3, 2017

Test build #76429 has finished for PR 17849 at commit 765eb5f.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@BryanCutler
Copy link
Member Author

@jkbradley @holdenk the heart of this change is just adding the call to _copyValues to copy param values from Estimator to Model. That doesn't really do much though, since most of the Python models do not define any params and there is nothing to copy to. So I added a temporary little hack to look at the Java Model params after fitting and create any params that don't already exist, then any set values can be copied. Also needed to do the same after loading a Python model or this will fail persistence tests.

I know having this temporary 'fix' isn't ideal but it would allow us to incrementally add missing Params or restructure class hierarchy to match Scala versions and will continue to copy these values to the Models. Until that is done, there won't be explicit methods to get each param, such as getMaxDepth() but the param value can still be accessed by param.getOrDefault("maxDepth") to give users a workaround for all of those type of JIRAs that have come up. What do you guys think?

# SPARK-10931: This is a temporary fix to allow models to own params
# from estimators. Eventually, these params should be in models through
# using common base classes between estimators and models.
model._create_params_from_java()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be better to move to JavaModel.__init() for the case of creating a model without fitting - e.g. CountVectorizerModel from vocabulary.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So right now this would apply to all of the models, would it make sense to make it so that we can selectively move the params forward one at a time?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there is really any downside of just creating all the Params from Java, see my comment below.

@@ -404,6 +404,53 @@ def test_copy_param_extras(self):
self.assertEqual(tp._paramMap, copied_no_extra)
self.assertEqual(tp._defaultParamMap, tp_copy._defaultParamMap)

@staticmethod
def check_params(test_self, py_stage, check_params_exist=True):
Copy link
Contributor

@holdenk holdenk May 6, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank so you much for putting in the time on this. :D :D

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no problem!

Copy link
Contributor

@holdenk holdenk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for working on this, I've done a first read through with some questions :)

@@ -1355,7 +1370,7 @@ def test_java_params(self):
for name, cls in inspect.getmembers(module, inspect.isclass):
if not name.endswith('Model') and issubclass(cls, JavaParams)\
and not inspect.isabstract(cls):
self.check_params(cls())
ParamTests.check_params(self, cls(), check_params_exist=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might make sense to include as a comment in the code for whoever is coming to update this.

# SPARK-10931: This is a temporary fix to allow models to own params
# from estimators. Eventually, these params should be in models through
# using common base classes between estimators and models.
model._create_params_from_java()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So right now this would apply to all of the models, would it make sense to make it so that we can selectively move the params forward one at a time?

@BryanCutler
Copy link
Member Author

Thanks @holdenk for the review! I think I wrote the description a little too rushed, so let me clarify a bit...

The temporary "fix" will just create empty params in the model if they exist in the Java model but not the Python one. There should be no risk of having these added to the Python model since they are empty when created and not yet defined with a value. These params will be set in 2 ways: 1) after the model is fit in the call to _copy_values where the value is copied from the estimator for any matching params, 2) when the model is loaded there is a call to _transfer_params_from_java that will copy value if the the Java param has been explicitly set (I think I need to add something here for the case that the Java model has a default value but Python model doesn't).

I think the best way forward to get parity with the Scala API is to then organize a JIRA with subtasks to update the Python ML class hierarchies to match the Scala ones, so that the Params will be defined that way with proper "get" and "set" methods too. It might be good to also have a Python test that checks for matching params in Java for both the estimators and models. It could be ignored by default and then enabled during the QA period. The temporary fix here would continue to work and not interfere while the params are being added. It could be removed once we feel that most of the params have been properly added and close to matching the Scala API.

@SparkQA
Copy link

SparkQA commented May 8, 2017

Test build #76596 has finished for PR 17849 at commit 4a66e90.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@BryanCutler
Copy link
Member Author

ping @jkbradley @holdenk , please have a look when you can, thanks!

@holdenk
Copy link
Contributor

holdenk commented Jul 30, 2017

This looks pretty reasonable, sorry for the delay. If you have a chance to update this to master would be good to do.

@BryanCutler
Copy link
Member Author

Thanks @holdenk! Sure, I'll update to master

@SparkQA
Copy link

SparkQA commented Jul 31, 2017

Test build #80089 has finished for PR 17849 at commit 4affa01.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@BryanCutler
Copy link
Member Author

ping @holdenk - think this is good to go?

@WeichenXu123
Copy link
Contributor

Thanks your work on this but I am curious what is the benefit of doing this? In pyspark there is no param in Model itself currently, what is the problem or bugs it can resolve after adding params to pyspark model ?

@BryanCutler
Copy link
Member Author

If params are defined in the PySpark model, when that model is fit a Scala version is created then the PySpark model is wrapped around it. The param values from the Scala version are never transferred to the PySpark model, so the defined params will only have default values.

@BryanCutler
Copy link
Member Author

ping @holdenk , also @HyukjinKwon if you are able to take a look

@@ -263,7 +284,8 @@ def _fit_java(self, dataset):

def _fit(self, dataset):
java_model = self._fit_java(dataset)
return self._create_model(java_model)
model = self._create_model(java_model)
return self._copyValues(model)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the crucial line being added in this PR. Without this, if a Python model defines a param (matching one from Scala), then when the model is fit in Scala that param value will never be sent back to Python.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I think it is going to copy values from the estimator to the created model. So I think we assume that the params in estimator and model are the same?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that is the assumption and it's the same on the Scala side too. The estimators and models should both have a shared mixin that defines the common params used. That's how it's done on the Scala side and Python should follow (once that's done, the temporary fix from here can be removed).

@HyukjinKwon
Copy link
Member

I am rather a backend developer and work together with data scientists. So, my ML knowledge is limited (am studying hard :)). Will leave few comments together if there are some nits and someone starts to review so that they can be addressed together.

cc @viirya who I believe knows ML bit and @zero323 who I believe should be able to review this (but now is inactive though), are you maybe able to make a pass for this one?

@HyukjinKwon
Copy link
Member

HyukjinKwon commented Aug 10, 2017

Will try to give a pass anyway.

java_param_name = java_param.name()
if not hasattr(self, java_param_name):
param = Param(self, java_param_name, java_param.doc())
setattr(param, "created_from_java_param", True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, would you mind if I ask where created_from_java_param is used?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is part of a temporary fix to add Params that are defined in Java but not in Python, then this just adds a tag to the Param in case something goes wrong we will know the param was created here.

from pyspark.ml.param import Param
for java_param in java_params:
java_param_name = java_param.name()
if not hasattr(self, java_param_name):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If self contains a same name attribute which is not a Param, should we process it like throw exception?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, it's possible that there could be an attribute with that name that is not a param. If that's the case, then it is probably best to just ignore silently since this is not critical to the model.

@holdenk
Copy link
Contributor

holdenk commented Aug 10, 2017

Sorry, let me try and take a look tomorrow.

if self._java_obj.isSet(java_param):
if self._java_obj.isSet(java_param) or (
# SPARK-10931: Temporary fix for params that have a default in Java
self._java_obj.hasDefault(java_param) and not self.isDefined(param)):
Copy link
Member

@viirya viirya Aug 10, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change will make a default value for a param in java side as an user-provided param value in python side. I think we should use _setDefault for default value instead of _set.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True. I was thinking since this is part of the temporary fix, then it doesn't matter, but it won't be much extra to use _setDefault and probably be clearer.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, fixed to use _setDefault

test_self.assertEqual(
param_names, sorted(java_param_names),
"Param list in Python does not match Java for %s:\nJava = %s\nPython = %s"
% (py_stage_str, java_param_names, param_names))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 436-443 is the only change to check_params?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also changed the return to continue on line 454, this loop is checking all params so it was meant to skip over random seed params - not break out of the loop entirely (this is why that default value for MLP was missed). I cleaned up the NaN checks, before it was just checking for Imputer params, but it should be the same for any params with NaN's as default values. This is lines 460-462

@@ -1572,7 +1588,8 @@ def test_java_params(self):
for name, cls in inspect.getmembers(module, inspect.isclass):
if not name.endswith('Model') and issubclass(cls, JavaParams)\
and not inspect.isabstract(cls):
self.check_params(cls())
# NOTE: disable check_params_exist until there is parity with Scala API
ParamTests.check_params(self, cls(), check_params_exist=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This skips param test for Model. Should we do similar check to all models?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, ideally but most of the models need to be trained first so that is why they are skipped here. Some basic framework would need to be added to allow this, and I'm looking into that as a follow on.

@BryanCutler
Copy link
Member Author

Thanks for reviewing @viirya and @HyukjinKwon !
Btw, the temporary fix I talk about here is an optional addition to this PR to allow users to access model param values this way decision_tree_model.getOrDefault("maxDepth") as a workaround until proper accessors (like getMaxDepth()) can be added, since I've seen a lot of JIRAs with people asking for this.

@SparkQA
Copy link

SparkQA commented Aug 10, 2017

Test build #80499 has finished for PR 17849 at commit f4a657e.

  • This patch fails Python style tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Aug 10, 2017

Test build #80506 has finished for PR 17849 at commit 07f6e85.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@holdenk
Copy link
Contributor

holdenk commented Aug 18, 2017

LGTM, its certainly sort of an intermediary fix state but making the params accessible without users having to go through py4j manually is worth while.

I'll leave this over the weekend in case anyone has issues.

@BryanCutler
Copy link
Member Author

@holdenk , do you think this is good to go now?

@WeichenXu123
Copy link
Contributor

What do you think about this ? @jkbradley

@holdenk
Copy link
Contributor

holdenk commented Aug 22, 2017

I think its good to go for master pending jenkins (it's been awhile since the last run). So let's just make sure everything is still ok: Jenkins retest this please.

@SparkQA
Copy link

SparkQA commented Aug 22, 2017

Test build #81004 has finished for PR 17849 at commit 07f6e85.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@asfgit asfgit closed this in 41bb1dd Aug 23, 2017
@holdenk
Copy link
Contributor

holdenk commented Aug 23, 2017

Merged to master, thanks everyone :) (There is also a follow up JIRA https://issues.apache.org/jira/browse/SPARK-21812 for explicitly defining all of the params in Python).

@BryanCutler
Copy link
Member Author

Thanks @holdenk!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants