import torch from torchvision.transforms import v2 RESIZE = { "effnet": 384, "resnet": 224, "mbnet": 224, "swin": 256, } def get_preprocessing(model_type: str) -> v2.Compose: """ Gets the right image preprocessing transform for each model Parameters ---------- model_type : str Model nickname Returns ------- v2.Compose Preprocessing transform Raises ------ NotImplementedError If it's an invalid model_type """ resize = RESIZE[model_type] transform = v2.Compose( [ v2.ToImage(), v2.Resize((resize, resize)), v2.ToDtype(torch.float, True), v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), v2.Grayscale(3), ] ) return transform