Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Config: unified logic to retrieve text config #33219

Merged
merged 17 commits into from
Sep 4, 2024
2 changes: 1 addition & 1 deletion .circleci/parse_test_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,4 @@ def main():


if __name__ == "__main__":
main()
main()
37 changes: 33 additions & 4 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,10 +1026,9 @@ def _get_non_default_generation_parameters(self) -> Dict[str, Any]:
try:
default_config = self.__class__()
except ValueError:
for decoder_attribute_name in ("decoder", "generator", "text_config"):
if hasattr(self, decoder_attribute_name):
default_config = getattr(self, decoder_attribute_name).__class__()
break
decoder_config = self.get_text_config(decoder=True)
if decoder_config is not self:
default_config = decoder_config.__class__()

# If it is a composite model, we want to check the subconfig that will be used for generation
self_decoder_config = self if decoder_attribute_name is None else getattr(self, decoder_attribute_name)
Expand Down Expand Up @@ -1057,6 +1056,36 @@ def _get_non_default_generation_parameters(self) -> Dict[str, Any]:

return non_default_generation_parameters

def get_text_config(self, decoder=False) -> "PretrainedConfig":
"""
Returns the config that is meant to be used with text IO. On most models, it is the original config instance
itself. On specific composite models, it is under a set of valid names.

If `decoder` is set to `True`, then only search for decoder config names.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for my own understanding, what does it mean to search in "decoder config names"? Is it somehow related to a model being an encoder-decoder or decoder-only?

From what I see, text_encoder is used in Musicgen only and we never used the decoder=False in transformers

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is indeed mostly for musicgen at the moment. Needing/wanting to use the encoder text config is uncommon, but there is at least one pair of use cases:

  1. resize_token_embeddings (has decoder=False)
  2. The tests for resize_token_embeddings

"""
decoder_possible_text_config_names = ("decoder", "generator", "text_config")
encoder_possible_text_config_names = ("text_encoder",)
if decoder:
possible_text_config_names = decoder_possible_text_config_names
else:
possible_text_config_names = encoder_possible_text_config_names + decoder_possible_text_config_names

valid_text_config_names = []
for text_config_name in possible_text_config_names:
if hasattr(self, text_config_name):
text_config = getattr(self, text_config_name, None)
if text_config is not None:
valid_text_config_names += [text_config_name]

if len(valid_text_config_names) > 1:
raise ValueError(
f"Multiple valid text configs were found in the model config: {valid_text_config_names}. "
"Either don't use `get_text_config`, as it is ambiguous -- access the text configs directly."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't entirely understand what this comment hints at doing 😅

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, I've updated the code but left this exception halfway 😅

)
elif len(valid_text_config_names) == 1:
return getattr(self, valid_text_config_names[0])
return self


def get_configuration_file(configuration_files: List[str]) -> str:
"""
Expand Down
26 changes: 14 additions & 12 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,20 +1177,22 @@ def from_model_config(cls, model_config: PretrainedConfig) -> "GenerationConfig"
"""
config_dict = model_config.to_dict()
config_dict.pop("_from_model_config", None)
config = cls.from_dict(config_dict, return_unused_kwargs=False, _from_model_config=True)
generation_config = cls.from_dict(config_dict, return_unused_kwargs=False, _from_model_config=True)

# Special case: some models have generation attributes set in the decoder. Use them if still unset in the
# generation config.
for decoder_name in ("decoder", "generator", "text_config"):
if decoder_name in config_dict:
default_generation_config = GenerationConfig()
decoder_config = config_dict[decoder_name]
for attr in config.to_dict().keys():
if attr in decoder_config and getattr(config, attr) == getattr(default_generation_config, attr):
setattr(config, attr, decoder_config[attr])

config._original_object_hash = hash(config) # Hash to detect whether the instance was modified
return config
# generation config (which in turn is defined from the outer attributes of model config).
decoder_config = model_config.get_text_config(decoder=True)
if decoder_config is not model_config:
default_generation_config = GenerationConfig()
decoder_config_dict = decoder_config.to_dict()
for attr in generation_config.to_dict().keys():
is_unset = getattr(generation_config, attr) == getattr(default_generation_config, attr)
if attr in decoder_config_dict and is_unset:
setattr(generation_config, attr, decoder_config_dict[attr])

# Hash to detect whether the instance was modified
generation_config._original_object_hash = hash(generation_config)
return generation_config

