Spaces:
Sleeping
Sleeping
Commit ·
fd607ef
1
Parent(s): 2aa12be
Merge prepare_dataset.py from feature/ui-fix
Browse files
data_preparation/prepare_dataset.py
CHANGED
|
@@ -5,22 +5,25 @@ import numpy as np
|
|
| 5 |
from sklearn.preprocessing import StandardScaler
|
| 6 |
from sklearn.model_selection import train_test_split
|
| 7 |
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
except ImportError: # pragma: no cover
|
| 12 |
-
torch = None
|
| 13 |
|
| 14 |
-
class Dataset: # type: ignore
|
| 15 |
-
pass
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
|
| 26 |
|
|
@@ -38,8 +41,9 @@ SELECTED_FEATURES = {
|
|
| 38 |
|
| 39 |
class FeatureVectorDataset(Dataset):
|
| 40 |
def __init__(self, features: np.ndarray, labels: np.ndarray):
|
| 41 |
-
|
| 42 |
-
self.
|
|
|
|
| 43 |
|
| 44 |
def __len__(self):
|
| 45 |
return len(self.labels)
|
|
@@ -217,6 +221,7 @@ def get_numpy_splits(model_name: str, split_ratios=(0.7, 0.15, 0.15), seed: int
|
|
| 217 |
|
| 218 |
def get_dataloaders(model_name: str, batch_size: int = 32, split_ratios=(0.7, 0.15, 0.15), seed: int = 42, scale: bool = True):
|
| 219 |
"""Return PyTorch DataLoaders for neural-network models."""
|
|
|
|
| 220 |
features, labels = _load_real_data(model_name)
|
| 221 |
num_features = features.shape[1]
|
| 222 |
num_classes = int(labels.max()) + 1
|
|
@@ -228,9 +233,9 @@ def get_dataloaders(model_name: str, batch_size: int = 32, split_ratios=(0.7, 0.
|
|
| 228 |
val_ds = FeatureVectorDataset(splits["X_val"], splits["y_val"])
|
| 229 |
test_ds = FeatureVectorDataset(splits["X_test"], splits["y_test"])
|
| 230 |
|
| 231 |
-
train_loader =
|
| 232 |
-
val_loader =
|
| 233 |
-
test_loader =
|
| 234 |
|
| 235 |
return train_loader, val_loader, test_loader, num_features, num_classes, scaler
|
| 236 |
|
|
|
|
| 5 |
from sklearn.preprocessing import StandardScaler
|
| 6 |
from sklearn.model_selection import train_test_split
|
| 7 |
|
| 8 |
+
torch = None
|
| 9 |
+
Dataset = object # type: ignore
|
| 10 |
+
DataLoader = None
|
|
|
|
|
|
|
| 11 |
|
|
|
|
|
|
|
| 12 |
|
| 13 |
+
def _require_torch():
|
| 14 |
+
global torch, Dataset, DataLoader
|
| 15 |
+
if torch is None:
|
| 16 |
+
try:
|
| 17 |
+
import torch as _torch
|
| 18 |
+
from torch.utils.data import Dataset as _Dataset, DataLoader as _DataLoader
|
| 19 |
+
except ImportError as exc: # pragma: no cover
|
| 20 |
+
raise ImportError("PyTorch not installed") from exc
|
| 21 |
|
| 22 |
+
torch = _torch
|
| 23 |
+
Dataset = _Dataset # type: ignore
|
| 24 |
+
DataLoader = _DataLoader # type: ignore
|
| 25 |
+
|
| 26 |
+
return torch, Dataset, DataLoader
|
| 27 |
|
| 28 |
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
|
| 29 |
|
|
|
|
| 41 |
|
| 42 |
class FeatureVectorDataset(Dataset):
|
| 43 |
def __init__(self, features: np.ndarray, labels: np.ndarray):
|
| 44 |
+
torch_mod, _, _ = _require_torch()
|
| 45 |
+
self.features = torch_mod.tensor(features, dtype=torch_mod.float32)
|
| 46 |
+
self.labels = torch_mod.tensor(labels, dtype=torch_mod.long)
|
| 47 |
|
| 48 |
def __len__(self):
|
| 49 |
return len(self.labels)
|
|
|
|
| 221 |
|
| 222 |
def get_dataloaders(model_name: str, batch_size: int = 32, split_ratios=(0.7, 0.15, 0.15), seed: int = 42, scale: bool = True):
|
| 223 |
"""Return PyTorch DataLoaders for neural-network models."""
|
| 224 |
+
_, _, dataloader_cls = _require_torch()
|
| 225 |
features, labels = _load_real_data(model_name)
|
| 226 |
num_features = features.shape[1]
|
| 227 |
num_classes = int(labels.max()) + 1
|
|
|
|
| 233 |
val_ds = FeatureVectorDataset(splits["X_val"], splits["y_val"])
|
| 234 |
test_ds = FeatureVectorDataset(splits["X_test"], splits["y_test"])
|
| 235 |
|
| 236 |
+
train_loader = dataloader_cls(train_ds, batch_size=batch_size, shuffle=True)
|
| 237 |
+
val_loader = dataloader_cls(val_ds, batch_size=batch_size, shuffle=False)
|
| 238 |
+
test_loader = dataloader_cls(test_ds, batch_size=batch_size, shuffle=False)
|
| 239 |
|
| 240 |
return train_loader, val_loader, test_loader, num_features, num_classes, scaler
|
| 241 |
|