| | |
| | |
| |
|
| | from collections.abc import Callable |
| | from typing import Any, Dict, Generator, List, Optional, Tuple, Type, Union |
| | from rich.table import Table |
| | from rich.console import Console |
| |
|
| |
|
| | class Registry: |
| | """A registry to map strings to classes or functions. |
| | |
| | Registered object could be built from registry. Meanwhile, registered |
| | functions could be called from registry. |
| | """ |
| |
|
| | def __init__(self, |
| | name: str, |
| | build_func: Optional[Callable] = None): |
| | self._name = name |
| | self._module_dict: Dict[str, Type] = dict() |
| | self._imported = False |
| | |
| | self.build_func: Callable |
| | if build_func is None: |
| | self.build_func = build_from_cfg |
| | else: |
| | self.build_func = build_func |
| | |
| | def __len__(self): |
| | return len(self._module_dict) |
| | |
| | def __repr__(self): |
| | table = Table(title=f'Registry of {self._name}') |
| | table.add_column('Names', justify='left', style='cyan') |
| | table.add_column('Objects', justify='left', style='green') |
| |
|
| | for name, obj in sorted(self._module_dict.items()): |
| | table.add_row(name, str(obj)) |
| |
|
| | console = Console() |
| | with console.capture() as capture: |
| | console.print(table, end='') |
| |
|
| | return capture.get() |
| |
|
| | @property |
| | def name(self): |
| | return self._name |
| |
|
| | @property |
| | def module_dict(self): |
| | return self._module_dict |
| |
|
| | def build(self, cfg: dict, *args, **kwargs) -> Any: |
| | """Build an instance. |
| | Build an instance by calling :attr:`build_func`. |
| | """ |
| | return self.build_func(cfg, *args, **kwargs, registry=self) |
| | |
| | def _register_module(self, |
| | module: Type, |
| | module_name: Optional[Union[str, List[str]]] = None) -> None: |
| | """Register a module. |
| | |
| | Args: |
| | module (type): Module to be registered. Typically a class or a |
| | function, but generally all ``Callable`` are acceptable. |
| | module_name (str or list of str, optional): The module name to be |
| | registered. If not specified, the class name will be used. |
| | Defaults to None. |
| | force (bool): Whether to override an existing class with the same |
| | name. Defaults to False. |
| | """ |
| | if not callable(module): |
| | raise TypeError(f'module must be Callable, but got {type(module)}') |
| |
|
| | if module_name is None: |
| | module_name = module.__name__ |
| | if isinstance(module_name, str): |
| | module_name = [module_name] |
| | for name in module_name: |
| | if name in self._module_dict: |
| | existed_module = self.module_dict[name] |
| | raise KeyError(f'{name} is already registered in {self.name} ' |
| | f'at {existed_module.__module__}') |
| | self._module_dict[name] = module |
| |
|
| | def register_module( |
| | self, |
| | name: Optional[Union[str, List[str]]] = None, |
| | module: Optional[Type] = None) -> Union[type, Callable]: |
| | """Register a module. |
| | |
| | A record will be added to ``self._module_dict``, whose key is the class |
| | name or the specified name, and value is the class itself. |
| | It can be used as a decorator or a normal function. |
| | """ |
| |
|
| | |
| | if not (name is None or isinstance(name, str)): |
| | raise TypeError( |
| | 'name must be None, an instance of str, ' |
| | f'but got {type(name)}') |
| |
|
| | |
| | if module is not None: |
| | self._register_module(module=module, module_name=name) |
| | return module |
| |
|
| | |
| | def _register(module): |
| | self._register_module(module=module, module_name=name) |
| | return module |
| |
|
| | return _register |
| |
|
| |
|
| | def build_from_cfg( |
| | cfg: dict, |
| | registry: Registry, |
| | default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> Any: |
| | """Build a module from config dict when it is a class configuration, or |
| | call a function from config dict when it is a function configuration. |
| | """ |
| | |
| |
|
| | if not isinstance(cfg, (dict, ConfigDict, Config)): |
| | raise TypeError( |
| | f'cfg should be a dict, ConfigDict or Config, but got {type(cfg)}') |
| |
|
| | if 'type' not in cfg: |
| | if default_args is None or 'type' not in default_args: |
| | raise KeyError( |
| | '`cfg` or `default_args` must contain the key "type", ' |
| | f'but got {cfg}\n{default_args}') |
| |
|
| | if not isinstance(registry, Registry): |
| | raise TypeError('registry must be a mmengine.Registry object, ' |
| | f'but got {type(registry)}') |
| |
|
| | if not (isinstance(default_args, |
| | (dict, ConfigDict, Config)) or default_args is None): |
| | raise TypeError( |
| | 'default_args should be a dict, ConfigDict, Config or None, ' |
| | f'but got {type(default_args)}') |
| |
|
| | args = cfg.copy() |
| | if default_args is not None: |
| | for name, value in default_args.items(): |
| | args.setdefault(name, value) |
| |
|
| | |
| | |
| | |
| | scope = args.pop('_scope_', None) |
| | with registry.switch_scope_and_registry(scope) as registry: |
| | obj_type = args.pop('type') |
| | if isinstance(obj_type, str): |
| | obj_cls = registry.get(obj_type) |
| | if obj_cls is None: |
| | raise KeyError( |
| | f'{obj_type} is not in the {registry.name} registry. ' |
| | f'Please check whether the value of `{obj_type}` is ' |
| | 'correct or it was registered as expected. More details ' |
| | 'can be found at ' |
| | 'https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#import-the-custom-module' |
| | ) |
| | |
| | elif callable(obj_type): |
| | obj_cls = obj_type |
| | else: |
| | raise TypeError( |
| | f'type must be a str or valid type, but got {type(obj_type)}') |
| |
|
| | try: |
| | |
| | |
| | |
| | if inspect.isclass(obj_cls) and \ |
| | issubclass(obj_cls, ManagerMixin): |
| | obj = obj_cls.get_instance(**args) |
| | else: |
| | obj = obj_cls(**args) |
| |
|
| | if (inspect.isclass(obj_cls) or inspect.isfunction(obj_cls) |
| | or inspect.ismethod(obj_cls)): |
| | print_log( |
| | f'An `{obj_cls.__name__}` instance is built from ' |
| | 'registry, and its implementation can be found in ' |
| | f'{obj_cls.__module__}', |
| | logger='current', |
| | level=logging.DEBUG) |
| | else: |
| | print_log( |
| | 'An instance is built from registry, and its constructor ' |
| | f'is {obj_cls}', |
| | logger='current', |
| | level=logging.DEBUG) |
| | return obj |
| |
|
| | except Exception as e: |
| | |
| | cls_location = '/'.join( |
| | obj_cls.__module__.split('.')) |
| | raise type(e)( |
| | f'class `{obj_cls.__name__}` in ' |
| | f'{cls_location}.py: {e}') |
| |
|