| | import os |
| | import unittest |
| | from pathlib import Path |
| |
|
| | from transformers.utils.import_utils import define_import_structure, spread_import_structure |
| |
|
| |
|
| | import_structures = Path("import_structures") |
| |
|
| |
|
| | def fetch__all__(file_content): |
| | """ |
| | Returns the content of the __all__ variable in the file content. |
| | Returns None if not defined, otherwise returns a list of strings. |
| | """ |
| | lines = file_content.split("\n") |
| | for line_index in range(len(lines)): |
| | line = lines[line_index] |
| | if line.startswith("__all__ = "): |
| | |
| | if line.endswith("]"): |
| | return [obj.strip("\"' ") for obj in line.split("=")[1].strip(" []").split(",")] |
| |
|
| | |
| | else: |
| | _all = [] |
| | for __all__line_index in range(line_index + 1, len(lines)): |
| | if lines[__all__line_index].strip() == "]": |
| | return _all |
| | else: |
| | _all.append(lines[__all__line_index].strip("\"', ")) |
| |
|
| |
|
| | class TestImportStructures(unittest.TestCase): |
| | base_transformers_path = Path(__file__).parent.parent.parent |
| | models_path = base_transformers_path / "src" / "transformers" / "models" |
| | models_import_structure = spread_import_structure(define_import_structure(models_path)) |
| |
|
| | |
| | |
| | @unittest.skip(reason="failing") |
| | def test_definition(self): |
| | import_structure = define_import_structure(import_structures) |
| | import_structure_definition = { |
| | frozenset(()): { |
| | "import_structure_raw_register": {"A0", "a0", "A4"}, |
| | "import_structure_register_with_comments": {"B0", "b0"}, |
| | }, |
| | frozenset(("tf", "torch")): { |
| | "import_structure_raw_register": {"A1", "a1", "A2", "a2", "A3", "a3"}, |
| | "import_structure_register_with_comments": {"B1", "b1", "B2", "b2", "B3", "b3"}, |
| | }, |
| | frozenset(("torch",)): { |
| | "import_structure_register_with_duplicates": {"C0", "c0", "C1", "c1", "C2", "c2", "C3", "c3"}, |
| | }, |
| | } |
| |
|
| | self.assertDictEqual(import_structure, import_structure_definition) |
| |
|
| | def test_transformers_specific_model_import(self): |
| | """ |
| | This test ensures that there is equivalence between what is written down in __all__ and what is |
| | written down with register(). |
| | |
| | It doesn't test the backends attributed to register(). |
| | """ |
| | for architecture in os.listdir(self.models_path): |
| | if ( |
| | os.path.isfile(self.models_path / architecture) |
| | or architecture.startswith("_") |
| | or architecture == "deprecated" |
| | ): |
| | continue |
| |
|
| | with self.subTest(f"Testing arch {architecture}"): |
| | import_structure = define_import_structure(self.models_path / architecture) |
| | backend_agnostic_import_structure = {} |
| | for requirement, module_object_mapping in import_structure.items(): |
| | for module, objects in module_object_mapping.items(): |
| | if module not in backend_agnostic_import_structure: |
| | backend_agnostic_import_structure[module] = [] |
| |
|
| | backend_agnostic_import_structure[module].extend(objects) |
| |
|
| | for module, objects in backend_agnostic_import_structure.items(): |
| | with open(self.models_path / architecture / f"{module}.py") as f: |
| | content = f.read() |
| | _all = fetch__all__(content) |
| |
|
| | if _all is None: |
| | raise ValueError(f"{module} doesn't have __all__ defined.") |
| |
|
| | error_message = ( |
| | f"self.models_path / architecture / f'{module}.py doesn't seem to be defined correctly:\n" |
| | f"Defined in __all__: {sorted(_all)}\nDefined with register: {sorted(objects)}" |
| | ) |
| | self.assertListEqual(sorted(objects), sorted(_all), msg=error_message) |
| |
|
| | |
| | |
| | @unittest.skip(reason="failing") |
| | def test_export_backend_should_be_defined(self): |
| | with self.assertRaisesRegex(ValueError, "Backend should be defined in the BACKENDS_MAPPING"): |
| | pass |
| |
|