Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding gradient checkpointing to GPT2 #7446

Merged
merged 6 commits into from
Sep 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/transformers/configuration_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ class GPT2Config(PretrainedConfig):
:class:`~transformers.GPT2DoubleHeadsModel` and :class:`~transformers.TFGPT2DoubleHeadsModel`.

The dropout ratio to be used after the projection and activation.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.

Example::

Expand Down Expand Up @@ -142,6 +144,7 @@ def __init__(
summary_first_dropout=0.1,
bos_token_id=50256,
eos_token_id=50256,
gradient_checkpointing=False,
**kwargs
):
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
Expand All @@ -164,6 +167,7 @@ def __init__(
self.summary_activation = summary_activation
self.summary_first_dropout = summary_first_dropout
self.summary_proj_to_labels = summary_proj_to_labels
self.gradient_checkpointing = gradient_checkpointing

self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
Expand Down
40 changes: 29 additions & 11 deletions src/transformers/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# limitations under the License.
"""PyTorch OpenAI GPT-2 model."""


import os
import warnings
from dataclasses import dataclass
Expand Down Expand Up @@ -624,16 +623,35 @@ def forward(
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)

outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
if getattr(self.config, "gradient_checkpointing", False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if self.config.gradient_checkpointing: is nicer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most model configs don't actually have this attribute, only the ones that support checkpointing (AFAIK, Bert and Longformer for now) so it's less risky to do things this way.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is in modeling_gpt2.py which only works with configuration_gpt2.py. So if you add gradient_checkpointing to the config with default = False I don't see why this would be risky


def create_custom_forward(module):
def custom_forward(*inputs):
# checkpointing only works with tuple returns, not with lists
return tuple(output for output in module(*inputs, use_cache, output_attentions))

return custom_forward

outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
layer_past,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)

hidden_states, present = outputs[:2]
if use_cache is True:
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,8 +661,10 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
model,
device_ids=[self.args.local_rank],
output_device=self.args.local_rank,
find_unused_parameters=True,
find_unused_parameters=not getattr(model.config, "gradient_checkpointing", False),
)
# find_unused_parameters breaks checkpointing as per
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021

if self.tb_writer is not None:
self.tb_writer.add_text("args", self.args.to_json_string())
Expand Down
71 changes: 43 additions & 28 deletions tests/test_modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(
self.bos_token_id = vocab_size - 1
self.eos_token_id = vocab_size - 1

def prepare_config_and_inputs(self):
def prepare_config_and_inputs(self, gradient_checkpointing=False):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)

input_mask = None
Expand Down Expand Up @@ -127,6 +127,7 @@ def prepare_config_and_inputs(self):
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
return_dict=True,
gradient_checkpointing=gradient_checkpointing,
)

head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
Expand Down Expand Up @@ -269,6 +270,15 @@ def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mas
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))

def create_and_check_forward_and_backwards(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
model = GPT2LMHeadModel(config)
model.to(torch_device)

result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
result.loss.backward()

def create_and_check_double_lm_head_model(
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, *args
):
Expand Down Expand Up @@ -355,6 +365,10 @@ def test_gpt2_double_lm_head_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_double_lm_head_model(*config_and_inputs)

def test_gpt2_gradient_checkpointing(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome that you add a test!

config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True)
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs)

@slow
def test_model_from_pretrained(self):
for model_name in GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
Expand All @@ -366,33 +380,34 @@ def test_model_from_pretrained(self):
class GPT2ModelLanguageGenerationTest(unittest.TestCase):
@slow
def test_lm_generate_gpt2(self):
model = GPT2LMHeadModel.from_pretrained("gpt2")
model.to(torch_device)
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog
expected_output_ids = [
464,
3290,
373,
1043,
287,
257,
2214,
1474,
262,
16246,
286,
2688,
290,
2688,
27262,
13,
198,
198,
464,
3290,
] # The dog was found in a field near the intersection of West and West Streets.\n\nThe dog
output_ids = model.generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
for checkpointing in [True, False]:
model = GPT2LMHeadModel.from_pretrained("gpt2", gradient_checkpointing=checkpointing)
model.to(torch_device)
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog
expected_output_ids = [
464,
3290,
373,
1043,
287,
257,
2214,
1474,
262,
16246,
286,
2688,
290,
2688,
27262,
13,
198,
198,
464,
3290,
] # The dog was found in a field near the intersection of West and West Streets.\n\nThe dog
output_ids = model.generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)

@slow
def test_lm_generate_distilgpt2(self):
Expand Down