def update(self, **kwargs):
"""
Expand Down
12 changes: 3 additions & 9 deletions src/transformers/integrations/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,7 @@ def get_modules_to_fuse(model, quantization_config):
current_fused_mapping = AWQ_FUSED_MAPPINGS[model.config.model_type]

# Properly deal with the case where we have a multi-modal model as well (e.g. Llava)
if not hasattr(model.config, "text_config"):
config = model.config
else:
config = model.config.text_config
config = model.config.get_text_config(decoder=True)

# Handle hidden_size, num_attention_heads, num_key_value_heads on our own.
hidden_size = config.hidden_size
Expand Down Expand Up @@ -345,11 +342,8 @@ def _fuse_awq_mlp(model, current_module_name, fuse_module_names, module, target_
previous_device = gate_proj.qweight.device

# Deal also with the case model has `text_config` attribute
hidden_act = (
model.config.hidden_act
if not hasattr(model.config, "text_config")
else model.config.text_config.hidden_act
)
config = model.config.get_text_config(decoder=True)
hidden_act = config.hidden_act
activation_fn = ACT2FN[hidden_act]
new_module = target_cls(gate_proj, down_proj, up_proj, activation_fn)

Expand Down
7 changes: 2 additions & 5 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2025,11 +2025,8 @@ def resize_token_embeddings(
else:
vocab_size = model_embeds.weight.shape[0]

# Update base model and current model config
if hasattr(self.config, "text_config"):
self.config.text_config.vocab_size = vocab_size
else:
self.config.vocab_size = vocab_size
# Update base model and current model config.
self.config.get_text_config().vocab_size = vocab_size
self.vocab_size = vocab_size

# Tie weights again if needed
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/clvp/modeling_clvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,7 @@ def _init_weights(self, module):
nn.init.normal_(module.fc1.proj.weight if getattr(module.fc1, "proj") else module.fc1.weight, std=fc_std)
nn.init.normal_(module.fc2.weight, std=in_proj_std)
elif isinstance(module, ClvpEncoder):
config = self.config.text_config if hasattr(self.config, "text_config") else self.config
config = self.config.get_text_config()
factor = config.initializer_factor
module.projection.weight.data.normal_(mean=0.0, std=factor * (config.hidden_size**-0.5))
elif isinstance(module, ClvpConditioningEncoder):
Expand Down
12 changes: 7 additions & 5 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,7 +831,7 @@ def test_constrained_beam_search_generate(self):

# Sample constraints
min_id = 3
max_id = config.vocab_size
max_id = config.get_text_config(decoder=True).vocab_size

force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
constraints = [
Expand Down Expand Up @@ -889,7 +889,7 @@ def test_constrained_beam_search_generate_dict_output(self):

# Sample constraints
min_id = 3
max_id = model.config.vocab_size
max_id = model.config.get_text_config(decoder=True).vocab_size
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
constraints = [
PhrasalConstraint(force_tokens),
Expand Down Expand Up @@ -2012,18 +2012,20 @@ def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_
self.assertTrue(output.past_key_values is None)

def _check_scores(self, batch_size, scores, length, config):
expected_shape = (batch_size, config.vocab_size)
vocab_size = config.get_text_config(decoder=True).vocab_size
expected_shape = (batch_size, vocab_size)
self.assertIsInstance(scores, tuple)
self.assertEqual(len(scores), length)
self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores))

def _check_logits(self, batch_size, scores, config):
vocab_size = config.get_text_config(decoder=True).vocab_size
self.assertIsInstance(scores, tuple)
self.assertListEqual([iter_scores.shape[0] for iter_scores in scores], [batch_size] * len(scores))
# vocabulary difference equal to one (imagegptmodel?) or zero (all other models)
vocab_diff = config.vocab_size - scores[0].shape[-1]
vocab_diff = vocab_size - scores[0].shape[-1]
self.assertTrue(vocab_diff in [0, 1])
self.assertListEqual([config.vocab_size - score.shape[-1] for score in scores], [vocab_diff] * len(scores))
self.assertListEqual([vocab_size - score.shape[-1] for score in scores], [vocab_diff] * len(scores))

def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
Expand Down
57 changes: 16 additions & 41 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1747,12 +1747,13 @@ def test_resize_position_vector_embeddings(self):
self.assertTrue(models_equal)

def test_resize_tokens_embeddings(self):
if not self.test_resize_embeddings:
self.skipTest(reason="test_resize_embeddings is set to `False`")
Comment on lines +1750 to +1751
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(moved the skip here: no point in spending compute if we are going to skip the test)


(
original_config,
inputs_dict,
) = self.model_tester.prepare_config_and_inputs_for_common()
if not self.test_resize_embeddings:
self.skipTest(reason="test_resize_embeddings is set to `False`")

for model_class in self.all_model_classes:
config = copy.deepcopy(original_config)
Expand All @@ -1764,18 +1765,15 @@ def test_resize_tokens_embeddings(self):
if self.model_tester.is_training is False:
model.eval()

model_vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size
model_vocab_size = config.get_text_config().vocab_size
# Retrieve the embeddings and clone theme
model_embed = model.resize_token_embeddings(model_vocab_size)
cloned_embeddings = model_embed.weight.clone()

# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
model_embed = model.resize_token_embeddings(model_vocab_size + 10)
new_model_vocab_size = (
model.config.text_config.vocab_size
if hasattr(model.config, "text_config")
else model.config.vocab_size
)
new_model_vocab_size = model.config.get_text_config().vocab_size

self.assertEqual(new_model_vocab_size, model_vocab_size + 10)
# Check that it actually resizes the embeddings matrix
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
Expand All @@ -1787,11 +1785,7 @@ def test_resize_tokens_embeddings(self):

# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
model_embed = model.resize_token_embeddings(model_vocab_size - 15)
new_model_vocab_size = (
model.config.text_config.vocab_size
if hasattr(model.config, "text_config")
else model.config.vocab_size
)
new_model_vocab_size = model.config.get_text_config().vocab_size
self.assertEqual(new_model_vocab_size, model_vocab_size - 15)
# Check that it actually resizes the embeddings matrix
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15)
Expand All @@ -1817,21 +1811,13 @@ def test_resize_tokens_embeddings(self):
model = model_class(config)
model.to(torch_device)

model_vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size
model_vocab_size = config.get_text_config().vocab_size
model.resize_token_embeddings(model_vocab_size + 10, pad_to_multiple_of=1)
new_model_vocab_size = (
model.config.text_config.vocab_size
if hasattr(model.config, "text_config")
else model.config.vocab_size
)
new_model_vocab_size = model.config.get_text_config().vocab_size
self.assertTrue(new_model_vocab_size + 10, model_vocab_size)

model_embed = model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=64)
new_model_vocab_size = (
model.config.text_config.vocab_size
if hasattr(model.config, "text_config")
else model.config.vocab_size
)
new_model_vocab_size = model.config.get_text_config().vocab_size
self.assertTrue(model_embed.weight.shape[0] // 64, 0)

self.assertTrue(model_embed.weight.shape[0], new_model_vocab_size)
Expand All @@ -1852,13 +1838,10 @@ def test_resize_tokens_embeddings(self):
model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=1.3)

def test_resize_embeddings_untied(self):
(
original_config,
inputs_dict,
) = self.model_tester.prepare_config_and_inputs_for_common()
if not self.test_resize_embeddings:
self.skipTest(reason="test_resize_embeddings is set to `False`")

original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
original_config.tie_word_embeddings = False

# if model cannot untied embeddings -> leave test
Expand All @@ -1874,13 +1857,9 @@ def test_resize_embeddings_untied(self):
continue

# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
model_vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size
model_vocab_size = config.get_text_config().vocab_size
model.resize_token_embeddings(model_vocab_size + 10)
new_model_vocab_size = (
model.config.text_config.vocab_size
if hasattr(model.config, "text_config")
else model.config.vocab_size
)
new_model_vocab_size = model.config.get_text_config().vocab_size
self.assertEqual(new_model_vocab_size, model_vocab_size + 10)
output_embeds = model.get_output_embeddings()
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size + 10)
Expand All @@ -1892,11 +1871,7 @@ def test_resize_embeddings_untied(self):

# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
model.resize_token_embeddings(model_vocab_size - 15)
new_model_vocab_size = (
model.config.text_config.vocab_size
if hasattr(model.config, "text_config")
else model.config.vocab_size
)
new_model_vocab_size = model.config.get_text_config().vocab_size
self.assertEqual(new_model_vocab_size, model_vocab_size - 15)
# Check that it actually resizes the embeddings matrix
output_embeds = model.get_output_embeddings()
Expand Down Expand Up @@ -1988,7 +1963,7 @@ def check_same_values(layer_1, layer_2):
# self.assertTrue(check_same_values(embeddings, decoding))

# Check that after resize they remain tied.
vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size
vocab_size = config.get_text_config().vocab_size
model_tied.resize_token_embeddings(vocab_size + 10)
params_tied_2 = list(model_tied.parameters())
self.assertEqual(len(params_tied_2), len(params_tied))
Expand Down Expand Up @@ -4831,7 +4806,7 @@ def test_forward_with_num_logits_to_keep(self):

config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
batch_size, sequence_length = inputs["input_ids"].shape
vocab_size = config.vocab_size
vocab_size = config.get_text_config().vocab_size
model = model_class(config).to(device=torch_device).eval()

# num_logits_to_keep=0 is a special case meaning "keep all logits"
Expand Down
8 changes: 3 additions & 5 deletions tests/test_pipeline_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,14 +675,12 @@ def validate_test_components(test_case, task, model, tokenizer, processor):
# Avoid `IndexError` in embedding layers
CONFIG_WITHOUT_VOCAB_SIZE = ["CanineConfig"]
if tokenizer is not None:
config_vocab_size = getattr(model.config, "vocab_size", None)
# Removing `decoder=True` in `get_text_config` can lead to conflicting values e.g. in MusicGen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, we can add a flag to a) return the first valid match; OR b) return all matches when there is more than one match.

config_vocab_size = getattr(model.config.get_text_config(decoder=True), "vocab_size", None)
# For CLIP-like models
if config_vocab_size is None:
if hasattr(model.config, "text_config"):
if hasattr(model.config, "text_encoder"):
config_vocab_size = getattr(model.config.text_config, "vocab_size", None)
elif hasattr(model.config, "text_encoder"):
config_vocab_size = getattr(model.config.text_encoder, "vocab_size", None)

if config_vocab_size is None and model.config.__class__.__name__ not in CONFIG_WITHOUT_VOCAB_SIZE:
raise ValueError(
"Could not determine `vocab_size` from model configuration while `tokenizer` is not `None`."
Expand Down
Loading