Skip to content

Commit 35d9da6

Browse files
committed
fea(): Added the roberta files
1 parent 417cbee commit 35d9da6

File tree

4 files changed

+174
-0
lines changed

4 files changed

+174
-0
lines changed

optimum/habana/transformers/modeling_utils.py

+4
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@
153153
GaudiQwen2VLModel,
154154
GaudiQwen2VLSdpaAttention,
155155
GaudiQwen2VLVisionBlock,
156+
GaudiRobertaForCausalLM,
156157
GaudiStableLmAttention,
157158
GaudiStableLmDecoderLayer,
158159
GaudiStableLmForCausalLM,
@@ -387,6 +388,9 @@ def adapt_transformers_to_gaudi():
387388
gaudi_BartForConditionalGeneration_prepare_inputs_for_generation
388389
)
389390

391+
# Overwrite Roberta fwd
392+
transformers.models.roberta.modeling_roberta.RobertaForCausalLM = GaudiRobertaForCausalLM
393+
390394
# Optimization for BERT on Gaudi
391395
transformers.models.bert.modeling_bert.BertModel.forward = gaudi_BertModel_forward
392396

optimum/habana/transformers/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@
262262
GaudiQwen2VLVisionBlock,
263263
GaudiVisionSdpaAttention,
264264
)
265+
from .roberta import GaudiRobertaForCausalLM
265266
from .seamless_m4t import (
266267
gaudi_SeamlessM4TAttention_forward,
267268
gaudi_SeamlessM4TCodeHifiGan_get_output_hifigan_lengths,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .modeling_roberta import GaudiRobertaForCausalLM
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
from typing import Optional, Tuple, Union
2+
3+
import torch
4+
import torch.utils.checkpoint
5+
from torch.nn import CrossEntropyLoss
6+
from transformers.modeling_outputs import (
7+
CausalLMOutputWithCrossAttentions,
8+
)
9+
from transformers.models.llama.modeling_llama import (
10+
KwargsForCausalLM,
11+
)
12+
from transformers.models.roberta.modeling_roberta import (
13+
ROBERTA_INPUTS_DOCSTRING,
14+
RobertaLMHead,
15+
RobertaModel,
16+
RobertaPreTrainedModel,
17+
)
18+
from transformers.processing_utils import Unpack
19+
from transformers.utils import add_start_docstrings_to_model_forward, logging, replace_return_docstrings
20+
21+
from ...generation.utils import GaudiGenerationMixin
22+
23+
24+
logger = logging.get_logger(__name__)
25+
_CONFIG_FOR_DOC = "RobertaConfig"
26+
27+
28+
class GaudiRobertaForCausalLM(RobertaPreTrainedModel, GaudiGenerationMixin):
29+
"""
30+
Updated from: https://github.com/huggingface/transformers/blob/v4.48.2/src/transformers/models/roberta/modeling_roberta.py with passing **kwargs to forward function
31+
"""
32+
33+
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
34+
35+
def __init__(self, config):
36+
super().__init__(config)
37+
38+
if not config.is_decoder:
39+
logger.warning("If you want to use `RobertaLMHeadModel` as a standalone, add `is_decoder=True.`")
40+
41+
self.roberta = RobertaModel(config, add_pooling_layer=False)
42+
self.lm_head = RobertaLMHead(config)
43+
44+
# Initialize weights and apply final processing
45+
self.post_init()
46+
47+
def get_output_embeddings(self):
48+
return self.lm_head.decoder
49+
50+
def set_output_embeddings(self, new_embeddings):
51+
self.lm_head.decoder = new_embeddings
52+
53+
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
54+
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
55+
def forward(
56+
self,
57+
input_ids: Optional[torch.LongTensor] = None,
58+
attention_mask: Optional[torch.FloatTensor] = None,
59+
token_type_ids: Optional[torch.LongTensor] = None,
60+
position_ids: Optional[torch.LongTensor] = None,
61+
head_mask: Optional[torch.FloatTensor] = None,
62+
inputs_embeds: Optional[torch.FloatTensor] = None,
63+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
64+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
65+
labels: Optional[torch.LongTensor] = None,
66+
past_key_values: Tuple[Tuple[torch.FloatTensor]] = None,
67+
use_cache: Optional[bool] = None,
68+
output_attentions: Optional[bool] = None,
69+
output_hidden_states: Optional[bool] = None,
70+
return_dict: Optional[bool] = None,
71+
**kwargs: Unpack[KwargsForCausalLM],
72+
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
73+
r"""
74+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
75+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
76+
the model is configured as a decoder.
77+
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
78+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
79+
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
80+
81+
- 1 for tokens that are **not masked**,
82+
- 0 for tokens that are **masked**.
83+
84+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
85+
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
86+
`[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
87+
ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
88+
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
89+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
90+
91+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
92+
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
93+
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
94+
use_cache (`bool`, *optional*):
95+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
96+
`past_key_values`).
97+
98+
Returns:
99+
100+
Example:
101+
102+
```python
103+
>>> from transformers import AutoTokenizer, RobertaForCausalLM, AutoConfig
104+
>>> import torch
105+
106+
>>> tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-base")
107+
>>> config = AutoConfig.from_pretrained("FacebookAI/roberta-base")
108+
>>> config.is_decoder = True
109+
>>> model = RobertaForCausalLM.from_pretrained("FacebookAI/roberta-base", config=config)
110+
111+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
112+
>>> outputs = model(**inputs)
113+
114+
>>> prediction_logits = outputs.logits
115+
```"""
116+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
117+
if labels is not None:
118+
use_cache = False
119+
120+
outputs = self.roberta(
121+
input_ids,
122+
attention_mask=attention_mask,
123+
token_type_ids=token_type_ids,
124+
position_ids=position_ids,
125+
head_mask=head_mask,
126+
inputs_embeds=inputs_embeds,
127+
encoder_hidden_states=encoder_hidden_states,
128+
encoder_attention_mask=encoder_attention_mask,
129+
past_key_values=past_key_values,
130+
use_cache=use_cache,
131+
output_attentions=output_attentions,
132+
output_hidden_states=output_hidden_states,
133+
return_dict=return_dict,
134+
)
135+
136+
sequence_output = outputs[0]
137+
prediction_scores = self.lm_head(sequence_output)
138+
139+
lm_loss = None
140+
if labels is not None:
141+
# move labels to correct device to enable model parallelism
142+
labels = labels.to(prediction_scores.device)
143+
# we are doing next-token prediction; shift prediction scores and input ids by one
144+
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
145+
labels = labels[:, 1:].contiguous()
146+
loss_fct = CrossEntropyLoss()
147+
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
148+
149+
if not return_dict:
150+
output = (prediction_scores,) + outputs[2:]
151+
return ((lm_loss,) + output) if lm_loss is not None else output
152+
153+
return CausalLMOutputWithCrossAttentions(
154+
loss=lm_loss,
155+
logits=prediction_scores,
156+
past_key_values=outputs.past_key_values,
157+
hidden_states=outputs.hidden_states,
158+
attentions=outputs.attentions,
159+
cross_attentions=outputs.cross_attentions,
160+
)
161+
162+
def _reorder_cache(self, past_key_values, beam_idx):
163+
reordered_past = ()
164+
for layer_past in past_key_values:
165+
reordered_past += (
166+
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
167+
)
168+
return reordered_past

0 commit comments

Comments
 (0)