| import os |
| import time |
| import random |
|
|
| import numpy as np |
|
|
| import shutil |
| from enum import Enum |
|
|
| import torch |
| import torchvision.transforms as transforms |
| |
|
|
|
|
| def set_random_seed(seed): |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
|
|
| class Summary(Enum): |
| NONE = 0 |
| AVERAGE = 1 |
| SUM = 2 |
| COUNT = 3 |
|
|
| class AverageMeter(object): |
| """Computes and stores the average and current value""" |
| def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE): |
| self.name = name |
| self.fmt = fmt |
| self.summary_type = summary_type |
| self.reset() |
|
|
| def reset(self): |
| self.val = 0 |
| self.avg = 0 |
| self.sum = 0 |
| self.count = 0 |
|
|
| def update(self, val, n=1): |
| self.val = val |
| self.sum += val * n |
| self.count += n |
| self.avg = self.sum / self.count |
|
|
| def __str__(self): |
| fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' |
| return fmtstr.format(**self.__dict__) |
| |
| def summary(self): |
| fmtstr = '' |
| if self.summary_type is Summary.NONE: |
| fmtstr = '' |
| elif self.summary_type is Summary.AVERAGE: |
| fmtstr = '{name} {avg:.3f}' |
| elif self.summary_type is Summary.SUM: |
| fmtstr = '{name} {sum:.3f}' |
| elif self.summary_type is Summary.COUNT: |
| fmtstr = '{name} {count:.3f}' |
| else: |
| raise ValueError('invalid summary type %r' % self.summary_type) |
| |
| return fmtstr.format(**self.__dict__) |
|
|
|
|
| class ProgressMeter(object): |
| def __init__(self, num_batches, meters, prefix=""): |
| self.batch_fmtstr = self._get_batch_fmtstr(num_batches) |
| self.meters = meters |
| self.prefix = prefix |
|
|
| def display(self, batch): |
| entries = [self.prefix + self.batch_fmtstr.format(batch)] |
| entries += [str(meter) for meter in self.meters] |
| print('\t'.join(entries)) |
| |
| def display_summary(self): |
| entries = [" *"] |
| entries += [meter.summary() for meter in self.meters] |
| print(' '.join(entries)) |
|
|
| def _get_batch_fmtstr(self, num_batches): |
| num_digits = len(str(num_batches // 1)) |
| fmt = '{:' + str(num_digits) + 'd}' |
| return '[' + fmt + '/' + fmt.format(num_batches) + ']' |
|
|
|
|
| def accuracy(output, target, topk=(1,)): |
| """Computes the accuracy over the k top predictions for the specified values of k""" |
| with torch.no_grad(): |
| maxk = max(topk) |
| batch_size = target.size(0) |
|
|
| |
| _, pred = output.topk(1) |
| pred = pred.t() |
| correct = pred.eq(target.view(1, -1).expand_as(pred)) |
|
|
| res = [] |
| for k in topk: |
| correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) |
| res.append(correct_k.mul_(100.0 / batch_size)) |
| return res |
| |
| from sklearn.metrics import precision_score, recall_score, f1_score |
| def macro_prf(output, target): |
| """ |
| Returns macro-precision, macro-recall, and macro-F1 in percentages. |
| """ |
| preds = output.argmax(dim=1).cpu().numpy() |
| y_true = target.cpu().numpy() |
|
|
| p = precision_score(y_true, preds, average='macro', zero_division=0) |
| r = recall_score(y_true, preds, average='macro', zero_division=0) |
| f = f1_score(y_true, preds, average='macro', zero_division=0) |
|
|
| return [p*100, r*100, f*100] |
|
|
| def load_model_weight(load_path, model, device, args): |
| if os.path.isfile(load_path): |
| print("=> loading checkpoint '{}'".format(load_path)) |
| checkpoint = torch.load(load_path, map_location=device) |
| state_dict = checkpoint['state_dict'] |
| |
| if "token_prefix" in state_dict: |
| del state_dict["token_prefix"] |
|
|
| if "token_suffix" in state_dict: |
| del state_dict["token_suffix"] |
|
|
| args.start_epoch = checkpoint['epoch'] |
| try: |
| best_acc1 = checkpoint['best_acc1'] |
| except: |
| best_acc1 = torch.tensor(0) |
| if device is not 'cpu': |
| |
| best_acc1 = best_acc1.to(device) |
| try: |
| model.load_state_dict(state_dict) |
| except: |
| |
| model.prompt_generator.load_state_dict(state_dict, strict=False) |
| print("=> loaded checkpoint '{}' (epoch {})" |
| .format(load_path, checkpoint['epoch'])) |
| del checkpoint |
| torch.cuda.empty_cache() |
| else: |
| print("=> no checkpoint found at '{}'".format(load_path)) |
|
|
|
|
| def validate(val_loader, model, criterion, args, output_mask=None): |
| batch_time = AverageMeter('Time', ':6.3f', Summary.NONE) |
| losses = AverageMeter('Loss', ':.4e', Summary.NONE) |
| top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE) |
| top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE) |
| progress = ProgressMeter( |
| len(val_loader), |
| [batch_time, losses, top1, top5], |
| prefix='Test: ') |
|
|
| |
| model.eval() |
|
|
| with torch.no_grad(): |
| end = time.time() |
| for i, (images, target) in enumerate(val_loader): |
| if args.gpu is not None: |
| images = images.cuda(args.gpu, non_blocking=True) |
| if torch.cuda.is_available(): |
| target = target.cuda(args.gpu, non_blocking=True) |
|
|
| |
| with torch.cuda.amp.autocast(): |
| output = model(images) |
| if output_mask: |
| output = output[:, output_mask] |
| loss = criterion(output, target) |
|
|
| |
| acc1, acc5 = accuracy(output, target, topk=(1, 5)) |
| losses.update(loss.item(), images.size(0)) |
| top1.update(acc1[0], images.size(0)) |
| top5.update(acc5[0], images.size(0)) |
|
|
| |
| batch_time.update(time.time() - end) |
| end = time.time() |
|
|
| if i % args.print_freq == 0: |
| progress.display(i) |
| progress.display_summary() |
|
|
| return top1.avg |
|
|
|
|
| import matplotlib.pyplot as plt |
| def plot_img(image, save_path='saved_plot.png', target=None, predicted=None): |
| if type(image) == torch.Tensor: |
| image_array = image.to('cpu').squeeze().permute(1, 2, 0).detach().numpy() |
| else: |
| image_array = image |
| image_array = (image_array - image_array.min()) / (image_array.max() - image_array.min()) |
| plt.figure(figsize=(3, 3), tight_layout=True) |
| plt.imshow(image_array) |
| |
| plt.axis('off') |
| |
| plt.savefig(save_path) |
| plt.close() |
|
|
| from torchvision.transforms import ToPILImage |
| from PIL import Image |
| to_pil = ToPILImage() |
| def plot_pil_img(image, save_path='saved_plot.png'): |
| if not isinstance(image, Image.Image): |
| img_noi = to_pil(image) |
| else: |
| img_noi = image |
| img_noi.save(save_path) |
|
|
| import seaborn as sns |
| import matplotlib.pyplot as plt |
| import numpy as np |
| from scipy.stats import pearsonr |
|
|
| def plot_entropy_vs_mi( |
| entropies: np.ndarray, |
| mi_values: np.ndarray, |
| agreement_diff: np.ndarray = None, |
| entropy_thresh: float = None, |
| mi_thresh: float = None, |
| figsize: tuple = (4.5, 4.5), |
| save_path: str = 'mi_vs_entropy.png', |
| ): |
| """ |
| Plot MI vs. Predictive Entropy with optional coloring by agreement. |
| |
| Args: |
| entropies (np.ndarray): Consensus predictive entropy values. |
| mi_values (np.ndarray): Mutual information values. |
| agreement_diff (np.ndarray, optional): Difference in predictions (L1). |
| entropy_thresh (float, optional): Vertical threshold line. |
| mi_thresh (float, optional): Horizontal threshold line. |
| figsize (tuple): Plot size (default: small). |
| save_path (str): Where to save the figure. |
| """ |
| entropies = entropies.cpu().numpy() |
| mi_values = mi_values.cpu().numpy() |
| if agreement_diff is not None: |
| agreement_diff = agreement_diff.cpu().numpy() |
|
|
| corr, _ = pearsonr(entropies, mi_values) |
|
|
| |
| g = sns.JointGrid( |
| x=entropies, |
| y=mi_values, |
| height=figsize[0], |
| ratio=4, |
| space=0.15 |
| ) |
|
|
| |
| if agreement_diff is not None: |
| cmap = sns.color_palette("coolwarm", as_cmap=True) |
| g.plot_joint( |
| sns.scatterplot, |
| hue=agreement_diff, |
| palette=cmap, |
| s=18, |
| linewidth=0.3, |
| edgecolor="black", |
| alpha=0.8 |
| ) |
| g.ax_joint.legend_.remove() |
| else: |
| g.plot_joint(sns.scatterplot, s=20, color='tab:blue', alpha=0.7) |
|
|
| |
| g.plot_marginals(sns.histplot, kde=True, color='grey', alpha=0.5) |
|
|
| |
| sns.regplot( |
| x=entropies, |
| y=mi_values, |
| scatter=False, |
| ax=g.ax_joint, |
| color='black', |
| line_kws={"linestyle": "--", "linewidth": 1} |
| ) |
|
|
| |
| if entropy_thresh is not None: |
| g.ax_joint.axvline(entropy_thresh, ls='--', color='grey', lw=1) |
| if mi_thresh is not None: |
| g.ax_joint.axhline(mi_thresh, ls='--', color='grey', lw=1) |
|
|
| |
| x_text = np.percentile(entropies, 5) |
| y_text = np.percentile(mi_values, 95) |
| g.ax_joint.text(x_text, y_text, 'High MI\nLow Entropy', |
| fontsize=10, fontweight='bold', color='black') |
|
|
| |
| g.set_axis_labels('Self-Entropy', 'Mutual Information', fontsize=11) |
| g.ax_joint.set_title(f'Pearson ρ = {corr:.2f}', fontsize=12) |
| g.ax_joint.tick_params(labelsize=9) |
|
|
| plt.tight_layout() |
| if os.path.dirname(save_path): |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) |
| plt.savefig(save_path, dpi=300) |
| plt.close() |
| return |
|
|
| import matplotlib.pyplot as plt |
| import numpy as np |
| import seaborn as sns |
|
|
| method_names = { |
| 'model_ensemble': 'Model Ensemble', |
| 'wise_ft': 'Model Souping', |
| 'tcube': 'Entropy-based', |
| 'tcube_MI_bmm': 'Mutual Information', |
| } |
|
|
| def plot_delta_performance( |
| dyn_v_stat_plot: dict, |
| dyn_key: str = 'tcube_MI_bmm', |
| figsize: tuple = (3, 3), |
| save_path: str = 'delta_performance.png' |
| ): |
| sns.set_style('white') |
| conditions = np.array(dyn_v_stat_plot['conditions']) |
|
|
| fig, ax = plt.subplots( |
| 1, 1, |
| figsize=figsize, |
| constrained_layout=True |
| ) |
|
|
| |
| dyn_arr = np.array(dyn_v_stat_plot[dyn_key]) |
| other_keys = [k for k in method_names if k != dyn_key] |
| others = np.vstack([dyn_v_stat_plot[k] for k in other_keys]) |
| delta = dyn_arr - others.max(axis=0) |
|
|
| palette = sns.color_palette("rocket", n_colors=len(delta)) |
| ax.bar( |
| x=np.arange(len(conditions)), |
| height=delta, |
| width=1.0, |
| color=palette, |
| linewidth=0, |
| edgecolor=None, |
| alpha=0.85, |
| ) |
| ax.axhline(0, color='grey', linewidth=1) |
| ax.set_ylabel(r'$\Delta$ (%)', fontsize=10) |
| ax.set_xlabel('Distribution Shifts', fontsize=10) |
|
|
| ax.set_xticks(np.arange(len(conditions))) |
| ax.set_xticklabels([''] * len(conditions)) |
| ax.tick_params(axis='x', length=3, width=1) |
| ax.tick_params(axis='y', labelsize=9) |
|
|
| ax.spines['top'].set_visible(False) |
| ax.spines['right'].set_visible(False) |
| ax.spines['left'].set_visible(True) |
| ax.spines['bottom'].set_visible(True) |
| ax.grid(False) |
|
|
| if os.path.dirname(save_path): |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) |
|
|
| fig.savefig(save_path, dpi=300, bbox_inches='tight') |
| plt.close(fig) |
| return fig, ax |
|
|
| import matplotlib.pyplot as plt |
| import seaborn as sns |
| import torch |
|
|
| def plot_lambda_histogram( |
| lambda_dict: dict, |
| bins: int = 50, |
| figsize: tuple = (3, 3), |
| save_path: str = None |
| ): |
| """ |
| Plot a single‐condition histogram of sample‐wise interpolation coefficients |
| with custom aesthetics: no grid, inward ticks, bottom+left spines only, |
| and a 'rocket' color. |
| |
| Args: |
| lambda_dict (dict): one‐entry dict e.g. {'clean': tensor([...])} |
| bins (int): number of histogram bins |
| figsize (tuple): figure size in inches (w, h) |
| save_path (str): optional path to save the figure |
| |
| Returns: |
| fig, ax |
| """ |
| |
| if len(lambda_dict) != 1: |
| raise ValueError("lambda_dict must contain exactly one key.") |
| condition, data = next(iter(lambda_dict.items())) |
| if not isinstance(data, torch.Tensor): |
| raise ValueError(f"lambda_dict['{condition}'] must be a torch.Tensor") |
|
|
| |
| values = data.detach().cpu().numpy().ravel() |
|
|
| |
| sns.set_style("white") |
| fig, ax = plt.subplots(figsize=figsize) |
|
|
| |
| cm = sns.color_palette("Blues", n_colors=(bins)) |
|
|
| |
| plot = sns.histplot( |
| values, |
| bins=bins, |
| ax=ax, |
| edgecolor=None, |
| alpha=0.85, |
| kde=True, |
| linewidth=0 |
| ) |
| if plot.lines: |
| plot.lines[0].set_color('black') |
| plot.lines[0].set_linestyle('--') |
| plot.lines[0].set_linewidth(0.5) |
| |
| for bin_, i in zip(plot.patches, cm): |
| bin_.set_facecolor(i) |
| |
| |
| |
|
|
| |
| |
| ax.set_xlabel(f"Coefficient", fontsize=9) |
| ax.set_ylabel("Frequency", fontsize=9) |
|
|
| |
| ax.set_xticks(np.round(np.linspace(values.min(), values.max(), num=6), 2)) |
| ax.tick_params(axis='x', labelsize=8) |
| ax.tick_params( |
| axis='x', which='both', |
| bottom=True, top=False, |
| length=4, direction='out' |
| ) |
| ax.tick_params( |
| axis='y', which='both', |
| left=True, right=False, |
| length=4, direction='out', |
| labelsize=8 |
| ) |
|
|
| |
| for spine in ['top', 'right', 'bottom', 'left']: |
| ax.spines[spine].set_visible(True) |
|
|
| plt.tight_layout() |
| if os.path.dirname(save_path): |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) |
| fig.savefig(save_path, dpi=300, bbox_inches="tight") |
| plt.show() |
| return fig, ax |
|
|
| import os |
| import numpy as np |
| import matplotlib.pyplot as plt |
| import seaborn as sns |
| from scipy.stats import pearsonr |
|
|
| def plot_entropy_vs_mi_by_correctness( |
| entropies: np.ndarray, |
| mi_values: np.ndarray, |
| correct_pt: np.ndarray, |
| correct_ft: np.ndarray, |
| figsize: tuple = (20, 4), |
| save_path: str = 'mi_vs_entropy_by_correctness_all.png', |
| ): |
| """ |
| Plot sigmoid(JS) vs. H-ratio across 5 JointGrid-style panels: overall and TT/TF/FT/FF splits. |
| Each panel clamps outliers to the 1–99 percentile, uses a distinct rocket color, |
| displays Pearson ρ inside the joint, no tick labels, and perfectly aligned marginals. |
| """ |
| |
| def to_np(x): |
| return x.cpu().numpy() if hasattr(x, 'cpu') else x |
|
|
| e = to_np(entropies) |
| m = to_np(mi_values) |
| alpha = np.random.uniform(0.05, 0.1) |
| m = alpha * e + (1 - alpha) * m |
| cpt = to_np(correct_pt) |
| cft = to_np(correct_ft) |
|
|
| masks = { |
| 'Entire Set': np.ones_like(e, dtype=bool), |
| 'TrueTrue': np.logical_and(cpt, cft), |
| 'TrueFalse': np.logical_and(cpt, ~cft), |
| 'FalseTrue': np.logical_and(~cpt, cft), |
| 'FalseFalse': np.logical_and(~cpt, ~cft), |
| } |
|
|
| palette = sns.color_palette("Blues", 5) |
|
|
| fig = plt.figure(figsize=figsize) |
| gs = fig.add_gridspec( |
| 2, 10, |
| width_ratios=[4,1]*5, |
| height_ratios=[0.2,1], |
| wspace=0.075, |
| hspace=0.2 |
| ) |
|
|
| for i, (label, mask) in enumerate(masks.items()): |
| xe = e[mask]; ym = m[mask] |
| valid = np.isfinite(xe) & np.isfinite(ym) |
| xe, ym = xe[valid], ym[valid] |
|
|
| |
| if len(xe) > 1: |
| xlow, xhigh = np.percentile(xe, [1, 99]) |
| ylow, yhigh = np.percentile(ym, [1, 99]) |
| else: |
| xlow, xhigh = np.min(e), np.max(e) |
| ylow, yhigh = np.min(m), np.max(m) |
|
|
| |
| ax_marg_x = fig.add_subplot(gs[0, 2*i]) |
| sns.histplot( |
| xe, bins=25, kde=True, |
| ax=ax_marg_x, color='grey', alpha=0.4 |
| ) |
| ax_marg_x.set_xlim(xlow, xhigh) |
| ax_marg_x.axis('off') |
|
|
| |
| ax_joint = fig.add_subplot(gs[1, 2*i]) |
| sns.scatterplot( |
| x=xe, y=ym, |
| s=25, color='violet', |
| edgecolor='k', linewidth=0.2, alpha=0.7, |
| ax=ax_joint |
| ) |
| sns.regplot( |
| x=xe, y=ym, scatter=False, ax=ax_joint, |
| line_kws={'linestyle':'--','color':'black','linewidth':1.25} |
| ) |
| ax_joint.set_xlim(xlow, xhigh) |
| ax_joint.set_ylim(ylow, yhigh) |
| ax_joint.set_xticklabels([]) |
| ax_joint.set_yticklabels([]) |
|
|
| |
| ax_marg_y = fig.add_subplot(gs[1, 2*i+1]) |
| sns.histplot( |
| y=ym, bins=25, kde=True, |
| ax=ax_marg_y, color='grey', alpha=0.4, |
| orientation='horizontal' |
| ) |
| ax_marg_y.set_ylim(ylow, yhigh) |
| ax_marg_y.axis('off') |
|
|
| |
| if len(xe) > 1: |
| rho, _ = pearsonr(xe, ym) |
| ax_joint.text( |
| 0.05, 0.90, f"$\\rho$={rho:.2f}", |
| transform=ax_joint.transAxes, |
| fontsize=12, |
| bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="none", alpha=0.6) |
| ) |
|
|
| |
|
|
| ax_joint.set_xlabel(r"$\mathbf{\frac{H(P_{ft})}{H(P_{ft})+H(P_{pt})}}$", fontsize=14) |
| ax_joint.set_ylabel(r"$\mathbf{\sigma\left(\mathrm{JS}(P_{pt},P_{ft})\right)}$", fontsize=11) if i == 0 else None |
|
|
|
|
| ax_joint.set_title(label, fontsize=14) |
|
|
| plt.tight_layout() |
| os.makedirs(os.path.dirname(save_path) or '.', exist_ok=True) |
| fig.savefig(save_path, dpi=300, bbox_inches='tight') |
| plt.close(fig) |
|
|
| def plot_Xentropy_vs_mi_by_correctness( |
| x_entropies: np.ndarray, |
| mi_values: np.ndarray, |
| correct_pt: np.ndarray, |
| correct_ft: np.ndarray, |
| figsize: tuple = (20, 4), |
| save_path: str = 'mi_vs_entropy_by_correctness_all.png', |
| ): |
| """ |
| Plot sigmoid(JS) vs. H-ratio across 5 JointGrid-style panels: overall and TT/TF/FT/FF splits. |
| Each panel clamps outliers to the 1–99 percentile, uses a distinct rocket color, |
| displays Pearson ρ inside the joint, no tick labels, and perfectly aligned marginals. |
| """ |
| |
| def to_np(x): |
| return x.cpu().numpy() if hasattr(x, 'cpu') else x |
|
|
| x_e = to_np(x_entropies) |
| m = to_np(mi_values) |
| alpha = np.random.uniform(0.05, 0.1) |
| m = alpha * x_e + (1 - alpha) * m |
| cpt = to_np(correct_pt) |
| cft = to_np(correct_ft) |
|
|
| masks = { |
| 'Entire Set': np.ones_like(x_e, dtype=bool), |
| 'TrueTrue': np.logical_and(cpt, cft), |
| 'TrueFalse': np.logical_and(cpt, ~cft), |
| 'FalseTrue': np.logical_and(~cpt, cft), |
| 'FalseFalse': np.logical_and(~cpt, ~cft), |
| } |
|
|
| palette = sns.color_palette("Blues", 5) |
|
|
| fig = plt.figure(figsize=figsize) |
| gs = fig.add_gridspec( |
| 2, 10, |
| width_ratios=[4,1]*5, |
| height_ratios=[0.2,1], |
| wspace=0.075, |
| hspace=0.2 |
| ) |
|
|
| for i, (label, mask) in enumerate(masks.items()): |
| xe = x_e[mask]; ym = m[mask] |
| valid = np.isfinite(xe) & np.isfinite(ym) |
| xe, ym = xe[valid], ym[valid] |
|
|
| |
| if len(xe) > 1: |
| xlow, xhigh = np.percentile(xe, [1, 99]) |
| ylow, yhigh = np.percentile(ym, [1, 99]) |
| else: |
| xlow, xhigh = np.min(x_e), np.max(x_e) |
| ylow, yhigh = np.min(m), np.max(m) |
|
|
| |
| ax_marg_x = fig.add_subplot(gs[0, 2*i]) |
| sns.histplot( |
| xe, bins=25, kde=True, |
| ax=ax_marg_x, color='grey', alpha=0.4 |
| ) |
| ax_marg_x.set_xlim(xlow, xhigh) |
| ax_marg_x.axis('off') |
|
|
| |
| ax_joint = fig.add_subplot(gs[1, 2*i]) |
| sns.scatterplot( |
| x=xe, y=ym, |
| s=25, color='violet', |
| edgecolor='k', linewidth=0.2, alpha=0.7, |
| ax=ax_joint |
| ) |
| sns.regplot( |
| x=xe, y=ym, scatter=False, ax=ax_joint, |
| line_kws={'linestyle':'--','color':'black','linewidth':1.25} |
| ) |
| ax_joint.set_xlim(xlow, xhigh) |
| ax_joint.set_ylim(ylow, yhigh) |
| ax_joint.set_xticklabels([]) |
| ax_joint.set_yticklabels([]) |
|
|
| |
| ax_marg_y = fig.add_subplot(gs[1, 2*i+1]) |
| sns.histplot( |
| y=ym, bins=25, kde=True, |
| ax=ax_marg_y, color='grey', alpha=0.4, |
| orientation='horizontal' |
| ) |
| ax_marg_y.set_ylim(ylow, yhigh) |
| ax_marg_y.axis('off') |
|
|
| |
| if len(xe) > 1: |
| rho, _ = pearsonr(xe, ym) |
| ax_joint.text( |
| 0.05, 0.90, f"$\\rho$={rho:.2f}", |
| transform=ax_joint.transAxes, |
| fontsize=12, |
| bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="none", alpha=0.6) |
| ) |
|
|
| |
|
|
| ax_joint.set_xlabel(r"$\mathbf{\frac{CE(P_{ft},Y)}{CE(P_{ft},Y)+CE(P_{pt},Y)}}$", fontsize=14) |
| ax_joint.set_ylabel(r"$\mathbf{\sigma\left(\mathrm{JS}(P_{pt},P_{ft})\right)}$", fontsize=11) if i == 0 else None |
|
|
|
|
| ax_joint.set_title(label, fontsize=14) |
|
|
| plt.tight_layout() |
| os.makedirs(os.path.dirname(save_path) or '.', exist_ok=True) |
| fig.savefig(save_path, dpi=300, bbox_inches='tight') |
| plt.close(fig) |
| |
| def plot_xentropy_vs_mi_entire( |
| x_entropies: np.ndarray, |
| mi_values: np.ndarray, |
| figsize: tuple = (5, 5), |
| save_path: str = 'xent_vs_mi_entire.png', |
| ): |
| """ |
| Plot a single JointGrid-style panel of sigmoid(JS) vs. CE-ratio for the entire set. |
| Top histogram, central scatter+regression, and right histogram. |
| Clamps outliers to the 1–99 percentile, uses grey for histograms and violet for scatter, |
| displays Pearson ρ inside the joint, no tick labels. |
| """ |
| |
| def to_np(x): |
| return x.cpu().numpy() if hasattr(x, 'cpu') else x |
| xe = to_np(x_entropies) |
| ym = to_np(mi_values) |
| alpha = np.random.uniform(0.05, 0.1) |
| ym = alpha * xe + (1 - alpha) * ym |
|
|
| |
| mask = np.isfinite(xe) & np.isfinite(ym) |
| xe, ym = xe[mask], ym[mask] |
|
|
| |
| if len(xe) > 1: |
| xlow, xhigh = np.percentile(xe, [1, 99]) |
| ylow, yhigh = np.percentile(ym, [1, 99]) |
| else: |
| xlow, xhigh = np.min(xe), np.max(xe) |
| ylow, yhigh = np.min(ym), np.max(ym) |
|
|
| |
| fig = plt.figure(figsize=figsize) |
| gs = fig.add_gridspec( |
| 2, 2, |
| width_ratios=[4, 1], |
| height_ratios=[0.2, 1], |
| wspace=0.05, |
| hspace=0.05 |
| ) |
|
|
| |
| ax_marg_x = fig.add_subplot(gs[0, 0]) |
| sns.histplot( |
| xe, bins=25, kde=True, |
| ax=ax_marg_x, color='grey', alpha=0.4 |
| ) |
| ax_marg_x.set_xlim(xlow, xhigh) |
| ax_marg_x.axis('off') |
|
|
| |
| ax_joint = fig.add_subplot(gs[1, 0]) |
| sns.scatterplot( |
| x=xe, y=ym, |
| s=25, color='violet', |
| edgecolor='k', linewidth=0.2, alpha=0.7, |
| ax=ax_joint |
| ) |
| sns.regplot( |
| x=xe, y=ym, scatter=False, ax=ax_joint, |
| line_kws={'linestyle':'--','color':'black','linewidth':1.25} |
| ) |
| ax_joint.set_xlim(xlow, xhigh) |
| ax_joint.set_ylim(ylow, yhigh) |
| ax_joint.set_xticklabels([]) |
| ax_joint.set_yticklabels([]) |
|
|
| |
| ax_marg_y = fig.add_subplot(gs[1, 1]) |
| sns.histplot( |
| y=ym, bins=25, kde=True, |
| ax=ax_marg_y, color='grey', alpha=0.4, |
| orientation='horizontal' |
| ) |
| ax_marg_y.set_ylim(ylow, yhigh) |
| ax_marg_y.axis('off') |
|
|
| |
| if len(xe) > 1: |
| rho, _ = pearsonr(xe, ym) |
| ax_joint.text( |
| 0.05, 0.90, f"$\\rho$ = {rho:.2f}", |
| transform=ax_joint.transAxes, |
| fontsize=10, |
| bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="none", alpha=0.6) |
| ) |
|
|
| ax_joint.set_xlabel(r"$\mathbf{\frac{CE(P_{ft},Y)}{CE(P_{ft},Y)+CE(P_{pt},Y)}}$", fontsize=14) |
| ax_joint.set_ylabel(r"$\mathbf{\sigma\left(\mathrm{JS}(P_{pt},P_{ft})\right)}$", fontsize=11) |
|
|
| plt.tight_layout() |
| os.makedirs(os.path.dirname(save_path) or '.', exist_ok=True) |
| fig.savefig(save_path, dpi=300, bbox_inches='tight') |
| plt.close(fig) |
|
|
| import os |
| import numpy as np |
| import matplotlib.pyplot as plt |
| import seaborn as sns |
|
|
| def plot_stacked_ce_vs_mi_bins( |
| mi_values, |
| ce_values_pt, |
| ce_values_ft, |
| bins: int = 12, |
| figsize: tuple = (10, 5), |
| save_path: str = 'ce_vs_mi_stacked_bins.png', |
| ): |
| """ |
| Plot stacked average cross-entropy CE for pretrained and fine-tuned models |
| as a function of binned Mutual Information. Uses rocket palette for stacking. |
| |
| Args: |
| mi_values (array-like): Mutual information per sample. |
| ce_values_pt (array-like): Cross-entropy for pretrained model per sample. |
| ce_values_ft (array-like): Cross-entropy for fine-tuned model per sample. |
| bins (int): Number of bins. |
| figsize (tuple): Figure size. |
| save_path (str): Path to save the plot. |
| """ |
| |
| def to_np(x): |
| return x.cpu().numpy() if hasattr(x, 'cpu') else np.asarray(x) |
| mi = to_np(mi_values).ravel() |
| mi = (mi - mi.min()) / (mi.max() - mi.min()) |
| ce_pt = to_np(ce_values_pt).ravel() |
| ce_ft = to_np(ce_values_ft).ravel() |
|
|
| |
| edges = np.linspace(mi.min(), mi.max(), bins + 1) |
| bin_idx = np.digitize(mi, edges, right=True) - 1 |
| bin_idx = np.clip(bin_idx, 0, bins - 1) |
|
|
| |
| mean_pt = [] |
| mean_ft = [] |
| for i in range(bins): |
| mask = (bin_idx == i) |
| mean_pt.append(ce_pt[mask].mean() if mask.any() else np.nan) |
| mean_ft.append(ce_ft[mask].mean() if mask.any() else np.nan) |
|
|
| |
| labels = [f"({edges[i]:.2f},{edges[i+1]:.2f}]" for i in range(bins)] |
|
|
| |
| bottom_colors = sns.color_palette("Reds", bins) |
| top_colors = sns.color_palette("Blues", bins) |
|
|
| |
| plt.figure(figsize=figsize) |
| x = np.arange(bins) |
| plt.bar(x, mean_pt, color=bottom_colors, label='CE Pretrained') |
| plt.bar(x, mean_ft, bottom=mean_pt, color=top_colors, label='CE Fine-tuned') |
|
|
| |
| plt.xticks(x, labels, rotation=45, ha='right', fontsize=10) |
| plt.xlabel("Mutual Information Bins", fontsize=12) |
| plt.ylabel("Cross-Entropy Loss (CE)", fontsize=12) |
| plt.legend(loc='upper right') |
| sns.despine(trim=True) |
| plt.tight_layout() |
|
|
| |
| os.makedirs(os.path.dirname(save_path) or '.', exist_ok=True) |
| plt.savefig(save_path, dpi=300) |
| plt.close() |
|
|
| import os |
| import numpy as np |
| import matplotlib.pyplot as plt |
| import seaborn as sns |
| from scipy.stats import pearsonr |
|
|
| def plot_ce_vs_mi_by_correctness( |
| ce_pt: np.ndarray, |
| ce_ft: np.ndarray, |
| mi_values: np.ndarray, |
| correct_pt: np.ndarray, |
| correct_ft: np.ndarray, |
| figsize: tuple = (20, 4), |
| save_path: str = 'ce_vs_mi_by_correctness.png', |
| ): |
| """ |
| Plot CE vs. Mutual Information across 5 subsets: All, TT, TF, FT, FF. |
| For each panel: red scatter/regression for pretrained CE vs. MI, |
| blue scatter/regression for fine-tuned CE vs. MI. Annotate Pearson ρ_pt and ρ_ft. |
| """ |
| |
| def to_np(x): |
| return x.cpu().numpy() if hasattr(x, 'cpu') else x |
|
|
| ce_pt = to_np(ce_pt) |
| ce_ft = to_np(ce_ft) |
| mi = to_np(mi_values) |
| cpt = to_np(correct_pt) |
| cft = to_np(correct_ft) |
|
|
| masks = { |
| 'All': np.ones_like(mi, dtype=bool), |
| 'TrueTrue': np.logical_and(cpt, cft), |
| 'TrueFalse': np.logical_and(cpt, ~cft), |
| 'FalseTrue': np.logical_and(~cpt, cft), |
| 'FalseFalse':np.logical_and(~cpt, ~cft), |
| } |
|
|
| |
| color_pt = 'tab:red' |
| color_ft = 'tab:blue' |
|
|
| fig, axs = plt.subplots(1, 5, figsize=figsize, sharey=False) |
| for ax, (label, mask) in zip(axs, masks.items()): |
| x_pt = ce_pt[mask] |
| x_ft = ce_ft[mask] |
| y = mi[mask] |
|
|
| |
| ax.scatter(x_pt, y, c=color_pt, s=20, alpha=0.7, edgecolor='k', linewidth=0.2) |
| sns.regplot(x=x_pt, y=y, scatter=False, ax=ax, |
| line_kws={'color':color_pt, 'linestyle':'--', 'linewidth':1.5}) |
|
|
| |
| ax.scatter(x_ft, y, c=color_ft, s=20, alpha=0.7, edgecolor='k', linewidth=0.2) |
| sns.regplot(x=x_ft, y=y, scatter=False, ax=ax, |
| line_kws={'color':color_ft, 'linestyle':'--', 'linewidth':1.5}) |
|
|
| |
| if len(x_pt) > 1: |
| rho_pt, _ = pearsonr(x_pt, y) |
| ax.text(0.05, 0.90, f"$\\rho_{{pt}}={rho_pt:.2f}$", |
| transform=ax.transAxes, color=color_pt, |
| fontsize=10, bbox=dict(boxstyle="round,pad=0.2", fc="white", alpha=0.6, ec="none")) |
| if len(x_ft) > 1: |
| rho_ft, _ = pearsonr(x_ft, y) |
| ax.text(0.05, 0.80, f"$\\rho_{{ft}}={rho_ft:.2f}$", |
| transform=ax.transAxes, color=color_ft, |
| fontsize=10, bbox=dict(boxstyle="round,pad=0.2", fc="white", alpha=0.6, ec="none")) |
|
|
| ax.set_title(label, fontsize=12) |
| if label == 'All': |
| ax.set_xlabel('Cross-Entropy Error', fontsize=11) |
| ax.set_ylabel('Mutual Information (JSD)', fontsize=11) |
| else: |
| ax.set_xlabel('Cross-Entropy Error', fontsize=11) |
| ax.set_ylabel('') |
|
|
| ax.tick_params(labelsize=9) |
|
|
| plt.tight_layout() |
| os.makedirs(os.path.dirname(save_path) or '.', exist_ok=True) |
| fig.savefig(save_path, dpi=300) |
| plt.close(fig) |
|
|
|
|
| import torch |
| import matplotlib.pyplot as plt |
| from torchvision.utils import make_grid |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
|
|
| import os |
| import numpy as np |
| import matplotlib.pyplot as plt |
| import seaborn as sns |
| from matplotlib.ticker import MaxNLocator, FormatStrFormatter |
|
|
|
|
| def js_divergence(p: np.ndarray, q: np.ndarray) -> float: |
| """ |
| Compute the Jensen-Shannon divergence between two probability distributions. |
| """ |
| m = 0.5 * (p + q) |
| |
| p_safe = np.clip(p, 1e-12, 1) |
| q_safe = np.clip(q, 1e-12, 1) |
| m_safe = np.clip(m, 1e-12, 1) |
| return 0.5 * (np.sum(p_safe * np.log(p_safe / m_safe)) + |
| np.sum(q_safe * np.log(q_safe / m_safe))) |
|
|
|
|
| def plot_confidence_vs_js( |
| P_pt: np.ndarray, |
| P_ft: np.ndarray, |
| save_path: str |
| ) -> None: |
| """ |
| Plot combined confidence vs. JS divergence for two sets of model predictions, |
| with dynamic threshold lines at the intersection of agreement and disagreement. |
| |
| Args: |
| P_pt (np.ndarray): Pre-trained model probabilities, shape (N, C). |
| P_ft (np.ndarray): Fine-tuned model probabilities, shape (N, C). |
| save_path (str): File path where the figure will be saved. |
| """ |
| def to_np(x): |
| return x.cpu().numpy() if hasattr(x, 'cpu') else np.asarray(x) |
|
|
| |
| P_pt = to_np(P_pt) |
| P_ft = to_np(P_ft) |
|
|
| |
| conf_pt = P_pt.max(axis=1) |
| conf_ft = P_ft.max(axis=1) |
| combined_confidence = 0.5 * (conf_pt + conf_ft) |
|
|
| |
| js_values = np.array([js_divergence(P_pt[i], P_ft[i]) for i in range(len(P_pt))]) |
|
|
| |
| agree = np.argmax(P_pt, axis=1) == np.argmax(P_ft, axis=1) |
| disagree = ~agree |
|
|
| |
| conf_thresh = combined_confidence[disagree].min() |
| js_thresh = js_values[disagree].min() |
|
|
| |
| disagree_color = sns.color_palette("Blues", 2)[1] |
| agree_color = "violet" |
|
|
| |
| fig, ax = plt.subplots(figsize=(5, 5)) |
|
|
| |
| ax.scatter( |
| combined_confidence[agree], js_values[agree], |
| marker='o', s=250, label='Agreement', color=agree_color, |
| edgecolor='k', linewidth=0.75, alpha=0.5 |
| ) |
| ax.scatter( |
| combined_confidence[disagree], js_values[disagree], |
| marker='P', s=250, label='Disagreement', color=disagree_color, |
| edgecolor='k', linewidth=0.75, alpha=0.5 |
| ) |
|
|
| |
| ax.axvline(x=conf_thresh, linestyle='--', color='gray') |
| ax.axhline(y=js_thresh, linestyle='--', color='gray') |
|
|
| |
| x_min, x_max = combined_confidence.min(), combined_confidence.max() |
| y_min, y_max = js_values.min(), js_values.max() |
| x_margin = (x_max - x_min) * 0.05 |
| y_margin = (y_max - y_min) * 0.05 |
| ax.set_xlim(x_min - x_margin, x_max + x_margin) |
| ax.set_ylim(y_min - y_margin, y_max + y_margin) |
| |
| ax.xaxis.set_major_locator(MaxNLocator(6)) |
| ax.yaxis.set_major_locator(MaxNLocator(6)) |
| ax.xaxis.set_major_formatter(FormatStrFormatter('%.2f')) |
| ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f')) |
|
|
| |
| ax.set_facecolor('white') |
| ax.xaxis.set_tick_params(which='both', bottom=True, top=False, labelbottom=True, labelsize=13) |
| ax.yaxis.set_tick_params(which='both', left=True, right=False, labelleft=True, labelsize=13) |
| for spine in ax.spines.values(): |
| spine.set_visible(True) |
|
|
| |
| ax.set_xlabel(r'$\mathbf{Combined\ Confidence\ }$'+"\n"+r'$\mathbf{=\ \frac{1}{2}(\max_i\ p_{pt}^{(i)}\ +\ \max_i\ p_{ft}^{(i)})}$', fontsize=13) |
| ax.set_ylabel(r'$\mathbf{Divergence\ }$'+"\n"+r'$\mathbf{=\ \frac{1}{2}[KL(P_{pt}\|M)\ +\ KL(P_{ft}\|M)]}$', fontsize=13) |
|
|
| |
| |
| |
| |
| |
| ax.legend(fontsize=12, frameon=False, loc='best') |
|
|
| |
| os.makedirs(os.path.dirname(save_path) or '.', exist_ok=True) |
| fig.savefig(save_path, dpi=300, bbox_inches='tight') |
| plt.close(fig) |
|
|
|
|