File size: 6,759 Bytes
3589275 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | import json
import os
from utils import find_optimal_coef
from src.args import parse_arguments
from src.eval import evaluate_task_vector, evaluate_task_vector_at_coef
from src.task_vectors import LinearizedTaskVector, NonLinearTaskVector
from src.attention_only_finetune import AttentionOnlyFinetuneEncoder
args = parse_arguments()
if 'ortho' in args.finetuning_mode:
args.save = f"checkpoints_{args.seed}/{args.finetuning_mode}_{args.lr}_lambda{args.ortho_lambda}_{args.model}"
else:
if args.seed is not None:
args.save = f"checkpoints_{args.seed}/{args.finetuning_mode}_{args.lr}_{args.model}"
else:
args.save = f"checkpoints/{args.finetuning_mode}_{args.lr}_{args.model}"
if args.seed is not None:
base_model_save_path = f"checkpoints_{args.seed}/{args.model}"
else:
base_model_save_path = f"checkpoints/{args.model}"
with open(os.path.join(base_model_save_path, "zeroshot_accuracies.json")) as f:
pretrained_accuracies = json.load(f)
eval_datasets = [
"Cars", "DTD", "EuroSAT", "GTSRB", "MNIST", "RESISC45", "SUN397", "SVHN",
]
print("*" * 100)
mode_labels = {
"standard": "Evaluating non-linear FT models.",
"standard_ortho": "Evaluating standard FT models with orthogonality regularization.",
"linear": "Evaluating linear FT models.",
"linear_ortho": "Evaluating linear FT models with orthogonality regularization.",
"linear-2": "Evaluating Attention-Only Finetune models.",
"linear-2_ortho": "Evaluating Attention-Only Finetune models with orthogonality regularization.",
}
ft_accuracies_name_map = {
"standard": "ft_accuracies.json",
"standard_ortho": "standard_ortho_ft_accuracies.json",
"linear": "linear_ft_accuracies.json",
"linear_ortho": "linear_ortho_ft_accuracies.json",
"linear-2": "linear-2_ft_accuracies.json",
"linear-2_ortho": "linear-2_ortho_ft_accuracies.json",
}
print(mode_labels.get(args.finetuning_mode, f"Evaluating {args.finetuning_mode} models."))
print("*" * 100)
ft_accuracies_path = os.path.join(args.save, ft_accuracies_name_map[args.finetuning_mode])
with open(ft_accuracies_path) as f:
args.finetuning_accuracies = json.load(f)
control_dataset = "ImageNet"
negation_accuracies = {}
mode = args.finetuning_mode
for dataset in eval_datasets:
task_vector = None
pretrained_checkpoint = None
if mode == "linear":
pretrained_checkpoint = f"{args.save}/{dataset}Val/linear_zeroshot.pt"
finetuned_checkpoint = f"{args.save}/{dataset}Val/linear_finetuned.pt"
if not (os.path.exists(pretrained_checkpoint) and os.path.exists(finetuned_checkpoint)):
print(f"Warning: Missing checkpoints for {dataset}. Skipping.")
continue
task_vector = LinearizedTaskVector(pretrained_checkpoint, finetuned_checkpoint)
elif mode == "linear_ortho":
pretrained_checkpoint = f"{args.save}/{dataset}Val/linear_ortho_zeroshot.pt"
finetuned_checkpoint = f"{args.save}/{dataset}Val/linear_ortho_finetuned.pt"
if not (os.path.exists(pretrained_checkpoint) and os.path.exists(finetuned_checkpoint)):
print(f"Warning: Missing checkpoints for {dataset}. Skipping.")
continue
task_vector = LinearizedTaskVector(pretrained_checkpoint, finetuned_checkpoint)
elif mode == "standard_ortho":
pretrained_checkpoint = f"{args.save}/{dataset}Val/standard_ortho_zeroshot.pt"
finetuned_checkpoint = f"{args.save}/{dataset}Val/standard_ortho_finetuned.pt"
if not (os.path.exists(pretrained_checkpoint) and os.path.exists(finetuned_checkpoint)):
print(f"Warning: Missing checkpoints for {dataset}. Skipping.")
continue
task_vector = NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint)
elif mode in ("linear-2", "linear-2_ortho"):
prefix = mode + "_"
pretrained_checkpoint = f"{args.save}/{dataset}Val/{prefix}zeroshot.pt"
finetuned_checkpoint = f"{args.save}/{dataset}Val/{prefix}finetuned.pt"
if not (os.path.exists(pretrained_checkpoint) and os.path.exists(finetuned_checkpoint)):
print(f"Warning: Missing checkpoints for {dataset}. Skipping.")
continue
task_vector = NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint)
else: # standard
pretrained_checkpoint = f"{args.save}/{dataset}Val/zeroshot.pt"
finetuned_checkpoint = f"{args.save}/{dataset}Val/finetuned.pt"
if not (os.path.exists(pretrained_checkpoint) and os.path.exists(finetuned_checkpoint)):
print(f"Warning: Missing checkpoints for {dataset}. Skipping.")
continue
task_vector = NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint)
if not os.path.exists(pretrained_checkpoint):
print(f"Error: Base pretrained checkpoint not found at {pretrained_checkpoint}. Skipping {dataset}.")
continue
task_vector = -task_vector
args.eval_datasets = [dataset + "Val"]
args.control_dataset = control_dataset + "Val"
val_metrics = evaluate_task_vector(
task_vector,
pretrained_checkpoint,
args,
posthoc_linearization=False,
)
optimal_coef = find_optimal_coef(
val_metrics,
metric=f"{dataset}Val:top1",
minimize=True,
control_metric=f"{control_dataset}Val:top1",
control_metric_threshold=args.control_threshold * pretrained_accuracies[control_dataset + "Val"],
)
args.eval_datasets = [dataset]
args.control_dataset = control_dataset
test_metrics = evaluate_task_vector_at_coef(
task_vector,
pretrained_checkpoint,
args,
optimal_coef,
posthoc_linearization=False,
)
print("=" * 100)
print(f"Results for dataset: {dataset}")
print(f"Optimal Coefficient: {optimal_coef}")
print(f"Test accuracy: {test_metrics.get(f'{dataset}:top1', 'N/A')}")
print(f"Control accuracy on {control_dataset}: {test_metrics.get(f'{control_dataset}:top1', 'N/A')}")
negation_accuracies[dataset] = {
"test": test_metrics.get(f"{dataset}:top1"),
"test_control": test_metrics.get(f"{control_dataset}:top1"),
"val": val_metrics,
"optimal_coef": optimal_coef,
}
save_name_map = {
"standard": "negations.json",
"standard_ortho": "standard_ortho_negations.json",
"linear": "linear_negations.json",
"linear_ortho": "linear_ortho_negations.json",
"linear-2": "linear-2_negations.json",
"linear-2_ortho": "linear-2_ortho_negations.json",
}
save_file = os.path.join(args.save, save_name_map[mode])
with open(save_file, "w") as f:
json.dump(negation_accuracies, f, indent=4)
print(f"Negation results saved to {save_file}")
|