Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): Clarify to_torch "features" and "label" parameter behaviour when return type is not "dataset" #16218

Merged
merged 1 commit into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1669,15 +1669,16 @@ def to_torch(
Set return type; a 2D PyTorch tensor, PolarsDataset (a frame-specialized
TensorDataset), or dict of Tensors.
label
One or more column names or expressions that label the feature data; when
`return_type` is "dataset", the PolarsDataset returns `(features, label)`
tensor tuples for each row. Otherwise, it returns `(features,)` tensor
tuples where the feature contains all the row data. This parameter is a
no-op for the other return-types.
One or more column names, expressions, or selectors that label the feature
data; when `return_type` is "dataset", the PolarsDataset will return
`(features, label)` tensor tuples for each row. Otherwise, it returns
`(features,)` tensor tuples where the feature contains all the row data;
note that setting this parameter with any other result type will raise an
informative error.
features
One or more column names or expressions that contain the feature data; if
omitted, all columns that are not designated as part of the label are used.
This parameter is a no-op for return-types other than "dataset".
One or more column names, expressions, or selectors that contain the feature
data; if omitted, all columns that are not designated as part of the label
are used. This parameter is a no-op for return-types other than "dataset".
dtype
Unify the dtype of all returned tensors; this casts any frame Series
that are not of the required dtype before converting to tensor. This
Expand Down Expand Up @@ -1770,6 +1771,10 @@ def to_torch(
... batch_size=64,
... ) # doctest: +SKIP
"""
if return_type != "dataset" and (label is not None or features is not None):
msg = "the `label` and `features` parameters can only be set when `return_type='dataset'`"
raise ValueError(msg)

torch = import_optional("torch")

if dtype in (UInt16, UInt32, UInt64):
Expand Down
12 changes: 12 additions & 0 deletions py-polars/tests/unit/dataframe/test_to_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,3 +276,15 @@ def test_misc_errors(self, df: pl.DataFrame) -> None:
match="tensors used as indices must be long, int",
):
_res2 = ds[torch.tensor([0, 3], dtype=torch.complex64)]

with pytest.raises(
ValueError,
match="`label` and `features` parameters .* when `return_type='dataset'`",
):
_res3 = df.to_torch(label="stroopwafel")

with pytest.raises(
ValueError,
match="`label` and `features` parameters .* when `return_type='dataset'`",
):
_res4 = df.to_torch("dict", features=cs.float())