Skip to content

Commit

Permalink
New design of the transformer API (#1022)
Browse files Browse the repository at this point in the history
* implement new design of the Transformer API on top of the release-23.02 branch

* add support of ragged tensor to weight tying

* update example notebook with the new API

* include PR comments

* fix masking of sequence-predict-next transform

* adjust sample_weights to targets shape

* add masking support to SequencePredictRandom transform

* rebase with main branch to include data loader changes

* fix linting

* Fix the adjust-predictions logic to support targets as 2-D scalars

* Fix transformer example notebook

* update import of transformer blocks in transforms/sequence and move them inside configure_for_train() function.
  • Loading branch information
sararb authored Mar 21, 2023
1 parent 13e7fcb commit 62e0591
Show file tree
Hide file tree
Showing 9 changed files with 876 additions and 179 deletions.
42 changes: 29 additions & 13 deletions examples/usecases/transformers-next-item-prediction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,25 @@
"seq_schema"
]
},
{
"cell_type": "markdown",
"id": "0a87439c",
"metadata": {},
"source": [
"Align the schema of train and validation datasets with the model's schema"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2b90424a",
"metadata": {},
"outputs": [],
"source": [
"train_set_processed.schema = seq_schema\n",
"validation_set_processed.schema = seq_schema"
]
},
{
"cell_type": "markdown",
"id": "8d422833",
Expand Down Expand Up @@ -724,20 +743,20 @@
"id": "0a460e4c",
"metadata": {},
"source": [
"For the transformer portion of our model, we will use the `XLNet` architecture. Additionally, we are passing `mm.ReplaceMaskedEmbeddings()` as our `pre` block. We will be training a masked language model and this parameter is responsible for the masking of our sequences."
"For the transformer portion of our model, we will use the `XLNet` architecture."
]
},
{
"cell_type": "markdown",
"id": "23bf02dc",
"metadata": {},
"source": [
"Later, when we run the `fit` method on our model, we will specify the `masking_probability` of `0.3`. Through the combination of these parameters, our model will train on sequences where any given timestep will be masked with a probability of 0.3 and it will be our model's training task to infer the target value for that step!\n",
"Later, when we run the `fit` method on our model, we will specify the `masking_probability` of `0.3` and link it to the transformer block defined in out model. Through the combination of these parameters, our model will train on sequences where any given timestep will be masked with a probability of 0.3 and it will be our model's training task to infer the target value for that step!\n",
"\n",
"To summarize, Masked Language Modeling is implemented by using two blocks in combination:\n",
"To summarize, Masked Language Modeling is implemented by:\n",
"\n",
"* `SequenceMaskRandom()` - Used as a pre for model.fit(), it randomly selects items from the sequence to be masked for prediction as targets, by using Keras masking.\n",
"* `ReplaceMaskedEmbeddings()` - Used as a pre for a `TransformerBlock`, it replaces the input embeddings at masked positions for prediction by a dummy trainable embedding, to avoid leakage of the targets.\n",
"* `SequenceMaskRandom()` - Used as a pre for model.fit(), it randomly selects items from the sequence to be masked for prediction as targets, by using Keras masking. This block also adds the necessary configuration to the specified `transformer` block so as it\n",
"is pre-configured with the necessary layers needed to prepare the inputs to the HuggingFace transformer layer and to post-process its outputs. For example, one pre-processing operation is to replace the input embeddings at masked positions for prediction by a dummy trainable embedding, to avoid leakage of the targets.\n",
"\n",
"\n",
"**Read more about the apis used to construct models** \n",
Expand All @@ -746,7 +765,6 @@
"- [InputBlockV2](https://github.com/NVIDIA-Merlin/models/blob/main/merlin/models/tf/inputs/base.py)\n",
"- [Embeddings](https://github.com/NVIDIA-Merlin/models/blob/main/merlin/models/tf/inputs/embedding.py)\n",
"- [XLNetBlock](https://github.com/NVIDIA-Merlin/models/blob/main/merlin/models/tf/transformers/block.py)\n",
"- [ReplaceMaskedEmbeddings](https://github.com/NVIDIA-Merlin/models/blob/main/merlin/models/tf/transforms/sequence.py)\n",
"- [CategoricalOutput](https://github.com/NVIDIA-Merlin/models/blob/main/merlin/models/tf/outputs/classification.py)\n",
"- [.schema.select_by_name](https://github.com/NVIDIA-Merlin/core/blob/main/merlin/schema/schema.py)\n",
"- [.schema.select_by_tag](https://github.com/NVIDIA-Merlin/core/blob/main/merlin/schema/schema.py)\n",
Expand All @@ -770,6 +788,7 @@
" activation='relu',\n",
" no_activation_last_layer=True,\n",
" )\n",
"transformer_block = mm.XLNetBlock(d_model=dmodel, n_head=4, n_layer=2)\n",
"model = mm.Model(\n",
" mm.InputBlockV2(\n",
" seq_schema,\n",
Expand All @@ -778,10 +797,7 @@
" ),\n",
" ),\n",
" mlp_block,\n",
" mm.XLNetBlock(d_model=dmodel, n_head=4, n_layer=2, \n",
" pre=mm.ReplaceMaskedEmbeddings(),\n",
" post=\"inference_hidden_state\",\n",
" ),\n",
" transformer_block,\n",
" mm.CategoricalOutput(\n",
" train_set_processed.schema.select_by_name(target),\n",
" default_loss=\"categorical_crossentropy\",\n",
Expand Down Expand Up @@ -891,7 +907,7 @@
],
"source": [
"model.compile(run_eagerly=False, optimizer='adam', loss=\"categorical_crossentropy\")\n",
"model.fit(train_set_processed, batch_size=64, epochs=5, pre=mm.SequenceMaskRandom(schema=seq_schema, target=target, masking_prob=0.3))"
"model.fit(train_set_processed, batch_size=64, epochs=5, pre=mm.SequenceMaskRandom(schema=seq_schema, target=target, masking_prob=0.3, transformer=transformer_block))"
]
},
{
Expand Down Expand Up @@ -960,7 +976,7 @@
"model.evaluate(\n",
" validation_set_processed,\n",
" batch_size=128,\n",
" pre=mm.SequenceMaskLast(schema=seq_schema, target=target),\n",
" pre=mm.SequenceMaskLast(schema=validation_set_processed.schema, target=target, transformer=transformer_block),\n",
" return_dict=True\n",
")"
]
Expand Down Expand Up @@ -1000,7 +1016,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
"version": "3.8.2"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit 62e0591

Please sign in to comment.