Skip to content

Commit

Permalink
Merge pull request #345 from D-X-Y/main
Browse files Browse the repository at this point in the history
Fix errors when SignalRecord is not called before SigAna/PortAna
  • Loading branch information
you-n-g authored Mar 17, 2021
2 parents d47e35d + 872ddc6 commit aa552fd
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 11 deletions.
Empty file.
45 changes: 45 additions & 0 deletions qlib/contrib/workflow/record_temp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import re
import pandas as pd
from sklearn.metrics import mean_squared_error
from pprint import pprint
import numpy as np

from ...workflow.record_temp import SignalRecord
from ...log import get_module_logger

logger = get_module_logger("workflow", "INFO")


class SignalMseRecord(SignalRecord):
"""
This is the Signal MSE Record class that computes the mean squared error (MSE).
This class inherits the ``SignalMseRecord`` class.
"""

artifact_path = "sig_analysis"

def __init__(self, recorder, **kwargs):
super().__init__(recorder=recorder, **kwargs)

def generate(self, **kwargs):
try:
self.check(parent=True)
except FileExistsError:
super().generate()

pred = self.load("pred.pkl")
label = self.load("label.pkl")
masks = ~np.isnan(label.values)
mse = mean_squared_error(pred.values[masks], label[masks])
metrics = {"MSE": mse, "RMSE": np.sqrt(mse)}
objects = {"mse.pkl": mse, "rmse.pkl": np.sqrt(mse)}
self.recorder.log_metrics(**metrics)
self.recorder.save_objects(**objects, artifact_path=self.get_path())
pprint(metrics)

def list(self):
paths = [self.get_path("mse.pkl"), self.get_path("rmse.pkl")]
return paths
21 changes: 13 additions & 8 deletions qlib/workflow/record_temp.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class SignalRecord(RecordTemp):
This is the Signal Record class that generates the signal prediction. This class inherits the ``RecordTemp`` class.
"""

def __init__(self, model=None, dataset=None, recorder=None, **kwargs):
def __init__(self, model=None, dataset=None, recorder=None):
super().__init__(recorder=recorder)
self.model = model
self.dataset = dataset
Expand Down Expand Up @@ -164,13 +164,15 @@ class SigAnaRecord(SignalRecord):
artifact_path = "sig_analysis"

def __init__(self, recorder, ana_long_short=False, ann_scaler=252, **kwargs):
super().__init__(recorder=recorder, **kwargs)
self.ana_long_short = ana_long_short
self.ann_scaler = ann_scaler
super().__init__(recorder=recorder, **kwargs)
# The name must be unique. Otherwise it will be overridden

def generate(self):
self.check(parent=True)
def generate(self, **kwargs):
try:
self.check(parent=True)
except FileExistsError:
super().generate()

pred = self.load("pred.pkl")
label = self.load("label.pkl")
Expand Down Expand Up @@ -228,18 +230,21 @@ def __init__(self, recorder, config, **kwargs):
config["backtest"] : dict
define the backtest kwargs.
"""
super().__init__(recorder=recorder)
super().__init__(recorder=recorder, **kwargs)

self.strategy_config = config["strategy"]
self.backtest_config = config["backtest"]
self.strategy = init_instance_by_config(self.strategy_config, accept_types=BaseStrategy)

def generate(self, **kwargs):
# check previously stored prediction results
self.check(parent=True) # "Make sure the parent process is completed and store the data properly."
try:
self.check(parent=True) # "Make sure the parent process is completed and store the data properly."
except FileExistsError:
super().generate()

# custom strategy and get backtest
pred_score = super().load()
pred_score = super().load("pred.pkl")
report_dict = normal_backtest(pred_score, strategy=self.strategy, **self.backtest_config)
report_normal = report_dict.get("report_df")
positions_normal = report_dict.get("positions")
Expand Down
45 changes: 42 additions & 3 deletions tests/test_all_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
backtest as normal_backtest,
risk_analysis,
)
from qlib.contrib.workflow.record_temp import SignalMseRecord
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord, SigAnaRecord, PortAnaRecord
Expand Down Expand Up @@ -139,6 +140,38 @@ def train():
return pred_score, {"ic": ic, "ric": ric}, rid


def train_with_sigana():
"""train model followed by SigAnaRecord
Returns
-------
pred_score: pandas.DataFrame
predict scores
performance: dict
model performance
"""
model = init_instance_by_config(task["model"])
dataset = init_instance_by_config(task["dataset"])

# start exp
with R.start(experiment_name="workflow_with_sigana"):
R.log_params(**flatten_dict(task))
model.fit(dataset)

# predict and calculate ic and ric
recorder = R.get_recorder()
sar = SigAnaRecord(recorder, model=model, dataset=dataset)
sar.generate()
ic = sar.load(sar.get_path("ic.pkl"))
ric = sar.load(sar.get_path("ric.pkl"))
pred_score = sar.load("pred.pkl")

smr = SignalMseRecord(recorder)
smr.generate()
uri_path = R.get_uri()
return pred_score, {"ic": ic, "ric": ric}, uri_path


def fake_experiment():
"""A fake experiment workflow to test uri
Expand Down Expand Up @@ -195,20 +228,26 @@ class TestAllFlow(TestAutoData):
def tearDownClass(cls) -> None:
shutil.rmtree(str(Path(C["exp_manager"]["kwargs"]["uri"].strip("file:")).resolve()))

def test_0_train(self):
def test_0_train_with_sigana(self):
TestAllFlow.PRED_SCORE, ic_ric, uri_path = train_with_sigana()
self.assertGreaterEqual(ic_ric["ic"].all(), 0, "train failed")
self.assertGreaterEqual(ic_ric["ric"].all(), 0, "train failed")
shutil.rmtree(str(Path(uri_path.strip("file:")).resolve()))

def test_1_train(self):
TestAllFlow.PRED_SCORE, ic_ric, TestAllFlow.RID = train()
self.assertGreaterEqual(ic_ric["ic"].all(), 0, "train failed")
self.assertGreaterEqual(ic_ric["ric"].all(), 0, "train failed")

def test_1_backtest(self):
def test_2_backtest(self):
analyze_df = backtest_analysis(TestAllFlow.PRED_SCORE, TestAllFlow.RID)
self.assertGreaterEqual(
analyze_df.loc(axis=0)["excess_return_with_cost", "annualized_return"].values[0],
0.10,
"backtest failed",
)

def test_2_expmanager(self):
def test_3_expmanager(self):
pass_default, pass_current, uri_path = fake_experiment()
self.assertTrue(pass_default, msg="default uri is incorrect")
self.assertTrue(pass_current, msg="current uri is incorrect")
Expand Down

0 comments on commit aa552fd

Please sign in to comment.