| | from torch.utils.data import DataLoader
|
| | import torchvision
|
| |
|
| |
|
| | weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT
|
| | transform = weights.transforms()
|
| |
|
| |
|
| | train_dir = "intel_image/seg_train"
|
| | test_dir = "intel_image/seg_test"
|
| |
|
| | train_data = torchvision.datasets.ImageFolder(root = train_dir, transform = transform)
|
| | test_data = torchvision.datasets.ImageFolder(root = test_dir, transform = transform)
|
| |
|
| | train_loader = DataLoader(train_data, shuffle = True, batch_size = 32)
|
| | test_loader = DataLoader(test_data, shuffle = False, batch_size = 32)
|
| |
|
| | def create_dataloaders():
|
| | """Returns: Training and test dataloaders """
|
| | return train_loader, test_loader
|
| |
|