directionality_probe / protify /hyperopt_utils.py
nikraf's picture
Upload folder using huggingface_hub
714cf46 verified
import copy
import os
import yaml
import json
import csv
from typing import Dict, Any, List, Tuple
from utils import torch_load, print_message
from embedder import get_embedding_filename
from base_models.get_base_models import get_tokenizer
if os.environ.get('WANDB_AVAILABLE') == 'true':
import wandb
else:
pass
class HyperoptModule:
def __init__(
self,
main_process,
model_name: str,
data_name: str,
dataset: Tuple,
emb_dict: Any,
sweep_config: Dict[str, Any],
results_list: List[Dict[str, Any]],
swept_param_keys: List[str] = None
):
self.mp = main_process
self.model_name = model_name
self.data_name = data_name
self.dataset = dataset
self.emb_dict = emb_dict
self.sweep_config = sweep_config
self.results_list = results_list
self.swept_param_keys = swept_param_keys or []
self.base_probe_args = copy.deepcopy(self.mp.probe_args.__dict__)
self.base_trainer_args = copy.deepcopy(self.mp.trainer_args.__dict__)
self.probe_keys = {
'hidden_size','transformer_hidden_size','dropout','n_layers','pre_ln','classifier_size',
'classifier_dropout','n_heads','rotary','use_bias','probe_pooling_types',
'lora','lora_r','lora_alpha','lora_dropout','probe_type','tokenwise', 'pooling_types'
}
self.trainer_keys = {
'lr','weight_decay','num_epochs','probe_batch_size',
'base_batch_size','probe_grad_accum','base_grad_accum',
'patience','seed'
}
self.embedding_keys = {
'embedding_pooling_types'
}
self.int_keys = {
'hidden_size', 'transformer_hidden_size', 'n_layers', 'classifier_size', 'n_heads',
'lora_r', 'lora_alpha', 'num_epochs', 'probe_batch_size',
'base_batch_size', 'probe_grad_accum', 'base_grad_accum',
'patience', 'seed'
}
def apply_config(self, cfg: Dict[str, Any]):
self.mp.probe_args.__dict__.update(copy.deepcopy(self.base_probe_args))
self.mp.trainer_args.__dict__.update(copy.deepcopy(self.base_trainer_args))
# Ensure integer parameters are actually integers
for key in self.int_keys:
if key in cfg:
cfg[key] = int(cfg[key])
if 'hidden_size' in cfg:
val = cfg['hidden_size']
# Automatically set n_heads based on hidden_size (linear probe)
n_heads = max(1, val // 64)
cfg['n_heads'] = n_heads
if 'transformer_hidden_size' in cfg:
val = cfg['transformer_hidden_size']
# Automatically set n_heads based on transformer_hidden_size (transformer probe)
n_heads = max(1, val // 64)
cfg['n_heads'] = n_heads
if 'dropout' in cfg:
cfg['transformer_dropout'] = cfg['dropout']
if 'probe_pooling_types' in cfg:
cfg['pooling_types'] = cfg['probe_pooling_types']
for k, v in cfg.items():
if k in self.probe_keys and hasattr(self.mp.probe_args, k):
setattr(self.mp.probe_args, k, v)
if k in self.trainer_keys and hasattr(self.mp.trainer_args, k):
setattr(self.mp.trainer_args, k, v)
# Handle embedding pooling types
if k in self.embedding_keys:
if k == 'embedding_pooling_types':
if isinstance(v, str):
v = [v]
self.mp.embedding_args.pooling_types = v
def train_model(self, sweep_mode=True):
train_set, valid_set, test_set, _, _, ppi = self.dataset
if self.mp.full_args.full_finetuning:
model, valid_metrics, test_metrics = self.mp._run_full_finetuning(
self.model_name, self.data_name,
train_set, valid_set, test_set,
ppi=ppi, sweep_mode=sweep_mode
)
return model, valid_metrics, test_metrics
elif self.mp.full_args.hybrid_probe:
tokenizer = get_tokenizer(self.model_name)
model, valid_metrics, test_metrics = self.mp._run_hybrid_probe(
self.model_name, self.data_name,
train_set, valid_set, test_set,
tokenizer,
emb_dict=self.emb_dict,
ppi=ppi,
sweep_mode=sweep_mode
)
return model, valid_metrics, test_metrics
else:
tokenizer = get_tokenizer(self.model_name)
probe, valid_metrics, test_metrics = self.mp._run_nn_probe(
self.model_name, self.data_name,
train_set, valid_set, test_set,
tokenizer,
emb_dict=self.emb_dict,
ppi=ppi,
sweep_mode=sweep_mode
)
return probe, valid_metrics, test_metrics
def select_metric(self, valid_metrics: Dict[str, Any], test_metrics: Dict[str, Any], sweep_metric: str) -> float:
if valid_metrics and sweep_metric in valid_metrics:
return float(valid_metrics[sweep_metric])
elif test_metrics and sweep_metric in test_metrics:
return float(test_metrics[sweep_metric])
# Raise a helpful error if metric was not found
available_keys = []
if valid_metrics: available_keys.extend(valid_metrics.keys())
if test_metrics: available_keys.extend(test_metrics.keys())
raise KeyError(f"Metric '{sweep_metric}' not found in validation or test metrics. Available metrics: {available_keys}")
def objective(self):
run = wandb.init(
project=self.mp.full_args.wandb_project,
entity=self.mp.full_args.wandb_entity,
config=self.sweep_config,
reinit=True,
tags=["sweep", f"model:{self.model_name}", f"data:{self.data_name}"],
)
run.name = f"sweep-{self.model_name}_{self.data_name}-{run.id[:6]}"
# Store only the actual hyperparameters used for this run
full_config = dict(wandb.config)
self.apply_config(full_config)
# Filter to only include the hyperparameters that were actually tuned
applied_config = {k: v for k, v in full_config.items() if k in self.swept_param_keys}
self.mp.trainer_args.make_plots = False
# Reload embeddings if pooling type changed
if 'embedding_pooling_types' in full_config and not self.mp.full_args.full_finetuning:
_, _, _, _, _, ppi = self.dataset
tokenizer = get_tokenizer(self.model_name)
test_seq = self.mp.all_seqs[0]
if self.mp._sql:
filename = get_embedding_filename(self.model_name, self.mp._full,
self.mp.embedding_args.pooling_types, 'db')
save_path = os.path.join(self.mp.embedding_args.embedding_save_dir, filename)
input_dim = self.mp.get_embedding_dim_sql(save_path, test_seq, tokenizer)
self.emb_dict = None
else:
filename = get_embedding_filename(self.model_name, self.mp._full,
self.mp.embedding_args.pooling_types, 'pth')
save_path = os.path.join(self.mp.embedding_args.embedding_save_dir, filename)
self.emb_dict = torch_load(save_path)
input_dim = self.mp.get_embedding_dim_pth(self.emb_dict, test_seq, tokenizer)
self.mp.probe_args.input_size = input_dim * 2 if (ppi and not self.mp._full) else input_dim
_, valid_metrics, test_metrics = self.train_model(sweep_mode=True)
# Choose task-specific metric to optimize
label_type = self.mp.probe_args.task_type
metric_cls = getattr(self.mp.full_args, 'sweep_metric_cls', None)
metric_reg = getattr(self.mp.full_args, 'sweep_metric_reg', None)
dataset_metric = metric_cls if label_type in ["singlelabel", "multilabel"] else metric_reg
all_metrics = {}
if isinstance(valid_metrics, dict):
for k, v in valid_metrics.items():
all_metrics[f"{k}"] = v
if isinstance(test_metrics, dict):
for k, v in test_metrics.items():
all_metrics[f"{k}"] = v
wandb.log(all_metrics)
metric_value = self.select_metric(valid_metrics, test_metrics, dataset_metric)
self.results_list.append({
"wandb_run_id": run.id,
dataset_metric: metric_value,
"config": applied_config,
"valid_metrics": valid_metrics,
"test_metrics": test_metrics,
})
run.finish()
return float(metric_value)
@classmethod
def run_wandb_hyperopt(cls, mp):
mp.logger.info("Called method: run_wandb_hyperopt")
sweep_config = {}
sweep_config_path = mp.full_args.sweep_config_path
if os.path.exists(sweep_config_path):
with open(sweep_config_path, 'r') as f:
sweep_config = yaml.safe_load(f)
else:
raise ValueError(f"Sweep config file not found: {sweep_config_path}")
params_to_hyperopt = sweep_config.get("parameters", {})
# Filter parameters based on probe type and LoRA settings
probe_type = getattr(mp.probe_args, 'probe_type', 'linear')
use_lora = getattr(mp.probe_args, 'lora', False)
# Define which parameters are relevant for each probe type
linear_probe_params = {'lr', 'weight_decay', 'hidden_size', 'n_layers', 'dropout', 'pre_ln', 'use_bias', 'probe_batch_size'}
transformer_probe_params = {'lr', 'weight_decay', 'transformer_hidden_size', 'n_layers', 'transformer_dropout', 'pre_ln',
'classifier_dropout', 'classifier_size', 'use_bias', 'probe_pooling_types', 'embedding_pooling_types', 'probe_batch_size'}
lora_params = {'lora_r', 'lora_alpha', 'lora_dropout'}
# Determine which parameters to include
if probe_type == 'linear':
relevant_params = linear_probe_params
elif probe_type == 'transformer':
relevant_params = transformer_probe_params
else:
# For other probe types, include all common params
relevant_params = linear_probe_params | transformer_probe_params
# Add LoRA parameters only if LoRA is enabled
if use_lora:
relevant_params = relevant_params | lora_params
# Filter the parameters dictionary
filtered_params = {k: v for k, v in params_to_hyperopt.items() if k in relevant_params}
params_to_hyperopt = filtered_params
# Log which parameters are being swept
mp.logger.info(f"Probe type: {probe_type}, LoRA enabled: {use_lora}")
mp.logger.info(f"Sweeping over {len(params_to_hyperopt)} parameters: {list(params_to_hyperopt.keys())}")
method = mp.full_args.sweep_method
early_term = sweep_config.get("early_terminate", None)
total_combinations = len(mp.model_args.model_names) * len(mp.datasets)
mp.logger.info(f"Hyperopt over {total_combinations} model/dataset combinations")
for model_name in mp.model_args.model_names:
tokenizer = get_tokenizer(model_name)
test_seq = mp.all_seqs[0]
if "random" in model_name.lower() or "onehot" in model_name.lower():
print_message(f"Skipping hyperparameter optimization for {model_name}.")
for data_name, dataset in mp.datasets.items():
train_set, valid_set, test_set, num_labels, label_type, ppi = dataset
mp.probe_args.num_labels = num_labels
mp.probe_args.task_type = label_type
mp.trainer_args.task_type = label_type
mp.trainer_args.make_plots = True
emb_dict = None
if not mp.full_args.full_finetuning:
if mp._sql:
filename = get_embedding_filename(model_name, mp._full, mp.embedding_args.pooling_types, 'db')
save_path = os.path.join(mp.embedding_args.embedding_save_dir, filename)
input_dim = mp.get_embedding_dim_sql(save_path, test_seq, tokenizer)
else:
filename = get_embedding_filename(model_name, mp._full, mp.embedding_args.pooling_types, 'pth')
save_path = os.path.join(mp.embedding_args.embedding_save_dir, filename)
emb_dict = torch_load(save_path)
input_dim = mp.get_embedding_dim_pth(emb_dict, test_seq, tokenizer)
mp.probe_args.input_size = input_dim * 2 if (ppi and not mp._full) else input_dim
if mp.full_args.full_finetuning:
_ = mp._run_full_finetuning(model_name, data_name, train_set, valid_set, test_set, ppi, sweep_mode=False)
elif mp.full_args.hybrid_probe:
_ = mp._run_hybrid_probe(model_name, data_name, train_set, valid_set, test_set, tokenizer, emb_dict=emb_dict, ppi=ppi, sweep_mode=False)
else:
_ = mp._run_nn_probe(model_name, data_name, train_set, valid_set, test_set, tokenizer, emb_dict=emb_dict, ppi=ppi, sweep_mode=False)
continue
for data_name, dataset in mp.datasets.items():
mp.logger.info(f"Sweeping over {data_name} with {model_name}")
train_set, _, _, num_labels, label_type, ppi = dataset
mp.probe_args.num_labels = num_labels
mp.probe_args.task_type = label_type
mp.trainer_args.task_type = label_type
emb_dict = None
if not mp.full_args.full_finetuning:
if mp._sql:
filename = get_embedding_filename(model_name, mp._full, mp.embedding_args.pooling_types, 'db')
save_path = os.path.join(mp.embedding_args.embedding_save_dir, filename)
input_dim = mp.get_embedding_dim_sql(save_path, test_seq, tokenizer)
else:
filename = get_embedding_filename(model_name, mp._full, mp.embedding_args.pooling_types, 'pth')
save_path = os.path.join(mp.embedding_args.embedding_save_dir, filename)
emb_dict = torch_load(save_path)
input_dim = mp.get_embedding_dim_pth(emb_dict, test_seq, tokenizer)
mp.probe_args.input_size = input_dim * 2 if (ppi and not mp._full) else input_dim
# Save base args for restoring after each trial
base_probe = copy.deepcopy(mp.probe_args.__dict__)
base_trainer = copy.deepcopy(mp.trainer_args.__dict__)
results_list = []
# Choose task-specific metric to optimize
metric_cls = getattr(mp.full_args, 'sweep_metric_cls', None)
metric_reg = getattr(mp.full_args, 'sweep_metric_reg', None)
dataset_metric = metric_cls if label_type in ["singlelabel", "multilabel"] else metric_reg
hyperopt_module = cls(
main_process=mp,
model_name=model_name,
data_name=data_name,
dataset=dataset,
emb_dict=emb_dict,
sweep_config=sweep_config,
results_list=results_list,
swept_param_keys=list(params_to_hyperopt.keys())
)
wb_sweep = {
"method": method,
"metric": {"name": dataset_metric, "goal": mp.full_args.sweep_goal},
"early_terminate": early_term,
"parameters": params_to_hyperopt,
}
sweep_id = wandb.sweep(sweep=wb_sweep, project=mp.full_args.wandb_project, entity=mp.full_args.wandb_entity)
wandb.agent(sweep_id, function=hyperopt_module.objective, count=mp.full_args.sweep_count)
# Sort, write, and save sweep results
reverse_flag = True if mp.full_args.sweep_goal == "maximize" else False
results_list.sort(key=lambda x: x[dataset_metric], reverse=reverse_flag)
sweep_log_path = os.path.join(mp.full_args.log_dir, f"{mp.random_id}_sweep_{data_name}_{model_name}.csv")
with open(sweep_log_path, 'w', newline='', encoding='utf-8') as f:
writer = csv.writer(f, delimiter=',')
# Columns
columns = ["rank","wandb_run_id",dataset_metric,"config","valid_metrics","test_metrics"]
writer.writerow(columns)
for idx, res in enumerate(results_list, start=1):
writer.writerow([
idx,
res['wandb_run_id'],
res[dataset_metric],
json.dumps(res['config']),
json.dumps(res['valid_metrics']),
json.dumps(res['test_metrics']),
])
# Log best hyperparameters
best = results_list[0] if results_list else None
best_score = best[dataset_metric]
best_config = best['config']
print_message(f"Best sweep result - {dataset_metric}: {best_score}")
print_message(f"Best hyperparameters: {json.dumps(best_config, indent=2)}")
# Restore base args then apply best
mp.probe_args.__dict__.update(copy.deepcopy(base_probe))
mp.trainer_args.__dict__.update(copy.deepcopy(base_trainer))
hyperopt_module.apply_config(best_config)
mp.trainer_args.make_plots = True
final_config = {
**best_config,
'probe_batch_size': mp.trainer_args.probe_batch_size,
'seed': mp.trainer_args.seed,
'patience': mp.trainer_args.patience,
'num_epochs': mp.trainer_args.num_epochs,
}
print_message(f"Final training config: {json.dumps(final_config, indent=2)}")
# Create a fresh wandb run for the final model to track it
final_run = wandb.init(
project=mp.full_args.wandb_project,
entity=mp.full_args.wandb_entity,
config=final_config,
reinit=True,
tags=["final_model", f"model:{model_name}", f"data:{data_name}", f"best_sweep_score:{best_score}"],
name=f"final-{model_name}_{data_name}-best",
)
# Run best model with the best hyperparameters, log metrics, create plots
_, valid_metrics, test_metrics = hyperopt_module.train_model(sweep_mode=False)
# Log final model metrics to wandb
all_final_metrics = {}
if isinstance(valid_metrics, dict):
for k, v in valid_metrics.items():
all_final_metrics[f"final_{k}"] = v
if isinstance(test_metrics, dict):
for k, v in test_metrics.items():
all_final_metrics[f"final_{k}"] = v
wandb.log(all_final_metrics)
final_run.finish()