forked from microsoft/qlib
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request microsoft#374 from bxdd/qlib_loaderhandler
Add DataLoader Based on DataHandler & Add Rolling Process Example & Restructure the Config & Setup_data
- Loading branch information
Showing
8 changed files
with
372 additions
and
71 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Rolling Process Data | ||
|
||
This workflow is an example for `Rolling Process Data`. | ||
|
||
## Background | ||
|
||
When rolling train the models, data also needs to be generated in the different rolling windows. When the rolling window moves, the training data will change, and the processor's learnable state (such as standard deviation, mean, etc.) will also change. | ||
|
||
In order to avoid regenerating data, this example uses the `DataHandler-based DataLoader` to load the raw features that are not related to the rolling window, and then used Processors to generate processed-features related to the rolling window. | ||
|
||
|
||
## Run the Code | ||
|
||
Run the example by running the following command: | ||
```bash | ||
python workflow.py rolling_process | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from qlib.data.dataset.handler import DataHandlerLP | ||
from qlib.data.dataset.loader import DataLoaderDH | ||
from qlib.contrib.data.handler import check_transform_proc | ||
|
||
|
||
class RollingDataHandler(DataHandlerLP): | ||
def __init__( | ||
self, | ||
start_time=None, | ||
end_time=None, | ||
infer_processors=[], | ||
learn_processors=[], | ||
fit_start_time=None, | ||
fit_end_time=None, | ||
data_loader_kwargs={}, | ||
): | ||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time) | ||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time) | ||
|
||
data_loader = { | ||
"class": "DataLoaderDH", | ||
"kwargs": {**data_loader_kwargs}, | ||
} | ||
|
||
super().__init__( | ||
instruments=None, | ||
start_time=start_time, | ||
end_time=end_time, | ||
data_loader=data_loader, | ||
infer_processors=infer_processors, | ||
learn_processors=learn_processors, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
import qlib | ||
import fire | ||
import pickle | ||
import pandas as pd | ||
|
||
from datetime import datetime | ||
from qlib.config import REG_CN | ||
from qlib.data.dataset.handler import DataHandlerLP | ||
from qlib.contrib.data.handler import Alpha158 | ||
from qlib.utils import exists_qlib_data, init_instance_by_config | ||
from qlib.tests.data import GetData | ||
|
||
|
||
class RollingDataWorkflow: | ||
|
||
MARKET = "csi300" | ||
start_time = "2010-01-01" | ||
end_time = "2019-12-31" | ||
rolling_cnt = 5 | ||
|
||
def _init_qlib(self): | ||
"""initialize qlib""" | ||
# use yahoo_cn_1min data | ||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir | ||
if not exists_qlib_data(provider_uri): | ||
print(f"Qlib data is not found in {provider_uri}") | ||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN) | ||
qlib.init(provider_uri=provider_uri, region=REG_CN) | ||
|
||
def _dump_pre_handler(self, path): | ||
handler_config = { | ||
"class": "Alpha158", | ||
"module_path": "qlib.contrib.data.handler", | ||
"kwargs": { | ||
"start_time": self.start_time, | ||
"end_time": self.end_time, | ||
"instruments": self.MARKET, | ||
"infer_processors": [], | ||
"learn_processors": [], | ||
}, | ||
} | ||
pre_handler = init_instance_by_config(handler_config) | ||
pre_handler.config(dump_all=True) | ||
pre_handler.to_pickle(path) | ||
|
||
def _load_pre_handler(self, path): | ||
with open(path, "rb") as file_dataset: | ||
pre_handler = pickle.load(file_dataset) | ||
return pre_handler | ||
|
||
def rolling_process(self): | ||
self._init_qlib() | ||
self._dump_pre_handler("pre_handler.pkl") | ||
pre_handler = self._load_pre_handler("pre_handler.pkl") | ||
|
||
train_start_time = (2010, 1, 1) | ||
train_end_time = (2012, 12, 31) | ||
valid_start_time = (2013, 1, 1) | ||
valid_end_time = (2013, 12, 31) | ||
test_start_time = (2014, 1, 1) | ||
test_end_time = (2014, 12, 31) | ||
|
||
dataset_config = { | ||
"class": "DatasetH", | ||
"module_path": "qlib.data.dataset", | ||
"kwargs": { | ||
"handler": { | ||
"class": "RollingDataHandler", | ||
"module_path": "rolling_handler", | ||
"kwargs": { | ||
"start_time": datetime(*train_start_time), | ||
"end_time": datetime(*test_end_time), | ||
"fit_start_time": datetime(*train_start_time), | ||
"fit_end_time": datetime(*train_end_time), | ||
"infer_processors": [ | ||
{"class": "RobustZScoreNorm", "kwargs": {"fields_group": "feature"}}, | ||
], | ||
"learn_processors": [ | ||
{"class": "DropnaLabel"}, | ||
{"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}}, | ||
], | ||
"data_loader_kwargs": { | ||
"handler_config": pre_handler, | ||
}, | ||
}, | ||
}, | ||
"segments": { | ||
"train": (datetime(*train_start_time), datetime(*train_end_time)), | ||
"valid": (datetime(*valid_start_time), datetime(*valid_end_time)), | ||
"test": (datetime(*test_start_time), datetime(*test_end_time)), | ||
}, | ||
}, | ||
} | ||
|
||
dataset = init_instance_by_config(dataset_config) | ||
|
||
for rolling_offset in range(self.rolling_cnt): | ||
|
||
print(f"===========rolling{rolling_offset} start===========") | ||
if rolling_offset: | ||
dataset.config( | ||
handler_kwargs={ | ||
"start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]), | ||
"end_time": datetime(test_end_time[0] + rolling_offset, *test_end_time[1:]), | ||
"processor_kwargs": { | ||
"fit_start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]), | ||
"fit_end_time": datetime(train_end_time[0] + rolling_offset, *train_end_time[1:]), | ||
}, | ||
}, | ||
segments={ | ||
"train": ( | ||
datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]), | ||
datetime(train_end_time[0] + rolling_offset, *train_end_time[1:]), | ||
), | ||
"valid": ( | ||
datetime(valid_start_time[0] + rolling_offset, *valid_start_time[1:]), | ||
datetime(valid_end_time[0] + rolling_offset, *valid_end_time[1:]), | ||
), | ||
"test": ( | ||
datetime(test_start_time[0] + rolling_offset, *test_start_time[1:]), | ||
datetime(test_end_time[0] + rolling_offset, *test_end_time[1:]), | ||
), | ||
}, | ||
) | ||
dataset.setup_data( | ||
handler_kwargs={ | ||
"init_type": DataHandlerLP.IT_FIT_SEQ, | ||
} | ||
) | ||
|
||
dtrain, dvalid, dtest = dataset.prepare(["train", "valid", "test"]) | ||
print(dtrain, dvalid, dtest) | ||
## print or dump data | ||
print(f"===========rolling{rolling_offset} end===========") | ||
|
||
|
||
if __name__ == "__main__": | ||
fire.Fire(RollingDataWorkflow) |
Oops, something went wrong.