-
Notifications
You must be signed in to change notification settings - Fork 27.8k
/
Copy pathmodular_model_converter.py
1511 lines (1318 loc) · 76.9 KB
/
modular_model_converter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import glob
import importlib
import os
import re
from abc import ABC, abstractmethod
from collections import defaultdict, deque
from typing import Dict, Set
import libcst as cst
from check_copies import run_ruff
from create_dependency_mapping import find_priority_list
from libcst import ClassDef, CSTVisitor
from libcst import matchers as m
from libcst.metadata import MetadataWrapper, ParentNodeProvider, PositionProvider, ScopeProvider
from transformers import logging
from transformers.models.auto.configuration_auto import CONFIG_MAPPING_NAMES
logger = logging.get_logger(__name__)
AUTO_GENERATED_MESSAGE = """# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from {relative_path}.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# {short_name} file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
"""
def get_module_source_from_name(module_name: str) -> str:
# Extract the source code from the module name
spec = importlib.util.find_spec(module_name)
if spec is None or spec.origin is None:
return f"Module {module_name} not found"
with open(spec.origin, "r", encoding="utf-8") as file:
source_code = file.read()
return source_code
def preserve_case_replace(text, patterns: dict, default_name: str):
# Create a regex pattern to match all variations
regex_pattern = "|".join(re.escape(key) for key in patterns.keys())
compiled_regex = re.compile(regex_pattern, re.IGNORECASE)
def replace(match):
word = match.group(0)
result = patterns.get(word, default_name)
return result
return compiled_regex.sub(replace, text)
def convert_to_camelcase(text, old_name: str, default_old_name: str):
# Regex pattern to match consecutive uppercase letters and lowercase the first set
result = re.sub(rf"^({old_name})(?=[a-z]+)", lambda m: default_old_name, text, flags=re.IGNORECASE, count=1)
return result
class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
"""A transformer that replaces `old_name` with `new_name` in comments, string and any references.
It should take into account name like `MyNewModel`, or `my_new_model`. Without using the AUTO_MAPPING.
Supported renaming patterns:
- llama -> my_new_model and my_new_model -> llama
- Llama -> MyNewModel and MyNewModel -> Llama
- LLAMA -> MY_NEW_MODEL and MY_NEW_MODEL -> LLAMA
- LLaMa -> MyNewModel abd MyNewModel -> Llama
"""
def __init__(
self,
old_name,
new_name,
given_old_name=None,
given_new_name=None,
):
super().__init__()
self.old_name = old_name
self.new_name = new_name
self.default_name = "".join(x.title() for x in new_name.split("_"))
if self.new_name in CONFIG_MAPPING_NAMES:
self.default_name = CONFIG_MAPPING_NAMES[self.new_name].replace(
"Config", ""
) # the best source of truth for class names. Could also just use the ones de
self.patterns = {
old_name: new_name,
old_name.upper(): new_name.upper(),
"".join(x.title() for x in old_name.split("_")): self.default_name,
}
if given_old_name is not None and given_new_name is not None and given_old_name not in self.patterns:
self.patterns[given_old_name] = given_new_name
if self.old_name in CONFIG_MAPPING_NAMES:
self.default_old_name = CONFIG_MAPPING_NAMES[self.old_name].replace("Config", "")
if self.default_old_name.isupper():
self.default_old_name = self.default_old_name.capitalize()
@m.leave(m.Name() | m.SimpleString() | m.Comment())
def replace_name(self, original_node, updated_node):
if re.findall(r"# Copied from", updated_node.value):
return cst.RemoveFromParent()
update = preserve_case_replace(updated_node.value, self.patterns, self.default_name)
return updated_node.with_changes(value=update)
def leave_ClassDef(self, original_node, updated_node):
new_name = convert_to_camelcase(updated_node.name.value, self.old_name, self.default_old_name)
return updated_node.with_changes(name=cst.Name(new_name))
DOCSTRING_NODE = m.SimpleStatementLine(
body=[
m.Expr(
value=m.SimpleString(
# match anything between """ """
value=m.MatchIfTrue(lambda value: re.search(r"\"\"\"[\s\S]*\"\"\"", value) is not None)
)
)
]
)
def SUPER_CALL_NODE(func_name):
return m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")), attr=m.Name(func_name)))
def is_call_to_super(node, func_name):
return m.matches(
node, m.SimpleStatementLine(body=[m.Return(SUPER_CALL_NODE(func_name)) | m.Expr(SUPER_CALL_NODE(func_name))])
)
# Transformer class to replace ClassB.call_to_method and ClassB().call_to_method with super().call_to_method
class ReplaceMethodCallTransformer(cst.CSTTransformer):
def __init__(self, all_bases: Set[str]):
self.all_bases = all_bases
def leave_Attribute(self, original_node: cst.Attribute, updated_node: cst.Attribute) -> cst.CSTNode:
# Handle ClassB.call_to_method
if (
isinstance(original_node.value, cst.Name)
and original_node.value.value in self.all_bases
and isinstance(original_node.attr, cst.Name)
):
# Replace with super().call_to_method
return updated_node.with_changes(
value=cst.Call(cst.Name("super")),
)
# Handle ClassB().call_to_method
elif (
isinstance(original_node.value, cst.Call)
and isinstance(original_node.value.func, cst.Name)
and original_node.value.func.value in self.all_bases
and isinstance(original_node.attr, cst.Name)
):
# Replace with super().call_to_method
return updated_node.with_changes(func=cst.Attribute(value=cst.Call(func=cst.Name("super"))))
return updated_node
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode:
# Check if the function being called is of the form ClassB().func_a or ClassB.func_a
if isinstance(original_node.func, cst.Attribute) and (
# Match ClassB().func_a(...)
(
isinstance(original_node.func.value, cst.Call)
and isinstance(original_node.func.value.func, cst.Name)
and original_node.func.value.func.value in self.all_bases
)
or
# Match ClassB.func_a(...)
(isinstance(original_node.func.value, cst.Name) and original_node.func.value.value in self.all_bases)
):
# Check if the first argument is 'self', and remove it
if len(original_node.args) > 0 and m.matches(original_node.args[0].value, m.Name("self")):
# Create the new argument list without 'self'
new_args = updated_node.args[1:]
else:
new_args = updated_node.args
return updated_node.with_changes(args=new_args)
return updated_node
def get_docstring_indent(docstring):
# Match the first line after the opening triple quotes
match = re.search(r'(?:"""|\'\'\'|```)\n(\s+)', docstring)
if match:
# Return the indentation spaces captured
return len(match.group(1))
return 0
def merge_docstrings(original_docstring, updated_docstring):
# indent_level = get_docstring_indent(updated_docstring)
original_level = get_docstring_indent(original_docstring)
if not re.findall(r"\n\s*Args:\n", updated_docstring):
# Split the docstring at the example section, assuming `"""` is used to define the docstring
parts = original_docstring.split("```")
if "```" in updated_docstring and len(parts) > 1:
updated_docstring = updated_docstring.lstrip('r"')
new_parts = updated_docstring.split("```")
if len(new_parts) != 3:
raise ValueError("There should only be one example, and it should have opening and closing '```'")
parts[1] = new_parts[1]
updated_docstring = "".join(
[
parts[0].rstrip(" \n") + new_parts[0],
f"\n{original_level*' '}```",
parts[1],
"```",
parts[2],
]
)
elif updated_docstring not in original_docstring:
# add tabulation if we are at the lowest level.
if re.search(r"\n\s*.*\(.*\)\:\n\s*\w", updated_docstring):
updated_docstring = updated_docstring.replace("\n ", "\n ")
updated_docstring = original_docstring.rstrip('"') + "\n" + updated_docstring.lstrip('r"\n')
return updated_docstring
class SuperTransformer(cst.CSTTransformer):
METADATA_DEPENDENCIES = (ParentNodeProvider,)
def __init__(self, python_module: cst.Module, original_methods, updated_methods, all_bases=None):
self.python_module = python_module
self.original_methods = original_methods
self.updated_methods = updated_methods
self.all_assign_target = {}
self.deleted_targets = {} # child node can delete some arguments
self.all_bases = all_bases or []
self.transformer = ReplaceMethodCallTransformer(set(self.all_bases))
def update_body(self, existing_body, new_statements):
"""
Helper method to update the body by removing duplicates before adding new statements.
`existing_body` is the body of the original method, the parent class
`new_statements` are the additional statements
"""
deduplicated_new_body = []
existing_nodes = set()
for node in new_statements:
if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])):
target = self.python_module.code_for_node(node.body[0].targets[0].target)
self.all_assign_target[target] = node
if m.matches(node, m.SimpleStatementLine(body=[m.Del()])):
target = self.python_module.code_for_node(node.body[0].target)
self.deleted_targets[target] = node
for stmt in existing_body:
if m.matches(stmt, m.SimpleStatementLine(body=[m.Assign()])):
target = self.python_module.code_for_node(stmt.body[0].targets[0].target)
if target in self.deleted_targets:
logger.warning(f"Deleted the assign for {target}")
continue
if target in self.all_assign_target:
stmt = self.all_assign_target[target]
# Skip the docstring (will be added later on, at the beginning)
elif m.matches(stmt, DOCSTRING_NODE):
continue
comment_less_code = re.sub(r"#.*", "", self.python_module.code_for_node(stmt)).strip()
comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip()
deduplicated_new_body.append(stmt)
existing_nodes.add(comment_less_code)
for node in new_statements:
code = self.python_module.code_for_node(node)
comment_less_code = re.sub(r"#.*", "", code).strip()
comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip()
if node not in deduplicated_new_body and comment_less_code not in existing_nodes:
if not m.matches(node, m.SimpleStatementLine(body=[m.Del()])):
deduplicated_new_body.append(node)
existing_nodes.add(comment_less_code)
deduplicated_new_body = self._fix_post_init_location(deduplicated_new_body)
return deduplicated_new_body
def _fix_post_init_location(self, new_body: list[cst.CSTNode]):
"""Fix the location of the `post_init()` in the new body, if we added statements after the call to
`super()` (it needs to be the very last statement called)"""
# Fix the post_init() that has to be last
for i, node in enumerate(new_body):
code = self.python_module.code_for_node(node)
comment_less_code = re.sub(r"#.*", "", code).strip()
comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip()
if "self.post_init(" in comment_less_code and i < len(new_body) - 1:
# Remove it and add it again at the end
new_body.pop(i)
new_body.append(node)
break
return new_body
def _fix_init_location(self, new_body):
"""Fix the location of the `super().__init__()` in the new body, if we had new statements before it."""
start_index = 0
for i, node in enumerate(new_body):
if m.matches(node, DOCSTRING_NODE) and i == start_index:
start_index += 1
continue
code = self.python_module.code_for_node(node)
comment_less_code = re.sub(r"#.*", "", code).strip()
comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip()
if "super().__init__" in comment_less_code and i > start_index:
# Remove it and add it again at the top after the docstrings
node = new_body.pop(i)
new_body = new_body[:start_index] + [node] + new_body[start_index:]
break
return new_body
def replace_super_calls(self, node: cst.IndentedBlock, func_name: str) -> cst.CSTNode:
"""Updates the body of the input `node`'s `func_name` function by replacing calls
to super().func_name() with the source code of the parent class' `func_name`.
It keeps everything that is defined before `super().func_name()`.
"""
self.has_docstring = False
parent_has_docstring = False
if func_name in self.original_methods:
parent_has_docstring = m.matches(self.original_methods[func_name].body.body[0], DOCSTRING_NODE)
new_body = []
has_super_call = False
for i, expr in enumerate(node.body):
if is_call_to_super(expr, func_name):
has_super_call = True
new_body.extend(self.update_body(self.original_methods[func_name].body.body, node.body[i + 1 :]))
new_body = self._fix_init_location(new_body)
else:
expr = expr.visit(self.transformer)
if m.matches(expr, DOCSTRING_NODE):
self.has_docstring = True
if parent_has_docstring: # actually here we ought to de-duplicate?
original_docstring = self.original_methods[func_name].body.body[0].body[0].value.value
updated_docstring = expr.body[0].value.value
merged_doc = merge_docstrings(original_docstring, updated_docstring)
new_node = [expr.with_changes(body=[cst.Expr(value=cst.SimpleString(value=merged_doc))])]
else:
new_node = [expr]
new_body.extend(new_node)
elif not m.matches(expr, m.SimpleStatementLine(body=[m.Del()])) and not has_super_call:
new_body.append(expr)
if not self.has_docstring and parent_has_docstring:
new_body = [self.original_methods[func_name].body.body[0]] + new_body
return node.with_changes(body=new_body)
def leave_FunctionDef(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode:
if updated_node.name.value in self.updated_methods:
name = updated_node.name.value
new_body = self.replace_super_calls(updated_node.body, name)
return updated_node.with_changes(body=new_body, params=updated_node.params)
return updated_node
def leave_Return(self, original_node: cst.Return, updated_node: cst.Return) -> cst.CSTNode:
""" "When a return statement is reached, it is replaced with the unrolled super code"""
if m.matches(updated_node.value, m.Call(func=m.Attribute(attr=m.Name("super")))):
func_def = self.get_metadata(ParentNodeProvider, original_node)
if m.matched(func_def, m.FunctionDef()) and func_def.name.value in self.original_methods:
updated_return_value = updated_node.value.with_changes(
args=[
cst.Arg(
value=cst.Call(func=cst.Name("super"), args=[cst.Arg(value=cst.Name(func_def.name.value))])
)
]
)
return updated_node.with_changes(value=updated_return_value)
return updated_node
def find_all_dependencies(
dependency_mapping: Dict[str, set],
start_entity: str | None = None,
initial_dependencies: set | None = None,
initial_checked_dependencies: set | None = None,
return_parent: bool = False,
) -> list | set:
"""Return all the dependencies of the given `start_entity` or `initial_dependencies`. This is basically some kind of
BFS traversal algorithm. It can either start from `start_entity`, or `initial_dependencies`.
Args:
dependency_mapping (`Dict[str, set]`):
A mapping from entities (usually function/assignment names), to immediate dependencies. That is, for function names,
a mapping {"foo": {"bar", "test"}} would indicate that functions `bar` and `test` are immediately called
in `foo`'s definition.
start_entity (str | None, *optional*):
A key of `dependency_mapping`, indicating from which entity to start the search.
initial_dependencies (set | None, *optional*):
If `start_entity` is not provided, this can be used as an alternative. In this case, the search will continue
from all the entities in `initial_dependencies`, if they are in `dependency_mapping`.
initial_checked_dependencies (set | None, *optional*):
If provided, entities already present in `initial_checked_dependencies` will not be part of the returned dependencies.
return_parent (bool, *optional*):
If `True`, will return a list consisting of tuples (dependency, parent) instead of a simple set of dependencies. Note
that the order of the items in the list reflects the traversal order. Thus, no parent can ever appear before childs.
Returns:
A set of all the dependencies, or a list of tuples `(dependency, parent)` if `return_parent=True`.
Example:
Given the following structure in the `modular_xxx.py` file:
```
def foo1():
pass
def foo2():
pass
def bar():
foo1()
def foobar():
bar()
foo2()
class MyLayer(SomeOtherModelLayer):
def forward(...):
foobar()
```
and the `dependency_mapping` created when visiting the `modular_xxx.py` file, we get:
```
dependency_mapping = {'bar': {'foo1'}, 'foobar': {'bar', 'foo2'}}
find_all_dependencies(dependency_mapping, start_entity='foobar', return_parent=True)
>>> [('bar', 'foobar'), ('foo2', 'foobar'), ('foo1', 'bar')]
```
That is, all the functions needed (and potentially their immediate parent) so that the function to be added
in MyLayer (`foobar`) can work correctly.
"""
if initial_dependencies is None and start_entity is not None:
initial_dependencies = dependency_mapping[start_entity]
if initial_checked_dependencies is None:
initial_checked_dependencies = set()
dependency_queue = deque(initial_dependencies)
all_dependencies = set()
all_dependencies_with_parent = []
checked_dependencies = set(initial_checked_dependencies)
parents = {initial_dep: start_entity for initial_dep in initial_dependencies}
while len(dependency_queue) > 0:
# Pick element to visit
current = dependency_queue.popleft()
if current not in checked_dependencies:
# Add the dependencies
all_dependencies.add(current)
all_dependencies_with_parent += [(current, parents[current])]
if current in dependency_mapping.keys():
# Update dependency queue
dependency_queue.extend(dependency_mapping[current])
parents.update({dep: current for dep in dependency_mapping[current]})
# add visited node to the list
checked_dependencies.add(current)
if not return_parent:
return all_dependencies
# no child can ever appear before its parent thanks to the queue (needed to add them at the correct location in the body later)
return all_dependencies_with_parent
# These top-level variables will always use the value in the `modular_xxx.py` file
ASSIGNMENTS_TO_KEEP = {
"_CHECKPOINT_FOR_DOC",
}
class ClassDependencyMapper(CSTVisitor):
"""A visitor which is designed to analyze a single class node to get all its dependencies that are shared with the set of
`global_names`.
"""
def __init__(
self, class_name: str, global_names: set[str], objects_imported_from_modeling: set[str] | None = None
):
super().__init__()
self.class_name = class_name
self.dependencies = set()
self.global_names = global_names
self.objects_imported_from_modeling = (
set() if objects_imported_from_modeling is None else objects_imported_from_modeling
)
def visit_Name(self, node):
if (
node.value != self.class_name
and node.value in self.global_names
and node.value not in self.objects_imported_from_modeling
):
self.dependencies.add(node.value)
def dependencies_for_class_node(node: cst.ClassDef, global_names: set[str]) -> set:
"""Create immediate dependencies for a class node based on the `global_names`."""
temp_module = cst.Module(body=[node])
visitor = ClassDependencyMapper(node.name.value, global_names)
temp_module.visit(visitor)
return visitor.dependencies
def augmented_dependencies_for_class_node(
node: cst.ClassDef, mapper: "ModuleMapper", objects_imported_from_modeling: set[str] | None = None
) -> set:
"""Create augmented dependencies for a class node based on a `mapper`.
Augmented dependencies means immediate dependencies + recursive function and assignments dependencies.
"""
temp_module = cst.Module(body=[node])
visitor = ClassDependencyMapper(node.name.value, set(mapper.global_nodes.keys()), objects_imported_from_modeling)
temp_module.visit(visitor)
return mapper.augment_dependencies(visitor.dependencies)
# All the potential file types to create
ALL_FILE_TYPES = (
"modeling",
"configuration",
"tokenization",
"processing",
"image_processing",
"feature_extractor",
)
class ModuleMapper(CSTVisitor, ABC):
"""An abstract visitor class which analyses a module, creating a mapping of dependencies for classes, functions and assignments.
Class dependencies are computed with `compute_class_dependencies()`, while function and assignment dependencies are stored in
`self.object_recursive_dependency_mapping` (can be computed by `_compute_recursive_object_dependencies()`).
It defines common visiting patterns (i.e. common visit_xxx/leave_xxx functions) between the modular file and the
modeling files that will be visited.
"""
METADATA_DEPENDENCIES = (ParentNodeProvider, PositionProvider)
def __init__(self, python_module: cst.Module):
# fmt: off
self.python_module: cst.Module = python_module # original cst.Module being visited
self.classes: Dict[str, cst.ClassDef] = {} # mapping from class names to Nodes (it will be ordered by default!!)
self.imports = [] # stores all import statements
self.functions: Dict[str, cst.FunctionDef] = {} # mapping of global scope function names to Nodes
self.object_dependency_mapping = defaultdict(set) # immediate function/assignment dependency mapping (i.e. dependencies immediately in the function/assignment definition)
self.assignments: Dict[str, cst.SimpleStatementLine] = {} # mapping of global assignments names to Nodes
self.current_function = None # this keeps track of the current module-scope function
self.current_assignment = None # this keeps track of the current module-scope assignment
# this keeps track of objects imported from modeling files (`from .configuration import Config`) -> `Config` should not be a dependency
self.objects_imported_from_modeling = set()
# regex pattern joining every possible file type
self.match_patterns = "|".join(ALL_FILE_TYPES)
# fmt: on
def visit_ImportFrom(self, node):
"""This keeps track of objects imported from neighbor modeling files (e.g. in `modeling_xxx.py, we have
`from .configuration_xxx import Config`, then `Config` should be recorded as it is not a dependency that needs
to be added (because it will be part of the imports)"""
import_module = self.python_module.code_for_node(node.module)
import_statement = "." * len(node.relative) + import_module
if re.search(rf"^\.({self.match_patterns})_.*", import_statement):
for imported_object in node.names:
# If an alias is present, we record it and not the original name
if imported_object.evaluated_alias is not None:
self.objects_imported_from_modeling.add(imported_object.evaluated_alias)
else:
self.objects_imported_from_modeling.add(imported_object.evaluated_name)
def visit_SimpleStatementLine(self, node):
"""
Global Assigns like `GEMMA_INPUT_DOCSTRING = 'THIS IS THE INPUT'` and all import statements
are extracted and saved in their corresponding dict. They are then used when updating dependency mappings.
"""
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node)
simple_top_level_assign_structure = m.SimpleStatementLine(
body=[m.Assign(targets=[m.AssignTarget(target=m.Name())])]
)
if m.matches(parent_node, m.Module()):
if m.matches(node, simple_top_level_assign_structure):
left_hand_side = node.body[0].targets[0].target.value
self.current_assignment = left_hand_side
self.assignments[left_hand_side] = node
elif m.matches(node, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()])):
self.imports.append(node)
def leave_SimpleStatementLine(self, node):
# No need to check for the parent here -> everytime we exit one, it should be None anyway independently of where the
# SimpleStatement is located
self.current_assignment = None
def visit_FunctionDef(self, node):
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node)
if m.matches(parent_node, m.Module()):
self.current_function = node.name.value
self.functions[node.name.value] = node
def leave_FunctionDef(self, node):
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node)
if m.matches(parent_node, m.Module()):
self.current_function = None
def visit_If(self, node):
for stmt in node.body.body:
if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])):
self.imports.append(node)
def visit_ClassDef(self, node: ClassDef) -> None:
"""Record class nodes to create their dependencies at the end."""
self.classes[node.name.value] = node
def visit_Name(self, node: cst.Call):
"""This is used to create a mapping from module-scope functions and assignments to objects used inside them."""
if self.current_function is not None:
self.object_dependency_mapping[self.current_function].add(node.value)
if self.current_assignment is not None:
self.object_dependency_mapping[self.current_assignment].add(node.value)
def leave_Module(self, node):
"""When leaving the module, we store the position of each global scoped node to allow sorting the dependencies
based on their position in the code later. We use the PositionProvider metadata wrapper for this.
We also make sure to update `self.object_dependency_mapping` so that it contains only names recorded in
`self.global_nodes`.
"""
# assign all nodes
self.global_nodes = {**self.assignments, **self.classes, **self.functions}
# now sort the class dependency_mapping based on the position of the nodes
self.start_lines = {}
for id, node in self.global_nodes.items():
self.start_lines[id] = self.get_metadata(cst.metadata.PositionProvider, node).start.line
# Since we added every Name as part of `self.object_dependency_mapping`, we now remove those that
# are not part of the recorded objects (i.e. built-in variables, imports, etc)
global_objects = set(self.global_nodes.keys())
for object_name, dependencies in self.object_dependency_mapping.items():
self.object_dependency_mapping[object_name] = {dep for dep in dependencies if dep in global_objects}
def _compute_recursive_object_dependencies(self) -> dict[str, set]:
"""Based on immediate dependency mapping, create the recursive dependency mapping. For example, given the
following file:
```
def foo():
pass
def bar():
foo()
def test():
bar()
```
this visitor can only record immediate dependencies, i.e. it will record the following
`self.object_dependency_mapping = {"test": {"bar"}, "bar": {"foo}}`. This function is used to create
the recursive mapping, i.e. `recursive_dependencies = {"test": {"bar", "foo"}, "bar": {"foo}}`.
"""
recursive_dependencies = {}
for object_name in self.object_dependency_mapping.keys():
all_dependencies = find_all_dependencies(self.object_dependency_mapping, start_entity=object_name)
recursive_dependencies[object_name] = all_dependencies
return recursive_dependencies
def augment_dependencies(self, dependencies: set[str]) -> set[str]:
"""For a set of `dependencies`, augment them by adding all potential dependencies of the **functions** and
**assignments** present in the `dependencies`.
"""
new_dependencies = dependencies.copy()
# Go through the set of dependencies
for dep in tuple(dependencies):
if dep in self.object_recursive_dependency_mapping.keys():
new_dependencies.update(self.object_recursive_dependency_mapping[dep])
return new_dependencies
def compute_class_dependencies(self):
"""For each visited class, find its dependencies based on visiting the current file + potential merged dependencies."""
self.class_dependency_mapping = {}
for class_name, class_node in self.classes.items():
dependencies = dependencies_for_class_node(class_node, set(self.global_nodes.keys()))
# Correctly augment class dependencies with all needed objects
self.class_dependency_mapping[class_name] = self.augment_dependencies(dependencies)
@abstractmethod
def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]:
raise NotImplementedError
class ModelFileMapper(ModuleMapper):
"""A mapper designed to parse modeling files (like `modeling_llama.py`). When encountering such a file
in the `modular_xxx.py` file, we need to correctly visit it and merge the dependencies of the modular and current file.
For this reason, this class should only be instantiated from the class method `visit_and_merge_dependencies`, which takes
care of correctly merging dependencies, then finalizes all dependency graph computations.
Note that we only merge functions and assignments here, as classes will be treated later on as they may be modified.
For example, if you redefine `apply_rotary_pos_emb()` in the modular, the new node should be used in the dependencies
of the modeling files as well.
"""
def __init__(self, python_module: cst.Module):
super().__init__(python_module)
def compute_relative_order(self, missing_dependencies: set[str]) -> dict[str, int]:
"""Compute in which relative order the `missing_dependencies` should appear when the nodes are added to the final file that
will be created based on the modular.
"""
relative_order = {}
idx = 0
classes = sorted(
[dep for dep in tuple(missing_dependencies) if dep in self.classes], key=lambda x: self.start_lines[x]
)
# This is because for merged dependencies, we only have relative order in the other visited file, so we need
# to track dependency order relative to a given class
if len(classes) > 0 and not hasattr(self, "class_dependency_mapping"):
raise ValueError("Cannot correctly find the relative order of the dependencies.")
remaining_dependencies = missing_dependencies.copy()
# Start by tracking relative order class by class
for class_name in classes:
class_dependencies = tuple(self.class_dependency_mapping[class_name] & remaining_dependencies)
original_dependencies = []
merged_dependencies = []
# We need to differentiate between nodes that were already present (we can get relative order globally) and
# nodes that were merged (we can get relative order only relative to the class the dependencies relate to)
for class_dep in class_dependencies:
if class_dep in self.start_lines:
original_dependencies.append(class_dep)
else:
merged_dependencies.append(class_dep)
# Sort both list according to the order in their respective file
original_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines[x])
merged_dependencies = sorted(merged_dependencies, key=lambda x: self.modular_file_start_lines[x])
# Add all original node first, then merged ones
for dep in original_dependencies + merged_dependencies:
remaining_dependencies.remove(dep)
relative_order[dep] = idx
idx += 1
# Add the class itself
remaining_dependencies.remove(class_name)
relative_order[class_name] = idx
idx += 1
# Now add what still remains
remaining_dependencies = tuple(remaining_dependencies)
original_dependencies = []
merged_dependencies = []
for dep in remaining_dependencies:
if dep in self.modular_file_start_lines:
merged_dependencies.append(dep)
else:
original_dependencies.append(dep)
# Sort both list according to the order in their respective file
original_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines[x])
merged_dependencies = sorted(merged_dependencies, key=lambda x: self.modular_file_start_lines[x])
# Add all original node first, then merged ones
for dep in original_dependencies + merged_dependencies:
relative_order[dep] = idx
idx += 1
return relative_order
def _merge_functions(self, functions: dict[str, cst.CSTNode], object_mapping: dict[str, set]):
"""Update the global nodes and function dependency mapping with those from the modular file.
Merging rule: if any function with the same name was redefined in the modular, use it and its dependencies
instead of the original ones (this may mean to add new functions as well, if any redefined function uses a new one).
"""
# Add/overwrite all needed function nodes and dependencies
self.functions.update(functions)
self.object_dependency_mapping.update(
{obj: dep for obj, dep in object_mapping.items() if obj in functions.keys()}
)
def _merge_assignments(self, assignments: dict[str, cst.CSTNode], object_mapping: dict[str, set]):
"""Update the global nodes with the assignment from the modular file.
Merging rule: if any assignment with the same name was redefined in the modular, we use it and its dependencies ONLY if it is
in `ASSIGNMENTS_TO_KEEP`. Otherwise, we use the original value and dependencies. This rule was chosen to avoid having to rewrite the
big docstrings.
"""
for assignment, node in assignments.items():
if assignment in ASSIGNMENTS_TO_KEEP or assignment not in self.assignments:
self.assignments[assignment] = node
if assignment in object_mapping:
self.object_dependency_mapping[assignment] = object_mapping[assignment]
def _merge_classes(self, classes: dict[str, cst.CSTNode]):
"""Update the global nodes with the new classes from the modular (i.e. classes which do not exist in current file, and
are not imported). We do NOT update any dependency mapping here. This is because we only need the names of newly defined
classes in the modular to be discoverable when computing dependencies for new nodes later on. For this reason, we
do not add the new classes to `self.classes`, but only to `global_nodes`.
"""
# Add/overwrite all needed function nodes and dependencies
self.global_nodes.update(
{
name: node
for name, node in classes.items()
if name not in self.classes and name not in self.objects_imported_from_modeling
}
)
def merge_modular_dependencies(self, classes, functions, assignments, object_mapping, start_lines):
"""Merge classes, functions and assignments from the modular definitions into the current module file,
then record the relative order of all nodes.
Note: This function takes care of updating `global_nodes` and `object_recursive_dependency_mapping` as well after the
merge with other files dependencies.
"""
self._merge_functions(functions, object_mapping)
self._merge_assignments(assignments, object_mapping)
self._merge_classes(classes)
self.modular_file_start_lines = start_lines
# Correctly re-set the global nodes at this point
self.global_nodes.update(self.functions)
self.global_nodes.update(self.assignments)
# Create the global mapping of recursive dependencies for functions and assignments
self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies()
@classmethod
def visit_and_merge_dependencies(
cls, module: cst.Module, classes, functions, assignments, object_mapping, start_lines
) -> "ModelFileMapper":
wrapper = MetadataWrapper(module)
mapper = cls(module)
wrapper.visit(mapper)
# Merge dependencies
mapper.merge_modular_dependencies(classes, functions, assignments, object_mapping, start_lines)
# Create the class dependencies graph
mapper.compute_class_dependencies()
return mapper
def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, renamed_super_class: str):
"""
Replace a class node which inherits from another modeling class. This function works in the following way:
- start from the base class node of the inherited class (a cst.Node)
- replace all methods of the base node with the methods defined in the child class
- append all new methods defined in the child class
- replace all calls to super() with the unravelled code
| ```python | | ```python
| class GemmaModel(LlamaModel): | | class GemmaModel(nn.Module):
| def __init__(self): | | def __init__(self):
Going from: | super().__init__() | to: | super().__init__(config)
| self.dropout = 0.2 | | self.dropout = 0.2
| ``` | | self.padding_idx = config.pad_token_id
| self.vocab_size = config.vocab_size
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
| self.layers = nn.ModuleList(
| [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
| )
| self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
| self.gradient_checkpointing = False
| # Initialize weights and apply final processing
| self.post_init()
| ```
"""
all_bases = [k.value.value for k in class_node.bases]
original_node = mapper.classes[renamed_super_class]
original_methods = {
f.name.value if hasattr(f, "name") else mapper.python_module.code_for_node(f): f
for f in original_node.body.body
}
updated_methods = {
f.name.value if hasattr(f, "name") else mapper.python_module.code_for_node(f): f for f in class_node.body.body
}
end_meth = []
assign_targets = {}
docstring_node = []
# Iterate directly from node.body as there can be property/setters with same names which are overwritten when we use a dict
for func in original_node.body.body:
name = func.name.value if hasattr(func, "name") else mapper.python_module.code_for_node(func)
if m.matches(func, m.FunctionDef()) and name in updated_methods and updated_methods[name] is not None:
new_params = updated_methods[name].params
# Replace the method in the replacement class, preserving decorators
kwarg_name = getattr(updated_methods[name].params, "star_kwarg", None)
if kwarg_name and kwarg_name.name.value == "super_kwargs":
parent_params = {k.name.value: k for k in func.params.params}
parent_params.update({k.name.value: k for k in new_params.params[1:]})
new_params = new_params.with_changes(
params=list(parent_params.values()), star_kwarg=func.params.star_kwarg
)
# Keep decorators in `modular_xxx.py` if any, else original decorators
new_decorators = (
updated_methods[name].decorators if len(updated_methods[name].decorators) > 0 else func.decorators
)
if not re.match(
r"\ndef .*\(.*\):\n raise.*Error\(.*",
mapper.python_module.code_for_node(updated_methods[name]),
):
func = func.with_changes(body=updated_methods[name].body, params=new_params, decorators=new_decorators)
else:
continue
if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])):
target = mapper.python_module.code_for_node(func.body[0].targets[0])
assign_targets[target] = func
elif m.matches(func, m.SimpleStatementLine(body=[m.AnnAssign()])):
target = mapper.python_module.code_for_node(func.body[0].target)
assign_targets[target] = func
elif m.matches(func, DOCSTRING_NODE):
docstring_node = [func]
else:
end_meth.append(func)
# Port new methods that are defined only in modular-file and append at the end
for func in class_node.body.body:
name = func.name.value if hasattr(func, "name") else mapper.python_module.code_for_node(func)
if m.matches(func, DOCSTRING_NODE): # This processes the docstring of the class!
# Extract the original docstring
updated_docstring = func.body[0].value.value
original_docstring = docstring_node[0].body[0].value.value
merged_doc = merge_docstrings(original_docstring, updated_docstring)
# Update the docstring in the original function
docstring_node = [
docstring_node[0].with_changes(body=[cst.Expr(value=cst.SimpleString(value=merged_doc))])
]
if name not in original_methods and func is not None and isinstance(func, cst.FunctionDef):
end_meth.append(func)
if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])):
# TODO we only use single assign might cause issues
target = mapper.python_module.code_for_node(func.body[0].targets[0])
assign_targets[target] = func
if m.matches(func, m.SimpleStatementLine(body=[m.AnnAssign()])):
target = mapper.python_module.code_for_node(func.body[0].target)
assign_targets[target] = func
end_meth = docstring_node + list(assign_targets.values()) + end_meth
# Replace the calls to `super()` with the unrolled code
result_node = original_node.with_changes(body=cst.IndentedBlock(body=end_meth))
temp_module = cst.Module(body=[result_node])
new_module = MetadataWrapper(temp_module)
new_replacement_class = new_module.visit(
SuperTransformer(temp_module, original_methods, updated_methods, all_bases)
)
new_replacement_body = new_replacement_class.body[0].body # get the indented block
# Use decorators redefined in `modular_xxx.py` if any
new_decorators = class_node.decorators if len(class_node.decorators) > 0 else original_node.decorators
# Always use the new name of the class (in case we use e.g. `ColPaliForRetrieval` inheriting from `PaliGemmaForConditionalGeneration`)
name = class_node.name
return original_node.with_changes(body=new_replacement_body, decorators=new_decorators, name=name)
TYPE_TO_FILE_TYPE = {
"Config": "configuration",
"Tokenizer": "tokenization",
"Processor": "processing",
"ImageProcessor": "image_processing",
"FeatureExtractor": "feature_extractor",
"ProcessorKwargs": "processing",
"ImagesKwargs": "processing",
"TextKwargs": "processing",
}
def find_file_type(class_name: str) -> str:
"""Based on a class name, find the file type corresponding to the class.
If the class name is `LlamaConfig` it will return `configuration`.
The list of suffixes is in `TYPE_TO_FILE_TYPE`. If there are no match, we match by default to `modeling`
"""
match_pattern = "|".join(TYPE_TO_FILE_TYPE.keys())
match = re.search(rf"({match_pattern})$", class_name)
if match:
file_type = TYPE_TO_FILE_TYPE[match.group(1)]
else:
file_type = "modeling"
return file_type
# These top-level variables will always appear at the very beginning of the file, in the order they are defined in
# this list (this is to avoid having variables at weird places, even if they are not used before)
VARIABLES_AT_THE_BEGINNING = (
"logger",
"_CHECKPOINT_FOR_DOC",
"_CONFIG_FOR_DOC",
)
# These specific modeling imports should not be visited as other modeling files
IMPORTS_TO_SKIP_IN_MODULAR = ("auto.modeling_auto",)
def append_new_import_node(node: cst.CSTNode, unused_imports: set[str], imports_to_keep: list[cst.CSTNode]):
"""Insert the new `node` to the list of `imports_to_keep` in-place, if it is not part of the `unused_imports`."""
import_node = node.body[0]
names_to_keep = []
for name in import_node.names:
name_value = name.evaluated_name
if name_value not in unused_imports:
names_to_keep.append(name.with_changes(comma=cst.MaybeSentinel.DEFAULT))
if len(names_to_keep) > 0:
new_node = node.with_changes(body=[import_node.with_changes(names=names_to_keep)])
imports_to_keep.append(new_node)
def get_needed_imports(body: dict[str, dict], all_imports: list[cst.CSTNode]) -> list[cst.CSTNode]: