| | import tensorflow as tf
|
| | import os
|
| | import argparse
|
| | from sklearn.model_selection import StratifiedShuffleSplit
|
| | from tqdm import tqdm
|
| | import sys
|
| | import uuid
|
| | from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
| |
|
| | def parse_arguments():
|
| | parser = argparse.ArgumentParser(description='Image Data Loader with Augmentation and Splits')
|
| | parser.add_argument('--path', type=str, required=True, help='Path to the folder containing images')
|
| | parser.add_argument('--dim', type=int, default=224, help='Required image dimension')
|
| | parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
|
| | parser.add_argument('--num_workers', type=int, default=4, help='Number of workers for data loading')
|
| | parser.add_argument('--target_folder', type=str, required=True, help='Folder to store the train, test, and val splits')
|
| | parser.add_argument('--augment_data', action='store_true', help='Apply data augmentation')
|
| | return parser.parse_args()
|
| |
|
| | def create_datagens():
|
| |
|
| | return [
|
| | ImageDataGenerator(rescale=1./255),
|
| | ImageDataGenerator(rotation_range=20),
|
| | ImageDataGenerator(width_shift_range=0.2),
|
| | ImageDataGenerator(height_shift_range=0.2),
|
| | ImageDataGenerator(horizontal_flip=True)
|
| | ]
|
| |
|
| | def process_image(file_path, image_size):
|
| |
|
| | file_path = file_path.numpy().decode('utf-8')
|
| | image = tf.io.read_file(file_path)
|
| | image = tf.image.decode_image(image, channels=3, dtype=tf.float32)
|
| | image = tf.image.resize(image, [image_size, image_size])
|
| | image = tf.clip_by_value(image, 0.0, 1.0)
|
| | return image
|
| |
|
| | def save_image(image, file_path):
|
| |
|
| | image = tf.image.convert_image_dtype(image, dtype=tf.uint8)
|
| | image = tf.image.encode_jpeg(image)
|
| | tf.io.write_file(file_path, image)
|
| |
|
| | def load_data(path, image_size, batch_size):
|
| | all_images = []
|
| | labels = []
|
| |
|
| |
|
| | for subdir, _, files in os.walk(path):
|
| | label = os.path.basename(subdir)
|
| | for fname in files:
|
| | if fname.endswith(('.jpg', '.jpeg', '.png')):
|
| | all_images.append(os.path.join(subdir, fname))
|
| | labels.append(label)
|
| |
|
| | unique_labels = set(labels)
|
| | print(f"Found {len(all_images)} images in {path}\n")
|
| | print(f"Labels found ({len(unique_labels)}): {unique_labels}\n")
|
| |
|
| |
|
| | if len(all_images) == 0:
|
| | raise ValueError(f"No images found in the specified path: {path}")
|
| |
|
| |
|
| | sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
|
| | train_indices, test_indices = next(sss.split(all_images, labels))
|
| |
|
| | train_files = [all_images[i] for i in train_indices]
|
| | train_labels = [labels[i] for i in train_indices]
|
| | test_files = [all_images[i] for i in test_indices]
|
| | test_labels = [labels[i] for i in test_indices]
|
| |
|
| | sss_val = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=42)
|
| | val_indices, test_indices = next(sss_val.split(test_files, test_labels))
|
| |
|
| | val_files = [test_files[i] for i in val_indices]
|
| | val_labels = [test_labels[i] for i in val_indices]
|
| | test_files = [test_files[i] for i in test_indices]
|
| | test_labels = [test_labels[i] for i in test_indices]
|
| |
|
| | print(f"Data split into {len(train_files)} train, {len(val_files)} validation, and {len(test_files)} test images.\n")
|
| |
|
| |
|
| | def tf_load_and_augment_image(file_path, label):
|
| | image = tf.py_function(func=lambda x: process_image(x, image_size), inp=[file_path], Tout=tf.float32)
|
| | image.set_shape([image_size, image_size, 3])
|
| | return image, label
|
| |
|
| | train_dataset = tf.data.Dataset.from_tensor_slices((train_files, train_labels))
|
| | val_dataset = tf.data.Dataset.from_tensor_slices((val_files, val_labels))
|
| |
|
| | test_dataset = tf.data.Dataset.from_tensor_slices((test_files, test_labels))
|
| |
|
| | train_dataset = train_dataset.map(lambda x, y: tf_load_and_augment_image(x, y))
|
| | val_dataset = val_dataset.map(lambda x, y: tf_load_and_augment_image(x, y))
|
| | test_dataset = test_dataset.map(lambda x, y: tf_load_and_augment_image(x, y))
|
| |
|
| | train_dataset = train_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
|
| | val_dataset = val_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
|
| | test_dataset = test_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
|
| |
|
| | return train_dataset, val_dataset, test_dataset
|
| |
|
| | def save_datasets_to_folders(dataset, folder_path, datagens=None):
|
| |
|
| | if not os.path.exists(folder_path):
|
| | os.makedirs(folder_path)
|
| |
|
| | count = 0
|
| | for batch_images, batch_labels in tqdm(dataset, desc=f"Saving to {folder_path}"):
|
| | for i in range(batch_images.shape[0]):
|
| | image = batch_images[i].numpy()
|
| | label = batch_labels[i].numpy().decode('utf-8')
|
| | label_folder = os.path.join(folder_path, label)
|
| | if not os.path.exists(label_folder):
|
| | os.makedirs(label_folder)
|
| |
|
| |
|
| | file_path = os.path.join(label_folder, f"{uuid.uuid4().hex}.jpg")
|
| | save_image(image, file_path)
|
| | count += 1
|
| |
|
| |
|
| | if datagens:
|
| | for datagen in datagens:
|
| | aug_image = datagen.random_transform(image)
|
| | file_path = os.path.join(label_folder, f"{uuid.uuid4().hex}.jpg")
|
| | save_image(aug_image, file_path)
|
| | count += 1
|
| |
|
| | print(f"Saved {count} images to {folder_path}\n")
|
| | return count
|
| |
|
| | def main():
|
| |
|
| | args = parse_arguments()
|
| |
|
| | if not os.path.exists(args.target_folder):
|
| | os.makedirs(args.target_folder)
|
| |
|
| | train_folder = os.path.join(args.target_folder, 'train')
|
| | val_folder = os.path.join(args.target_folder, 'val')
|
| | test_folder = os.path.join(args.target_folder, 'test')
|
| |
|
| | datagens = create_datagens() if args.augment_data else None
|
| |
|
| | train_dataset, val_dataset, test_dataset = load_data(
|
| | args.path,
|
| | args.dim,
|
| | args.batch_size
|
| | )
|
| |
|
| |
|
| | train_count = save_datasets_to_folders(train_dataset, train_folder, datagens)
|
| | val_count = save_datasets_to_folders(val_dataset, val_folder)
|
| | test_count = save_datasets_to_folders(test_dataset, test_folder)
|
| |
|
| | print(f"Train dataset saved to: {train_folder}\n")
|
| | print(f"Validation dataset saved to: {val_folder}\n")
|
| | print(f"Test dataset saved to: {test_folder}\n")
|
| |
|
| | print('-'*20)
|
| |
|
| | print(f"Number of images in training set: {train_count}\n")
|
| | print(f"Number of images in validation set: {val_count}\n")
|
| | print(f"Number of images in test set: {test_count}\n")
|
| |
|
| | if __name__ == "__main__":
|
| |
|
| | sys.stdout = open(sys.stdout.fileno(), mode='w', encoding='utf-8', buffering=1)
|
| | sys.stderr = open(sys.stderr.fileno(), mode='w', encoding='utf-8', buffering=1)
|
| | main()
|
| |
|