Skip to content

Commit

Permalink
RL backtest pipeline on 5-min data (microsoft#1417)
Browse files Browse the repository at this point in the history
* Workflow runnable

* CI

* Slight changes to make the workflow runnable. The changes of handler/provider should be reverted before merging.

* Train experiment successful

* Refine handler & provider

* test passed

* Ready to test on server

* Minor

* Test passed

* TWAP training

* Add PPOReward

* Add a FIXME

* Refine PPO reward according to PR comments

* Minor

* Resolve PR comments

* CI issues

* CI issues

* CI issues
  • Loading branch information
lihuoran authored and qianyun210603 committed Mar 23, 2023
1 parent c385d34 commit 507a5a1
Show file tree
Hide file tree
Showing 25 changed files with 250 additions and 166 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/test_qlib_from_source.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,12 @@ jobs:
# W1309: f-string-without-interpolation
# E1102: not-callable
# E1136: unsubscriptable-object
# FIXME: Due to the version change of Pylint, some code will cause W0719 error after PR 1417. W0719 is temporarily disabled in PR 1417 and should be fixed.
# References for parameters: https://github.com/PyCQA/pylint/issues/4577#issuecomment-1000245962
# We use sys.setrecursionlimit(2000) to make the recursion depth larger to ensure that pylint works properly (the default recursion depth is 1000).
- name: Check Qlib with pylint
run: |
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' qlib --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)"
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136,W0719 --const-rgx='[a-z_][a-z0-9_]{2,30}$' qlib --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)"
# The following flake8 error codes were ignored:
# E501 line too long
Expand Down
6 changes: 3 additions & 3 deletions qlib/backtest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def get_exchange(
open_cost: float = 0.0015,
close_cost: float = 0.0025,
min_cost: float = 5.0,
limit_threshold: Union[Tuple[str, str], float, None] = None,
deal_price: Union[str, Tuple[str, str], List[str]] = None,
limit_threshold: Union[Tuple[str, str], float, None] | None = None,
deal_price: Union[str, Tuple[str, str], List[str]] | None = None,
**kwargs: Any,
) -> Exchange:
"""get_exchange
Expand Down Expand Up @@ -284,7 +284,7 @@ def collect_data(
account: Union[float, int, dict] = 1e9,
exchange_kwargs: dict = {},
pos_type: str = "Position",
return_value: dict = None,
return_value: dict | None = None,
) -> Generator[object, None, None]:
"""initialize the strategy and executor, then collect the trade decision data for rl training
Expand Down
4 changes: 3 additions & 1 deletion qlib/backtest/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,9 @@ def reset_report(self, freq: str, benchmark_config: dict) -> None:
# trading related metrics(e.g. high-frequency trading)
self.indicator = Indicator()

def reset(self, freq: str = None, benchmark_config: dict = None, port_metr_enabled: bool = None) -> None:
def reset(
self, freq: str | None = None, benchmark_config: dict | None = None, port_metr_enabled: bool | None = None
) -> None:
"""reset freq and report of account
Parameters
Expand Down
2 changes: 1 addition & 1 deletion qlib/backtest/backtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def collect_data_loop(
end_time: Union[pd.Timestamp, str],
trade_strategy: BaseStrategy,
trade_executor: BaseExecutor,
return_value: dict = None,
return_value: dict | None = None,
) -> Generator[BaseTradeDecision, Optional[BaseTradeDecision], None]:
"""Generator for collecting the trade decision data for rl training
Expand Down
6 changes: 3 additions & 3 deletions qlib/backtest/decision.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def __init__(self, start_idx: int, end_idx: int) -> None:
self._start_idx = start_idx
self._end_idx = end_idx

def __call__(self, trade_calendar: TradeCalendarManager = None) -> Tuple[int, int]:
def __call__(self, trade_calendar: TradeCalendarManager | None = None) -> Tuple[int, int]:
return self._start_idx, self._end_idx

