import torch import torch.nn as nn import torchvision.models as models from torchvision import transforms NUM_CLASSES = 7 # 6 defects + normal def get_model(pretrained=True): model = models.resnet18(pretrained=pretrained) num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, NUM_CLASSES) return model def get_transforms(): return transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ])