| import numpy as np |
| from numpy.fft import rfft |
| from numpy.lib.stride_tricks import as_strided |
| from scipy.signal import get_window |
|
|
| def stft_multi( |
| x, |
| fs: float, |
| win_s: float = 0.032, |
| hop_s: float = 0.010, |
| nfft: int | None = None, |
| window: str | tuple | np.ndarray = "hann", |
| center: bool = True, |
| pad_mode: str = "reflect", |
| out_dtype = np.complex64, |
| ): |
| """ |
| Multichannel STFT (vectorized). |
| Args |
| ---- |
| x : np.ndarray, shape (N, C) time-domain signal |
| fs : float, sampling rate (Hz) |
| win_s : float, window length in seconds (default 32 ms) |
| hop_s : float, hop length in seconds (default 10 ms) |
| nfft : int or None. If None, uses next power of two >= frame_len |
| window : str/tuple/array for scipy.signal.get_window or a length-L array |
| center : if True, pad by L//2 on both sides (librosa-style) |
| pad_mode: np.pad mode (e.g., "reflect", "constant") |
| out_dtype: dtype for STFT output (complex64 recommended) |
| |
| Returns |
| ------- |
| X : np.ndarray, shape (T, C, F) complex STFT |
| freqs: np.ndarray, shape (F,) frequency bins in Hz |
| times: np.ndarray, shape (T,) frame center times in seconds |
| """ |
| x = np.asarray(x) |
| if x.ndim == 1: |
| x = x[:, None] |
| assert x.ndim == 2, "x must be (samples, channels)" |
| N, C = x.shape |
|
|
| |
| frame_len = int(round(win_s * fs)) |
| hop = int(round(hop_s * fs)) |
| if frame_len <= 0 or hop <= 0: |
| raise ValueError("win_s and hop_s must be > 0") |
|
|
| |
| def _next_pow2(n): |
| return 1 << (int(n - 1).bit_length()) |
| nfft = _next_pow2(frame_len) if nfft is None else int(nfft) |
| if nfft < frame_len: |
| raise ValueError("nfft must be >= frame_len") |
|
|
| |
| if isinstance(window, np.ndarray): |
| w = window.astype(float, copy=False) |
| else: |
| w = get_window(window, frame_len, fftbins=True).astype(float) |
| if w.shape[0] != frame_len: |
| raise ValueError("Provided window length != frame_len") |
|
|
| |
| pad = frame_len // 2 if center else 0 |
| if pad > 0: |
| x_pad = np.pad(x, ((pad, pad), (0, 0)), mode=pad_mode) |
| else: |
| x_pad = x |
|
|
| Np = x_pad.shape[0] |
| if Np < frame_len: |
| |
| x_pad = np.pad(x_pad, ((0, frame_len - Np), (0, 0)), mode=pad_mode) |
| Np = x_pad.shape[0] |
|
|
| |
| T = 1 + (Np - frame_len) // hop |
| if T <= 0: |
| raise ValueError("Signal too short for given window/hop") |
|
|
| |
| s_t, s_c = x_pad.strides |
| frames = as_strided( |
| x_pad, |
| shape=(T, frame_len, C), |
| strides=(hop * s_t, s_t, s_c), |
| writeable=False, |
| ) |
| |
| frames = np.transpose(frames, (0, 2, 1)) |
|
|
| |
| frames = frames * w[None, None, :] |
|
|
| |
| X = rfft(frames, n=nfft, axis=-1).astype(out_dtype, copy=False) |
|
|
| |
| F = X.shape[-1] |
| freqs = (fs / nfft) * np.arange(F) |
| |
| if center: |
| |
| times = (np.arange(T) * hop) / fs |
| else: |
| |
| times = (np.arange(T) * hop + frame_len / 2.0) / fs |
|
|
| return X, freqs, times |
|
|
|
|
|
|
| def _wrap_to_2pi(x: np.ndarray) -> np.ndarray: |
| """Wrap angles to [0, 2π).""" |
| return np.mod(x, 2.0 * np.pi) |
|
|
| def compute_mag_phase( |
| X: np.ndarray, |
| dtype=np.float32, |
| ): |
| """ |
| Per-channel magnitude and absolute phase (wrapped to [0, 2π)). |
| |
| Args |
| ---- |
| X : np.ndarray, shape (T, C, F), complex STFT |
| dtype: output dtype |
| |
| Returns |
| ------- |
| mag : np.ndarray, shape (T, C, F) = |X| |
| phase : np.ndarray, shape (T, C, F) = angle(X) in [0, 2π) |
| """ |
| assert X.ndim == 3, "X must be (T, C, F)" |
| mag = np.abs(X).astype(dtype, copy=False) |
| phase = _wrap_to_2pi(np.angle(X)).astype(dtype, copy=False) |
| return mag, phase |
|
|
| def compute_mag_phase_cos_sin( |
| X: np.ndarray, |
| dtype=np.float32, |
| ): |
| """ |
| Concatenate per-channel magnitude, cos(phase), sin(phase). |
| |
| Args |
| ---- |
| X : np.ndarray, shape (T, C, F), complex STFT |
| dtype: output dtype |
| |
| Returns |
| ------- |
| feats : np.ndarray, shape (T, 3*C, F) |
| Layout = [mag (C), cos(phase) (C), sin(phase) (C)] |
| where phase is angle(X) wrapped to [0, 2π). |
| """ |
| mag, phase = compute_mag_phase(X, dtype=dtype) |
| cos_phase = np.cos(phase).astype(dtype, copy=False) |
| sin_phase = np.sin(phase).astype(dtype, copy=False) |
| feats = np.concatenate([mag, cos_phase, sin_phase], axis=1) |
| return feats |
|
|
| def compute_real_imag_features( |
| X: np.ndarray, |
| dtype=np.float32, |
| ): |
| """ |
| Concatenate per-channel real and imaginary parts. |
| |
| Args |
| ---- |
| X : np.ndarray, shape (T, C, F), complex STFT |
| dtype: output dtype |
| |
| Returns |
| ------- |
| feats : np.ndarray, shape (T, 2*C, F) |
| Layout = [Re (C), Im (C)] |
| """ |
| assert X.ndim == 3, "X must be (T, C, F)" |
| real = X.real.astype(dtype, copy=False) |
| imag = X.imag.astype(dtype, copy=False) |
| feats = np.concatenate([real, imag], axis=1) |
| return feats |
|
|
|
|