diff --git a/python/ray/train/batch_predictor.py b/python/ray/train/batch_predictor.py index 576d14699e5ff..284c2b026347e 100644 --- a/python/ray/train/batch_predictor.py +++ b/python/ray/train/batch_predictor.py @@ -3,6 +3,7 @@ from typing import Any, Dict, Optional, List, Type, Union, Callable import pandas as pd import numpy as np +import warnings import ray from ray.air import Checkpoint @@ -11,12 +12,12 @@ from ray.data import Dataset, DatasetPipeline, Preprocessor from ray.data.context import DataContext from ray.train.predictor import Predictor -from ray.util.annotations import PublicAPI +from ray.util.annotations import Deprecated logger = logging.getLogger(__name__) -@PublicAPI(stability="beta") +@Deprecated class BatchPredictor: """Batch predictor class. @@ -30,6 +31,11 @@ class BatchPredictor: def __init__( self, checkpoint: Checkpoint, predictor_cls: Type[Predictor], **predictor_kwargs ): + warnings.warn( + "`BatchPredictor` is deprecated. Use `Dataset.map_batches` instead. To " + "learn more, see http://batchinference.io.", + DeprecationWarning, + ) self._checkpoint = checkpoint # Store as object ref so we only serialize it once for all map workers self._checkpoint_ref = ray.put(checkpoint) diff --git a/python/ray/train/tests/test_batch_predictor.py b/python/ray/train/tests/test_batch_predictor.py index 1e2f5ef73cf8b..569db17a6b971 100644 --- a/python/ray/train/tests/test_batch_predictor.py +++ b/python/ray/train/tests/test_batch_predictor.py @@ -37,6 +37,14 @@ def _transform_numpy(self, np_data): return np_data * self.multiplier +def test_batch_predictor_warns_deprecation(shutdown_only): + with pytest.warns(DeprecationWarning): + BatchPredictor.from_checkpoint( + Checkpoint.from_dict({"factor": 0}), + DummyPredictorFS, + ) + + def test_repr(shutdown_only): predictor = BatchPredictor.from_checkpoint( Checkpoint.from_dict({"factor": 2.0}),