MatteoFasulo commited on
Commit
7be20ea
·
1 Parent(s): f7b4d24

refactor: improve Ninapro DB5 with progress indication for data augmentation

Browse files
Files changed (1) hide show
  1. scripts/db5.py +14 -6
scripts/db5.py CHANGED
@@ -6,6 +6,8 @@ import numpy as np
6
  import scipy.io
7
  import scipy.signal as signal
8
  from scipy.signal import iirnotch
 
 
9
 
10
  def sequence_to_seconds(seq_len: int, fs: float) -> float:
11
  """Converts a sequence length in samples to time in seconds.
@@ -20,7 +22,9 @@ def sequence_to_seconds(seq_len: int, fs: float) -> float:
20
  return seq_len / fs
21
 
22
 
23
- def random_amplitude_scale(sig: np.ndarray, scale_range: Tuple[float, float] = (0.9, 1.1)) -> np.ndarray:
 
 
24
  """Applies random amplitude scaling to the input signal.
25
 
26
  Args:
@@ -85,7 +89,9 @@ def augment_one_sample(seg: np.ndarray) -> np.ndarray:
85
  return out
86
 
87
 
88
- def augment_train_data(data: np.ndarray, labels: np.ndarray, factor: int = 3) -> Tuple[np.ndarray, np.ndarray]:
 
 
89
  """Augments the training dataset by creating multiple versions of each sample.
90
 
91
  Args:
@@ -104,7 +110,7 @@ def augment_train_data(data: np.ndarray, labels: np.ndarray, factor: int = 3) ->
104
  aug_segs = [data]
105
  aug_lbls = [labels]
106
  N = data.shape[0]
107
- for i in range(N):
108
  seg = data[i] # [window_size, n_ch]
109
  lab = labels[i]
110
  for _ in range(factor):
@@ -115,7 +121,9 @@ def augment_train_data(data: np.ndarray, labels: np.ndarray, factor: int = 3) ->
115
  return new_data, new_labels
116
 
117
 
118
- def notch_filter(data: np.ndarray, notch_freq: float = 50.0, Q: float = 30.0, fs: float = 200.0) -> np.ndarray:
 
 
119
  """Applies a notch filter to remove power line interference.
120
 
121
  Args:
@@ -140,7 +148,7 @@ def bandpass_filter_emg(
140
  lowcut: float = 20.0,
141
  highcut: float = 90.0,
142
  fs: float = 200.0,
143
- order: int = 4
144
  ) -> np.ndarray:
145
  """Applies a Butterworth bandpass filter to the EMG signal.
146
 
@@ -169,7 +177,7 @@ def process_emg_features(
169
  label: np.ndarray,
170
  rerep: np.ndarray,
171
  window_size: int = 1024,
172
- stride: int = 512
173
  ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
174
  """Segments raw EMG signals into overlapping windows.
175
 
 
6
  import scipy.io
7
  import scipy.signal as signal
8
  from scipy.signal import iirnotch
9
+ from tqdm import tqdm
10
+
11
 
12
  def sequence_to_seconds(seq_len: int, fs: float) -> float:
13
  """Converts a sequence length in samples to time in seconds.
 
22
  return seq_len / fs
23
 
24
 
25
+ def random_amplitude_scale(
26
+ sig: np.ndarray, scale_range: Tuple[float, float] = (0.9, 1.1)
27
+ ) -> np.ndarray:
28
  """Applies random amplitude scaling to the input signal.
29
 
30
  Args:
 
89
  return out
90
 
91
 
92
+ def augment_train_data(
93
+ data: np.ndarray, labels: np.ndarray, factor: int = 3
94
+ ) -> Tuple[np.ndarray, np.ndarray]:
95
  """Augments the training dataset by creating multiple versions of each sample.
96
 
97
  Args:
 
110
  aug_segs = [data]
111
  aug_lbls = [labels]
112
  N = data.shape[0]
113
+ for i in tqdm(range(N), desc="Augmenting training data"):
114
  seg = data[i] # [window_size, n_ch]
115
  lab = labels[i]
116
  for _ in range(factor):
 
121
  return new_data, new_labels
122
 
123
 
124
+ def notch_filter(
125
+ data: np.ndarray, notch_freq: float = 50.0, Q: float = 30.0, fs: float = 200.0
126
+ ) -> np.ndarray:
127
  """Applies a notch filter to remove power line interference.
128
 
129
  Args:
 
148
  lowcut: float = 20.0,
149
  highcut: float = 90.0,
150
  fs: float = 200.0,
151
+ order: int = 4,
152
  ) -> np.ndarray:
153
  """Applies a Butterworth bandpass filter to the EMG signal.
154
 
 
177
  label: np.ndarray,
178
  rerep: np.ndarray,
179
  window_size: int = 1024,
180
+ stride: int = 512,
181
  ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
182
  """Segments raw EMG signals into overlapping windows.
183