Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine backtest codes #1120

Merged
merged 7 commits into from
Jun 10, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 79 additions & 56 deletions qlib/backtest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,29 @@
# Licensed under the MIT License.

from __future__ import annotations

import copy
from typing import List, Tuple, Union, TYPE_CHECKING
from pathlib import Path
from typing import TYPE_CHECKING, Generator, List, Optional, Tuple, Union

import pandas as pd

from .account import Account
from .report import Indicator, PortfolioMetrics

if TYPE_CHECKING:
from ..strategy.base import BaseStrategy
from .executor import BaseExecutor
from .decision import BaseTradeDecision
from .position import Position

from ..config import C
from ..log import get_module_logger
from ..utils import init_instance_by_config
from .backtest import backtest_loop, collect_data_loop
from .decision import Order
from .exchange import Exchange
from .backtest import backtest_loop
from .backtest import collect_data_loop
from .position import Position
from .utils import CommonInfrastructure
from .decision import Order
from ..utils import init_instance_by_config
from ..log import get_module_logger
from ..config import C

# make import more user-friendly by adding `from qlib.backtest import STH`

Expand All @@ -28,26 +33,34 @@


def get_exchange(
exchange=None,
freq="day",
start_time=None,
end_time=None,
codes="all",
subscribe_fields=[],
open_cost=0.0015,
close_cost=0.0025,
min_cost=5.0,
limit_threshold=None,
exchange: Union[str, dict, object, Path] = None,
freq: str = "day",
start_time: Union[pd.Timestamp, str] = None,
end_time: Union[pd.Timestamp, str] = None,
codes: Union[list, str] = "all",
subscribe_fields: list = [],
open_cost: float = 0.0015,
close_cost: float = 0.0025,
min_cost: float = 5.0,
limit_threshold: Union[Tuple[str, str], float, None] = None,
deal_price: Union[str, Tuple[str], List[str]] = None,
**kwargs,
):
) -> Exchange:
"""get_exchange

Parameters
----------

# exchange related arguments
exchange: Exchange().
exchange: Exchange(). It could be None or any types that are acceptable by `init_instance_by_config`.
freq: str
frequency of data.
start_time: Union[pd.Timestamp, str]
closed start time for backtest.
end_time: Union[pd.Timestamp, str]
closed end time for backtest.
codes: list|str
list stock_id list or a string of instruments (i.e. all, csi500, sse50)
subscribe_fields: list
subscribe fields.
open_cost : float
Expand All @@ -57,8 +70,6 @@ def get_exchange(
min_cost : float
min transaction cost. It is an absolute amount of cost instead of a ratio of your order's deal amount.
e.g. You must pay at least 5 yuan of commission regardless of your order's deal amount.
trade_unit : int
Included in kwargs. Please refer to the docs of `__init__` of `Exchange`
deal_price: Union[str, Tuple[str], List[str]]
The `deal_price` supports following two types of input
- <deal_price> : str
Expand Down Expand Up @@ -101,10 +112,14 @@ def get_exchange(


def create_account_instance(
start_time, end_time, benchmark: str, account: Union[float, int, dict], pos_type: str = "Position"
start_time: Union[pd.Timestamp, str],
end_time: Union[pd.Timestamp, str],
benchmark: str,
account: Union[float, int, dict],
pos_type: str = "Position",
) -> Account:
"""
# TODO: is very strange pass benchmark_config in the account(maybe for report)
# TODO: is very strange pass benchmark_config in the account (maybe for report)
# There should be a post-step to process the report.

