forked from huggingface/transformers
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodular_model_converter.py
1205 lines (1076 loc) · 61.7 KB
/
modular_model_converter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import glob
import importlib
import os
import re
from collections import defaultdict, deque
from typing import Dict, List, Optional, Set
import libcst as cst
from check_copies import run_ruff
from create_dependency_mapping import find_priority_list
from libcst import ClassDef, CSTTransformer, CSTVisitor
from libcst import matchers as m
from libcst.metadata import MetadataWrapper, ParentNodeProvider, PositionProvider, ScopeProvider
from transformers import logging
from transformers.models.auto.configuration_auto import CONFIG_MAPPING_NAMES
logger = logging.get_logger(__name__)
# This is used to avoid overwriting these top-level assignments even if they are in the dependency graph. Otherwise, the
# value from the dependency is used, then mapped to current name convention, resulting in wrong value.
# The corresponding mapped value is used to define the file target for the assignment
ASSIGNMENTS_TO_KEEP = {
"_CHECKPOINT_FOR_DOC": "modeling",
}
AUTO_GENERATED_MESSAGE = """# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from {relative_path}.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# {short_name} file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
"""
def get_module_source_from_name(module_name: str) -> str:
# Extract the source code from the module name
spec = importlib.util.find_spec(module_name)
if spec is None or spec.origin is None:
return f"Module {module_name} not found"
with open(spec.origin, "r", encoding="utf-8") as file:
source_code = file.read()
return source_code
class ClassFinder(CSTVisitor):
"""A visitor class which analyses a module, creating a mapping of dependencies between classes and functions.
For example if the visited code has
```python3
def init_value(): return 1
class LlamaModel(PreTrainedModel):
def __init__(self):
super().__init__(self)
self.value = init_value()
```
then the `class_dependency_mapping` should be: `{"LlamaModel":["PreTrainedModel","init_value"], "init_value":[]}
The dependency mapping is updated via the `visit_Name`, `visit_Arg` and `visit_Decorator`. This is very broad, and by
checking the parent node, or the scope of a `cst.Name` or `cst.Arg` or `cst.Decorator` we are able to map the
dependence parent -> child.
When visiting such nodes, we update the dependency of the parent node, to take into account the visited node.
All `visit_XXX` correspond to the code executed when vising the cst.Node of type XXX.
"""
METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider)
def __init__(self, python_module: cst.Module):
# fmt: off
self.python_module: cst.Module = python_module # original cst.Module being visited
self.classes: Dict[str, cst.ClassDef] = {} # stores a mapping from classname to the cst.Node
self.imports = {} # stores all import statements
self.function_def = {} # stores global scope function definition
self.assignments = {} # LLAMA_DOCSTRING
self.class_dependency_mapping = {} # "LlamaModel":["LlamaDecoderLayer, "LlamaRMSNorm", "LlamaPreTrainedModel"], "LlamaDecoderLayer":["LlamaAttention","Llama"]
self.first_lvl_dependency_mapping = {} # "LlamaModel":["LlamaDecoderLayer, "LlamaRMSNorm", "LlamaPreTrainedModel"], "LlamaDecoderLayer":["LlamaAttention","Llama"]
# fmt: on
def _update_class_dependency(self, name, value):
"""Update the dependency mapping for `name` with `value` by appending the previous
dependencies to the new `value`.
"""
dep = set(self.first_lvl_dependency_mapping.get(name, set())) | set({value})
self.first_lvl_dependency_mapping[name] = dep
dep = set(self.class_dependency_mapping.get(value, set()))
dep |= set(self.class_dependency_mapping.get(name, {})) | set({value})
self.class_dependency_mapping[name] = dep
def visit_ClassDef(self, node: ClassDef) -> None:
"""We don't have non global scope class defs in transformers. Here we add the inheritance dependencies"""
self.classes[node.name.value] = node
for k in node.bases: # deal with inheritance
base_name = self.python_module.code_for_node(k)
self._update_class_dependency(node.name.value, base_name)
def visit_SimpleStatementLine(self, node):
"""
Global Assigns like `GEMMA_INPUT_DOCSTRING = 'THIS IS THE INPUT' and all import statements
are extracted and saved in their corresponding dict. They are then used when updating dependency mappings.
"""
if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])) and m.matches(
self.get_metadata(cst.metadata.ParentNodeProvider, node), m.Module()
):
left_hand_side = node.body[0].targets[0].target
if hasattr(left_hand_side, "value"):
if left_hand_side.value not in ASSIGNMENTS_TO_KEEP.keys():
self.assignments[left_hand_side.value] = node
else:
for idx, target in enumerate(list(left_hand_side.elements)):
if target.value.value not in ASSIGNMENTS_TO_KEEP.keys():
self.assignments[target.value.value] = node.body[0].value.elements[idx].value
if m.matches(node, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()])):
self.imports[node.body[0].names] = node
def visit_FunctionDef(self, node):
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node)
if m.matches(parent_node, m.Module()):
self.function_def[node.name.value] = node
def leave_If(self, node):
for stmt in node.body.body:
if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])):
self.imports[stmt.body[0].names] = node
def leave_Name(self, node):
if node.value in self.classes.keys() | self.assignments.keys() | self.function_def.keys():
parent = self.get_metadata(cst.metadata.ScopeProvider, node)
if not isinstance(parent, cst.metadata.scope_provider.GlobalScope):
self._update_class_dependency(parent._name_prefix.split(".")[0], node.value)
def leave_Arg(self, node):
if m.matches(node.value, m.Name()):
parent = self.get_metadata(ParentNodeProvider, node)
if m.matches(parent, m.ClassDef()) and parent.bases:
self._update_class_dependency(parent.name.value, node.value.value)
def leave_Dict(self, node):
parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
if m.matches(parent, m.Assign(targets=[m.AssignTarget()])):
name = parent.targets[0].target.value
if name in self.assignments:
for k in node.elements:
dep_name = k.value.value
if dep_name in self.classes:
self._update_class_dependency(name, dep_name)
def leave_Decorator(self, node):
if hasattr(node.decorator, "args"):
for k in node.decorator.args:
if m.matches(k.value, m.Call(func=m.Attribute(value=m.Name()))): # and k.value.func.value.value:
if k.value.func.value.value not in self.assignments:
raise ValueError(
f"We detected a call to {k.value.func.value.value}, but it was not assigned. See the list of assigments {self.assignments.keys()}"
)
parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
scope = self.get_metadata(cst.metadata.ScopeProvider, node)
name = scope._name_prefix.split(".")[0] if scope._name_prefix != "" else parent.name.value
self._update_class_dependency(name, k.value.func.value.value)
elif m.matches(k, m.Arg(value=m.Name())) and k.value.value in self.assignments:
parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
scope = self.get_metadata(cst.metadata.ScopeProvider, node)
name = scope._name_prefix.split(".")[0] if scope._name_prefix != "" else parent.name.value
self._update_class_dependency(name, k.value.value)
def leave_Module(self, node):
"""When leaving the module, we store the position of each global scoped node (Assigns, function def and class def)
to allow sorting the dependencies based on their position in the code. We use the PositionProvider metadata wrapper for this.
"""
self.global_nodes = {**self.assignments, **self.classes, **self.function_def}
# now sort the class dependency_mapping based on the position of the nodes
self.class_start_line = {}
for id, node in self.global_nodes.items():
self.class_start_line[id] = self.get_metadata(cst.metadata.PositionProvider, node).start.line
class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
"""A transformer that replaces `old_name` with `new_name` in comments, string and any references.
It should take into account name like `MyNewModel`, or `my_new_model`. Without using the AUTO_MAPPING.
Supported renaming patterns:
- llama -> my_new_model and my_new_model -> llama
- Llama -> MyNewModel and MyNewModel -> Llama
- LLAMA -> MY_NEW_MODEL and MY_NEW_MODEL -> LLAMA
- LLaMa -> MyNewModel abd MyNewModel -> Llama
"""
def __init__(
self,
old_name,
new_name,
given_old_name=None,
given_new_name=None,
old_class_name: str = None,
new_class_name: str = None,
):
super().__init__()
self.old_name = old_name
self.new_name = new_name
self.default_name = "".join(x.title() for x in new_name.split("_"))
if self.new_name in CONFIG_MAPPING_NAMES:
self.default_name = CONFIG_MAPPING_NAMES[self.new_name].replace(
"Config", ""
) # the best source of truth for class names. Could also just use the ones de
self.patterns = {
old_name: new_name,
old_name.upper(): new_name.upper(),
"".join(x.title() for x in old_name.split("_")): self.default_name,
}
if given_old_name is not None and given_new_name is not None and given_old_name not in self.patterns:
self.patterns[given_old_name] = given_new_name
if self.old_name in CONFIG_MAPPING_NAMES:
self.default_old_name = CONFIG_MAPPING_NAMES[self.old_name].replace("Config", "")
if self.default_old_name.isupper():
self.default_old_name = self.default_old_name.capitalize()
if new_class_name is not None and old_class_name is not None and old_class_name not in self.patterns:
# In last recourse, when the suffix of the new class is not the same as the old class,
# and if the old and new classes start with the default name, we keep the default class name
# and replace the old suffix with the new one.
# Useful when we have a class like `ColPaliForRetrieval` inheriting from `PaliGemmaForConditionalGeneration`
# where a model extends another model, but is used for a different task.
if old_class_name.startswith(self.default_old_name) and new_class_name.startswith(self.default_name):
self.patterns[old_class_name[len(self.default_old_name) :]] = new_class_name[len(self.default_name) :]
def preserve_case_replace(self, text):
# Create a regex pattern to match all variations
regex_pattern = "|".join(re.escape(key) for key in self.patterns.keys())
compiled_regex = re.compile(regex_pattern, re.IGNORECASE)
def replace(match):
word = match.group(0)
result = self.patterns.get(word, self.default_name)
return result
return compiled_regex.sub(replace, text)
def convert_to_camelcase(self, text):
# Regex pattern to match consecutive uppercase letters and lowercase the first set
result = re.sub(
rf"^({self.old_name})(?=[a-z]+)", lambda m: self.default_old_name, text, flags=re.IGNORECASE, count=1
)
return result
@m.leave(m.Name() | m.SimpleString() | m.Comment())
def replace_name(self, original_node, updated_node):
if re.findall(r"# Copied from", updated_node.value):
return cst.RemoveFromParent()
update = self.preserve_case_replace(updated_node.value)
return updated_node.with_changes(value=update)
def leave_ClassDef(self, original_node, updated_node):
return updated_node.with_changes(name=cst.Name(self.convert_to_camelcase(updated_node.name.value)))
def find_classes_in_file(
module: cst.Module,
old_id="llama",
new_id="gemma",
given_old_name=None,
given_new_name=None,
old_class_name=None,
new_class_name=None,
):
"""Helper function to rename and then parse a source file using the ClassFinder"""
transformer = ReplaceNameTransformer(
old_id,
new_id,
given_old_name=given_old_name,
given_new_name=given_new_name,
old_class_name=old_class_name,
new_class_name=new_class_name,
)
new_module = module.visit(transformer)
wrapper = MetadataWrapper(new_module)
class_finder = ClassFinder(new_module)
wrapper.visit(class_finder)
return class_finder
DOCSTRING_NODE = m.SimpleStatementLine(
body=[
m.Expr(
value=m.SimpleString(
# match anything between """ """
value=m.MatchIfTrue(lambda value: re.search(r"\"\"\"[\s\S]*\"\"\"", value) is not None)
)
)
]
)
def SUPER_CALL_NODE(func_name):
return m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")), attr=m.Name(func_name)))
def is_call_to_super(node, func_name):
return m.matches(
node, m.SimpleStatementLine(body=[m.Return(SUPER_CALL_NODE(func_name)) | m.Expr(SUPER_CALL_NODE(func_name))])
)
# Transformer class to replace ClassB.call_to_method and ClassB().call_to_method with super().call_to_method
class ReplaceMethodCallTransformer(cst.CSTTransformer):
def __init__(self, all_bases: Set[str]):
self.all_bases = all_bases
def leave_Attribute(self, original_node: cst.Attribute, updated_node: cst.Attribute) -> cst.CSTNode:
# Handle ClassB.call_to_method
if (
isinstance(original_node.value, cst.Name)
and original_node.value.value in self.all_bases
and isinstance(original_node.attr, cst.Name)
):
# Replace with super().call_to_method
return updated_node.with_changes(
value=cst.Call(cst.Name("super")),
)
# Handle ClassB().call_to_method
elif (
isinstance(original_node.value, cst.Call)
and isinstance(original_node.value.func, cst.Name)
and original_node.value.func.value in self.all_bases
and isinstance(original_node.attr, cst.Name)
):
# Replace with super().call_to_method
return updated_node.with_changes(func=cst.Attribute(value=cst.Call(func=cst.Name("super"))))
return updated_node
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode:
# Check if the function being called is of the form ClassB().func_a or ClassB.func_a
if isinstance(original_node.func, cst.Attribute) and (
# Match ClassB().func_a(...)
(
isinstance(original_node.func.value, cst.Call)
and isinstance(original_node.func.value.func, cst.Name)
and original_node.func.value.func.value in self.all_bases
)
or
# Match ClassB.func_a(...)
(isinstance(original_node.func.value, cst.Name) and original_node.func.value.value in self.all_bases)
):
# Check if the first argument is 'self', and remove it
if len(original_node.args) > 0 and m.matches(original_node.args[0].value, m.Name("self")):
# Create the new argument list without 'self'
new_args = updated_node.args[1:]
else:
new_args = updated_node.args
return updated_node.with_changes(args=new_args)
return updated_node
def get_docstring_indent(docstring):
# Match the first line after the opening triple quotes
match = re.search(r'(?:"""|\'\'\'|```)\n(\s+)', docstring)
if match:
# Return the indentation spaces captured
return len(match.group(1))
return 0
def merge_docstrings(original_docstring, updated_docstring):
# indent_level = get_docstring_indent(updated_docstring)
original_level = get_docstring_indent(original_docstring)
if not re.findall(r"\n\s*Args:\n", updated_docstring):
# Split the docstring at the example section, assuming `"""` is used to define the docstring
parts = original_docstring.split("```")
if "```" in updated_docstring and len(parts) > 1:
updated_docstring = updated_docstring.lstrip('r"')
new_parts = updated_docstring.split("```")
if len(new_parts) != 3:
raise ValueError("There should only be one example, and it should have opening and closing '```'")
parts[1] = new_parts[1]
updated_docstring = "".join(
[
parts[0].rstrip(" \n") + new_parts[0],
f"\n{original_level*' '}```",
parts[1],
"```",
parts[2],
]
)
elif updated_docstring not in original_docstring:
# add tabulation if we are at the lowest level.
if re.search(r"\n\s*.*\(.*\)\:\n\s*\w", updated_docstring):
updated_docstring = updated_docstring.replace("\n ", "\n ")
updated_docstring = original_docstring.rstrip('"') + "\n" + updated_docstring.lstrip('r"\n')
return updated_docstring
class SuperTransformer(cst.CSTTransformer):
METADATA_DEPENDENCIES = (ParentNodeProvider,)
def __init__(self, python_module: cst.Module, original_methods, updated_methods, class_name="", all_bases=None):
self.python_module = python_module
self.original_methods = original_methods
self.updated_methods = updated_methods
self.all_assign_target = {}
self.deleted_targets = {} # child node can delete some arguments
self.class_name = class_name
self.all_bases = all_bases or []
self.transformer = ReplaceMethodCallTransformer(set(self.all_bases))
def update_body(self, existing_body, new_statements):
"""
Helper method to update the body by removing duplicates before adding new statements.
`existing_body` is the body of the original method, the parent class
`new_statements` are the additional statements
"""
deduplicated_new_body = []
existing_nodes = set()
for node in new_statements:
if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])):
target = self.python_module.code_for_node(node.body[0].targets[0].target)
self.all_assign_target[target] = node
if m.matches(node, m.SimpleStatementLine(body=[m.Del()])):
target = self.python_module.code_for_node(node.body[0].target)
self.deleted_targets[target] = node
continue
for stmt in existing_body:
if m.matches(stmt, m.SimpleStatementLine(body=[m.Assign()])):
target = self.python_module.code_for_node(stmt.body[0].targets[0].target)
if target in self.deleted_targets:
logger.warning(f"Deleted the assign for {target}")
continue
if target in self.all_assign_target:
stmt = self.all_assign_target[target]
comment_less_code = re.sub(r"#.*", "", self.python_module.code_for_node(stmt)).strip()
comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip()
deduplicated_new_body.append(stmt)
existing_nodes.add(comment_less_code)
for node in new_statements:
code = self.python_module.code_for_node(node)
comment_less_code = re.sub(r"#.*", "", code).strip()
comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip()
if (
node not in deduplicated_new_body
and "super().__init__" not in comment_less_code
and comment_less_code not in existing_nodes
):
if not m.matches(node, m.SimpleStatementLine(body=[m.Del()])):
# HACK here to fix the pos_init() that has to be last we kinda do this.
deduplicated_new_body = deduplicated_new_body[:-1] + [node] + deduplicated_new_body[-1:]
existing_nodes.add(comment_less_code)
return deduplicated_new_body
def replace_super_calls(self, node: cst.IndentedBlock, func_name: str) -> cst.CSTNode:
"""Updates the body of the input `node`'s `func_name` function by replacing calls
to super().func_name() with the source code of the parent class' `func_name`.
It keeps everything that is defined before `super().func_name()`.
"""
self.has_docstring = False
parent_has_docstring = False
if func_name in self.original_methods:
parent_has_docstring = m.matches(self.original_methods[func_name].body.body[0], DOCSTRING_NODE)
new_body = []
has_super_call = False
for expr in node.body:
if is_call_to_super(expr, func_name):
has_super_call = True
new_body.extend(self.update_body(self.original_methods[func_name].body.body, node.body))
else:
expr = expr.visit(self.transformer)
if m.matches(expr, DOCSTRING_NODE):
self.has_docstring = True
if parent_has_docstring: # actually here we ought to de-duplicate?
original_docstring = self.original_methods[func_name].body.body[0].body[0].value.value
updated_docstring = expr.body[0].value.value
merged_doc = merge_docstrings(original_docstring, updated_docstring)
new_node = [expr.with_changes(body=[cst.Expr(value=cst.SimpleString(value=merged_doc))])]
else:
new_node = [expr]
new_body.extend(new_node)
elif not m.matches(expr, m.SimpleStatementLine(body=[m.Del()])) and not has_super_call:
new_body.append(expr)
if not self.has_docstring and parent_has_docstring:
new_body = [self.original_methods[func_name].body.body[0]] + new_body
return node.with_changes(body=new_body)
def leave_FunctionDef(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode:
if updated_node.name.value in self.updated_methods:
name = updated_node.name.value
new_body = self.replace_super_calls(updated_node.body, name)
return updated_node.with_changes(body=new_body, params=updated_node.params)
return updated_node
def leave_Return(self, original_node: cst.Return, updated_node: cst.Return) -> cst.CSTNode:
""" "When a return statement is reached, it is replaced with the unrolled super code"""
if m.matches(updated_node.value, m.Call(func=m.Attribute(attr=m.Name("super")))):
func_def = self.get_metadata(ParentNodeProvider, original_node)
if m.matched(func_def, m.FunctionDef()) and func_def.name.value in self.original_methods:
updated_return_value = updated_node.value.with_changes(
args=[
cst.Arg(
value=cst.Call(func=cst.Name("super"), args=[cst.Arg(value=cst.Name(func_def.name.value))])
)
]
)
return updated_node.with_changes(value=updated_return_value)
return updated_node
def replace_call_to_super(
class_finder: ClassFinder, updated_node: cst.ClassDef, class_name: str, all_bases: List[str]
):
"""
Given the `class_name`, the `updated_node`'s call to super are unpacked.
| ```python | | ```python
| class GemmaModel(LlamaModel): | | class GemmaModel(nn.Module):
| def __init__(self): | | def __init__(self):
Going from: | super().__init__() | to: | super().__init__(config)
| self.dropout = 0.2 | | self.dropout = 0.2
| ``` | | self.padding_idx = config.pad_token_id
| self.vocab_size = config.vocab_size
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
| self.layers = nn.ModuleList(
| [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
| )
| self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
| self.gradient_checkpointing = False
| # Initialize weights and apply final processing
| self.post_init()
| ```
"""
original_node = class_finder.classes[class_name]
original_methods = {
f.name.value if hasattr(f, "name") else class_finder.python_module.code_for_node(f): f
for f in original_node.body.body
}
updated_methods = {
f.name.value if hasattr(f, "name") else class_finder.python_module.code_for_node(f): f
for f in updated_node.body.body
}
end_meth = []
assign_targets = {}
docstring_node = []
# Iterate directly from node.body as there can be property/setters with same names which are overwritten when we use a dict
for func in original_node.body.body:
name = func.name.value if hasattr(func, "name") else class_finder.python_module.code_for_node(func)
if m.matches(func, m.FunctionDef()) and name in updated_methods and updated_methods[name] is not None:
new_params = updated_methods[name].params
# Replace the method in the replacement class, preserving decorators
kwarg_name = getattr(updated_methods[name].params, "star_kwarg", None)
if kwarg_name and kwarg_name.name.value == "super_kwargs":
parent_params = {k.name.value: k for k in func.params.params}
parent_params.update({k.name.value: k for k in new_params.params[1:]})
new_params = new_params.with_changes(
params=list(parent_params.values()), star_kwarg=func.params.star_kwarg
)
if not re.match(
r"\ndef .*\(.*\):\n raise.*Error\(.*",
class_finder.python_module.code_for_node(updated_methods[name]),
):
func = func.with_changes(body=updated_methods[name].body, params=new_params)
else:
continue
if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])):
target = class_finder.python_module.code_for_node(func.body[0].targets[0])
assign_targets[target] = func
elif m.matches(func, m.SimpleStatementLine(body=[m.AnnAssign()])):
target = class_finder.python_module.code_for_node(func.body[0].target)
assign_targets[target] = func
elif m.matches(func, DOCSTRING_NODE):
docstring_node = [func]
else:
end_meth.append(func)
# Port new methods that are defined only in modular-file and append at the end
for func in updated_node.body.body:
name = func.name.value if hasattr(func, "name") else class_finder.python_module.code_for_node(func)
if m.matches(func, DOCSTRING_NODE): # This processes the docstring of the class!
# Extract the original docstring
updated_docstring = func.body[0].value.value
original_docstring = docstring_node[0].body[0].value.value
merged_doc = merge_docstrings(original_docstring, updated_docstring)
# Update the docstring in the original function
docstring_node = [
docstring_node[0].with_changes(body=[cst.Expr(value=cst.SimpleString(value=merged_doc))])
]
if name not in original_methods and func is not None and isinstance(func, cst.FunctionDef):
end_meth.append(func)
if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])):
# TODO we only use single assign might cause issues
target = class_finder.python_module.code_for_node(func.body[0].targets[0])
assign_targets[target] = func
if m.matches(func, m.SimpleStatementLine(body=[m.AnnAssign()])):
target = class_finder.python_module.code_for_node(func.body[0].target)
assign_targets[target] = func
end_meth = docstring_node + list(assign_targets.values()) + end_meth
result_node = original_node.with_changes(body=cst.IndentedBlock(body=end_meth))
temp_module = cst.Module(body=[result_node])
new_module = MetadataWrapper(temp_module)
new_replacement_class = new_module.visit(
SuperTransformer(temp_module, original_methods, updated_methods, class_name, all_bases)
)
new_replacement_body = new_replacement_class.body[0].body # get the indented block
return original_node.with_changes(body=new_replacement_body)
TYPE_TO_FILE_TYPE = {
"Config": "configuration",
"Tokenizer": "tokenization",
"Processor": "processing",
"ImageProcessor": "image_processing",
"FeatureExtractor": "feature_extractor",
}
def get_new_part(class_name, base_class):
"""
When `MyClassNameAttention` inherits from `MistralAttention`, we need
to process the name to properly find dependencies.
Here we take what is the same (Attention) and what is different
when finding the dependencies.
"""
common_suffix_len = 0
for i in range(1, min(len(class_name), len(base_class)) + 1):
if class_name[-i] == base_class[-i]:
common_suffix_len += 1
else:
break
if common_suffix_len > 0:
new_part = class_name[:-common_suffix_len]
else:
new_part = class_name
# Convert the remaining new part to snake_case
snake_case = re.sub(r"(?<!^)(?=[A-Z])", "_", new_part).lower()
return snake_case
def find_all_dependencies(function: str, dependency_mapping: Dict[str, set]):
"""Return all the dependencies of the given top-level function. Given the following structure in the `modular_xxx.py` file:
```
def foo1():
pass
def foo2():
pass
def bar():
foo1()
def foobar():
bar()
foo2()
class MyLayer(SomeOtherModelLayer):
def forward(...):
foobar()
```
and the `dependency_mapping` created when visiting the `modular_xxx.py` file, we get:
```
dependency_mapping = {'bar': {'foo1'}, 'foobar': {'bar', 'foo2'}}
find_all_dependencies('foobar', dependency_mapping)
>>> [('bar', 'foobar'), ('foo2', 'foobar'), ('foo1', 'bar')]
```
That is, all the functions needed (and their immediate parent) so that the function to be added in MyLayer (`foobar`) can
work correctly.
"""
all_dependencies = deque(dependency_mapping[function])
all_dependencies_with_parent = [(dep, function) for dep in dependency_mapping[function]]
checked_dependencies = set(function)
while len(all_dependencies) > 0:
# Pick element to visit
parent = all_dependencies.popleft()
if parent not in checked_dependencies:
# Update dependencies
all_dependencies.extend(dependency_mapping[parent])
all_dependencies_with_parent += [(dependency, parent) for dependency in dependency_mapping[parent]]
# add visited node to the list
checked_dependencies.add(parent)
# no child can ever appear before its parent thanks to the queue (needed to add them at the correct location in the body later)
return all_dependencies_with_parent
class PostModularConverterCleaner(CSTTransformer):
"""Allow simple cleaning after conversion. Remove top-level functions/classes without any calls (they may arise due
to dependency mapping, even if code parts with those functions/classes were overwritten)"""
METADATA_DEPENDENCIES = (ParentNodeProvider,)
def __init__(self, added_dependencies: set):
super().__init__()
self.top_level_functions_or_classes = {}
self.all_used_functions_or_classes = set()
self.added_dependencies = added_dependencies
def visit_FunctionDef(self, node):
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node)
if m.matches(parent_node, m.Module()):
self.top_level_functions_or_classes[node.name.value] = node
def visit_ClassDef(self, node):
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node)
if m.matches(parent_node, m.Module()):
self.top_level_functions_or_classes[node.name.value] = node
def visit_Name(self, node: cst.Name):
"""This is used to find any mention of a top-level function or class except its own definition.
It will contain other names as well, but those will not be used. This is the most general way to do it
since mentions may appear in a lot of different contexts (apart from simple Call to the function/class).
e.g. Attention classes are only mentionned by their name in a dict assignment.
"""
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node)
if not (
(m.matches(parent_node, m.ClassDef()) and parent_node.name.value == node.value)
or (m.matches(parent_node, m.FunctionDef()) and parent_node.name.value == node.value)
):
self.all_used_functions_or_classes.add(node.value)
def leave_Module(self, original_node: cst.Module, node):
# Find any class/function that was mistakenly added as part of the dependencies and remove it
unused = self.added_dependencies - self.all_used_functions_or_classes
nodes_to_remove = [
self.top_level_functions_or_classes[name] for name in unused if name in self.top_level_functions_or_classes
]
new_body = [node_ for node_ in original_node.body if node_ not in nodes_to_remove]
# Return a new module with the updated body
return node.with_changes(body=new_body)
class ModularConverterTransformer(CSTTransformer):
METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider)
def __init__(self, python_module, new_name, given_old_name=None, given_new_name=None):
super().__init__()
self.model_name = (
new_name # name of the model being defined. Should be in the format of `llama` or `layout_xlm` our `phi3`
)
self.given_old_name = given_old_name
self.given_new_name = given_new_name
# fmt: off
self.python_module = python_module # we store the original module to use `code_for_node`
self.transformers_imports = {} # maps the imports name like "from transformers.models.xxx" to the parsed AST module
self.imported_mapping = {} # stores the name of the imported classes, with their source {"LlamaModel":"transformers.model.llama.modeling_llama"}
self.visited_module = {} # modules visited like "transformers.models.llama.modeling_llama"
self.inserted_deps = [] # nodes inserted via super dependency
self.all_imports = [] # just stores all of the imports
self.all_safe_imports = [] # stores the import under simple statements
self.global_scope_index = 0
# fmt: on
self.files = { # mapping for different component bodies
"modeling": {},
"configuration": {},
"tokenization": {},
"processing": {},
"image_processing": {},
"feature_extractor": {},
}
self.match_patterns = "|".join(self.files.keys())
self.all_definitions = {}
self.class_to_file_type = {}
self.current_class = None # keep track of current top-level class during visit
self.current_top_level_function = None # keep track of current top-level function during visit
# Mapping from top-level functions to classes using them
self.function_call_class_mapping = defaultdict(lambda: set())
# Mapping from top-level functions to other top-level functions dependencies
self.function_call_dependency_mapping = defaultdict(lambda: set())
self.added_dependencies = set()
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
"""When visiting imports from `transformers.models.xxx` we need to:
1. Get the original source code
2. Parse it into an AST Tree
3. Add this import to `self.transformers_imports` as visited to not parse it twice
"""
import_statement = self.python_module.code_for_node(node.module)
if m.matches(node.module, m.Attribute()):
for imported_ in node.names:
_import = re.search(rf"(transformers\.models\..|..)*\.({self.match_patterns})_.*", import_statement)
if _import:
source = _import.groups()[0]
if source == "modeling" and "Config" in self.python_module.code_for_node(imported_):
raise ValueError(
f"You are importing {self.python_module.code_for_node(imported_)} from the modeling file. Import from the `configuration_xxxx.py` file instead"
)
if import_statement not in self.transformers_imports:
if "models" not in import_statement:
import_statement = "models." + import_statement
if "transformers" not in import_statement:
import_statement = "transformers." + import_statement
source_code = get_module_source_from_name(import_statement)
tree = cst.parse_module(source_code)
self.transformers_imports[import_statement] = tree
imported_class = self.python_module.code_for_node(imported_.name)
self.imported_mapping[imported_class] = import_statement
if m.matches(node.module, m.Name()):
if "transformers" == import_statement:
raise ValueError(
f"You are importing from {import_statement} directly using global imports. Import from the correct local path"
)
def leave_SimpleStatementLine(self, original_node, updated_node):
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
if m.matches(parent_node, m.Module()):
if m.matches(updated_node, m.SimpleStatementLine(body=[m.Import()])):
if updated_node not in self.all_imports:
self.all_imports.append(updated_node)
return updated_node
elif m.matches(updated_node, m.SimpleStatementLine(body=[m.ImportFrom()])):
full_statement = self.python_module.code_for_node(updated_node.body[0].module)
if re.search(
rf"(transformers\.models\..|..)*\.({self.match_patterns})_.*", full_statement
): # OR MATCH ..llama.modeling_llama
return cst.RemoveFromParent()
if updated_node not in self.all_imports:
self.all_imports.append(updated_node)
return updated_node
elif m.matches(original_node, m.SimpleStatementLine(body=[m.Assign()])):
if original_node.body[0].targets[0].target.value in ASSIGNMENTS_TO_KEEP.keys():
file_ = ASSIGNMENTS_TO_KEEP[original_node.body[0].targets[0].target.value]
self.files[file_][original_node.body[0].targets[0].target.value] = {
"node": original_node,
"insert_idx": self.global_scope_index,
}
self.global_scope_index += 100
return updated_node
def visit_ClassDef(self, node: cst.ClassDef):
"""Used to keep track of current class"""
self.current_class = node.name.value
def leave_ClassDef(self, original_node, updated_node):
"""
1. Filter the `base` classes of this class
If they are from `transformers.models.xx` then:
- take the AST tree of the module it comes from and parse it with a `ClassFinder`.
- rename all every instance of `old_name` (llama) to `new_name` (gemma)
2. We insert the modules which the inherited base depends on. This has to be done in
the order of the dependencies. If on is already in the new_body (because it's defined in the diff file)
then we remove it from the new body to add it again in the correct order.
3. Replace the calls to `super().xxxx` merging parent code
"""
class_name = original_node.name.value
bases = [k.value.value for k in original_node.bases if k.value.value in self.imported_mapping]
all_bases = [k.value.value for k in original_node.bases]
self.global_scope_index += 100
for super_class in bases:
if super_class not in self.imported_mapping:
raise ImportError(
f"{super_class} was not imported using `from transformers.models.xxxxx.modeling_xxxx import {super_class}"
)
super_file_name = self.imported_mapping[super_class] # we need to get the parsed tree
model_name = re.search(r"models\.\w*?\.\w*?_(\S*)", super_file_name)
if model_name:
model_name = model_name.groups()[0]
else:
raise ValueError(
f"Tried parsing the name of the imported package from {super_file_name}, could not extract the model name"
)
file_type = re.search(r"models?\.\w*?\.(\w*?)_", super_file_name).groups()[0]
visited_module = self.visited_module
if super_file_name not in visited_module: # only extract classes once
class_finder = find_classes_in_file(
self.transformers_imports[super_file_name],
model_name,
self.model_name,
self.given_old_name,
self.given_new_name,
)
visited_module[super_file_name] = class_finder
list_dependencies = {
dep: class_finder.class_start_line.get(dep, 1000)
for dep in class_finder.class_dependency_mapping.get(class_name, [])
}
else: # we are re-using the previously parsed data
class_finder = visited_module[super_file_name]
list_dependencies = {
dep: class_finder.class_start_line.get(dep, 1000)
for dep in class_finder.class_dependency_mapping.get(class_name, [])
}
if len(list_dependencies) == 0:
# so, maybe standard renaming did not work (the class name is different)
# we try with another renaming pattern
potential_given_name = get_new_part(class_name, super_class)
del visited_module[super_file_name]
class_finder = find_classes_in_file(
self.transformers_imports[super_file_name],
model_name,
potential_given_name,
self.model_name,
potential_given_name,
)
list_dependencies = {
dep: class_finder.class_start_line.get(dep, 1000)
for dep in class_finder.class_dependency_mapping.get(class_name, [])
}
if len(list_dependencies) == 0:
# last recourse, if the suffix of the new class is different from the one of the super class
# e.g. MyNewClassForSegmentation extends MyOldClassForObjectDetection
# we try with another renaming pattern
class_finder = find_classes_in_file(
self.transformers_imports[super_file_name],
model_name,
self.model_name,
self.given_old_name,
self.given_new_name,
super_class,
class_name,
)
visited_module[super_file_name] = class_finder
list_dependencies = {
dep: class_finder.class_start_line.get(dep, 1000)
for dep in class_finder.class_dependency_mapping.get(class_name, [])
}
if len(list_dependencies) == 0:
raise ValueError(
f"We were unable to find dependencies for {class_name} (based on inheriting from {super_class})"
f" Here are all the global dependencies that we found in you modular file: {list(class_finder.class_dependency_mapping.keys())}."
f" This usually means that the name of `{class_name}` does not match the pattern of `{super_class}`"
)
list_dependencies = sorted(list_dependencies.items(), key=lambda x: x[1], reverse=True)
start_insert_idx = self.global_scope_index
file_to_update = self.files[file_type]
is_empty_node = self.python_module.code_for_node(original_node.body) == "pass\n"
for dependency, _ in list_dependencies:
# we can write to the correct body, using the source of the parent class
node = class_finder.global_nodes.get(dependency, None)
if node is not None:
if dependency not in file_to_update:
node = self.all_definitions.pop(dependency, node)
start_insert_idx -= 1
file_to_update[dependency] = {"insert_idx": start_insert_idx, "node": node}
self.added_dependencies.add(dependency)
elif dependency not in self.inserted_deps:
# make sure the node is written after its dependencies
start_insert_idx = file_to_update[dependency]["insert_idx"] - 1
if (
dependency in file_to_update.keys()
and dependency in class_finder.first_lvl_dependency_mapping[class_name]
):
# If dependency is defined, but not used, raise error
calls = m.findall(original_node, m.Call(func=m.Name(dependency)))
if not calls and not is_empty_node and dependency not in all_bases:
raise ValueError(
f"""You defined `{dependency}` in the modular_{self.model_name}.py, it should be used
when you define `{class_name}`, as it is one of it's direct dependencies. Make sure
you use it in the `__init__` function."""
)
self.inserted_deps.append(dependency)
if len(list_dependencies) > 0:
updated_node = replace_call_to_super(class_finder, updated_node, class_name, all_bases)
# Now, if a class was defined without parents, we look for the name
match_pattern = "|".join(TYPE_TO_FILE_TYPE.keys())
match = re.search(rf"({match_pattern})$", class_name)
if match:
key = TYPE_TO_FILE_TYPE[match.group(1)]
self.class_to_file_type[class_name] = key
self.files[key][class_name] = {"insert_idx": self.global_scope_index, "node": updated_node}
else:
self.class_to_file_type[class_name] = "modeling"
self.files["modeling"][class_name] = {"insert_idx": self.global_scope_index, "node": updated_node}
self.current_class = None
return updated_node
def visit_FunctionDef(self, node):
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node)
if m.matches(parent_node, m.Module()):
self.current_top_level_function = node.name.value