| | import os |
| | import matplotlib.pyplot as plt |
| | import seaborn as sns |
| | import numpy as np |
| | from scipy.stats import spearmanr, pearsonr |
| | from sklearn.metrics import r2_score |
| |
|
| | from visualization.pauc_plot import plot_roc_with_ci |
| |
|
| |
|
| | def regression_ci_plot(y_true, y_pred, save_path, title=None): |
| | """ |
| | Calculate the spearman rho and p-value of the regression model. |
| | Plot the line of best fit with 95% confidence intervals for spearman rho. |
| | Display the R-squared value, spearman rho, pearson rho, and p-values. |
| | """ |
| | |
| | y_true, y_pred = y_true.flatten(), y_pred.flatten() |
| | mask = np.isfinite(y_true) & np.isfinite(y_pred) |
| | y_true, y_pred = y_true[mask], y_pred[mask] |
| | r2 = r2_score(y_true, y_pred) |
| | r_s, p_s = spearmanr(y_true, y_pred) |
| | r_p, p_p = pearsonr(y_true, y_pred) |
| |
|
| | |
| | fig, ax = plt.subplots(figsize=(8, 6)) |
| | sns.scatterplot(x=y_true, y=y_pred, ax=ax) |
| | sns.regplot( |
| | x=y_true, y=y_pred, |
| | ci=95, ax=ax, scatter=False, |
| | line_kws={'color': 'red'} |
| | ) |
| |
|
| | ax.set_xlabel('True Values') |
| | ax.set_ylabel('Predicted Values') |
| | if title: |
| | ax.set_title(title) |
| | else: |
| | ax.set_title('Regression Plot with 95% Confidence Interval') |
| |
|
| | |
| | stats_text = ( |
| | f"$R^2$ = {r2:.2f}\n" |
| | f"Spearman $\\rho$ = {r_s:.2f} (p = {p_s:.2e})\n" |
| | f"Pearson $\\rho$ = {r_p:.2f} (p = {p_p:.2e})" |
| | ) |
| | ax.text( |
| | 0.05, 0.95, stats_text, |
| | transform=ax.transAxes, |
| | fontsize=12, verticalalignment='top' |
| | ) |
| |
|
| | |
| | fig.savefig(save_path, dpi=300, bbox_inches='tight') |
| | plt.close(fig) |
| |
|
| |
|
| | def classification_ci_plot(y_true, y_pred, save_path, title=None): |
| | """ |
| | Use pauc to display classification plot |
| | """ |
| | if len(y_pred.shape) == 3 and len(y_true.shape) == 2: |
| | y_pred = y_pred.reshape(-1, y_pred.shape[-1]) |
| | y_true = y_true.reshape(-1) |
| |
|
| | |
| | if len(y_pred.shape) == 2 and len(y_true.shape) == 2: |
| | y_pred = y_pred.flatten() |
| | y_true = y_true.flatten() |
| |
|
| | |
| | |
| | if y_true.shape[0] > 10000: |
| | y_pred = y_pred[:10000] |
| | y_true = y_true[:10000] |
| |
|
| | print(y_true.shape, y_pred.shape) |
| |
|
| | try: |
| | plot_roc_with_ci(y_true, y_pred, save_path, fig_title=title) |
| | except Exception as e: |
| | print(f"Error plotting pAUC curve, likely the wrong version: {e}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | import os |
| | os.makedirs("plots/test_plots", exist_ok=True) |
| | y_true = np.random.rand(100) |
| | y_pred = np.random.rand(100) |
| | regression_ci_plot(y_true, y_pred, "plots/test_plots/regression.png", title="Regression Plot") |
| |
|
| | y_true = np.random.randint(0, 2, (50, 514)) |
| | y_pred = np.random.rand(50, 514, 4) |
| | classification_ci_plot(y_true, y_pred, "plots/test_plots/classification.png", title="Classification Plot") |
| |
|