Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RL backtest pipeline on 5-min data #1417

Merged
merged 20 commits into from
Feb 13, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions qlib/contrib/ops/high_freq.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class DayCumsum(ElemOperator):
Otherwise, the value is zero.
"""

def __init__(self, feature, start: str = "9:30", end: str = "14:59"):
def __init__(self, feature, start: str = "9:30", end: str = "14:59", data_granularity: int = 1):
self.feature = feature
self.start = datetime.strptime(start, "%H:%M")
self.end = datetime.strptime(end, "%H:%M")
Expand All @@ -80,15 +80,17 @@ def __init__(self, feature, start: str = "9:30", end: str = "14:59"):
self.noon_open = datetime.strptime("13:00", "%H:%M")
self.noon_close = datetime.strptime("15:00", "%H:%M")

self.start_id = time_to_day_index(self.start)
self.end_id = time_to_day_index(self.end)
self.data_granularity = data_granularity
self.start_id = time_to_day_index(self.start) // self.data_granularity
self.end_id = time_to_day_index(self.end) // self.data_granularity
assert 240 % self.data_granularity == 0

def period_cusum(self, df):
df = df.copy()
assert len(df) == 240
assert len(df) == 240 // self.data_granularity
df.iloc[0 : self.start_id] = 0
df = df.cumsum()
df.iloc[self.end_id + 1 : 240] = 0
df.iloc[self.end_id + 1 : 240 // self.data_granularity] = 0
return df

def _load_internal(self, instrument, start_index, end_index, freq):
Expand Down
6 changes: 3 additions & 3 deletions qlib/rl/contrib/backtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def _get_multi_level_executor_config(
"class": "SimulatorExecutor",
"module_path": "qlib.backtest.executor",
"kwargs": {
"time_per_step": "1min",
"time_per_step": "5min", # FIXME: move this into config
"verbose": False,
"trade_type": SimulatorExecutor.TT_PARAL if cash_limit is not None else SimulatorExecutor.TT_SERIAL,
"generate_report": generate_report,
Expand Down Expand Up @@ -187,7 +187,7 @@ def single_with_simulator(
exchange_config.update(
{
"codes": stocks,
"freq": "1min",
"freq": "5min", # FIXME: move this into config
}
)

Expand Down Expand Up @@ -286,7 +286,7 @@ def single_with_collect_data_loop(
exchange_config.update(
{
"codes": stocks,
"freq": "1min",
"freq": "5min", # FIXME: move this into config
}
)

Expand Down
2 changes: 1 addition & 1 deletion qlib/rl/contrib/naive_config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def get_backtest_config_fromfile(path: str) -> dict:
"debug_single_day": None,
"concurrency": -1,
"multiplier": 1.0,
"output_dir": "outputs/",
"output_dir": "outputs_backtest/",
"generate_report": False,
}
backtest_config = merge_a_into_b(a=backtest_config, b=backtest_config_default)
Expand Down
151 changes: 82 additions & 69 deletions qlib/rl/contrib/train_onpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import argparse
import os
import random
import warnings
from pathlib import Path
from typing import cast, List, Optional

Expand Down Expand Up @@ -101,6 +102,7 @@ def train_and_test(
action_interpreter: ActionInterpreter,
policy: BasePolicy,
reward: Reward,
run_training: bool,
run_backtest: bool,
) -> None:
qlib.init()
Expand All @@ -122,62 +124,67 @@ def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple:
assert data_config["source"]["default_start_time_index"] % data_granularity == 0
assert data_config["source"]["default_end_time_index"] % data_granularity == 0

train_dataset, valid_dataset, test_dataset = [
LazyLoadDataset(
order_file_path=order_root_path / tag,
data_dir=Path(data_config["source"]["data_dir"]),
default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity,
default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity,
)
for tag in ("train", "valid", "test")
]
if run_training:
train_dataset, valid_dataset = [
LazyLoadDataset(
order_file_path=order_root_path / tag,
data_dir=Path(data_config["source"]["data_dir"]),
default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity,
default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity,
)
for tag in ("train", "valid")
]

if "checkpoint_path" in trainer_config:
callbacks: List[Callback] = []
callbacks.append(MetricsWriter(dirpath=Path(trainer_config["checkpoint_path"])))
callbacks.append(
Checkpoint(
dirpath=Path(trainer_config["checkpoint_path"]) / "checkpoints",
every_n_iters=trainer_config.get("checkpoint_every_n_iters", 1),
save_latest="copy",
),
)
if "earlystop_patience" in trainer_config:
callbacks.append(
EarlyStopping(
patience=trainer_config["earlystop_patience"],
monitor="val/pa",
if "checkpoint_path" in trainer_config:
callbacks.append(MetricsWriter(dirpath=Path(trainer_config["checkpoint_path"])))
callbacks.append(
Checkpoint(
dirpath=Path(trainer_config["checkpoint_path"]) / "checkpoints",
every_n_iters=trainer_config.get("checkpoint_every_n_iters", 1),
save_latest="copy",
),
)
if "earlystop_patience" in trainer_config:
callbacks.append(
EarlyStopping(
patience=trainer_config["earlystop_patience"],
monitor="val/pa",
)
)
)

trainer_kwargs = {
"max_iters": trainer_config["max_epoch"],
"finite_env_type": env_config["parallel_mode"],
"concurrency": env_config["concurrency"],
"val_every_n_iters": trainer_config.get("val_every_n_epoch", None),
"callbacks": callbacks,
}
vessel_kwargs = {
"episode_per_iter": trainer_config["episode_per_collect"],
"update_kwargs": {
"batch_size": trainer_config["batch_size"],
"repeat": trainer_config["repeat_per_collect"],
},
"val_initial_states": valid_dataset,
}

train(
simulator_fn=_simulator_factory_simple,
state_interpreter=state_interpreter,
action_interpreter=action_interpreter,
policy=policy,
reward=reward,
initial_states=cast(List[Order], train_dataset),
trainer_kwargs=trainer_kwargs,
vessel_kwargs=vessel_kwargs,
)
train(
simulator_fn=_simulator_factory_simple,
state_interpreter=state_interpreter,
action_interpreter=action_interpreter,
policy=policy,
reward=reward,
initial_states=cast(List[Order], train_dataset),
trainer_kwargs={
"max_iters": trainer_config["max_epoch"],
"finite_env_type": env_config["parallel_mode"],
"concurrency": env_config["concurrency"],
"val_every_n_iters": trainer_config.get("val_every_n_epoch", None),
"callbacks": callbacks,
},
vessel_kwargs={
"episode_per_iter": trainer_config["episode_per_collect"],
"update_kwargs": {
"batch_size": trainer_config["batch_size"],
"repeat": trainer_config["repeat_per_collect"],
},
"val_initial_states": valid_dataset,
},
)

if run_backtest:
test_dataset = LazyLoadDataset(
order_file_path=order_root_path / "test",
data_dir=Path(data_config["source"]["data_dir"]),
default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity,
default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity,
)

backtest(
simulator_fn=_simulator_factory_simple,
state_interpreter=state_interpreter,
Expand All @@ -186,35 +193,39 @@ def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple:
policy=policy,
logger=CsvWriter(Path(trainer_config["checkpoint_path"])),
reward=reward,
finite_env_type=trainer_kwargs["finite_env_type"],
concurrency=trainer_kwargs["concurrency"],
finite_env_type=env_config["parallel_mode"],
concurrency=env_config["concurrency"],
)


def main(config: dict, run_backtest: bool) -> None:
def main(config: dict, run_training: bool, run_backtest: bool) -> None:
if not run_training and not run_backtest:
warnings.warn("Skip the entire job since training and backtest are both skipped.")
return

if "seed" in config["runtime"]:
seed_everything(config["runtime"]["seed"])

state_config = config["state_interpreter"]
state_interpreter: StateInterpreter = init_instance_by_config(state_config)

state_interpreter: StateInterpreter = init_instance_by_config(config["state_interpreter"])
action_interpreter: ActionInterpreter = init_instance_by_config(config["action_interpreter"])
reward: Reward = init_instance_by_config(config["reward"])

additional_policy_kwargs = {
"obs_space": state_interpreter.observation_space,
"action_space": action_interpreter.action_space,
}

# Create torch network
if "kwargs" not in config["network"]:
config["network"]["kwargs"] = {}
config["network"]["kwargs"].update({"obs_space": state_interpreter.observation_space})
network: nn.Module = init_instance_by_config(config["network"])
if "network" in config:
if "kwargs" not in config["network"]:
config["network"]["kwargs"] = {}
config["network"]["kwargs"].update({"obs_space": state_interpreter.observation_space})
additional_policy_kwargs["network"] = init_instance_by_config(config["network"])

# Create policy
config["policy"]["kwargs"].update(
{
"network": network,
"obs_space": state_interpreter.observation_space,
"action_space": action_interpreter.action_space,
}
)
if "kwargs" not in config["policy"]:
config["policy"]["kwargs"] = {}
config["policy"]["kwargs"].update(additional_policy_kwargs)
policy: BasePolicy = init_instance_by_config(config["policy"])

use_cuda = config["runtime"].get("use_cuda", False)
Expand All @@ -230,6 +241,7 @@ def main(config: dict, run_backtest: bool) -> None:
state_interpreter=state_interpreter,
policy=policy,
reward=reward,
run_training=run_training,
run_backtest=run_backtest,
)

Expand All @@ -242,10 +254,11 @@ def main(config: dict, run_backtest: bool) -> None:

parser = argparse.ArgumentParser()
parser.add_argument("--config_path", type=str, required=True, help="Path to the config file")
parser.add_argument("--run_backtest", action="store_true", help="Run backtest workflow after training is finished")
parser.add_argument("--no_training", action="store_true", help="Skip training workflow.")
parser.add_argument("--run_backtest", action="store_true", help="Run backtest workflow.")
args = parser.parse_args()

with open(args.config_path, "r") as input_stream:
config = yaml.safe_load(input_stream)

main(config, run_backtest=args.run_backtest)
main(config, run_training=not args.no_training, run_backtest=args.run_backtest)
7 changes: 3 additions & 4 deletions qlib/rl/data/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,9 @@ def _convert_to_path(path: str | Path) -> Path:
return path if isinstance(path, Path) else Path(path)

provider_uri_map = {}
if "provider_uri_day" in qlib_config:
provider_uri_map["day"] = _convert_to_path(qlib_config["provider_uri_day"]).as_posix()
if "provider_uri_1min" in qlib_config:
provider_uri_map["1min"] = _convert_to_path(qlib_config["provider_uri_1min"]).as_posix()
for granularity in ["1min", "5min", "day"]:
if f"provider_uri_{granularity}" in qlib_config:
provider_uri_map[f"{granularity}"] = _convert_to_path(qlib_config[f"provider_uri_{granularity}"]).as_posix()

qlib.init(
region=REG_CN,
Expand Down
12 changes: 12 additions & 0 deletions qlib/rl/order_execution/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,18 @@ class FullHistoryObs(TypedDict):
target: Any
position: Any
position_history: Any


class NoReturnStateInterpreter(StateInterpreter[SAOEState, dict]):
"""Do not return any observation. For policies that do not need inputs (for example, AllOne).
"""

def interpret(self, state: SAOEState) -> dict:
return {"DUMMY": _to_int32(1)} # FIXME: A fake state, used to pass `check_nan_observation`

@property
def observation_space(self) -> spaces.Dict:
return spaces.Dict({"DUMMY": spaces.Box(-np.inf, np.inf, shape=(), dtype=np.int32)}) # FIXME:


class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]):
Expand Down
9 changes: 7 additions & 2 deletions qlib/rl/order_execution/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,19 @@ class AllOne(NonLearnablePolicy):

Useful when implementing some baselines (e.g., TWAP).
"""


