Spaces:
Runtime error
Runtime error
| title: Erav2s13 | |
| emoji: 🔥 | |
| colorFrom: yellow | |
| colorTo: red | |
| sdk: gradio | |
| sdk_version: 4.27.0 | |
| app_file: app.py | |
| pinned: false | |
| license: mit | |
| # Erav2s13- SOUTRIK 🔥 | |
| ## Overview | |
| This repository leverages the Hugging Face repository and Gradio for building a user interface (UI). The model training was conducted using Google Colab, and the resulting model files are utilized for inference in the Gradio app. | |
| - **Model Training**: `Main.ipynb` - Colab notebook used to build and train the model. | |
| - **Inference**: The same model structure and files are used in the Gradio app. | |
| ## Custom ResNet Model | |
| The `custom_resnet.py` file defines a custom ResNet (Residual Network) model using PyTorch Lightning. This model is specifically designed for image classification tasks, particularly for the CIFAR-10 dataset. | |
| ### Model Architecture | |
| The custom ResNet model comprises the following components: | |
| 1. **Preparation Layer**: Convolutional layer with 64 filters, followed by batch normalization, ReLU activation, and dropout. | |
| 2. **Layer 1**: Convolutional layer with 128 filters, max pooling, batch normalization, ReLU activation, and dropout. Includes a residual block with two convolutional layers (128 filters each), batch normalization, ReLU activation, and dropout. | |
| 3. **Layer 2**: Convolutional layer with 256 filters, max pooling, batch normalization, ReLU activation, and dropout. | |
| 4. **Layer 3**: Convolutional layer with 512 filters, max pooling, batch normalization, ReLU activation, and dropout. Includes a residual block with two convolutional layers (512 filters each), batch normalization, ReLU activation, and dropout. | |
| 5. **Max Pooling**: Max pooling layer with a kernel size of 4. | |
| 6. **Fully Connected Layer**: Flattened output passed through a fully connected layer with 10 output units (for CIFAR-10 classes). | |
| 7. **Softmax**: Log softmax activation function to obtain predicted class probabilities. | |
| ### Training and Evaluation | |
| The model is trained using PyTorch Lightning, which provides a high-level interface for training, validation, and testing. Key components include: | |
| - **Optimizer**: Adam with a learning rate specified by `PREFERRED_START_LR`. | |
| - **Scheduler**: OneCycleLR for learning rate adjustment. | |
| - **Loss and Accuracy**: Cross-entropy loss and accuracy are computed and logged during training, validation, and testing. | |
| ### Misclassified Images | |
| During testing, misclassified images are tracked and stored in a dictionary along with their ground truth and predicted labels, facilitating error analysis and model improvement. | |
| ### Hyperparameters | |
| Key hyperparameters include: | |
| - `PREFERRED_START_LR`: Initial learning rate. | |
| - `PREFERRED_WEIGHT_DECAY`: Weight decay for regularization. | |
| ### Model Summary | |
| The `detailed_model_summary` function prints a comprehensive summary of the model architecture, detailing input size, kernel size, output size, number of parameters, and trainable status of each layer. | |
| ## Lightning Dataset Module | |
| The `lightning_dataset.py` file contains the `CIFARDataModule` class, which is a PyTorch Lightning `LightningDataModule` for the CIFAR-10 dataset. This class handles data preparation, splitting, and loading. | |
| ### CIFARDataModule Class | |
| #### Parameters | |
| - `data_path`: Directory path for CIFAR-10 dataset. | |
| - `batch_size`: Batch size for data loaders. | |
| - `seed`: Random seed for reproducibility. | |
| - `val_split`: Fraction of training data used for validation (default: 0). | |
| - `num_workers`: Number of worker processes for data loading (default: 0). | |
| #### Methods | |
| - `prepare_data`: Downloads CIFAR-10 dataset if not present. | |
| - `setup`: Defines data transformations and creates training, validation, and testing datasets. | |
| - `train_dataloader`: Returns training data loader. | |
| - `val_dataloader`: Returns validation data loader. | |
| - `test_dataloader`: Returns testing data loader. | |
| #### Utility Methods | |
| - `_split_train_val`: Splits training dataset into training and validation subsets. | |
| - `_init_fn`: Initializes random seed for each worker process to ensure reproducibility. | |
| ## License | |
| This project is licensed under the MIT License. | |
| --- | |