From 8a32408991a1aeeaf5243e64ffd4ace49a7aaaa4 Mon Sep 17 00:00:00 2001 From: you-n-g Date: Fri, 18 Nov 2022 13:11:31 +0800 Subject: [PATCH] Optimize the implementation of uri & Fix async log bug (#1364) * Optimize the implementation of uri * remove redundant func * Set the right order of _set_client_uri * Update qlib/workflow/expm.py * Simplify client & add test.Add docs; Fix async bug * Fix comments & pylint * Improve README --- qlib/workflow/__init__.py | 9 +-- qlib/workflow/exp.py | 1 - qlib/workflow/expm.py | 105 ++++++++++++-------------- qlib/workflow/recorder.py | 3 +- setup.py | 2 +- tests/dependency_tests/README.md | 3 + tests/dependency_tests/test_mlflow.py | 34 +++++++++ 7 files changed, 94 insertions(+), 63 deletions(-) create mode 100644 tests/dependency_tests/README.md create mode 100644 tests/dependency_tests/test_mlflow.py diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py index aecf0ac992..d14782c60d 100644 --- a/qlib/workflow/__init__.py +++ b/qlib/workflow/__init__.py @@ -8,7 +8,6 @@ from .recorder import Recorder from ..utils import Wrapper from ..utils.exceptions import RecorderInitializationError -from qlib.config import C class QlibRecorder: @@ -347,14 +346,14 @@ def get_uri(self): def set_uri(self, uri: Optional[Text]): """ - Method to reset the current uri of current experiment manager. + Method to reset the **default** uri of current experiment manager. NOTE: - When the uri is refer to a file path, please using the absolute path instead of strings like "~/mlruns/" The backend don't support strings like this. """ - self.exp_manager.set_uri(uri) + self.exp_manager.default_uri = uri @contextmanager def uri_context(self, uri: Text): @@ -370,11 +369,11 @@ def uri_context(self, uri: Text): the temporal uri """ prev_uri = self.exp_manager.default_uri - C.exp_manager["kwargs"]["uri"] = uri + self.exp_manager.default_uri = uri try: yield finally: - C.exp_manager["kwargs"]["uri"] = prev_uri + self.exp_manager.default_uri = prev_uri def get_recorder( self, diff --git a/qlib/workflow/exp.py b/qlib/workflow/exp.py index d3dd0a535d..95e5db4738 100644 --- a/qlib/workflow/exp.py +++ b/qlib/workflow/exp.py @@ -249,7 +249,6 @@ class MLflowExperiment(Experiment): def __init__(self, id, name, uri): super(MLflowExperiment, self).__init__(id, name) self._uri = uri - self._default_name = None self._default_rec_name = "mlflow_recorder" self._client = mlflow.tracking.MlflowClient(tracking_uri=self._uri) diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py index 419848517d..3aaa574dd2 100644 --- a/qlib/workflow/expm.py +++ b/qlib/workflow/expm.py @@ -15,23 +15,32 @@ from ..log import get_module_logger from ..utils.exceptions import ExpAlreadyExistError + logger = get_module_logger("workflow") class ExpManager: """ - This is the `ExpManager` class for managing experiments. The API is designed similar to mlflow. - (The link: https://mlflow.org/docs/latest/python_api/mlflow.html) + This is the `ExpManager` class for managing experiments. The API is designed similar to mlflow. + (The link: https://mlflow.org/docs/latest/python_api/mlflow.html) + + The `ExpManager` is expected to be a singleton (btw, we can have multiple `Experiment`s with different uri. user can get different experiments from different uri, and then compare records of them). Global Config (i.e. `C`) is also a singleton. + So we try to align them together. They share the same variable, which is called **default uri**. Please refer to `ExpManager.default_uri` for details of variable sharing. + + When the user starts an experiment, the user may want to set the uri to a specific uri (it will override **default uri** during this period), and then unset the **specific uri** and fallback to the **default uri**. `ExpManager._active_exp_uri` is that **specific uri**. """ + active_experiment: Optional[Experiment] + def __init__(self, uri: Text, default_exp_name: Optional[Text]): - self._current_uri = uri + self.default_uri = uri + self._active_exp_uri = None # No active experiments. So it is set to None self._default_exp_name = default_exp_name self.active_experiment = None # only one experiment can be active each time - logger.info(f"experiment manager uri is at {self._current_uri}") + logger.info(f"experiment manager uri is at {self.uri}") def __repr__(self): - return "{name}(current_uri={curi})".format(name=self.__class__.__name__, curi=self._current_uri) + return "{name}(uri={uri})".format(name=self.__class__.__name__, uri=self.uri) def start_exp( self, @@ -43,11 +52,13 @@ def start_exp( uri: Optional[Text] = None, resume: bool = False, **kwargs, - ): + ) -> Experiment: """ Start an experiment. This method includes first get_or_create an experiment, and then set it to be active. + Maintaining `_active_exp_uri` is included in start_exp, remaining implementation should be included in _end_exp in subclass + Parameters ---------- experiment_id : str @@ -67,12 +78,28 @@ def start_exp( ------- An active experiment. """ + self._active_exp_uri = uri + # The subclass may set the underlying uri back. + # So setting `_active_exp_uri` come before `_start_exp` + return self._start_exp( + experiment_id=experiment_id, + experiment_name=experiment_name, + recorder_id=recorder_id, + recorder_name=recorder_name, + resume=resume, + **kwargs, + ) + + def _start_exp(self, *args, **kwargs) -> Experiment: + """Please refer to the doc of `start_exp`""" raise NotImplementedError(f"Please implement the `start_exp` method.") def end_exp(self, recorder_status: Text = Recorder.STATUS_S, **kwargs): """ End an active experiment. + Maintaining `_active_exp_uri` is included in end_exp, remaining implementation should be included in _end_exp in subclass + Parameters ---------- experiment_name : str @@ -80,6 +107,12 @@ def end_exp(self, recorder_status: Text = Recorder.STATUS_S, **kwargs): recorder_status : str the status of the active recorder of the experiment. """ + self._active_exp_uri = None + # The subclass may set the underlying uri back. + # So setting `_active_exp_uri` come before `_end_exp` + self._end_exp(recorder_status=recorder_status, **kwargs) + + def _end_exp(self, recorder_status: Text = Recorder.STATUS_S, **kwargs): raise NotImplementedError(f"Please implement the `end_exp` method.") def create_exp(self, experiment_name: Optional[Text] = None): @@ -254,6 +287,10 @@ def default_uri(self): raise ValueError("The default URI is not set in qlib.config.C") return C.exp_manager["kwargs"]["uri"] + @default_uri.setter + def default_uri(self, value): + C.exp_manager.setdefault("kwargs", {})["uri"] = value + @property def uri(self): """ @@ -263,33 +300,7 @@ def uri(self): ------- The tracking URI string. """ - return self._current_uri or self.default_uri - - def set_uri(self, uri: Optional[Text] = None): - """ - Set the current tracking URI and the corresponding variables. - - Parameters - ---------- - uri : str - - """ - if uri is None: - if self._current_uri is None: - logger.debug("No tracking URI is provided. Use the default tracking URI.") - self._current_uri = self.default_uri - else: - # Temporarily re-set the current uri as the uri argument. - self._current_uri = uri - # Customized features for subclasses. - self._set_uri() - - def _set_uri(self): - """ - Customized features for subclasses' set_uri function. - This method is designed for the underlying experiment backend storage. - """ - raise NotImplementedError(f"Please implement the `_set_uri` method.") + return self._active_exp_uri or self.default_uri def list_experiments(self): """ @@ -307,33 +318,21 @@ class MLflowExpManager(ExpManager): Use mlflow to implement ExpManager. """ - def __init__(self, uri: Text, default_exp_name: Optional[Text]): - super(MLflowExpManager, self).__init__(uri, default_exp_name) - self._client = None - - def _set_uri(self): - self._client = mlflow.tracking.MlflowClient(tracking_uri=self.uri) - logger.info("{:}".format(self._client)) - @property def client(self): - # Delay the creation of mlflow client in case of creating `mlruns` folder when importing qlib - if self._client is None: - self._client = mlflow.tracking.MlflowClient(tracking_uri=self.uri) - return self._client + # Please refer to `tests/dependency_tests/test_mlflow.py::MLflowTest::test_creating_client` + # The test ensure the speed of create a new client + return mlflow.tracking.MlflowClient(tracking_uri=self.uri) - def start_exp( + def _start_exp( self, *, experiment_id: Optional[Text] = None, experiment_name: Optional[Text] = None, recorder_id: Optional[Text] = None, recorder_name: Optional[Text] = None, - uri: Optional[Text] = None, resume: bool = False, ): - # Set the tracking uri - self.set_uri(uri) # Create experiment if experiment_name is None: experiment_name = self._default_exp_name @@ -345,12 +344,10 @@ def start_exp( return self.active_experiment - def end_exp(self, recorder_status: Text = Recorder.STATUS_S): + def _end_exp(self, recorder_status: Text = Recorder.STATUS_S): if self.active_experiment is not None: self.active_experiment.end(recorder_status) self.active_experiment = None - # When an experiment end, we will release the current uri. - self._current_uri = None def create_exp(self, experiment_name: Optional[Text] = None): assert experiment_name is not None @@ -362,9 +359,7 @@ def create_exp(self, experiment_name: Optional[Text] = None): raise ExpAlreadyExistError() from e raise e - experiment = MLflowExperiment(experiment_id, experiment_name, self.uri) - experiment._default_name = self._default_exp_name - return experiment + return MLflowExperiment(experiment_id, experiment_name, self.uri) def _get_exp(self, experiment_id=None, experiment_name=None): """ diff --git a/qlib/workflow/recorder.py b/qlib/workflow/recorder.py index 1b46466013..9d82bf0a47 100644 --- a/qlib/workflow/recorder.py +++ b/qlib/workflow/recorder.py @@ -378,14 +378,15 @@ def end_run(self, status: str = Recorder.STATUS_S): Recorder.STATUS_FI, Recorder.STATUS_FA, ], f"The status type {status} is not supported." - mlflow.end_run(status) self.end_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") if self.status != Recorder.STATUS_S: self.status = status if self.async_log is not None: + # Waiting Queue should go before mlflow.end_run. Otherwise mlflow will raise error with TimeInspector.logt("waiting `async_log`"): self.async_log.wait() self.async_log = None + mlflow.end_run(status) def save_objects(self, local_path=None, artifact_path=None, **kwargs): assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly." diff --git a/setup.py b/setup.py index a796ecf4b7..faf058d631 100644 --- a/setup.py +++ b/setup.py @@ -62,7 +62,7 @@ def get_version(rel_path: str) -> str: "matplotlib>=3.3", "tables>=3.6.1", "pyyaml>=5.3.1", - "mlflow>=1.12.1", + "mlflow>=1.12.1, <=1.30.0", "tqdm", "loguru", "lightgbm>=3.3.0", diff --git a/tests/dependency_tests/README.md b/tests/dependency_tests/README.md new file mode 100644 index 0000000000..544fac130a --- /dev/null +++ b/tests/dependency_tests/README.md @@ -0,0 +1,3 @@ +Some implementations of Qlib depend on some assumptions of its dependencies. + +So some tests are requried to ensure that these assumptions are valid. diff --git a/tests/dependency_tests/test_mlflow.py b/tests/dependency_tests/test_mlflow.py new file mode 100644 index 0000000000..94f164a357 --- /dev/null +++ b/tests/dependency_tests/test_mlflow.py @@ -0,0 +1,34 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import unittest +import mlflow +import time +from pathlib import Path +import shutil + + +class MLflowTest(unittest.TestCase): + TMP_PATH = Path("./.mlruns_tmp/") + + def tearDown(self) -> None: + if self.TMP_PATH.exists(): + shutil.rmtree(self.TMP_PATH) + + def test_creating_client(self): + """ + Please refer to qlib/workflow/expm.py:MLflowExpManager._client + we don't cache _client (this is helpful to reduce maintainance work when MLflowExpManager's uri is chagned) + + This implementation is based on the assumption creating a client is fast + """ + start = time.time() + for i in range(10): + _ = mlflow.tracking.MlflowClient(tracking_uri=str(self.TMP_PATH)) + end = time.time() + elasped = end - start + self.assertLess(elasped, 1e-2) # it can be done in less than 10ms + print(elasped) + + +if __name__ == "__main__": + unittest.main()