Spaces:
Sleeping
Sleeping
| import torch | |
| from torchvision.transforms import v2 | |
| RESIZE = { | |
| "effnet": 384, | |
| "resnet": 224, | |
| "mbnet": 224, | |
| "swin": 256, | |
| } | |
| def get_preprocessing(model_type: str) -> v2.Compose: | |
| """ | |
| Gets the right image preprocessing transform for each model | |
| Parameters | |
| ---------- | |
| model_type : str | |
| Model nickname | |
| Returns | |
| ------- | |
| v2.Compose | |
| Preprocessing transform | |
| Raises | |
| ------ | |
| NotImplementedError | |
| If it's an invalid model_type | |
| """ | |
| resize = RESIZE[model_type] | |
| transform = v2.Compose( | |
| [ | |
| v2.ToImage(), | |
| v2.Resize((resize, resize)), | |
| v2.ToDtype(torch.float, True), | |
| v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | |
| v2.Grayscale(3), | |
| ] | |
| ) | |
| return transform | |