File size: 5,885 Bytes
3589275 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | import json
import os
from utils import find_optimal_coef
from src.args import parse_arguments
from src.eval import evaluate_task_vector, evaluate_task_vector_at_coef
from src.task_vectors import LinearizedTaskVector, NonLinearTaskVector
from src.attention_only_finetune import AttentionOnlyFinetuneEncoder
args = parse_arguments()
if 'ortho' in args.finetuning_mode:
args.save = f"checkpoints_{args.seed}/{args.finetuning_mode}_{args.lr}_lambda{args.ortho_lambda}_{args.model}"
else:
if args.seed is not None:
args.save = f"checkpoints_{args.seed}/{args.finetuning_mode}_{args.lr}_{args.model}"
else:
args.save = f"checkpoints/{args.finetuning_mode}_{args.lr}_{args.model}"
print("*" * 100)
mode_labels = {
"standard": "Evaluating non-linear FT models.",
"standard_ortho": "Evaluating standard FT models with orthogonality regularization.",
"linear": "Evaluating linear FT models.",
"linear_ortho": "Evaluating linear FT models with orthogonality regularization.",
"linear-2": "Evaluating Attention-Only Finetune models.",
"linear-2_ortho": "Evaluating Attention-Only Finetune models with orthogonality regularization.",
}
ft_accuracies_name_map = {
"standard": "ft_accuracies.json",
"standard_ortho": "standard_ortho_ft_accuracies.json",
"linear": "linear_ft_accuracies.json",
"linear_ortho": "linear_ortho_ft_accuracies.json",
"linear-2": "linear-2_ft_accuracies.json",
"linear-2_ortho": "linear-2_ortho_ft_accuracies.json",
}
print(mode_labels.get(args.finetuning_mode, f"Evaluating {args.finetuning_mode} models."))
print("*" * 100)
ft_accuracies_path = os.path.join(args.save, ft_accuracies_name_map[args.finetuning_mode])
with open(ft_accuracies_path) as f:
args.finetuning_accuracies = json.load(f)
if args.seed is not None:
base_model_save_path = f"checkpoints_{args.seed}/{args.model}"
else:
base_model_save_path = f"checkpoints/{args.model}"
with open(os.path.join(base_model_save_path, "zeroshot_accuracies.json")) as f:
pretrained_accuracies = json.load(f)
eval_datasets = [
"Cars", "DTD", "EuroSAT", "GTSRB", "MNIST", "RESISC45", "SVHN", "SUN397",
]
task_vectors = []
mode = args.finetuning_mode
for dataset in eval_datasets:
if mode == "linear":
pretrained_checkpoint = f"{args.save}/{dataset}Val/linear_zeroshot.pt"
finetuned_checkpoint = f"{args.save}/{dataset}Val/linear_finetuned.pt"
task_vectors.append(LinearizedTaskVector(pretrained_checkpoint, finetuned_checkpoint))
elif mode == "linear_ortho":
pretrained_checkpoint = f"{args.save}/{dataset}Val/linear_ortho_zeroshot.pt"
finetuned_checkpoint = f"{args.save}/{dataset}Val/linear_ortho_finetuned.pt"
task_vectors.append(LinearizedTaskVector(pretrained_checkpoint, finetuned_checkpoint))
elif mode == "standard_ortho":
pretrained_checkpoint = f"{args.save}/{dataset}Val/standard_ortho_zeroshot.pt"
finetuned_checkpoint = f"{args.save}/{dataset}Val/standard_ortho_finetuned.pt"
task_vectors.append(NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint))
elif mode in ("linear-2", "linear-2_ortho"):
prefix = mode + "_"
pretrained_checkpoint = f"{args.save}/{dataset}Val/{prefix}zeroshot.pt"
finetuned_checkpoint = f"{args.save}/{dataset}Val/{prefix}finetuned.pt"
if not (os.path.exists(pretrained_checkpoint) and os.path.exists(finetuned_checkpoint)):
print(f"Warning: Missing checkpoints for {dataset}. Skipping.")
continue
task_vectors.append(NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint))
else: # standard
pretrained_checkpoint = f"{args.save}/{dataset}Val/zeroshot.pt"
finetuned_checkpoint = f"{args.save}/{dataset}Val/finetuned.pt"
task_vectors.append(NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint))
if not task_vectors:
print("No task vectors were created. Exiting.")
exit()
task_vector = sum(task_vectors)
# Determine the base pretrained checkpoint
mode_prefix_map = {
"standard": "",
"standard_ortho": "standard_ortho_",
"linear": "linear_",
"linear_ortho": "linear_ortho_",
"linear-2": "linear-2_",
"linear-2_ortho": "linear-2_ortho_",
}
mode_prefix = mode_prefix_map[mode]
pretrained_checkpoint = f"{args.save}/{eval_datasets[0]}Val/{mode_prefix}zeroshot.pt"
if not os.path.exists(pretrained_checkpoint):
print(f"Error: Base pretrained checkpoint not found at {pretrained_checkpoint}")
exit()
args.eval_datasets = [dataset + "Val" for dataset in eval_datasets]
args.control_dataset = None
val_metrics = evaluate_task_vector(
task_vector,
pretrained_checkpoint,
args,
posthoc_linearization=False,
)
optimal_coef = find_optimal_coef(
val_metrics,
metric="avg_normalized_top1",
minimize=False,
)
args.eval_datasets = [dataset for dataset in eval_datasets]
test_metrics = evaluate_task_vector_at_coef(
task_vector,
pretrained_checkpoint,
args,
float(optimal_coef),
posthoc_linearization=False,
)
print("=" * 100)
print(f"Optimal Coefficient: {optimal_coef}")
print(f"Test normalized accuracy: {test_metrics['avg_normalized_top1']}")
print(f"Test absolute accuracy: {test_metrics['avg_top1']}")
additive_accuracies = {"test": test_metrics, "val": val_metrics, "optimal_coef": optimal_coef}
save_name_map = {
"standard": "additions.json",
"standard_ortho": "standard_ortho_additions.json",
"linear": "linear_additions.json",
"linear_ortho": "linear_ortho_additions.json",
"linear-2": "linear-2_additions.json",
"linear-2_ortho": "linear-2_ortho_additions.json",
}
save_file = os.path.join(args.save, save_name_map[mode])
with open(save_file, "w") as f:
json.dump(additive_accuracies, f, indent=4)
print(f"Addition results saved to {save_file}")
|