Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plot enhancement #1390

Merged
merged 16 commits into from
Dec 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 57 additions & 37 deletions qlib/contrib/report/analysis_model/analysis_model_performance.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from functools import partial

import pandas as pd

Expand All @@ -10,7 +11,11 @@

from scipy import stats

from typing import Sequence
from qlib.typehint import Literal

from ..graph import ScatterGraph, SubplotsGraph, BarGraph, HeatmapGraph
from ..utils import guess_plotly_rangebreaks


def _group_return(pred_label: pd.DataFrame = None, reverse: bool = False, N: int = 5, **kwargs) -> tuple:
Expand Down Expand Up @@ -48,12 +53,13 @@ def _group_return(pred_label: pd.DataFrame = None, reverse: bool = False, N: int
t_df["long-average"] = t_df["Group1"] - pred_label.groupby(level="datetime")["label"].mean()

t_df = t_df.dropna(how="all") # for days which does not contain label
# FIXME: support HIGH-FREQ
t_df.index = t_df.index.strftime("%Y-%m-%d")
# Cumulative Return By Group
group_scatter_figure = ScatterGraph(
t_df.cumsum(),
layout=dict(title="Cumulative Return", xaxis=dict(type="category", tickangle=45)),
layout=dict(
title="Cumulative Return",
xaxis=dict(tickangle=45, rangebreaks=kwargs.get("rangebreaks", guess_plotly_rangebreaks(t_df.index))),
),
).figure

t_df = t_df.loc[:, ["long-short", "long-average"]]
Expand Down Expand Up @@ -110,22 +116,36 @@ def _plot_qq(data: pd.Series = None, dist=stats.norm) -> go.Figure:
return fig


def _pred_ic(pred_label: pd.DataFrame = None, rank: bool = False, **kwargs) -> tuple:
def _pred_ic(
pred_label: pd.DataFrame = None, methods: Sequence[Literal["IC", "Rank IC"]] = ("IC", "Rank IC"), **kwargs
) -> tuple:
"""

:param pred_label:
:param rank:
:param pred_label: pd.DataFrame
must contain one column of realized return with name `label` and one column of predicted score names `score`.
:param methods: Sequence[Literal["IC", "Rank IC"]]
IC series to plot.
IC is sectional pearson correlation between label and score
Rank IC is the spearman correlation between label and score
For the Monthly IC, IC histogram, IC Q-Q plot. Only the first type of IC will be plotted.
:return:
"""
if rank:
ic = pred_label.groupby(level="datetime").apply(
lambda x: x["label"].rank(pct=True).corr(x["score"].rank(pct=True))
)
else:
ic = pred_label.groupby(level="datetime").apply(lambda x: x["label"].corr(x["score"]))
_methods_mapping = {"IC": "pearson", "Rank IC": "spearman"}

def _corr_series(x, method):
return x["label"].corr(x["score"], method=method)

_index = ic.index.get_level_values(0).astype("str").str.replace("-", "").str.slice(0, 6)
_monthly_ic = ic.groupby(_index).mean()
ic_df = pd.concat(
[
pred_label.groupby(level="datetime").apply(partial(_corr_series, method=_methods_mapping[m])).rename(m)
for m in methods
],
axis=1,
)
_ic = ic_df.iloc(axis=1)[0]

_index = _ic.index.get_level_values(0).astype("str").str.replace("-", "").str.slice(0, 6)
_monthly_ic = _ic.groupby(_index).mean()
_monthly_ic.index = pd.MultiIndex.from_arrays(
[_monthly_ic.index.str.slice(0, 4), _monthly_ic.index.str.slice(4, 6)],
names=["year", "month"],
Expand All @@ -148,27 +168,27 @@ def _pred_ic(pred_label: pd.DataFrame = None, rank: bool = False, **kwargs) -> t

_monthly_ic = _monthly_ic.reindex(fill_index)

_ic_df = ic.to_frame("ic")
ic_bar_figure = ic_figure(_ic_df, kwargs.get("show_nature_day", True))
ic_bar_figure = ic_figure(ic_df, kwargs.get("show_nature_day", False))

ic_heatmap_figure = HeatmapGraph(
_monthly_ic.unstack(),
layout=dict(title="Monthly IC", yaxis=dict(tickformat=",d")),
layout=dict(title="Monthly IC", xaxis=dict(dtick=1), yaxis=dict(tickformat="04d", dtick=1)),
graph_kwargs=dict(xtype="array", ytype="array"),
).figure

dist = stats.norm
_qqplot_fig = _plot_qq(ic, dist)
_qqplot_fig = _plot_qq(_ic, dist)

if isinstance(dist, stats.norm.__class__):
dist_name = "Normal"
else:
dist_name = "Unknown"

_ic_df = _ic.to_frame("IC")
_bin_size = ((_ic_df.max() - _ic_df.min()) / 20).min()
_sub_graph_data = [
(
"ic",
"IC",
dict(
row=1,
col=1,
Expand Down Expand Up @@ -202,12 +222,13 @@ def _pred_autocorr(pred_label: pd.DataFrame, lag=1, **kwargs) -> tuple:
pred = pred_label.copy()
pred["score_last"] = pred.groupby(level="instrument")["score"].shift(lag)
ac = pred.groupby(level="datetime").apply(lambda x: x["score"].rank(pct=True).corr(x["score_last"].rank(pct=True)))
# FIXME: support HIGH-FREQ
_df = ac.to_frame("value")
_df.index = _df.index.strftime("%Y-%m-%d")
ac_figure = ScatterGraph(
_df,
layout=dict(title="Auto Correlation", xaxis=dict(type="category", tickangle=45)),
layout=dict(
title="Auto Correlation",
xaxis=dict(tickangle=45, rangebreaks=kwargs.get("rangebreaks", guess_plotly_rangebreaks(_df.index))),
),
).figure
return (ac_figure,)

Expand All @@ -233,32 +254,33 @@ def _pred_turnover(pred_label: pd.DataFrame, N=5, lag=1, **kwargs) -> tuple:
"Bottom": bottom,
}
)
# FIXME: support HIGH-FREQ
r_df.index = r_df.index.strftime("%Y-%m-%d")
turnover_figure = ScatterGraph(
r_df,
layout=dict(title="Top-Bottom Turnover", xaxis=dict(type="category", tickangle=45)),
layout=dict(
title="Top-Bottom Turnover",
xaxis=dict(tickangle=45, rangebreaks=kwargs.get("rangebreaks", guess_plotly_rangebreaks(r_df.index))),
),
).figure
return (turnover_figure,)


def ic_figure(ic_df: pd.DataFrame, show_nature_day=True, **kwargs) -> go.Figure:
"""IC figure
r"""IC figure

:param ic_df: ic DataFrame
:param show_nature_day: whether to display the abscissa of non-trading day
:param \*\*kwargs: contains some parameters to control plot style in plotly. Currently, supports
- `rangebreaks`: https://plotly.com/python/time-series/#Hiding-Weekends-and-Holidays
:return: plotly.graph_objs.Figure
"""
if show_nature_day:
date_index = pd.date_range(ic_df.index.min(), ic_df.index.max())
ic_df = ic_df.reindex(date_index)
# FIXME: support HIGH-FREQ
ic_df.index = ic_df.index.strftime("%Y-%m-%d")
ic_bar_figure = BarGraph(
ic_df,
layout=dict(
title="Information Coefficient (IC)",
xaxis=dict(type="category", tickangle=45),
xaxis=dict(tickangle=45, rangebreaks=kwargs.get("rangebreaks", guess_plotly_rangebreaks(ic_df.index))),
),
).figure
return ic_bar_figure
Expand All @@ -272,9 +294,10 @@ def model_performance_graph(
rank=False,
graph_names: list = ["group_return", "pred_ic", "pred_autocorr"],
show_notebook: bool = True,
show_nature_day=True,
show_nature_day: bool = False,
**kwargs,
) -> [list, tuple]:
"""Model performance
r"""Model performance

:param pred_label: index is **pd.MultiIndex**, index name is **[instrument, datetime]**; columns names is **[score, label]**.
It is usually same as the label of model training(e.g. "Ref($close, -2)/Ref($close, -1) - 1").
Expand All @@ -297,17 +320,14 @@ def model_performance_graph(
:param graph_names: graph names; default ['cumulative_return', 'pred_ic', 'pred_autocorr', 'pred_turnover'].
:param show_notebook: whether to display graphics in notebook, the default is `True`.
:param show_nature_day: whether to display the abscissa of non-trading day.
:param \*\*kwargs: contains some parameters to control plot style in plotly. Currently, supports
- `rangebreaks`: https://plotly.com/python/time-series/#Hiding-Weekends-and-Holidays
:return: if show_notebook is True, display in notebook; else return `plotly.graph_objs.Figure` list.
"""
figure_list = []
for graph_name in graph_names:
fun_res = eval(f"_{graph_name}")(
pred_label=pred_label,
lag=lag,
N=N,
reverse=reverse,
rank=rank,
show_nature_day=show_nature_day,
pred_label=pred_label, lag=lag, N=N, reverse=reverse, rank=rank, show_nature_day=show_nature_day, **kwargs
)
figure_list += fun_res

Expand Down
2 changes: 1 addition & 1 deletion qlib/contrib/report/analysis_position/risk_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _get_risk_analysis_figure(analysis_df: pd.DataFrame) -> Iterable[py.Figure]:
_figure = SubplotsGraph(
_get_all_risk_analysis(analysis_df),
kind_map=dict(kind="BarGraph", kwargs={}),
subplots_kwargs={"rows": 4, "cols": 1},
subplots_kwargs={"rows": 1, "cols": 4},
).figure
return (_figure,)

Expand Down
11 changes: 7 additions & 4 deletions qlib/contrib/report/analysis_position/score_ic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pandas as pd

from ..graph import ScatterGraph
from ..utils import guess_plotly_rangebreaks


def _get_score_ic(pred_label: pd.DataFrame):
Expand All @@ -19,7 +20,7 @@ def _get_score_ic(pred_label: pd.DataFrame):
return pd.DataFrame({"ic": _ic, "rank_ic": _rank_ic})


def score_ic_graph(pred_label: pd.DataFrame, show_notebook: bool = True) -> [list, tuple]:
def score_ic_graph(pred_label: pd.DataFrame, show_notebook: bool = True, **kwargs) -> [list, tuple]:
"""score IC

Example:
Expand Down Expand Up @@ -53,11 +54,13 @@ def score_ic_graph(pred_label: pd.DataFrame, show_notebook: bool = True) -> [lis
:return: if show_notebook is True, display in notebook; else return **plotly.graph_objs.Figure** list.
"""
_ic_df = _get_score_ic(pred_label)
# FIXME: support HIGH-FREQ
_ic_df.index = _ic_df.index.strftime("%Y-%m-%d")

_figure = ScatterGraph(
_ic_df,
layout=dict(title="Score IC", xaxis=dict(type="category", tickangle=45)),
layout=dict(
title="Score IC",
xaxis=dict(tickangle=45, rangebreaks=kwargs.get("rangebreaks", guess_plotly_rangebreaks(_ic_df.index))),
),
graph_kwargs={"mode": "lines+markers"},
).figure
if show_notebook:
Expand Down
29 changes: 29 additions & 0 deletions qlib/contrib/report/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import matplotlib.pyplot as plt
import pandas as pd


def sub_fig_generator(sub_fs=(3, 3), col_n=10, row_n=1, wspace=None, hspace=None, sharex=False, sharey=False):
Expand Down Expand Up @@ -43,3 +44,31 @@ def sub_fig_generator(sub_fs=(3, 3), col_n=10, row_n=1, wspace=None, hspace=None
res = res.item()
yield res
plt.show()


def guess_plotly_rangebreaks(dt_index: pd.DatetimeIndex):
"""
This function `guesses` the rangebreaks required to remove gaps in datetime index.
It basically calculates the difference between a `continuous` datetime index and index given.

For more details on `rangebreaks` params in plotly, see
https://plotly.com/python/reference/layout/xaxis/#layout-xaxis-rangebreaks

Parameters
----------
dt_index: pd.DatetimeIndex
The datetimes of the data.

Returns
-------
the `rangebreaks` to be passed into plotly axis.

"""
dt_idx = dt_index.sort_values()
gaps = dt_idx[1:] - dt_idx[:-1]
min_gap = gaps.min()
gaps_to_break = {}
for gap, d in zip(gaps, dt_idx[:-1]):
if gap > min_gap:
gaps_to_break.setdefault(gap - min_gap, []).append(d + min_gap)
return [dict(values=v, dvalue=int(k.total_seconds() * 1000)) for k, v in gaps_to_break.items()]
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_version(rel_path: str) -> str:
# What packages are required for this module to be executed?
# `estimator` may depend on other packages. In order to reduce dependencies, it is not written here.
REQUIRED = [
"numpy>=1.12.0",
"numpy>=1.12.0, <1.24",
"pandas>=0.25.1",
"scipy>=1.0.0",
"requests>=2.18.0",
Expand Down