Skip to content

Commit

Permalink
Remove hard-coded keys from production DFP (#607)
Browse files Browse the repository at this point in the history
Remove hard-coded keys `username` and `timestamp` from production DFP stages. Allows source dataframe column names to be configured with `config.ae.userid_column_name` and `config.ae.timestamp_column_name`.

Fixes #606

Authors:
  - Eli Fajardo (https://github.com/efajardo-nv)
  - David Gardner (https://github.com/dagardner-nv)
  - Michael Demoret (https://github.com/mdemoret-nv)

Approvers:
  - David Gardner (https://github.com/dagardner-nv)
  - Michael Demoret (https://github.com/mdemoret-nv)

URL: #607
  • Loading branch information
efajardo-nv authored Jan 24, 2023
1 parent cdbe4ce commit f5a3558
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def _get_or_create_dataframe_from_s3_batch(
output_df: pd.DataFrame = pd.concat(dfs)

# Finally sort by timestamp and then reset the index
output_df.sort_values(by=["timestamp"], inplace=True)
output_df.sort_values(by=[self._config.ae.timestamp_column_name], inplace=True)

output_df.reset_index(drop=True, inplace=True)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ def on_data(self, message: MultiAEMessage):
"Epochs": model.lr_decay.state_dict().get("last_epoch", "unknown"),
"Learning rate": model.lr,
"Batch size": model.batch_size,
"Start Epoch": message.get_meta("timestamp").min(),
"End Epoch": message.get_meta("timestamp").max(),
"Start Epoch": message.get_meta(self._config.ae.timestamp_column_name).min(),
"End Epoch": message.get_meta(self._config.ae.timestamp_column_name).max(),
"Log Count": message.mess_count,
})

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,11 @@ def extract_users(self, message: cudf.DataFrame):

if (self._include_individual):

split_dataframes.update(
{username: user_df
for username, user_df in message.groupby("username", sort=False)})
split_dataframes.update({
username: user_df
for username,
user_df in message.groupby(self._config.ae.userid_column_name, sort=False)
})

output_messages: typing.List[DFPMessageMeta] = []

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ class CachedUserWindow:
def append_dataframe(self, incoming_df: pd.DataFrame) -> bool:

# Filter the incoming df by epochs later than the current max_epoch
filtered_df = incoming_df[incoming_df["timestamp"] > self.max_epoch]
filtered_df = incoming_df[incoming_df[self.timestamp_column] > self.max_epoch]

if (len(filtered_df) == 0):
# We have nothing new to add. Double check that we fit within the window
before_history = incoming_df[incoming_df["timestamp"] < self.min_epoch]
before_history = incoming_df[incoming_df[self.timestamp_column] < self.min_epoch]

return len(before_history) == 0

Expand All @@ -59,7 +59,7 @@ def append_dataframe(self, incoming_df: pd.DataFrame) -> bool:
# Set the filtered index
filtered_df.index = range(self.total_count, self.total_count + len(filtered_df))

# Save the row hash to make it easier to find later. Do this before the batch so it doesnt participate
# Save the row hash to make it easier to find later. Do this before the batch so it doesn't participate
filtered_df["_row_hash"] = pd.util.hash_pandas_object(filtered_df, index=False)

# Use batch id to distinguish groups in the same dataframe
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,10 +231,16 @@ def run_pipeline(train_users,
groupby_column=config.ae.userid_column_name),
CustomColumn(name="locincrement",
dtype=int,
process_column_fn=partial(create_increment_col, column_name="location")),
process_column_fn=partial(create_increment_col,
column_name="location",
groupby_column=config.ae.userid_column_name,
timestamp_column=config.ae.timestamp_column_name)),
CustomColumn(name="appincrement",
dtype=int,
process_column_fn=partial(create_increment_col, column_name="appDisplayName")),
process_column_fn=partial(create_increment_col,
column_name="appDisplayName",
groupby_column=config.ae.userid_column_name,
timestamp_column=config.ae.timestamp_column_name))
]

preprocess_schema = DataFrameInputSchema(column_info=preprocess_column_info, preserve_columns=["_batch_id"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,10 @@ def run_pipeline(train_users,
groupby_column=config.ae.userid_column_name),
CustomColumn(name="locincrement",
dtype=int,
process_column_fn=partial(create_increment_col, column_name="location")),
process_column_fn=partial(create_increment_col,
column_name="location",
groupby_column=config.ae.userid_column_name,
timestamp_column=config.ae.timestamp_column_name))
]

preprocess_schema = DataFrameInputSchema(column_info=preprocess_column_info, preserve_columns=["_batch_id"])
Expand Down

0 comments on commit f5a3558

Please sign in to comment.