Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: overwrite keyword in DataFrame.update #1478

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3143,7 +3143,7 @@ def combine_first(self, other):
combiner = lambda x, y: np.where(isnull(x), y, x)
return self.combine(other, combiner)

def update(self, other, join='left'):
def update(self, other, join='left', overwrite=True, filter_func=None):
"""
Modify DataFrame in place using non-NA values from passed
DataFrame. Aligns on indices
Expand All @@ -3152,6 +3152,11 @@ def update(self, other, join='left'):
----------
other : DataFrame
join : {'left', 'right', 'outer', 'inner'}, default 'left'
overwrite : boolean, default True
If True then overwrite values for common keys in the calling frame
filter_func : callable(1d-array) -> 1d-array<boolean>, default None
Can choose to replace values other than NA. Return True for values
that should be updated
"""
if join != 'left':
raise NotImplementedError
Expand All @@ -3160,7 +3165,14 @@ def update(self, other, join='left'):
for col in self.columns:
this = self[col].values
that = other[col].values
self[col] = np.where(isnull(that), this, that)
if filter_func is not None:
mask = -filter_func(this) | isnull(that)
else:
if overwrite:
mask = isnull(that)
else:
mask = notnull(this)
self[col] = np.where(mask, this, that)

#----------------------------------------------------------------------
# Misc methods
Expand Down
35 changes: 35 additions & 0 deletions pandas/tests/test_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -5235,6 +5235,41 @@ def test_update(self):
[1.5, nan, 7.]])
assert_frame_equal(df, expected)

def test_update_nooverwrite(self):
df = DataFrame([[1.5, nan, 3.],
[1.5, nan, 3.],
[1.5, nan, 3],
[1.5, nan, 3]])

other = DataFrame([[3.6, 2., np.nan],
[np.nan, np.nan, 7]], index=[1, 3])

df.update(other, overwrite=False)

expected = DataFrame([[1.5, nan, 3],
[1.5, 2, 3],
[1.5, nan, 3],
[1.5, nan, 3.]])
assert_frame_equal(df, expected)

def test_update_filtered(self):
df = DataFrame([[1.5, nan, 3.],
[1.5, nan, 3.],
[1.5, nan, 3],
[1.5, nan, 3]])

other = DataFrame([[3.6, 2., np.nan],
[np.nan, np.nan, 7]], index=[1, 3])

df.update(other, filter_func=lambda x: x > 2)

expected = DataFrame([[1.5, nan, 3],
[1.5, nan, 3],
[1.5, nan, 3],
[1.5, nan, 7.]])
assert_frame_equal(df, expected)


def test_combineAdd(self):
# trivial
comb = self.frame.combineAdd(self.frame)
Expand Down