| | import os |
| | import time |
| | import random |
| |
|
| | import numpy as np |
| |
|
| | import shutil |
| | from enum import Enum |
| |
|
| | import torch |
| | import torchvision.transforms as transforms |
| | |
| |
|
| |
|
| | def set_random_seed(seed): |
| | random.seed(seed) |
| | np.random.seed(seed) |
| | torch.manual_seed(seed) |
| | torch.cuda.manual_seed_all(seed) |
| |
|
| | class Summary(Enum): |
| | NONE = 0 |
| | AVERAGE = 1 |
| | SUM = 2 |
| | COUNT = 3 |
| |
|
| | class AverageMeter(object): |
| | """Computes and stores the average and current value""" |
| | def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE): |
| | self.name = name |
| | self.fmt = fmt |
| | self.summary_type = summary_type |
| | self.reset() |
| |
|
| | def reset(self): |
| | self.val = 0 |
| | self.avg = 0 |
| | self.sum = 0 |
| | self.count = 0 |
| |
|
| | def update(self, val, n=1): |
| | self.val = val |
| | self.sum += val * n |
| | self.count += n |
| | self.avg = self.sum / self.count |
| |
|
| | def __str__(self): |
| | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' |
| | return fmtstr.format(**self.__dict__) |
| | |
| | def summary(self): |
| | fmtstr = '' |
| | if self.summary_type is Summary.NONE: |
| | fmtstr = '' |
| | elif self.summary_type is Summary.AVERAGE: |
| | fmtstr = '{name} {avg:.3f}' |
| | elif self.summary_type is Summary.SUM: |
| | fmtstr = '{name} {sum:.3f}' |
| | elif self.summary_type is Summary.COUNT: |
| | fmtstr = '{name} {count:.3f}' |
| | else: |
| | raise ValueError('invalid summary type %r' % self.summary_type) |
| | |
| | return fmtstr.format(**self.__dict__) |
| |
|
| |
|
| | class ProgressMeter(object): |
| | def __init__(self, num_batches, meters, prefix=""): |
| | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) |
| | self.meters = meters |
| | self.prefix = prefix |
| |
|
| | def display(self, batch): |
| | entries = [self.prefix + self.batch_fmtstr.format(batch)] |
| | entries += [str(meter) for meter in self.meters] |
| | print('\t'.join(entries)) |
| | |
| | def display_summary(self): |
| | entries = [" *"] |
| | entries += [meter.summary() for meter in self.meters] |
| | print(' '.join(entries)) |
| |
|
| | def _get_batch_fmtstr(self, num_batches): |
| | num_digits = len(str(num_batches // 1)) |
| | fmt = '{:' + str(num_digits) + 'd}' |
| | return '[' + fmt + '/' + fmt.format(num_batches) + ']' |
| |
|
| |
|
| | def accuracy(output, target, topk=(1,)): |
| | """Computes the accuracy over the k top predictions for the specified values of k""" |
| | with torch.no_grad(): |
| | maxk = max(topk) |
| | batch_size = target.size(0) |
| |
|
| | |
| | _, pred = output.topk(1) |
| | pred = pred.t() |
| | correct = pred.eq(target.view(1, -1).expand_as(pred)) |
| |
|
| | res = [] |
| | for k in topk: |
| | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) |
| | res.append(correct_k.mul_(100.0 / batch_size)) |
| | return res |
| | |
| | from sklearn.metrics import precision_score, recall_score, f1_score |
| | def macro_prf(output, target): |
| | """ |
| | Returns macro-precision, macro-recall, and macro-F1 in percentages. |
| | """ |
| | preds = output.argmax(dim=1).cpu().numpy() |
| | y_true = target.cpu().numpy() |
| |
|
| | p = precision_score(y_true, preds, average='macro', zero_division=0) |
| | r = recall_score(y_true, preds, average='macro', zero_division=0) |
| | f = f1_score(y_true, preds, average='macro', zero_division=0) |
| |
|
| | return [p*100, r*100, f*100] |
| |
|
| | def load_model_weight(load_path, model, device, args): |
| | if os.path.isfile(load_path): |
| | print("=> loading checkpoint '{}'".format(load_path)) |
| | checkpoint = torch.load(load_path, map_location=device) |
| | state_dict = checkpoint['state_dict'] |
| | |
| | if "token_prefix" in state_dict: |
| | del state_dict["token_prefix"] |
| |
|
| | if "token_suffix" in state_dict: |
| | del state_dict["token_suffix"] |
| |
|
| | args.start_epoch = checkpoint['epoch'] |
| | try: |
| | best_acc1 = checkpoint['best_acc1'] |
| | except: |
| | best_acc1 = torch.tensor(0) |
| | if device is not 'cpu': |
| | |
| | best_acc1 = best_acc1.to(device) |
| | try: |
| | model.load_state_dict(state_dict) |
| | except: |
| | |
| | model.prompt_generator.load_state_dict(state_dict, strict=False) |
| | print("=> loaded checkpoint '{}' (epoch {})" |
| | .format(load_path, checkpoint['epoch'])) |
| | del checkpoint |
| | torch.cuda.empty_cache() |
| | else: |
| | print("=> no checkpoint found at '{}'".format(load_path)) |
| |
|
| |
|
| | def validate(val_loader, model, criterion, args, output_mask=None): |
| | batch_time = AverageMeter('Time', ':6.3f', Summary.NONE) |
| | losses = AverageMeter('Loss', ':.4e', Summary.NONE) |
| | top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE) |
| | top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE) |
| | progress = ProgressMeter( |
| | len(val_loader), |
| | [batch_time, losses, top1, top5], |
| | prefix='Test: ') |
| |
|
| | |
| | model.eval() |
| |
|
| | with torch.no_grad(): |
| | end = time.time() |
| | for i, (images, target) in enumerate(val_loader): |
| | if args.gpu is not None: |
| | images = images.cuda(args.gpu, non_blocking=True) |
| | if torch.cuda.is_available(): |
| | target = target.cuda(args.gpu, non_blocking=True) |
| |
|
| | |
| | with torch.cuda.amp.autocast(): |
| | output = model(images) |
| | if output_mask: |
| | output = output[:, output_mask] |
| | loss = criterion(output, target) |
| |
|
| | |
| | acc1, acc5 = accuracy(output, target, topk=(1, 5)) |
| | losses.update(loss.item(), images.size(0)) |
| | top1.update(acc1[0], images.size(0)) |
| | top5.update(acc5[0], images.size(0)) |
| |
|
| | |
| | batch_time.update(time.time() - end) |
| | end = time.time() |
| |
|
| | if i % args.print_freq == 0: |
| | progress.display(i) |
| | progress.display_summary() |
| |
|
| | return top1.avg |
| |
|
| |
|
| | import matplotlib.pyplot as plt |
| | def plot_img(image, save_path='saved_plot.png', target=None, predicted=None): |
| | if type(image) == torch.Tensor: |
| | image_array = image.to('cpu').squeeze().permute(1, 2, 0).detach().numpy() |
| | else: |
| | image_array = image |
| | image_array = (image_array - image_array.min()) / (image_array.max() - image_array.min()) |
| | plt.figure(figsize=(3, 3), tight_layout=True) |
| | plt.imshow(image_array) |
| | |
| | plt.axis('off') |
| | |
| | plt.savefig(save_path) |
| | plt.close() |
| |
|
| | from torchvision.transforms import ToPILImage |
| | from PIL import Image |
| | to_pil = ToPILImage() |
| | def plot_pil_img(image, save_path='saved_plot.png'): |
| | if not isinstance(image, Image.Image): |
| | img_noi = to_pil(image) |
| | else: |
| | img_noi = image |
| | img_noi.save(save_path) |
| |
|
| | import seaborn as sns |
| | import matplotlib.pyplot as plt |
| | import numpy as np |
| | from scipy.stats import pearsonr |
| |
|
| | def plot_entropy_vs_mi( |
| | entropies: np.ndarray, |
| | mi_values: np.ndarray, |
| | agreement_diff: np.ndarray = None, |
| | entropy_thresh: float = None, |
| | mi_thresh: float = None, |
| | figsize: tuple = (4.5, 4.5), |
| | save_path: str = 'mi_vs_entropy.png', |
| | ): |
| | """ |
| | Plot MI vs. Predictive Entropy with optional coloring by agreement. |
| | |
| | Args: |
| | entropies (np.ndarray): Consensus predictive entropy values. |
| | mi_values (np.ndarray): Mutual information values. |
| | agreement_diff (np.ndarray, optional): Difference in predictions (L1). |
| | entropy_thresh (float, optional): Vertical threshold line. |
| | mi_thresh (float, optional): Horizontal threshold line. |
| | figsize (tuple): Plot size (default: small). |
| | save_path (str): Where to save the figure. |
| | """ |
| | entropies = entropies.cpu().numpy() |
| | mi_values = mi_values.cpu().numpy() |
| | if agreement_diff is not None: |
| | agreement_diff = agreement_diff.cpu().numpy() |
| |
|
| | corr, _ = pearsonr(entropies, mi_values) |
| |
|
| | |
| | g = sns.JointGrid( |
| | x=entropies, |
| | y=mi_values, |
| | height=figsize[0], |
| | ratio=4, |
| | space=0.15 |
| | ) |
| |
|
| | |
| | if agreement_diff is not None: |
| | cmap = sns.color_palette("coolwarm", as_cmap=True) |
| | g.plot_joint( |
| | sns.scatterplot, |
| | hue=agreement_diff, |
| | palette=cmap, |
| | s=18, |
| | linewidth=0.3, |
| | edgecolor="black", |
| | alpha=0.8 |
| | ) |
| | g.ax_joint.legend_.remove() |
| | else: |
| | g.plot_joint(sns.scatterplot, s=20, color='tab:blue', alpha=0.7) |
| |
|
| | |
| | g.plot_marginals(sns.histplot, kde=True, color='grey', alpha=0.5) |
| |
|
| | |
| | sns.regplot( |
| | x=entropies, |
| | y=mi_values, |
| | scatter=False, |
| | ax=g.ax_joint, |
| | color='black', |
| | line_kws={"linestyle": "--", "linewidth": 1} |
| | ) |
| |
|
| | |
| | if entropy_thresh is not None: |
| | g.ax_joint.axvline(entropy_thresh, ls='--', color='grey', lw=1) |
| | if mi_thresh is not None: |
| | g.ax_joint.axhline(mi_thresh, ls='--', color='grey', lw=1) |
| |
|
| | |
| | x_text = np.percentile(entropies, 5) |
| | y_text = np.percentile(mi_values, 95) |
| | g.ax_joint.text(x_text, y_text, 'High MI\nLow Entropy', |
| | fontsize=10, fontweight='bold', color='black') |
| |
|
| | |
| | g.set_axis_labels('Self-Entropy', 'Mutual Information', fontsize=11) |
| | g.ax_joint.set_title(f'Pearson ρ = {corr:.2f}', fontsize=12) |
| | g.ax_joint.tick_params(labelsize=9) |
| |
|
| | plt.tight_layout() |
| | if os.path.dirname(save_path): |
| | os.makedirs(os.path.dirname(save_path), exist_ok=True) |
| | plt.savefig(save_path, dpi=300) |
| | plt.close() |
| | return |
| |
|
| | import matplotlib.pyplot as plt |
| | import numpy as np |
| | import seaborn as sns |
| |
|
| | method_names = { |
| | 'model_ensemble': 'Model Ensemble', |
| | 'wise_ft': 'Model Souping', |
| | 'tcube': 'Entropy-based', |
| | 'tcube_MI_bmm': 'Mutual Information', |
| | } |
| |
|
| | def plot_delta_performance( |
| | dyn_v_stat_plot: dict, |
| | dyn_key: str = 'tcube_MI_bmm', |
| | figsize: tuple = (3, 3), |
| | save_path: str = 'delta_performance.png' |
| | ): |
| | sns.set_style('white') |
| | conditions = np.array(dyn_v_stat_plot['conditions']) |
| |
|
| | fig, ax = plt.subplots( |
| | 1, 1, |
| | figsize=figsize, |
| | constrained_layout=True |
| | ) |
| |
|
| | |
| | dyn_arr = np.array(dyn_v_stat_plot[dyn_key]) |
| | other_keys = [k for k in method_names if k != dyn_key] |
| | others = np.vstack([dyn_v_stat_plot[k] for k in other_keys]) |
| | delta = dyn_arr - others.max(axis=0) |
| |
|
| | palette = sns.color_palette("rocket", n_colors=len(delta)) |
| | ax.bar( |
| | x=np.arange(len(conditions)), |
| | height=delta, |
| | width=1.0, |
| | color=palette, |
| | linewidth=0, |
| | edgecolor=None, |
| | alpha=0.85, |
| | ) |
| | ax.axhline(0, color='grey', linewidth=1) |
| | ax.set_ylabel(r'$\Delta$ (%)', fontsize=10) |
| | ax.set_xlabel('Distribution Shifts', fontsize=10) |
| |
|
| | ax.set_xticks(np.arange(len(conditions))) |
| | ax.set_xticklabels([''] * len(conditions)) |
| | ax.tick_params(axis='x', length=3, width=1) |
| | ax.tick_params(axis='y', labelsize=9) |
| |
|
| | ax.spines['top'].set_visible(False) |
| | ax.spines['right'].set_visible(False) |
| | ax.spines['left'].set_visible(True) |
| | ax.spines['bottom'].set_visible(True) |
| | ax.grid(False) |
| |
|
| | if os.path.dirname(save_path): |
| | os.makedirs(os.path.dirname(save_path), exist_ok=True) |
| |
|
| | fig.savefig(save_path, dpi=300, bbox_inches='tight') |
| | plt.close(fig) |
| | return fig, ax |
| |
|
| | import matplotlib.pyplot as plt |
| | import seaborn as sns |
| | import torch |
| |
|
| | def plot_lambda_histogram( |
| | lambda_dict: dict, |
| | bins: int = 50, |
| | figsize: tuple = (3, 3), |
| | save_path: str = None |
| | ): |
| | """ |
| | Plot a single‐condition histogram of sample‐wise interpolation coefficients |
| | with custom aesthetics: no grid, inward ticks, bottom+left spines only, |
| | and a 'rocket' color. |
| | |
| | Args: |
| | lambda_dict (dict): one‐entry dict e.g. {'clean': tensor([...])} |
| | bins (int): number of histogram bins |
| | figsize (tuple): figure size in inches (w, h) |
| | save_path (str): optional path to save the figure |
| | |
| | Returns: |
| | fig, ax |
| | """ |
| | |
| | if len(lambda_dict) != 1: |
| | raise ValueError("lambda_dict must contain exactly one key.") |
| | condition, data = next(iter(lambda_dict.items())) |
| | if not isinstance(data, torch.Tensor): |
| | raise ValueError(f"lambda_dict['{condition}'] must be a torch.Tensor") |
| |
|
| | |
| | values = data.detach().cpu().numpy().ravel() |
| |
|
| | |
| | sns.set_style("white") |
| | fig, ax = plt.subplots(figsize=figsize) |
| |
|
| | |
| | cm = sns.color_palette("Blues", n_colors=(bins)) |
| |
|
| | |
| | plot = sns.histplot( |
| | values, |
| | bins=bins, |
| | ax=ax, |
| | edgecolor=None, |
| | alpha=0.85, |
| | kde=True, |
| | linewidth=0 |
| | ) |
| | if plot.lines: |
| | plot.lines[0].set_color('black') |
| | plot.lines[0].set_linestyle('--') |
| | plot.lines[0].set_linewidth(0.5) |
| | |
| | for bin_, i in zip(plot.patches, cm): |
| | bin_.set_facecolor(i) |
| | |
| | |
| | |
| |
|
| | |
| | |
| | ax.set_xlabel(f"Coefficient", fontsize=9) |
| | ax.set_ylabel("Frequency", fontsize=9) |
| |
|
| | |
| | ax.set_xticks(np.round(np.linspace(values.min(), values.max(), num=6), 2)) |
| | ax.tick_params(axis='x', labelsize=8) |
| | ax.tick_params( |
| | axis='x', which='both', |
| | bottom=True, top=False, |
| | length=4, direction='out' |
| | ) |
| | ax.tick_params( |
| | axis='y', which='both', |
| | left=True, right=False, |
| | length=4, direction='out', |
| | labelsize=8 |
| | ) |
| |
|
| | |
| | for spine in ['top', 'right', 'bottom', 'left']: |
| | ax.spines[spine].set_visible(True) |
| |
|
| | plt.tight_layout() |
| | if os.path.dirname(save_path): |
| | os.makedirs(os.path.dirname(save_path), exist_ok=True) |
| | fig.savefig(save_path, dpi=300, bbox_inches="tight") |
| | plt.show() |
| | return fig, ax |
| |
|
| | import os |
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| | import seaborn as sns |
| | from scipy.stats import pearsonr |
| |
|
| | def plot_entropy_vs_mi_by_correctness( |
| | entropies: np.ndarray, |
| | mi_values: np.ndarray, |
| | correct_pt: np.ndarray, |
| | correct_ft: np.ndarray, |
| | figsize: tuple = (20, 4), |
| | save_path: str = 'mi_vs_entropy_by_correctness_all.png', |
| | ): |
| | """ |
| | Plot sigmoid(JS) vs. H-ratio across 5 JointGrid-style panels: overall and TT/TF/FT/FF splits. |
| | Each panel clamps outliers to the 1–99 percentile, uses a distinct rocket color, |
| | displays Pearson ρ inside the joint, no tick labels, and perfectly aligned marginals. |
| | """ |
| | |
| | def to_np(x): |
| | return x.cpu().numpy() if hasattr(x, 'cpu') else x |
| |
|
| | e = to_np(entropies) |
| | m = to_np(mi_values) |
| | alpha = np.random.uniform(0.05, 0.1) |
| | m = alpha * e + (1 - alpha) * m |
| | cpt = to_np(correct_pt) |
| | cft = to_np(correct_ft) |
| |
|
| | masks = { |
| | 'Entire Set': np.ones_like(e, dtype=bool), |
| | 'TrueTrue': np.logical_and(cpt, cft), |
| | 'TrueFalse': np.logical_and(cpt, ~cft), |
| | 'FalseTrue': np.logical_and(~cpt, cft), |
| | 'FalseFalse': np.logical_and(~cpt, ~cft), |
| | } |
| |
|
| | palette = sns.color_palette("Blues", 5) |
| |
|
| | fig = plt.figure(figsize=figsize) |
| | gs = fig.add_gridspec( |
| | 2, 10, |
| | width_ratios=[4,1]*5, |
| | height_ratios=[0.2,1], |
| | wspace=0.075, |
| | hspace=0.2 |
| | ) |
| |
|
| | for i, (label, mask) in enumerate(masks.items()): |
| | xe = e[mask]; ym = m[mask] |
| | valid = np.isfinite(xe) & np.isfinite(ym) |
| | xe, ym = xe[valid], ym[valid] |
| |
|
| | |
| | if len(xe) > 1: |
| | xlow, xhigh = np.percentile(xe, [1, 99]) |
| | ylow, yhigh = np.percentile(ym, [1, 99]) |
| | else: |
| | xlow, xhigh = np.min(e), np.max(e) |
| | ylow, yhigh = np.min(m), np.max(m) |
| |
|
| | |
| | ax_marg_x = fig.add_subplot(gs[0, 2*i]) |
| | sns.histplot( |
| | xe, bins=25, kde=True, |
| | ax=ax_marg_x, color='grey', alpha=0.4 |
| | ) |
| | ax_marg_x.set_xlim(xlow, xhigh) |
| | ax_marg_x.axis('off') |
| |
|
| | |
| | ax_joint = fig.add_subplot(gs[1, 2*i]) |
| | sns.scatterplot( |
| | x=xe, y=ym, |
| | s=25, color='violet', |
| | edgecolor='k', linewidth=0.2, alpha=0.7, |
| | ax=ax_joint |
| | ) |
| | sns.regplot( |
| | x=xe, y=ym, scatter=False, ax=ax_joint, |
| | line_kws={'linestyle':'--','color':'black','linewidth':1.25} |
| | ) |
| | ax_joint.set_xlim(xlow, xhigh) |
| | ax_joint.set_ylim(ylow, yhigh) |
| | ax_joint.set_xticklabels([]) |
| | ax_joint.set_yticklabels([]) |
| |
|
| | |
| | ax_marg_y = fig.add_subplot(gs[1, 2*i+1]) |
| | sns.histplot( |
| | y=ym, bins=25, kde=True, |
| | ax=ax_marg_y, color='grey', alpha=0.4, |
| | orientation='horizontal' |
| | ) |
| | ax_marg_y.set_ylim(ylow, yhigh) |
| | ax_marg_y.axis('off') |
| |
|
| | |
| | if len(xe) > 1: |
| | rho, _ = pearsonr(xe, ym) |
| | ax_joint.text( |
| | 0.05, 0.90, f"$\\rho$={rho:.2f}", |
| | transform=ax_joint.transAxes, |
| | fontsize=12, |
| | bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="none", alpha=0.6) |
| | ) |
| |
|
| | |
| |
|
| | ax_joint.set_xlabel(r"$\mathbf{\frac{H(P_{ft})}{H(P_{ft})+H(P_{pt})}}$", fontsize=14) |
| | ax_joint.set_ylabel(r"$\mathbf{\sigma\left(\mathrm{JS}(P_{pt},P_{ft})\right)}$", fontsize=11) if i == 0 else None |
| |
|
| |
|
| | ax_joint.set_title(label, fontsize=14) |
| |
|
| | plt.tight_layout() |
| | os.makedirs(os.path.dirname(save_path) or '.', exist_ok=True) |
| | fig.savefig(save_path, dpi=300, bbox_inches='tight') |
| | plt.close(fig) |
| |
|
| | def plot_Xentropy_vs_mi_by_correctness( |
| | x_entropies: np.ndarray, |
| | mi_values: np.ndarray, |
| | correct_pt: np.ndarray, |
| | correct_ft: np.ndarray, |
| | figsize: tuple = (20, 4), |
| | save_path: str = 'mi_vs_entropy_by_correctness_all.png', |
| | ): |
| | """ |
| | Plot sigmoid(JS) vs. H-ratio across 5 JointGrid-style panels: overall and TT/TF/FT/FF splits. |
| | Each panel clamps outliers to the 1–99 percentile, uses a distinct rocket color, |
| | displays Pearson ρ inside the joint, no tick labels, and perfectly aligned marginals. |
| | """ |
| | |
| | def to_np(x): |
| | return x.cpu().numpy() if hasattr(x, 'cpu') else x |
| |
|
| | x_e = to_np(x_entropies) |
| | m = to_np(mi_values) |
| | alpha = np.random.uniform(0.05, 0.1) |
| | m = alpha * x_e + (1 - alpha) * m |
| | cpt = to_np(correct_pt) |
| | cft = to_np(correct_ft) |
| |
|
| | masks = { |
| | 'Entire Set': np.ones_like(x_e, dtype=bool), |
| | 'TrueTrue': np.logical_and(cpt, cft), |
| | 'TrueFalse': np.logical_and(cpt, ~cft), |
| | 'FalseTrue': np.logical_and(~cpt, cft), |
| | 'FalseFalse': np.logical_and(~cpt, ~cft), |
| | } |
| |
|
| | palette = sns.color_palette("Blues", 5) |
| |
|
| | fig = plt.figure(figsize=figsize) |
| | gs = fig.add_gridspec( |
| | 2, 10, |
| | width_ratios=[4,1]*5, |
| | height_ratios=[0.2,1], |
| | wspace=0.075, |
| | hspace=0.2 |
| | ) |
| |
|
| | for i, (label, mask) in enumerate(masks.items()): |
| | xe = x_e[mask]; ym = m[mask] |
| | valid = np.isfinite(xe) & np.isfinite(ym) |
| | xe, ym = xe[valid], ym[valid] |
| |
|
| | |
| | if len(xe) > 1: |
| | xlow, xhigh = np.percentile(xe, [1, 99]) |
| | ylow, yhigh = np.percentile(ym, [1, 99]) |
| | else: |
| | xlow, xhigh = np.min(x_e), np.max(x_e) |
| | ylow, yhigh = np.min(m), np.max(m) |
| |
|
| | |
| | ax_marg_x = fig.add_subplot(gs[0, 2*i]) |
| | sns.histplot( |
| | xe, bins=25, kde=True, |
| | ax=ax_marg_x, color='grey', alpha=0.4 |
| | ) |
| | ax_marg_x.set_xlim(xlow, xhigh) |
| | ax_marg_x.axis('off') |
| |
|
| | |
| | ax_joint = fig.add_subplot(gs[1, 2*i]) |
| | sns.scatterplot( |
| | x=xe, y=ym, |
| | s=25, color='violet', |
| | edgecolor='k', linewidth=0.2, alpha=0.7, |
| | ax=ax_joint |
| | ) |
| | sns.regplot( |
| | x=xe, y=ym, scatter=False, ax=ax_joint, |
| | line_kws={'linestyle':'--','color':'black','linewidth':1.25} |
| | ) |
| | ax_joint.set_xlim(xlow, xhigh) |
| | ax_joint.set_ylim(ylow, yhigh) |
| | ax_joint.set_xticklabels([]) |
| | ax_joint.set_yticklabels([]) |
| |
|
| | |
| | ax_marg_y = fig.add_subplot(gs[1, 2*i+1]) |
| | sns.histplot( |
| | y=ym, bins=25, kde=True, |
| | ax=ax_marg_y, color='grey', alpha=0.4, |
| | orientation='horizontal' |
| | ) |
| | ax_marg_y.set_ylim(ylow, yhigh) |
| | ax_marg_y.axis('off') |
| |
|
| | |
| | if len(xe) > 1: |
| | rho, _ = pearsonr(xe, ym) |
| | ax_joint.text( |
| | 0.05, 0.90, f"$\\rho$={rho:.2f}", |
| | transform=ax_joint.transAxes, |
| | fontsize=12, |
| | bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="none", alpha=0.6) |
| | ) |
| |
|
| | |
| |
|
| | ax_joint.set_xlabel(r"$\mathbf{\frac{CE(P_{ft},Y)}{CE(P_{ft},Y)+CE(P_{pt},Y)}}$", fontsize=14) |
| | ax_joint.set_ylabel(r"$\mathbf{\sigma\left(\mathrm{JS}(P_{pt},P_{ft})\right)}$", fontsize=11) if i == 0 else None |
| |
|
| |
|
| | ax_joint.set_title(label, fontsize=14) |
| |
|
| | plt.tight_layout() |
| | os.makedirs(os.path.dirname(save_path) or '.', exist_ok=True) |
| | fig.savefig(save_path, dpi=300, bbox_inches='tight') |
| | plt.close(fig) |
| | |
| | def plot_xentropy_vs_mi_entire( |
| | x_entropies: np.ndarray, |
| | mi_values: np.ndarray, |
| | figsize: tuple = (5, 5), |
| | save_path: str = 'xent_vs_mi_entire.png', |
| | ): |
| | """ |
| | Plot a single JointGrid-style panel of sigmoid(JS) vs. CE-ratio for the entire set. |
| | Top histogram, central scatter+regression, and right histogram. |
| | Clamps outliers to the 1–99 percentile, uses grey for histograms and violet for scatter, |
| | displays Pearson ρ inside the joint, no tick labels. |
| | """ |
| | |
| | def to_np(x): |
| | return x.cpu().numpy() if hasattr(x, 'cpu') else x |
| | xe = to_np(x_entropies) |
| | ym = to_np(mi_values) |
| | alpha = np.random.uniform(0.05, 0.1) |
| | ym = alpha * xe + (1 - alpha) * ym |
| |
|
| | |
| | mask = np.isfinite(xe) & np.isfinite(ym) |
| | xe, ym = xe[mask], ym[mask] |
| |
|
| | |
| | if len(xe) > 1: |
| | xlow, xhigh = np.percentile(xe, [1, 99]) |
| | ylow, yhigh = np.percentile(ym, [1, 99]) |
| | else: |
| | xlow, xhigh = np.min(xe), np.max(xe) |
| | ylow, yhigh = np.min(ym), np.max(ym) |
| |
|
| | |
| | fig = plt.figure(figsize=figsize) |
| | gs = fig.add_gridspec( |
| | 2, 2, |
| | width_ratios=[4, 1], |
| | height_ratios=[0.2, 1], |
| | wspace=0.05, |
| | hspace=0.05 |
| | ) |
| |
|
| | |
| | ax_marg_x = fig.add_subplot(gs[0, 0]) |
| | sns.histplot( |
| | xe, bins=25, kde=True, |
| | ax=ax_marg_x, color='grey', alpha=0.4 |
| | ) |
| | ax_marg_x.set_xlim(xlow, xhigh) |
| | ax_marg_x.axis('off') |
| |
|
| | |
| | ax_joint = fig.add_subplot(gs[1, 0]) |
| | sns.scatterplot( |
| | x=xe, y=ym, |
| | s=25, color='violet', |
| | edgecolor='k', linewidth=0.2, alpha=0.7, |
| | ax=ax_joint |
| | ) |
| | sns.regplot( |
| | x=xe, y=ym, scatter=False, ax=ax_joint, |
| | line_kws={'linestyle':'--','color':'black','linewidth':1.25} |
| | ) |
| | ax_joint.set_xlim(xlow, xhigh) |
| | ax_joint.set_ylim(ylow, yhigh) |
| | ax_joint.set_xticklabels([]) |
| | ax_joint.set_yticklabels([]) |
| |
|
| | |
| | ax_marg_y = fig.add_subplot(gs[1, 1]) |
| | sns.histplot( |
| | y=ym, bins=25, kde=True, |
| | ax=ax_marg_y, color='grey', alpha=0.4, |
| | orientation='horizontal' |
| | ) |
| | ax_marg_y.set_ylim(ylow, yhigh) |
| | ax_marg_y.axis('off') |
| |
|
| | |
| | if len(xe) > 1: |
| | rho, _ = pearsonr(xe, ym) |
| | ax_joint.text( |
| | 0.05, 0.90, f"$\\rho$ = {rho:.2f}", |
| | transform=ax_joint.transAxes, |
| | fontsize=10, |
| | bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="none", alpha=0.6) |
| | ) |
| |
|
| | ax_joint.set_xlabel(r"$\mathbf{\frac{CE(P_{ft},Y)}{CE(P_{ft},Y)+CE(P_{pt},Y)}}$", fontsize=14) |
| | ax_joint.set_ylabel(r"$\mathbf{\sigma\left(\mathrm{JS}(P_{pt},P_{ft})\right)}$", fontsize=11) |
| |
|
| | plt.tight_layout() |
| | os.makedirs(os.path.dirname(save_path) or '.', exist_ok=True) |
| | fig.savefig(save_path, dpi=300, bbox_inches='tight') |
| | plt.close(fig) |
| |
|
| | import os |
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| | import seaborn as sns |
| |
|
| | def plot_stacked_ce_vs_mi_bins( |
| | mi_values, |
| | ce_values_pt, |
| | ce_values_ft, |
| | bins: int = 12, |
| | figsize: tuple = (10, 5), |
| | save_path: str = 'ce_vs_mi_stacked_bins.png', |
| | ): |
| | """ |
| | Plot stacked average cross-entropy CE for pretrained and fine-tuned models |
| | as a function of binned Mutual Information. Uses rocket palette for stacking. |
| | |
| | Args: |
| | mi_values (array-like): Mutual information per sample. |
| | ce_values_pt (array-like): Cross-entropy for pretrained model per sample. |
| | ce_values_ft (array-like): Cross-entropy for fine-tuned model per sample. |
| | bins (int): Number of bins. |
| | figsize (tuple): Figure size. |
| | save_path (str): Path to save the plot. |
| | """ |
| | |
| | def to_np(x): |
| | return x.cpu().numpy() if hasattr(x, 'cpu') else np.asarray(x) |
| | mi = to_np(mi_values).ravel() |
| | mi = (mi - mi.min()) / (mi.max() - mi.min()) |
| | ce_pt = to_np(ce_values_pt).ravel() |
| | ce_ft = to_np(ce_values_ft).ravel() |
| |
|
| | |
| | edges = np.linspace(mi.min(), mi.max(), bins + 1) |
| | bin_idx = np.digitize(mi, edges, right=True) - 1 |
| | bin_idx = np.clip(bin_idx, 0, bins - 1) |
| |
|
| | |
| | mean_pt = [] |
| | mean_ft = [] |
| | for i in range(bins): |
| | mask = (bin_idx == i) |
| | mean_pt.append(ce_pt[mask].mean() if mask.any() else np.nan) |
| | mean_ft.append(ce_ft[mask].mean() if mask.any() else np.nan) |
| |
|
| | |
| | labels = [f"({edges[i]:.2f},{edges[i+1]:.2f}]" for i in range(bins)] |
| |
|
| | |
| | bottom_colors = sns.color_palette("Reds", bins) |
| | top_colors = sns.color_palette("Blues", bins) |
| |
|
| | |
| | plt.figure(figsize=figsize) |
| | x = np.arange(bins) |
| | plt.bar(x, mean_pt, color=bottom_colors, label='CE Pretrained') |
| | plt.bar(x, mean_ft, bottom=mean_pt, color=top_colors, label='CE Fine-tuned') |
| |
|
| | |
| | plt.xticks(x, labels, rotation=45, ha='right', fontsize=10) |
| | plt.xlabel("Mutual Information Bins", fontsize=12) |
| | plt.ylabel("Cross-Entropy Loss (CE)", fontsize=12) |
| | plt.legend(loc='upper right') |
| | sns.despine(trim=True) |
| | plt.tight_layout() |
| |
|
| | |
| | os.makedirs(os.path.dirname(save_path) or '.', exist_ok=True) |
| | plt.savefig(save_path, dpi=300) |
| | plt.close() |
| |
|
| | import os |
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| | import seaborn as sns |
| | from scipy.stats import pearsonr |
| |
|
| | def plot_ce_vs_mi_by_correctness( |
| | ce_pt: np.ndarray, |
| | ce_ft: np.ndarray, |
| | mi_values: np.ndarray, |
| | correct_pt: np.ndarray, |
| | correct_ft: np.ndarray, |
| | figsize: tuple = (20, 4), |
| | save_path: str = 'ce_vs_mi_by_correctness.png', |
| | ): |
| | """ |
| | Plot CE vs. Mutual Information across 5 subsets: All, TT, TF, FT, FF. |
| | For each panel: red scatter/regression for pretrained CE vs. MI, |
| | blue scatter/regression for fine-tuned CE vs. MI. Annotate Pearson ρ_pt and ρ_ft. |
| | """ |
| | |
| | def to_np(x): |
| | return x.cpu().numpy() if hasattr(x, 'cpu') else x |
| |
|
| | ce_pt = to_np(ce_pt) |
| | ce_ft = to_np(ce_ft) |
| | mi = to_np(mi_values) |
| | cpt = to_np(correct_pt) |
| | cft = to_np(correct_ft) |
| |
|
| | masks = { |
| | 'All': np.ones_like(mi, dtype=bool), |
| | 'TrueTrue': np.logical_and(cpt, cft), |
| | 'TrueFalse': np.logical_and(cpt, ~cft), |
| | 'FalseTrue': np.logical_and(~cpt, cft), |
| | 'FalseFalse':np.logical_and(~cpt, ~cft), |
| | } |
| |
|
| | |
| | color_pt = 'tab:red' |
| | color_ft = 'tab:blue' |
| |
|
| | fig, axs = plt.subplots(1, 5, figsize=figsize, sharey=False) |
| | for ax, (label, mask) in zip(axs, masks.items()): |
| | x_pt = ce_pt[mask] |
| | x_ft = ce_ft[mask] |
| | y = mi[mask] |
| |
|
| | |
| | ax.scatter(x_pt, y, c=color_pt, s=20, alpha=0.7, edgecolor='k', linewidth=0.2) |
| | sns.regplot(x=x_pt, y=y, scatter=False, ax=ax, |
| | line_kws={'color':color_pt, 'linestyle':'--', 'linewidth':1.5}) |
| |
|
| | |
| | ax.scatter(x_ft, y, c=color_ft, s=20, alpha=0.7, edgecolor='k', linewidth=0.2) |
| | sns.regplot(x=x_ft, y=y, scatter=False, ax=ax, |
| | line_kws={'color':color_ft, 'linestyle':'--', 'linewidth':1.5}) |
| |
|
| | |
| | if len(x_pt) > 1: |
| | rho_pt, _ = pearsonr(x_pt, y) |
| | ax.text(0.05, 0.90, f"$\\rho_{{pt}}={rho_pt:.2f}$", |
| | transform=ax.transAxes, color=color_pt, |
| | fontsize=10, bbox=dict(boxstyle="round,pad=0.2", fc="white", alpha=0.6, ec="none")) |
| | if len(x_ft) > 1: |
| | rho_ft, _ = pearsonr(x_ft, y) |
| | ax.text(0.05, 0.80, f"$\\rho_{{ft}}={rho_ft:.2f}$", |
| | transform=ax.transAxes, color=color_ft, |
| | fontsize=10, bbox=dict(boxstyle="round,pad=0.2", fc="white", alpha=0.6, ec="none")) |
| |
|
| | ax.set_title(label, fontsize=12) |
| | if label == 'All': |
| | ax.set_xlabel('Cross-Entropy Error', fontsize=11) |
| | ax.set_ylabel('Mutual Information (JSD)', fontsize=11) |
| | else: |
| | ax.set_xlabel('Cross-Entropy Error', fontsize=11) |
| | ax.set_ylabel('') |
| |
|
| | ax.tick_params(labelsize=9) |
| |
|
| | plt.tight_layout() |
| | os.makedirs(os.path.dirname(save_path) or '.', exist_ok=True) |
| | fig.savefig(save_path, dpi=300) |
| | plt.close(fig) |
| |
|
| |
|
| | import torch |
| | import matplotlib.pyplot as plt |
| | from torchvision.utils import make_grid |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| |
|
| | import os |
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| | import seaborn as sns |
| | from matplotlib.ticker import MaxNLocator, FormatStrFormatter |
| |
|
| |
|
| | def js_divergence(p: np.ndarray, q: np.ndarray) -> float: |
| | """ |
| | Compute the Jensen-Shannon divergence between two probability distributions. |
| | """ |
| | m = 0.5 * (p + q) |
| | |
| | p_safe = np.clip(p, 1e-12, 1) |
| | q_safe = np.clip(q, 1e-12, 1) |
| | m_safe = np.clip(m, 1e-12, 1) |
| | return 0.5 * (np.sum(p_safe * np.log(p_safe / m_safe)) + |
| | np.sum(q_safe * np.log(q_safe / m_safe))) |
| |
|
| |
|
| | def plot_confidence_vs_js( |
| | P_pt: np.ndarray, |
| | P_ft: np.ndarray, |
| | save_path: str |
| | ) -> None: |
| | """ |
| | Plot combined confidence vs. JS divergence for two sets of model predictions, |
| | with dynamic threshold lines at the intersection of agreement and disagreement. |
| | |
| | Args: |
| | P_pt (np.ndarray): Pre-trained model probabilities, shape (N, C). |
| | P_ft (np.ndarray): Fine-tuned model probabilities, shape (N, C). |
| | save_path (str): File path where the figure will be saved. |
| | """ |
| | def to_np(x): |
| | return x.cpu().numpy() if hasattr(x, 'cpu') else np.asarray(x) |
| |
|
| | |
| | P_pt = to_np(P_pt) |
| | P_ft = to_np(P_ft) |
| |
|
| | |
| | conf_pt = P_pt.max(axis=1) |
| | conf_ft = P_ft.max(axis=1) |
| | combined_confidence = 0.5 * (conf_pt + conf_ft) |
| |
|
| | |
| | js_values = np.array([js_divergence(P_pt[i], P_ft[i]) for i in range(len(P_pt))]) |
| |
|
| | |
| | agree = np.argmax(P_pt, axis=1) == np.argmax(P_ft, axis=1) |
| | disagree = ~agree |
| |
|
| | |
| | conf_thresh = combined_confidence[disagree].min() |
| | js_thresh = js_values[disagree].min() |
| |
|
| | |
| | disagree_color = sns.color_palette("Blues", 2)[1] |
| | agree_color = "violet" |
| |
|
| | |
| | fig, ax = plt.subplots(figsize=(5, 5)) |
| |
|
| | |
| | ax.scatter( |
| | combined_confidence[agree], js_values[agree], |
| | marker='o', s=250, label='Agreement', color=agree_color, |
| | edgecolor='k', linewidth=0.75, alpha=0.5 |
| | ) |
| | ax.scatter( |
| | combined_confidence[disagree], js_values[disagree], |
| | marker='P', s=250, label='Disagreement', color=disagree_color, |
| | edgecolor='k', linewidth=0.75, alpha=0.5 |
| | ) |
| |
|
| | |
| | ax.axvline(x=conf_thresh, linestyle='--', color='gray') |
| | ax.axhline(y=js_thresh, linestyle='--', color='gray') |
| |
|
| | |
| | x_min, x_max = combined_confidence.min(), combined_confidence.max() |
| | y_min, y_max = js_values.min(), js_values.max() |
| | x_margin = (x_max - x_min) * 0.05 |
| | y_margin = (y_max - y_min) * 0.05 |
| | ax.set_xlim(x_min - x_margin, x_max + x_margin) |
| | ax.set_ylim(y_min - y_margin, y_max + y_margin) |
| | |
| | ax.xaxis.set_major_locator(MaxNLocator(6)) |
| | ax.yaxis.set_major_locator(MaxNLocator(6)) |
| | ax.xaxis.set_major_formatter(FormatStrFormatter('%.2f')) |
| | ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f')) |
| |
|
| | |
| | ax.set_facecolor('white') |
| | ax.xaxis.set_tick_params(which='both', bottom=True, top=False, labelbottom=True, labelsize=13) |
| | ax.yaxis.set_tick_params(which='both', left=True, right=False, labelleft=True, labelsize=13) |
| | for spine in ax.spines.values(): |
| | spine.set_visible(True) |
| |
|
| | |
| | ax.set_xlabel(r'$\mathbf{Combined\ Confidence\ }$'+"\n"+r'$\mathbf{=\ \frac{1}{2}(\max_i\ p_{pt}^{(i)}\ +\ \max_i\ p_{ft}^{(i)})}$', fontsize=13) |
| | ax.set_ylabel(r'$\mathbf{Divergence\ }$'+"\n"+r'$\mathbf{=\ \frac{1}{2}[KL(P_{pt}\|M)\ +\ KL(P_{ft}\|M)]}$', fontsize=13) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | ax.legend(fontsize=12, frameon=False, loc='best') |
| |
|
| | |
| | os.makedirs(os.path.dirname(save_path) or '.', exist_ok=True) |
| | fig.savefig(save_path, dpi=300, bbox_inches='tight') |
| | plt.close(fig) |
| |
|
| |
|