import os import torch import json import glob import collections import random import numpy as np from tqdm import tqdm import torchvision.datasets as datasets from torch.utils.data import Dataset, DataLoader, Sampler class SubsetSampler(Sampler): def __init__(self, indices): self.indices = indices def __iter__(self): return (i for i in self.indices) def __len__(self): return len(self.indices) class ImageFolderWithPaths(datasets.ImageFolder): def __init__(self, path, transform, flip_label_prob=0.0): super().__init__(path, transform) self.flip_label_prob = flip_label_prob if self.flip_label_prob > 0: print(f'Flipping labels with probability {self.flip_label_prob}') num_classes = len(self.classes) for i in range(len(self.samples)): if random.random() < self.flip_label_prob: new_label = random.randint(0, num_classes-1) self.samples[i] = ( self.samples[i][0], new_label ) def __getitem__(self, index): image, label = super(ImageFolderWithPaths, self).__getitem__(index) return { 'images': image, 'labels': label, 'image_paths': self.samples[index][0] } def maybe_dictionarize(batch): if isinstance(batch, dict): return batch if len(batch) == 2: batch = {'images': batch[0], 'labels': batch[1]} elif len(batch) == 3: batch = {'images': batch[0], 'labels': batch[1], 'metadata': batch[2]} else: raise ValueError(f'Unexpected number of elements: {len(batch)}') return batch def get_features_helper(image_encoder, dataloader, device): all_data = collections.defaultdict(list) image_encoder = image_encoder.to(device) image_encoder = torch.nn.DataParallel(image_encoder, device_ids=[x for x in range(torch.cuda.device_count())]) image_encoder.eval() with torch.no_grad(): for batch in tqdm(dataloader): batch = maybe_dictionarize(batch) features = image_encoder(batch['images'].cuda()) all_data['features'].append(features.cpu()) for key, val in batch.items(): if key == 'images': continue if hasattr(val, 'cpu'): val = val.cpu() all_data[key].append(val) else: all_data[key].extend(val) for key, val in all_data.items(): if torch.is_tensor(val[0]): all_data[key] = torch.cat(val).numpy() return all_data def get_features(is_train, image_encoder, dataset, device): split = 'train' if is_train else 'val' dname = type(dataset).__name__ if image_encoder.cache_dir is not None: cache_dir = f'{image_encoder.cache_dir}/{dname}/{split}' cached_files = glob.glob(f'{cache_dir}/*') if image_encoder.cache_dir is not None and len(cached_files) > 0: print(f'Getting features from {cache_dir}') data = {} for cached_file in cached_files: name = os.path.splitext(os.path.basename(cached_file))[0] data[name] = torch.load(cached_file) else: print(f'Did not find cached features at {cache_dir}. Building from scratch.') loader = dataset.train_loader if is_train else dataset.test_loader data = get_features_helper(image_encoder, loader, device) if image_encoder.cache_dir is None: print('Not caching because no cache directory was passed.') else: os.makedirs(cache_dir, exist_ok=True) print(f'Caching data at {cache_dir}') for name, val in data.items(): torch.save(val, f'{cache_dir}/{name}.pt') return data class FeatureDataset(Dataset): def __init__(self, is_train, image_encoder, dataset, device): self.data = get_features(is_train, image_encoder, dataset, device) def __len__(self): return len(self.data['features']) def __getitem__(self, idx): data = {k: v[idx] for k, v in self.data.items()} data['features'] = torch.from_numpy(data['features']).float() return data def get_dataloader(dataset, is_train, args, image_encoder=None): if image_encoder is not None: feature_dataset = FeatureDataset(is_train, image_encoder, dataset, args.device) dataloader = DataLoader(feature_dataset, batch_size=args.batch_size, shuffle=is_train) else: dataloader = dataset.train_loader if is_train else dataset.test_loader return dataloader