Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simplify the code and prevent float when shifting #473

Merged
merged 2 commits into from
Jun 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions qlib/data/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def acquire(lock, lock_name):
lock.acquire()
except redis_lock.AlreadyAcquired:
raise QlibCacheException(
f"""It sees the key(lock:{repr(lock_name)[1:-1]}-wlock) of the redis lock has existed in your redis db now.
f"""It sees the key(lock:{repr(lock_name)[1:-1]}-wlock) of the redis lock has existed in your redis db now.
You can use the following command to clear your redis keys and rerun your commands:
$ redis-cli
> select {C.redis_task_db}
Expand Down Expand Up @@ -784,10 +784,10 @@ def append_index(self, data, to_disk=True):
def build_index_from_data(data, start_index=0):
if data.empty:
return pd.DataFrame()
line_data = data.iloc[:, 0].fillna(0).groupby("datetime").count()
line_data = data.groupby("datetime").size()
line_data.sort_index(inplace=True)
index_end = line_data.cumsum()
index_start = index_end.shift(1).fillna(0)
index_start = index_end.shift(1, fill_value=0)

index_data = pd.DataFrame()
index_data["start"] = index_start
Expand Down
4 changes: 3 additions & 1 deletion qlib/data/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,9 @@ class TSDatasetH(DatasetH):
- The dimension of a batch of data <batch_idx, feature, timestep>
"""

def __init__(self, step_len=30, **kwargs):
DEFAULT_STEP_LEN = 30

def __init__(self, step_len=DEFAULT_STEP_LEN, **kwargs):
self.step_len = step_len
super().__init__(**kwargs)

Expand Down
7 changes: 5 additions & 2 deletions qlib/workflow/online/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
"""

from typing import List, Union
from qlib.data.dataset import TSDatasetH

from qlib.log import get_module_logger
from qlib.utils import get_cls_kwargs
from qlib.workflow.online.update import PredUpdater
from qlib.workflow.recorder import Recorder
from qlib.workflow.task.utils import list_recorders
Expand Down Expand Up @@ -161,8 +163,9 @@ def update_online_pred(self, to_date=None):
hist_ref = 0
task = rec.load_object("task")
# Special treatment of historical dependencies
if task["dataset"]["class"] == "TSDatasetH":
hist_ref = task["dataset"]["kwargs"]["step_len"]
cls, kwargs = get_cls_kwargs(task["dataset"], default_module="qlib.data.dataset")
if issubclass(cls, TSDatasetH):
hist_ref = kwargs.get("step_len", TSDatasetH.DEFAULT_STEP_LEN)
PredUpdater(rec, to_date=to_date, hist_ref=hist_ref).update()

self.logger.info(f"Finished updating {len(online_models)} online model predictions of {self.exp_name}.")