-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-19454][PYTHON][SQL] DataFrame.replace improvements #16793
Changes from 1 commit
a02e4ff
db8f4c9
e014867
17e6820
03303df
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1268,7 +1268,7 @@ def fillna(self, value, subset=None): | |
return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx) | ||
|
||
@since(1.4) | ||
def replace(self, to_replace, value, subset=None): | ||
def replace(self, to_replace, value=None, subset=None): | ||
"""Returns a new :class:`DataFrame` replacing a value with another value. | ||
:func:`DataFrame.replace` and :func:`DataFrameNaFunctions.replace` are | ||
aliases of each other. | ||
|
@@ -1307,43 +1307,66 @@ def replace(self, to_replace, value, subset=None): | |
|null| null|null| | ||
+----+------+----+ | ||
""" | ||
if not isinstance(to_replace, (float, int, long, basestring, list, tuple, dict)): | ||
# Helper functions | ||
def all_of(types): | ||
def all_of_(xs): | ||
return all(isinstance(x, types) for x in xs) | ||
return all_of_ | ||
|
||
all_of_bool = all_of(bool) | ||
all_of_str = all_of(basestring) | ||
all_of_numeric = all_of((float, int, long)) | ||
|
||
# Validate input types | ||
valid_types = (bool, float, int, long, basestring, list, tuple) | ||
if not isinstance(to_replace, valid_types + (dict, )): | ||
raise ValueError( | ||
"to_replace should be a float, int, long, string, list, tuple, or dict") | ||
"to_replace should be a float, int, long, string, list, tuple, or dict. " | ||
"Got {0}".format(type(to_replace))) | ||
|
||
if not isinstance(value, (float, int, long, basestring, list, tuple)): | ||
raise ValueError("value should be a float, int, long, string, list, or tuple") | ||
if (not isinstance(value, valid_types) and | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems like a weird split. |
||
not isinstance(to_replace, dict)): | ||
raise ValueError("If to_replace is not a dict, value should be " | ||
"a float, int, long, string, list, or tuple. " | ||
"Got {0}".format(type(value))) | ||
|
||
if isinstance(to_replace, (list, tuple)) and isinstance(value, (list, tuple)): | ||
if len(to_replace) != len(value): | ||
raise ValueError("to_replace and value lists should be of the same length. " | ||
"Got {0} and {1}".format(len(to_replace), len(value))) | ||
|
||
rep_dict = dict() | ||
if not (subset is None or isinstance(subset, (list, tuple, basestring))): | ||
raise ValueError("subset should be a list or tuple of column names, " | ||
"column name or None. Got {0}".format(type(subset))) | ||
|
||
# Reshape input arguments if necessary | ||
if isinstance(to_replace, (float, int, long, basestring)): | ||
to_replace = [to_replace] | ||
|
||
if isinstance(to_replace, tuple): | ||
to_replace = list(to_replace) | ||
if isinstance(value, (float, int, long, basestring)): | ||
value = [value for _ in range(len(to_replace))] | ||
|
||
if isinstance(value, tuple): | ||
value = list(value) | ||
|
||
if isinstance(to_replace, list) and isinstance(value, list): | ||
if len(to_replace) != len(value): | ||
raise ValueError("to_replace and value lists should be of the same length") | ||
rep_dict = dict(zip(to_replace, value)) | ||
elif isinstance(to_replace, list) and isinstance(value, (float, int, long, basestring)): | ||
rep_dict = dict([(tr, value) for tr in to_replace]) | ||
elif isinstance(to_replace, dict): | ||
if isinstance(to_replace, dict): | ||
rep_dict = to_replace | ||
if value is not None: | ||
warnings.warn("to_replace is a dict, but value is not None. " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this need to be split? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe not. |
||
"value will be ignored.") | ||
else: | ||
rep_dict = dict(zip(to_replace, value)) | ||
|
||
if subset is None: | ||
return DataFrame(self._jdf.na().replace('*', rep_dict), self.sql_ctx) | ||
elif isinstance(subset, basestring): | ||
if isinstance(subset, basestring): | ||
subset = [subset] | ||
|
||
if not isinstance(subset, (list, tuple)): | ||
raise ValueError("subset should be a list or tuple of column names") | ||
# Check if we won't pass mixed type generics | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This reads a bit awkwardly. How about "Verify we were not passed in mixed type generics."? |
||
if not any(all_of_type(rep_dict.keys()) and all_of_type(rep_dict.values()) | ||
for all_of_type in [all_of_bool, all_of_str, all_of_numeric]): | ||
raise ValueError("Mixed type replacements are not supported") | ||
|
||
return DataFrame( | ||
self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), self.sql_ctx) | ||
if subset is None: | ||
return DataFrame(self._jdf.na().replace('*', rep_dict), self.sql_ctx) | ||
else: | ||
return DataFrame( | ||
self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), self.sql_ctx) | ||
|
||
@since(2.0) | ||
def approxQuantile(self, col, probabilities, relativeError): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1591,6 +1591,67 @@ def test_replace(self): | |
self.assertEqual(row.age, 10) | ||
self.assertEqual(row.height, None) | ||
|
||
# replace with lists | ||
row = self.spark.createDataFrame( | ||
[(u'Alice', 10, 80.1)], schema).replace([u'Alice'], [u'Ann']).first() | ||
self.assertTupleEqual(row, (u'Ann', 10, 80.1)) | ||
|
||
# replace with dict | ||
row = self.spark.createDataFrame( | ||
[(u'Alice', 10, 80.1)], schema).replace({10: 11}).first() | ||
self.assertTupleEqual(row, (u'Alice', 11, 80.1)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the only test of "new" functionality (excluding error cases), correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These tests are mostly a side effect of discussions related to #16792 Right now test coverage is low and we depend on a certain behavior of Py4j and Scala counterpart. Also I wanted to be sure that all the expected types are still accepted after the changes I've made. So maybe not necessary, but I will argue it is a good idea to have these. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think (and I could be wrong) that @nchammas was suggesting it might make sense to have some more tests with dict, not that the other additional new tests are bad. |
||
|
||
# replace with tuples | ||
row = self.spark.createDataFrame( | ||
[(u'Alice', 10, 80.1)], schema).replace((u'Alice', ), (u'Bob', )).first() | ||
self.assertTupleEqual(row, (u'Bob', 10, 80.1)) | ||
|
||
# replace multiple columns | ||
row = self.spark.createDataFrame( | ||
[(u'Alice', 10, 80.0)], schema).replace((10, 80.0), (20, 90)).first() | ||
self.assertTupleEqual(row, (u'Alice', 20, 90.0)) | ||
|
||
# test for mixed numerics | ||
row = self.spark.createDataFrame( | ||
[(u'Alice', 10, 80.0)], schema).replace((10, 80), (20, 90.5)).first() | ||
self.assertTupleEqual(row, (u'Alice', 20, 90.5)) | ||
|
||
row = self.spark.createDataFrame( | ||
[(u'Alice', 10, 80.0)], schema).replace({10: 20, 80: 90.5}).first() | ||
self.assertTupleEqual(row, (u'Alice', 20, 90.5)) | ||
|
||
# replace with boolean | ||
row = (self | ||
.spark.createDataFrame([(u'Alice', 10, 80.0)], schema) | ||
.selectExpr("name = 'Bob'", 'age <= 15') | ||
.replace(False, True).first()) | ||
self.assertTupleEqual(row, (True, True)) | ||
|
||
# should fail if subset is not list, tuple or None | ||
with self.assertRaises(ValueError): | ||
self.spark.createDataFrame( | ||
[(u'Alice', 10, 80.1)], schema).replace({10: 11}, subset=1).first() | ||
|
||
# should fail if to_replace and value have different length | ||
with self.assertRaises(ValueError): | ||
self.spark.createDataFrame( | ||
[(u'Alice', 10, 80.1)], schema).replace(["Alice", "Bob"], ["Eve"]).first() | ||
|
||
# should fail if when received unexpected type | ||
with self.assertRaises(ValueError): | ||
from datetime import datetime | ||
self.spark.createDataFrame( | ||
[(u'Alice', 10, 80.1)], schema).replace(datetime.now(), datetime.now()).first() | ||
|
||
# should fail if provided mixed type replacements | ||
with self.assertRaises(ValueError): | ||
self.spark.createDataFrame( | ||
[(u'Alice', 10, 80.1)], schema).replace(["Alice", 10], ["Eve", 20]).first() | ||
|
||
with self.assertRaises(ValueError): | ||
self.spark.createDataFrame( | ||
[(u'Alice', 10, 80.1)], schema).replace({u"Alice": u"Bob", 10: 20}).first() | ||
|
||
def test_capture_analysis_exception(self): | ||
self.assertRaises(AnalysisException, lambda: self.spark.sql("select abc")) | ||
self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b")) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe give this a doc-string to clarify what all_of does even though its not user facing better to have a docstring than not.