| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | import argparse |
| | import glob |
| | import importlib |
| | import os |
| | import re |
| | from abc import ABC, abstractmethod |
| | from collections import Counter, defaultdict, deque |
| | from typing import Dict, Optional, Set, Union |
| |
|
| | import libcst as cst |
| | from check_copies import run_ruff |
| | from create_dependency_mapping import find_priority_list |
| | from libcst import ClassDef, CSTVisitor |
| | from libcst import matchers as m |
| | from libcst.metadata import MetadataWrapper, ParentNodeProvider, PositionProvider, ScopeProvider |
| |
|
| | from transformers import logging |
| | from transformers.models.auto.configuration_auto import CONFIG_MAPPING_NAMES |
| |
|
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | AUTO_GENERATED_MESSAGE = """# π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨ |
| | # This file was automatically generated from {relative_path}. |
| | # Do NOT edit this file manually as any edits will be overwritten by the generation of |
| | # the file from the modular. If any change should be done, please apply the change to the |
| | # {short_name} file directly. One of our CI enforces this. |
| | # π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨ |
| | """ |
| |
|
| |
|
| | def get_module_source_from_name(module_name: str) -> str: |
| | |
| | spec = importlib.util.find_spec(module_name) |
| | if spec is None or spec.origin is None: |
| | raise ValueError(f"Cannot open file associated with {module_name} module.") |
| |
|
| | with open(spec.origin, "r", encoding="utf-8") as file: |
| | source_code = file.read() |
| | return source_code |
| |
|
| |
|
| | def preserve_case_replace(text, patterns: dict, default_name: str): |
| | |
| | regex_pattern = "|".join(re.escape(key) for key in patterns.keys()) |
| | compiled_regex = re.compile(f"(?<![a-z0-9])({regex_pattern})(.|$)", re.IGNORECASE | re.DOTALL) |
| |
|
| | def replace(match): |
| | matched_pattern = match.group(1) |
| | next_char = match.group(2) |
| | new_pattern = patterns.get(matched_pattern, default_name) |
| |
|
| | |
| | |
| | if len(patterns) == 2 and matched_pattern.isupper(): |
| | if not next_char.isalpha(): |
| | |
| | new_pattern = patterns[matched_pattern.lower()].upper() |
| |
|
| | return new_pattern + next_char |
| |
|
| | return compiled_regex.sub(replace, text) |
| |
|
| |
|
| | def get_cased_name(lowercase_name: str) -> str: |
| | """From a model name in lowercase in the format `my_model`, return the cased name in the format `MyModel`.""" |
| | alt_lowercase_name = lowercase_name.replace("_", "-") |
| | if lowercase_name in CONFIG_MAPPING_NAMES: |
| | return CONFIG_MAPPING_NAMES[lowercase_name].replace("Config", "") |
| | elif alt_lowercase_name in CONFIG_MAPPING_NAMES: |
| | return CONFIG_MAPPING_NAMES[alt_lowercase_name].replace("Config", "") |
| | else: |
| | return "".join(x.title() for x in lowercase_name.split("_")) |
| |
|
| |
|
| | def get_lowercase_name(cased_name: str) -> str: |
| | """From a model name in Camelcase in the format `MyModel`, return the lowercase name in the format `my_model`.""" |
| | inverse_mapping = {value: key for key, value in CONFIG_MAPPING_NAMES.items()} |
| | if cased_name + "Config" in inverse_mapping: |
| | return inverse_mapping[cased_name + "Config"] |
| | else: |
| | return "_".join([s.lower() for s in re.findall(r"[A-Z][^A-Z]*", cased_name)]) |
| |
|
| |
|
| | class ReplaceNameTransformer(m.MatcherDecoratableTransformer): |
| | """A transformer that replaces `old_name` with `new_name` in comments, string and any references. |
| | It should take into account name like `MyNewModel`, or `my_new_model`. Without using the AUTO_MAPPING. |
| | Supported renaming patterns: |
| | - llama -> my_new_model and my_new_model -> llama |
| | - Llama -> MyNewModel and MyNewModel -> Llama |
| | - LLAMA -> MY_NEW_MODEL and MY_NEW_MODEL -> LLAMA |
| | - LLaMa -> MyNewModel abd MyNewModel -> Llama |
| | """ |
| |
|
| | def __init__(self, old_name: str, new_name: str, original_new_model_name: str = "", only_doc: bool = False): |
| | super().__init__() |
| | old_name = old_name.replace("-", "_") |
| | new_name = new_name.replace("-", "_") |
| | self.old_name = old_name |
| | self.new_name = new_name |
| | self.cased_new_name = get_cased_name(self.new_name) |
| | self.cased_old_name = get_cased_name(self.old_name) |
| | self.patterns = { |
| | old_name: new_name, |
| | old_name.upper(): new_name.upper(), |
| | |
| | self.cased_old_name: self.cased_new_name, |
| | } |
| | |
| | self.original_new_model_name = original_new_model_name |
| | self.only_doc = only_doc |
| |
|
| | def _replace_name(self, original_node, updated_node): |
| | if re.findall(r"# Copied from", updated_node.value): |
| | return cst.RemoveFromParent() |
| | update = preserve_case_replace(updated_node.value, self.patterns, self.cased_new_name) |
| | return updated_node.with_changes(value=update) |
| |
|
| | @m.leave(m.SimpleString() | m.Comment()) |
| | def replace_name(self, original_node, updated_node): |
| | return self._replace_name(original_node, updated_node) |
| |
|
| | def leave_Name(self, original_node, updated_node): |
| | if not self.only_doc: |
| | return self._replace_name(original_node, updated_node) |
| | return updated_node |
| |
|
| | def leave_ImportFrom(self, original_node, updated_node): |
| | """The imports from other file types (configuration, processing etc) should use original model name.""" |
| | if self.original_new_model_name != self.new_name and m.matches(updated_node.module, m.Name()): |
| | patterns = "|".join(ALL_FILE_TYPES) |
| | regex = rf"({patterns})_{self.new_name}" |
| | new_source = re.sub( |
| | regex, lambda m: f"{m.group(1)}_{self.original_new_model_name}", updated_node.module.value |
| | ) |
| | updated_node = updated_node.with_changes(module=updated_node.module.with_changes(value=new_source)) |
| | return updated_node |
| |
|
| |
|
| | DOCSTRING_NODE = m.SimpleStatementLine( |
| | body=[ |
| | m.Expr( |
| | value=m.SimpleString( |
| | |
| | value=m.MatchIfTrue(lambda value: re.search(r"\"\"\"[\s\S]*\"\"\"", value) is not None) |
| | ) |
| | ) |
| | ] |
| | ) |
| |
|
| |
|
| | def SUPER_CALL_NODE(func_name): |
| | return m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")), attr=m.Name(func_name))) |
| |
|
| |
|
| | def is_call_to_super(node, func_name): |
| | return m.matches( |
| | node, m.SimpleStatementLine(body=[m.Return(SUPER_CALL_NODE(func_name)) | m.Expr(SUPER_CALL_NODE(func_name))]) |
| | ) |
| |
|
| |
|
| | def get_full_attribute_name(node: Union[cst.Attribute, cst.Name]) -> Optional[str]: |
| | """Get the full name of an Attribute or Name node (e.g. `"nn.Module"` for an Attribute representing it). If the |
| | successive value of an Attribute are not Name nodes, return `None`.""" |
| | if m.matches(node, m.Name()): |
| | return node.value |
| | elif m.matches(node, m.Attribute()): |
| | if not m.matches(node.attr, m.Name()): |
| | return None |
| | name = node.attr.value |
| | new_node = node.value |
| | while m.matches(new_node, m.Attribute()): |
| | if not m.matches(new_node.attr, m.Name()): |
| | return None |
| | name = new_node.attr.value + "." + name |
| | new_node = new_node.value |
| | if not m.matches(new_node, m.Name()): |
| | return None |
| | return new_node.value + "." + name |
| | return None |
| |
|
| |
|
| | |
| | class ReplaceMethodCallTransformer(cst.CSTTransformer): |
| | def __init__(self, all_bases: Set[str]): |
| | self.all_bases = all_bases |
| |
|
| | def leave_Attribute(self, original_node: cst.Attribute, updated_node: cst.Attribute) -> cst.CSTNode: |
| | |
| | if ( |
| | m.matches(original_node.value, m.Name() | m.Attribute()) |
| | and get_full_attribute_name(original_node.value) in self.all_bases |
| | and m.matches(original_node.attr, m.Name()) |
| | ): |
| | |
| | return updated_node.with_changes( |
| | value=cst.Call(cst.Name("super")), |
| | ) |
| | |
| | elif ( |
| | m.matches(original_node.value, m.Call()) |
| | and m.matches(original_node.value.func, m.Name() | m.Attribute()) |
| | and get_full_attribute_name(original_node.value.func) in self.all_bases |
| | and m.matches(original_node.attr, m.Name()) |
| | ): |
| | |
| | return updated_node.with_changes(value=cst.Call(cst.Name("super"))) |
| | return updated_node |
| |
|
| | def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode: |
| | |
| | if m.matches(original_node.func, m.Attribute()) and ( |
| | |
| | ( |
| | m.matches(original_node.func.value, m.Call()) |
| | and m.matches(original_node.func.value.func, m.Name() | m.Attribute()) |
| | and get_full_attribute_name(original_node.func.value.func) in self.all_bases |
| | ) |
| | or |
| | |
| | ( |
| | m.matches(original_node.func.value, m.Name() | m.Attribute()) |
| | and get_full_attribute_name(original_node.func.value) in self.all_bases |
| | ) |
| | ): |
| | |
| | if len(original_node.args) > 0 and m.matches(original_node.args[0].value, m.Name("self")): |
| | |
| | new_args = updated_node.args[1:] |
| | else: |
| | new_args = updated_node.args |
| |
|
| | return updated_node.with_changes(args=new_args) |
| | return updated_node |
| |
|
| |
|
| | def get_docstring_indent(docstring): |
| | |
| | match = re.search(r'(?:"""|\'\'\'|```)\n(\s+)', docstring) |
| | if match: |
| | |
| | return len(match.group(1)) |
| | return 0 |
| |
|
| |
|
| | def is_full_docstring(new_docstring: str) -> bool: |
| | """Check if `new_docstring` is a full docstring, or if it is only part of a docstring that should then |
| | be merged with the existing old one. |
| | """ |
| | |
| | new_docstring = new_docstring.split('"""', 1)[1] |
| | |
| | if re.search(r"\n\s*Args:\n", new_docstring): |
| | return True |
| | |
| | |
| | match_object = re.search(r"\n([^\S\n]*)Returns:\n", new_docstring) |
| | if match_object is not None: |
| | full_indent = match_object.group(1) |
| | striped_doc = new_docstring.strip("\n") |
| | if striped_doc.startswith(full_indent + " " * 4) or striped_doc.startswith(full_indent + "\t"): |
| | return True |
| | return False |
| |
|
| |
|
| | def merge_docstrings(original_docstring, updated_docstring): |
| | original_level = get_docstring_indent(original_docstring) |
| | if not is_full_docstring(updated_docstring): |
| | |
| | parts = original_docstring.split("```") |
| | if "```" in updated_docstring and len(parts) > 1: |
| | updated_docstring = updated_docstring.lstrip('r"') |
| | new_parts = updated_docstring.split("```") |
| | if len(new_parts) != 3: |
| | raise ValueError("There should only be one example, and it should have opening and closing '```'") |
| | parts[1] = new_parts[1] |
| | updated_docstring = "".join( |
| | [ |
| | parts[0].rstrip(" \n") + new_parts[0], |
| | f"\n{original_level * ' '}```", |
| | parts[1], |
| | "```", |
| | parts[2], |
| | ] |
| | ) |
| | elif updated_docstring not in original_docstring: |
| | |
| | if re.search(r"\n\s*.*\(.*\)\:\n\s*\w", updated_docstring): |
| | updated_docstring = updated_docstring.replace("\n ", "\n ") |
| | updated_docstring = original_docstring.rstrip('"') + "\n" + updated_docstring.lstrip('r"\n') |
| | return updated_docstring |
| |
|
| |
|
| | class SuperTransformer(cst.CSTTransformer): |
| | METADATA_DEPENDENCIES = (ParentNodeProvider,) |
| |
|
| | def __init__(self, python_module: cst.Module, original_methods, updated_methods, all_bases=None): |
| | self.python_module = python_module |
| | self.original_methods = original_methods |
| | self.updated_methods = updated_methods |
| | self.all_assign_target = {} |
| | self.deleted_targets = {} |
| | self.all_bases = all_bases or [] |
| | self.transformer = ReplaceMethodCallTransformer(set(self.all_bases)) |
| |
|
| | def update_body(self, existing_body, new_statements): |
| | """ |
| | Helper method to update the body by removing duplicates before adding new statements. |
| | `existing_body` is the body of the original method, the parent class |
| | `new_statements` are the additional statements |
| | """ |
| | deduplicated_new_body = [] |
| | existing_nodes = set() |
| | for node in new_statements: |
| | if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])): |
| | target = self.python_module.code_for_node(node.body[0].targets[0].target) |
| | self.all_assign_target[target] = node |
| | if m.matches(node, m.SimpleStatementLine(body=[m.Del()])): |
| | target = self.python_module.code_for_node(node.body[0].target) |
| | self.deleted_targets[target] = node |
| |
|
| | for stmt in existing_body: |
| | if m.matches(stmt, m.SimpleStatementLine(body=[m.Assign()])): |
| | target = self.python_module.code_for_node(stmt.body[0].targets[0].target) |
| | if target in self.deleted_targets: |
| | continue |
| | if target in self.all_assign_target: |
| | stmt = self.all_assign_target[target] |
| | |
| | elif m.matches(stmt, DOCSTRING_NODE): |
| | continue |
| | comment_less_code = re.sub(r"#.*", "", self.python_module.code_for_node(stmt)).strip() |
| | comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip() |
| | deduplicated_new_body.append(stmt) |
| | existing_nodes.add(comment_less_code) |
| |
|
| | for node in new_statements: |
| | code = self.python_module.code_for_node(node) |
| | comment_less_code = re.sub(r"#.*", "", code).strip() |
| | comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip() |
| | if node not in deduplicated_new_body and comment_less_code not in existing_nodes: |
| | if not m.matches(node, m.SimpleStatementLine(body=[m.Del()])): |
| | deduplicated_new_body.append(node) |
| | existing_nodes.add(comment_less_code) |
| |
|
| | deduplicated_new_body = self._fix_post_init_location(deduplicated_new_body) |
| |
|
| | return deduplicated_new_body |
| |
|
| | def _fix_post_init_location(self, new_body: list[cst.CSTNode]): |
| | """Fix the location of the `post_init()` in the new body, if we added statements after the call to |
| | `super()` (it needs to be the very last statement called)""" |
| | |
| | for i, node in enumerate(new_body): |
| | code = self.python_module.code_for_node(node) |
| | comment_less_code = re.sub(r"#.*", "", code).strip() |
| | comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip() |
| | if "self.post_init(" in comment_less_code and i < len(new_body) - 1: |
| | |
| | new_body.pop(i) |
| | new_body.append(node) |
| | break |
| | return new_body |
| |
|
| | def _fix_init_location(self, new_body): |
| | """Fix the location of the `super().__init__()` in the new body, if we had new statements before it.""" |
| | start_index = 0 |
| | for i, node in enumerate(new_body): |
| | if m.matches(node, DOCSTRING_NODE) and i == start_index: |
| | start_index += 1 |
| | continue |
| | code = self.python_module.code_for_node(node) |
| | comment_less_code = re.sub(r"#.*", "", code).strip() |
| | comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip() |
| | if "super().__init__" in comment_less_code and i > start_index: |
| | |
| | node = new_body.pop(i) |
| | new_body = new_body[:start_index] + [node] + new_body[start_index:] |
| | break |
| | return new_body |
| |
|
| | def replace_super_calls(self, node: cst.IndentedBlock, func_name: str) -> cst.CSTNode: |
| | """Updates the body of the input `node`'s `func_name` function by replacing calls |
| | to super().func_name() with the source code of the parent class' `func_name`. |
| | It keeps everything that is defined before `super().func_name()`. |
| | """ |
| | self.has_docstring = False |
| | parent_has_docstring = False |
| | if func_name in self.original_methods: |
| | parent_has_docstring = m.matches(self.original_methods[func_name].body.body[0], DOCSTRING_NODE) |
| | new_body = [] |
| | has_super_call = False |
| |
|
| | for i, expr in enumerate(node.body): |
| | if is_call_to_super(expr, func_name): |
| | has_super_call = True |
| | new_body.extend(self.update_body(self.original_methods[func_name].body.body, node.body[i + 1 :])) |
| | new_body = self._fix_init_location(new_body) |
| | else: |
| | expr = expr.visit(self.transformer) |
| | if m.matches(expr, DOCSTRING_NODE): |
| | self.has_docstring = True |
| | if parent_has_docstring: |
| | original_docstring = self.original_methods[func_name].body.body[0].body[0].value.value |
| | updated_docstring = expr.body[0].value.value |
| | merged_doc = merge_docstrings(original_docstring, updated_docstring) |
| | new_node = [expr.with_changes(body=[cst.Expr(value=cst.SimpleString(value=merged_doc))])] |
| | else: |
| | new_node = [expr] |
| | new_body.extend(new_node) |
| | elif not m.matches(expr, m.SimpleStatementLine(body=[m.Del()])) and not has_super_call: |
| | new_body.append(expr) |
| | if not self.has_docstring and parent_has_docstring: |
| | new_body = [self.original_methods[func_name].body.body[0]] + new_body |
| | return node.with_changes(body=new_body) |
| |
|
| | def leave_FunctionDef(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode: |
| | if updated_node.name.value in self.updated_methods: |
| | name = updated_node.name.value |
| | new_body = self.replace_super_calls(updated_node.body, name) |
| | return updated_node.with_changes(body=new_body, params=updated_node.params) |
| | return updated_node |
| |
|
| | def leave_Return(self, original_node: cst.Return, updated_node: cst.Return) -> cst.CSTNode: |
| | """ "When a return statement is reached, it is replaced with the unrolled super code""" |
| | if m.matches(updated_node.value, m.Call(func=m.Attribute(attr=m.Name("super")))): |
| | func_def = self.get_metadata(ParentNodeProvider, original_node) |
| | if m.matched(func_def, m.FunctionDef()) and func_def.name.value in self.original_methods: |
| | updated_return_value = updated_node.value.with_changes( |
| | args=[ |
| | cst.Arg( |
| | value=cst.Call(func=cst.Name("super"), args=[cst.Arg(value=cst.Name(func_def.name.value))]) |
| | ) |
| | ] |
| | ) |
| | return updated_node.with_changes(value=updated_return_value) |
| | return updated_node |
| |
|
| |
|
| | def find_all_dependencies( |
| | dependency_mapping: Dict[str, set], |
| | start_entity: Optional[str] = None, |
| | initial_dependencies: Optional[set] = None, |
| | initial_checked_dependencies: Optional[set] = None, |
| | return_parent: bool = False, |
| | ) -> Union[list, set]: |
| | """Return all the dependencies of the given `start_entity` or `initial_dependencies`. This is basically some kind of |
| | BFS traversal algorithm. It can either start from `start_entity`, or `initial_dependencies`. |
| | |
| | Args: |
| | dependency_mapping (`Dict[str, set]`): |
| | A mapping from entities (usually function/assignment names), to immediate dependencies. That is, for function names, |
| | a mapping {"foo": {"bar", "test"}} would indicate that functions `bar` and `test` are immediately called |
| | in `foo`'s definition. |
| | start_entity (str | None, *optional*): |
| | A key of `dependency_mapping`, indicating from which entity to start the search. |
| | initial_dependencies (set | None, *optional*): |
| | If `start_entity` is not provided, this can be used as an alternative. In this case, the search will continue |
| | from all the entities in `initial_dependencies`, if they are in `dependency_mapping`. |
| | initial_checked_dependencies (set | None, *optional*): |
| | If provided, entities already present in `initial_checked_dependencies` will not be part of the returned dependencies. |
| | return_parent (bool, *optional*): |
| | If `True`, will return a list consisting of tuples (dependency, parent) instead of a simple set of dependencies. Note |
| | that the order of the items in the list reflects the traversal order. Thus, no parent can ever appear before childs. |
| | Returns: |
| | A set of all the dependencies, or a list of tuples `(dependency, parent)` if `return_parent=True`. |
| | |
| | Example: |
| | Given the following structure in the `modular_xxx.py` file: |
| | ``` |
| | def foo1(): |
| | pass |
| | |
| | def foo2(): |
| | pass |
| | |
| | def bar(): |
| | foo1() |
| | |
| | def foobar(): |
| | bar() |
| | foo2() |
| | |
| | class MyLayer(SomeOtherModelLayer): |
| | def forward(...): |
| | foobar() |
| | ``` |
| | and the `dependency_mapping` created when visiting the `modular_xxx.py` file, we get: |
| | ``` |
| | dependency_mapping = {'bar': {'foo1'}, 'foobar': {'bar', 'foo2'}} |
| | find_all_dependencies(dependency_mapping, start_entity='foobar', return_parent=True) |
| | >>> [('bar', 'foobar'), ('foo2', 'foobar'), ('foo1', 'bar')] |
| | ``` |
| | That is, all the functions needed (and potentially their immediate parent) so that the function to be added |
| | in MyLayer (`foobar`) can work correctly. |
| | """ |
| | if initial_dependencies is None and start_entity is not None: |
| | initial_dependencies = dependency_mapping[start_entity] |
| | if initial_checked_dependencies is None: |
| | initial_checked_dependencies = set() |
| |
|
| | dependency_queue = deque(initial_dependencies) |
| | all_dependencies = set() |
| | all_dependencies_with_parent = [] |
| | checked_dependencies = set(initial_checked_dependencies) |
| | parents = dict.fromkeys(initial_dependencies, start_entity) |
| | while len(dependency_queue) > 0: |
| | |
| | current = dependency_queue.popleft() |
| | if current not in checked_dependencies: |
| | |
| | all_dependencies.add(current) |
| | all_dependencies_with_parent += [(current, parents[current])] |
| | if current in dependency_mapping.keys(): |
| | |
| | dependency_queue.extend(dependency_mapping[current]) |
| | parents.update(dict.fromkeys(dependency_mapping[current], current)) |
| | |
| | checked_dependencies.add(current) |
| |
|
| | if not return_parent: |
| | return all_dependencies |
| | |
| | return all_dependencies_with_parent |
| |
|
| |
|
| | |
| | ASSIGNMENTS_REGEX_TO_KEEP = [r"_CHECKPOINT", r"_EXPECTED", r"_FOR_DOC", r"_HIDDEN_STATES_START_POSITION"] |
| |
|
| | |
| | ASSIGNMENTS_REGEX_TO_KEEP_IF_NOT_NONE = [r"_DOCSTRING"] |
| |
|
| |
|
| | class ClassDependencyMapper(CSTVisitor): |
| | """A visitor which is designed to analyze a single class node to get all its dependencies that are shared with the set of |
| | `global_names`. |
| | """ |
| |
|
| | def __init__( |
| | self, class_name: str, global_names: set[str], objects_imported_from_modeling: Optional[set[str]] = None |
| | ): |
| | super().__init__() |
| | self.class_name = class_name |
| | self.dependencies = set() |
| | self.global_names = global_names |
| | self.objects_imported_from_modeling = ( |
| | set() if objects_imported_from_modeling is None else objects_imported_from_modeling |
| | ) |
| |
|
| | def visit_Name(self, node): |
| | if ( |
| | node.value != self.class_name |
| | and node.value in self.global_names |
| | and node.value not in self.objects_imported_from_modeling |
| | ): |
| | self.dependencies.add(node.value) |
| |
|
| |
|
| | def dependencies_for_class_node(node: cst.ClassDef, global_names: set[str]) -> set: |
| | """Create immediate dependencies for a class node based on the `global_names`.""" |
| | temp_module = cst.Module(body=[node]) |
| | visitor = ClassDependencyMapper(node.name.value, global_names) |
| | temp_module.visit(visitor) |
| | return visitor.dependencies |
| |
|
| |
|
| | def augmented_dependencies_for_class_node( |
| | node: cst.ClassDef, mapper: "ModuleMapper", objects_imported_from_modeling: Optional[set[str]] = None |
| | ) -> set: |
| | """Create augmented dependencies for a class node based on a `mapper`. |
| | Augmented dependencies means immediate dependencies + recursive function and assignments dependencies. |
| | """ |
| | temp_module = cst.Module(body=[node]) |
| | visitor = ClassDependencyMapper(node.name.value, set(mapper.global_nodes.keys()), objects_imported_from_modeling) |
| | temp_module.visit(visitor) |
| | return mapper.augment_dependencies(visitor.dependencies) |
| |
|
| |
|
| | |
| | ALL_FILE_TYPES = ( |
| | "modeling", |
| | "configuration", |
| | "tokenization", |
| | "processing", |
| | "image_processing", |
| | "feature_extractor", |
| | ) |
| |
|
| |
|
| | class ModuleMapper(CSTVisitor, ABC): |
| | """An abstract visitor class which analyses a module, creating a mapping of dependencies for classes, functions and assignments. |
| | Class dependencies are computed with `compute_class_dependencies()`, while function and assignment dependencies are stored in |
| | `self.object_recursive_dependency_mapping` (can be computed by `_compute_recursive_object_dependencies()`). |
| | It defines common visiting patterns (i.e. common visit_xxx/leave_xxx functions) between the modular file and the |
| | modeling files that will be visited. |
| | """ |
| |
|
| | METADATA_DEPENDENCIES = (ParentNodeProvider, PositionProvider) |
| |
|
| | def __init__(self, python_module: cst.Module): |
| | |
| | self.python_module: cst.Module = python_module |
| | self.classes: Dict[str, cst.ClassDef] = {} |
| | self.imports = [] |
| | self.functions: Dict[str, cst.FunctionDef] = {} |
| | self.object_dependency_mapping = defaultdict(set) |
| | self.assignments: Dict[str, cst.SimpleStatementLine] = {} |
| | self.current_function = None |
| | self.current_class = None |
| | self.current_assignment = None |
| | |
| | self.objects_imported_from_modeling = set() |
| | |
| | self.match_patterns = "|".join(ALL_FILE_TYPES) |
| | |
| |
|
| | def visit_ImportFrom(self, node): |
| | """This keeps track of objects imported from neighbor modeling files (e.g. in `modeling_xxx.py, we have |
| | `from .configuration_xxx import Config`, then `Config` should be recorded as it is not a dependency that needs |
| | to be added (because it will be part of the imports)""" |
| | import_module = self.python_module.code_for_node(node.module) |
| | import_statement = "." * len(node.relative) + import_module |
| | if re.search(rf"^\.({self.match_patterns})_.*", import_statement): |
| | for imported_object in node.names: |
| | |
| | if imported_object.evaluated_alias is not None: |
| | self.objects_imported_from_modeling.add(imported_object.evaluated_alias) |
| | else: |
| | self.objects_imported_from_modeling.add(imported_object.evaluated_name) |
| |
|
| | def visit_SimpleStatementLine(self, node): |
| | """ |
| | Global Assigns like `GEMMA_INPUT_DOCSTRING = 'THIS IS THE INPUT'` and all import statements |
| | are extracted and saved in their corresponding dict. They are then used when updating dependency mappings. |
| | """ |
| | parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) |
| | simple_top_level_assign_structure = m.SimpleStatementLine( |
| | body=[m.Assign(targets=[m.AssignTarget(target=m.Name())])] |
| | ) |
| | if m.matches(parent_node, m.Module()): |
| | if m.matches(node, simple_top_level_assign_structure): |
| | left_hand_side = node.body[0].targets[0].target.value |
| | self.current_assignment = left_hand_side |
| | self.assignments[left_hand_side] = node |
| | elif m.matches(node, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()])): |
| | self.imports.append(node) |
| |
|
| | def leave_SimpleStatementLine(self, node): |
| | |
| | |
| | self.current_assignment = None |
| |
|
| | def visit_FunctionDef(self, node): |
| | parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) |
| | if m.matches(parent_node, m.Module()): |
| | self.current_function = node.name.value |
| | self.functions[node.name.value] = node |
| |
|
| | def leave_FunctionDef(self, node): |
| | parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) |
| | if m.matches(parent_node, m.Module()): |
| | self.current_function = None |
| |
|
| | def visit_If(self, node): |
| | |
| | if self.current_function is None and self.current_class is None: |
| | for stmt in node.body.body: |
| | if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])): |
| | self.imports.append(node) |
| |
|
| | def visit_ClassDef(self, node: ClassDef) -> None: |
| | """Record class nodes to create their dependencies at the end.""" |
| | self.classes[node.name.value] = node |
| | self.current_class = node.name.value |
| |
|
| | def leave_ClassDef(self, node): |
| | self.current_class = None |
| |
|
| | def visit_Name(self, node: cst.Call): |
| | """This is used to create a mapping from module-scope functions and assignments to objects used inside them.""" |
| | if self.current_function is not None: |
| | self.object_dependency_mapping[self.current_function].add(node.value) |
| | if self.current_assignment is not None: |
| | self.object_dependency_mapping[self.current_assignment].add(node.value) |
| |
|
| | def leave_Module(self, node): |
| | """When leaving the module, we store the position of each global scoped node to allow sorting the dependencies |
| | based on their position in the code later. We use the PositionProvider metadata wrapper for this. |
| | We also make sure to update `self.object_dependency_mapping` so that it contains only names recorded in |
| | `self.global_nodes`. |
| | """ |
| | |
| | self.global_nodes = {**self.assignments, **self.classes, **self.functions} |
| | |
| | self.start_lines = {} |
| | for id, node in self.global_nodes.items(): |
| | self.start_lines[id] = self.get_metadata(cst.metadata.PositionProvider, node).start.line |
| |
|
| | def _restrict_dependencies_to_known_entities(self): |
| | """Since we added every Name as part of `self.object_dependency_mapping`, we need to remove those that |
| | are not part of the recorded objects in `self.global_nodes` (i.e. built-in variables, imports, etc). |
| | This should be called only after all merging operations have been finalized!!""" |
| | global_objects = set(self.global_nodes.keys()) |
| | for object_name, dependencies in self.object_dependency_mapping.items(): |
| | self.object_dependency_mapping[object_name] = {dep for dep in dependencies if dep in global_objects} |
| |
|
| | def _compute_recursive_object_dependencies(self) -> dict[str, set]: |
| | """Based on immediate dependency mapping, create the recursive dependency mapping. For example, given the |
| | following file: |
| | ``` |
| | def foo(): |
| | pass |
| | |
| | def bar(): |
| | foo() |
| | |
| | def test(): |
| | bar() |
| | ``` |
| | this visitor can only record immediate dependencies, i.e. it will record the following |
| | `self.object_dependency_mapping = {"test": {"bar"}, "bar": {"foo}}`. This function is used to create |
| | the recursive mapping, i.e. `recursive_dependencies = {"test": {"bar", "foo"}, "bar": {"foo}}`. |
| | """ |
| | recursive_dependencies = {} |
| | for object_name in self.object_dependency_mapping.keys(): |
| | all_dependencies = find_all_dependencies(self.object_dependency_mapping, start_entity=object_name) |
| | recursive_dependencies[object_name] = all_dependencies |
| | return recursive_dependencies |
| |
|
| | def augment_dependencies(self, dependencies: set[str]) -> set[str]: |
| | """For a set of `dependencies`, augment them by adding all potential dependencies of the **functions** and |
| | **assignments** present in the `dependencies`. |
| | """ |
| | new_dependencies = dependencies.copy() |
| | |
| | for dep in tuple(dependencies): |
| | if dep in self.object_recursive_dependency_mapping.keys(): |
| | new_dependencies.update(self.object_recursive_dependency_mapping[dep]) |
| | return new_dependencies |
| |
|
| | def compute_class_dependencies(self): |
| | """For each visited class, find its dependencies based on visiting the current file + potential merged dependencies.""" |
| | self.class_dependency_mapping = {} |
| | for class_name, class_node in self.classes.items(): |
| | dependencies = dependencies_for_class_node(class_node, set(self.global_nodes.keys())) |
| | |
| | self.class_dependency_mapping[class_name] = self.augment_dependencies(dependencies) |
| |
|
| | @abstractmethod |
| | def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: |
| | raise NotImplementedError |
| |
|
| |
|
| | class ModelFileMapper(ModuleMapper): |
| | """A mapper designed to parse modeling files (like `modeling_llama.py`). When encountering such a file |
| | in the `modular_xxx.py` file, we need to correctly visit it and merge the dependencies of the modular and current file. |
| | For this reason, this class should only be instantiated from the class method `visit_and_merge_dependencies`, which takes |
| | care of correctly merging dependencies, then finalizes all dependency graph computations. |
| | Note that we only merge functions and assignments here, as classes will be treated later on as they may be modified. |
| | For example, if you redefine `apply_rotary_pos_emb()` in the modular, the new node should be used in the dependencies |
| | of the modeling files as well. |
| | """ |
| |
|
| | def __init__(self, python_module: cst.Module): |
| | super().__init__(python_module) |
| |
|
| | def compute_relative_order(self, missing_dependencies: set[str]) -> dict[str, int]: |
| | """Compute in which relative order the `missing_dependencies` should appear when the nodes are added to the final file that |
| | will be created based on the modular. |
| | """ |
| | relative_order = {} |
| | idx = 0 |
| | classes = sorted( |
| | [dep for dep in tuple(missing_dependencies) if dep in self.classes], key=lambda x: self.start_lines[x] |
| | ) |
| | |
| | |
| | if len(classes) > 0 and not hasattr(self, "class_dependency_mapping"): |
| | raise ValueError("Cannot correctly find the relative order of the dependencies.") |
| |
|
| | remaining_dependencies = missing_dependencies.copy() |
| |
|
| | |
| | for class_name in classes: |
| | class_dependencies = tuple(self.class_dependency_mapping[class_name] & remaining_dependencies) |
| | original_dependencies = [] |
| | merged_dependencies = [] |
| | |
| | |
| | for class_dep in class_dependencies: |
| | if class_dep in self.start_lines: |
| | original_dependencies.append(class_dep) |
| | else: |
| | merged_dependencies.append(class_dep) |
| | |
| | |
| | original_dependencies = sorted(original_dependencies, reverse=True) |
| | |
| | original_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines.get(x, 1e10)) |
| | merged_dependencies = sorted(merged_dependencies, key=lambda x: self.modular_file_start_lines[x]) |
| |
|
| | |
| | for dep in original_dependencies + merged_dependencies: |
| | remaining_dependencies.remove(dep) |
| | relative_order[dep] = idx |
| | idx += 1 |
| | |
| | |
| | if class_name in remaining_dependencies: |
| | remaining_dependencies.remove(class_name) |
| | relative_order[class_name] = idx |
| | idx += 1 |
| |
|
| | |
| | remaining_dependencies = tuple(remaining_dependencies) |
| | original_dependencies = [] |
| | merged_dependencies = [] |
| | for dep in remaining_dependencies: |
| | if dep in self.modular_file_start_lines: |
| | merged_dependencies.append(dep) |
| | else: |
| | original_dependencies.append(dep) |
| | |
| | |
| | original_dependencies = sorted(original_dependencies, reverse=True) |
| | |
| | original_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines.get(x, 1e10)) |
| | merged_dependencies = sorted(merged_dependencies, key=lambda x: self.modular_file_start_lines[x]) |
| |
|
| | |
| | for dep in original_dependencies + merged_dependencies: |
| | relative_order[dep] = idx |
| | idx += 1 |
| |
|
| | return relative_order |
| |
|
| | def _merge_functions(self, functions: dict[str, cst.CSTNode], object_mapping: dict[str, set]): |
| | """Update the global nodes and function dependency mapping with those from the modular file. |
| | |
| | Merging rule: if any function with the same name was redefined in the modular, use it and its dependencies |
| | instead of the original ones (this may mean to add new functions as well, if any redefined function uses a new one). |
| | """ |
| | |
| | self.functions.update(functions) |
| | self.object_dependency_mapping.update( |
| | {obj: dep for obj, dep in object_mapping.items() if obj in functions.keys()} |
| | ) |
| | |
| | self.global_nodes.update(self.functions) |
| |
|
| | def _merge_assignments(self, assignments: dict[str, cst.CSTNode], object_mapping: dict[str, set]): |
| | """Update the global nodes with the assignment from the modular file. |
| | |
| | Merging rule: if any assignment with the same name was redefined in the modular, we use it and its dependencies ONLY if it matches |
| | a pattern in `ASSIGNMENTS_REGEX_TO_KEEP_IF_NOT_NONE` and its value is not None, or if it matches a pattern in `ASSIGNMENTS_REGEX_TO_KEEP. |
| | Otherwise, we use the original value and dependencies. This rule was chosen to avoid having to rewrite the big docstrings. |
| | """ |
| | for assignment, node in assignments.items(): |
| | should_keep = any(re.search(pattern, assignment) for pattern in ASSIGNMENTS_REGEX_TO_KEEP) |
| |
|
| | should_keep_if_not_none = any( |
| | re.search(pattern, assignment) for pattern in ASSIGNMENTS_REGEX_TO_KEEP_IF_NOT_NONE |
| | ) and not (hasattr(node.body[0].value, "value") and node.body[0].value.value == "None") |
| |
|
| | if should_keep or should_keep_if_not_none or assignment not in self.assignments: |
| | self.assignments[assignment] = node |
| | if assignment in object_mapping: |
| | self.object_dependency_mapping[assignment] = object_mapping[assignment] |
| | |
| | self.global_nodes.update(self.assignments) |
| |
|
| | def _merge_classes(self, classes: dict[str, cst.CSTNode]): |
| | """Update the global nodes with the new classes from the modular (i.e. classes which do not exist in current file, and |
| | are not imported). We do NOT update any dependency mapping here. This is because we only need the names of newly defined |
| | classes in the modular to be discoverable when computing dependencies for new nodes later on. For this reason, we |
| | do not add the new classes to `self.classes`, but only to `global_nodes`. |
| | """ |
| | |
| | self.global_nodes.update( |
| | { |
| | name: node |
| | for name, node in classes.items() |
| | if name not in self.classes and name not in self.objects_imported_from_modeling |
| | } |
| | ) |
| |
|
| | def merge_modular_dependencies(self, classes, functions, assignments, object_mapping, start_lines): |
| | """Merge classes, functions and assignments from the modular definitions into the current module file, |
| | then record the relative order of all nodes. |
| | Note: This function takes care of updating `global_nodes` and `object_recursive_dependency_mapping` as well after the |
| | merge with other files dependencies. |
| | """ |
| | self._merge_functions(functions, object_mapping) |
| | self._merge_assignments(assignments, object_mapping) |
| | self._merge_classes(classes) |
| | self.modular_file_start_lines = start_lines |
| |
|
| | |
| | self._restrict_dependencies_to_known_entities() |
| | |
| | self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies() |
| |
|
| | @classmethod |
| | def visit_and_merge_dependencies( |
| | cls, module: cst.Module, classes, functions, assignments, object_mapping, start_lines |
| | ) -> "ModelFileMapper": |
| | wrapper = MetadataWrapper(module) |
| | mapper = cls(module) |
| | wrapper.visit(mapper) |
| | |
| | mapper.merge_modular_dependencies(classes, functions, assignments, object_mapping, start_lines) |
| | |
| | mapper.compute_class_dependencies() |
| | return mapper |
| |
|
| |
|
| | def common_partial_suffix(str1: str, str2: str) -> str: |
| | """Return the biggest common suffix between 2 strings. If one string is a full suffix of the other string, |
| | we do not consider it a common suffix and return `""`""" |
| | common_suffix = "" |
| | for i in range(1, min(len(str1), len(str2)) + 1): |
| | if str1[-i] == str2[-i]: |
| | common_suffix = str1[-i] + common_suffix |
| | else: |
| | break |
| | |
| | if common_suffix == str1 or common_suffix == str2: |
| | common_suffix = "" |
| | return common_suffix |
| |
|
| |
|
| | def replace_class_node( |
| | mapper: ModelFileMapper, class_node: cst.ClassDef, renamed_super_class: str, original_super_class: str |
| | ): |
| | """ |
| | Replace a class node which inherits from another modeling class. This function works in the following way: |
| | - start from the base class node of the inherited class (a cst.Node) |
| | - replace all methods of the base node with the methods defined in the child class |
| | - append all new methods defined in the child class |
| | - replace all calls to super() with the unravelled code |
| | |
| | | ```python | | ```python |
| | | class GemmaModel(LlamaModel): | | class GemmaModel(nn.Module): |
| | | def __init__(self): | | def __init__(self): |
| | Going from: | super().__init__() | to: | super().__init__(config) |
| | | self.dropout = 0.2 | | self.dropout = 0.2 |
| | | ``` | | self.padding_idx = config.pad_token_id |
| | | self.vocab_size = config.vocab_size |
| | | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
| | | self.layers = nn.ModuleList( |
| | | [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
| | | ) |
| | | self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| | | self.gradient_checkpointing = False |
| | | # Initialize weights and apply final processing |
| | | self.post_init() |
| | | ``` |
| | """ |
| | all_bases = [get_full_attribute_name(k.value) for k in class_node.bases] |
| | if any(base is None for base in all_bases): |
| | raise ValueError(f"Could not parse the name of the bases for {class_node.name.value}") |
| |
|
| | original_node = mapper.classes[renamed_super_class] |
| | |
| | new_name = class_node.name |
| |
|
| | |
| | if new_name.value != renamed_super_class: |
| | common_suffix = common_partial_suffix(new_name.value, renamed_super_class) |
| | |
| | old, new = renamed_super_class.replace(common_suffix, ""), new_name.value.replace(common_suffix, "") |
| | temp_module = cst.Module(body=[original_node]) |
| | original_node = temp_module.visit( |
| | ReplaceNameTransformer(get_lowercase_name(old), get_lowercase_name(new), only_doc=True) |
| | ).body[0] |
| |
|
| | |
| | |
| | additional_bases = [base for base in all_bases if base != original_super_class] |
| | new_bases = [] |
| | for original_base in original_node.bases: |
| | new_base = original_base |
| | |
| | if m.matches(original_base.value, m.Name()): |
| | original_base_name = original_base.value.value |
| | for additional_base_name in additional_bases: |
| | suffix = common_partial_suffix(original_base_name, additional_base_name) |
| | if len(suffix) > 0 and suffix[0].isupper(): |
| | new_name_node = original_base.value.with_changes(value=additional_base_name) |
| | new_base = original_base.with_changes(value=new_name_node) |
| | break |
| | new_bases.append(new_base) |
| |
|
| | original_methods = { |
| | f.name.value if hasattr(f, "name") else mapper.python_module.code_for_node(f): f |
| | for f in original_node.body.body |
| | } |
| | updated_methods = { |
| | f.name.value if hasattr(f, "name") else mapper.python_module.code_for_node(f): f for f in class_node.body.body |
| | } |
| | end_meth = [] |
| |
|
| | assign_targets = {} |
| | docstring_node = [] |
| | |
| | for func in original_node.body.body: |
| | name = func.name.value if hasattr(func, "name") else mapper.python_module.code_for_node(func) |
| | if m.matches(func, m.FunctionDef()) and name in updated_methods and updated_methods[name] is not None: |
| | new_params = updated_methods[name].params |
| | |
| | kwarg_name = getattr(updated_methods[name].params, "star_kwarg", None) |
| | if kwarg_name and kwarg_name.name.value == "super_kwargs": |
| | parent_params = {k.name.value: k for k in func.params.params} |
| | parent_params.update({k.name.value: k for k in new_params.params[1:]}) |
| | new_params = new_params.with_changes( |
| | params=list(parent_params.values()), star_kwarg=func.params.star_kwarg |
| | ) |
| | |
| | new_decorators = ( |
| | updated_methods[name].decorators if len(updated_methods[name].decorators) > 0 else func.decorators |
| | ) |
| |
|
| | |
| | new_return_annotation = updated_methods[name].returns if updated_methods[name].returns else func.returns |
| |
|
| | if not re.match( |
| | r"\ndef .*\(.*\):\n raise.*Error\(.*", |
| | mapper.python_module.code_for_node(updated_methods[name]), |
| | ): |
| | func = func.with_changes( |
| | body=updated_methods[name].body, |
| | params=new_params, |
| | decorators=new_decorators, |
| | returns=new_return_annotation, |
| | ) |
| | else: |
| | continue |
| |
|
| | if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])): |
| | target = mapper.python_module.code_for_node(func.body[0].targets[0]) |
| | assign_targets[target] = func |
| | elif m.matches(func, m.SimpleStatementLine(body=[m.AnnAssign()])): |
| | target = mapper.python_module.code_for_node(func.body[0].target) |
| | assign_targets[target] = func |
| | elif m.matches(func, DOCSTRING_NODE): |
| | docstring_node = [func] |
| | else: |
| | end_meth.append(func) |
| |
|
| | |
| | for func in class_node.body.body: |
| | name = func.name.value if hasattr(func, "name") else mapper.python_module.code_for_node(func) |
| | if m.matches(func, DOCSTRING_NODE): |
| | |
| | updated_docstring = func.body[0].value.value |
| | if len(docstring_node) == 0: |
| | docstring_node = [ |
| | cst.SimpleStatementLine(body=[cst.Expr(value=cst.SimpleString(value=updated_docstring))]) |
| | ] |
| | else: |
| | original_docstring = docstring_node[0].body[0].value.value |
| | merged_doc = merge_docstrings(original_docstring, updated_docstring) |
| | |
| | docstring_node = [ |
| | docstring_node[0].with_changes(body=[cst.Expr(value=cst.SimpleString(value=merged_doc))]) |
| | ] |
| | if name not in original_methods and func is not None and isinstance(func, cst.FunctionDef): |
| | end_meth.append(func) |
| | if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])): |
| | |
| | target = mapper.python_module.code_for_node(func.body[0].targets[0]) |
| | assign_targets[target] = func |
| | if m.matches(func, m.SimpleStatementLine(body=[m.AnnAssign()])): |
| | target = mapper.python_module.code_for_node(func.body[0].target) |
| | assign_targets[target] = func |
| | end_meth = docstring_node + list(assign_targets.values()) + end_meth |
| |
|
| | |
| | result_node = original_node.with_changes(body=cst.IndentedBlock(body=end_meth)) |
| | temp_module = cst.Module(body=[result_node]) |
| | new_module = MetadataWrapper(temp_module) |
| | new_replacement_class = new_module.visit( |
| | SuperTransformer(temp_module, original_methods, updated_methods, all_bases) |
| | ) |
| | new_replacement_body = new_replacement_class.body[0].body |
| |
|
| | |
| | new_decorators = class_node.decorators if len(class_node.decorators) > 0 else original_node.decorators |
| |
|
| | return original_node.with_changes( |
| | body=new_replacement_body, decorators=new_decorators, bases=new_bases, name=new_name |
| | ) |
| |
|
| |
|
| | TYPE_TO_FILE_TYPE = { |
| | "Config": "configuration", |
| | "Tokenizer": "tokenization", |
| | "Processor": "processing", |
| | "ImageProcessor": "image_processing", |
| | "ImageProcessorFast": "image_processing*_fast", |
| | "FastImageProcessorKwargs": "image_processing*_fast", |
| | "FeatureExtractor": "feature_extractor", |
| | "ProcessorKwargs": "processing", |
| | "ImagesKwargs": "processing", |
| | "TextKwargs": "processing", |
| | } |
| |
|
| |
|
| | def find_file_type(class_name: str) -> str: |
| | """Based on a class name, find the file type corresponding to the class. |
| | If the class name is `LlamaConfig` it will return `configuration`. |
| | The list of suffixes is in `TYPE_TO_FILE_TYPE`. If there are no match, we match by default to `modeling` |
| | """ |
| | match_pattern = "|".join(TYPE_TO_FILE_TYPE.keys()) |
| | match = re.search(rf"({match_pattern})$", class_name) |
| | if match: |
| | file_type = TYPE_TO_FILE_TYPE[match.group(1)] |
| | else: |
| | file_type = "modeling" |
| | return file_type |
| |
|
| |
|
| | |
| | |
| | VARIABLES_AT_THE_BEGINNING = ( |
| | "logger", |
| | "_CHECKPOINT_FOR_DOC", |
| | "_CONFIG_FOR_DOC", |
| | ) |
| |
|
| | |
| | IMPORTS_TO_SKIP_IN_MODULAR = ("auto.modeling_auto",) |
| |
|
| |
|
| | def append_new_import_node( |
| | node: cst.CSTNode, unused_imports: set[str], added_names: set, imports_to_keep: list[cst.CSTNode] |
| | ): |
| | """Insert the new `node` to the list of `imports_to_keep` in-place, if it is not part of the `unused_imports` or `added_names`. |
| | Also modifies `added_names` in-place accordingly.""" |
| | import_node = node.body[0] |
| | names_to_keep = [] |
| | for name in import_node.names: |
| | name_value = name.evaluated_alias or name.evaluated_name |
| | if name_value not in unused_imports and name_value not in added_names: |
| | names_to_keep.append(name.with_changes(comma=cst.MaybeSentinel.DEFAULT)) |
| | added_names.add(name_value) |
| | if len(names_to_keep) > 0: |
| | new_node = node.with_changes(body=[import_node.with_changes(names=names_to_keep)]) |
| | imports_to_keep.append(new_node) |
| |
|
| |
|
| | def get_needed_imports(body: dict[str, dict], all_imports: list[cst.CSTNode]) -> list[cst.CSTNode]: |
| | """Get all the imports needed in the `body`, from the list of `all_imports`. |
| | `body` is a dict with the following structure `{str: {"insert_idx": int, "node": cst.CSTNode}}`. |
| | Note: we need to use `isinstance` on scope assignments, m.matches apparently does not work here yet! |
| | """ |
| | new_body = [k[1]["node"] for k in sorted(body.items(), key=lambda x: x[1]["insert_idx"])] |
| | wrapper = MetadataWrapper(cst.Module(body=all_imports + new_body)) |
| | scopes = set(wrapper.resolve(ScopeProvider).values()) |
| | unused_imports = set() |
| | import_ref_count = defaultdict(lambda: 0) |
| | for scope in scopes: |
| | for assignment in scope.assignments: |
| | node = assignment.node |
| | if isinstance(assignment, cst.metadata.Assignment) and isinstance(node, (cst.Import, cst.ImportFrom)): |
| | ref_count = len(assignment.references) |
| | name = assignment.name |
| | import_ref_count[name] = max(ref_count, import_ref_count[name]) |
| | |
| | |
| | unused_imports = {name for name, count in import_ref_count.items() if count <= 0 or name in body.keys()} |
| |
|
| | imports_to_keep = [] |
| | |
| | |
| | added_names = set() |
| | existing_protected_statements = set() |
| | for node in all_imports: |
| | if m.matches(node, m.If()): |
| | new_statements = [] |
| | for stmt_node in node.body.body: |
| | append_new_import_node(stmt_node, unused_imports, added_names, new_statements) |
| | new_statements = [stmt for stmt in new_statements if str(stmt) not in existing_protected_statements] |
| | if len(new_statements) > 0: |
| | new_node = node.with_changes(body=node.body.with_changes(body=new_statements)) |
| | imports_to_keep.append(new_node) |
| | existing_protected_statements.update({str(stmt) for stmt in new_statements}) |
| | else: |
| | append_new_import_node(node, unused_imports, added_names, imports_to_keep) |
| |
|
| | protected_import_nodes = [node for node in imports_to_keep if m.matches(node, m.If())] |
| | usual_import_nodes = [node for node in imports_to_keep if not m.matches(node, m.If())] |
| |
|
| | |
| | return usual_import_nodes + protected_import_nodes |
| |
|
| |
|
| | def split_all_assignment(node: cst.CSTNode) -> dict[str, cst.CSTNode]: |
| | """Split the `__all__` assignment found in the modular between each corresponding files.""" |
| | all_all_per_file = {} |
| | assign_node = node.body[0] |
| | if isinstance(assign_node.value, cst.List): |
| | |
| | all_all_to_add = defaultdict(list) |
| | for element in assign_node.value.elements: |
| | if isinstance(element.value, cst.SimpleString): |
| | |
| | class_name = element.value.value |
| | file = find_file_type(element.value.evaluated_value) |
| | all_all_to_add[file] += [class_name] |
| | for file, new_alls in all_all_to_add.items(): |
| | new_node = assign_node.with_changes( |
| | value=cst.List(elements=[cst.Element(value=cst.SimpleString(value=k)) for k in new_alls]) |
| | ) |
| | all_all_per_file[file] = node.with_changes(body=[new_node]) |
| | return all_all_per_file |
| |
|
| |
|
| | class ModularFileMapper(ModuleMapper): |
| | """This is a Mapper to visit a modular file (like `modular_llama.py`). It visits the whole file, recording dependency, |
| | then visits all imported modeling files (like `modeling_llama.py`), and manages their mutual dependencies. |
| | Calling the method `create_modules()` after visit will create all modules based on this modular file. |
| | """ |
| |
|
| | def __init__(self, python_module, new_name): |
| | super().__init__(python_module) |
| | |
| | self.model_name = new_name |
| |
|
| | self.model_specific_imported_objects: Dict[str, str] = {} |
| | self.model_specific_modules: Dict[str, cst.Module] = {} |
| |
|
| | self.all_all_to_add = {} |
| | |
| |
|
| | def visit_ImportFrom(self, node: cst.ImportFrom) -> None: |
| | """When visiting imports from modeling files (i.e. `transformers.models.xxx`) we get the code, parse it, |
| | and save it in `self.model_specific_modules` to later visit. The imported objects are saved in `self.model_specific_imported_objects`. |
| | """ |
| | import_module = self.python_module.code_for_node(node.module) |
| | import_statement = "." * len(node.relative) + import_module |
| | if any(import_to_skip in import_statement for import_to_skip in IMPORTS_TO_SKIP_IN_MODULAR): |
| | return |
| | if m.matches(node.module, m.Attribute()): |
| | for imported_ in node.names: |
| | _import = re.search( |
| | rf"(?:transformers\.models\.)|(?:\.\.)\w+\.({self.match_patterns})_.*", import_statement |
| | ) |
| | if _import: |
| | source = _import.group(1) |
| | if source == "modeling" and "Config" in self.python_module.code_for_node(imported_): |
| | raise ValueError( |
| | f"You are importing {self.python_module.code_for_node(imported_)} from the modeling file. Import from the `configuration_xxxx.py` file instead" |
| | ) |
| | if import_module not in self.model_specific_modules: |
| | if "models" not in import_module: |
| | import_module = "models." + import_module |
| | if "transformers" not in import_module: |
| | import_module = "transformers." + import_module |
| | source_code = get_module_source_from_name(import_module) |
| | tree = cst.parse_module(source_code) |
| | self.model_specific_modules[import_module] = tree |
| | imported_object = self.python_module.code_for_node(imported_.name) |
| | self.model_specific_imported_objects[imported_object] = import_module |
| | if m.matches(node.module, m.Name()): |
| | if "transformers" == import_module: |
| | raise ValueError( |
| | f"You are importing from {import_module} directly using global imports. Import from the correct local path" |
| | ) |
| |
|
| | def visit_SimpleStatementLine(self, node): |
| | """If we visit an import statement not previously visited, record it. If we visit a module-scope assignment, |
| | simply record it or, if it is `__all__`, split it between files where we should dispatch it. |
| | """ |
| | parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) |
| | simple_top_level_assign_structure = m.SimpleStatementLine( |
| | body=[m.Assign(targets=[m.AssignTarget(target=m.Name())])] |
| | ) |
| | if m.matches(parent_node, m.Module()): |
| | if m.matches(node, m.SimpleStatementLine(body=[m.Import()])): |
| | self.imports.append(node) |
| | elif m.matches(node, m.SimpleStatementLine(body=[m.ImportFrom()])): |
| | import_module = self.python_module.code_for_node(node.body[0].module) |
| | import_statement = "." * len(node.body[0].relative) + import_module |
| | if not ( |
| | re.search(rf"(?:transformers\.models\.)|(?:\.\.)\w+\.({self.match_patterns})_.*", import_statement) |
| | and not any(import_to_skip in import_statement for import_to_skip in IMPORTS_TO_SKIP_IN_MODULAR) |
| | ): |
| | self.imports.append(node) |
| | elif m.matches(node, simple_top_level_assign_structure): |
| | assigned_variable = node.body[0].targets[0].target.value |
| | |
| | if assigned_variable == "__all__": |
| | self.all_all_to_add = split_all_assignment(node) |
| | else: |
| | self.current_assignment = assigned_variable |
| | self.assignments[assigned_variable] = node |
| |
|
| | def leave_Module(self, node): |
| | """When we leave the modular file, we do the following in order: |
| | 1. for each modeling file found in the imports, rename it with the new model name, visit it, and update |
| | its dependency graph with the new function and assignment definitions found in the modular |
| | 2. update the modular dependency graph with the imported functions and assignments (found when visiting the matching files) |
| | 3. compute the nested (recursive) function and assignment dependencies |
| | """ |
| | |
| | super().leave_Module(node) |
| |
|
| | |
| | self.visited_modules = {} |
| | self.renamers = {} |
| | name_prefixes = self.infer_new_model_name() |
| | for file, module in self.model_specific_modules.items(): |
| | file_model_name = file.split(".")[-2] |
| | new_name = name_prefixes[file] |
| | renamer = ReplaceNameTransformer(file_model_name, new_name, self.model_name) |
| | renamed_module = module.visit(renamer) |
| | self.visited_modules[file] = ModelFileMapper.visit_and_merge_dependencies( |
| | renamed_module, |
| | self.classes, |
| | self.functions, |
| | self.assignments, |
| | self.object_dependency_mapping, |
| | self.start_lines, |
| | ) |
| | |
| | self.renamers[file] = renamer |
| |
|
| | |
| | |
| | self.merge_model_specific_imports(self.visited_modules) |
| |
|
| | |
| | self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies() |
| |
|
| | |
| | |
| | self.imported_objects_per_file = defaultdict(set) |
| | for file, mapper in self.visited_modules.items(): |
| | file_type = re.search(rf"^transformers\.models\.\w+\.({self.match_patterns})_.*", file).group(1) |
| | self.imported_objects_per_file[file_type].update(mapper.objects_imported_from_modeling) |
| |
|
| | def merge_model_specific_imports(self, visited_modules): |
| | """Merge the functions and assignments imported from the modeling files to the modular nodes and dependency graph, |
| | based on the visited files.""" |
| | self.start_lines_file_mapping = {} |
| | self.added_objects_file_mapping = {} |
| | for object_name, file in self.model_specific_imported_objects.items(): |
| | visited_module = visited_modules[file] |
| | self.start_lines_file_mapping[file] = visited_module.start_lines |
| | |
| | if object_name in visited_module.functions and object_name not in self.functions: |
| | self.functions[object_name] = visited_module.functions[object_name] |
| | self.added_objects_file_mapping[object_name] = file |
| | dependencies = visited_module.object_dependency_mapping.get(object_name, None) |
| | if dependencies is not None: |
| | self.object_dependency_mapping[object_name] = dependencies |
| | for dep in dependencies: |
| | if dep not in self.global_nodes: |
| | self.added_objects_file_mapping[dep] = file |
| | self.functions[dep] = visited_module.global_nodes[dep] |
| |
|
| | |
| | |
| | |
| | recursive_dependencies = visited_module.object_recursive_dependency_mapping.get(object_name, set()) |
| | node_recursive_dependencies_mapping = { |
| | dep: visited_module.global_nodes[dep] for dep in recursive_dependencies |
| | } |
| | for filename, module_mapper in self.visited_modules.items(): |
| | if filename != file: |
| | module_mapper.global_nodes[object_name] = visited_module.functions[object_name] |
| | if len(recursive_dependencies) > 0: |
| | module_mapper.object_recursive_dependency_mapping[object_name] = recursive_dependencies |
| | module_mapper.global_nodes.update(node_recursive_dependencies_mapping) |
| |
|
| | |
| | elif object_name in visited_module.assignments and object_name not in self.assignments: |
| | self.assignments[object_name] = visited_module.assignments[object_name] |
| | self.added_objects_file_mapping[object_name] = file |
| | dependencies = visited_module.object_dependency_mapping.get(object_name, None) |
| | if dependencies is not None: |
| | self.object_dependency_mapping[object_name] = dependencies |
| | for dep in dependencies: |
| | if dep not in self.global_nodes: |
| | self.added_objects_file_mapping[dep] = file |
| | self.assignments[dep] = visited_module.global_nodes[dep] |
| |
|
| | |
| | self.global_nodes = {**self.assignments, **self.classes, **self.functions} |
| | |
| | self._restrict_dependencies_to_known_entities() |
| |
|
| | def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: |
| | """Compute in which relative order the `missing_dependencies` should appear when the nodes are added to the final file that |
| | will be created based on the modular. |
| | """ |
| | relative_order = {} |
| | idx = 0 |
| |
|
| | original_dependencies = [] |
| | other_files_dependencies = defaultdict(list) |
| | for dep in tuple(missing_dependencies): |
| | if dep in self.added_objects_file_mapping: |
| | file = self.added_objects_file_mapping[dep] |
| | other_files_dependencies[file].append(dep) |
| | else: |
| | original_dependencies.append(dep) |
| | |
| | all_dependencies = [] |
| | for file, dependencies in other_files_dependencies.items(): |
| | sorted_dependencies = sorted(dependencies, key=lambda x: self.start_lines_file_mapping[file][x]) |
| | all_dependencies += sorted_dependencies |
| | all_dependencies += sorted(original_dependencies, key=lambda x: self.start_lines[x]) |
| |
|
| | |
| | for dep in all_dependencies: |
| | relative_order[dep] = idx |
| | idx += 1 |
| |
|
| | return relative_order |
| |
|
| | def infer_new_model_name(self) -> dict: |
| | """Infer whether we are using a model name prefix different from the usual model name as defined from the filename. |
| | This is useful e.g. when we define a new multi-modal model, and only the text part inherits from `LlamaModel`, |
| | so we have something like: |
| | ```python |
| | class NewModelNameTextDecoderLayer(LlamaDecoderLayer): |
| | pass |
| | ``` |
| | with the `Text` prefix added to the model name. |
| | However, in case of multiple prefix used, we raise a warning and use the most frequent prefix, to avoid parsing |
| | the same file multiple times and inconsistencies in the objects added from dependencies. |
| | If the new prefix collides with a prefix of another class in the file where we are importing from, then we also |
| | raise a warning, and use the default prefix (model name) to avoid collisions in dependencies. |
| | """ |
| | prefix_model_name_mapping = defaultdict(Counter) |
| | cased_default_name = get_cased_name(self.model_name) |
| | |
| | for class_name, class_node in self.classes.items(): |
| | modeling_bases = [ |
| | k.value.value for k in class_node.bases if k.value.value in self.model_specific_imported_objects |
| | ] |
| | if len(modeling_bases) > 1: |
| | raise ValueError( |
| | f"{class_name} was defined with more than 1 model-specific super class. This is unsupported. We found {(*modeling_bases,)}." |
| | ) |
| | if len(modeling_bases) == 1: |
| | filename = self.model_specific_imported_objects[modeling_bases[0]] |
| | cased_model_name = cased_default_name |
| | suffix = common_partial_suffix(class_name, modeling_bases[0]) |
| | if len(suffix) > 0 and suffix[0].isupper(): |
| | cased_model_name = class_name.replace(suffix, "") |
| | prefix_model_name_mapping[filename].update([cased_model_name]) |
| |
|
| | |
| | final_name_mapping = {} |
| | for file, prefixes_counter in prefix_model_name_mapping.items(): |
| | if len(prefixes_counter) > 1: |
| | _, total = prefixes_counter.most_common(1)[0] |
| | most_used_entities = [name for name, count in prefixes_counter.most_common() if count == total] |
| | |
| | final_name = cased_default_name if cased_default_name in most_used_entities else most_used_entities[-1] |
| | else: |
| | final_name = list(prefixes_counter)[0] |
| | |
| | old_cased_model_name = get_cased_name(file.split(".")[-2]) |
| | old_model_name_prefix = final_name.replace(cased_default_name, old_cased_model_name) |
| | |
| | has_prefix_collision = f"\nclass {old_model_name_prefix}" in get_module_source_from_name(file) |
| | if final_name != cased_default_name and has_prefix_collision: |
| | if len(prefixes_counter) > 1: |
| | logger.warning( |
| | f"We detected multiple prefix names when inheriting from {file}: {(*set(prefixes_counter),)}. However, the " |
| | f"most used one, '{final_name}', is already present in the source file and will likely cause consistency " |
| | f"issues. For this reason we fallback to the default prefix '{cased_default_name}' when grabbing args " |
| | "and dependencies. Make sure to subclass the intermediate classes with the prefix you want (if different " |
| | f"from '{cased_default_name}') or use a single prefix in all the modular (best)." |
| | ) |
| | else: |
| | logger.warning( |
| | f"We detected the use of the new default prefix {final_name} when inheriting from {file}. However, it is " |
| | "already present in the source file and will likely cause consistency issues. For this reason we fallback " |
| | f"to the default prefix '{cased_default_name}' when grabbing args and dependencies. Make sure to subclass " |
| | f"the intermediate classes with the prefix you want (if different from '{cased_default_name}')" |
| | ) |
| | final_name = cased_default_name |
| | elif len(prefixes_counter) > 1: |
| | logger.warning( |
| | f"We detected multiple prefix names when inheriting from {file}: {(*set(prefixes_counter),)}. We will only " |
| | f"use the most used '{final_name}' prefix when grabbing args and dependencies. Make sure to subclass the " |
| | f"intermediate classes with the prefix you want (if different from '{final_name}') or use a single prefix " |
| | "in all the modular (best)." |
| | ) |
| | final_name_mapping[file] = get_lowercase_name(final_name) |
| |
|
| | |
| | for file in self.model_specific_modules.keys(): |
| | if file not in final_name_mapping.keys(): |
| | final_name_mapping[file] = self.model_name |
| |
|
| | return final_name_mapping |
| |
|
| |
|
| | def check_dependencies_and_create_import_node( |
| | file_type: str, new_dependencies: set[str], mapper: ModuleMapper, new_name: str |
| | ) -> tuple[set[str], dict[str, cst.CSTNode]]: |
| | """Check that all class nodes in the `new_dependencies` belong to the correct `file_type`. If this is not the case, |
| | we need to remove it from the dependencies, and create a new import to it instead. |
| | This scenario may appear in the following case: |
| | If a new class in the `modular_xxx.py` file does not belong to `type_xxx.py`, but is used somewhere in `other_type_xxx.py` |
| | (e.g. as a type hint), but none of the visited files had a similar class, then it would be imported in `type_xxx.py` as |
| | part of the standard dependency graph (because we never encountered an import towards this new class in any file). |
| | For example imagine the following `modular.py`: |
| | ``` |
| | from ..llama.modeling_llama import LlamaModel |
| | |
| | class NewNameTextConfig(PretrainedConfig): |
| | ... |
| | |
| | class NewNameConfig(PretrainedConfig): |
| | ... |
| | |
| | class NewNameModel(LlamaModel): |
| | config = NewNameConfig() |
| | text_config = NewNameTextConfig() |
| | ... |
| | ``` |
| | then without the help of this function, `NewNameTextConfig` would be imported in the `modeling_newname.py` as well as |
| | `configuration_newname.py`, because `modeling_llama.py` tells us to not import `NewNameConfig`, but has no |
| | knowledge of `NewNameTextConfig`. |
| | """ |
| | class_dependencies = {dep for dep in new_dependencies if m.matches(mapper.global_nodes[dep], m.ClassDef())} |
| | corrected_dependencies = new_dependencies.copy() |
| | new_imports = {} |
| | for class_name in class_dependencies: |
| | class_file_type = find_file_type(class_name) |
| | |
| | if class_file_type != file_type: |
| | corrected_dependencies.remove(class_name) |
| | import_statement = f"from .{class_file_type}_{new_name} import {class_name}" |
| | new_imports[class_name] = cst.parse_statement(import_statement) |
| |
|
| | return corrected_dependencies, new_imports |
| |
|
| |
|
| | def get_class_node_and_dependencies( |
| | modular_mapper: ModularFileMapper, class_name: str, node: cst.CSTNode, files: dict[str, dict] |
| | ) -> tuple[dict, str, dict]: |
| | """Return a single class node (and all its dependency nodes), to be added to the `files`. It creates the new |
| | class node based on the inherited classes if needed. Also returns any new imports of a new class defined in |
| | the modular that we nay need. |
| | """ |
| | |
| | model_specific_bases = [ |
| | k.value.value for k in node.bases if k.value.value in modular_mapper.model_specific_imported_objects |
| | ] |
| | super_class = model_specific_bases[0] if len(model_specific_bases) == 1 else None |
| |
|
| | file_type = find_file_type(class_name) |
| | file_to_update = files[file_type] |
| | model_name = modular_mapper.model_name |
| |
|
| | |
| | imported_objects = modular_mapper.imported_objects_per_file[file_type] |
| |
|
| | |
| | if super_class is not None: |
| | super_file_name = modular_mapper.model_specific_imported_objects[super_class] |
| |
|
| | |
| | mapper = modular_mapper.visited_modules[super_file_name] |
| | |
| | renamer = modular_mapper.renamers[super_file_name] |
| | renamed_super_class = preserve_case_replace(super_class, renamer.patterns, renamer.cased_new_name) |
| |
|
| | |
| | updated_node = replace_class_node(mapper, node, renamed_super_class, super_class) |
| |
|
| | |
| | new_node_dependencies = augmented_dependencies_for_class_node(updated_node, mapper, imported_objects) |
| |
|
| | |
| | |
| | new_node_dependencies, new_imports = check_dependencies_and_create_import_node( |
| | file_type, new_node_dependencies, mapper, model_name |
| | ) |
| |
|
| | |
| | all_dependencies_to_add = find_all_dependencies( |
| | dependency_mapping=mapper.class_dependency_mapping, |
| | initial_dependencies=new_node_dependencies, |
| | initial_checked_dependencies=set(file_to_update.keys()), |
| | ) |
| |
|
| | relative_dependency_order = mapper.compute_relative_order(all_dependencies_to_add) |
| | nodes_to_add = { |
| | dep: (relative_dependency_order[dep], mapper.global_nodes[dep]) for dep in all_dependencies_to_add |
| | } |
| |
|
| | |
| | else: |
| | updated_node = node |
| | |
| | |
| | all_dependencies_to_add = augmented_dependencies_for_class_node(updated_node, modular_mapper, imported_objects) |
| |
|
| | |
| | |
| | all_dependencies_to_add, new_imports = check_dependencies_and_create_import_node( |
| | file_type, all_dependencies_to_add, modular_mapper, model_name |
| | ) |
| |
|
| | relative_dependency_order = modular_mapper.compute_relative_order(all_dependencies_to_add) |
| | nodes_to_add = { |
| | dep: (relative_dependency_order[dep], modular_mapper.global_nodes[dep]) |
| | for dep in all_dependencies_to_add |
| | if dep not in file_to_update.keys() |
| | } |
| |
|
| | |
| | class_idx = max(relative_dependency_order.values()) + 1 if len(relative_dependency_order) > 0 else 0 |
| | nodes_to_add[class_name] = (class_idx, updated_node) |
| |
|
| | return nodes_to_add, file_type, new_imports |
| |
|
| |
|
| | def create_modules(modular_mapper: ModularFileMapper) -> dict[str, cst.Module]: |
| | """Create all the new modules based on visiting the modular file. It replaces all classes as necessary.""" |
| | files = defaultdict(dict) |
| | current_file_indices = defaultdict(lambda: 0) |
| |
|
| | |
| | for class_name, node in modular_mapper.classes.items(): |
| | nodes_to_add, file_type, new_imports = get_class_node_and_dependencies(modular_mapper, class_name, node, files) |
| |
|
| | |
| | modular_mapper.imported_objects_per_file[file_type].update(new_imports.keys()) |
| | modular_mapper.imports.extend(list(new_imports.values())) |
| |
|
| | |
| | nodes_to_add = sorted(nodes_to_add.items(), key=lambda x: x[1][0]) |
| | |
| | for dependency, (_, node) in nodes_to_add: |
| | |
| | try: |
| | |
| | idx = -1000 + VARIABLES_AT_THE_BEGINNING.index(dependency) |
| | except ValueError: |
| | idx = current_file_indices[file_type] |
| | current_file_indices[file_type] += 1 |
| | files[file_type][dependency] = {"insert_idx": idx, "node": node} |
| |
|
| | |
| | for file_type, node in modular_mapper.all_all_to_add.items(): |
| | idx = current_file_indices[file_type] |
| | files[file_type]["__all__"] = {"insert_idx": idx, "node": node} |
| |
|
| | |
| | |
| | all_imports = modular_mapper.imports.copy() |
| | all_imports_code = {modular_mapper.python_module.code_for_node(node).strip() for node in all_imports} |
| | for file, mapper in modular_mapper.visited_modules.items(): |
| | new_imports = [ |
| | node for node in mapper.imports if mapper.python_module.code_for_node(node).strip() not in all_imports_code |
| | ] |
| | new_imports_code = {mapper.python_module.code_for_node(node).strip() for node in new_imports} |
| | all_imports.extend(new_imports) |
| | all_imports_code.update(new_imports_code) |
| |
|
| | |
| | for file, body in files.items(): |
| | new_body = [k[1]["node"] for k in sorted(body.items(), key=lambda x: x[1]["insert_idx"])] |
| | needed_imports = get_needed_imports(body, all_imports) |
| | full_module = needed_imports + new_body |
| | new_module = cst.Module(body=full_module, header=modular_mapper.python_module.header) |
| | files[file] = new_module |
| |
|
| | return files |
| |
|
| |
|
| | def convert_modular_file(modular_file): |
| | pattern = re.search(r"modular_(.*)(?=\.py$)", modular_file) |
| | output = {} |
| | if pattern is not None: |
| | model_name = pattern.groups()[0] |
| | |
| | with open(modular_file, "r", encoding="utf-8") as file: |
| | code = file.read() |
| | module = cst.parse_module(code) |
| | wrapper = MetadataWrapper(module) |
| | cst_transformers = ModularFileMapper(module, model_name) |
| | wrapper.visit(cst_transformers) |
| | for file, module in create_modules(cst_transformers).items(): |
| | if module != {}: |
| | |
| | relative_path = re.search( |
| | r"(src/transformers/.*|examples/.*)", os.path.abspath(modular_file).replace("\\", "/") |
| | ).group(1) |
| |
|
| | header = AUTO_GENERATED_MESSAGE.format( |
| | relative_path=relative_path, short_name=os.path.basename(relative_path) |
| | ) |
| | ruffed_code = run_ruff(header + module.code, True) |
| | formatted_code = run_ruff(ruffed_code, False) |
| | output[file] = [formatted_code, ruffed_code] |
| | return output |
| | else: |
| | print(f"modular pattern not found in {modular_file}, exiting") |
| | return {} |
| |
|
| |
|
| | def save_modeling_file(modular_file, converted_file): |
| | for file_type in converted_file.keys(): |
| | file_name_prefix = file_type.split("*")[0] |
| | file_name_suffix = file_type.split("*")[-1] if "*" in file_type else "" |
| | new_file_name = modular_file.replace("modular_", f"{file_name_prefix}_").replace( |
| | ".py", f"{file_name_suffix}.py" |
| | ) |
| | non_comment_lines = len( |
| | [line for line in converted_file[file_type][0].strip().split("\n") if not line.strip().startswith("#")] |
| | ) |
| | if len(converted_file[file_type][0].strip()) > 0 and non_comment_lines > 0: |
| | with open(new_file_name, "w", encoding="utf-8") as f: |
| | f.write(converted_file[file_type][0]) |
| | else: |
| | non_comment_lines = len( |
| | [line for line in converted_file[file_type][0].strip().split("\n") if not line.strip().startswith("#")] |
| | ) |
| | if len(converted_file[file_type][1].strip()) > 0 and non_comment_lines > 0: |
| | logger.warning("The modeling code contains errors, it's written without formatting") |
| | with open(new_file_name, "w", encoding="utf-8") as f: |
| | f.write(converted_file[file_type][1]) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument( |
| | "--files_to_parse", |
| | default=["all"], |
| | nargs="+", |
| | help="A list of `modular_xxxx` files that should be converted to single model file", |
| | ) |
| | args = parser.parse_args() |
| | if args.files_to_parse == ["all"]: |
| | args.files_to_parse = glob.glob("src/transformers/models/**/modular_*.py", recursive=True) |
| | if args.files_to_parse == ["examples"]: |
| | args.files_to_parse = glob.glob("examples/**/modular_*.py", recursive=True) |
| |
|
| | priority_list, _ = find_priority_list(args.files_to_parse) |
| | assert len(priority_list) == len(args.files_to_parse), "Some files will not be converted" |
| |
|
| | for file_name in priority_list: |
| | print(f"Converting {file_name} to a single model single file format") |
| | module_path = file_name.replace("/", ".").replace(".py", "").replace("src.", "") |
| | converted_files = convert_modular_file(file_name) |
| | converter = save_modeling_file(file_name, converted_files) |
| |
|