Skip to content

Commit

Permalink
Update bros checkpoint (huggingface#26277)
Browse files Browse the repository at this point in the history
* fix bros integration test

* update bros checkpoint
  • Loading branch information
jinhopark8345 authored and parambharat committed Sep 26, 2023
1 parent 8c8f26f commit 3c88d60
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 25 deletions.
10 changes: 5 additions & 5 deletions src/transformers/models/bros/configuration_bros.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
logger = logging.get_logger(__name__)

BROS_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"naver-clova-ocr/bros-base-uncased": "https://huggingface.co/naver-clova-ocr/bros-base-uncased/resolve/main/config.json",
"naver-clova-ocr/bros-large-uncased": "https://huggingface.co/naver-clova-ocr/bros-large-uncased/resolve/main/config.json",
"jinho8345/bros-base-uncased": "https://huggingface.co/jinho8345/bros-base-uncased/blob/main/config.json",
"jinho8345/bros-large-uncased": "https://huggingface.co/jinho8345/bros-large-uncased/blob/main/config.json",
}


Expand All @@ -31,7 +31,7 @@ class BrosConfig(PretrainedConfig):
This is the configuration class to store the configuration of a [`BrosModel`] or a [`TFBrosModel`]. It is used to
instantiate a Bros model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the Bros
[naver-clova-ocr/bros-base-uncased](https://huggingface.co/naver-clova-ocr/bros-base-uncased) architecture.
[jinho8345/bros-base-uncased](https://huggingface.co/jinho8345/bros-base-uncased) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Expand Down Expand Up @@ -81,10 +81,10 @@ class BrosConfig(PretrainedConfig):
```python
>>> from transformers import BrosConfig, BrosModel
>>> # Initializing a BROS naver-clova-ocr/bros-base-uncased style configuration
>>> # Initializing a BROS jinho8345/bros-base-uncased style configuration
>>> configuration = BrosConfig()
>>> # Initializing a model from the naver-clova-ocr/bros-base-uncased style configuration
>>> # Initializing a model from the jinho8345/bros-base-uncased style configuration
>>> model = BrosModel(configuration)
>>> # Accessing the model configuration
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/bros/convert_bros_to_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def convert_bros_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_h
# Required parameters
parser.add_argument(
"--model_name",
default="naver-clova-ocr/bros-base-uncased",
default="jinho8345/bros-base-uncased",
required=False,
type=str,
help="Name of the original model you'd like to convert.",
Expand Down
22 changes: 11 additions & 11 deletions src/transformers/models/bros/modeling_bros.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@

logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "naver-clova-ocr/bros-base-uncased"
_CHECKPOINT_FOR_DOC = "jinho8345/bros-base-uncased"
_CONFIG_FOR_DOC = "BrosConfig"

BROS_PRETRAINED_MODEL_ARCHIVE_LIST = [
"naver-clova-ocr/bros-base-uncased",
"naver-clova-ocr/bros-large-uncased",
"jinho8345/bros-base-uncased",
"jinho8345/bros-large-uncased",
# See all Bros models at https://huggingface.co/models?filter=bros
]

Expand Down Expand Up @@ -846,9 +846,9 @@ def forward(
>>> import torch
>>> from transformers import BrosProcessor, BrosModel
>>> processor = BrosProcessor.from_pretrained("naver-clova-ocr/bros-base-uncased")
>>> processor = BrosProcessor.from_pretrained("jinho8345/bros-base-uncased")
>>> model = BrosModel.from_pretrained("naver-clova-ocr/bros-base-uncased")
>>> model = BrosModel.from_pretrained("jinho8345/bros-base-uncased")
>>> encoding = processor("Hello, my dog is cute", add_special_tokens=False, return_tensors="pt")
>>> bbox = torch.tensor([[[0, 0, 1, 1]]]).repeat(1, encoding["input_ids"].shape[-1], 1)
Expand Down Expand Up @@ -1011,9 +1011,9 @@ def forward(
>>> import torch
>>> from transformers import BrosProcessor, BrosForTokenClassification
>>> processor = BrosProcessor.from_pretrained("naver-clova-ocr/bros-base-uncased")
>>> processor = BrosProcessor.from_pretrained("jinho8345/bros-base-uncased")
>>> model = BrosForTokenClassification.from_pretrained("naver-clova-ocr/bros-base-uncased")
>>> model = BrosForTokenClassification.from_pretrained("jinho8345/bros-base-uncased")
>>> encoding = processor("Hello, my dog is cute", add_special_tokens=False, return_tensors="pt")
>>> bbox = torch.tensor([[[0, 0, 1, 1]]]).repeat(1, encoding["input_ids"].shape[-1], 1)
Expand Down Expand Up @@ -1130,9 +1130,9 @@ def forward(
>>> import torch
>>> from transformers import BrosProcessor, BrosSpadeEEForTokenClassification
>>> processor = BrosProcessor.from_pretrained("naver-clova-ocr/bros-base-uncased")
>>> processor = BrosProcessor.from_pretrained("jinho8345/bros-base-uncased")
>>> model = BrosSpadeEEForTokenClassification.from_pretrained("naver-clova-ocr/bros-base-uncased")
>>> model = BrosSpadeEEForTokenClassification.from_pretrained("jinho8345/bros-base-uncased")
>>> encoding = processor("Hello, my dog is cute", add_special_tokens=False, return_tensors="pt")
>>> bbox = torch.tensor([[[0, 0, 1, 1]]]).repeat(1, encoding["input_ids"].shape[-1], 1)
Expand Down Expand Up @@ -1261,9 +1261,9 @@ def forward(
>>> import torch
>>> from transformers import BrosProcessor, BrosSpadeELForTokenClassification
>>> processor = BrosProcessor.from_pretrained("naver-clova-ocr/bros-base-uncased")
>>> processor = BrosProcessor.from_pretrained("jinho8345/bros-base-uncased")
>>> model = BrosSpadeELForTokenClassification.from_pretrained("naver-clova-ocr/bros-base-uncased")
>>> model = BrosSpadeELForTokenClassification.from_pretrained("jinho8345/bros-base-uncased")
>>> encoding = processor("Hello, my dog is cute", add_special_tokens=False, return_tensors="pt")
>>> bbox = torch.tensor([[[0, 0, 1, 1]]]).repeat(1, encoding["input_ids"].shape[-1], 1)
Expand Down
13 changes: 5 additions & 8 deletions tests/models/bros/test_modeling_bros.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@
import copy
import unittest

from transformers import BrosProcessor
from transformers.testing_utils import require_torch, slow, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available
from transformers.utils import is_torch_available

from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
Expand Down Expand Up @@ -412,13 +411,10 @@ def prepare_bros_batch_inputs():

@require_torch
class BrosModelIntegrationTest(unittest.TestCase):
@cached_property
def default_processor(self):
return BrosProcessor.from_pretrained("naver-clova-ocr/bros-base-uncased") if is_vision_available() else None

@slow
def test_inference_no_head(self):
model = BrosModel.from_pretrained("naver-clova-ocr/bros-base-uncased").to(torch_device)
model = BrosModel.from_pretrained("jinho8345/bros-base-uncased").to(torch_device)

input_ids, bbox, attention_mask = prepare_bros_batch_inputs()

with torch.no_grad():
Expand All @@ -434,7 +430,8 @@ def test_inference_no_head(self):
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)

expected_slice = torch.tensor(
[[-0.4027, 0.0756, -0.0647], [-0.0192, -0.0065, 0.1042], [-0.0671, 0.0214, 0.0960]]
[[-0.3074, 0.1363, 0.3143], [0.0925, -0.1155, 0.1050], [0.0221, 0.0003, 0.1285]]
).to(torch_device)
torch.set_printoptions(sci_mode=False)

self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))

0 comments on commit 3c88d60

Please sign in to comment.