File size: 1,952 Bytes
f16e7e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# data_loader.py — dataset loading and augmentation

from tensorflow.keras.preprocessing.image import ImageDataGenerator
from src.utils import get_logger

logger = get_logger("data_loader")


def get_data_generators(cfg: dict):
    data_cfg = cfg["data"]
    aug_cfg  = cfg["augmentation"]

    image_size  = tuple(data_cfg["image_size"])
    batch_size  = data_cfg["batch_size"]
    val_split   = data_cfg["validation_split"]
    train_dir   = data_cfg["train_dir"]
    test_dir    = data_cfg["test_dir"]

    train_gen = ImageDataGenerator(
        rescale=1./255,
        validation_split=val_split,
        rotation_range=aug_cfg["rotation_range"],
        width_shift_range=aug_cfg["width_shift_range"],
        height_shift_range=aug_cfg["height_shift_range"],
        zoom_range=aug_cfg["zoom_range"],
        horizontal_flip=aug_cfg["horizontal_flip"],
        brightness_range=aug_cfg["brightness_range"],
    )

    test_gen = ImageDataGenerator(rescale=1./255)

    train_data = train_gen.flow_from_directory(
        train_dir,
        target_size=image_size,
        batch_size=batch_size,
        class_mode="categorical",
        subset="training",
        seed=cfg["project"]["seed"],
        shuffle=True
    )

    val_data = train_gen.flow_from_directory(
        train_dir,
        target_size=image_size,
        batch_size=batch_size,
        class_mode="categorical",
        subset="validation",
        seed=cfg["project"]["seed"],
        shuffle=False
    )

    test_data = test_gen.flow_from_directory(
        test_dir,
        target_size=image_size,
        batch_size=batch_size,
        class_mode="categorical",
        shuffle=False
    )

    logger.info(f"Train samples : {train_data.samples}")
    logger.info(f"Val samples   : {val_data.samples}")
    logger.info(f"Test samples  : {test_data.samples}")
    logger.info(f"Classes       : {train_data.class_indices}")

    return train_data, val_data, test_data