diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 4353a1132238..1ae006ba674c 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -908,6 +908,16 @@ def to_json_string(self, use_diff: bool = True, ignore_metadata: bool = False) - for metadata_field in METADATA_FIELDS: config_dict.pop(metadata_field, None) + def convert_keys_to_string(obj): + if isinstance(obj, dict): + return {str(key): convert_keys_to_string(value) for key, value in obj.items()} + elif isinstance(obj, list): + return [convert_keys_to_string(item) for item in obj] + else: + return obj + + config_dict = convert_keys_to_string(config_dict) + return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True):