| import difflib |
|
|
| import torch |
|
|
|
|
| def get_layer(l_name, library=torch.nn): |
| """Return layer object handler from library e.g. from torch.nn |
| |
| E.g. if l_name=="elu", returns torch.nn.ELU. |
| |
| Args: |
| l_name (string): Case insensitive name for layer in library (e.g. .'elu'). |
| library (module): Name of library/module where to search for object handler |
| with l_name e.g. "torch.nn". |
| |
| Returns: |
| layer_handler (object): handler for the requested layer e.g. (torch.nn.ELU) |
| |
| """ |
|
|
| all_torch_layers = [x for x in dir(torch.nn)] |
| match = [x for x in all_torch_layers if l_name.lower() == x.lower()] |
| if len(match) == 0: |
| close_matches = difflib.get_close_matches( |
| l_name, [x.lower() for x in all_torch_layers] |
| ) |
| raise NotImplementedError( |
| "Layer with name {} not found in {}.\n Closest matches: {}".format( |
| l_name, str(library), close_matches |
| ) |
| ) |
| elif len(match) > 1: |
| close_matches = difflib.get_close_matches( |
| l_name, [x.lower() for x in all_torch_layers] |
| ) |
| raise NotImplementedError( |
| "Multiple matchs for layer with name {} not found in {}.\n " |
| "All matches: {}".format(l_name, str(library), close_matches) |
| ) |
| else: |
| |
| layer_handler = getattr(library, match[0]) |
| return layer_handler |