NematodeClassifier / utils.py
VikramR's picture
Uploaded app
a9d56ef
raw
history blame contribute delete
834 Bytes
import torch
from torchvision.transforms import v2
RESIZE = {
"effnet": 384,
"resnet": 224,
"mbnet": 224,
"swin": 256,
}
def get_preprocessing(model_type: str) -> v2.Compose:
"""
Gets the right image preprocessing transform for each model
Parameters
----------
model_type : str
Model nickname
Returns
-------
v2.Compose
Preprocessing transform
Raises
------
NotImplementedError
If it's an invalid model_type
"""
resize = RESIZE[model_type]
transform = v2.Compose(
[
v2.ToImage(),
v2.Resize((resize, resize)),
v2.ToDtype(torch.float, True),
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
v2.Grayscale(3),
]
)
return transform