| | import numpy as np |
| | from pathlib import Path |
| | from scipy.special import gamma |
| | from typing import Optional, Tuple, Dict, List, Union |
| | import torch |
| | import os |
| |
|
| | class GeneralizedGaussianMixture: |
| | r"""广义高斯混合分布数据集生成器 |
| | P_{\theta_k}(x_i) = \eta_k \exp(-s_k d_k(x_i)) = \frac{p}{2\alpha_k \Gamma(1/p)}\exp(-|\frac{x_i-c_k}{\alpha_k}|^p) |
| | """ |
| | |
| | def __init__(self, |
| | D: int = 2, |
| | K: int = 3, |
| | p: float = 2.0, |
| | centers: Optional[np.ndarray] = None, |
| | scales: Optional[np.ndarray] = None, |
| | weights: Optional[np.ndarray] = None, |
| | seed: int = 42): |
| | """初始化GMM数据集生成器 |
| | Args: |
| | D: 数据维度 |
| | K: 聚类数量 |
| | p: 幂次参数,控制分布的形状 |
| | centers: 聚类中心,形状为(K, D) |
| | scales: 尺度参数,形状为(K, D) |
| | weights: 混合权重,形状为(K,) |
| | seed: 随机种子 |
| | """ |
| | self.D = D |
| | self.K = K |
| | self.p = p |
| | self.seed = seed |
| | np.random.seed(seed) |
| | |
| | |
| | if centers is None: |
| | self.centers = np.random.randn(K, D) * 2 |
| | else: |
| | self.centers = centers |
| | |
| | if scales is None: |
| | self.scales = np.random.uniform(0.1, 0.5, size=(K, D)) |
| | else: |
| | self.scales = scales |
| | |
| | if weights is None: |
| | self.weights = np.random.dirichlet(np.ones(K)) |
| | else: |
| | self.weights = weights / weights.sum() |
| | |
| | def component_pdf(self, x: np.ndarray, k: int) -> np.ndarray: |
| | """计算第k个分量的概率密度 |
| | Args: |
| | x: 输入数据点,形状为(N, D) |
| | k: 分量索引 |
| | Returns: |
| | 概率密度值,形状为(N,) |
| | """ |
| | |
| | norm_const = self.p / (2 * self.scales[k] * gamma(1/self.p)) |
| | |
| | |
| | z = np.abs(x - self.centers[k]) / self.scales[k] |
| | exp_term = np.exp(-np.sum(z**self.p, axis=1)) |
| | |
| | return np.prod(norm_const) * exp_term |
| | |
| | def pdf(self, x: np.ndarray) -> np.ndarray: |
| | """计算混合分布的概率密度 |
| | Args: |
| | x: 输入数据点,形状为(N, D) |
| | Returns: |
| | 概率密度值,形状为(N,) |
| | """ |
| | density = np.zeros(len(x)) |
| | for k in range(self.K): |
| | density += self.weights[k] * self.component_pdf(x, k) |
| | return density |
| | |
| | def generate_component_samples(self, n: int, k: int) -> np.ndarray: |
| | """从第k个分量生成样本 |
| | Args: |
| | n: 样本数量 |
| | k: 分量索引 |
| | Returns: |
| | 样本点,形状为(n, D) |
| | """ |
| | |
| | u = np.random.uniform(-1, 1, size=(n, self.D)) |
| | r = np.abs(u) ** (1/self.p) |
| | samples = self.centers[k] + self.scales[k] * np.sign(u) * r |
| | return samples |
| | |
| | def generate_samples(self, N: int) -> Tuple[np.ndarray, np.ndarray]: |
| | """生成混合分布的样本 |
| | Args: |
| | N: 总样本数量 |
| | Returns: |
| | X: 生成的数据点,形状为(N, D) |
| | y: 对应的概率密度值,形状为(N,) |
| | """ |
| | |
| | n_samples = np.random.multinomial(N, self.weights) |
| | |
| | |
| | samples = [] |
| | for k in range(self.K): |
| | x = self.generate_component_samples(n_samples[k], k) |
| | samples.append(x) |
| | |
| | |
| | X = np.vstack(samples) |
| | idx = np.random.permutation(N) |
| | X = X[idx] |
| | |
| | |
| | y = self.pdf(X) |
| | |
| | return X, y |
| | |
| | def save_dataset(self, save_dir: Union[str, Path], name: str = 'gmm_dataset') -> None: |
| | """保存数据集到文件 |
| | Args: |
| | save_dir: 保存目录 |
| | name: 数据集名称 |
| | """ |
| | save_path = Path(save_dir) |
| | save_path.mkdir(parents=True, exist_ok=True) |
| | |
| | |
| | X, y = self.generate_samples(N=1000) |
| | np.savez(str(save_path / f'{name}.npz'), |
| | X=X, y=y, |
| | centers=self.centers, |
| | scales=self.scales, |
| | weights=self.weights, |
| | D=self.D, |
| | K=self.K, |
| | p=self.p) |
| | |
| | @classmethod |
| | def load_dataset(cls, file_path: Union[str, Path]) -> "GeneralizedGaussianMixture": |
| | """从文件加载数据集 |
| | Args: |
| | file_path: 数据文件路径 |
| | Returns: |
| | 加载的GMM对象 |
| | """ |
| | data = np.load(str(file_path)) |
| | return cls( |
| | D=int(data['D']), |
| | K=int(data['K']), |
| | p=float(data['p']), |
| | centers=data['centers'], |
| | scales=data['scales'], |
| | weights=data['weights'] |
| | ) |
| |
|
| | def test_gmm_dataset(): |
| | """测试GMM数据集生成器""" |
| | |
| | gmm = GeneralizedGaussianMixture( |
| | D=2, |
| | K=3, |
| | p=2.0, |
| | centers=np.array([[-2, -2], [0, 0], [2, 2]]), |
| | scales=np.array([[0.3, 0.3], [0.2, 0.2], [0.4, 0.4]]), |
| | weights=np.array([0.3, 0.4, 0.3]) |
| | ) |
| | |
| | |
| | X, y = gmm.generate_samples(1000) |
| | |
| | |
| | gmm.save_dataset('test_data') |
| | |
| | |
| | loaded_gmm = GeneralizedGaussianMixture.load_dataset('test_data/gmm_dataset.npz') |
| | |
| | |
| | assert np.allclose(gmm.centers, loaded_gmm.centers) |
| | assert np.allclose(gmm.scales, loaded_gmm.scales) |
| | assert np.allclose(gmm.weights, loaded_gmm.weights) |
| | |
| | print("GMM数据集测试通过!") |
| |
|
| | if __name__ == '__main__': |
| | test_gmm_dataset() |