Skip to content

Commit

Permalink
fix workflow bug (#882)
Browse files Browse the repository at this point in the history
* fix workflow bug

* Fix output of pytorch NN

* Fix parameter bug
  • Loading branch information
you-n-g authored Jan 22, 2022
1 parent d533219 commit 01afd06
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 9 deletions.
2 changes: 1 addition & 1 deletion qlib/contrib/model/pytorch_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def fit(
R.log_metrics(val_loss=loss_val.val, step=step)
if verbose:
self.logger.info(
"[Epoch {}]: train_loss {:.6f}, valid_loss {:.6f}".format(step, train_loss, loss_val.val)
"[Step {}]: train_loss {:.6f}, valid_loss {:.6f}".format(step, train_loss, loss_val.val)
)
evals_result["train"].append(train_loss)
evals_result["valid"].append(loss_val.val)
Expand Down
4 changes: 2 additions & 2 deletions qlib/model/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def _log_task_info(task_config: dict):
def _exe_task(task_config: dict):
rec = R.get_recorder()
# model & dataset initiation
model: Model = init_instance_by_config(task_config["model"])
dataset: Dataset = init_instance_by_config(task_config["dataset"])
model: Model = init_instance_by_config(task_config["model"], accept_types=Model)
dataset: Dataset = init_instance_by_config(task_config["dataset"], accept_types=Dataset)
reweighter: Reweighter = task_config.get("reweighter", None)
# model training
auto_filter_kwargs(model.fit)(dataset, reweighter=reweighter)
Expand Down
16 changes: 11 additions & 5 deletions qlib/workflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,9 @@ def list_recorders(self, experiment_id=None, experiment_name=None):
"""
return self.get_exp(experiment_id=experiment_id, experiment_name=experiment_name).list_recorders()

def get_exp(self, *, experiment_id=None, experiment_name=None, create: bool = True) -> Experiment:
def get_exp(
self, *, experiment_id=None, experiment_name=None, create: bool = True, start: bool = False
) -> Experiment:
"""
Method for retrieving an experiment with given id or name. Once the `create` argument is set to
True, if no valid experiment is found, this method will create one for you. Otherwise, it will
Expand Down Expand Up @@ -291,6 +293,10 @@ def get_exp(self, *, experiment_id=None, experiment_name=None, create: bool = Tr
create : boolean
an argument determines whether the method will automatically create a new experiment
according to user's specification if the experiment hasn't been created before.
start : bool
when start is True,
if the experiment has not started(not activated), it will start
It is designed for R.log_params to auto start experiments
Returns
-------
Expand All @@ -300,7 +306,7 @@ def get_exp(self, *, experiment_id=None, experiment_name=None, create: bool = Tr
experiment_id=experiment_id,
experiment_name=experiment_name,
create=create,
start=False,
start=start,
)

def delete_exp(self, experiment_id=None, experiment_name=None):
Expand Down Expand Up @@ -542,7 +548,7 @@ def log_params(self, **kwargs):
keyword argument:
name1=value1, name2=value2, ...
"""
self.get_exp().get_recorder(start=True).log_params(**kwargs)
self.get_exp(start=True).get_recorder(start=True).log_params(**kwargs)

def log_metrics(self, step=None, **kwargs):
"""
Expand All @@ -567,7 +573,7 @@ def log_metrics(self, step=None, **kwargs):
keyword argument:
name1=value1, name2=value2, ...
"""
self.get_exp().get_recorder(start=True).log_metrics(step, **kwargs)
self.get_exp(start=True).get_recorder(start=True).log_metrics(step, **kwargs)

def set_tags(self, **kwargs):
"""
Expand All @@ -592,7 +598,7 @@ def set_tags(self, **kwargs):
keyword argument:
name1=value1, name2=value2, ...
"""
self.get_exp().get_recorder(start=True).set_tags(**kwargs)
self.get_exp(start=True).get_recorder(start=True).set_tags(**kwargs)


class RecorderWrapper(Wrapper):
Expand Down
2 changes: 1 addition & 1 deletion qlib/workflow/expm.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def get_exp(self, *, experiment_id=None, experiment_name=None, create: bool = Tr
self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name),
False,
)
if is_new and start:
if self.active_experiment is None and start:
self.active_experiment = exp
# start the recorder
self.active_experiment.start()
Expand Down

0 comments on commit 01afd06

Please sign in to comment.