def __init__(self, obs_space: gym.Space, action_space: gym.Space, fill_value: float | int = 1.0) -> None:
super().__init__(obs_space, action_space)

self.fill_value = fill_value

def forward(
self,
batch: Batch,
state: dict | Batch | np.ndarray = None,
**kwargs: Any,
) -> Batch:
return Batch(act=np.full(len(batch), 1.0), state=state)
return Batch(act=np.full(len(batch), self.fill_value), state=state)


# ppo #
Expand Down
49 changes: 49 additions & 0 deletions qlib/rl/order_execution/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np

from qlib.backtest.decision import OrderDir
from qlib.rl.order_execution.state import SAOEMetrics, SAOEState
from qlib.rl.reward import Reward

Expand Down Expand Up @@ -47,3 +48,51 @@ def reward(self, simulator_state: SAOEState) -> float:
self.log("reward/pa", pa)
self.log("reward/penalty", penalty)
return reward * self.scale


def _weighted_average(val: np.ndarray, weight: np.ndarray):
return (val * weight).sum() / weight.sum()


class PPOReward(Reward[SAOEState]):
"""Reward proposed by paper "An End-to-End Optimal Trade Execution Framework based on Proximal Policy Optimization".

Parameters
----------
max_step
Maximum number of steps.
start_time_index
First time index that allowed to trade.
end_time_index
Last time index that allowed to trade.
"""

def __init__(self, max_step: int, start_time_index: int = 0, end_time_index: int = 239) -> None:
self.max_step = max_step
self.start_time_index = start_time_index
self.end_time_index = end_time_index

def reward(self, simulator_state: SAOEState) -> float:
if simulator_state.cur_step == self.max_step - 1 or simulator_state.position < 1e-6:
traded = simulator_state.order.amount - simulator_state.position
trade_value = sum(simulator_state.history_steps["trade_value"])
intraday_data = simulator_state.backtest_data.data[self.start_time_index:self.end_time_index + 1]

if traded >= 1e-6:
vwap_price = trade_value / traded
else:
vwap_price = _weighted_average( # VWAP price of the entire day
val=intraday_data["$close0"].to_numpy(),
weight=intraday_data["$volume0"].to_numpy(),
)
twap_price = np.mean(intraday_data["$close0"].to_numpy())

ratio = vwap_price / twap_price if simulator_state.order.direction == OrderDir.SELL else twap_price / vwap_price
if ratio < 1.0:
return -1.0
elif ratio < 1.1:
return 0.0
else:
return 1.0
else:
return 0.0
Loading