File size: 5,885 Bytes
3589275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import json
import os

from utils import find_optimal_coef

from src.args import parse_arguments
from src.eval import evaluate_task_vector, evaluate_task_vector_at_coef
from src.task_vectors import LinearizedTaskVector, NonLinearTaskVector
from src.attention_only_finetune import AttentionOnlyFinetuneEncoder

args = parse_arguments()

if 'ortho' in args.finetuning_mode:
    args.save = f"checkpoints_{args.seed}/{args.finetuning_mode}_{args.lr}_lambda{args.ortho_lambda}_{args.model}"
else:
    if args.seed is not None:
        args.save = f"checkpoints_{args.seed}/{args.finetuning_mode}_{args.lr}_{args.model}"
    else:
        args.save = f"checkpoints/{args.finetuning_mode}_{args.lr}_{args.model}"

print("*" * 100)
mode_labels = {
    "standard": "Evaluating non-linear FT models.",
    "standard_ortho": "Evaluating standard FT models with orthogonality regularization.",
    "linear": "Evaluating linear FT models.",
    "linear_ortho": "Evaluating linear FT models with orthogonality regularization.",
    "linear-2": "Evaluating Attention-Only Finetune models.",
    "linear-2_ortho": "Evaluating Attention-Only Finetune models with orthogonality regularization.",
}
ft_accuracies_name_map = {
    "standard": "ft_accuracies.json",
    "standard_ortho": "standard_ortho_ft_accuracies.json",
    "linear": "linear_ft_accuracies.json",
    "linear_ortho": "linear_ortho_ft_accuracies.json",
    "linear-2": "linear-2_ft_accuracies.json",
    "linear-2_ortho": "linear-2_ortho_ft_accuracies.json",
}
print(mode_labels.get(args.finetuning_mode, f"Evaluating {args.finetuning_mode} models."))
print("*" * 100)

ft_accuracies_path = os.path.join(args.save, ft_accuracies_name_map[args.finetuning_mode])
with open(ft_accuracies_path) as f:
    args.finetuning_accuracies = json.load(f)

if args.seed is not None:
    base_model_save_path = f"checkpoints_{args.seed}/{args.model}"
else:
    base_model_save_path = f"checkpoints/{args.model}"

with open(os.path.join(base_model_save_path, "zeroshot_accuracies.json")) as f:
    pretrained_accuracies = json.load(f)

eval_datasets = [
    "Cars", "DTD", "EuroSAT", "GTSRB", "MNIST", "RESISC45", "SVHN", "SUN397",
]

task_vectors = []
mode = args.finetuning_mode

for dataset in eval_datasets:
    if mode == "linear":
        pretrained_checkpoint = f"{args.save}/{dataset}Val/linear_zeroshot.pt"
        finetuned_checkpoint = f"{args.save}/{dataset}Val/linear_finetuned.pt"
        task_vectors.append(LinearizedTaskVector(pretrained_checkpoint, finetuned_checkpoint))

    elif mode == "linear_ortho":
        pretrained_checkpoint = f"{args.save}/{dataset}Val/linear_ortho_zeroshot.pt"
        finetuned_checkpoint = f"{args.save}/{dataset}Val/linear_ortho_finetuned.pt"
        task_vectors.append(LinearizedTaskVector(pretrained_checkpoint, finetuned_checkpoint))

    elif mode == "standard_ortho":
        pretrained_checkpoint = f"{args.save}/{dataset}Val/standard_ortho_zeroshot.pt"
        finetuned_checkpoint = f"{args.save}/{dataset}Val/standard_ortho_finetuned.pt"
        task_vectors.append(NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint))

    elif mode in ("linear-2", "linear-2_ortho"):
        prefix = mode + "_"
        pretrained_checkpoint = f"{args.save}/{dataset}Val/{prefix}zeroshot.pt"
        finetuned_checkpoint = f"{args.save}/{dataset}Val/{prefix}finetuned.pt"
        if not (os.path.exists(pretrained_checkpoint) and os.path.exists(finetuned_checkpoint)):
            print(f"Warning: Missing checkpoints for {dataset}. Skipping.")
            continue
        task_vectors.append(NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint))

    else:  # standard
        pretrained_checkpoint = f"{args.save}/{dataset}Val/zeroshot.pt"
        finetuned_checkpoint = f"{args.save}/{dataset}Val/finetuned.pt"
        task_vectors.append(NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint))

if not task_vectors:
    print("No task vectors were created. Exiting.")
    exit()

task_vector = sum(task_vectors)

# Determine the base pretrained checkpoint
mode_prefix_map = {
    "standard": "",
    "standard_ortho": "standard_ortho_",
    "linear": "linear_",
    "linear_ortho": "linear_ortho_",
    "linear-2": "linear-2_",
    "linear-2_ortho": "linear-2_ortho_",
}
mode_prefix = mode_prefix_map[mode]
pretrained_checkpoint = f"{args.save}/{eval_datasets[0]}Val/{mode_prefix}zeroshot.pt"

if not os.path.exists(pretrained_checkpoint):
    print(f"Error: Base pretrained checkpoint not found at {pretrained_checkpoint}")
    exit()

args.eval_datasets = [dataset + "Val" for dataset in eval_datasets]
args.control_dataset = None

val_metrics = evaluate_task_vector(
    task_vector,
    pretrained_checkpoint,
    args,
    posthoc_linearization=False,
)

optimal_coef = find_optimal_coef(
    val_metrics,
    metric="avg_normalized_top1",
    minimize=False,
)

args.eval_datasets = [dataset for dataset in eval_datasets]
test_metrics = evaluate_task_vector_at_coef(
    task_vector,
    pretrained_checkpoint,
    args,
    float(optimal_coef),
    posthoc_linearization=False,
)

print("=" * 100)
print(f"Optimal Coefficient: {optimal_coef}")
print(f"Test normalized accuracy: {test_metrics['avg_normalized_top1']}")
print(f"Test absolute accuracy: {test_metrics['avg_top1']}")
additive_accuracies = {"test": test_metrics, "val": val_metrics, "optimal_coef": optimal_coef}

save_name_map = {
    "standard": "additions.json",
    "standard_ortho": "standard_ortho_additions.json",
    "linear": "linear_additions.json",
    "linear_ortho": "linear_ortho_additions.json",
    "linear-2": "linear-2_additions.json",
    "linear-2_ortho": "linear-2_ortho_additions.json",
}
save_file = os.path.join(args.save, save_name_map[mode])
with open(save_file, "w") as f:
    json.dump(additive_accuracies, f, indent=4)
print(f"Addition results saved to {save_file}")