Skip to content

Commit

Permalink
test: Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
setu4993 committed Jul 7, 2023
1 parent 99552ba commit 6165778
Show file tree
Hide file tree
Showing 24 changed files with 98 additions and 94 deletions.
30 changes: 23 additions & 7 deletions dataquality/core/finish.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,18 @@ def finish(
job_name=JobName.inference,
non_inference_logged=data_logger.non_inference_logged(),
)
res = api_client.make_request(RequestType.POST, url=f"{config.api_url}/{Route.jobs}", body=body)
print(f"Job {res['job_name']} successfully submitted. Results will be available " f"soon at {res['link']}")
res = api_client.make_request(
RequestType.POST, url=f"{config.api_url}/{Route.jobs}", body=body
)
print(
f"Job {res['job_name']} successfully submitted. Results will be available "
f"soon at {res['link']}"
)
if data_logger.logger_config.conditions:
print("Waiting for run to process before building run report... " "Don't close laptop or terminate shell.")
print(
"Waiting for run to process before building run report... "
"Don't close laptop or terminate shell."
)
wait_for_run()
build_run_report(
data_logger.logger_config.conditions,
Expand All @@ -96,7 +104,9 @@ def finish(


@check_noop
def wait_for_run(project_name: Optional[str] = None, run_name: Optional[str] = None) -> None:
def wait_for_run(
project_name: Optional[str] = None, run_name: Optional[str] = None
) -> None:
"""
Waits until a specific project run transitions from started to finished.
Defaults to the current run if project_name and run_name are empty.
Expand All @@ -110,7 +120,9 @@ def wait_for_run(project_name: Optional[str] = None, run_name: Optional[str] = N


@check_noop
def get_run_status(project_name: Optional[str] = None, run_name: Optional[str] = None) -> Dict[str, Any]:
def get_run_status(
project_name: Optional[str] = None, run_name: Optional[str] = None
) -> Dict[str, Any]:
"""
Returns the latest job of a specified project run.
Defaults to the current run if project_name and run_name are empty.
Expand All @@ -126,7 +138,9 @@ def get_run_status(project_name: Optional[str] = None, run_name: Optional[str] =


@check_noop
def _reset_run(project_id: UUID4, run_id: UUID4, task_type: Optional[TaskType] = None) -> None:
def _reset_run(
project_id: UUID4, run_id: UUID4, task_type: Optional[TaskType] = None
) -> None:
"""Clear the data in minio before uploading new data
If this is a run that already existed, we want to fully overwrite the old data.
Expand All @@ -135,7 +149,9 @@ def _reset_run(project_id: UUID4, run_id: UUID4, task_type: Optional[TaskType] =
"""
old_run_id = run_id
api_client.reset_run(project_id, old_run_id, task_type)
project_dir = f"{dataquality.get_data_logger().LOG_FILE_DIR}/{config.current_project_id}"
project_dir = (
f"{dataquality.get_data_logger().LOG_FILE_DIR}/{config.current_project_id}"
)
# All of the logged user data is to the old run ID, so rename it to the new ID
os.rename(f"{project_dir}/{old_run_id}", f"{project_dir}/{config.current_run_id}")
# Move std logs as well
Expand Down
25 changes: 19 additions & 6 deletions dataquality/core/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ class InitManager:
wait=wait_exponential_jitter(initial=0.1, max=2),
stop=stop_after_attempt(5),
)
def get_or_create_project(self, project_name: str, is_public: bool) -> Tuple[Dict, bool]:
def get_or_create_project(
self, project_name: str, is_public: bool
) -> Tuple[Dict, bool]:
"""Gets a project by name, or creates a new one if it doesn't exist.
Returns:
Expand All @@ -58,7 +60,9 @@ def get_or_create_project(self, project_name: str, is_public: bool) -> Tuple[Dic
print(f"✨ Initializing {created_str} {visibility} project '{project_name}'")
return project, created

def get_or_create_run(self, project_name: str, run_name: str, task_type: TaskType) -> Tuple[Dict, bool]:
def get_or_create_run(
self, project_name: str, run_name: str, task_type: TaskType
) -> Tuple[Dict, bool]:
"""Gets a run by name, or creates a new one if it doesn't exist.
Returns:
Expand All @@ -75,7 +79,9 @@ def get_or_create_run(self, project_name: str, run_name: str, task_type: TaskTyp
print(f"🏃‍♂️ {verb} {created_str} run '{run_name}'")
return run, created

def create_log_file_dir(self, project_id: UUID4, run_id: UUID4, overwrite_local: bool) -> None:
def create_log_file_dir(
self, project_id: UUID4, run_id: UUID4, overwrite_local: bool
) -> None:
write_output_dir = f"{BaseGalileoLogger.LOG_FILE_DIR}/{project_id}/{run_id}"
stdout_dir = f"{DQ_LOG_FILE_HOME}/{run_id}"
for out_dir in [write_output_dir, stdout_dir]:
Expand Down Expand Up @@ -222,13 +228,20 @@ def init(
"images",
GALILEO_DEFAULT_IMG_BUCKET_NAME,
)
config.minio_fqdn = _dq_healthcheck_response.get("minio_fqdn", os.getenv("MINIO_FQDN", None))
if config.minio_fqdn is not None and config.minio_fqdn.endswith(EXOSCALE_FQDN_SUFFIX):
config.minio_fqdn = _dq_healthcheck_response.get(
"minio_fqdn", os.getenv("MINIO_FQDN", None)
)
if config.minio_fqdn is not None and config.minio_fqdn.endswith(
EXOSCALE_FQDN_SUFFIX
):
config.is_exoscale_cluster = True

proj_created_str = "new" if proj_created else "existing"
run_created_str = "new" if run_created else "existing"
print(f"🛰 Connected to {proj_created_str} project '{project_name}', " f"and {run_created_str} run '{run_name}'.")
print(
f"🛰 Connected to {proj_created_str} project '{project_name}', "
f"and {run_created_str} run '{run_name}'."
)

config.update_file_config()
if config.current_project_id and config.current_run_id:
Expand Down
3 changes: 2 additions & 1 deletion dataquality/utils/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ def version_check() -> None:
server_semver = _get_api_version()
if _major_version(client_semver) != _major_version(server_semver):
get_dq_logger().warning(
"Major version mismatched between client, " f"{client_semver}, and server {server_semver}."
"Major version mismatched between client, "
f"{client_semver}, and server {server_semver}."
)


Expand Down
8 changes: 0 additions & 8 deletions tests/core/test_finish.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def test_get_run_status(mock_client: MagicMock) -> None:
assert status.get("status") == "in_progress"


@mock.patch.object(dataquality.core.finish, "_version_check")
@mock.patch.object(dataquality.core.finish, "_reset_run")
@mock.patch.object(dataquality.core.finish, "upload_dq_log_file")
@mock.patch.object(dataquality.clients.api.ApiClient, "make_request")
Expand All @@ -60,7 +59,6 @@ def test_finish_waits_default(
mock_make_request: MagicMock,
mock_upload_log_file: MagicMock,
mock_reset_run: MagicMock,
mock_version_check: MagicMock,
set_test_config,
) -> None:
set_test_config(task_type=TaskType.text_classification)
Expand All @@ -71,7 +69,6 @@ def test_finish_waits_default(
mock_wait_for_run.assert_called_once()


@mock.patch.object(dataquality.core.finish, "_version_check")
@mock.patch.object(dataquality.core.finish, "_reset_run")
@mock.patch.object(dataquality.core.finish, "upload_dq_log_file")
@mock.patch.object(dataquality.clients.api.ApiClient, "make_request")
Expand All @@ -86,7 +83,6 @@ def test_finish_no_waits_when_false(
mock_make_request: MagicMock,
mock_upload_log_file: MagicMock,
mock_reset_run: MagicMock,
mock_version_check: MagicMock,
set_test_config,
) -> None:
set_test_config(task_type=TaskType.text_classification)
Expand All @@ -97,7 +93,6 @@ def test_finish_no_waits_when_false(
mock_wait_for_run.assert_not_called()


@mock.patch.object(dataquality.core.finish, "_version_check")
@mock.patch.object(dataquality.core.finish, "_reset_run")
@mock.patch.object(dataquality.core.finish, "upload_dq_log_file")
@mock.patch.object(dataquality.clients.api.ApiClient, "make_request")
Expand All @@ -107,7 +102,6 @@ def test_finish_ignores_missing_inference_name_inframe(
mock_make_request: MagicMock,
mock_upload_log_file: MagicMock,
mock_reset_run: MagicMock,
mock_version_check: MagicMock,
set_test_config: Callable,
cleanup_after_use: Generator,
) -> None:
Expand All @@ -132,7 +126,6 @@ def test_finish_ignores_missing_inference_name_inframe(
dataquality.finish()


@mock.patch.object(dataquality.core.finish, "_version_check")
@mock.patch.object(dataquality.core.finish, "_reset_run")
@mock.patch.object(dataquality.core.finish, "upload_dq_log_file")
@mock.patch.object(dataquality.clients.api.ApiClient, "make_request")
Expand All @@ -149,7 +142,6 @@ def test_finish_with_conditions(
mock_make_request: MagicMock,
mock_upload_log_file: MagicMock,
mock_reset_run: MagicMock,
mock_version_check: MagicMock,
test_session_vars: TestSessionVariables,
set_test_config,
) -> None:
Expand Down
Loading

0 comments on commit 6165778

Please sign in to comment.