Parameters
Expand Down Expand Up @@ -132,6 +147,8 @@ def create_account_instance(
key "cash" means initial cash.
key "stock1" means the information of first stock with amount and price(optional).
...
pos_type: str
Postion type.
"""
if isinstance(account, (int, float)):
pos_kwargs = {"init_cash": account}
Expand Down Expand Up @@ -159,15 +176,15 @@ def create_account_instance(


def get_strategy_executor(
start_time,
end_time,
strategy: BaseStrategy,
executor: BaseExecutor,
start_time: Union[pd.Timestamp, str],
end_time: Union[pd.Timestamp, str],
strategy: Union[str, dict, object, Path],
executor: Union[str, dict, object, Path],
benchmark: str = "SH000300",
account: Union[float, int, Position] = 1e9,
exchange_kwargs: dict = {},
pos_type: str = "Position",
):
) -> Tuple[BaseStrategy, BaseExecutor]:

# NOTE:
# - for avoiding recursive import
Expand All @@ -176,7 +193,11 @@ def get_strategy_executor(
from .executor import BaseExecutor # pylint: disable=C0415

trade_account = create_account_instance(
start_time=start_time, end_time=end_time, benchmark=benchmark, account=account, pos_type=pos_type
start_time=start_time,
end_time=end_time,
benchmark=benchmark,
account=account,
pos_type=pos_type,
)

exchange_kwargs = copy.copy(exchange_kwargs)
Expand All @@ -196,29 +217,31 @@ def get_strategy_executor(


def backtest(
start_time,
end_time,
strategy,
executor,
benchmark="SH000300",
account=1e9,
exchange_kwargs={},
start_time: Union[pd.Timestamp, str],
end_time: Union[pd.Timestamp, str],
strategy_config: Union[str, dict, object, Path],
executor_config: Union[str, dict, object, Path],
benchmark: str = "SH000300",
account: Union[float, int, Position] = 1e9,
exchange_kwargs: dict = {},
pos_type: str = "Position",
):
"""initialize the strategy and executor, then backtest function for the interaction of the outermost strategy and executor in the nested decision execution
) -> Tuple[PortfolioMetrics, Indicator]:
"""initialize the strategy and executor, then backtest function for the interaction of the outermost strategy and
executor in the nested decision execution

Parameters
----------
start_time : pd.Timestamp|str
start_time : Union[pd.Timestamp, str]
closed start time for backtest
**NOTE**: This will be applied to the outmost executor's calendar.
end_time : pd.Timestamp|str
end_time : Union[pd.Timestamp, str]
closed end time for backtest
**NOTE**: This will be applied to the outmost executor's calendar.
E.g. Executor[day](Executor[1min]), setting `end_time == 20XX0301` will include all the minutes on 20XX0301
strategy : Union[str, dict, BaseStrategy]
for initializing outermost portfolio strategy. Please refer to the docs of init_instance_by_config for more information.
executor : Union[str, dict, BaseExecutor]
strategy_config : Union[str, dict, object, Path]
for initializing outermost portfolio strategy. Please refer to the docs of init_instance_by_config for more
information.
executor_config : Union[str, dict, object, Path]
for initializing the outermost executor.
benchmark: str
the benchmark for reporting.
Expand All @@ -245,8 +268,8 @@ def backtest(
trade_strategy, trade_executor = get_strategy_executor(
start_time,
end_time,
strategy,
executor,
strategy_config,
executor_config,
benchmark,
account,
exchange_kwargs,
Expand All @@ -257,16 +280,16 @@ def backtest(


def collect_data(
start_time,
end_time,
strategy,
executor,
benchmark="SH000300",
account=1e9,
exchange_kwargs={},
start_time: Union[pd.Timestamp, str],
end_time: Union[pd.Timestamp, str],
strategy: Union[str, dict, object, Path],
executor: Union[str, dict, object, Path],
benchmark: str = "SH000300",
account: Union[float, int, Position] = 1e9,
exchange_kwargs: dict = {},
pos_type: str = "Position",
return_value: dict = None,
):
) -> Generator[object, None, None]:
"""initialize the strategy and executor, then collect the trade decision data for rl training

please refer to the docs of the backtest for the explanation of the parameters
Expand All @@ -291,7 +314,7 @@ def collect_data(

def format_decisions(
decisions: List[BaseTradeDecision],
) -> Tuple[str, List[Tuple[BaseTradeDecision, Union[Tuple, None]]]]:
) -> Optional[Tuple[str, List[Tuple[BaseTradeDecision, Union[Tuple, None]]]]]:
"""
format the decisions collected by `qlib.backtest.collect_data`
The decisions will be organized into a tree-like structure.
Expand Down Expand Up @@ -326,4 +349,4 @@ def format_decisions(
return res


__all__ = ["Order"]
__all__ = ["Order", "backtest"]
Loading