Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Revert "[train] TransformersPredictor: Add support for custom pipeline class"" #36705

Merged
merged 3 commits into from
Jun 23, 2023

Conversation

krfricke
Copy link
Contributor

@krfricke krfricke commented Jun 22, 2023

Reverts #36701

This re-activates the changes in #36494 which were generally working. The problem was that an import of TFPreTrainedModel on a GPU instance seems to initialize the GPU and make it unusable by Ray workers, so that CUDA memory allocations fail.

Thus, imports of TF modules should be guarded behind the TYPE_CHECKING variable:

if TYPE_CHECKING:
    # ...
    from transformers.modeling_utils import PreTrainedModel
    from transformers.modeling_tf_utils import TFPreTrainedModel

@krfricke krfricke requested a review from Yard1 June 23, 2023 10:51
@krfricke krfricke added the tests-ok The tagger certifies test failures are unrelated and assumes personal liability. label Jun 23, 2023
@krfricke krfricke merged commit 0d9cb92 into master Jun 23, 2023
@krfricke krfricke deleted the revert-36701-revert-36494-train/hf-predictor branch June 23, 2023 19:17
arvind-chandra pushed a commit to lmco/ray that referenced this pull request Aug 31, 2023
… pipeline class"" (ray-project#36705)

Reverts ray-project#36701

This re-activates the changes in ray-project#36494 which were generally working. The problem was that an import of `TFPreTrainedModel` on a GPU instance seems to initialize the GPU and make it unusable by Ray workers, so that CUDA memory allocations fail.

Thus, imports of TF modules should be guarded behind the TYPE_CHECKING variable:

```
if TYPE_CHECKING:
    # ...
    from transformers.modeling_utils import PreTrainedModel
    from transformers.modeling_tf_utils import TFPreTrainedModel
```

Signed-off-by: Kai Fricke <kai@anyscale.com>
Signed-off-by: e428265 <arvind.chandramouli@lmco.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tests-ok The tagger certifies test failures are unrelated and assumes personal liability.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants