Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BROS #23190

Merged
merged 120 commits into from
Sep 14, 2023
Merged

Add BROS #23190

Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
120 commits
Select commit Hold shift + click to select a range
3db764f
add Bros boilerplate
jinhopark8345 May 10, 2023
2dc368a
copy and pasted modeling_bros.py from official Bros repo
jinhopark8345 May 10, 2023
5603061
update copyright of bros files
jinhopark8345 May 10, 2023
dbc56b8
copy tokenization_bros.py from official repo and update import path
jinhopark8345 May 10, 2023
e2a2d9d
copy tokenization_bros_fast.py from official repo and update import path
jinhopark8345 May 10, 2023
90ce711
copy configuration_bros.py from official repo and update import path
jinhopark8345 May 10, 2023
6126da1
remove trailing period in copyright line
jinhopark8345 May 10, 2023
63139eb
copy and paste bros/__init__.py from official repo
jinhopark8345 May 10, 2023
596d1a7
save formatting
jinhopark8345 May 14, 2023
764e8df
remove unused unnecessary pe_type argument - using only crel type
jinhopark8345 May 14, 2023
f35348f
resolve import issue
jinhopark8345 May 14, 2023
892dd2d
remove unused model classes
jinhopark8345 May 14, 2023
37c7d9f
remove unnecessary tests
jinhopark8345 May 18, 2023
d878de0
remove unused classes
jinhopark8345 May 18, 2023
772d20e
fix original code's bug - layer_module's argument order
jinhopark8345 May 18, 2023
6ef6ca7
clean up modeling auto
jinhopark8345 May 18, 2023
c338261
add bbox to prepare_config_and_inputs
jinhopark8345 May 18, 2023
7379457
set temporary value to hidden_size (32 is too low because of the of the
jinhopark8345 May 18, 2023
602e2d9
remove decoder test, update create_and_check* input arguemnts
jinhopark8345 May 18, 2023
79b886c
add missing variable to model tests
jinhopark8345 May 18, 2023
5f35f68
do make fixup
jinhopark8345 May 20, 2023
3eace5d
update bros.mdx
jinhopark8345 May 21, 2023
9f0e8ca
add boilerate plate for no_head inference test
jinhopark8345 May 21, 2023
66ff6ce
update BROS_PRETRAINED_MODEL_ARCHIVE_LIST (add naver-clova-ocr prefix)
jinhopark8345 May 21, 2023
f3e9dab
add prepare_bros_batch_inputs function
jinhopark8345 May 21, 2023
7022d4c
update modeling_common to add bbox inputs in Bros Model Test
jinhopark8345 May 21, 2023
f9aab55
remove unnecessary model inference
jinhopark8345 May 22, 2023
41e1ad9
add test case
jinhopark8345 May 22, 2023
94cf5fc
add model_doc
jinhopark8345 May 23, 2023
d10e166
add test case for token_classification
jinhopark8345 May 24, 2023
2845c23
apply fixup
jinhopark8345 May 24, 2023
55d5d7b
update modeling code
jinhopark8345 Jul 23, 2023
e41ab5d
update BrosForTokenClassification loss calculation logic
jinhopark8345 Aug 1, 2023
4ef71fd
revert logits preprocessing logic to make sure logits have original s…
jinhopark8345 Aug 1, 2023
d735bd5
- update class name
jinhopark8345 Aug 8, 2023
5ce570e
- add BrosSpadeOutput
jinhopark8345 Aug 8, 2023
4933093
add boilerate plate for no_head inference test
jinhopark8345 Aug 8, 2023
0d53a2d
add prepare_bros_batch_inputs function
jinhopark8345 May 21, 2023
7228d98
add test case
jinhopark8345 May 22, 2023
9e758d7
add test case for token_classification
jinhopark8345 May 24, 2023
7be8d1d
update modeling code
jinhopark8345 Jul 23, 2023
ca9f5e8
update BrosForTokenClassification loss calculation logic
jinhopark8345 Aug 1, 2023
a6e77d7
revert logits preprocessing logic to make sure logits have original s…
jinhopark8345 Aug 1, 2023
13639d7
Merge remote-tracking branch 'upstream/main' into add-bros
jinhopark8345 Aug 9, 2023
725e145
apply masking on the fly
jinhopark8345 Aug 9, 2023
f5113b3
add BrosSpadeForTokenLinking
jinhopark8345 Aug 10, 2023
a955d3c
update class name
jinhopark8345 Aug 13, 2023
0cb524f
separate the logits calculation logic and loss calculation logic
jinhopark8345 Aug 13, 2023
5939860
update logic for loss calculation so that logits shape doesn't change
jinhopark8345 Aug 14, 2023
179c4f9
update typo
jinhopark8345 Aug 14, 2023
24d55f9
update prepare_config_and_inputs
jinhopark8345 Aug 14, 2023
aa28567
update dummy node initialization
jinhopark8345 Aug 15, 2023
d1a120f
update last_hidden_states getting logic to consider when return_dict …
jinhopark8345 Aug 15, 2023
ed5efb3
update box first token mask param
jinhopark8345 Aug 15, 2023
2d7bcc7
bugfix: remove random attention mask generation
jinhopark8345 Aug 15, 2023
8379565
update keys to ignore on load missing
jinhopark8345 Aug 15, 2023
632fde5
run make style and quality
jinhopark8345 Aug 15, 2023
1f2a956
apply make style and quality of other codes
jinhopark8345 Aug 15, 2023
d83b042
update box_first_token_mask to bool type
jinhopark8345 Aug 16, 2023
794dbba
update index.md
jinhopark8345 Aug 16, 2023
863155f
apply make style and quality
jinhopark8345 Aug 16, 2023
d18a5e6
Merge remote-tracking branch 'upstream/main' into add-bros
jinhopark8345 Aug 16, 2023
6f7c3d3
apply make fix-copies
jinhopark8345 Aug 16, 2023
eb7ba73
pass check_repo
jinhopark8345 Aug 17, 2023
16f4830
update bros model doc
jinhopark8345 Aug 18, 2023
ecce552
docstring bugfix fix
jinhopark8345 Aug 18, 2023
d927015
add checkpoint for doc, tokenizer for doc
jinhopark8345 Aug 18, 2023
46ec931
Update README.md
jinhopark8345 Aug 18, 2023
4433162
Update docs/source/en/model_doc/bros.md
jinhopark8345 Aug 18, 2023
828c9b0
Update bros.md
jinhopark8345 Aug 18, 2023
41de331
Update src/transformers/__init__.py
jinhopark8345 Aug 18, 2023
fa52d90
Update docs/source/en/model_doc/bros.md
jinhopark8345 Aug 18, 2023
d219760
Apply suggestions from code review
jinhopark8345 Aug 18, 2023
3b64b10
apply suggestions from code review
jinhopark8345 Aug 19, 2023
6811e44
apply suggestions from code review
jinhopark8345 Aug 19, 2023
8922ffa
revert test_processor_markuplm.py
jinhopark8345 Aug 19, 2023
09c3f82
Update test_processor_markuplm.py
jinhopark8345 Aug 19, 2023
a4d2e91
apply suggestions from code review
jinhopark8345 Aug 19, 2023
6bce6e1
apply suggestions from code review
jinhopark8345 Aug 21, 2023
0b3e750
apply suggestions from code review
jinhopark8345 Aug 21, 2023
a10fbac
update BrosSpadeELForTokenClassification head name to entity linker
jinhopark8345 Aug 21, 2023
336a94c
add doc string for config params
jinhopark8345 Aug 21, 2023
9da2fa4
update class, var names to more explicit and apply suggestions from c…
jinhopark8345 Aug 21, 2023
e2e304f
remove unnecessary keys to ignore
jinhopark8345 Aug 21, 2023
f621427
update relation extractor to be initialized with config
jinhopark8345 Aug 21, 2023
8a7d54c
add bros processor
jinhopark8345 Aug 21, 2023
fb7a991
apply make style and quality
jinhopark8345 Aug 21, 2023
9a47510
update bros.md
jinhopark8345 Aug 21, 2023
ab706c0
remove bros tokenizer, add bros processor that wraps bert tokenizer
jinhopark8345 Aug 21, 2023
5222230
revert change
jinhopark8345 Aug 21, 2023
3ef8bd5
apply make fix-copies
jinhopark8345 Aug 21, 2023
2a5a010
update processor code, update itc -> initial token, stc -> subsequent…
jinhopark8345 Aug 21, 2023
7761029
add type hint
jinhopark8345 Aug 21, 2023
e9449d1
remove unnecessary condition branches in embedding forward
jinhopark8345 Aug 21, 2023
b001e88
Merge remote-tracking branch 'upstream/main' into add-bros
jinhopark8345 Aug 21, 2023
6a22091
fix auto tokenizer fail
jinhopark8345 Aug 21, 2023
c16e4d8
update docstring for each classes
jinhopark8345 Aug 23, 2023
66f1446
Merge remote-tracking branch 'upstream/main' into add-bros
jinhopark8345 Aug 23, 2023
3f07cb4
update bbox input dimension as standard 2 points and convert them to 4
jinhopark8345 Aug 24, 2023
20a2bee
update bros docs
jinhopark8345 Aug 24, 2023
14e5591
Merge remote-tracking branch 'upstream/main' into add-bros
jinhopark8345 Aug 30, 2023
52dcb38
apply suggestions from code review : update Bros -> BROS in bros.md
jinhopark8345 Sep 2, 2023
6cdcaf2
1. box prefix var -> bbox
jinhopark8345 Sep 2, 2023
983ac62
replace einsum with torch matmul
jinhopark8345 Sep 4, 2023
007333a
apply style and quality
jinhopark8345 Sep 4, 2023
a51a66d
remove unused argument
jinhopark8345 Sep 4, 2023
0403675
remove unused arguments
jinhopark8345 Sep 4, 2023
e15b019
update docstrings
jinhopark8345 Sep 4, 2023
2b6a8f4
apply suggestions from code review: add BrosBboxEmbeddings, replace
jinhopark8345 Sep 5, 2023
039afcb
Merge remote-tracking branch 'upstream/main' into add-bros
jinhopark8345 Sep 5, 2023
0fb70f1
Merge remote-tracking branch 'upstream/main' into add-bros
jinhopark8345 Sep 8, 2023
1a8558b
revert einsum update
jinhopark8345 Sep 10, 2023
8eb78e1
update bros processor
jinhopark8345 Sep 10, 2023
44a0fc9
apply suggestions from code review
jinhopark8345 Sep 14, 2023
19993a7
add conversion script for bros
jinhopark8345 Sep 14, 2023
8fe9f5a
Apply suggestions from code review
jinhopark8345 Sep 14, 2023
e1d0c73
Merge remote-tracking branch 'upstream/main' into add-bros
jinhopark8345 Sep 14, 2023
9e883fb
fix readme
jinhopark8345 Sep 14, 2023
8223fed
apply fix-copies
jinhopark8345 Sep 14, 2023
187c411
Merge remote-tracking branch 'upstream/main' into add-bros
jinhopark8345 Sep 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
separate the logits calculation logic and loss calculation logic
  • Loading branch information
jinhopark8345 committed Aug 13, 2023
commit 0cb524f15e33f22ce7204c48a1d83f5ea8fae5fb
29 changes: 13 additions & 16 deletions src/transformers/models/bros/modeling_bros.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,35 +1122,32 @@ def forward(
itc_outputs = self.itc_layer(last_hidden_states).transpose(0, 1).contiguous()
stc_outputs = self.stc_layer(last_hidden_states, last_hidden_states).squeeze(0)

itc_logits = itc_outputs.view(-1, self.num_labels)

# calculate stc_logits
inv_attention_mask = 1 - attention_mask
bsz, max_seq_length = inv_attention_mask.shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
bsz, max_seq_length = inv_attention_mask.shape
batch_size, max_seq_length = inv_attention_mask.shape

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apply suggestions from code review.

device = inv_attention_mask.device
invalid_token_mask = torch.cat([inv_attention_mask, torch.zeros([bsz, 1]).to(device)], axis=1).bool()
stc_outputs.masked_fill_(invalid_token_mask[:, None, :], -10000.0)
self_token_mask = torch.eye(max_seq_length, max_seq_length + 1).to(device).bool()
stc_outputs.masked_fill_(self_token_mask[None, :, :], -10000.0)
stc_mask = attention_mask.view(-1).bool()
stc_logits = stc_outputs.view(-1, max_seq_length + 1)

loss = None
if itc_labels is not None and stc_labels is not None:
loss_fct = CrossEntropyLoss()

# get itc loss
itc_logits = itc_outputs.view(-1, self.num_labels)
itc_labels = itc_labels.view(-1)
if itc_mask is not None:
itc_mask = itc_mask.view(-1)
itc_loss = loss_fct(itc_logits[itc_mask], itc_labels[itc_mask])
else:
itc_loss = loss_fct(itc_logits, itc_labels)

# get stc loss
inv_attention_mask = 1 - attention_mask

bsz, max_seq_length = inv_attention_mask.shape
device = inv_attention_mask.device

invalid_token_mask = torch.cat([inv_attention_mask, torch.zeros([bsz, 1]).to(device)], axis=1).bool()
stc_outputs.masked_fill_(invalid_token_mask[:, None, :], -10000.0)

self_token_mask = torch.eye(max_seq_length, max_seq_length + 1).to(device).bool()
stc_outputs.masked_fill_(self_token_mask[None, :, :], -10000.0)

stc_mask = attention_mask.view(-1).bool()
stc_logits = stc_outputs.view(-1, max_seq_length + 1)
stc_labels = stc_labels.view(-1)

stc_loss = loss_fct(stc_logits[stc_mask], stc_labels[stc_mask])

loss = itc_loss + stc_loss
Expand Down
43 changes: 42 additions & 1 deletion tests/models/bros/test_modeling_bros.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
""" Testing suite for the PyTorch Bros model. """


import copy
import unittest

from transformers import BrosConfig, is_torch_available
Expand All @@ -29,6 +29,8 @@

from transformers import (
BrosForTokenClassification,
BrosSpadeEEForTokenClassification,
BrosSpadeELForTokenClassification,
BrosModel,
)
from transformers.models.bros.modeling_bros import (
Expand Down Expand Up @@ -162,6 +164,12 @@ def create_and_check_for_token_classification(
)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))

def create_and_check_for_spade_ee_token_classification(self):
...

def create_and_check_for_spade_el_token_classification(self):
...

def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
Expand All @@ -185,9 +193,15 @@ def prepare_config_and_inputs_for_common(self):

@require_torch
class BrosModelTest(ModelTesterMixin, unittest.TestCase):
test_pruning = False
test_torchscript = False
test_mismatched_shapes = False

all_model_classes = (
(
BrosForTokenClassification,
BrosSpadeEEForTokenClassification,
BrosSpadeELForTokenClassification,
BrosModel,
)
if is_torch_available()
Expand All @@ -199,13 +213,25 @@ def setUp(self):
self.model_tester = BrosModelTester(self)
self.config_tester = ConfigTester(self, config_class=BrosConfig, hidden_size=37)

def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = copy.deepcopy(inputs_dict)

if return_labels:
...
...

return inputs_dict

def test_config(self):
self.config_tester.run_common_tests()

def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)

def test_multi_gpu_data_parallel_forward(self):
pass

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be skipped with an explicit reason using a unittest.skip decorator

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated code to not just pass.

def test_model_various_embeddings(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
for type in ["absolute", "relative_key", "relative_key_query"]:
Expand All @@ -216,12 +242,27 @@ def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)

def test_for_spade_ee_token_classification(self):
...

def test_for_spade_el_token_classification(self):
...

def test_attention_outputs(self):
...

def test_hidden_states_output(self):
...

@slow
def test_model_from_pretrained(self):
for model_name in BROS_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = BrosModel.from_pretrained(model_name)
self.assertIsNotNone(model)

def test_initialization(self):
...


def prepare_bros_batch_inputs():
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
Expand Down