| import argparse |
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| import os |
| from data.datautils import build_medmnist_dataset |
| from torchvision import transforms |
| from utils.tools import * |
| |
| from BetaMixture import BetaMixtureModel |
| from clip.custom_clip import get_coop |
| from data.cls_to_names import * |
| from tqdm import tqdm |
| from sklearn.metrics import roc_auc_score |
| from medmnistc_data import * |
| import copy |
| from datetime import datetime |
| import warnings |
| import gc |
| from baselines import * |
| warnings.filterwarnings("ignore") |
| import random |
|
|
| random.seed(0) |
|
|
|
|
| medimeta_testset_task_dict = { |
| |
| "pbc": ["cell_class","bloodmnist"], |
| |
| "mammo_mass": ["pathology","breastmnist"], |
| |
| "pneumonia": ["disease_class","pneumoniamnist"], |
| "fundus": ["disease_presence","retinamnist"], |
| "oct": ["disease_class","octmnist"] |
| } |
|
|
| method_names = { |
| |
| |
| 'model_ensemble': 'Model Ensemble', |
| 'wise_ft': 'Model Souping', |
| 'tcube': 'Entropy-based', |
| |
| |
| 'tcube_MI_bmm': 'Mutual Information', |
| } |
|
|
| ent_mi_dict = {'entropy': [], 'mi': [], 'agreement_diff': [], 'correct_pt': [], 'correct_ft': [], 'x_entropy': []} |
| dyn_v_stat_plot = {method: [] for method in method_names.keys()} |
| dyn_v_stat_plot['conditions'] = [] |
|
|
| def fetch_keys_for_value(dictionary, target_value): |
| return [key for key, value in dictionary.items() if value[1] == target_value] |
| |
| def load_models(args, classnames, set_id=None): |
| clip_pt = get_coop(args.arch, None, args.gpu, args.n_ctx, args.ctx_init, classnames) |
| sd_pt = clip_pt.state_dict() |
| |
| if set_id in medimeta_testset_task_dict.keys(): |
| ft_path = os.path.join(args.ft_path, f'fine_tuned_clip_{medimeta_testset_task_dict[set_id][1]}.pth') |
| else: |
| ft_path = os.path.join(args.ft_path, f'fine_tuned_clip_{set_id}.pth') |
| sd_ft = torch.load(ft_path, map_location='cpu') |
| if 'pub' in ft_path.lower(): |
| sd_ft = sd_ft['state_dict'] |
| clip_ft = get_coop(args.arch, None, args.gpu, args.n_ctx, args.ctx_init, state_dict=sd_ft, classnames=classnames) |
| del sd_ft |
| sd_ft = clip_ft.state_dict() |
| return clip_pt, sd_pt, clip_ft, sd_ft |
| def get_logits(model, dataloader, args, return_feats=False, normalize=True): |
| |
| model.eval() |
| logits = [] |
| labels = [] |
| image_features = [] |
| text_features = [] |
| with torch.no_grad(): |
| for inputs, label in tqdm(dataloader): |
| inputs = inputs.cuda(args.gpu, non_blocking=True) |
| label = label.cuda(args.gpu, non_blocking=True) |
| if return_feats: |
| outputs, img_feats, text_feats = model(inputs, return_logits=return_feats, normalize=normalize) |
| image_features.append(img_feats) |
| text_features.append(text_feats) |
| else: |
| outputs = model(inputs) |
| logits.append(outputs) |
| labels.append(label) |
| |
| if return_feats: |
| return torch.cat(logits), torch.cat(labels), torch.cat(image_features), torch.cat(text_features) |
| return torch.cat(logits), torch.cat(labels) |
| def self_entropy(logits, temperature=0.95): |
| logits = logits / temperature |
| probs = torch.nn.functional.softmax(logits, dim=1) |
| return -(probs * torch.log(probs + 1e-9)).sum(dim=1) |
| def interpolation(lambdas, sd_pt, sd_ft): |
| merged_sd = {} |
| for key in sd_ft.keys(): |
| interpolated_value = sd_pt[key] * lambdas[0] + sd_ft[key] * lambdas[1] |
| merged_sd[key] = interpolated_value |
| return merged_sd |
|
|
| def compute_samplewise_tcube_weights(clip_pt, clip_ft, dataloader, args): |
| logits_pt, _ = get_logits(clip_pt, dataloader, args, return_feats=False) |
| logits_ft, _ = get_logits(clip_ft, dataloader, args, return_feats=False) |
| ent_pt = self_entropy(logits_pt) |
| ent_ft = self_entropy(logits_ft) |
| expertise_pt = (-ent_pt).exp() |
| expertise_ft = (-ent_ft).exp() |
|
|
| total_expertise = expertise_pt + expertise_ft |
| if args.offset: |
| coef_bias = (ent_pt.std()/ent_pt.mean() + ent_ft.std()/ent_ft.mean()) / 2 |
| coef_biasw = (ent_pt.mean() + ent_ft.mean()) / ent_pt.mean() |
| lambda_ft = (expertise_ft + (coef_bias/coef_biasw)) / (total_expertise + coef_bias) |
| else: |
| lambda_ft = expertise_ft / total_expertise |
|
|
| |
| global ent_mi_dict |
| |
| |
| |
| |
| ent_mi_dict['entropy'] = lambda_ft |
| |
|
|
| if args.batch_wise: |
| batch_size = len(dataloader.dataset) // len(dataloader) |
| num_batches = len(dataloader) |
| if True: |
| |
| |
| |
| |
| |
| |
| lambda_ft_bmm = [] |
| lambda_ft_np = lambda_ft.cpu().numpy().reshape(-1,1) |
| bmm = BetaMixtureModel(n_mixtures=num_batches) |
| bmm.fit(lambda_ft_np) |
| for i in range(bmm.n_mixtures): |
| a,b = bmm.beta_params_[i, 0],bmm.beta_params_[i, 1] |
| |
| lambda_ft_bmm.append(a/(a+b)) |
| lambda_ft_bmm = torch.tensor(lambda_ft_bmm) |
| lambda_pt = 1 - lambda_ft_bmm |
| return torch.stack([lambda_pt, lambda_ft_bmm], dim=0) |
| coefs_label = bmm.predict(lambda_ft_np) |
| |
| lambda_pt = 1 - lambda_ft |
| |
| return torch.stack([lambda_pt, lambda_ft]) |
| def compute_samplewise_tcube_weights_MI(clip_pt, clip_ft, dataloader, args, delta=0.5, batch_wise=True): |
| |
| logits_pt, labels = get_logits(clip_pt, dataloader, args, return_feats=False) |
| logits_ft, _ = get_logits(clip_ft, dataloader, args, return_feats=False) |
| |
| |
| p_pt = torch.softmax(logits_pt, dim=1) |
| p_ft = torch.softmax(logits_ft, dim=1) |
| |
| pred_pt = p_pt.argmax(dim=1) |
| pred_ft = p_ft.argmax(dim=1) |
| correct_pt = pred_pt.eq(labels.squeeze()) |
| correct_ft = pred_ft.eq(labels.squeeze()) |
| |
| |
| p_bar = (p_pt + p_ft) / 2.0 |
| |
| |
| |
| kl_pt = torch.sum(p_pt * torch.log(p_pt / (p_bar + 1e-8)), dim=1) |
| kl_ft = torch.sum(p_ft * torch.log(p_ft / (p_bar + 1e-8)), dim=1) |
| |
| |
| MI = 0.5 * (kl_pt + kl_ft) |
| MI_orig = MI |
| |
| |
| |
| lam_min = args.lam_min if hasattr(args, 'lam_min') else 0.01 |
| lam_max = args.lam_max if hasattr(args, 'lam_max') else 0.99 |
| gamma = args.gamma if hasattr(args, 'gamma') else 0.5 |
| lambda_ft = lam_min + (lam_max - lam_min) * torch.sigmoid(gamma * MI) |
| lambda_plot = lam_min + (lam_max - lam_min) * torch.sigmoid(gamma * MI_orig) |
| |
|
|
| |
| ent_pt = self_entropy(logits_pt) |
| ent_ft = self_entropy(logits_ft) |
|
|
| |
| entropy_thresh_ft = getattr(args, 'entropy_thresh_ft', 0.05) |
| entropy_thresh_pt = getattr(args, 'entropy_thresh_pt', 0.65) |
| delta_extrap = delta |
| |
| |
| |
| lambda_ft = torch.where( |
| ent_ft < entropy_thresh_ft, |
| |
| lambda_ft + delta_extrap, |
| torch.where( |
| ent_pt < entropy_thresh_pt, |
| |
| lambda_ft - delta_extrap, |
| |
| lambda_ft |
| ) |
| ) |
|
|
| |
| lambda_ft = torch.clamp(lambda_ft, 0.0, 1.5) |
| lambda_pt = 1 - lambda_ft |
|
|
|
|
| |
| global ent_mi_dict |
| |
| |
| |
| |
| ent_mi_dict['mi'] = MI |
| |
| |
| ent_mi_dict['Ppt'] = p_pt |
| ent_mi_dict['Pft'] = p_ft |
| ent_mi_dict['correct_pt'] = correct_pt |
| ent_mi_dict['correct_ft'] = correct_ft |
| ce_pt = F.cross_entropy(logits_pt, labels.squeeze(), reduction='none') |
| ce_ft = F.cross_entropy(logits_ft, labels.squeeze(), reduction='none') |
| x_entropy_ratio = ce_ft / (ce_pt + ce_ft + 1e-9) |
| ent_mi_dict['x_entropy'] = x_entropy_ratio |
| ent_mi_dict['CE_pt'] = ce_pt |
| ent_mi_dict['CE_ft'] = ce_ft |
| |
| |
| |
| if batch_wise: |
| batch_size = len(dataloader.dataset) // len(dataloader) |
| num_batches = len(dataloader) |
| if args.lambda_mean_type == 'mean': |
| lambda_ft_batchwise = lambda_ft[:num_batches * batch_size].view(num_batches, batch_size).mean(dim=1) |
| lambda_pt_batchwise = 1 - lambda_ft_batchwise |
| return torch.stack([lambda_pt_batchwise, lambda_ft_batchwise], dim=0) |
| elif args.lambda_mean_type == 'bmm': |
| |
| lambda_ft_bmm = [] |
| lambda_ft_np = lambda_ft.cpu().numpy().reshape(-1,1) |
| bmm = BetaMixtureModel(n_mixtures=num_batches) |
| bmm.fit(lambda_ft_np) |
| for i in range(bmm.n_mixtures): |
| a, b = bmm.beta_params_[i, 0], bmm.beta_params_[i, 1] |
| |
| lambda_ft_bmm.append(a/(a+b)) |
| lambda_ft_bmm = torch.tensor(lambda_ft_bmm) |
| lambda_pt_bmm = 1 - lambda_ft_bmm |
| return torch.stack([lambda_pt_bmm, lambda_ft_bmm], dim=0) |
| |
| |
| return torch.stack([lambda_pt, lambda_ft]), lambda_plot |
| def compute_and_evaluate_model_ensemble(clip_pt, clip_ft, dataloaders, args): |
| logits_pt, _ = get_logits(clip_pt, dataloaders[0], args, return_feats=False, normalize=False) |
| logits_ft, _ = get_logits(clip_ft, dataloaders[0], args, return_feats=False, normalize=False) |
| |
| logits_final = (logits_pt + logits_ft) / 2.0 |
| |
| labels_final = [] |
| for _, label in tqdm(dataloaders[0]): |
| labels_final.append(label) |
| labels_final = torch.cat(labels_final).cuda(args.gpu, non_blocking=True) |
| |
| return compute_metrics(logits_final, labels_final) |
| def compute_samplewise_conf_weights(clip_pt, clip_ft, dataloader, device="cuda"): |
| clip_pt.to(device).eval() |
| clip_ft.to(device).eval() |
| all_lambdas = [] |
| with torch.no_grad(): |
| for images, _ in dataloader: |
| images = images.to(device) |
| |
| logits_pt = clip_pt(images) |
| logits_ft = clip_ft(images) |
| |
| conf_pt = F.softmax(logits_pt, dim=1).max(dim=1)[0] |
| conf_ft = F.softmax(logits_ft, dim=1).max(dim=1)[0] |
| |
| conf_stack = torch.stack([conf_pt, conf_ft], dim=0) |
| |
| lambdas = conf_stack / conf_stack.sum(dim=0, keepdim=True) |
| all_lambdas.append(lambdas) |
| |
| all_lambdas = torch.cat(all_lambdas, dim=1) |
| return all_lambdas |
|
|
| def evaluate_zero_shot(clip, dataloaders, classnames, args): |
| """ Evaluate using zero-shot """ |
| model = copy.deepcopy(clip) |
| return evaluate_model(model, dataloaders[0], args) |
| def evaluate_wise_ft(clip_pt, sd_pt, sd_ft, dataloaders, args): |
| """ Evaluate using weight-space interpolation (WiSE-FT). """ |
| model = copy.deepcopy(clip_pt) |
| sd_pt = copy.deepcopy(sd_pt) |
| sd_ft = copy.deepcopy(sd_ft) |
|
|
| alpha = 0.5 |
| merged_sd = {key: (alpha * sd_ft[key] + (1 - alpha) * sd_pt[key]) for key in sd_ft.keys()} |
| model.load_state_dict(merged_sd) |
| |
| return evaluate_model(model, dataloaders[0], args) |
| def evaluate_tcube(clip_pt, sd_pt, sd_ft, lambdas, dataloaders, args, batch_wise=True): |
| """ Evaluate using TCube (Entropy-based Weight Interpolation). """ |
| |
| model = copy.deepcopy(clip_pt) |
| sd_pt = copy.deepcopy(sd_pt) |
| sd_ft = copy.deepcopy(sd_ft) |
|
|
| logits_final, labels_final = [], [] |
| dataloader = dataloaders[0] if batch_wise else dataloaders[1] |
| for i, (inputs, label) in enumerate(tqdm(dataloader)): |
| inputs, label = inputs.cuda(args.gpu, non_blocking=True), label.cuda(args.gpu, non_blocking=True) |
|
|
| merged_sd = interpolation(lambdas[:, i], sd_pt, sd_ft) |
| |
| model.load_state_dict(merged_sd, strict=False) |
| model.eval() |
|
|
| with torch.no_grad(): |
| outputs = model(inputs) |
| logits_final.append(outputs) |
| labels_final.append(label) |
|
|
| logits_final = torch.cat(logits_final).cuda(args.gpu, non_blocking=True) |
| labels_final = torch.cat(labels_final).cuda(args.gpu, non_blocking=True) |
| |
| return compute_metrics(logits_final, labels_final) |
| def evaluate_model(model, dataloader, args): |
| """ Generic evaluation function for a given model. """ |
| logits_final, labels_final = [], [] |
| model.eval() |
| for inputs, label in tqdm(dataloader): |
| inputs, label = inputs.cuda(args.gpu, non_blocking=True), label.cuda(args.gpu, non_blocking=True) |
| with torch.no_grad(): |
| outputs = model(inputs) |
| logits_final.append(outputs) |
| labels_final.append(label) |
|
|
| logits_final = torch.cat(logits_final) |
| labels_final = torch.cat(labels_final) |
| return compute_metrics(logits_final, labels_final) |
| def compute_metrics(logits_final, labels_final): |
| """ Compute Accuracy and AUC metrics. """ |
| logits_final_tensor = (logits_final) |
| labels_final_tensor = (labels_final) |
| acc = accuracy(logits_final_tensor, labels_final_tensor) |
| probs = F.softmax(logits_final_tensor, dim=1).cpu().numpy() |
| labels = labels_final_tensor.view(-1).cpu().numpy() |
| if probs.shape[1] > 2: |
| |
| unique_classes = np.unique(labels) |
| n_classes = probs.shape[1] |
| if len(unique_classes) < n_classes: |
| |
| |
| auc_scores = [] |
| for cls in unique_classes: |
| if np.sum(labels == cls) > 0: |
| |
| binary_labels = (labels == cls).astype(int) |
| auc_scores.append(roc_auc_score(binary_labels, probs[:, cls])) |
| auc = np.mean(auc_scores) if auc_scores else 0.5 |
| else: |
| |
| auc = roc_auc_score(labels, probs, multi_class='ovr', average='macro') |
| else: |
| |
| auc = roc_auc_score(labels, probs[:, 1]) |
| return acc, auc*100 |
|
|
| class CustomDataset(Dataset): |
| def __init__(self, images, labels, transform=None): |
| self.images = images |
| self.labels = labels |
| self.transform = transform |
|
|
| def __len__(self): |
| return len(self.images) |
|
|
| def __getitem__(self, idx): |
| image = Image.fromarray(self.images[idx]) |
| label = self.labels[idx] |
| if self.transform: |
| image = self.transform(image) |
| return image, label |
| def get_transform(args): |
| transform = transforms.Compose([ |
| transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BICUBIC), |
| transforms.CenterCrop(args.resolution), |
| transforms.Lambda(lambda image: image.convert('RGB')), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[.5], std=[.5]) |
| ]) |
| return transform |
| def get_medmnistc_dataloader(args, set_id, batch_size=32, num_workers=4, split='test', dataset=None, severity=None): |
| transform = get_transform(args) |
| data_root = os.path.join(args.medmnistc_data, set_id, split) |
| path = os.path.join(data_root, f'{dataset}_severity_{severity}.npz') if dataset not in ["clean", None] else os.path.join(data_root, "clean.npz") |
| if not os.path.exists(path): |
| raise FileNotFoundError(f"Dataset file not found: {path}") |
| data = np.load(path) |
| images = data["images"] |
| labels = data["labels"].squeeze() |
| dataset = CustomDataset(images, labels, transform=transform) |
| return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) |
| def get_medimeta_dataloader(args, testset, batch_size=32, num_workers=4, split='test'): |
| transform = get_transform(args) |
| task_name = medimeta_testset_task_dict[testset][0].replace("_", " ") |
| dataset = build_medimeta_dataset(args.medimeta_data, testset, task_name, transform) |
| return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) |
|
|
| def evaluate_on_test_set(args, set_id, _dataset, severities, clip_pt, sd_pt, clip_ft, sd_ft, classnames, results, test_set=None): |
| if set_id not in results: |
| results[set_id] = {} |
| _dataset = test_set if test_set is not None else _dataset |
| if _dataset not in results[set_id]: |
| results[set_id][_dataset] = {} |
| |
| for severity in severities: |
| print(f"\nEvaluating on _dataset: {_dataset}, severity: {severity}.....") |
| if test_set is not None: |
| _dataloaders = [get_medimeta_dataloader(args, test_set, batch_size=args.bs), |
| get_medimeta_dataloader(args, test_set, batch_size=1)] |
| else: |
| _dataloaders = [get_medmnistc_dataloader(args, set_id, batch_size=args.bs, dataset=_dataset, severity=severity), |
| get_medmnistc_dataloader(args, set_id, batch_size=1, dataset=_dataset, severity=severity)] |
|
|
| |
| |
| |
| lambdas_tcube_MI_bmm = compute_samplewise_tcube_weights_MI(clip_pt, clip_ft, _dataloaders[0], args, batch_wise=True) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| plot_confidence_vs_js(ent_mi_dict['Ppt'], ent_mi_dict['Pft'], save_path=f'/home/raza.imam/Documents/Umaima/TPT/results_tcube/plots/conf_v_jsd/{(args.arch).replace("/", "_")}/{set_id}_{_dataset}_{severity}.png') |
| |
|
|
| lambdas_dict = { |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| } |
| |
| if severity not in results[set_id][_dataset]: |
| results[set_id][_dataset][severity] = {} |
| |
| for method_type, lambdas in lambdas_dict.items(): |
| print("Interpolating and evaluating on - interpolation method: ", method_type) |
| global dyn_v_stat_plot |
| if method_type == 'zero_shot_pt': |
| acc, auc = evaluate_zero_shot(clip_pt, _dataloaders, classnames, args) |
| elif method_type == 'zero_shot_ft': |
| acc, auc = evaluate_zero_shot(clip_ft, _dataloaders, classnames, args) |
| elif method_type == 'model_ensemble': |
| acc, auc = compute_and_evaluate_model_ensemble(clip_pt, clip_ft, _dataloaders, args) |
| elif method_type == 'wise_ft': |
| acc, auc = evaluate_wise_ft(clip_pt, sd_pt, sd_ft, _dataloaders, args) |
| elif method_type == 'slerp': |
| acc, auc = evaluate_slerp(clip_pt, sd_pt, sd_ft, _dataloaders[0], args) |
| elif method_type == 't_arithmetic': |
| acc, auc = evaluate_task_arithmetic(clip_pt, sd_pt, sd_ft, _dataloaders[0], args) |
| elif method_type == 'm3': |
| acc, auc = evaluate_m3(clip_pt, sd_pt, sd_ft, _dataloaders[0], args) |
| elif method_type == 'tcube': |
| acc, auc = evaluate_tcube(clip_pt, sd_pt, sd_ft, lambdas, _dataloaders, args, batch_wise=args.batch_wise) |
| elif method_type == 'conf': |
| acc, auc = evaluate_tcube(clip_pt, sd_pt, sd_ft, lambdas, _dataloaders, args, batch_wise=False) |
| elif method_type == 'tcube_MI': |
| acc, auc = evaluate_tcube(clip_pt, sd_pt, sd_ft, lambdas, _dataloaders, args, batch_wise=False) |
| elif method_type == 'tcube_MI_bmm': |
| acc, auc = evaluate_tcube(clip_pt, sd_pt, sd_ft, lambdas, _dataloaders, args, batch_wise=True) |
| |
| print(f'Accuracy: {acc[0].item():.2f}%, AUC: {auc:.2f}%, Mean: {(acc[0].item()+auc)/2:.2f}%') |
|
|
| results[set_id][_dataset][severity][method_type] = {'accuracy': acc[0].item(), 'auc': auc, 'mean': (acc[0].item()+auc)/2} |
| if method_type in method_names: |
| |
| |
| dyn_v_stat_plot[method_type].append(acc[0].item()) |
|
|
| |
| |
| |
| |
| |
| |
| |
| del _dataloaders, lambdas_dict |
| gc.collect() |
|
|
| return results |
| def evaluate_on_datasets(args, datasets, default_datasets, default_severity_range): |
| results = {} |
| for set_id in datasets: |
| print(f"\nEvaluating on dataset: {set_id}\n") |
|
|
| for _dataset in default_datasets: |
| severities = [0] if _dataset in ["clean", "medimeta"] else range(default_severity_range[0], default_severity_range[1]+1) |
| |
| if _dataset == "medimeta": |
| test_sets = fetch_keys_for_value(medimeta_testset_task_dict, set_id) |
| for test_set in test_sets: |
| classnames = eval("{}_classes".format(test_set.lower())) |
| clip_pt, sd_pt, clip_ft, sd_ft = load_models(args, classnames, set_id) |
| results = evaluate_on_test_set(args, set_id, _dataset, severities, clip_pt, sd_pt, clip_ft, sd_ft, classnames, results, test_set) |
| else: |
| classnames = eval("{}_classes".format(set_id.lower())) |
| clip_pt, sd_pt, clip_ft, sd_ft = load_models(args, classnames, set_id) |
| results = evaluate_on_test_set(args, set_id, _dataset, severities, clip_pt, sd_pt, clip_ft, sd_ft, classnames, results) |
|
|
| del clip_pt, clip_ft, sd_ft |
| |
| |
| |
| |
| |
| |
| return results |
|
|
| def print_results(results): |
| now = datetime.now() |
| formatted_date = now.strftime("%Y-%m-%d %H:%M:%S") |
| print(f"\nResults (Evaluated on: {formatted_date}):") |
| for set_id, result in results.items(): |
| print(f"\nDataset: {set_id}") |
| print("=" * 75) |
| print(f"{'_dataset':<20}{'Severity':<10}{'Method':<20}{'Accuracy':<10}{'AUC':<10}{'Mean':<10}") |
| for _dataset, severity_dict in result.items(): |
| print("=" * 75) |
| for severity, metrics_dict in severity_dict.items(): |
| print("-" * 80) |
| for method_type, metrics in metrics_dict.items(): |
| print(f"{_dataset:<20}{severity:<10}{method_type:<20}{metrics['accuracy']:<10.2f}{metrics['auc']:<10.2f}{metrics['mean']:<10.2f}") |
| print("=" * 75) |
| def log_results(results, args): |
| now = datetime.now() |
| formatted_date = now.strftime("%Y-%m-%d %H:%M:%S") |
| |
| os.makedirs(os.path.dirname(args.log_path), exist_ok=True) |
| with open(args.log_path, 'w') as log_file: |
| log_file.write(f"\nResults (Evaluated on: {formatted_date}):\n") |
| log_file.write(f"Arguments:\n") |
| for arg, value in vars(args).items(): |
| log_file.write(f"{arg}: {value}\n") |
| log_file.write("\n") |
| |
| for set_id, result in results.items(): |
| log_file.write(f"\nDataset Group: {set_id}\n") |
| log_file.write("-" * 80 + "\n") |
| |
| header = f"{'_dataset':<20}{'Severity':<10}{'Method':<20}" \ |
| f"{'Accuracy':<15}{'AUC':<15}{'Mean':<15}\n" |
| log_file.write(header) |
| |
| for _dataset, severity_dict in result.items(): |
| for severity, metrics_dict in severity_dict.items(): |
| for method_type, metrics in metrics_dict.items(): |
| line = f"{_dataset:<20}{str(severity):<10}{method_type:<20}" \ |
| f"{metrics['accuracy']:<15.2f}{metrics['auc']:<15.2f}{metrics['mean']:<15.2f}\n" |
| log_file.write(line) |
| log_file.write("-" * 80 + "\n") |
| log_file.write("-" * 80 + "\n") |
| def save_json_results(results, args): |
| json_results = {} |
| |
| for set_id, result in results.items(): |
| |
| for dataset, severity_dict in result.items(): |
| for severity, metrics_dict in severity_dict.items(): |
| for method, metrics in metrics_dict.items(): |
| if method not in json_results: |
| json_results[method] = {} |
| if dataset not in json_results[method]: |
| json_results[method][dataset] = {} |
| |
| json_results[method][dataset][str(severity)] = { |
| "accuracy": metrics["accuracy"], |
| "auc": metrics["auc"], |
| "mean": metrics["mean"] |
| } |
| |
| os.makedirs(os.path.dirname(args.json_path), exist_ok=True) |
| with open(args.json_path, 'w') as f: |
| json.dump(json_results, f, indent=4) |
|
|
| def main(): |
| default_ft_path = [ |
| '/home/raza.imam/Documents/Umaima/TPT/finetuned_models/ViT-B_16' |
| ] |
| default_medmnistc_root = '/home/raza.imam/Documents/Umaima/TPT/MedMNIST-C' |
| default_medimeta_root = '/home/raza.imam/Documents/Umaima/datasets/medimeta' |
| default_testset = 'breastmnist/retinamnist/bloodmnist/octmnist' |
| default_datasets = [ |
| "clean", |
| "medimeta", |
| |
| "impulse_noise", |
| |
| |
| |
| |
| "pixelate", |
| ] |
| default_seed = 42 |
| default_arch = 'ViT-B/16' |
| default_ctx_init = 'a_photo_of_a' |
| default_gpu = 1 |
| default_severity_range = [5, 5] |
| default_batch_wise = True |
| default_offset = False |
| default_lambda_mean_type = 'mean' |
| default_bs = 32 |
| save_time = datetime.now().strftime("%Y%m%d_%H%M") |
| save_path = f'/home/raza.imam/Documents/Umaima/TPT/results_tcube/{save_time}_{default_arch.replace("/", "_")}/' |
| default_log_path = f'{save_path}log.txt' |
| default_json_path = f'{save_path}dict.json' |
| |
| parser = argparse.ArgumentParser(description='Multi-Model Interpolation') |
| parser.add_argument('medmnistc_data', metavar='DIR', nargs="?", default=default_medmnistc_root, help='path to medmnistc dataset root') |
| parser.add_argument('medimeta_data', metavar='DIR', nargs="?", default=default_medimeta_root, help='path to medimeta dataset root') |
| parser.add_argument('--ft_path', type=str, default=default_ft_path[0], help='Paths to FT model state dicts') |
| parser.add_argument('--log_path', type=str, default=default_log_path, help='Path to save results') |
| parser.add_argument('--json_path', type=str, default=default_json_path, help='Path to save results in json format') |
| parser.add_argument('--testset', type=str, default=default_testset, help='Dataset name') |
| parser.add_argument('--offset', action='store_true', default=default_offset, help='Use offset for TCube') |
| parser.add_argument('--lambda_mean_type', type=str, default=default_lambda_mean_type, help='Type of lambda mean for TCube') |
| parser.add_argument('--batch_wise', action='store_true', default=default_batch_wise) |
| parser.add_argument('--seed', type=int, default=default_seed, help='Random seed') |
| parser.add_argument('-a', '--arch', metavar='ARCH', default=default_arch, help='model architecture') |
| parser.add_argument('--gpu', type=int, default=default_gpu, help='GPU ID') |
| parser.add_argument('--n_ctx', default=4, type=int, help='number of tunable tokens') |
| parser.add_argument('--ctx_init', default=default_ctx_init, type=str, help='init tunable prompts') |
| parser.add_argument('--resolution', default=224, type=int, help='CLIP image resolution') |
| parser.add_argument('--bs', default=default_bs, type=int, help='Batch size') |
| args = parser.parse_args() |
| print(args) |
|
|
| torch.manual_seed(args.seed) |
|
|
| datasets = args.testset.split("/") |
| results = evaluate_on_datasets(args, datasets, default_datasets, default_severity_range) |
|
|
| |
| log_results(results, args) |
| save_json_results(results, args) |
|
|
| if __name__ == '__main__': |
| main() |
|
|