Skip to content

Commit

Permalink
pre-compute inputs in Dask test functions
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb committed Mar 17, 2024
1 parent 09b3eb7 commit f1aa403
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,17 @@ def _create_data(objective, n_samples=1_000, output="array", chunk_size=500, **k


def _r2_score(dy_true, dy_pred):
numerator = ((dy_true - dy_pred) ** 2).sum(axis=0, dtype=np.float64)
denominator = ((dy_true - dy_true.mean(axis=0)) ** 2).sum(axis=0, dtype=np.float64)
return (1 - numerator / denominator).compute()
y_true = dy_true.compute()
y_pred = dy_pred.compute()
numerator = ((y_true - y_pred) ** 2).sum(axis=0)
denominator = ((y_true - y_true.mean(axis=0)) ** 2).sum(axis=0)
return 1 - numerator / denominator


def _accuracy_score(dy_true, dy_pred):
return da.average(dy_true == dy_pred).compute()
y_true = dy_true.compute()
y_pred = dy_pred.compute()
return (y_true == y_pred).mean()


def _constant_metric(y_true, y_pred):
Expand Down

0 comments on commit f1aa403

Please sign in to comment.