Skip to content

Commit

Permalink
[Data] Deprecate BatchPredictor (#36947)
Browse files Browse the repository at this point in the history
`BatchPredictor` has been superseded by `Dataset.map_batches`. To learn more, see ray-project/enhancements#25.
---------

Signed-off-by: Balaji Veeramani <balaji@anyscale.com>
  • Loading branch information
bveeramani authored Jun 30, 2023
1 parent 7275fad commit 811907d
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
10 changes: 8 additions & 2 deletions python/ray/train/batch_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Dict, Optional, List, Type, Union, Callable
import pandas as pd
import numpy as np
import warnings

import ray
from ray.air import Checkpoint
Expand All @@ -11,12 +12,12 @@
from ray.data import Dataset, DatasetPipeline, Preprocessor
from ray.data.context import DataContext
from ray.train.predictor import Predictor
from ray.util.annotations import PublicAPI
from ray.util.annotations import Deprecated

logger = logging.getLogger(__name__)


@PublicAPI(stability="beta")
@Deprecated
class BatchPredictor:
"""Batch predictor class.
Expand All @@ -30,6 +31,11 @@ class BatchPredictor:
def __init__(
self, checkpoint: Checkpoint, predictor_cls: Type[Predictor], **predictor_kwargs
):
warnings.warn(
"`BatchPredictor` is deprecated. Use `Dataset.map_batches` instead. To "
"learn more, see http://batchinference.io.",
DeprecationWarning,
)
self._checkpoint = checkpoint
# Store as object ref so we only serialize it once for all map workers
self._checkpoint_ref = ray.put(checkpoint)
Expand Down
8 changes: 8 additions & 0 deletions python/ray/train/tests/test_batch_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ def _transform_numpy(self, np_data):
return np_data * self.multiplier


def test_batch_predictor_warns_deprecation(shutdown_only):
with pytest.warns(DeprecationWarning):
BatchPredictor.from_checkpoint(
Checkpoint.from_dict({"factor": 0}),
DummyPredictorFS,
)


def test_repr(shutdown_only):
predictor = BatchPredictor.from_checkpoint(
Checkpoint.from_dict({"factor": 2.0}),
Expand Down

0 comments on commit 811907d

Please sign in to comment.