Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
rwedge committed Feb 25, 2025
1 parent eed8b6a commit be9f309
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
12 changes: 6 additions & 6 deletions tests/benchmark/supported_dtypes_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,12 +279,12 @@ def test_transformer(dtype, data, sdtype, transformer, transformer_kwargs):
if sdtype != transformer.INPUT_SDTYPE:
pytest.skip("Sdtype does not match transformer's input type, skipping.")

previous_fit_result, _ = get_previous_dtype_result(dtype, sdtype, f"RDT_{transformer_name}_FIT")
previous_fit_result, _ = get_previous_dtype_result(dtype, sdtype, f'RDT_{transformer_name}_FIT')
previous_transform_result, _ = get_previous_dtype_result(
dtype, sdtype, f"RDT_{transformer_name}_TRANSFORM"
dtype, sdtype, f'RDT_{transformer_name}_TRANSFORM'
)
previous_reverse_result, _ = get_previous_dtype_result(
dtype, sdtype, f"RDT_{transformer_name}_REVERSE"
dtype, sdtype, f'RDT_{transformer_name}_REVERSE'
)
fit_result = False
transform_result = False
Expand All @@ -303,9 +303,9 @@ def test_transformer(dtype, data, sdtype, transformer, transformer_kwargs):
save_results_to_json({
'dtype': dtype,
'sdtype': sdtype,
f"RDT_{transformer_name}_FIT": fit_result,
f"RDT_{transformer_name}_TRANSFORM": transform_result,
f"RDT_{transformer_name}_REVERSE": reverse_result,
f'RDT_{transformer_name}_FIT': fit_result,
f'RDT_{transformer_name}_TRANSFORM': transform_result,
f'RDT_{transformer_name}_REVERSE': reverse_result,
})

fit_assertion_message = f"{dtype} is no longer supported by 'RDT_{transformer_name}_FIT'."
Expand Down
6 changes: 3 additions & 3 deletions tests/benchmark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _load_temp_results(filename):
df.iloc[:, 2:] = df.groupby(['dtype', 'sdtype']).transform(lambda x: x.ffill().bfill())
for column in df.columns:
if column not in ('sdtype', 'dtype'):
df[column] = df[column].astype("float")
df[column] = df[column].astype('float')

return df.drop_duplicates().reset_index(drop=True)

Expand Down Expand Up @@ -193,14 +193,14 @@ def compare_and_store_results_in_gdrive():
for name, current_results_df in results.items():
for startswith in measurement_prefixes:
supported_df = calculate_support_percentage(current_results_df, startswith)
column_name = f"{name} {startswith}"
column_name = f'{name} {startswith}'
if summary.empty:
summary = supported_df.rename(columns={'percentage_supported': column_name})
else:
summary[column_name] = supported_df['percentage_supported']

for startswith in measurement_prefixes:
measurement_columns = [f"{name} {startswith}" for name in results]
measurement_columns = [f'{name} {startswith}' for name in results]
summary[startswith] = summary[measurement_columns].mean(axis=1).round(2)

for col in summary.columns:
Expand Down

0 comments on commit be9f309

Please sign in to comment.