Yingtao-Zheng commited on
Commit
fd607ef
·
1 Parent(s): 2aa12be

Merge prepare_dataset.py from feature/ui-fix

Browse files
Files changed (1) hide show
  1. data_preparation/prepare_dataset.py +23 -18
data_preparation/prepare_dataset.py CHANGED
@@ -5,22 +5,25 @@ import numpy as np
5
  from sklearn.preprocessing import StandardScaler
6
  from sklearn.model_selection import train_test_split
7
 
8
- try:
9
- import torch
10
- from torch.utils.data import Dataset, DataLoader
11
- except ImportError: # pragma: no cover
12
- torch = None
13
 
14
- class Dataset: # type: ignore
15
- pass
16
 
17
- class _MissingTorchDataLoader: # type: ignore
18
- def __init__(self, *args, **kwargs):
19
- raise ImportError(
20
- "PyTorch not installed"
21
- )
 
 
 
22
 
23
- DataLoader = _MissingTorchDataLoader # type: ignore
 
 
 
 
24
 
25
  DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
26
 
@@ -38,8 +41,9 @@ SELECTED_FEATURES = {
38
 
39
  class FeatureVectorDataset(Dataset):
40
  def __init__(self, features: np.ndarray, labels: np.ndarray):
41
- self.features = torch.tensor(features, dtype=torch.float32)
42
- self.labels = torch.tensor(labels, dtype=torch.long)
 
43
 
44
  def __len__(self):
45
  return len(self.labels)
@@ -217,6 +221,7 @@ def get_numpy_splits(model_name: str, split_ratios=(0.7, 0.15, 0.15), seed: int
217
 
218
  def get_dataloaders(model_name: str, batch_size: int = 32, split_ratios=(0.7, 0.15, 0.15), seed: int = 42, scale: bool = True):
219
  """Return PyTorch DataLoaders for neural-network models."""
 
220
  features, labels = _load_real_data(model_name)
221
  num_features = features.shape[1]
222
  num_classes = int(labels.max()) + 1
@@ -228,9 +233,9 @@ def get_dataloaders(model_name: str, batch_size: int = 32, split_ratios=(0.7, 0.
228
  val_ds = FeatureVectorDataset(splits["X_val"], splits["y_val"])
229
  test_ds = FeatureVectorDataset(splits["X_test"], splits["y_test"])
230
 
231
- train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
232
- val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
233
- test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)
234
 
235
  return train_loader, val_loader, test_loader, num_features, num_classes, scaler
236
 
 
5
  from sklearn.preprocessing import StandardScaler
6
  from sklearn.model_selection import train_test_split
7
 
8
+ torch = None
9
+ Dataset = object # type: ignore
10
+ DataLoader = None
 
 
11
 
 
 
12
 
13
+ def _require_torch():
14
+ global torch, Dataset, DataLoader
15
+ if torch is None:
16
+ try:
17
+ import torch as _torch
18
+ from torch.utils.data import Dataset as _Dataset, DataLoader as _DataLoader
19
+ except ImportError as exc: # pragma: no cover
20
+ raise ImportError("PyTorch not installed") from exc
21
 
22
+ torch = _torch
23
+ Dataset = _Dataset # type: ignore
24
+ DataLoader = _DataLoader # type: ignore
25
+
26
+ return torch, Dataset, DataLoader
27
 
28
  DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
29
 
 
41
 
42
  class FeatureVectorDataset(Dataset):
43
  def __init__(self, features: np.ndarray, labels: np.ndarray):
44
+ torch_mod, _, _ = _require_torch()
45
+ self.features = torch_mod.tensor(features, dtype=torch_mod.float32)
46
+ self.labels = torch_mod.tensor(labels, dtype=torch_mod.long)
47
 
48
  def __len__(self):
49
  return len(self.labels)
 
221
 
222
  def get_dataloaders(model_name: str, batch_size: int = 32, split_ratios=(0.7, 0.15, 0.15), seed: int = 42, scale: bool = True):
223
  """Return PyTorch DataLoaders for neural-network models."""
224
+ _, _, dataloader_cls = _require_torch()
225
  features, labels = _load_real_data(model_name)
226
  num_features = features.shape[1]
227
  num_classes = int(labels.max()) + 1
 
233
  val_ds = FeatureVectorDataset(splits["X_val"], splits["y_val"])
234
  test_ds = FeatureVectorDataset(splits["X_test"], splits["y_test"])
235
 
236
+ train_loader = dataloader_cls(train_ds, batch_size=batch_size, shuffle=True)
237
+ val_loader = dataloader_cls(val_ds, batch_size=batch_size, shuffle=False)
238
+ test_loader = dataloader_cls(test_ds, batch_size=batch_size, shuffle=False)
239
 
240
  return train_loader, val_loader, test_loader, num_features, num_classes, scaler
241