| | from typing import Any, Dict, List, Optional, Tuple, Union |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import yaml |
| | from torch.nn.modules.batchnorm import _BatchNorm |
| |
|
| | __all__ = [ |
| | "make_divisible", |
| | "load_state_dict_from_file", |
| | "list_mean", |
| | "list_sum", |
| | "parse_unknown_args", |
| | "partial_update_config", |
| | "remove_bn", |
| | "get_same_padding", |
| | "torch_random_choices", |
| | ] |
| |
|
| |
|
| | def make_divisible( |
| | v: Union[int, float], divisor: Optional[int], min_val=None |
| | ) -> Union[int, float]: |
| | """This function is taken from the original tf repo. |
| | |
| | It ensures that all layers have a channel number that is divisible by 8 |
| | It can be seen here: |
| | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py |
| | :param v: |
| | :param divisor: |
| | :param min_val: |
| | :return: |
| | """ |
| | if divisor is None: |
| | return v |
| |
|
| | if min_val is None: |
| | min_val = divisor |
| | new_v = max(min_val, int(v + divisor / 2) // divisor * divisor) |
| | |
| | if new_v < 0.9 * v: |
| | new_v += divisor |
| | return new_v |
| |
|
| |
|
| | def load_state_dict_from_file(file: str) -> Dict[str, torch.Tensor]: |
| | checkpoint = torch.load(file, map_location="cpu") |
| | if "state_dict" in checkpoint: |
| | checkpoint = checkpoint["state_dict"] |
| | return checkpoint |
| |
|
| |
|
| | def list_sum(x: List) -> Any: |
| | return x[0] if len(x) == 1 else x[0] + list_sum(x[1:]) |
| |
|
| |
|
| | def list_mean(x: List) -> Any: |
| | return list_sum(x) / len(x) |
| |
|
| |
|
| | def parse_unknown_args(unknown: List) -> Dict: |
| | """Parse unknown args.""" |
| | index = 0 |
| | parsed_dict = {} |
| | while index < len(unknown): |
| | key, val = unknown[index], unknown[index + 1] |
| | index += 2 |
| | if key.startswith("--"): |
| | key = key[2:] |
| | try: |
| | |
| | if "{" in val and "}" in val and ":" in val: |
| | val = val.replace(":", ": ") |
| | out_val = yaml.safe_load(val) |
| | except ValueError: |
| | |
| | out_val = val |
| | parsed_dict[key] = out_val |
| | return parsed_dict |
| |
|
| |
|
| | def partial_update_config(config: Dict, partial_config: Dict): |
| | for key in partial_config: |
| | if ( |
| | key in config |
| | and isinstance(partial_config[key], Dict) |
| | and isinstance(config[key], Dict) |
| | ): |
| | partial_update_config(config[key], partial_config[key]) |
| | else: |
| | config[key] = partial_config[key] |
| |
|
| |
|
| | def remove_bn(model: nn.Module) -> None: |
| | for m in model.modules(): |
| | if isinstance(m, _BatchNorm): |
| | m.weight = m.bias = None |
| | m.forward = lambda x: x |
| |
|
| |
|
| | def get_same_padding(kernel_size: Union[int, Tuple[int, int]]) -> Union[int, tuple]: |
| | if isinstance(kernel_size, tuple): |
| | assert len(kernel_size) == 2, f"invalid kernel size: {kernel_size}" |
| | p1 = get_same_padding(kernel_size[0]) |
| | p2 = get_same_padding(kernel_size[1]) |
| | return p1, p2 |
| | else: |
| | assert isinstance( |
| | kernel_size, int |
| | ), "kernel size should be either `int` or `tuple`" |
| | assert kernel_size % 2 > 0, "kernel size should be odd number" |
| | return kernel_size // 2 |
| |
|
| |
|
| | def torch_random_choices( |
| | src_list: List[Any], |
| | generator: Optional[torch.Generator], |
| | k=1, |
| | ) -> Union[Any, List[Any]]: |
| | rand_idx = torch.randint(low=0, high=len(src_list), generator=generator, size=(k,)) |
| | out_list = [src_list[i] for i in rand_idx] |
| | return out_list[0] if k == 1 else out_list |
| |
|