| | import yaml |
| | import os |
| | import torch |
| | import random |
| | import numpy as np |
| |
|
| | class BaseConfig: |
| | def __init__(self): |
| | config_path = os.path.join(os.path.dirname(__file__), 'config.yml') |
| | with open(config_path, 'r') as file: |
| | self.config = yaml.safe_load(file) |
| | |
| | self.setup_environment() |
| | |
| | def setup_environment(self): |
| | seed = 42 |
| | random.seed(seed) |
| | np.random.seed(seed) |
| | torch.manual_seed(seed) |
| | |
| | os.environ['CUDA_VISIBLE_DEVICES'] = self.config["gpu"]["visible_device"] |
| | self.device = torch.device(self.config["gpu"]["device"]) |
| | if torch.cuda.is_available(): |
| | torch.cuda.manual_seed_all(seed) |
| | torch.set_float32_matmul_precision("medium") |
| | |
| | def custom_collate(self, batch): |
| | """Handles variable size of the scans and pads the sequence dimension.""" |
| | images = [item['image'] for item in batch] |
| | labels = [item['label'] for item in batch] |
| | |
| | max_len = self.config["data"]["collate"] |
| | padded_images = [] |
| | |
| | for img in images: |
| | pad_size = max_len - img.shape[0] |
| | if pad_size > 0: |
| | padding = torch.zeros((pad_size,) + img.shape[1:]) |
| | img_padded = torch.cat([img, padding], dim=0) |
| | padded_images.append(img_padded) |
| | else: |
| | padded_images.append(img) |
| |
|
| | return {"image": torch.stack(padded_images, dim=0), "label": torch.stack(labels)} |
| |
|
| | def get_config(self): |
| | return self.config |