| import torch | |
| import torch.nn as nn | |
| import torchvision.models as models | |
| from torchvision import transforms | |
| NUM_CLASSES = 7 # 6 defects + normal | |
| def get_model(pretrained=True): | |
| model = models.resnet18(pretrained=pretrained) | |
| num_ftrs = model.fc.in_features | |
| model.fc = nn.Linear(num_ftrs, NUM_CLASSES) | |
| return model | |
| def get_transforms(): | |
| return transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) |