| import mim |
| from pathlib import Path |
| from mim.utils import get_installed_path, echo_success |
| from mmengine.config import Config |
|
|
| class Manager: |
|
|
| def __init__(self, path=None) -> None: |
| """ |
| Params: |
| - path: root path of projects to save checkpoints and configs |
| """ |
| if path: |
| self.path = Path(path) |
| else: |
| self.path = Path(__file__).parents[1] |
| self.keys = ['weight', 'config', 'model', 'training_data'] |
|
|
| def get_model_infos(self, package_name, keyword: str=None): |
| """ because mim search is too strict, |
| I want to search by keyword, not a strict match |
| """ |
| model_infos = mim.get_model_info(package_name) |
| model_names = model_infos.index |
| info_keys = model_infos.columns.tolist() |
| keys = self.intersect_keys(info_keys, |
| self.keys) |
| if keyword is None: |
| return model_infos[:, keys] |
| |
| valid_names = [name for name in model_names |
| if keyword in name] |
| filter_infos = model_infos.loc[valid_names, keys] |
| return filter_infos |
|
|
| def intersect_keys(self, keys1 , keys2): |
| return list(set(keys1) & set(keys2)) |
|
|
| def download(self, package, model, config_only=False): |
| """ Use model names to download checkpoints and configs. |
| Args: |
| - package: package name, e.g. mmdet |
| - model: model name, e.g. faster_rcnn or faster_rcnn_r50_fpn_1x_coco |
| - config_only: only download configs, which is helpful when you |
| already download checkpoints fast through other ways. |
| """ |
| infos = self.get_model_infos(package, model) |
| |
| for model, info in infos.iterrows(): |
| |
| hyper_name = info['model'] |
| dst_path = self.path / 'model_zoo' / hyper_name / model |
| dst_path.mkdir(parents=True, exist_ok=True) |
|
|
| if config_only: |
| |
| installed_path = Path(get_installed_path(package)) |
| config_path = info['config'] |
| config_path = installed_path / '.mim' / config_path |
| |
| config_obj = Config.fromfile(config_path) |
| saved_config_path = dst_path / f'{model}.py' |
| config_obj.dump(saved_config_path) |
| echo_success( |
| f'Successfully dumped {model}.py to {dst_path}') |
| else: |
| mim.download(package, [model], dest_root=dst_path) |
|
|
| if __name__ == '__main__': |
| m = Manager() |
| print(m.get_model_infos('mmdet', 'det')) |
| |
|
|