| | import gc |
| | import logging |
| | from typing import List, TypeVar |
| |
|
| | import torch |
| | from torch.utils.data import Dataset |
| |
|
| | logger = logging.getLogger(__name__) |
| | T = TypeVar("T") |
| |
|
| |
|
| | def get_torch_device(device: str = "auto") -> str: |
| | """ |
| | Returns the device (string) to be used by PyTorch. |
| | |
| | `device` arg defaults to "auto" which will use: |
| | - "cuda:0" if available |
| | - else "mps" if available |
| | - else "cpu". |
| | """ |
| |
|
| | if device == "auto": |
| | if torch.cuda.is_available(): |
| | device = "cuda:0" |
| | elif torch.backends.mps.is_available(): |
| | device = "mps" |
| | else: |
| | device = "cpu" |
| | logger.info(f"Using device: {device}") |
| |
|
| | return device |
| |
|
| |
|
| | def tear_down_torch(): |
| | """ |
| | Teardown for PyTorch. |
| | Clears GPU cache for both CUDA and MPS. |
| | """ |
| | gc.collect() |
| | torch.cuda.empty_cache() |
| | torch.mps.empty_cache() |
| |
|
| |
|
| | class ListDataset(Dataset[T]): |
| | def __init__(self, elements: List[T]): |
| | self.elements = elements |
| |
|
| | def __len__(self) -> int: |
| | return len(self.elements) |
| |
|
| | def __getitem__(self, idx: int) -> T: |
| | return self.elements[idx] |
| |
|