| import os |
| import numpy as np |
| import tensorflow as tf |
| import matplotlib.pyplot as plt |
| from tensorflow.keras.preprocessing.image import ImageDataGenerator |
| from model import create_model |
|
|
| base_dir = 'data/chest_xray' |
| train_dir = os.path.join(base_dir, 'train') |
| val_dir = os.path.join(base_dir, 'val') |
|
|
| train_datagen = ImageDataGenerator( |
| rescale=1./255, |
| rotation_range=20, |
| width_shift_range=0.2, |
| height_shift_range=0.2, |
| shear_range=0.2, |
| zoom_range=0.2, |
| horizontal_flip=True, |
| fill_mode='nearest' |
| ) |
| val_datagen = ImageDataGenerator(rescale=1./255) |
|
|
| train_generator = train_datagen.flow_from_directory( |
| train_dir, |
| target_size=(150, 150), |
| batch_size=32, |
| class_mode='binary' |
| ) |
|
|
| val_generator = val_datagen.flow_from_directory( |
| val_dir, |
| target_size=(150, 150), |
| batch_size=32, |
| class_mode='binary' |
| ) |
|
|
| sample_images, _ = next(train_generator) |
| for i in range(5): |
| plt.subplot(1, 5, i+1) |
| plt.imshow(sample_images[i]) |
| plt.axis('off') |
| plt.show() |
|
|
| model = create_model() |
|
|
| history = model.fit( |
| train_generator, |
| steps_per_epoch=243, |
| epochs=10, |
| validation_data=val_generator, |
| validation_steps=280, |
| callbacks=[tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)] |
| ) |
|
|
| model.save('xray_image_classifier_model.keras') |
|
|