Skip to content

Commit 0258e47

Browse files
authored
feat: customize to tensor function in pytorch dataset (#1890)
User can define their own Tensor SerDe
1 parent 085b4d9 commit 0258e47

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

python/python/lance/torch/data.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,9 @@ def __init__(
150150
rank: Optional[int] = None,
151151
world_size: Optional[int] = None,
152152
shard_granularity: Optional[Literal["fragment", "batch"]] = "fragment",
153+
to_tensor_fn: Optional[
154+
callable[[pa.RecordBatch], Union[dict[str, torch.Tensor], torch.Tensor]]
155+
] = None,
153156
**kwargs,
154157
):
155158
super().__init__(*args, **kwargs)
@@ -159,6 +162,9 @@ def __init__(
159162
self.samples: Optional[int] = samples
160163
self.filter = filter
161164
self.with_row_id = with_row_id
165+
if to_tensor_fn is None:
166+
to_tensor_fn = _to_tensor
167+
self._to_tensor_fn = to_tensor_fn
162168

163169
# As Shared Dataset
164170
self.rank = rank
@@ -217,5 +223,7 @@ def __iter__(self):
217223
stream = self.cached_ds
218224

219225
for batch in stream:
220-
yield _to_tensor(batch)
226+
if self._to_tensor_fn is not None:
227+
batch = self._to_tensor_fn(batch)
228+
yield batch
221229
del batch

0 commit comments

Comments
 (0)