Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make '--platform' argument mandatory in qualification and profiling CLI to prevent incorrect behavior #1463

Merged
merged 1 commit into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions user_tools/src/spark_rapids_tools/cmdli/argprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,15 @@ def init_extra_arg_cases(self) -> list:
def define_invalid_arg_cases(self) -> None:
super().define_invalid_arg_cases()
self.define_rejected_missing_eventlogs()
self.rejected['Missing Platform argument'] = {
'valid': False,
'callable': partial(self.raise_validation_exception,
'Cannot run tool cmd without platform argument. Re-run the command '
'providing the platform argument.'),
'cases': [
[ArgValueCase.UNDEFINED, ArgValueCase.IGNORE, ArgValueCase.IGNORE]
]
}
self.rejected['Cluster By Name Without Platform Hints'] = {
'valid': False,
'callable': partial(self.raise_validation_exception,
Expand Down
65 changes: 23 additions & 42 deletions user_tools/tests/spark_rapids_tools_ut/test_tool_argprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,9 @@ def test_with_platform_with_eventlogs(self, get_ut_data_dir, tool_name, csp):
cost_savings_enabled=False,
expected_platform=csp)

# should pass: platform not provided; event logs are provided
tool_args = self.create_tool_args_should_pass(tool_name,
eventlogs=f'{get_ut_data_dir}/eventlogs')
# for qualification, cost savings should be disabled because cluster is not provided
self.validate_tool_args(tool_name=tool_name, tool_args=tool_args,
cost_savings_enabled=False,
expected_platform=CspEnv.ONPREM)
# should fail: platform must be provided
self.create_tool_args_should_fail(tool_name,
eventlogs=f'{get_ut_data_dir}/eventlogs')

@pytest.mark.parametrize('tool_name', ['qualification', 'profiling'])
@pytest.mark.parametrize('csp', all_csps)
Expand All @@ -150,17 +146,19 @@ def test_with_platform_with_eventlogs_with_jar_files(self, get_ut_data_dir, tool
tools_jar=f'{get_ut_data_dir}/tools_mock.jar')
assert tool_args['toolsJar'] == f'{get_ut_data_dir}/tools_mock.jar'

# should pass: tools_jar is correct
tool_args = self.create_tool_args_should_pass(tool_name, eventlogs=f'{get_ut_data_dir}/eventlogs',
tools_jar=f'{get_ut_data_dir}/tools_mock.jar')
assert tool_args['toolsJar'] == f'{get_ut_data_dir}/tools_mock.jar'
# should fail: platform must be provided
self.create_tool_args_should_fail(tool_name,
eventlogs=f'{get_ut_data_dir}/eventlogs',
tools_jar=f'{get_ut_data_dir}/tools_mock.jar')

# should fail: tools_jar does not exist
self.create_tool_args_should_fail(tool_name, eventlogs=f'{get_ut_data_dir}/eventlogs',
self.create_tool_args_should_fail(tool_name, platform=csp,
eventlogs=f'{get_ut_data_dir}/eventlogs',
tools_jar=f'{get_ut_data_dir}/tools_mock.txt')

# should fail: tools_jar is not .jar extension
self.create_tool_args_should_fail(tool_name, eventlogs=f'{get_ut_data_dir}/eventlogs',
self.create_tool_args_should_fail(tool_name, platform=csp,
eventlogs=f'{get_ut_data_dir}/eventlogs',
tools_jar=f'{get_ut_data_dir}/worker_info.yaml')

@pytest.mark.parametrize('tool_name', ['qualification', 'profiling'])
Expand Down Expand Up @@ -230,25 +228,15 @@ def test_with_platform_with_cluster_props(self, get_ut_data_dir, tool_name, csp,
self.validate_tool_args(tool_name=tool_name, tool_args=tool_args,
cost_savings_enabled=True,
expected_platform=csp)

# should pass: platform not provided; missing eventlogs should be accepted for all CSPs (except onPrem)
# because the eventlogs can be retrieved from the cluster properties
tool_args = self.create_tool_args_should_pass(tool_name,
cluster=cluster_prop_file)
# for qualification, cost savings should be enabled because cluster is provided
self.validate_tool_args(tool_name=tool_name, tool_args=tool_args,
cost_savings_enabled=True,
expected_platform=csp)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add back this should pass test case when platform is available?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting observation @cindyyuanjiang !
I am fine with whichever you folks agree on.

Copy link
Collaborator Author

@parthosa parthosa Dec 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all of these tests already include a should_pass case with the platform defined in the previous block.

These updated tests were expected to pass when the platform is not defined. Now, these tests are expected to fail when the platform is not defined.

Hence, I dont think we need to add more tests for should_pass. Let me know your thoughts.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@parthosa thanks, I revisited the files and confirmed the other test cases.

else:
# should fail: onprem platform cannot retrieve eventlogs from cluster properties
self.create_tool_args_should_fail(tool_name,
platform=csp,
cluster=cluster_prop_file)

# should fail: platform not provided; defaults platform to onprem, cannot retrieve eventlogs from
# cluster properties
self.create_tool_args_should_fail(tool_name,
cluster=cluster_prop_file)
# should fail: platform must be provided for all CSPs as well as onprem
self.create_tool_args_should_fail(tool_name,
cluster=cluster_prop_file)

@pytest.mark.parametrize('tool_name', ['qualification', 'profiling'])
@pytest.mark.parametrize('csp,prop_path', all_cpu_cluster_props)
Expand All @@ -266,14 +254,10 @@ def test_with_platform_with_cluster_props_with_eventlogs(self, get_ut_data_dir,
cost_savings_enabled=CspEnv(csp) != CspEnv.ONPREM,
expected_platform=csp)

# should pass: platform not provided; cluster properties and eventlogs are provided
tool_args = self.create_tool_args_should_pass(tool_name,
cluster=cluster_prop_file,
eventlogs=f'{get_ut_data_dir}/eventlogs')
# for qualification, cost savings should be enabled because cluster is provided (except for onprem)
self.validate_tool_args(tool_name=tool_name, tool_args=tool_args,
cost_savings_enabled=CspEnv(csp) != CspEnv.ONPREM,
expected_platform=csp)
# should fail: platform must be provided
self.create_tool_args_should_fail(tool_name,
cluster=cluster_prop_file,
eventlogs=f'{get_ut_data_dir}/eventlogs')

@pytest.mark.parametrize('tool_name', ['profiling'])
@pytest.mark.parametrize('csp', all_csps)
Expand Down Expand Up @@ -308,18 +292,15 @@ def test_with_platform_with_autotuner_with_eventlogs(self, get_ut_data_dir, tool
cost_savings_enabled=False,
expected_platform=csp)

# should pass: platform not provided; autotuner properties and eventlogs are provided
tool_args = self.create_tool_args_should_pass(tool_name,
cluster=autotuner_prop_file,
eventlogs=f'{get_ut_data_dir}/eventlogs')
# cost savings should be disabled for profiling
self.validate_tool_args(tool_name=tool_name, tool_args=tool_args,
cost_savings_enabled=False,
expected_platform=CspEnv.ONPREM)
# should fail: platform must be provided
self.create_tool_args_should_fail(tool_name,
cluster=autotuner_prop_file,
eventlogs=f'{get_ut_data_dir}/eventlogs')

@pytest.mark.parametrize('prop_path', [autotuner_prop_path])
def test_profiler_with_driverlog(self, get_ut_data_dir, prop_path):
prof_args = AbsToolUserArgModel.create_tool_args('profiling',
platform=CspEnv.get_default(),
driverlog=f'{get_ut_data_dir}/{prop_path}')
assert not prof_args['requiresEventlogs']
assert prof_args['rapidOptions']['driverlog'] == f'{get_ut_data_dir}/{prop_path}'
Expand Down
Loading