| | """ |
| | Separate test classes for each BranchSBM experiment with specific plotting styles. |
| | Each class handles testing and visualization for: LiDAR, Mouse, Clonidine, Trametinib, Veres. |
| | """ |
| |
|
| | import os |
| | import json |
| | import csv |
| | import torch |
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| | import pytorch_lightning as pl |
| | import random |
| | import ot |
| | from torchdyn.core import NeuralODE |
| | from matplotlib.colors import LinearSegmentedColormap |
| | from matplotlib.collections import LineCollection |
| | from .networks.utils import flow_model_torch_wrapper |
| | from .branch_flow_net_train import BranchFlowNetTrainBase |
| | from .branch_growth_net_train import GrowthNetTrain |
| | from .utils import wasserstein, mix_rbf_mmd2, plot_lidar |
| | import json |
| |
|
| | def evaluate_model(gt_data, model_data, a, b): |
| | |
| | if not isinstance(gt_data, torch.Tensor): |
| | gt_data = torch.tensor(gt_data, dtype=torch.float32) |
| | if not isinstance(model_data, torch.Tensor): |
| | model_data = torch.tensor(model_data, dtype=torch.float32) |
| |
|
| | |
| | try: |
| | model_dev = model_data.device |
| | except Exception: |
| | model_dev = torch.device('cpu') |
| | try: |
| | gt_dev = gt_data.device |
| | except Exception: |
| | gt_dev = torch.device('cpu') |
| |
|
| | device = model_dev if model_dev.type != 'cpu' else gt_dev |
| |
|
| | gt = gt_data.to(device=device, dtype=torch.float32) |
| | md = model_data.to(device=device, dtype=torch.float32) |
| |
|
| | M = torch.cdist(gt, md, p=2).cpu().numpy() |
| | if np.isnan(M).any() or np.isinf(M).any(): |
| | return np.nan |
| | return ot.emd2(a, b, M, numItermax=1e7) |
| |
|
| | def compute_distribution_distances(pred, true, pred_full=None, true_full=None): |
| | w1 = wasserstein(pred, true, power=1) |
| | w2 = wasserstein(pred, true, power=2) |
| | |
| | |
| | mmd_pred = pred_full if pred_full is not None else pred |
| | mmd_true = true_full if true_full is not None else true |
| | |
| | |
| | n_pred, n_true = mmd_pred.shape[0], mmd_true.shape[0] |
| | if n_pred > n_true: |
| | perm = torch.randperm(n_pred)[:n_true] |
| | mmd_pred = mmd_pred[perm] |
| | elif n_true > n_pred: |
| | perm = torch.randperm(n_true)[:n_pred] |
| | mmd_true = mmd_true[perm] |
| | mmd = mix_rbf_mmd2(mmd_pred, mmd_true, sigma_list=[0.01, 0.1, 1, 10, 100]).item() |
| | |
| | return {"W1": w1, "W2": w2, "MMD": mmd} |
| |
|
| |
|
| | def compute_tmv_from_mass_over_time(mass_over_time, all_endpoints, time_points=None, timepoint_data=None, time_index=None, target_time=None, gt_key_template='t1_{}', weights_over_time=None): |
| |
|
| | if weights_over_time is not None or mass_over_time is not None: |
| | if time_index is None: |
| | if target_time is not None and time_points is not None: |
| | arr = np.array(time_points) |
| | time_index = int(np.argmin(np.abs(arr - float(target_time)))) |
| | else: |
| | |
| | ref_list = weights_over_time if weights_over_time is not None else mass_over_time |
| | time_index = len(ref_list[0]) - 1 |
| | else: |
| | |
| | if time_index is None: |
| | time_index = -1 |
| |
|
| | n_branches = len(all_endpoints) |
| |
|
| | |
| | n_initial = None |
| | if timepoint_data is not None and 't0' in timepoint_data: |
| | try: |
| | n_initial = int(timepoint_data['t0'].shape[0]) |
| | except Exception: |
| | n_initial = None |
| |
|
| | pred_masses = [] |
| | for i in range(n_branches): |
| | |
| | if weights_over_time is not None: |
| | try: |
| | weights_tensor = weights_over_time[i][time_index] |
| | |
| | total_mass = float(weights_tensor.sum().item()) |
| | pred_masses.append(total_mass) |
| | continue |
| | except Exception: |
| | pass |
| | |
| | |
| | mean_w = 1.0 |
| | if mass_over_time is not None: |
| | try: |
| | mean_w = float(mass_over_time[i][time_index]) |
| | except Exception: |
| | mean_w = 1.0 |
| |
|
| | |
| | num_particles = 0 |
| | try: |
| | if hasattr(all_endpoints[i], 'shape'): |
| | num_particles = int(all_endpoints[i].shape[0]) |
| | else: |
| | num_particles = int(len(all_endpoints[i])) |
| | except Exception: |
| | num_particles = 0 |
| |
|
| | pred_masses.append(mean_w * float(num_particles)) |
| |
|
| | |
| | gt_masses = [] |
| | if timepoint_data is not None: |
| | for i in range(n_branches): |
| | key1 = gt_key_template.format(i) |
| | if key1 in timepoint_data: |
| | gt_masses.append(float(timepoint_data[key1].shape[0])) |
| | else: |
| | base_key = gt_key_template.split("_")[0] if '_' in gt_key_template else gt_key_template |
| | if base_key in timepoint_data: |
| | gt_masses.append(float(timepoint_data[base_key].shape[0])) |
| | else: |
| | gt_masses.append(0.0) |
| | else: |
| | gt_masses = [0.0 for _ in range(n_branches)] |
| |
|
| | |
| | if n_initial is None: |
| | s = float(sum(gt_masses)) |
| | if s > 0: |
| | n_initial = s |
| | else: |
| | n_initial = float(sum(pred_masses)) if sum(pred_masses) > 0 else 1.0 |
| |
|
| | pred_fracs = [m / float(n_initial) for m in pred_masses] |
| | gt_fracs = [m / float(n_initial) for m in gt_masses] |
| |
|
| | tmv = 0.5 * float(np.sum(np.abs(np.array(pred_fracs) - np.array(gt_fracs)))) |
| |
|
| | return { |
| | 'time_index': time_index, |
| | 'pred_masses': pred_masses, |
| | 'gt_masses': gt_masses, |
| | 'pred_fracs': pred_fracs, |
| | 'gt_fracs': gt_fracs, |
| | 'tmv': tmv, |
| | } |
| |
|
| |
|
| | class FlowNetTestLidar(GrowthNetTrain): |
| | |
| | def test_step(self, batch, batch_idx): |
| | |
| | if isinstance(batch, (list, tuple)) and len(batch) == 1: |
| | batch = batch[0] |
| | |
| | if isinstance(batch, dict) and "test_samples" in batch: |
| | test_samples = batch["test_samples"] |
| | metric_samples = batch["metric_samples"] |
| | |
| | if isinstance(test_samples, (list, tuple)) and len(test_samples) >= 2 and isinstance(test_samples[-1], int): |
| | test_samples = test_samples[0] |
| | if isinstance(metric_samples, (list, tuple)) and len(metric_samples) >= 2 and isinstance(metric_samples[-1], int): |
| | metric_samples = metric_samples[0] |
| | |
| | if isinstance(test_samples, (list, tuple)) and len(test_samples) == 1: |
| | test_samples = test_samples[0] |
| | main_batch = test_samples |
| | |
| | if isinstance(metric_samples, dict): |
| | metric_batch = list(metric_samples.values()) |
| | elif isinstance(metric_samples, (list, tuple)): |
| | metric_batch = [m[0] if isinstance(m, (list, tuple)) and len(m) == 1 else m for m in metric_samples] |
| | else: |
| | metric_batch = [metric_samples] |
| | elif isinstance(batch, (list, tuple)) and len(batch) == 2: |
| | |
| | |
| | test_samples = batch[0] |
| | metric_samples = batch[1] |
| | |
| | if isinstance(test_samples, dict): |
| | main_batch = test_samples |
| | elif isinstance(test_samples, (list, tuple)): |
| | main_batch = test_samples[0] |
| | else: |
| | main_batch = test_samples |
| | |
| | if isinstance(metric_samples, dict): |
| | metric_batch = list(metric_samples.values()) |
| | elif isinstance(metric_samples, (list, tuple)): |
| | metric_batch = [m[0] if isinstance(m, (list, tuple)) and len(m) == 1 else m for m in metric_samples] |
| | else: |
| | metric_batch = [metric_samples] |
| | else: |
| | |
| | main_batch = batch |
| | metric_batch = [] |
| | |
| | timepoint_data = self.trainer.datamodule.get_timepoint_data() |
| | |
| | if isinstance(main_batch, dict): |
| | device = main_batch["x0"][0].device |
| | else: |
| | device = main_batch[0]["x0"][0].device |
| | |
| | x0_all = self.trainer.datamodule.val_dataloaders["x0"].dataset.tensors[0].to(device) |
| | w0_all = torch.ones(x0_all.shape[0], 1, dtype=torch.float32).to(device) |
| | full_batch = {"x0": (x0_all, w0_all)} |
| | |
| | time_points, all_endpoints, all_trajs, mass_over_time, energy_over_time, weights_over_time = self.get_mass_and_position(full_batch, metric_batch) |
| | |
| | cloud_points = main_batch["dataset"][0] |
| |
|
| | |
| | n_trials = 5 |
| |
|
| | |
| | metrics_dict = {} |
| | for i, endpoints in enumerate(all_endpoints): |
| | true_data_key = f't1_{i+1}' if f't1_{i+1}' in timepoint_data else 't1' |
| | true_data = torch.tensor(timepoint_data[true_data_key], dtype=torch.float32).to(endpoints.device) |
| | |
| | w1_br, w2_br, mmd_br = [], [], [] |
| | for trial in range(n_trials): |
| | n_min = min(endpoints.shape[0], true_data.shape[0]) |
| | perm_pred = torch.randperm(endpoints.shape[0])[:n_min] |
| | perm_gt = torch.randperm(true_data.shape[0])[:n_min] |
| | m = compute_distribution_distances( |
| | endpoints[perm_pred, :2], true_data[perm_gt, :2], |
| | pred_full=endpoints[perm_pred], true_full=true_data[perm_gt] |
| | ) |
| | w1_br.append(m["W1"]); w2_br.append(m["W2"]); mmd_br.append(m["MMD"]) |
| | |
| | metrics_dict[f"branch_{i+1}"] = { |
| | "W1_mean": float(np.mean(w1_br)), "W1_std": float(np.std(w1_br, ddof=1)), |
| | "W2_mean": float(np.mean(w2_br)), "W2_std": float(np.std(w2_br, ddof=1)), |
| | "MMD_mean": float(np.mean(mmd_br)), "MMD_std": float(np.std(mmd_br, ddof=1)), |
| | } |
| | self.log(f"test/W1_branch{i+1}", np.mean(w1_br), on_epoch=True) |
| | print(f"Branch {i+1} — W1: {np.mean(w1_br):.6f}±{np.std(w1_br, ddof=1):.6f}, " |
| | f"W2: {np.mean(w2_br):.6f}±{np.std(w2_br, ddof=1):.6f}, " |
| | f"MMD: {np.mean(mmd_br):.6f}±{np.std(mmd_br, ddof=1):.6f}") |
| | |
| | |
| | all_pred_combined = torch.cat(list(all_endpoints), dim=0) |
| | all_true_list = [] |
| | for i in range(len(all_endpoints)): |
| | true_data_key = f't1_{i+1}' if f't1_{i+1}' in timepoint_data else 't1' |
| | all_true_list.append(torch.tensor(timepoint_data[true_data_key], dtype=torch.float32).to(all_pred_combined.device)) |
| | all_true_combined = torch.cat(all_true_list, dim=0) |
| | |
| | w1_trials, w2_trials, mmd_trials = [], [], [] |
| | for trial in range(n_trials): |
| | n_min = min(all_pred_combined.shape[0], all_true_combined.shape[0]) |
| | perm_pred = torch.randperm(all_pred_combined.shape[0])[:n_min] |
| | perm_gt = torch.randperm(all_true_combined.shape[0])[:n_min] |
| | m = compute_distribution_distances( |
| | all_pred_combined[perm_pred, :2], all_true_combined[perm_gt, :2], |
| | pred_full=all_pred_combined[perm_pred], true_full=all_true_combined[perm_gt] |
| | ) |
| | w1_trials.append(m["W1"]); w2_trials.append(m["W2"]); mmd_trials.append(m["MMD"]) |
| |
|
| | w1_mean, w1_std = np.mean(w1_trials), np.std(w1_trials, ddof=1) |
| | w2_mean, w2_std = np.mean(w2_trials), np.std(w2_trials, ddof=1) |
| | mmd_mean, mmd_std = np.mean(mmd_trials), np.std(mmd_trials, ddof=1) |
| | self.log("test/W1_combined", w1_mean, on_epoch=True) |
| | self.log("test/W2_combined", w2_mean, on_epoch=True) |
| | self.log("test/MMD_combined", mmd_mean, on_epoch=True) |
| | |
| | metrics_dict["combined"] = { |
| | "W1_mean": float(w1_mean), "W1_std": float(w1_std), |
| | "W2_mean": float(w2_mean), "W2_std": float(w2_std), |
| | "MMD_mean": float(mmd_mean), "MMD_std": float(mmd_std), |
| | "n_trials": n_trials, |
| | } |
| | print(f"\n=== Combined ===") |
| | print(f"W1: {w1_mean:.6f} ± {w1_std:.6f}") |
| | print(f"W2: {w2_mean:.6f} ± {w2_std:.6f}") |
| | print(f"MMD: {mmd_mean:.6f} ± {mmd_std:.6f}") |
| |
|
| | |
| | if self.whiten: |
| | cloud_points = torch.tensor( |
| | self.trainer.datamodule.scaler.inverse_transform( |
| | cloud_points.cpu().detach().numpy() |
| | ) |
| | ) |
| |
|
| | |
| | run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name |
| | results_dir = os.path.join(self.args.working_dir, 'results', run_name) |
| | figures_dir = f'{results_dir}/figures' |
| | os.makedirs(figures_dir, exist_ok=True) |
| | |
| | |
| | metrics_path = f'{results_dir}/metrics.json' |
| | with open(metrics_path, 'w') as f: |
| | json.dump(metrics_dict, f, indent=2) |
| | print(f"Metrics saved to {metrics_path}") |
| | |
| | |
| | detailed_csv_path = f'{results_dir}/metrics_detailed.csv' |
| | with open(detailed_csv_path, 'w', newline='') as csvfile: |
| | writer = csv.writer(csvfile) |
| | writer.writerow(['Metric_Group', 'W1_Mean', 'W1_Std', 'W2_Mean', 'W2_Std', 'MMD_Mean', 'MMD_Std']) |
| | for key in sorted(metrics_dict.keys()): |
| | m = metrics_dict[key] |
| | writer.writerow([key, |
| | f'{m.get("W1_mean", m.get("W1", 0)):.6f}', f'{m.get("W1_std", 0):.6f}', |
| | f'{m.get("W2_mean", m.get("W2", 0)):.6f}', f'{m.get("W2_std", 0):.6f}', |
| | f'{m.get("MMD_mean", m.get("MMD", 0)):.6f}', f'{m.get("MMD_std", 0):.6f}']) |
| | print(f"Detailed metrics CSV saved to {detailed_csv_path}") |
| |
|
| | |
| | |
| | |
| | stacked_trajs = [] |
| | for traj_list in all_trajs: |
| | |
| | stacked_traj = torch.stack(traj_list, dim=1) |
| | stacked_trajs.append(stacked_traj) |
| | |
| | |
| | if self.whiten: |
| | stacked_trajs_original = [] |
| | for traj in stacked_trajs: |
| | B, T, D = traj.shape |
| | |
| | traj_flat = traj.reshape(-1, D).cpu().detach().numpy() |
| | traj_inv = self.trainer.datamodule.scaler.inverse_transform(traj_flat) |
| | |
| | traj_inv = torch.tensor(traj_inv).reshape(B, T, D) |
| | stacked_trajs_original.append(traj_inv) |
| | stacked_trajs = stacked_trajs_original |
| |
|
| | |
| | fig = plt.figure(figsize=(10, 8)) |
| | ax = fig.add_subplot(111, projection="3d", computed_zorder=False) |
| | ax.view_init(elev=30, azim=-115, roll=0) |
| | for i, traj in enumerate(stacked_trajs): |
| | plot_lidar(ax, cloud_points, xs=traj, branch_idx=i) |
| | plt.savefig(f'{figures_dir}/{self.args.data_name}_all_branches.png', dpi=300) |
| | plt.close() |
| |
|
| | |
| | for i, traj in enumerate(stacked_trajs): |
| | fig = plt.figure(figsize=(10, 8)) |
| | ax = fig.add_subplot(111, projection="3d", computed_zorder=False) |
| | ax.view_init(elev=30, azim=-115, roll=0) |
| | plot_lidar(ax, cloud_points, xs=traj, branch_idx=i) |
| | plt.savefig(f'{figures_dir}/{self.args.data_name}_branch_{i + 1}.png', dpi=300) |
| | plt.close() |
| |
|
| | print(f"LiDAR figures saved to {figures_dir}") |
| |
|
| |
|
| | class FlowNetTestMouse(GrowthNetTrain): |
| | |
| | def test_step(self, batch, batch_idx): |
| | |
| | if isinstance(batch, dict): |
| | main_batch = batch.get("test_samples", batch) |
| | if isinstance(main_batch, tuple): |
| | main_batch = main_batch[0] |
| | elif isinstance(batch, (list, tuple)) and len(batch) >= 1: |
| | if isinstance(batch[0], dict): |
| | main_batch = batch[0].get("test_samples", batch[0]) |
| | if isinstance(main_batch, tuple): |
| | main_batch = main_batch[0] |
| | else: |
| | main_batch = batch[0][0] |
| | else: |
| | main_batch = batch |
| | |
| | device = main_batch["x0"][0].device |
| | |
| | |
| | x0 = self.trainer.datamodule.val_dataloaders["x0"].dataset.tensors[0].to(device) |
| | |
| | |
| | timepoint_data = self.trainer.datamodule.get_timepoint_data() |
| | |
| | |
| | data_t1 = torch.tensor(timepoint_data['t1'], dtype=torch.float32) |
| | |
| | |
| | custom_colors_1 = ["#05009E", "#A19EFF", "#B83CFF"] |
| | custom_colors_2 = ["#05009E", "#A19EFF", "#50B2D7"] |
| | custom_cmap_1 = LinearSegmentedColormap.from_list("cmap1", custom_colors_1) |
| | custom_cmap_2 = LinearSegmentedColormap.from_list("cmap2", custom_colors_2) |
| | |
| | t_span_full = torch.linspace(0, 1.0, 100).to(device) |
| | all_trajs = [] |
| | |
| | for i, flow_net in enumerate(self.flow_nets): |
| | node = NeuralODE( |
| | flow_model_torch_wrapper(flow_net), |
| | solver="euler", |
| | sensitivity="adjoint", |
| | ).to(device) |
| | |
| | with torch.no_grad(): |
| | traj = node.trajectory(x0, t_span_full).cpu() |
| | |
| | traj = torch.transpose(traj, 0, 1) |
| | all_trajs.append(traj) |
| | |
| | t_span_metric_t1 = torch.linspace(0, 0.5, 50).to(device) |
| | t_span_metric_t2 = torch.linspace(0, 1.0, 100).to(device) |
| | n_trials = 5 |
| |
|
| | |
| | data_t2_branches = [] |
| | for i in range(len(self.flow_nets)): |
| | key = f't2_{i+1}' |
| | if key in timepoint_data: |
| | data_t2_branches.append(torch.tensor(timepoint_data[key], dtype=torch.float32)) |
| | elif i == 0 and 't2' in timepoint_data: |
| | data_t2_branches.append(torch.tensor(timepoint_data['t2'], dtype=torch.float32)) |
| | else: |
| | data_t2_branches.append(None) |
| |
|
| | |
| | data_t2_all_list = [d for d in data_t2_branches if d is not None] |
| | data_t2_combined = torch.cat(data_t2_all_list, dim=0) if data_t2_all_list else None |
| |
|
| | |
| | w1_t1_trials, w2_t1_trials, mmd_t1_trials = [], [], [] |
| |
|
| | for trial in range(n_trials): |
| | all_preds = [] |
| | for i, flow_net in enumerate(self.flow_nets): |
| | node = NeuralODE( |
| | flow_model_torch_wrapper(flow_net), |
| | solver="euler", |
| | sensitivity="adjoint", |
| | ).to(device) |
| | |
| | with torch.no_grad(): |
| | traj = node.trajectory(x0, t_span_metric_t1) |
| | |
| | x_final = traj[-1].cpu() |
| | all_preds.append(x_final) |
| | |
| | preds = torch.cat(all_preds, dim=0) |
| | target_size = preds.shape[0] |
| | perm = torch.randperm(data_t1.shape[0])[:target_size] |
| | data_t1_reduced = data_t1[perm] |
| | |
| | metrics = compute_distribution_distances( |
| | preds[:, :2], data_t1_reduced[:, :2] |
| | ) |
| | w1_t1_trials.append(metrics["W1"]) |
| | w2_t1_trials.append(metrics["W2"]) |
| | mmd_t1_trials.append(metrics["MMD"]) |
| |
|
| | |
| | branch_t2_metrics = {} |
| | for i, flow_net in enumerate(self.flow_nets): |
| | if data_t2_branches[i] is None: |
| | continue |
| | w1_br, w2_br, mmd_br = [], [], [] |
| | for trial in range(n_trials): |
| | node = NeuralODE( |
| | flow_model_torch_wrapper(flow_net), |
| | solver="euler", |
| | sensitivity="adjoint", |
| | ).to(device) |
| | with torch.no_grad(): |
| | traj = node.trajectory(x0, t_span_metric_t2) |
| | x_final = traj[-1].cpu() |
| | gt = data_t2_branches[i] |
| | n_min = min(x_final.shape[0], gt.shape[0]) |
| | perm_pred = torch.randperm(x_final.shape[0])[:n_min] |
| | perm_gt = torch.randperm(gt.shape[0])[:n_min] |
| | m = compute_distribution_distances( |
| | x_final[perm_pred, :2], gt[perm_gt, :2] |
| | ) |
| | w1_br.append(m["W1"]) |
| | w2_br.append(m["W2"]) |
| | mmd_br.append(m["MMD"]) |
| | branch_t2_metrics[f"branch_{i+1}_t2"] = { |
| | "W1_mean": float(np.mean(w1_br)), "W1_std": float(np.std(w1_br, ddof=1)), |
| | "W2_mean": float(np.mean(w2_br)), "W2_std": float(np.std(w2_br, ddof=1)), |
| | "MMD_mean": float(np.mean(mmd_br)), "MMD_std": float(np.std(mmd_br, ddof=1)), |
| | } |
| | print(f"Branch {i+1} @ t2 — W1: {np.mean(w1_br):.6f}±{np.std(w1_br, ddof=1):.6f}, " |
| | f"W2: {np.mean(w2_br):.6f}±{np.std(w2_br, ddof=1):.6f}, " |
| | f"MMD: {np.mean(mmd_br):.6f}±{np.std(mmd_br, ddof=1):.6f}") |
| |
|
| | |
| | w1_t2_trials, w2_t2_trials, mmd_t2_trials = [], [], [] |
| | if data_t2_combined is not None: |
| | for trial in range(n_trials): |
| | all_preds = [] |
| | for i, flow_net in enumerate(self.flow_nets): |
| | node = NeuralODE( |
| | flow_model_torch_wrapper(flow_net), |
| | solver="euler", |
| | sensitivity="adjoint", |
| | ).to(device) |
| | with torch.no_grad(): |
| | traj = node.trajectory(x0, t_span_metric_t2) |
| | all_preds.append(traj[-1].cpu()) |
| | preds = torch.cat(all_preds, dim=0) |
| | n_min = min(preds.shape[0], data_t2_combined.shape[0]) |
| | perm_pred = torch.randperm(preds.shape[0])[:n_min] |
| | perm_gt = torch.randperm(data_t2_combined.shape[0])[:n_min] |
| | m = compute_distribution_distances( |
| | preds[perm_pred, :2], data_t2_combined[perm_gt, :2] |
| | ) |
| | w1_t2_trials.append(m["W1"]) |
| | w2_t2_trials.append(m["W2"]) |
| | mmd_t2_trials.append(m["MMD"]) |
| | |
| | |
| | w1_t1_mean, w1_t1_std = np.mean(w1_t1_trials), np.std(w1_t1_trials, ddof=1) |
| | w2_t1_mean, w2_t1_std = np.mean(w2_t1_trials), np.std(w2_t1_trials, ddof=1) |
| | mmd_t1_mean, mmd_t1_std = np.mean(mmd_t1_trials), np.std(mmd_t1_trials, ddof=1) |
| | |
| | |
| | self.log("test/W1_combined_t1", w1_t1_mean, on_epoch=True) |
| | self.log("test/W2_combined_t1", w2_t1_mean, on_epoch=True) |
| | self.log("test/MMD_combined_t1", mmd_t1_mean, on_epoch=True) |
| | |
| | metrics_dict = { |
| | "combined_t1": { |
| | "W1_mean": float(w1_t1_mean), "W1_std": float(w1_t1_std), |
| | "W2_mean": float(w2_t1_mean), "W2_std": float(w2_t1_std), |
| | "MMD_mean": float(mmd_t1_mean), "MMD_std": float(mmd_t1_std), |
| | "n_trials": n_trials, |
| | } |
| | } |
| | metrics_dict.update(branch_t2_metrics) |
| |
|
| | if w1_t2_trials: |
| | w1_t2_mean, w1_t2_std = np.mean(w1_t2_trials), np.std(w1_t2_trials, ddof=1) |
| | w2_t2_mean, w2_t2_std = np.mean(w2_t2_trials), np.std(w2_t2_trials, ddof=1) |
| | mmd_t2_mean, mmd_t2_std = np.mean(mmd_t2_trials), np.std(mmd_t2_trials, ddof=1) |
| | self.log("test/W1_combined_t2", w1_t2_mean, on_epoch=True) |
| | self.log("test/W2_combined_t2", w2_t2_mean, on_epoch=True) |
| | self.log("test/MMD_combined_t2", mmd_t2_mean, on_epoch=True) |
| | metrics_dict["combined_t2"] = { |
| | "W1_mean": float(w1_t2_mean), "W1_std": float(w1_t2_std), |
| | "W2_mean": float(w2_t2_mean), "W2_std": float(w2_t2_std), |
| | "MMD_mean": float(mmd_t2_mean), "MMD_std": float(mmd_t2_std), |
| | "n_trials": n_trials, |
| | } |
| | |
| | print(f"\n=== Combined @ t1 ===") |
| | print(f"W1: {w1_t1_mean:.6f} ± {w1_t1_std:.6f}") |
| | print(f"W2: {w2_t1_mean:.6f} ± {w2_t1_std:.6f}") |
| | print(f"MMD: {mmd_t1_mean:.6f} ± {mmd_t1_std:.6f}") |
| | if w1_t2_trials: |
| | print(f"\n=== Combined @ t2 ===") |
| | print(f"W1: {w1_t2_mean:.6f} ± {w1_t2_std:.6f}") |
| | print(f"W2: {w2_t2_mean:.6f} ± {w2_t2_std:.6f}") |
| | print(f"MMD: {mmd_t2_mean:.6f} ± {mmd_t2_std:.6f}") |
| |
|
| | |
| | run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name |
| | results_dir = os.path.join(self.args.working_dir, 'results', run_name) |
| | figures_dir = f'{results_dir}/figures' |
| | os.makedirs(figures_dir, exist_ok=True) |
| | |
| | |
| | metrics_path = f'{results_dir}/metrics.json' |
| | with open(metrics_path, 'w') as f: |
| | json.dump(metrics_dict, f, indent=2) |
| | print(f"Metrics saved to {metrics_path}") |
| | |
| | |
| | |
| | detailed_csv_path = f'{results_dir}/metrics_detailed.csv' |
| | with open(detailed_csv_path, 'w', newline='') as csvfile: |
| | writer = csv.writer(csvfile) |
| | writer.writerow(['Metric_Group', 'W1_Mean', 'W1_Std', 'W2_Mean', 'W2_Std', 'MMD_Mean', 'MMD_Std']) |
| | for key in sorted(metrics_dict.keys()): |
| | m = metrics_dict[key] |
| | writer.writerow([key, |
| | f'{m.get("W1_mean", 0):.6f}', f'{m.get("W1_std", 0):.6f}', |
| | f'{m.get("W2_mean", 0):.6f}', f'{m.get("W2_std", 0):.6f}', |
| | f'{m.get("MMD_mean", 0):.6f}', f'{m.get("MMD_std", 0):.6f}']) |
| | print(f"Detailed metrics CSV saved to {detailed_csv_path}") |
| |
|
| | |
| | self._plot_mouse_branches(all_trajs, timepoint_data, figures_dir, custom_cmap_1, custom_cmap_2) |
| |
|
| | |
| | self._plot_mouse_combined(all_trajs, timepoint_data, figures_dir, custom_cmap_1, custom_cmap_2) |
| |
|
| | print(f"Mouse figures saved to {figures_dir}") |
| |
|
| | def _plot_mouse_branches(self, all_trajs, timepoint_data, save_dir, cmap1, cmap2): |
| | """Plot each branch separately with timepoint background.""" |
| | n_branches = len(all_trajs) |
| | branch_names = [f'Branch {i+1}' for i in range(n_branches)] |
| | branch_colors = ['#B83CFF', '#50B2D7'][:n_branches] |
| | cmaps = [cmap1, cmap2][:n_branches] |
| | |
| | |
| | all_trajs_np = [] |
| | for traj in all_trajs: |
| | if isinstance(traj, list): |
| | traj = torch.stack(traj, dim=1) |
| | all_trajs_np.append(traj.cpu().detach().numpy()) |
| | all_trajs = all_trajs_np |
| | |
| | |
| | for key in list(timepoint_data.keys()): |
| | if torch.is_tensor(timepoint_data[key]): |
| | timepoint_data[key] = timepoint_data[key].cpu().numpy() |
| | |
| | |
| | all_coords = [] |
| | for key in ['t0', 't1', 't2', 't2_1', 't2_2']: |
| | if key in timepoint_data: |
| | all_coords.append(timepoint_data[key][:, :2]) |
| | for traj_np in all_trajs: |
| | all_coords.append(traj_np.reshape(-1, traj_np.shape[-1])[:, :2]) |
| | |
| | all_coords = np.concatenate(all_coords, axis=0) |
| | x_min, x_max = all_coords[:, 0].min(), all_coords[:, 0].max() |
| | y_min, y_max = all_coords[:, 1].min(), all_coords[:, 1].max() |
| | |
| | |
| | x_margin = 0.05 * (x_max - x_min) |
| | y_margin = 0.05 * (y_max - y_min) |
| | x_min -= x_margin |
| | x_max += x_margin |
| | y_min -= y_margin |
| | y_max += y_margin |
| |
|
| | for i, traj in enumerate(all_trajs): |
| | fig, ax = plt.subplots(figsize=(10, 8)) |
| | cmap = cmaps[i] |
| | c_end = branch_colors[i] |
| | |
| | |
| | t2_key = f't2_{i+1}' if f't2_{i+1}' in timepoint_data else 't2' |
| | coords_list = [timepoint_data['t0'], timepoint_data['t1'], timepoint_data[t2_key]] |
| | tp_colors = ['#05009E', '#A19EFF', c_end] |
| | tp_labels = ["t=0", "t=1", f"t=2 (branch {i+1})"] |
| | |
| | for coords, color, label in zip(coords_list, tp_colors, tp_labels): |
| | alpha = 0.8 if color == '#05009E' else 0.6 |
| | ax.scatter(coords[:, 0], coords[:, 1], |
| | c=color, s=80, alpha=alpha, marker='x', |
| | label=f'{label} cells', linewidth=1.5) |
| | |
| | |
| | traj_2d = traj[:, :, :2] |
| | n_time = traj_2d.shape[1] |
| | color_vals = cmap(np.linspace(0, 1, n_time)) |
| | segments = [] |
| | seg_colors = [] |
| | for j in range(traj_2d.shape[0]): |
| | pts = traj_2d[j] |
| | segs = np.stack([pts[:-1], pts[1:]], axis=1) |
| | segments.append(segs) |
| | seg_colors.append(color_vals[:-1]) |
| | segments = np.concatenate(segments, axis=0) |
| | seg_colors = np.concatenate(seg_colors, axis=0) |
| | lc = LineCollection(segments, colors=seg_colors, linewidths=2, alpha=0.8) |
| | ax.add_collection(lc) |
| | |
| | |
| | ax.scatter(traj_2d[:, 0, 0], traj_2d[:, 0, 1], |
| | c='#05009E', s=30, marker='o', label='Trajectory Start', |
| | zorder=5, edgecolors='white', linewidth=1) |
| | ax.scatter(traj_2d[:, -1, 0], traj_2d[:, -1, 1], |
| | c=c_end, s=30, marker='o', label='Trajectory End', |
| | zorder=5, edgecolors='white', linewidth=1) |
| | |
| | ax.set_xlim(x_min, x_max) |
| | ax.set_ylim(y_min, y_max) |
| | ax.set_xlabel("PC1", fontsize=12) |
| | ax.set_ylabel("PC2", fontsize=12) |
| | ax.set_title(f"{branch_names[i]}: Trajectories with Timepoint Background", fontsize=14) |
| | ax.grid(True, alpha=0.3) |
| | ax.legend(loc='upper right', fontsize=12, frameon=False) |
| | |
| | plt.tight_layout() |
| | plt.savefig(f'{save_dir}/{self.args.data_name}_branch{i+1}.png', dpi=300) |
| | plt.close() |
| |
|
| | def _plot_mouse_combined(self, all_trajs, timepoint_data, save_dir, cmap1, cmap2): |
| | """Plot all branches together.""" |
| | n_branches = len(all_trajs) |
| | branch_names = [f'Branch {i+1}' for i in range(n_branches)] |
| | branch_colors = ['#B83CFF', '#50B2D7'][:n_branches] |
| | |
| | |
| | if 't2_1' in timepoint_data: |
| | tp_keys = ['t0', 't1', 't2_1', 't2_2'] |
| | tp_colors = ['#05009E', '#A19EFF', '#B83CFF', '#50B2D7'] |
| | tp_labels = ['t=0', 't=1', 't=2 (branch 1)', 't=2 (branch 2)'] |
| | else: |
| | tp_keys = ['t0', 't1', 't2'] |
| | tp_colors = ['#05009E', '#A19EFF', '#B83CFF'] |
| | tp_labels = ['t=0', 't=1', 't=2'] |
| | |
| | |
| | all_trajs_np = [] |
| | for traj in all_trajs: |
| | if isinstance(traj, list): |
| | traj = torch.stack(traj, dim=1) |
| | if torch.is_tensor(traj): |
| | traj = traj.cpu().detach().numpy() |
| | all_trajs_np.append(traj) |
| | all_trajs = all_trajs_np |
| | |
| | |
| | for key in list(timepoint_data.keys()): |
| | if torch.is_tensor(timepoint_data[key]): |
| | timepoint_data[key] = timepoint_data[key].cpu().numpy() |
| | |
| | fig, ax = plt.subplots(figsize=(12, 10)) |
| | |
| | |
| | for idx, (t_key, color, label) in enumerate(zip( |
| | tp_keys, |
| | tp_colors, |
| | tp_labels |
| | )): |
| | if t_key in timepoint_data: |
| | coords = timepoint_data[t_key] |
| | ax.scatter(coords[:, 0], coords[:, 1], |
| | c=color, s=80, alpha=0.4, marker='x', |
| | label=f'{label} cells', linewidth=1.5) |
| | |
| | |
| | cmaps = [cmap1, cmap2] |
| | for i, traj in enumerate(all_trajs): |
| | traj_2d = traj[:, :, :2] |
| | c_end = branch_colors[i] |
| | cmap = cmaps[i] |
| | n_time = traj_2d.shape[1] |
| | color_vals = cmap(np.linspace(0, 1, n_time)) |
| | segments = [] |
| | seg_colors = [] |
| | for j in range(traj_2d.shape[0]): |
| | pts = traj_2d[j] |
| | segs = np.stack([pts[:-1], pts[1:]], axis=1) |
| | segments.append(segs) |
| | seg_colors.append(color_vals[:-1]) |
| | segments = np.concatenate(segments, axis=0) |
| | seg_colors = np.concatenate(seg_colors, axis=0) |
| | lc = LineCollection(segments, colors=seg_colors, linewidths=2, alpha=0.8) |
| | ax.add_collection(lc) |
| | |
| | ax.scatter(traj_2d[:, 0, 0], traj_2d[:, 0, 1], |
| | c='#05009E', s=30, marker='o', |
| | label=f'{branch_names[i]} Start', |
| | zorder=5, edgecolors='white', linewidth=1) |
| | ax.scatter(traj_2d[:, -1, 0], traj_2d[:, -1, 1], |
| | c=c_end, s=30, marker='o', |
| | label=f'{branch_names[i]} End', |
| | zorder=5, edgecolors='white', linewidth=1) |
| | |
| | ax.set_xlabel("PC1", fontsize=14) |
| | ax.set_ylabel("PC2", fontsize=14) |
| | ax.set_title("All Branch Trajectories with Timepoint Background", |
| | fontsize=16, weight='bold') |
| | ax.grid(True, alpha=0.3) |
| | ax.legend(loc='upper right', fontsize=12, frameon=False) |
| | |
| | plt.tight_layout() |
| | plt.savefig(f'{save_dir}/{self.args.data_name}_combined.png', dpi=300) |
| | plt.close() |
| |
|
| |
|
| | class FlowNetTestClonidine(BranchFlowNetTrainBase): |
| | """Test class for Clonidine perturbation experiment (1 or 2 branches).""" |
| | |
| | def test_step(self, batch, batch_idx): |
| | |
| | if isinstance(batch, dict) and "test_samples" in batch: |
| | |
| | main_batch = batch["test_samples"] |
| | elif isinstance(batch, (list, tuple)) and len(batch) >= 1: |
| | |
| | test_samples = batch[0] |
| | if isinstance(test_samples, dict) and "test_samples" in test_samples: |
| | main_batch = test_samples["test_samples"][0] |
| | else: |
| | main_batch = test_samples |
| | else: |
| | |
| | main_batch = batch |
| | |
| | |
| | timepoint_data = self.trainer.datamodule.get_timepoint_data() |
| | device = main_batch["x0"][0].device |
| | |
| | |
| | x0 = self.trainer.datamodule.val_dataloaders["x0"].dataset.tensors[0].to(device) |
| | t_span = torch.linspace(0, 1, 100).to(device) |
| | |
| | |
| | custom_colors_1 = ["#05009E", "#A19EFF", "#B83CFF"] |
| | custom_colors_2 = ["#05009E", "#A19EFF", "#50B2D7"] |
| | custom_cmap_1 = LinearSegmentedColormap.from_list("cmap1", custom_colors_1) |
| | custom_cmap_2 = LinearSegmentedColormap.from_list("cmap2", custom_colors_2) |
| | |
| | all_trajs = [] |
| | all_endpoints = [] |
| |
|
| | for i, flow_net in enumerate(self.flow_nets): |
| | node = NeuralODE( |
| | flow_model_torch_wrapper(flow_net), |
| | solver="euler", |
| | sensitivity="adjoint", |
| | ) |
| |
|
| | with torch.no_grad(): |
| | traj = node.trajectory(x0, t_span).cpu() |
| |
|
| | traj = torch.transpose(traj, 0, 1) |
| | all_trajs.append(traj) |
| | all_endpoints.append(traj[:, -1, :]) |
| |
|
| | |
| | n_trials = 5 |
| | n_branches = len(self.flow_nets) |
| |
|
| | |
| | gt_data_per_branch = [] |
| | for i in range(n_branches): |
| | if n_branches == 1: |
| | key = 't1' |
| | else: |
| | key = f't1_{i+1}' if f't1_{i+1}' in timepoint_data else 't1' |
| | gt_data_per_branch.append(torch.tensor(timepoint_data[key], dtype=torch.float32)) |
| | gt_all = torch.cat(gt_data_per_branch, dim=0) |
| |
|
| | |
| | metrics_dict = {} |
| | for i in range(n_branches): |
| | w1_br, w2_br, mmd_br = [], [], [] |
| | pred = all_endpoints[i] |
| | gt = gt_data_per_branch[i] |
| | for trial in range(n_trials): |
| | n_min = min(pred.shape[0], gt.shape[0]) |
| | perm_pred = torch.randperm(pred.shape[0])[:n_min] |
| | perm_gt = torch.randperm(gt.shape[0])[:n_min] |
| | m = compute_distribution_distances(pred[perm_pred, :2], gt[perm_gt, :2]) |
| | w1_br.append(m["W1"]); w2_br.append(m["W2"]); mmd_br.append(m["MMD"]) |
| | metrics_dict[f"branch_{i+1}"] = { |
| | "W1_mean": float(np.mean(w1_br)), "W1_std": float(np.std(w1_br, ddof=1)), |
| | "W2_mean": float(np.mean(w2_br)), "W2_std": float(np.std(w2_br, ddof=1)), |
| | "MMD_mean": float(np.mean(mmd_br)), "MMD_std": float(np.std(mmd_br, ddof=1)), |
| | } |
| | self.log(f"test/W1_branch{i+1}", np.mean(w1_br), on_epoch=True) |
| | print(f"Branch {i+1} — W1: {np.mean(w1_br):.6f}±{np.std(w1_br, ddof=1):.6f}, " |
| | f"W2: {np.mean(w2_br):.6f}±{np.std(w2_br, ddof=1):.6f}, " |
| | f"MMD: {np.mean(mmd_br):.6f}±{np.std(mmd_br, ddof=1):.6f}") |
| |
|
| | |
| | pred_all = torch.cat(all_endpoints, dim=0) |
| | w1_trials, w2_trials, mmd_trials = [], [], [] |
| | for trial in range(n_trials): |
| | n_min = min(pred_all.shape[0], gt_all.shape[0]) |
| | perm_pred = torch.randperm(pred_all.shape[0])[:n_min] |
| | perm_gt = torch.randperm(gt_all.shape[0])[:n_min] |
| | m = compute_distribution_distances(pred_all[perm_pred, :2], gt_all[perm_gt, :2]) |
| | w1_trials.append(m["W1"]); w2_trials.append(m["W2"]); mmd_trials.append(m["MMD"]) |
| |
|
| | w1_mean, w1_std = np.mean(w1_trials), np.std(w1_trials, ddof=1) |
| | w2_mean, w2_std = np.mean(w2_trials), np.std(w2_trials, ddof=1) |
| | mmd_mean, mmd_std = np.mean(mmd_trials), np.std(mmd_trials, ddof=1) |
| | self.log("test/W1_t1_combined", w1_mean, on_epoch=True) |
| | self.log("test/W2_t1_combined", w2_mean, on_epoch=True) |
| | self.log("test/MMD_t1_combined", mmd_mean, on_epoch=True) |
| | metrics_dict['t1_combined'] = { |
| | "W1_mean": float(w1_mean), "W1_std": float(w1_std), |
| | "W2_mean": float(w2_mean), "W2_std": float(w2_std), |
| | "MMD_mean": float(mmd_mean), "MMD_std": float(mmd_std), |
| | "n_trials": n_trials, |
| | } |
| | print(f"\n=== Combined @ t1 ===") |
| | print(f"W1: {w1_mean:.6f} ± {w1_std:.6f}") |
| | print(f"W2: {w2_mean:.6f} ± {w2_std:.6f}") |
| | print(f"MMD: {mmd_mean:.6f} ± {mmd_std:.6f}") |
| |
|
| | |
| | run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name |
| | results_dir = os.path.join(self.args.working_dir, 'results', run_name) |
| | figures_dir = f'{results_dir}/figures' |
| | os.makedirs(figures_dir, exist_ok=True) |
| | |
| | |
| | metrics_path = f'{results_dir}/metrics.json' |
| | with open(metrics_path, 'w') as f: |
| | json.dump(metrics_dict, f, indent=2) |
| | print(f"Metrics saved to {metrics_path}") |
| |
|
| | |
| | detailed_csv_path = f'{results_dir}/metrics_detailed.csv' |
| | with open(detailed_csv_path, 'w', newline='') as csvfile: |
| | writer = csv.writer(csvfile) |
| | writer.writerow(['Metric_Group', 'W1_Mean', 'W1_Std', 'W2_Mean', 'W2_Std', 'MMD_Mean', 'MMD_Std']) |
| | for key in sorted(metrics_dict.keys()): |
| | m = metrics_dict[key] |
| | writer.writerow([key, |
| | f'{m.get("W1_mean", m.get("W1", 0)):.6f}', f'{m.get("W1_std", 0):.6f}', |
| | f'{m.get("W2_mean", m.get("W2", 0)):.6f}', f'{m.get("W2_std", 0):.6f}', |
| | f'{m.get("MMD_mean", m.get("MMD", 0)):.6f}', f'{m.get("MMD_std", 0):.6f}']) |
| | print(f"Detailed metrics CSV saved to {detailed_csv_path}") |
| |
|
| | |
| | self._plot_clonidine_branches(all_trajs, timepoint_data, figures_dir, custom_cmap_1, custom_cmap_2) |
| | self._plot_clonidine_combined(all_trajs, timepoint_data, figures_dir) |
| |
|
| | print(f"Clonidine figures saved to {figures_dir}") |
| |
|
| | def _plot_clonidine_branches(self, all_trajs, timepoint_data, save_dir, cmap1, cmap2): |
| | """Plot each branch separately.""" |
| | branch_names = ['Branch 1', 'Branch 2'] |
| | branch_colors = ['#B83CFF', '#50B2D7'] |
| | cmaps = [cmap1, cmap2] |
| | |
| | |
| | all_coords = [] |
| | if 't1_1' in timepoint_data: |
| | tp_keys = ['t0'] + [f't1_{i+1}' for i in range(len(all_trajs))] |
| | else: |
| | tp_keys = ['t0', 't1'] |
| | for key in tp_keys: |
| | all_coords.append(timepoint_data[key][:, :2]) |
| | for traj in all_trajs: |
| | all_coords.append(traj.reshape(-1, traj.shape[-1])[:, :2]) |
| | |
| | all_coords = np.concatenate(all_coords, axis=0) |
| | x_min, x_max = all_coords[:, 0].min(), all_coords[:, 0].max() |
| | y_min, y_max = all_coords[:, 1].min(), all_coords[:, 1].max() |
| | |
| | x_margin = 0.05 * (x_max - x_min) |
| | y_margin = 0.05 * (y_max - y_min) |
| | x_min -= x_margin |
| | x_max += x_margin |
| | y_min -= y_margin |
| | y_max += y_margin |
| |
|
| | for i, traj in enumerate(all_trajs): |
| | fig, ax = plt.subplots(figsize=(10, 8)) |
| | c_end = branch_colors[i] |
| | |
| | |
| | t1_key = f't1_{i+1}' if f't1_{i+1}' in timepoint_data else 't1' |
| | coords_list = [timepoint_data['t0'], timepoint_data[t1_key]] |
| | tp_colors = ['#05009E', c_end] |
| | t1_label = f"t=1 (branch {i+1})" if len(all_trajs) > 1 else "t=1" |
| | tp_labels = ["t=0", t1_label] |
| | |
| | for coords, color, label in zip(coords_list, tp_colors, tp_labels): |
| | ax.scatter(coords[:, 0], coords[:, 1], |
| | c=color, s=80, alpha=0.4, marker='x', |
| | label=f'{label} cells', linewidth=1.5) |
| | |
| | |
| | traj_2d = traj[:, :, :2] |
| | n_time = traj_2d.shape[1] |
| | color_vals = cmaps[i](np.linspace(0, 1, n_time)) |
| | segments = [] |
| | seg_colors = [] |
| | for j in range(traj_2d.shape[0]): |
| | pts = traj_2d[j] |
| | segs = np.stack([pts[:-1], pts[1:]], axis=1) |
| | segments.append(segs) |
| | seg_colors.append(color_vals[:-1]) |
| | segments = np.concatenate(segments, axis=0) |
| | seg_colors = np.concatenate(seg_colors, axis=0) |
| | lc = LineCollection(segments, colors=seg_colors, linewidths=2, alpha=0.8) |
| | ax.add_collection(lc) |
| | |
| | |
| | ax.scatter(traj_2d[:, 0, 0], traj_2d[:, 0, 1], |
| | c='#05009E', s=30, marker='o', label='Trajectory Start', |
| | zorder=5, edgecolors='white', linewidth=1) |
| | ax.scatter(traj_2d[:, -1, 0], traj_2d[:, -1, 1], |
| | c=c_end, s=30, marker='o', label='Trajectory End', |
| | zorder=5, edgecolors='white', linewidth=1) |
| | |
| | ax.set_xlim(x_min, x_max) |
| | ax.set_ylim(y_min, y_max) |
| | ax.set_xlabel("PC1", fontsize=12) |
| | ax.set_ylabel("PC2", fontsize=12) |
| | ax.set_title(f"{branch_names[i]}: Trajectories with Timepoint Background", fontsize=14) |
| | ax.grid(True, alpha=0.3) |
| | ax.legend(loc='upper right', fontsize=16, frameon=False) |
| | |
| | plt.tight_layout() |
| | plt.savefig(f'{save_dir}/{self.args.data_name}_branch{i+1}.png', dpi=300) |
| | plt.close() |
| |
|
| | def _plot_clonidine_combined(self, all_trajs, timepoint_data, save_dir): |
| | """Plot all branches together.""" |
| | branch_names = ['Branch 1', 'Branch 2'] |
| | branch_colors = ['#B83CFF', '#50B2D7'] |
| | |
| | fig, ax = plt.subplots(figsize=(12, 10)) |
| | |
| | |
| | if 't1_1' in timepoint_data: |
| | tp_keys = ['t0'] + [f't1_{j+1}' for j in range(len(all_trajs))] |
| | tp_labels_list = ['t=0'] + [f't=1 (branch {j+1})' for j in range(len(all_trajs))] |
| | else: |
| | tp_keys = ['t0', 't1'] |
| | tp_labels_list = ['t=0', 't=1'] |
| | tp_colors = ['#05009E', '#B83CFF', '#50B2D7'][:len(tp_keys)] |
| | |
| | |
| | for t_key, color, label in zip(tp_keys, tp_colors, tp_labels_list): |
| | coords = timepoint_data[t_key] |
| | ax.scatter(coords[:, 0], coords[:, 1], |
| | c=color, s=80, alpha=0.4, marker='x', |
| | label=f'{label} cells', linewidth=1.5) |
| | |
| | |
| | custom_colors_1 = ["#05009E", "#A19EFF", "#B83CFF"] |
| | custom_colors_2 = ["#05009E", "#A19EFF", "#50B2D7"] |
| | cmaps = [ |
| | LinearSegmentedColormap.from_list("clon_cmap1", custom_colors_1), |
| | LinearSegmentedColormap.from_list("clon_cmap2", custom_colors_2), |
| | ] |
| | for i, traj in enumerate(all_trajs): |
| | traj_2d = traj[:, :, :2] |
| | c_end = branch_colors[i] |
| | cmap = cmaps[i] |
| | n_time = traj_2d.shape[1] |
| | color_vals = cmap(np.linspace(0, 1, n_time)) |
| | segments = [] |
| | seg_colors = [] |
| | for j in range(traj_2d.shape[0]): |
| | pts = traj_2d[j] |
| | segs = np.stack([pts[:-1], pts[1:]], axis=1) |
| | segments.append(segs) |
| | seg_colors.append(color_vals[:-1]) |
| | segments = np.concatenate(segments, axis=0) |
| | seg_colors = np.concatenate(seg_colors, axis=0) |
| | lc = LineCollection(segments, colors=seg_colors, linewidths=2, alpha=0.8) |
| | ax.add_collection(lc) |
| | |
| | ax.scatter(traj_2d[:, 0, 0], traj_2d[:, 0, 1], |
| | c='#05009E', s=30, marker='o', |
| | label=f'{branch_names[i]} Start', |
| | zorder=5, edgecolors='white', linewidth=1) |
| | ax.scatter(traj_2d[:, -1, 0], traj_2d[:, -1, 1], |
| | c=c_end, s=30, marker='o', |
| | label=f'{branch_names[i]} End', |
| | zorder=5, edgecolors='white', linewidth=1) |
| | |
| | ax.set_xlabel("PC1", fontsize=14) |
| | ax.set_ylabel("PC2", fontsize=14) |
| | ax.set_title("All Branch Trajectories with Timepoint Background", |
| | fontsize=16, weight='bold') |
| | ax.grid(True, alpha=0.3) |
| | ax.legend(loc='upper right', fontsize=12, frameon=False) |
| | |
| | plt.tight_layout() |
| | plt.savefig(f'{save_dir}/{self.args.data_name}_combined.png', dpi=300) |
| | plt.close() |
| |
|
| |
|
| | class FlowNetTestTrametinib(BranchFlowNetTrainBase): |
| | """Test class for Trametinib perturbation experiment (1 or 3 branches).""" |
| | |
| | def test_step(self, batch, batch_idx): |
| | |
| | if isinstance(batch, dict) and "test_samples" in batch: |
| | |
| | main_batch = batch["test_samples"] |
| | elif isinstance(batch, (list, tuple)) and len(batch) >= 1: |
| | |
| | test_samples = batch[0] |
| | if isinstance(test_samples, dict) and "test_samples" in test_samples: |
| | main_batch = test_samples["test_samples"][0] |
| | else: |
| | main_batch = test_samples |
| | else: |
| | |
| | main_batch = batch |
| | |
| | |
| | timepoint_data = self.trainer.datamodule.get_timepoint_data() |
| | device = main_batch["x0"][0].device |
| | |
| | |
| | x0 = self.trainer.datamodule.val_dataloaders["x0"].dataset.tensors[0].to(device) |
| | t_span = torch.linspace(0, 1, 100).to(device) |
| | |
| | |
| | custom_colors_1 = ["#05009E", "#A19EFF", "#9793F8"] |
| | custom_colors_2 = ["#05009E", "#A19EFF", "#50B2D7"] |
| | custom_colors_3 = ["#05009E", "#A19EFF", "#B83CFF"] |
| | custom_cmap_1 = LinearSegmentedColormap.from_list("cmap1", custom_colors_1) |
| | custom_cmap_2 = LinearSegmentedColormap.from_list("cmap2", custom_colors_2) |
| | custom_cmap_3 = LinearSegmentedColormap.from_list("cmap3", custom_colors_3) |
| | |
| | all_trajs = [] |
| | all_endpoints = [] |
| |
|
| | for i, flow_net in enumerate(self.flow_nets): |
| | node = NeuralODE( |
| | flow_model_torch_wrapper(flow_net), |
| | solver="euler", |
| | sensitivity="adjoint", |
| | ) |
| |
|
| | with torch.no_grad(): |
| | traj = node.trajectory(x0, t_span).cpu() |
| |
|
| | traj = torch.transpose(traj, 0, 1) |
| | all_trajs.append(traj) |
| | all_endpoints.append(traj[:, -1, :]) |
| |
|
| | |
| | n_trials = 5 |
| | n_branches = len(self.flow_nets) |
| |
|
| | |
| | gt_data_per_branch = [] |
| | for i in range(n_branches): |
| | if n_branches == 1: |
| | key = 't1' |
| | else: |
| | key = f't1_{i+1}' if f't1_{i+1}' in timepoint_data else 't1' |
| | gt_data_per_branch.append(torch.tensor(timepoint_data[key], dtype=torch.float32)) |
| | gt_all = torch.cat(gt_data_per_branch, dim=0) |
| |
|
| | |
| | metrics_dict = {} |
| | for i in range(n_branches): |
| | w1_br, w2_br, mmd_br = [], [], [] |
| | pred = all_endpoints[i] |
| | gt = gt_data_per_branch[i] |
| | for trial in range(n_trials): |
| | n_min = min(pred.shape[0], gt.shape[0]) |
| | perm_pred = torch.randperm(pred.shape[0])[:n_min] |
| | perm_gt = torch.randperm(gt.shape[0])[:n_min] |
| | m = compute_distribution_distances(pred[perm_pred, :2], gt[perm_gt, :2]) |
| | w1_br.append(m["W1"]); w2_br.append(m["W2"]); mmd_br.append(m["MMD"]) |
| | metrics_dict[f"branch_{i+1}"] = { |
| | "W1_mean": float(np.mean(w1_br)), "W1_std": float(np.std(w1_br, ddof=1)), |
| | "W2_mean": float(np.mean(w2_br)), "W2_std": float(np.std(w2_br, ddof=1)), |
| | "MMD_mean": float(np.mean(mmd_br)), "MMD_std": float(np.std(mmd_br, ddof=1)), |
| | } |
| | self.log(f"test/W1_branch{i+1}", np.mean(w1_br), on_epoch=True) |
| | print(f"Branch {i+1} — W1: {np.mean(w1_br):.6f}±{np.std(w1_br, ddof=1):.6f}, " |
| | f"W2: {np.mean(w2_br):.6f}±{np.std(w2_br, ddof=1):.6f}, " |
| | f"MMD: {np.mean(mmd_br):.6f}±{np.std(mmd_br, ddof=1):.6f}") |
| |
|
| | |
| | pred_all = torch.cat(all_endpoints, dim=0) |
| | w1_trials, w2_trials, mmd_trials = [], [], [] |
| | for trial in range(n_trials): |
| | n_min = min(pred_all.shape[0], gt_all.shape[0]) |
| | perm_pred = torch.randperm(pred_all.shape[0])[:n_min] |
| | perm_gt = torch.randperm(gt_all.shape[0])[:n_min] |
| | m = compute_distribution_distances(pred_all[perm_pred, :2], gt_all[perm_gt, :2]) |
| | w1_trials.append(m["W1"]); w2_trials.append(m["W2"]); mmd_trials.append(m["MMD"]) |
| |
|
| | w1_mean, w1_std = np.mean(w1_trials), np.std(w1_trials, ddof=1) |
| | w2_mean, w2_std = np.mean(w2_trials), np.std(w2_trials, ddof=1) |
| | mmd_mean, mmd_std = np.mean(mmd_trials), np.std(mmd_trials, ddof=1) |
| | self.log("test/W1_t1_combined", w1_mean, on_epoch=True) |
| | self.log("test/W2_t1_combined", w2_mean, on_epoch=True) |
| | self.log("test/MMD_t1_combined", mmd_mean, on_epoch=True) |
| | metrics_dict['t1_combined'] = { |
| | "W1_mean": float(w1_mean), "W1_std": float(w1_std), |
| | "W2_mean": float(w2_mean), "W2_std": float(w2_std), |
| | "MMD_mean": float(mmd_mean), "MMD_std": float(mmd_std), |
| | "n_trials": n_trials, |
| | } |
| | print(f"\n=== Combined @ t1 ===") |
| | print(f"W1: {w1_mean:.6f} ± {w1_std:.6f}") |
| | print(f"W2: {w2_mean:.6f} ± {w2_std:.6f}") |
| | print(f"MMD: {mmd_mean:.6f} ± {mmd_std:.6f}") |
| |
|
| | |
| | run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name |
| | results_dir = os.path.join(self.args.working_dir, 'results', run_name) |
| | figures_dir = f'{results_dir}/figures' |
| | os.makedirs(figures_dir, exist_ok=True) |
| | |
| | |
| | metrics_path = f'{results_dir}/metrics.json' |
| | with open(metrics_path, 'w') as f: |
| | json.dump(metrics_dict, f, indent=2) |
| | print(f"Metrics saved to {metrics_path}") |
| |
|
| | |
| | detailed_csv_path = f'{results_dir}/metrics_detailed.csv' |
| | with open(detailed_csv_path, 'w', newline='') as csvfile: |
| | writer = csv.writer(csvfile) |
| | writer.writerow(['Metric_Group', 'W1_Mean', 'W1_Std', 'W2_Mean', 'W2_Std', 'MMD_Mean', 'MMD_Std']) |
| | for key in sorted(metrics_dict.keys()): |
| | m = metrics_dict[key] |
| | writer.writerow([key, |
| | f'{m.get("W1_mean", m.get("W1", 0)):.6f}', f'{m.get("W1_std", 0):.6f}', |
| | f'{m.get("W2_mean", m.get("W2", 0)):.6f}', f'{m.get("W2_std", 0):.6f}', |
| | f'{m.get("MMD_mean", m.get("MMD", 0)):.6f}', f'{m.get("MMD_std", 0):.6f}']) |
| | print(f"Detailed metrics CSV saved to {detailed_csv_path}") |
| |
|
| | |
| | self._plot_trametinib_branches(all_trajs, timepoint_data, figures_dir, |
| | custom_cmap_1, custom_cmap_2, custom_cmap_3) |
| | self._plot_trametinib_combined(all_trajs, timepoint_data, figures_dir) |
| |
|
| | print(f"Trametinib figures saved to {figures_dir}") |
| |
|
| | def _plot_trametinib_branches(self, all_trajs, timepoint_data, save_dir, |
| | cmap1, cmap2, cmap3): |
| | """Plot each branch separately.""" |
| | branch_names = ['Branch 1', 'Branch 2', 'Branch 3'] |
| | branch_colors = ['#9793F8', '#50B2D7', '#B83CFF'] |
| | cmaps = [cmap1, cmap2, cmap3] |
| | |
| | |
| | all_coords = [] |
| | if 't1_1' in timepoint_data: |
| | tp_keys = ['t0'] + [f't1_{i+1}' for i in range(len(all_trajs))] |
| | else: |
| | tp_keys = ['t0', 't1'] |
| | for key in tp_keys: |
| | all_coords.append(timepoint_data[key][:, :2]) |
| | for traj in all_trajs: |
| | all_coords.append(traj.reshape(-1, traj.shape[-1])[:, :2]) |
| | |
| | all_coords = np.concatenate(all_coords, axis=0) |
| | x_min, x_max = all_coords[:, 0].min(), all_coords[:, 0].max() |
| | y_min, y_max = all_coords[:, 1].min(), all_coords[:, 1].max() |
| | |
| | x_margin = 0.05 * (x_max - x_min) |
| | y_margin = 0.05 * (y_max - y_min) |
| | x_min -= x_margin |
| | x_max += x_margin |
| | y_min -= y_margin |
| | y_max += y_margin |
| |
|
| | for i, traj in enumerate(all_trajs): |
| | fig, ax = plt.subplots(figsize=(10, 8)) |
| | c_end = branch_colors[i] |
| | |
| | |
| | t1_key = f't1_{i+1}' if f't1_{i+1}' in timepoint_data else 't1' |
| | coords_list = [timepoint_data['t0'], timepoint_data[t1_key]] |
| | tp_colors = ['#05009E', c_end] |
| | t1_label = f"t=1 (branch {i+1})" if len(all_trajs) > 1 else "t=1" |
| | tp_labels = ["t=0", t1_label] |
| | |
| | for coords, color, label in zip(coords_list, tp_colors, tp_labels): |
| | ax.scatter(coords[:, 0], coords[:, 1], |
| | c=color, s=80, alpha=0.4, marker='x', |
| | label=f'{label} cells', linewidth=1.5) |
| | |
| | |
| | traj_2d = traj[:, :, :2] |
| | n_time = traj_2d.shape[1] |
| | color_vals = cmaps[i](np.linspace(0, 1, n_time)) |
| | segments = [] |
| | seg_colors = [] |
| | for j in range(traj_2d.shape[0]): |
| | pts = traj_2d[j] |
| | segs = np.stack([pts[:-1], pts[1:]], axis=1) |
| | segments.append(segs) |
| | seg_colors.append(color_vals[:-1]) |
| | segments = np.concatenate(segments, axis=0) |
| | seg_colors = np.concatenate(seg_colors, axis=0) |
| | lc = LineCollection(segments, colors=seg_colors, linewidths=2, alpha=0.8) |
| | ax.add_collection(lc) |
| | |
| | |
| | ax.scatter(traj_2d[:, 0, 0], traj_2d[:, 0, 1], |
| | c='#05009E', s=30, marker='o', label='Trajectory Start', |
| | zorder=5, edgecolors='white', linewidth=1) |
| | ax.scatter(traj_2d[:, -1, 0], traj_2d[:, -1, 1], |
| | c=c_end, s=30, marker='o', label='Trajectory End', |
| | zorder=5, edgecolors='white', linewidth=1) |
| | |
| | ax.set_xlim(x_min, x_max) |
| | ax.set_ylim(y_min, y_max) |
| | ax.set_xlabel("PC1", fontsize=12) |
| | ax.set_ylabel("PC2", fontsize=12) |
| | ax.set_title(f"{branch_names[i]}: Trajectories with Timepoint Background", fontsize=14) |
| | ax.grid(True, alpha=0.3) |
| | ax.legend(loc='upper right', fontsize=16, frameon=False) |
| | |
| | plt.tight_layout() |
| | plt.savefig(f'{save_dir}/{self.args.data_name}_branch{i+1}.png', dpi=300) |
| | plt.close() |
| |
|
| | def _plot_trametinib_combined(self, all_trajs, timepoint_data, save_dir): |
| | """Plot all 3 branches together.""" |
| | branch_names = ['Branch 1', 'Branch 2', 'Branch 3'] |
| | branch_colors = ['#9793F8', '#50B2D7', '#B83CFF'] |
| | |
| | fig, ax = plt.subplots(figsize=(12, 10)) |
| | |
| | |
| | if 't1_1' in timepoint_data: |
| | tp_keys = ['t0'] + [f't1_{j+1}' for j in range(len(all_trajs))] |
| | tp_labels_list = ['t=0'] + [f't=1 (branch {j+1})' for j in range(len(all_trajs))] |
| | else: |
| | tp_keys = ['t0', 't1'] |
| | tp_labels_list = ['t=0', 't=1'] |
| | tp_colors = ['#05009E', '#9793F8', '#50B2D7', '#B83CFF'][:len(tp_keys)] |
| | |
| | |
| | for t_key, color, label in zip(tp_keys, tp_colors, tp_labels_list): |
| | coords = timepoint_data[t_key] |
| | ax.scatter(coords[:, 0], coords[:, 1], |
| | c=color, s=80, alpha=0.4, marker='x', |
| | label=f'{label} cells', linewidth=1.5) |
| | |
| | |
| | custom_colors_1 = ["#05009E", "#A19EFF", "#9793F8"] |
| | custom_colors_2 = ["#05009E", "#A19EFF", "#50B2D7"] |
| | custom_colors_3 = ["#05009E", "#A19EFF", "#D577FF"] |
| | cmaps = [ |
| | LinearSegmentedColormap.from_list("tram_cmap1", custom_colors_1), |
| | LinearSegmentedColormap.from_list("tram_cmap2", custom_colors_2), |
| | LinearSegmentedColormap.from_list("tram_cmap3", custom_colors_3), |
| | ] |
| | for i, traj in enumerate(all_trajs): |
| | traj_2d = traj[:, :, :2] |
| | c_end = branch_colors[i] |
| | cmap = cmaps[i] |
| | n_time = traj_2d.shape[1] |
| | color_vals = cmap(np.linspace(0, 1, n_time)) |
| | segments = [] |
| | seg_colors = [] |
| | for j in range(traj_2d.shape[0]): |
| | pts = traj_2d[j] |
| | segs = np.stack([pts[:-1], pts[1:]], axis=1) |
| | segments.append(segs) |
| | seg_colors.append(color_vals[:-1]) |
| | segments = np.concatenate(segments, axis=0) |
| | seg_colors = np.concatenate(seg_colors, axis=0) |
| | lc = LineCollection(segments, colors=seg_colors, linewidths=2, alpha=0.8) |
| | ax.add_collection(lc) |
| | |
| | ax.scatter(traj_2d[:, 0, 0], traj_2d[:, 0, 1], |
| | c='#05009E', s=30, marker='o', |
| | label=f'{branch_names[i]} Start', |
| | zorder=5, edgecolors='white', linewidth=1) |
| | ax.scatter(traj_2d[:, -1, 0], traj_2d[:, -1, 1], |
| | c=c_end, s=30, marker='o', |
| | label=f'{branch_names[i]} End', |
| | zorder=5, edgecolors='white', linewidth=1) |
| | |
| | ax.set_xlabel("PC1", fontsize=14) |
| | ax.set_ylabel("PC2", fontsize=14) |
| | ax.set_title("All Branch Trajectories with Timepoint Background", |
| | fontsize=16, weight='bold') |
| | ax.grid(True, alpha=0.3) |
| | ax.legend(loc='upper right', fontsize=12, frameon=False) |
| | |
| | plt.tight_layout() |
| | plt.savefig(f'{save_dir}/{self.args.data_name}_combined.png', dpi=300) |
| | plt.close() |
| |
|
| | class FlowNetTestVeres(GrowthNetTrain): |
| | """Test class for Veres pancreatic endocrinogenesis experiment (3 or 5 branches).""" |
| | |
| | def test_step(self, batch, batch_idx): |
| | |
| | if isinstance(batch, dict): |
| | main_batch = batch["test_samples"][0] |
| | metric_batch = batch["metric_samples"][0] |
| | else: |
| | |
| | if isinstance(batch[0], dict): |
| | |
| | main_batch = batch[0]["test_samples"][0] |
| | metric_batch = batch[0]["metric_samples"][0] |
| | else: |
| | |
| | main_batch = batch[0][0] |
| | metric_batch = batch[1][0] |
| | |
| | |
| | timepoint_data = self.trainer.datamodule.get_timepoint_data() |
| | device = main_batch["x0"][0].device |
| | |
| | |
| | x0_all = self.trainer.datamodule.val_dataloaders["x0"].dataset.tensors[0].to(device) |
| | w0_all = torch.ones(x0_all.shape[0], 1, dtype=torch.float32).to(device) |
| | full_batch = {"x0": (x0_all, w0_all)} |
| | |
| | time_points, all_endpoints, all_trajs, mass_over_time, energy_over_time, weights_over_time = self.get_mass_and_position(full_batch, metric_batch) |
| |
|
| | n_branches = len(self.flow_nets) |
| | |
| | |
| | t_span = torch.linspace(0, 1, 101).to(device) |
| |
|
| | |
| | |
| | |
| | trajs_TBD = [torch.stack(branch_list, dim=0) for branch_list in all_trajs] |
| | trajs_BTD = [t.permute(1, 0, 2) for t in trajs_TBD] |
| |
|
| | all_trajs = [] |
| | all_endpoints = [] |
| | |
| | all_intermediates = [] |
| |
|
| | for traj in trajs_BTD: |
| | |
| | |
| | if self.whiten: |
| | traj_np = traj.detach().cpu().numpy() |
| | n_samples, n_time, n_dims = traj_np.shape |
| | traj_flat = traj_np.reshape(-1, n_dims) |
| | traj_inv_flat = self.trainer.datamodule.scaler.inverse_transform(traj_flat) |
| | traj_inv = traj_inv_flat.reshape(n_samples, n_time, n_dims) |
| | traj = torch.tensor(traj_inv, dtype=torch.float32) |
| |
|
| | all_trajs.append(traj) |
| |
|
| | |
| | n_T = traj.shape[1] |
| | |
| | inter_times = np.linspace(0.0, 1.0, 8)[1:-1] |
| | inter_indices = [int(round(t * (n_T - 1))) for t in inter_times] |
| | |
| | intermediates = torch.stack([traj[:, idx, :] for idx in inter_indices], dim=1) |
| | all_intermediates.append(intermediates) |
| |
|
| | |
| | all_endpoints.append(traj[:, -1, :]) |
| |
|
| | |
| | n_trials = 5 |
| | metrics_dict = {} |
| |
|
| | |
| | intermediate_keys = sorted([k for k in timepoint_data.keys() |
| | if k.startswith('t') and '_' not in k and k != 't0']) |
| |
|
| | if intermediate_keys: |
| | n_evals = min(6, len(intermediate_keys)) |
| | for j in range(n_evals): |
| | intermediate_key = intermediate_keys[j] |
| | true_data_intermediate = torch.tensor(timepoint_data[intermediate_key], dtype=torch.float32) |
| |
|
| | |
| | raw_intermediates = [branch[:, j, :] for branch in all_intermediates] |
| | all_raw_concat = torch.cat(raw_intermediates, dim=0).cpu() |
| |
|
| | w1_t, w2_t, mmd_t = [], [], [] |
| | w1_t_full, w2_t_full, mmd_t_full = [], [], [] |
| | for trial in range(n_trials): |
| | n_min = min(all_raw_concat.shape[0], true_data_intermediate.shape[0]) |
| | perm_pred = torch.randperm(all_raw_concat.shape[0])[:n_min] |
| | perm_gt = torch.randperm(true_data_intermediate.shape[0])[:n_min] |
| | |
| | m = compute_distribution_distances( |
| | all_raw_concat[perm_pred, :2], true_data_intermediate[perm_gt, :2]) |
| | w1_t.append(m["W1"]); w2_t.append(m["W2"]); mmd_t.append(m["MMD"]) |
| | |
| | m_full = compute_distribution_distances( |
| | all_raw_concat[perm_pred], true_data_intermediate[perm_gt]) |
| | w1_t_full.append(m_full["W1"]); w2_t_full.append(m_full["W2"]); mmd_t_full.append(m_full["MMD"]) |
| |
|
| | metrics_dict[f'{intermediate_key}_combined'] = { |
| | "W1_mean": float(np.mean(w1_t)), "W1_std": float(np.std(w1_t, ddof=1)), |
| | "W2_mean": float(np.mean(w2_t)), "W2_std": float(np.std(w2_t, ddof=1)), |
| | "MMD_mean": float(np.mean(mmd_t)), "MMD_std": float(np.std(mmd_t, ddof=1)), |
| | "W1_full_mean": float(np.mean(w1_t_full)), "W1_full_std": float(np.std(w1_t_full, ddof=1)), |
| | "W2_full_mean": float(np.mean(w2_t_full)), "W2_full_std": float(np.std(w2_t_full, ddof=1)), |
| | "MMD_full_mean": float(np.mean(mmd_t_full)), "MMD_full_std": float(np.std(mmd_t_full, ddof=1)), |
| | } |
| | self.log(f"test/W1_{intermediate_key}_combined", np.mean(w1_t), on_epoch=True) |
| | self.log(f"test/W1_full_{intermediate_key}_combined", np.mean(w1_t_full), on_epoch=True) |
| | print(f"{intermediate_key} combined — W1: {np.mean(w1_t):.6f}±{np.std(w1_t, ddof=1):.6f}, " |
| | f"W2: {np.mean(w2_t):.6f}±{np.std(w2_t, ddof=1):.6f}, " |
| | f"MMD: {np.mean(mmd_t):.6f}±{np.std(mmd_t, ddof=1):.6f}") |
| | print(f"{intermediate_key} combined (full) — W1: {np.mean(w1_t_full):.6f}±{np.std(w1_t_full, ddof=1):.6f}, " |
| | f"W2: {np.mean(w2_t_full):.6f}±{np.std(w2_t_full, ddof=1):.6f}, " |
| | f"MMD: {np.mean(mmd_t_full):.6f}±{np.std(mmd_t_full, ddof=1):.6f}") |
| |
|
| | |
| | gt_keys = sorted([k for k in timepoint_data.keys() if k.startswith('t7_')]) |
| | for i, endpoints in enumerate(all_endpoints): |
| | true_data_key = f"t7_{i}" |
| | if true_data_key not in timepoint_data: |
| | print(f"Warning: {true_data_key} not found in timepoint_data") |
| | continue |
| | gt = torch.tensor(timepoint_data[true_data_key], dtype=torch.float32) |
| | pred = endpoints.cpu() |
| |
|
| | w1_br, w2_br, mmd_br = [], [], [] |
| | w1_br_full, w2_br_full, mmd_br_full = [], [], [] |
| | for trial in range(n_trials): |
| | n_min = min(pred.shape[0], gt.shape[0]) |
| | perm_pred = torch.randperm(pred.shape[0])[:n_min] |
| | perm_gt = torch.randperm(gt.shape[0])[:n_min] |
| | |
| | m = compute_distribution_distances(pred[perm_pred, :2], gt[perm_gt, :2]) |
| | w1_br.append(m["W1"]); w2_br.append(m["W2"]); mmd_br.append(m["MMD"]) |
| | |
| | m_full = compute_distribution_distances(pred[perm_pred], gt[perm_gt]) |
| | w1_br_full.append(m_full["W1"]); w2_br_full.append(m_full["W2"]); mmd_br_full.append(m_full["MMD"]) |
| |
|
| | metrics_dict[f"branch_{i}"] = { |
| | "W1_mean": float(np.mean(w1_br)), "W1_std": float(np.std(w1_br, ddof=1)), |
| | "W2_mean": float(np.mean(w2_br)), "W2_std": float(np.std(w2_br, ddof=1)), |
| | "MMD_mean": float(np.mean(mmd_br)), "MMD_std": float(np.std(mmd_br, ddof=1)), |
| | "W1_full_mean": float(np.mean(w1_br_full)), "W1_full_std": float(np.std(w1_br_full, ddof=1)), |
| | "W2_full_mean": float(np.mean(w2_br_full)), "W2_full_std": float(np.std(w2_br_full, ddof=1)), |
| | "MMD_full_mean": float(np.mean(mmd_br_full)), "MMD_full_std": float(np.std(mmd_br_full, ddof=1)), |
| | } |
| | self.log(f"test/W1_branch{i}", np.mean(w1_br), on_epoch=True) |
| | self.log(f"test/W1_full_branch{i}", np.mean(w1_br_full), on_epoch=True) |
| | print(f"Branch {i} — W1: {np.mean(w1_br):.6f}±{np.std(w1_br, ddof=1):.6f}, " |
| | f"W2: {np.mean(w2_br):.6f}±{np.std(w2_br, ddof=1):.6f}, " |
| | f"MMD: {np.mean(mmd_br):.6f}±{np.std(mmd_br, ddof=1):.6f}") |
| | print(f"Branch {i} (full) — W1: {np.mean(w1_br_full):.6f}±{np.std(w1_br_full, ddof=1):.6f}, " |
| | f"W2: {np.mean(w2_br_full):.6f}±{np.std(w2_br_full, ddof=1):.6f}, " |
| | f"MMD: {np.mean(mmd_br_full):.6f}±{np.std(mmd_br_full, ddof=1):.6f}") |
| |
|
| | |
| | gt_list = [torch.tensor(timepoint_data[k], dtype=torch.float32) for k in gt_keys] |
| | if len(gt_list) > 0 and len(all_endpoints) > 0: |
| | gt_all = torch.cat(gt_list, dim=0) |
| | pred_all = torch.cat([e.cpu() for e in all_endpoints], dim=0) |
| |
|
| | w1_trials, w2_trials, mmd_trials = [], [], [] |
| | w1_trials_full, w2_trials_full, mmd_trials_full = [], [], [] |
| | for trial in range(n_trials): |
| | n_min = min(pred_all.shape[0], gt_all.shape[0]) |
| | perm_pred = torch.randperm(pred_all.shape[0])[:n_min] |
| | perm_gt = torch.randperm(gt_all.shape[0])[:n_min] |
| | |
| | m = compute_distribution_distances(pred_all[perm_pred, :2], gt_all[perm_gt, :2]) |
| | w1_trials.append(m["W1"]); w2_trials.append(m["W2"]); mmd_trials.append(m["MMD"]) |
| | |
| | m_full = compute_distribution_distances(pred_all[perm_pred], gt_all[perm_gt]) |
| | w1_trials_full.append(m_full["W1"]); w2_trials_full.append(m_full["W2"]); mmd_trials_full.append(m_full["MMD"]) |
| |
|
| | w1_mean, w1_std = np.mean(w1_trials), np.std(w1_trials, ddof=1) |
| | w2_mean, w2_std = np.mean(w2_trials), np.std(w2_trials, ddof=1) |
| | mmd_mean, mmd_std = np.mean(mmd_trials), np.std(mmd_trials, ddof=1) |
| | w1_mean_f, w1_std_f = np.mean(w1_trials_full), np.std(w1_trials_full, ddof=1) |
| | w2_mean_f, w2_std_f = np.mean(w2_trials_full), np.std(w2_trials_full, ddof=1) |
| | mmd_mean_f, mmd_std_f = np.mean(mmd_trials_full), np.std(mmd_trials_full, ddof=1) |
| | self.log("test/W1_t7_combined", w1_mean, on_epoch=True) |
| | self.log("test/W2_t7_combined", w2_mean, on_epoch=True) |
| | self.log("test/MMD_t7_combined", mmd_mean, on_epoch=True) |
| | self.log("test/W1_full_t7_combined", w1_mean_f, on_epoch=True) |
| | self.log("test/W2_full_t7_combined", w2_mean_f, on_epoch=True) |
| | self.log("test/MMD_full_t7_combined", mmd_mean_f, on_epoch=True) |
| | metrics_dict['t7_combined'] = { |
| | "W1_mean": float(w1_mean), "W1_std": float(w1_std), |
| | "W2_mean": float(w2_mean), "W2_std": float(w2_std), |
| | "MMD_mean": float(mmd_mean), "MMD_std": float(mmd_std), |
| | "W1_full_mean": float(w1_mean_f), "W1_full_std": float(w1_std_f), |
| | "W2_full_mean": float(w2_mean_f), "W2_full_std": float(w2_std_f), |
| | "MMD_full_mean": float(mmd_mean_f), "MMD_full_std": float(mmd_std_f), |
| | "n_trials": n_trials, |
| | } |
| | print(f"\n=== Combined @ t7 ===") |
| | print(f"W1: {w1_mean:.6f} ± {w1_std:.6f}") |
| | print(f"W2: {w2_mean:.6f} ± {w2_std:.6f}") |
| | print(f"MMD: {mmd_mean:.6f} ± {mmd_std:.6f}") |
| | print(f"W1 (full): {w1_mean_f:.6f} ± {w1_std_f:.6f}") |
| | print(f"W2 (full): {w2_mean_f:.6f} ± {w2_std_f:.6f}") |
| | print(f"MMD (full): {mmd_mean_f:.6f} ± {mmd_std_f:.6f}") |
| |
|
| | |
| | run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name |
| | results_dir = os.path.join(self.args.working_dir, 'results', run_name) |
| | figures_dir = f'{results_dir}/figures' |
| | os.makedirs(figures_dir, exist_ok=True) |
| |
|
| | |
| | metrics_path = f'{results_dir}/metrics.json' |
| | with open(metrics_path, 'w') as f: |
| | json.dump(metrics_dict, f, indent=2) |
| | print(f"Metrics saved to {metrics_path}") |
| |
|
| | |
| | detailed_csv_path = f'{results_dir}/metrics_detailed.csv' |
| | with open(detailed_csv_path, 'w', newline='') as csvfile: |
| | writer = csv.writer(csvfile) |
| | writer.writerow(['Metric_Group', |
| | 'W1_Mean', 'W1_Std', 'W2_Mean', 'W2_Std', 'MMD_Mean', 'MMD_Std', |
| | 'W1_Full_Mean', 'W1_Full_Std', 'W2_Full_Mean', 'W2_Full_Std', 'MMD_Full_Mean', 'MMD_Full_Std']) |
| | for key in sorted(metrics_dict.keys()): |
| | m = metrics_dict[key] |
| | writer.writerow([key, |
| | f'{m.get("W1_mean", 0):.6f}', f'{m.get("W1_std", 0):.6f}', |
| | f'{m.get("W2_mean", 0):.6f}', f'{m.get("W2_std", 0):.6f}', |
| | f'{m.get("MMD_mean", 0):.6f}', f'{m.get("MMD_std", 0):.6f}', |
| | f'{m.get("W1_full_mean", 0):.6f}', f'{m.get("W1_full_std", 0):.6f}', |
| | f'{m.get("W2_full_mean", 0):.6f}', f'{m.get("W2_full_std", 0):.6f}', |
| | f'{m.get("MMD_full_mean", 0):.6f}', f'{m.get("MMD_full_std", 0):.6f}']) |
| | print(f"Detailed metrics CSV saved to {detailed_csv_path}") |
| |
|
| | |
| | self._plot_veres_branches(all_trajs, timepoint_data, figures_dir, n_branches) |
| | self._plot_veres_combined(all_trajs, timepoint_data, figures_dir, n_branches) |
| |
|
| | print(f"Veres figures saved to {figures_dir}") |
| |
|
| | def _plot_veres_branches(self, all_trajs, timepoint_data, save_dir, n_branches): |
| | """Plot each branch separately in PCA space (PC1 vs PC2).""" |
| | branch_colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7', '#DFE6E9', |
| | '#74B9FF', '#A29BFE', '#FFB74D', '#AED581', '#F06292', '#BA68C8', |
| | '#4DB6AC', '#81C784', '#FFD54F', '#90A4AE', '#F48FB1', '#CE93D8', |
| | '#64B5F6', '#C5E1A5'] |
| | |
| | |
| | t0_2d = timepoint_data['t0'].cpu().numpy()[:, :2] |
| | t7_2d = [timepoint_data[f't7_{i}'].cpu().numpy()[:, :2] for i in range(n_branches)] |
| | |
| | |
| | trajs_2d = [] |
| | for traj in all_trajs: |
| | trajs_2d.append(traj.cpu().numpy()[:, :, :2]) |
| | |
| | |
| | all_coords = [t0_2d] + t7_2d |
| | for traj_2d in trajs_2d: |
| | all_coords.append(traj_2d.reshape(-1, 2)) |
| | |
| | all_coords = np.concatenate(all_coords, axis=0) |
| | x_min, x_max = all_coords[:, 0].min(), all_coords[:, 0].max() |
| | y_min, y_max = all_coords[:, 1].min(), all_coords[:, 1].max() |
| | |
| | x_margin = 0.05 * (x_max - x_min) |
| | y_margin = 0.05 * (y_max - y_min) |
| | x_min -= x_margin |
| | x_max += x_margin |
| | y_min -= y_margin |
| | y_max += y_margin |
| |
|
| | for i, traj_2d in enumerate(trajs_2d): |
| | fig, ax = plt.subplots(figsize=(10, 8)) |
| | c_end = branch_colors[i % len(branch_colors)] |
| | |
| | |
| | ax.scatter(t0_2d[:, 0], t0_2d[:, 1], |
| | c='#05009E', s=80, alpha=0.4, marker='x', |
| | label='t=0 cells', linewidth=1.5) |
| | ax.scatter(t7_2d[i][:, 0], t7_2d[i][:, 1], |
| | c=c_end, s=80, alpha=0.4, marker='x', |
| | label=f't=7 (branch {i+1}) cells', linewidth=1.5) |
| | |
| | |
| | cmap_colors = ["#05009E", "#A19EFF", c_end] |
| | cmap = LinearSegmentedColormap.from_list(f"veres_cmap_{i}", cmap_colors) |
| | n_time = traj_2d.shape[1] |
| | segments = [] |
| | seg_colors = [] |
| | color_vals = cmap(np.linspace(0, 1, n_time)) |
| | for j in range(traj_2d.shape[0]): |
| | pts = traj_2d[j] |
| | segs = np.stack([pts[:-1], pts[1:]], axis=1) |
| | segments.append(segs) |
| | seg_colors.append(color_vals[:-1]) |
| | segments = np.concatenate(segments, axis=0) |
| | seg_colors = np.concatenate(seg_colors, axis=0) |
| | lc = LineCollection(segments, colors=seg_colors, linewidths=2, alpha=0.8) |
| | ax.add_collection(lc) |
| | |
| | |
| | ax.scatter(traj_2d[:, 0, 0], traj_2d[:, 0, 1], |
| | c='#05009E', s=30, marker='o', label='Trajectory start (t=0)', |
| | zorder=5, edgecolors='white', linewidth=1) |
| | ax.scatter(traj_2d[:, -1, 0], traj_2d[:, -1, 1], |
| | c=c_end, s=30, marker='o', label='Trajectory end (t=1)', |
| | zorder=5, edgecolors='white', linewidth=1) |
| | |
| | ax.set_xlim(x_min, x_max) |
| | ax.set_ylim(y_min, y_max) |
| | ax.set_xlabel("PC 1", fontsize=12) |
| | ax.set_ylabel("PC 2", fontsize=12) |
| | ax.set_title(f"Branch {i+1}: Trajectories (PCA)", fontsize=14) |
| | ax.grid(True, alpha=0.3) |
| | ax.legend(loc='upper right', fontsize=9, frameon=False) |
| | |
| | plt.tight_layout() |
| | plt.savefig(f'{save_dir}/{self.args.data_name}_branch{i+1}.png', dpi=300) |
| | plt.close() |
| |
|
| | def _plot_veres_combined(self, all_trajs, timepoint_data, save_dir, n_branches): |
| | """Plot all branches together in PCA space (PC1 vs PC2).""" |
| | branch_colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7', '#DFE6E9', |
| | '#74B9FF', '#A29BFE', '#FFB74D', '#AED581', '#F06292', '#BA68C8', |
| | '#4DB6AC', '#81C784', '#FFD54F', '#90A4AE', '#F48FB1', '#CE93D8', |
| | '#64B5F6', '#C5E1A5'] |
| | |
| | |
| | t0_2d = timepoint_data['t0'].cpu().numpy()[:, :2] |
| | t7_2d = [timepoint_data[f't7_{i}'].cpu().numpy()[:, :2] for i in range(n_branches)] |
| | |
| | |
| | trajs_2d = [] |
| | for traj in all_trajs: |
| | trajs_2d.append(traj.cpu().numpy()[:, :, :2]) |
| | |
| | |
| | all_coords_real = [t0_2d] + t7_2d |
| | all_coords_real = np.concatenate(all_coords_real, axis=0) |
| | x_min, x_max = all_coords_real[:, 0].min(), all_coords_real[:, 0].max() |
| | y_min, y_max = all_coords_real[:, 1].min(), all_coords_real[:, 1].max() |
| | x_margin = 0.05 * (x_max - x_min) |
| | y_margin = 0.05 * (y_max - y_min) |
| | x_min -= x_margin |
| | x_max += x_margin |
| | y_min -= y_margin |
| | y_max += y_margin |
| | |
| | fig, ax = plt.subplots(figsize=(14, 12)) |
| | ax.set_xlim(x_min, x_max) |
| | ax.set_ylim(y_min, y_max) |
| | |
| | |
| | ax.scatter(t0_2d[:, 0], t0_2d[:, 1], |
| | c='#05009E', s=60, alpha=0.3, marker='x', |
| | label='t=0 cells', linewidth=1.5) |
| | |
| | |
| | for i, traj_2d in enumerate(trajs_2d): |
| | c_end = branch_colors[i % len(branch_colors)] |
| | |
| | |
| | ax.scatter(t7_2d[i][:, 0], t7_2d[i][:, 1], |
| | c=c_end, s=60, alpha=0.3, marker='x', |
| | label=f't=7 (branch {i+1})', linewidth=1.5) |
| | |
| | |
| | cmap_colors = ["#05009E", "#A19EFF", c_end] |
| | cmap = LinearSegmentedColormap.from_list(f"veres_combined_cmap_{i}", cmap_colors) |
| | n_time = traj_2d.shape[1] |
| | segments = [] |
| | seg_colors = [] |
| | color_vals = cmap(np.linspace(0, 1, n_time)) |
| | for j in range(traj_2d.shape[0]): |
| | pts = traj_2d[j] |
| | segs = np.stack([pts[:-1], pts[1:]], axis=1) |
| | segments.append(segs) |
| | seg_colors.append(color_vals[:-1]) |
| | segments = np.concatenate(segments, axis=0) |
| | seg_colors = np.concatenate(seg_colors, axis=0) |
| | lc = LineCollection(segments, colors=seg_colors, linewidths=1.5, alpha=0.6) |
| | ax.add_collection(lc) |
| | |
| | |
| | ax.scatter(traj_2d[:, 0, 0], traj_2d[:, 0, 1], |
| | c='#05009E', s=20, marker='o', |
| | zorder=5, edgecolors='white', linewidth=0.5, alpha=0.7) |
| | ax.scatter(traj_2d[:, -1, 0], traj_2d[:, -1, 1], |
| | c=c_end, s=20, marker='o', |
| | zorder=5, edgecolors='white', linewidth=0.5, alpha=0.7) |
| | |
| | ax.set_xlabel("PC 1", fontsize=14) |
| | ax.set_ylabel("PC 2", fontsize=14) |
| | ax.set_title(f"All {n_branches} Branch Trajectories (Veres) - PCA Projection", |
| | fontsize=16, weight='bold') |
| | ax.grid(True, alpha=0.3) |
| | ax.legend(loc='upper right', fontsize=10, frameon=False, ncol=2) |
| | |
| | plt.tight_layout() |
| | plt.savefig(f'{save_dir}/{self.args.data_name}_combined.png', dpi=300) |
| | plt.close() |