| | """
|
| | Training script for surgical instrument classification
|
| | """
|
| |
|
| | import os
|
| | import pickle
|
| | import cv2
|
| | import pandas as pd
|
| | import numpy as np
|
| | from utils.utils import extract_features_from_image, fit_pca_transformer, train_svm_model, augment_image
|
| | from utils.utils import extract_features_from_image, fit_pca_transformer, augment_image
|
| | from sklearn.model_selection import GridSearchCV, train_test_split
|
| | from sklearn.svm import SVC
|
| | from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report
|
| |
|
| | def train_and_save_model(base_path, images_folder, gt_csv, save_dir, n_components=100):
|
| | """
|
| | Complete training pipeline that saves everything needed for submission
|
| |
|
| | Args:
|
| | base_path: Base directory path
|
| | images_folder: Folder name containing images
|
| | gt_csv: Ground truth CSV filename
|
| | save_dir: Directory to save model artifacts
|
| | n_components: Number of PCA components
|
| | """
|
| |
|
| | print("="*80)
|
| | print("TRAINING SURGICAL INSTRUMENT CLASSIFIER")
|
| | print("="*80)
|
| |
|
| |
|
| | PATH_TO_GT = os.path.join(base_path, gt_csv)
|
| | PATH_TO_IMAGES = os.path.join(base_path, images_folder)
|
| |
|
| | print(f"\nConfiguration:")
|
| | print(f" Ground Truth: {PATH_TO_GT}")
|
| | print(f" Images: {PATH_TO_IMAGES}")
|
| | print(f" PCA Components: {n_components}")
|
| | print(f" Save directory: {save_dir}")
|
| |
|
| |
|
| | df = pd.read_csv(PATH_TO_GT)
|
| | print(f"\nLoaded {len(df)} training samples")
|
| | print(f"\nLabel distribution:")
|
| | print(df['category_id'].value_counts().sort_index())
|
| |
|
| |
|
| | print(f"\n{'='*80}")
|
| | print("STEP 1: FEATURE EXTRACTION WITH AUGMENTATION")
|
| | print("="*80)
|
| |
|
| |
|
| | AUGMENTATIONS_PER_IMAGE = 2
|
| |
|
| | print(f"\nAugmentation settings:")
|
| | print(f" Augmentations per image: {AUGMENTATIONS_PER_IMAGE}")
|
| | print(f" Rotation range: -10° to +10°")
|
| | print(f" Brightness range: 0.9 to 1.1")
|
| | print(f" Horizontal flip: Yes")
|
| | print(f" Gaussian noise: σ=3")
|
| | print(f" Expected total samples: {len(df) * (1 + AUGMENTATIONS_PER_IMAGE)}")
|
| |
|
| | features = []
|
| | labels = []
|
| |
|
| | for i in range(len(df)):
|
| | if i % 500 == 0:
|
| | print(f" Processing {i}/{len(df)* (1 + AUGMENTATIONS_PER_IMAGE)}...")
|
| |
|
| | image_name = df.iloc[i]["file_name"]
|
| | label = df.iloc[i]["category_id"]
|
| |
|
| | path_to_image = os.path.join(PATH_TO_IMAGES, image_name)
|
| |
|
| | try:
|
| | image = cv2.imread(path_to_image)
|
| | if image is None:
|
| | print(f" Warning: Could not read {image_name}, skipping...")
|
| | continue
|
| |
|
| |
|
| | original_features = extract_features_from_image(image)
|
| | features.append(original_features)
|
| | labels.append(label)
|
| |
|
| |
|
| | for aug_idx in range(AUGMENTATIONS_PER_IMAGE):
|
| |
|
| | aug_image = augment_image(
|
| | image,
|
| | rotation_range=(-10, 10),
|
| | brightness_range=(0.9, 1.1),
|
| | add_noise=True,
|
| | noise_sigma=3,
|
| | horizontal_flip=(aug_idx == 0)
|
| | )
|
| |
|
| |
|
| | aug_features = extract_features_from_image(aug_image)
|
| | features.append(aug_features)
|
| | labels.append(label)
|
| |
|
| | except Exception as e:
|
| | print(f" Error processing {image_name}: {e}")
|
| | continue
|
| |
|
| | features_array = np.array(features)
|
| | labels_array = np.array(labels)
|
| |
|
| | print(f"\nFeature extraction complete!")
|
| | print(f" Original samples: {len(df)}")
|
| | print(f" Total samples (with augmentation): {len(features)}")
|
| | print(f" Features shape: {features_array.shape}")
|
| | print(f" Labels shape: {labels_array.shape}")
|
| | print(f" Feature dimension: {features_array.shape[1]}")
|
| |
|
| |
|
| | print(f"\n{'='*80}")
|
| | print("STEP 2: DIMENSIONALITY REDUCTION (PCA)")
|
| | print("="*80)
|
| |
|
| | pca_params, features_reduced = fit_pca_transformer(features_array, n_components)
|
| |
|
| | print(f" Reduced from {features_array.shape[1]} to {n_components} dimensions")
|
| | print(f" Explained variance: {pca_params['cumulative_variance'][-1]:.4f}")
|
| |
|
| |
|
| | print(f"\n{'='*80}")
|
| | print("STEP 3: TRAINING SVM CLASSIFIER WITH GRID SEARCH")
|
| | print("="*80)
|
| |
|
| |
|
| | X_train, X_test, y_train, y_test = train_test_split(
|
| | features_reduced,
|
| | labels_array,
|
| | test_size=0.2,
|
| | random_state=42,
|
| | stratify=labels_array
|
| | )
|
| |
|
| | print(f"\nData split:")
|
| | print(f" Training samples: {len(X_train)}")
|
| | print(f" Test samples: {len(X_test)}")
|
| |
|
| |
|
| | param_grid = {
|
| | 'C': [1, 10, 50, 100],
|
| | 'gamma': ['scale', 0.001, 0.01, 0.1],
|
| | 'kernel': ['rbf']
|
| | }
|
| |
|
| | print(f"\nGrid Search parameters:")
|
| | print(f" C values: {param_grid['C']}")
|
| | print(f" Gamma values: {param_grid['gamma']}")
|
| | print(f" Kernel: {param_grid['kernel']}")
|
| | print(f" Total combinations: {len(param_grid['C']) * len(param_grid['gamma'])}")
|
| | print(f" Cross-validation folds: 3")
|
| | print(f"\nThis will take 15-30 minutes...")
|
| |
|
| |
|
| | grid_search = GridSearchCV(
|
| | SVC(),
|
| | param_grid,
|
| | cv=3,
|
| | scoring='f1_macro',
|
| | n_jobs=-1,
|
| | verbose=2
|
| | )
|
| |
|
| | print("\nStarting Grid Search...")
|
| | grid_search.fit(X_train, y_train)
|
| |
|
| |
|
| | svm_model = grid_search.best_estimator_
|
| |
|
| | print(f"\n{'='*80}")
|
| | print("GRID SEARCH COMPLETE!")
|
| | print("="*80)
|
| | print(f"\nBest parameters found:")
|
| | print(f" C: {grid_search.best_params_['C']}")
|
| | print(f" Gamma: {grid_search.best_params_['gamma']}")
|
| | print(f" Kernel: {grid_search.best_params_['kernel']}")
|
| | print(f"\nBest cross-validation F1-score: {grid_search.best_score_:.4f}")
|
| |
|
| |
|
| | print(f"\n{'='*80}")
|
| | print("FINAL MODEL EVALUATION")
|
| | print("="*80)
|
| |
|
| |
|
| | y_train_pred = svm_model.predict(X_train)
|
| | train_accuracy = accuracy_score(y_train, y_train_pred)
|
| | train_f1 = f1_score(y_train, y_train_pred, average='macro')
|
| | train_precision = precision_score(y_train, y_train_pred, average='macro')
|
| | train_recall = recall_score(y_train, y_train_pred, average='macro')
|
| |
|
| |
|
| | y_test_pred = svm_model.predict(X_test)
|
| | test_accuracy = accuracy_score(y_test, y_test_pred)
|
| | test_f1 = f1_score(y_test, y_test_pred, average='macro')
|
| | test_precision = precision_score(y_test, y_test_pred, average='macro')
|
| | test_recall = recall_score(y_test, y_test_pred, average='macro')
|
| |
|
| | print(f"\nTraining Set Performance:")
|
| | print(f" Accuracy: {train_accuracy:.4f}")
|
| | print(f" Precision: {train_precision:.4f}")
|
| | print(f" Recall: {train_recall:.4f}")
|
| | print(f" F1-score: {train_f1:.4f}")
|
| |
|
| | print(f"\nTest Set Performance:")
|
| | print(f" Accuracy: {test_accuracy:.4f}")
|
| | print(f" Precision: {test_precision:.4f}")
|
| | print(f" Recall: {test_recall:.4f}")
|
| | print(f" F1-score: {test_f1:.4f}")
|
| |
|
| | print(f"\nDetailed Classification Report:")
|
| | print(classification_report(y_test, y_test_pred,
|
| | target_names=[f'Class {i}' for i in sorted(np.unique(labels_array))]))
|
| |
|
| | print(f"\nModel Details:")
|
| | print(f" Support vectors: {len(svm_model.support_)}")
|
| | print(f" Support vectors per class: {svm_model.n_support_}")
|
| |
|
| |
|
| | model_path = os.path.join(save_dir, "multiclass_model.pkl")
|
| | with open(model_path, "wb") as f:
|
| | pickle.dump(svm_model, f)
|
| | print(f" ✓ Saved SVM model: {model_path}")
|
| |
|
| |
|
| | pca_path = os.path.join(save_dir, "pca_params.pkl")
|
| | with open(pca_path, "wb") as f:
|
| | pickle.dump(pca_params, f)
|
| | print(f" ✓ Saved PCA params: {pca_path}")
|
| |
|
| |
|
| | print(f"\n{'='*80}")
|
| | print("TRAINING COMPLETE!")
|
| | print("="*80)
|
| | print(f"\nFinal Optimized Results:")
|
| | print(f" Best Parameters: C={grid_search.best_params_['C']}, gamma={grid_search.best_params_['gamma']}")
|
| | print(f" CV F1-score: {grid_search.best_score_:.4f}")
|
| | print(f" Test F1-score: {test_f1:.4f}")
|
| | print(f" Test Precision: {test_precision:.4f}")
|
| | print(f" Test Recall: {test_recall:.4f}")
|
| | print(f"\nFiles saved to: {save_dir}")
|
| | print(f"\nNext steps:")
|
| | print(f" 1. Create a 'utils' folder in your HuggingFace repository")
|
| | print(f" 2. Copy utils.py into the 'utils' folder")
|
| | print(f" 3. Copy script.py, multiclass_model.pkl, and pca_params.pkl to the repository root")
|
| | print(f" 4. Create an empty __init__.py file in the 'utils' folder")
|
| | print(f" 5. Submit to competition!")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| |
|
| | BASE_PATH = "C:/Users/anna2/ISM/ANNA/phase1a-data-augmentation"
|
| | IMAGES_FOLDER = "C:/Users/anna2/ISM/Images"
|
| | GT_CSV = "C:/Users/anna2/ISM/Baselines/phase_1a/gt_for_classification_multiclass_from_filenames_0_index.csv"
|
| |
|
| | SAVE_DIR = "C:/Users/anna2/ISM/ANNA/phase1a-data-augmentation"
|
| |
|
| |
|
| | N_COMPONENTS = 250
|
| |
|
| |
|
| | train_and_save_model(BASE_PATH, IMAGES_FOLDER, GT_CSV, SAVE_DIR, N_COMPONENTS) |