AIOmarRehan commited on
Commit
7526f34
·
verified ·
1 Parent(s): 4bf5b66

Delete Notebook and Py File/inceptionv3_image_classification.py

Browse files
Notebook and Py File/inceptionv3_image_classification.py DELETED
@@ -1,594 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- """InceptionV3_Image_Classification.ipynb
3
-
4
- Automatically generated by Colab.
5
-
6
- Original file is located at
7
- https://colab.research.google.com/drive/1qSA_Rg-2gDZ0lKcAuLivNsCgrBmJOcfz
8
-
9
- # **Import Dependencies**
10
- """
11
-
12
- import warnings
13
- warnings.filterwarnings('ignore')
14
-
15
- import zipfile
16
- import hashlib
17
- import matplotlib.pyplot as plt
18
- import pandas as pd
19
- import os
20
- import uuid
21
- import re
22
- import random
23
- import cv2
24
- import numpy as np
25
- import tensorflow as tf
26
- import seaborn as sns
27
- from google.colab import drive
28
- from google.colab import files
29
- from pathlib import Path
30
- from PIL import Image, ImageStat, UnidentifiedImageError, ImageEnhance
31
- from matplotlib import patches
32
- from tqdm import tqdm
33
- from collections import defaultdict
34
- from sklearn.preprocessing import LabelEncoder, label_binarize
35
- from sklearn.model_selection import train_test_split
36
- from sklearn.utils import resample
37
- from tensorflow import keras
38
- from tensorflow.keras.applications.inception_v3 import InceptionV3, preprocess_input
39
- from tensorflow.keras.utils import to_categorical
40
- from tensorflow.keras import layers, models, optimizers, callbacks, regularizers
41
- from tensorflow.keras.models import Sequential
42
- from tensorflow.keras.layers import Conv2D, BatchNormalization, MaxPooling2D,Dropout, Flatten, Dense, GlobalAveragePooling2D
43
- from tensorflow.keras.regularizers import l2
44
- from tensorflow.keras.optimizers import Adam, AdamW
45
- from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint
46
- from tensorflow.keras import Input, Model
47
- from tensorflow.keras.preprocessing.image import ImageDataGenerator, img_to_array, load_img, array_to_img
48
- from tensorflow.keras.preprocessing import image
49
- from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, roc_curve, precision_score, recall_score, f1_score, precision_recall_fscore_support, auc
50
-
51
- print(tf.__version__)
52
-
53
- drive.mount('/content/drive')
54
- zip_path = '/content/drive/MyDrive/Animals.zip'
55
- extract_to = '/content/my_data'
56
- with zipfile.ZipFile(zip_path, 'r') as zip_ref:
57
- zip_ref.extractall(extract_to)
58
-
59
- """# **Convert Dataset to a Data Frame**"""
60
-
61
- image_extensions = {'.jpg', '.jpeg', '.png'}
62
- paths = [(path.parts[-2], path.name, str(path)) for path in Path(extract_to).rglob('*.*') if path.suffix.lower() in image_extensions]
63
-
64
- df = pd.DataFrame(paths, columns = ['class', 'image', 'full_path'])
65
- df = df.sort_values('class', ascending = True)
66
- df.reset_index(drop = True, inplace = True)
67
- df
68
-
69
- """# **EDA Process**"""
70
-
71
- class_count = df['class'].value_counts()
72
- for cls, count in class_count.items():
73
- print(f'Class: {cls}, Count: {count} images')
74
-
75
- print(f"\nTotal dataset size is: {len(df)} images")
76
- print(f"Number of classes: {df['class'].nunique()} classes")
77
-
78
- plt.figure(figsize = (32, 16))
79
- class_count.plot(kind = 'bar', color = 'skyblue', edgecolor = 'black')
80
- plt.title('Number of Images per Class')
81
- plt.xlabel('Class')
82
- plt.ylabel('Count')
83
- plt.xticks(rotation = 45)
84
- plt.show()
85
-
86
- plt.figure(figsize = (32, 16))
87
- class_count.plot(kind = 'pie', autopct = '%1.1f%%', colors = plt.cm.Paired.colors)
88
- plt.title('Percentage of Images per Class')
89
- plt.ylabel('')
90
- plt.show()
91
-
92
- percentages = (class_count / len(df)) * 100
93
- imbalance_df = pd.DataFrame({'Count': class_count, 'Percentage %': percentages.round(2)})
94
- print(imbalance_df)
95
-
96
- plt.figure(figsize = (32, 16))
97
- class_count.plot(kind = 'bar', color = 'lightgreen', edgecolor = 'black')
98
- plt.title('Class Distribution Check')
99
- plt.xlabel('Class')
100
- plt.ylabel('Count')
101
- plt.xticks(rotation = 45)
102
- plt.axhline(y = class_count.mean(), color = 'red', linestyle = '--', label = 'Average Count')
103
- plt.legend()
104
- plt.show()
105
-
106
- image_sizes = []
107
-
108
- for file_path in df['full_path']:
109
- with Image.open(file_path) as img:
110
- image_sizes.append(img.size)
111
-
112
- sizes_df = pd.DataFrame(image_sizes, columns=['Width', 'Height'])
113
-
114
- #Width
115
- plt.figure(figsize=(8,5))
116
- plt.scatter(x = range(len(sizes_df)), y = sizes_df['Width'], color='skyblue', s=10)
117
- plt.title('Image Width Distribution')
118
- plt.xlabel('Width (pixels)')
119
- plt.ylabel('Frequency')
120
- plt.show()
121
-
122
- #Height
123
- plt.figure(figsize=(8,5))
124
- plt.scatter(x = sizes_df['Height'], y = range(len(sizes_df)), color='lightgreen', s=10)
125
- plt.title('Image Height Distribution')
126
- plt.xlabel('Height (pixels)')
127
- plt.ylabel('Frequency')
128
- plt.show()
129
-
130
- #For best sure the size of the whole images
131
- unique_sizes = sizes_df.value_counts().reset_index(name='Count')
132
- print(unique_sizes)
133
-
134
- image_data = []
135
-
136
- for file_path in df['full_path']:
137
- with Image.open(file_path) as img:
138
- width, height = img.size
139
- mode = img.mode # e.g., 'RGB', 'L', 'RGBA', etc.
140
- channels = len(img.getbands()) # Number of channels
141
- image_data.append((width, height, mode, channels))
142
-
143
- # Create DataFrame
144
- image_df = pd.DataFrame(image_data, columns=['Width', 'Height', 'Mode', 'Channels'])
145
-
146
- print("Image Mode Distribution:")
147
- print(image_df['Mode'].value_counts())
148
-
149
- print("\nNumber of Channels Distribution:")
150
- print(image_df['Channels'].value_counts())
151
-
152
- plt.figure(figsize=(6,4))
153
- image_df['Mode'].value_counts().plot(kind='bar', color='coral')
154
- plt.title("Image Mode Distribution")
155
- plt.xlabel("Mode")
156
- plt.ylabel("Count")
157
- plt.xticks(rotation=45)
158
- plt.tight_layout()
159
- plt.show()
160
-
161
- plt.figure(figsize=(6,4))
162
- image_df['Channels'].value_counts().sort_index().plot(kind='bar', color='slateblue')
163
- plt.title("Number of Channels per Image")
164
- plt.xlabel("Channels")
165
- plt.ylabel("Count")
166
- plt.xticks(rotation=0)
167
- plt.tight_layout()
168
- plt.show()
169
-
170
- sample_df = df.sample(n = 10, random_state = 42)
171
-
172
- plt.figure(figsize=(32, 16))
173
-
174
- for i, (cls, img_name, full_path) in enumerate(sample_df.values):
175
- with Image.open(full_path) as img:
176
- stat = ImageStat.Stat(img.convert("RGB")) #Convert images to RGB images
177
- brightness = stat.mean[0]
178
- contrast = stat.stddev[0]
179
-
180
- width, height = img.size
181
- # Print size to console
182
- print(f"Image: {img_name} | Class: {cls} | Size: {width}x{height} | Brightness: {brightness:.1f} | Contrast: {contrast:.1f}")
183
-
184
- plt.subplot(2, 5, i + 1)
185
- plt.imshow(img)
186
- plt.axis('off')
187
- plt.title(f"Class: {cls}\nImage: {img_name}\nBrightness: {brightness:.2f}\nContrast: {contrast:.2f} \nSize: {width}x{height}")
188
-
189
- plt.tight_layout
190
- plt.show()
191
-
192
- # Sample 20 random images
193
- num_samples = 20
194
- sample_df = df.sample(num_samples, random_state=42)
195
-
196
- # Get sorted class list and color map
197
- classes = sorted(df['class'].unique())
198
- colors = plt.cm.tab10.colors
199
-
200
- # Grid setup
201
- cols = 4
202
- rows = num_samples // cols + int(num_samples % cols > 0)
203
-
204
- # Figure setup
205
- plt.figure(figsize=(15, 5 * rows))
206
-
207
- for idx, (cls, img_name, full_path) in enumerate(sample_df.values):
208
- with Image.open(full_path) as img:
209
- ax = plt.subplot(rows, cols, idx + 1)
210
- ax.imshow(img)
211
- ax.axis('off')
212
-
213
- # Title with class info
214
- ax.set_title(
215
- f"Class: {cls} \nImage: {img_name} \nSize: {img.width} x {img.height}",
216
- fontsize=10
217
- )
218
-
219
- # Rectangle in axes coords: full width, small height at top
220
- label_height = 0.1 # 10% of image height
221
- label_width = 1.0 # full width of the image
222
-
223
- rect = patches.Rectangle(
224
- (0, 1 - label_height), label_width, label_height,
225
- transform=ax.transAxes,
226
- linewidth=0,
227
- edgecolor=None,
228
- facecolor=colors[classes.index(cls) % len(colors)],
229
- alpha=0.7
230
- )
231
- ax.add_patch(rect)
232
-
233
- # Add class name text centered horizontally
234
- ax.text(
235
- 0.5, 1 - label_height / 2,
236
- cls,
237
- transform=ax.transAxes,
238
- fontsize=12,
239
- color="white",
240
- fontweight="bold",
241
- va="center",
242
- ha="center"
243
- )
244
-
245
- # Figure title and layout
246
- plt.suptitle("Random Dataset Samples - Sanity Check", fontsize=18, fontweight="bold")
247
- plt.tight_layout(rect=[0, 0, 1, 0.96])
248
- plt.show()
249
-
250
- #Check missing files
251
- print("Missing values per column: ")
252
- print(df.isnull().sum())
253
-
254
- #Check duplicate files
255
- duplicate_names = df.duplicated().sum()
256
- print(f"\nNumber of duplicate files: {duplicate_names}")
257
-
258
- duplicate_names = df[df.duplicated(subset = ['image'], keep = False)]
259
- print(f"Duplicate file names: {len(duplicate_names)}")
260
-
261
- #Check if two images or more are the same even if they are having different file names
262
- def get_hash(file_path):
263
- with open(file_path, 'rb') as f:
264
- return hashlib.md5(f.read()).hexdigest()
265
-
266
- df['file_hash'] = df['full_path'].apply(get_hash)
267
- duplicate_hashes = df[df.duplicated(subset = ['file_hash'], keep = False)]
268
- print(f"Duplicate image files: {len(duplicate_hashes)}")
269
-
270
- #This code below just removing the duplicate files, which means will not be feeded to the model, but will be still in the actual directory
271
- #Important note: duplicates are removed from the dataframe only, not from the actual directory.
272
- #Drop duplicates based on file_hash, keeping the first one
273
-
274
- # df_unique = df.drop_duplicates(subset='file_hash', keep='first')
275
- # print(f"After removing duplicates, unique images: {len(df_unique)}")
276
-
277
- #Check for images extentions
278
- df['extenstion'] = df['image'].apply(lambda x: Path(x).suffix.lower())
279
- print("File type counts: ")
280
- #print(df['extenstion'].value_counts)
281
- print(df['extenstion'].value_counts())
282
-
283
- #Check for resolution relationships
284
- df['Width'] = sizes_df['Width']
285
- df['Height'] = sizes_df['Height']
286
- #print(df.groupby(['Width', 'Height']).size())
287
- print(df.groupby('class')[['Width', 'Height']].agg(['min', 'max', 'mean']))
288
-
289
- #Check for class balance (relationship between label and count)
290
- class_summary = df['class'].value_counts(normalize = False).to_frame('Count')
291
- #class_summary['Percentage'] = class_summary['Count'] / class_summary['Count'].sum() * 100
292
- #class_summary
293
- class_summary['Percentage %'] = round((class_summary['Count'] / len(df)) * 100, 2)
294
- print(class_summary)
295
-
296
- """# **Data Cleaning Process**"""
297
-
298
- corrupted_files = []
299
-
300
- for file_path in df['full_path']:
301
- try:
302
- with Image.open(file_path) as img:
303
- img.verify()
304
- except (UnidentifiedImageError, OSError):
305
- corrupted_files.append(file_path)
306
-
307
- print(f"Found {len(corrupted_files)} corrupted images.")
308
- #except (IOError, SyntaxError) as e:
309
- #corrupted_files.append(file_path)
310
-
311
- #print(f"Number of corrupted files: {len(corrupted_files)}")
312
-
313
- if corrupted_files:
314
- df = df[~df['full_path'].isin(corrupted_files)].reset_index(drop = True)
315
- print("Corrupted files removed.")
316
-
317
- #Outliers detection
318
- #Resolution-based outlier detection
319
- #width_mean = width_std = sizes_df['Width'].mean(), sizes_df['Width'].std()
320
- #height_mean = height_std = sizes_df['Height'].mean(), sizes_df['Height'].std()
321
-
322
- width_mean = sizes_df['Width'].mean()
323
- width_std = sizes_df['Width'].std()
324
- height_mean = sizes_df['Height'].mean()
325
- height_std = sizes_df['Height'].std()
326
-
327
- outliers = df[(df['Width'] > width_mean + 3 * width_std) | (df['Width'] < width_mean - 3 * width_std) | (df['Height'] > height_mean + 3 * height_std) | (df['Height'] < height_mean - 3 * height_std)]
328
- #print(f"Number of outliers: {len(outliers)}")
329
- print(f"Found {len(outliers)} resolution outliers.")
330
-
331
- df["image"] = df["full_path"].apply(lambda p: Image.open(p).convert('RGB')) #Convert it to RGB for flexibility
332
-
333
- too_dark = []
334
- too_bright = []
335
- blank_or_gray = []
336
-
337
- # Thresholds
338
- dark_threshold = 30 # Below this is too dark
339
- bright_threshold = 225 # Above this is too bright
340
- low_contrast_threshold = 5 # Low contrast ~ blank/gray
341
-
342
- for idx, img in enumerate(df["image"]):
343
- gray = img.convert('L')
344
- stat = ImageStat.Stat(gray) # Convert to grayscale for brightness/contrast analysis
345
- brightness = stat.mean[0]
346
- contrast = stat.stddev[0]
347
-
348
- if brightness < dark_threshold:
349
- too_dark.append(idx)
350
- elif brightness > bright_threshold:
351
- too_bright.append(idx)
352
- elif contrast < low_contrast_threshold:
353
- blank_or_gray.append(idx)
354
-
355
- print(f"Too dark images: {len(too_dark)}")
356
- print(f"Too bright images: {len(too_bright)}")
357
- print(f"Blank/gray images: {len(blank_or_gray)}")
358
-
359
- # df = df.drop(index=too_bright + blank_or_gray).reset_index(drop=True) --> DROPS too_bright + blank_or_gray TOGETHER!
360
-
361
- for idx, row in tqdm(df.iterrows(), total=len(df), desc="Enhancing images"):
362
- img = row["image"]
363
-
364
- # Enhance too dark images
365
- if row["full_path"] in df.loc[too_dark, "full_path"].values:
366
- img = ImageEnhance.Brightness(img).enhance(1.5) # Increase brightness
367
- img = ImageEnhance.Contrast(img).enhance(1.5) # Increase contrast
368
-
369
- # Decrease brightness for too bright images
370
- if row["full_path"] in df.loc[too_bright, "full_path"].values:
371
- img = ImageEnhance.Brightness(img).enhance(0.7) # Decrease brightness (less than 1)
372
- img = ImageEnhance.Contrast(img).enhance(1.2) # Optionally, you can also enhance contrast
373
-
374
- # Overwrite the image back into the DataFrame
375
- df.at[idx, "image"] = img
376
-
377
- print(f"Enhanced images in memory: {len(df)}")
378
-
379
- # Lists to store paths of still too dark and too bright images
380
- still_dark = []
381
- still_bright = []
382
-
383
- # Threshold for "too bright" (already defined as bright_threshold)
384
- for idx, img in enumerate(df["image"]):
385
- gray = img.convert('L') # Convert to grayscale for brightness analysis
386
- stat = ImageStat.Stat(gray)
387
- brightness = stat.mean[0]
388
-
389
- # Check if the image is still too dark
390
- if brightness < dark_threshold:
391
- still_dark.append(df.loc[idx, 'full_path'])
392
-
393
- # Check if the image is too bright
394
- if brightness > bright_threshold:
395
- still_bright.append(df.loc[idx, 'full_path'])
396
-
397
- print(f"Still too dark after enhancement: {len(still_dark)} images")
398
- print(f"Still too bright after enhancement: {len(still_bright)} images")
399
-
400
- # Point to the extracted dataset, not the zip file location
401
- dataset_root = "/content/my_data/Animals"
402
-
403
- # Check mislabeled images
404
- mismatches = []
405
- for i, row in df.iterrows():
406
- folder_name = os.path.basename(os.path.dirname(row["full_path"]))
407
- if row["class"] != folder_name:
408
- mismatches.append((row["full_path"], row["class"], folder_name))
409
-
410
- print(f"Found {len(mismatches)} mislabeled images (class vs folder mismatch).")
411
-
412
- # Compare classes vs folders
413
- classes_in_df = set(df["class"].unique())
414
- folders_in_fs = {f for f in os.listdir(dataset_root) if os.path.isdir(os.path.join(dataset_root, f))}
415
-
416
- print("Classes in DF but not in folders:", classes_in_df - folders_in_fs)
417
- print("Folders in FS but not in DF:", folders_in_fs - classes_in_df)
418
-
419
- def check_file_naming_issues(df):
420
- issues = {"invalid_chars": [], "spaces": [], "long_paths": [], "case_conflicts": [], "duplicate_names_across_classes": []}
421
-
422
- seen_names = {}
423
-
424
- for _, row in df.iterrows():
425
- fpath = row["full_path"] # full path
426
- fname = os.path.basename(fpath) # just filename
427
- cls = row["class"]
428
-
429
- if re.search(r'[<>:"/\\|?*]', fname): # Windows restricted chars
430
- issues["invalid_chars"].append(fpath)
431
-
432
- if " " in fname or fname.startswith(" ") or fname.endswith(" "):
433
- issues["spaces"].append(fpath)
434
-
435
- if len(fpath) > 255:
436
- issues["long_paths"].append(fpath)
437
-
438
- lower_name = fname.lower()
439
- if lower_name in seen_names and seen_names[lower_name] != cls:
440
- issues["case_conflicts"].append((fpath, seen_names[lower_name]))
441
- else:
442
- seen_names[lower_name] = cls
443
-
444
- duplicates = df.groupby(df["full_path"].apply(os.path.basename))["class"].nunique()
445
- duplicates = duplicates[duplicates > 1].index.tolist()
446
- for dup in duplicates:
447
- dup_paths = df[df["full_path"].str.endswith(dup)]["full_path"].tolist()
448
- issues["duplicate_names_across_classes"].extend(dup_paths)
449
-
450
- return issues
451
-
452
- # Run the check
453
- naming_issues = check_file_naming_issues(df)
454
-
455
- for issue_type, files in naming_issues.items():
456
- print(f"\n{issue_type.upper()} ({len(files)})")
457
- for f in files[:10]: # preview first 10
458
- print(f)
459
-
460
- """# **Data Preprocessing Process**"""
461
-
462
- def preprocess_image(path, target_size=(256, 256), augment=True):
463
- img = tf.io.read_file(path)
464
- img = tf.image.decode_image(img, channels=3, expand_animations=False)
465
- img = tf.image.resize(img, target_size)
466
- img = tf.cast(img, tf.float32) / 255.0
467
-
468
- if augment and tf.random.uniform(()) < 0.1: # Only 10% chance
469
- img = tf.image.random_flip_left_right(img)
470
- img = tf.image.random_flip_up_down(img)
471
- img = tf.image.random_brightness(img, max_delta=0.1)
472
- img = tf.image.random_contrast(img, lower=0.9, upper=1.1)
473
-
474
- return img
475
-
476
- le = LabelEncoder()
477
- df['label'] = le.fit_transform(df['class'])
478
-
479
- # Prepare paths and labels
480
- paths = df['full_path'].values
481
- labels = df['label'].values
482
-
483
- AUTOTUNE = tf.data.AUTOTUNE
484
- batch_size = 32
485
-
486
- # Split data into train+val and test (10% test)
487
- train_val_paths, test_paths, train_val_labels, test_labels = train_test_split(
488
- paths, labels, test_size=0.1, random_state=42, stratify=labels
489
- )
490
-
491
- # Split train+val into train and val (10% of train_val as val)
492
- train_paths, val_paths, train_labels, val_labels = train_test_split(
493
- train_val_paths, train_val_labels, test_size=0.1, random_state=42, stratify=train_val_labels
494
- )
495
-
496
- # Create datasets
497
- def load_and_preprocess(path, label):
498
- return preprocess_image(path), label
499
-
500
- train_ds = tf.data.Dataset.from_tensor_slices((train_paths, train_labels))
501
- train_ds = train_ds.map(lambda x, y: (preprocess_image(x, augment=True), y), num_parallel_calls=AUTOTUNE)
502
- train_ds = train_ds.shuffle(1024).batch(batch_size).prefetch(AUTOTUNE)
503
-
504
- val_ds = tf.data.Dataset.from_tensor_slices((val_paths, val_labels))
505
- val_ds = val_ds.map(load_and_preprocess, num_parallel_calls=AUTOTUNE)
506
- val_ds = val_ds.batch(batch_size).prefetch(AUTOTUNE)
507
-
508
- test_ds = tf.data.Dataset.from_tensor_slices((test_paths, test_labels))
509
- test_ds = test_ds.map(load_and_preprocess, num_parallel_calls=AUTOTUNE)
510
- test_ds = test_ds.batch(batch_size).prefetch(AUTOTUNE)
511
-
512
- print("Dataset sizes:")
513
- print(f"Train: {len(train_paths)} images")
514
- print(f"Validation: {len(val_paths)} images")
515
- print(f"Test: {len(test_paths)} images")
516
- print("--------------------------------------------------")
517
- print("Train labels sample:", train_labels[:10])
518
- print("Validation labels sample:", val_labels[:10])
519
- print("Test labels sample:", test_labels[:10])
520
-
521
- # Preview normalized image stats and visualization
522
- for image_batch, label_batch in train_ds.take(1):
523
- # Print pixel value stats for first image in the batch
524
- image = image_batch[0]
525
- label = label_batch[0]
526
- print("Image dtype:", image.dtype)
527
- print("Min pixel value:", tf.reduce_min(image).numpy())
528
- print("Max pixel value:", tf.reduce_max(image).numpy())
529
- print("Label:", label.numpy())
530
-
531
- # Show the image
532
- plt.imshow(image.numpy())
533
- plt.title(f"Label: {label.numpy()}")
534
- plt.axis('off')
535
- plt.show()
536
- print("---------------------------------------------------")
537
- print("Number of Classes: ", len(le.classes_))
538
-
539
- # After train_ds is defined
540
- for image_batch, label_batch in train_ds.take(1):
541
- print("Image batch shape:", image_batch.shape) # full batch shape
542
- print("Label batch shape:", label_batch.shape) # labels shape
543
-
544
- input_shape = image_batch.shape[1:] # shape of a single image
545
- print("Single image shape:", input_shape)
546
- break
547
-
548
- """# **Model Loading**"""
549
-
550
- inception = InceptionV3(input_shape=input_shape, weights='imagenet', include_top=False)
551
-
552
- # don't train existing weights
553
- for layer in inception.layers:
554
- layer.trainable = False
555
-
556
- # Number of classes
557
- print("Number of Classes: ", len(le.classes_))
558
-
559
- x = GlobalAveragePooling2D()(inception.output)
560
- x = Dense(512, activation='relu')(x)
561
- x = Dropout(0.5)(x)
562
- prediction = Dense(len(le.classes_), activation='softmax')(x)
563
-
564
- # create a model object
565
- model = Model(inputs=inception.input, outputs=prediction)
566
-
567
- # view the structure of the model
568
- model.summary()
569
-
570
- # tell the model what cost and optimization method to use
571
- model.compile(
572
- loss='sparse_categorical_crossentropy',
573
- optimizer='adam',
574
- metrics=['accuracy']
575
- )
576
-
577
- """# **Model Training**"""
578
-
579
- callbacks = [
580
- EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True, verbose = 1),
581
- ModelCheckpoint("best_model.h5", save_best_only=True, monitor='val_loss', verbose = 1),
582
- ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=10, min_lr=1e-5, verbose=1)
583
- ]
584
-
585
- history = model.fit(train_ds, validation_data=val_ds, epochs=10, callbacks=callbacks, verbose = 1)
586
-
587
- """# **Model Evaluation**"""
588
-
589
- model.evaluate(test_ds)
590
-
591
- """# **Saving the Model**"""
592
-
593
- model.save("Simple_CNN_Classification.h5")
594
- files.download("Simple_CNN_Classification.h5")