| import pandas as pd |
| import albumentations as A |
| from typing import Optional, List |
| from sklearn.model_selection import train_test_split |
| from torch.utils.data import DataLoader |
| from torchgeo.datamodules import NonGeoDataModule |
| from methane_classification_dataset import MethaneClassificationDataset |
|
|
| class MethaneClassificationDataModule(NonGeoDataModule): |
| def __init__( |
| self, |
| data_root: str, |
| excel_file: str, |
| batch_size: int = 8, |
| num_workers: int = 0, |
| val_split: float = 0.2, |
| seed: int = 42, |
| **kwargs |
| ): |
| |
| |
| super().__init__(MethaneClassificationDataset, batch_size, num_workers, **kwargs) |
|
|
| self.data_root = data_root |
| self.excel_file = excel_file |
| self.val_split = val_split |
| self.seed = seed |
| self.batch_size = batch_size |
| self.num_workers = num_workers |
| |
| |
| self.train_paths = [] |
| self.val_paths = [] |
|
|
| def _get_training_transforms(self): |
| """Internal definition of training transforms""" |
| return A.Compose([ |
| A.ElasticTransform(p=0.25), |
| A.RandomRotate90(p=0.5), |
| A.Flip(p=0.5), |
| A.ShiftScaleRotate(rotate_limit=90, shift_limit_x=0.05, shift_limit_y=0.05, p=0.5) |
| ]) |
|
|
| def setup(self, stage: str = None): |
| |
| try: |
| df = pd.read_csv(self.excel_file) if self.excel_file.endswith('.csv') else pd.read_excel(self.excel_file) |
| except Exception as e: |
| raise RuntimeError(f"Failed to load summary file: {e}") |
|
|
| |
| |
| |
| all_paths = df['Filename'].tolist() |
|
|
| |
| self.train_paths, self.val_paths = train_test_split( |
| all_paths, |
| test_size=self.val_split, |
| random_state=self.seed |
| ) |
|
|
| |
| if stage in ("fit", "train"): |
| self.train_dataset = MethaneClassificationDataset( |
| root_dir=self.data_root, |
| excel_file=self.excel_file, |
| paths=self.train_paths, |
| transform=self._get_training_transforms(), |
| ) |
| |
| if stage in ("fit", "validate", "val"): |
| self.val_dataset = MethaneClassificationDataset( |
| root_dir=self.data_root, |
| excel_file=self.excel_file, |
| paths=self.val_paths, |
| transform=None, |
| ) |
|
|
| if stage in ("test", "predict"): |
| |
| |
| self.test_dataset = MethaneClassificationDataset( |
| root_dir=self.data_root, |
| excel_file=self.excel_file, |
| paths=self.val_paths, |
| transform=None, |
| ) |
|
|
|
|
| def train_dataloader(self): |
| return DataLoader( |
| self.train_dataset, |
| batch_size=self.batch_size, |
| shuffle=True, |
| num_workers=self.num_workers, |
| drop_last=True |
| ) |
|
|
| def val_dataloader(self): |
| return DataLoader( |
| self.val_dataset, |
| batch_size=self.batch_size, |
| shuffle=False, |
| num_workers=self.num_workers, |
| drop_last=True |
| ) |
|
|
| def test_dataloader(self): |
| return DataLoader( |
| self.test_dataset, |
| batch_size=self.batch_size, |
| shuffle=False, |
| num_workers=self.num_workers, |
| drop_last=True |
| ) |
|
|