|
|
| import torch.nn as nn
|
|
|
| from ..utils.registry import Registry
|
|
|
|
|
| MODEL_WRAPPERS = Registry('model_wrapper')
|
|
|
| def is_model_wrapper(model: nn.Module, registry: Registry = MODEL_WRAPPERS):
|
| """Check if a module is a model wrapper.
|
|
|
| Args:
|
| model (nn.Module): The model to be checked.
|
| registry (Registry): The parent registry to search for model wrappers.
|
|
|
| Returns:
|
| bool: True if the input model is a model wrapper.
|
| """
|
| module_wrappers = tuple(registry.module_dict.values())
|
| if isinstance(model, module_wrappers):
|
| return True
|
|
|
| if not registry.children:
|
| return False
|
|
|
| return any(
|
| is_model_wrapper(model, child) for child in registry.children.values())
|
|
|