Skip to content

Commit

Permalink
models module upgrades
Browse files Browse the repository at this point in the history
Summary: PSS2 upgrades that can be shipped pre-upgrade for the models module and tests

Reviewed By: islijepcevic

Differential Revision: D67507978

fbshipit-source-id: 712a143f8613abb189b9d227849878c0a3085405
  • Loading branch information
proof-by-accident authored and facebook-github-bot committed Dec 20, 2024
1 parent 10dd208 commit 35265cb
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 3 deletions.
2 changes: 1 addition & 1 deletion kats/models/bayesian_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def predict(
[Y_curr, look_ahead_pred[:, np.newaxis]], axis=1
)

times += ahead_times
times += list(ahead_times)

forecast_length = len(times)

Expand Down
3 changes: 2 additions & 1 deletion kats/models/globalmodel/backtester.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,8 @@ def _evaluate(
]
ans.extend(tmp_ans)
ensemble_fcst = np.median(
np.column_stack(fcst_all[i][k][j] for i in range(n)), axis=1
np.column_stack(fcst_all[i][k][j] for i in range(n)),
axis=1,
)
evl = eval_func(ensemble_fcst, tmp_actuals)
evl["step"] = j
Expand Down
2 changes: 1 addition & 1 deletion kats/models/metalearner/metalearner_hpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def __init__(
)
# pyre-fixme[4]: Attribute must be annotated.
self._dim_output_num = (
self._target_num.shape[1] if self.numerical_idx else 0
self._target_num.shape[1] if self._target_num is not None else 0
)
self._get_target_cat()
self._validate_data()
Expand Down

0 comments on commit 35265cb

Please sign in to comment.