Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIR][DOC] Fix minor usage nits in the HuggingFace + AIR example #30637

Merged
merged 1 commit into from
Nov 28, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions doc/source/ray-air/examples/huggingface_text_classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@
"id": "4RRkXuteIrIh"
},
"source": [
"This notebook is built to run on any of the tasks in the list above, with any model checkpoint from the [Model Hub](https://huggingface.co/models) as long as that model has a version with a classification head. Depending on you model and the GPU you are using, you might need to adjust the batch size to avoid out-of-memory errors. Set those three parameters, then the rest of the notebook should run smoothly:"
"This notebook is built to run on any of the tasks in the list above, with any model checkpoint from the [Model Hub](https://huggingface.co/models) as long as that model has a version with a classification head. Depending on your model and the GPU you are using, you might need to adjust the batch size to avoid out-of-memory errors. Set those three parameters, then the rest of the notebook should run smoothly:"
]
},
{
Expand Down Expand Up @@ -313,7 +313,7 @@
"id": "RzfPtOMoIrIu"
},
"source": [
"The `dataset` object itself is [`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict), which contains one key for the training, validation and test set (with more keys for the mismatched validation and test set in the special case of `mnli`)."
"The `dataset` object itself is [`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict), which contains one key for the training, validation, and test set (with more keys for the mismatched validation and test set in the special case of `mnli`)."
]
},
{
Expand Down Expand Up @@ -363,9 +363,9 @@
"id": "YVx71GdAIrJH"
},
"source": [
"Before we can feed those texts to our model, we need to preprocess them. This is done by a 🤗 Transformers `Tokenizer` which will (as the name indicates) tokenize the inputs (including converting the tokens to their corresponding IDs in the pretrained vocabulary) and put it in a format the model expects, as well as generate the other inputs that model requires.\n",
"Before we can feed those texts to our model, we need to preprocess them. This is done by a 🤗 Transformers' `Tokenizer`, which will (as the name indicates) tokenize the inputs (including converting the tokens to their corresponding IDs in the pretrained vocabulary) and put it in a format the model expects, as well as generate the other inputs that model requires.\n",
"\n",
"To do all of this, we instantiate our tokenizer with the `AutoTokenizer.from_pretrained` method, which will ensure:\n",
"To do all of this, we instantiate our tokenizer with the `AutoTokenizer.from_pretrained` method, which will ensure that:\n",
"\n",
"- we get a tokenizer that corresponds to the model architecture we want to use,\n",
"- we download the vocabulary used when pretraining this specific checkpoint."
Expand Down Expand Up @@ -469,9 +469,9 @@
"id": "2C0hcmp9IrJQ"
},
"source": [
"We can them write the function that will preprocess our samples. We just feed them to the `tokenizer` with the argument `truncation=True`. This will ensure that an input longer that what the model selected can handle will be truncated to the maximum length accepted by the model.\n",
"We can then write the function that will preprocess our samples. We just feed them to the `tokenizer` with the argument `truncation=True`. This will ensure that an input longer than what the model selected can handle will be truncated to the maximum length accepted by the model.\n",
"\n",
"We use a `BatchMapper` to create a Ray AIR preprocessor that will map the function to the dataset in a distributed fashion. It will be ran during training and prediction."
"We use a `BatchMapper` to create a Ray AIR preprocessor that will map the function to the dataset in a distributed fashion. It will run during training and prediction."
]
},
{
Expand Down Expand Up @@ -524,13 +524,13 @@
"\n",
"We will not go into details about each specific component of the training (see the [original notebook](https://github.com/huggingface/notebooks/blob/6ca682955173cc9d36ffa431ddda505a048cbe80/examples/text_classification.ipynb) for that). The tokenizer is the same as we have used to encoded the dataset before.\n",
"\n",
"The main difference when using the Ray AIR is that we need to create our 🤗 Transformers `Trainer` inside a function (`trainer_init_per_worker`) and return it. That function will be passed to the `HuggingFaceTrainer` and ran on every Ray worker. The training will then proceed by the means of PyTorch DDP.\n",
"The main difference when using the Ray AIR is that we need to create our 🤗 Transformers `Trainer` inside a function (`trainer_init_per_worker`) and return it. That function will be passed to the `HuggingFaceTrainer` and will run on every Ray worker. The training will then proceed by the means of PyTorch DDP.\n",
"\n",
"Make sure that you initialize the model, metric and tokenizer inside that function. Otherwise, you may run into serialization errors.\n",
"Make sure that you initialize the model, metric, and tokenizer inside that function. Otherwise, you may run into serialization errors.\n",
"\n",
"Furthermore, `push_to_hub=True` is not yet supported. Ray will however checkpoint the model at every epoch, allowing you to push it to hub manually. We will do that after the training.\n",
"Furthermore, `push_to_hub=True` is not yet supported. Ray will, however, checkpoint the model at every epoch, allowing you to push it to hub manually. We will do that after the training.\n",
"\n",
"If you wish to use thrid party logging libraries, such as MLFlow or Weights&Biases, do not set them in `TrainingArguments` (they will be automatically disabled) - instead, you should be passing Ray AIR callbacks to `HuggingFaceTrainer`'s `run_config`. In this example, we will use MLFlow."
"If you wish to use thrid party logging libraries, such as MLflow or Weights&Biases, do not set them in `TrainingArguments` (they will be automatically disabled) - instead, you should pass Ray AIR callbacks to `HuggingFaceTrainer`'s `run_config`. In this example, we will use MLflow."
]
},
{
Expand Down Expand Up @@ -600,7 +600,7 @@
"source": [
"With our `trainer_init_per_worker` complete, we can now instantiate the `HuggingFaceTrainer`. Aside from the function, we set the `scaling_config`, controlling the amount of workers and resources used, and the `datasets` we will use for training and evaluation.\n",
"\n",
"We specify the `MlflowLoggerCallback` inside the `run_config`, and pass the preprocessor we have defined earlier as an argument. It will be included with the returned `Checkpoint`, meaning it will also be applied during inference."
"We specify the `MlflowLoggerCallback` inside the `run_config`, and pass the preprocessor we have defined earlier as an argument. The preprocessor will be included with the returned `Checkpoint`, meaning it will also be applied during inference."
]
},
{
Expand Down Expand Up @@ -633,7 +633,7 @@
"id": "XvS136zKhYba"
},
"source": [
"Finally, we call the `fit` method to being training with Ray AIR. We will save the `Result` object to a variable so we can access metrics and checkpoints."
"Finally, we call the `fit` method to start training with Ray AIR. We will save the `Result` object to a variable so we can access metrics and checkpoints."
]
},
{
Expand Down Expand Up @@ -2269,7 +2269,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.9"
"version": "3.9.6"
},
"vscode": {
"interpreter": {
Expand Down