Keras
File size: 7,452 Bytes
108cad9
1
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyM+QR5SfGeJQHQQhljZ6cWZ"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","metadata":{"id":"6FsMtyMkFQBH"},"source":["#Training Deep Panel finetune for specific task"]},{"cell_type":"markdown","metadata":{"id":"UPzIJ402VVMd"},"source":["Train a Deep Panel Model\n","\n","Refer to https://github.com/pedrovgs/DeepPanel  in how to build a dataset"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"JFtsOMxP6dJ7"},"outputs":[],"source":["# Install required dependencies\n","!pip install tensorflow\n","!pip install opencv-python-headless\n","!pip install numpy\n","!pip install tqdm\n","\n","# Clone the DeepPanel repository\n","!git clone https://github.com/pedrovgs/DeepPanel.git\n","%cd DeepPanel\n","\n","# Import necessary libraries\n","import os\n","import zipfile\n","from google.colab import drive\n","import tensorflow as tf\n","import cv2\n","import numpy as np\n","from tqdm import tqdm\n","\n","# Mount Google Drive\n","drive.mount('/content/drive')\n","\n","# Unzip the dataset\n","dataset_path = '/content/drive/MyDrive/mia_panel_dataset.zip'\n","extract_path = '/content/mia_panel_dataset'\n","\n","with zipfile.ZipFile(dataset_path, 'r') as zip_ref:\n","    zip_ref.extractall(extract_path)\n","\n","# Remove __MACOSX folder if it exists\n","!rm -rf /content/mia_panel_dataset/__MACOSX\n","\n","# Verify dataset structure\n","!ls /content/mia_panel_dataset\n","#!ls /content/mia_panel_dataset\n","# Verify number of files in training folders\n","!echo \"Training raw images:\"\n","!ls -l /content/mia_panel_dataset/training/raw | wc -l\n","!echo \"Training masks:\"\n","!ls -l /content/mia_panel_dataset/training/segmentation_mask | wc -l\n","# Verify subfolder structure (uncomment if needed for debugging)\n","# !ls /content/mia_panel_dataset/mia_panel_dataset/test\n","# !ls /content/mia_panel_dataset/mia_panel_dataset/training\n","\n","# Create necessary directories for model checkpoints\n","os.makedirs('checkpoints', exist_ok=True)\n","\n","# Define dataset paths (updated for nested mia_panel_dataset folder)\n","train_raw_path = '/content/mia_panel_dataset/training/raw'\n","train_mask_path = '/content/mia_panel_dataset/training/segmentation_mask'\n","test_raw_path = '/content/mia_panel_dataset/test/raw'\n","test_mask_path = '/content/mia_panel_dataset/test/segmentation_mask'\n","\n","# Define configuration\n","class Config:\n","    INPUT_SHAPE = (256, 256, 3)  # Adjust based on your image size\n","    BATCH_SIZE = 5\n","    EPOCHS = 200\n","    LEARNING_RATE = 1e-4\n","    MODEL_PATH = 'checkpoints/model.keras'  # Updated to .keras format\n","\n","# Custom data loader\n","def load_image_and_mask(image_path, mask_path, target_size):\n","    image = tf.io.read_file(image_path)\n","    image = tf.image.decode_png(image, channels=3)\n","    image = tf.image.resize(image, target_size[:2])\n","    image = image / 255.0  # Normalize to [0, 1]\n","\n","    mask = tf.io.read_file(mask_path)\n","    mask = tf.image.decode_png(mask, channels=1)\n","    mask = tf.image.resize(mask, target_size[:2], method='nearest')\n","    mask = tf.cast(mask, tf.float32)\n","    # Normalize mask to [0, 1] for binary segmentation\n","    mask = mask / tf.reduce_max(mask)  # Ensure mask values are [0, 1]\n","    mask = tf.where(mask > 0.5, 1.0, 0.0)  # Binarize mask\n","\n","    return image, mask\n","\n","def create_dataset(raw_path, mask_path, batch_size, input_shape, is_train=True):\n","    image_files = sorted([os.path.join(raw_path, f) for f in os.listdir(raw_path) if f.endswith(('.png', '.jpg', '.jpeg'))])\n","    mask_files = sorted([os.path.join(mask_path, f) for f in os.listdir(mask_path) if f.endswith(('.png', '.jpg', '.jpeg'))])\n","\n","    # Ensure matching pairs\n","    print(f\"Found {len(image_files)} images and {len(mask_files)} masks\")\n","    assert len(image_files) == len(mask_files), \"Number of images and masks must match\"\n","    assert len(image_files) > 0, \"No images found in dataset\"\n","\n","    dataset = tf.data.Dataset.from_tensor_slices((image_files, mask_files))\n","    dataset = dataset.map(\n","        lambda x, y: load_image_and_mask(x, y, input_shape),\n","        num_parallel_calls=tf.data.AUTOTUNE\n","    )\n","\n","    if is_train:\n","        dataset = dataset.shuffle(buffer_size=1000)\n","\n","    dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)\n","    return dataset\n","\n","# Create datasets\n","train_dataset = create_dataset(\n","    train_raw_path, train_mask_path, Config.BATCH_SIZE, Config.INPUT_SHAPE, is_train=True\n",")\n","test_dataset = create_dataset(\n","    test_raw_path, test_mask_path, Config.BATCH_SIZE, Config.INPUT_SHAPE, is_train=False\n",")\n","\n","# Inspect a sample mask to verify format (optional debugging)\n","sample_image, sample_mask = next(iter(train_dataset))\n","print(f\"Sample mask shape: {sample_mask.shape}, min: {tf.reduce_min(sample_mask)}, max: {tf.reduce_max(sample_mask)}\")\n","\n","# Define the model (simplified U-Net inspired by DeepPanel's segmentation goal)\n","def build_model(input_shape):\n","    inputs = tf.keras.Input(shape=input_shape)\n","\n","    # Encoder\n","    c1 = tf.keras.layers.Conv2D(64, 3, padding='same', activation='relu')(inputs)\n","    c1 = tf.keras.layers.Conv2D(64, 3, padding='same', activation='relu')(c1)\n","    p1 = tf.keras.layers.MaxPooling2D()(c1)\n","\n","    c2 = tf.keras.layers.Conv2D(128, 3, padding='same', activation='relu')(p1)\n","    c2 = tf.keras.layers.Conv2D(128, 3, padding='same', activation='relu')(c2)\n","    p2 = tf.keras.layers.MaxPooling2D()(c2)\n","\n","    # Bottleneck\n","    b = tf.keras.layers.Conv2D(256, 3, padding='same', activation='relu')(p2)\n","    b = tf.keras.layers.Conv2D(256, 3, padding='same', activation='relu')(b)\n","\n","    # Decoder\n","    u1 = tf.keras.layers.Conv2DTranspose(128, 2, strides=2, padding='same')(b)\n","    u1 = tf.keras.layers.Concatenate()([u1, c2])\n","    c3 = tf.keras.layers.Conv2D(128, 3, padding='same', activation='relu')(u1)\n","    c3 = tf.keras.layers.Conv2D(128, 3, padding='same', activation='relu')(c3)\n","\n","    u2 = tf.keras.layers.Conv2DTranspose(64, 2, strides=2, padding='same')(c3)\n","    u2 = tf.keras.layers.Concatenate()([u2, c1])\n","    c4 = tf.keras.layers.Conv2D(64, 3, padding='same', activation='relu')(u2)\n","    c4 = tf.keras.layers.Conv2D(64, 3, padding='same', activation='relu')(c4)\n","\n","    outputs = tf.keras.layers.Conv2D(1, 1, activation='sigmoid')(c4)  # Binary segmentation\n","    model = tf.keras.Model(inputs, outputs)\n","    return model\n","\n","# Build and compile model\n","model = build_model(Config.INPUT_SHAPE)\n","model.compile(\n","    optimizer=tf.keras.optimizers.Adam(learning_rate=Config.LEARNING_RATE),\n","    loss='binary_crossentropy',  # Adjust if masks are multi-class\n","    metrics=['accuracy']\n",")\n","\n","# Train the model\n","history = model.fit(\n","    train_dataset,\n","    validation_data=test_dataset,\n","    epochs=Config.EPOCHS,\n","    callbacks=[\n","        tf.keras.callbacks.ModelCheckpoint(\n","            Config.MODEL_PATH, save_best_only=True, monitor='val_loss'\n","        )\n","    ]\n",")\n","\n","# Save the trained model to Google Drive\n","!cp checkpoints/model.keras /content/drive/MyDrive/deeppanel_model.keras"]}]}