Skip to content

Commit

Permalink
[data/preprocessors] do not fail transform_batch on missing column (r…
Browse files Browse the repository at this point in the history
…ay-project#48137)

Signed-off-by: JP-sDEV <jon.pablo80@gmail.com>
  • Loading branch information
martinbomio authored and JP-sDEV committed Nov 14, 2024
1 parent e6f71e9 commit e5b836b
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
12 changes: 10 additions & 2 deletions python/ray/data/preprocessors/imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

@PublicAPI(stability="alpha")
class SimpleImputer(Preprocessor):
"""Replace missing values with imputed values.
"""Replace missing values with imputed values. If the column is missing from a
batch, it will be filled with the imputed value.
Examples:
>>> import pandas as pd
Expand Down Expand Up @@ -131,7 +132,14 @@ def _transform_pandas(self, df: pd.DataFrame):
if is_categorical_dtype(df.dtypes[column]):
df[column] = df[column].cat.add_categories(value)

df = df.fillna(new_values)
for column_name in new_values:
if column_name not in df.columns:
# Create the column with the fill_value if it doesn't exist
df[column_name] = new_values[column_name]
else:
# Fill NaN (empty) values in the existing column with the fill_value
df[column_name].fillna(new_values[column_name], inplace=True)

return df

def __repr__(self):
Expand Down
12 changes: 12 additions & 0 deletions python/ray/data/tests/preprocessors/test_imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@ def test_simple_imputer():

assert pred_out_df.equals(pred_expected_df)

# with missing column
pred_in_df = pd.DataFrame.from_dict({"A": pred_col_a, "B": pred_col_b})
pred_out_df = imputer.transform_batch(pred_in_df)
pred_expected_df = pd.DataFrame.from_dict(
{
"A": pred_processed_col_a,
"B": pred_processed_col_b,
"C": pred_processed_col_c,
}
)
assert pred_out_df.equals(pred_expected_df)

# Test "most_frequent" strategy.
most_frequent_col_a = [1, 2, 2, None, None, None]
most_frequent_col_b = [None, "c", "c", "b", "b", "a"]
Expand Down

0 comments on commit e5b836b

Please sign in to comment.