def clip_time_range(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Tuple[pd.Timestamp, pd.Timestamp]:
Expand Down Expand Up @@ -315,7 +315,7 @@ class BaseTradeDecision(Generic[DecisionType]):
2. Same as `case 1.3`
"""

def __init__(self, strategy: BaseStrategy, trade_range: Union[Tuple[int, int], TradeRange] = None) -> None:
def __init__(self, strategy: BaseStrategy, trade_range: Union[Tuple[int, int], TradeRange, None] = None) -> None:
"""
Parameters
----------
Expand Down Expand Up @@ -554,7 +554,7 @@ def __init__(
self,
order_list: List[Order],
strategy: BaseStrategy,
trade_range: Union[Tuple[int, int], TradeRange] = None,
trade_range: Union[Tuple[int, int], TradeRange, None] = None,
) -> None:
super().__init__(strategy, trade_range=trade_range)
self.order_list = cast(List[Order], order_list)
Expand Down
26 changes: 13 additions & 13 deletions qlib/backtest/exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ def __init__(
start_time: Union[pd.Timestamp, str] = None,
end_time: Union[pd.Timestamp, str] = None,
codes: Union[list, str] = "all",
deal_price: Union[str, Tuple[str, str], List[str]] = None,
deal_price: Union[str, Tuple[str, str], List[str], None] = None,
subscribe_fields: list = [],
limit_threshold: Union[Tuple[str, str], float, None] = None,
volume_threshold: Union[tuple, dict] = None,
volume_threshold: Union[tuple, dict, None] = None,
open_cost: float = 0.0015,
close_cost: float = 0.0025,
min_cost: float = 5.0,
Expand Down Expand Up @@ -340,7 +340,7 @@ def check_stock_limit(
stock_id: str,
start_time: pd.Timestamp,
end_time: pd.Timestamp,
direction: int = None,
direction: int | None = None,
) -> bool:
"""
Parameters
Expand Down Expand Up @@ -406,7 +406,7 @@ def is_stock_tradable(
stock_id: str,
start_time: pd.Timestamp,
end_time: pd.Timestamp,
direction: int = None,
direction: int | None = None,
) -> bool:
# check if stock can be traded
return not (
Expand All @@ -421,8 +421,8 @@ def check_order(self, order: Order) -> bool:
def deal_order(
self,
order: Order,
trade_account: Account = None,
position: BasePosition = None,
trade_account: Account | None = None,
position: BasePosition | None = None,
dealt_order_amount: Dict[str, float] = defaultdict(float),
) -> Tuple[float, float, float]:
"""
Expand Down Expand Up @@ -586,7 +586,7 @@ def generate_amount_position_from_weight_position(
)
return amount_dict

def get_real_deal_amount(self, current_amount: float, target_amount: float, factor: float = None) -> float:
def get_real_deal_amount(self, current_amount: float, target_amount: float, factor: float | None = None) -> float:
"""
Calculate the real adjust deal amount when considering the trading unit
:param current_amount:
Expand Down Expand Up @@ -712,8 +712,8 @@ def calculate_amount_position_value(

def _get_factor_or_raise_error(
self,
factor: float = None,
stock_id: str = None,
factor: float | None = None,
stock_id: str | None = None,
start_time: pd.Timestamp = None,
end_time: pd.Timestamp = None,
) -> float:
Expand All @@ -728,8 +728,8 @@ def _get_factor_or_raise_error(

def get_amount_of_trade_unit(
self,
factor: float = None,
stock_id: str = None,
factor: float | None = None,
stock_id: str | None = None,
start_time: pd.Timestamp = None,
end_time: pd.Timestamp = None,
) -> Optional[float]:
Expand Down Expand Up @@ -762,8 +762,8 @@ def get_amount_of_trade_unit(
def round_amount_by_trade_unit(
self,
deal_amount: float,
factor: float = None,
stock_id: str = None,
factor: float | None = None,
stock_id: str | None = None,
start_time: pd.Timestamp = None,
end_time: pd.Timestamp = None,
) -> float:
Expand Down
12 changes: 6 additions & 6 deletions qlib/backtest/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def __init__(
generate_portfolio_metrics: bool = False,
verbose: bool = False,
track_data: bool = False,
trade_exchange: Exchange = None,
common_infra: CommonInfrastructure = None,
trade_exchange: Exchange | None = None,
common_infra: CommonInfrastructure | None = None,
settle_type: str = BasePosition.ST_NO,
**kwargs: Any,
) -> None:
Expand Down Expand Up @@ -161,7 +161,7 @@ def trade_calendar(self) -> TradeCalendarManager:
"""
return self.level_infra.get("trade_calendar")

def reset(self, common_infra: CommonInfrastructure = None, **kwargs: Any) -> None:
def reset(self, common_infra: CommonInfrastructure | None = None, **kwargs: Any) -> None:
"""
- reset `start_time` and `end_time`, used in trade calendar
- reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc
Expand Down Expand Up @@ -227,7 +227,7 @@ def _collect_data(
def collect_data(
self,
trade_decision: BaseTradeDecision,
return_value: dict = None,
return_value: dict | None = None,
level: int = 0,
) -> Generator[Any, Any, List[object]]:
"""Generator for collecting the trade decision data for rl training
Expand Down Expand Up @@ -327,7 +327,7 @@ def __init__(
track_data: bool = False,
skip_empty_decision: bool = True,
align_range_limit: bool = True,
common_infra: CommonInfrastructure = None,
common_infra: CommonInfrastructure | None = None,
**kwargs: Any,
) -> None:
"""
Expand Down Expand Up @@ -534,7 +534,7 @@ def __init__(
generate_portfolio_metrics: bool = False,
verbose: bool = False,
track_data: bool = False,
common_infra: CommonInfrastructure = None,
common_infra: CommonInfrastructure | None = None,
trade_type: str = TT_SERIAL,
**kwargs: Any,
) -> None:
Expand Down
3 changes: 2 additions & 1 deletion qlib/backtest/position.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from __future__ import annotations

from datetime import timedelta
from typing import Any, Dict, List, Union
Expand Down Expand Up @@ -320,7 +321,7 @@ def fill_stock_value(self, start_time: Union[str, pd.Timestamp], freq: str, last
self.position[stock]["price"] = price_dict[stock]
self.position["now_account_value"] = self.calculate_value()

def _init_stock(self, stock_id: str, amount: float, price: float = None) -> None:
def _init_stock(self, stock_id: str, amount: float, price: float | None = None) -> None:
"""
initialization the stock in current position
Expand Down
21 changes: 11 additions & 10 deletions qlib/backtest/report.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from __future__ import annotations

import pathlib
from collections import OrderedDict
Expand Down Expand Up @@ -86,7 +87,7 @@ def init_vars(self) -> None:
self.benches: dict = OrderedDict()
self.latest_pm_time: Optional[pd.TimeStamp] = None

def init_bench(self, freq: str = None, benchmark_config: dict = None) -> None:
def init_bench(self, freq: str | None = None, benchmark_config: dict | None = None) -> None:
if freq is not None:
self.freq = freq
self.benchmark_config = benchmark_config
Expand Down Expand Up @@ -149,15 +150,15 @@ def update_portfolio_metrics_record(
self,
trade_start_time: Union[str, pd.Timestamp] = None,
trade_end_time: Union[str, pd.Timestamp] = None,
account_value: float = None,
cash: float = None,
return_rate: float = None,
total_turnover: float = None,
turnover_rate: float = None,
total_cost: float = None,
cost_rate: float = None,
stock_value: float = None,
bench_value: float = None,
account_value: float | None = None,
cash: float | None = None,
return_rate: float | None = None,
total_turnover: float | None = None,
turnover_rate: float | None = None,
total_cost: float | None = None,
cost_rate: float | None = None,
stock_value: float | None = None,
bench_value: float | None = None,
) -> None:
# check data
if None in [
Expand Down
4 changes: 2 additions & 2 deletions qlib/backtest/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(
freq: str,
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
level_infra: LevelInfrastructure = None,
level_infra: LevelInfrastructure | None = None,
) -> None:
"""
Parameters
Expand Down Expand Up @@ -99,7 +99,7 @@ def get_trade_len(self) -> int:
def get_trade_step(self) -> int:
return self.trade_step

def get_step_time(self, trade_step: int = None, shift: int = 0) -> Tuple[pd.Timestamp, pd.Timestamp]:
def get_step_time(self, trade_step: int | None = None, shift: int = 0) -> Tuple[pd.Timestamp, pd.Timestamp]:
"""
Get the left and right endpoints of the trade_step'th trading interval
Expand Down
12 changes: 7 additions & 5 deletions qlib/contrib/ops/high_freq.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class DayCumsum(ElemOperator):
Otherwise, the value is zero.
"""

def __init__(self, feature, start: str = "9:30", end: str = "14:59"):
def __init__(self, feature, start: str = "9:30", end: str = "14:59", data_granularity: int = 1):
self.feature = feature
self.start = datetime.strptime(start, "%H:%M")
self.end = datetime.strptime(end, "%H:%M")
Expand All @@ -80,15 +80,17 @@ def __init__(self, feature, start: str = "9:30", end: str = "14:59"):
self.noon_open = datetime.strptime("13:00", "%H:%M")
self.noon_close = datetime.strptime("15:00", "%H:%M")

self.start_id = time_to_day_index(self.start)
self.end_id = time_to_day_index(self.end)
self.data_granularity = data_granularity
self.start_id = time_to_day_index(self.start) // self.data_granularity
self.end_id = time_to_day_index(self.end) // self.data_granularity
assert 240 % self.data_granularity == 0

def period_cusum(self, df):
df = df.copy()
assert len(df) == 240
assert len(df) == 240 // self.data_granularity
df.iloc[0 : self.start_id] = 0
df = df.cumsum()
df.iloc[self.end_id + 1 : 240] = 0
df.iloc[self.end_id + 1 : 240 // self.data_granularity] = 0
return df

def _load_internal(self, instrument, start_index, end_index, freq):
Expand Down
12 changes: 6 additions & 6 deletions qlib/rl/contrib/backtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@

def _get_multi_level_executor_config(
strategy_config: dict,
cash_limit: float = None,
cash_limit: float | None = None,
generate_report: bool = False,
) -> dict:
executor_config = {
"class": "SimulatorExecutor",
"module_path": "qlib.backtest.executor",
"kwargs": {
"time_per_step": "1min",
"time_per_step": "5min", # FIXME: move this into config
"verbose": False,
"trade_type": SimulatorExecutor.TT_PARAL if cash_limit is not None else SimulatorExecutor.TT_SERIAL,
"generate_report": generate_report,
Expand Down Expand Up @@ -127,7 +127,7 @@ def single_with_simulator(
backtest_config: dict,
orders: pd.DataFrame,
split: Literal["stock", "day"] = "stock",
cash_limit: float = None,
cash_limit: float | None = None,
generate_report: bool = False,
) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]:
"""Run backtest in a single thread with SingleAssetOrderExecution simulator. The orders will be executed day by day.
Expand Down Expand Up @@ -187,7 +187,7 @@ def single_with_simulator(
exchange_config.update(
{
"codes": stocks,
"freq": "1min",
"freq": "5min", # FIXME: move this into config
}
)

Expand Down Expand Up @@ -226,7 +226,7 @@ def single_with_collect_data_loop(
backtest_config: dict,
orders: pd.DataFrame,
split: Literal["stock", "day"] = "stock",
cash_limit: float = None,
cash_limit: float | None = None,
generate_report: bool = False,
) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]:
"""Run backtest in a single thread with collect_data_loop.
Expand Down Expand Up @@ -286,7 +286,7 @@ def single_with_collect_data_loop(
exchange_config.update(
{
"codes": stocks,
"freq": "1min",
"freq": "5min", # FIXME: move this into config
}
)

Expand Down
2 changes: 1 addition & 1 deletion qlib/rl/contrib/naive_config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def get_backtest_config_fromfile(path: str) -> dict:
"debug_single_day": None,
"concurrency": -1,
"multiplier": 1.0,
"output_dir": "outputs/",
"output_dir": "outputs_backtest/",
"generate_report": False,
}
backtest_config = merge_a_into_b(a=backtest_config, b=backtest_config_default)
Expand Down
Loading

0 comments on commit 507a5a1

Please sign in to comment.