import copy import os import yaml import json import csv from typing import Dict, Any, List, Tuple from utils import torch_load, print_message from embedder import get_embedding_filename from base_models.get_base_models import get_tokenizer if os.environ.get('WANDB_AVAILABLE') == 'true': import wandb else: pass class HyperoptModule: def __init__( self, main_process, model_name: str, data_name: str, dataset: Tuple, emb_dict: Any, sweep_config: Dict[str, Any], results_list: List[Dict[str, Any]], swept_param_keys: List[str] = None ): self.mp = main_process self.model_name = model_name self.data_name = data_name self.dataset = dataset self.emb_dict = emb_dict self.sweep_config = sweep_config self.results_list = results_list self.swept_param_keys = swept_param_keys or [] self.base_probe_args = copy.deepcopy(self.mp.probe_args.__dict__) self.base_trainer_args = copy.deepcopy(self.mp.trainer_args.__dict__) self.probe_keys = { 'hidden_size','transformer_hidden_size','dropout','n_layers','pre_ln','classifier_size', 'classifier_dropout','n_heads','rotary','use_bias','probe_pooling_types', 'lora','lora_r','lora_alpha','lora_dropout','probe_type','tokenwise', 'pooling_types' } self.trainer_keys = { 'lr','weight_decay','num_epochs','probe_batch_size', 'base_batch_size','probe_grad_accum','base_grad_accum', 'patience','seed' } self.embedding_keys = { 'embedding_pooling_types' } self.int_keys = { 'hidden_size', 'transformer_hidden_size', 'n_layers', 'classifier_size', 'n_heads', 'lora_r', 'lora_alpha', 'num_epochs', 'probe_batch_size', 'base_batch_size', 'probe_grad_accum', 'base_grad_accum', 'patience', 'seed' } def apply_config(self, cfg: Dict[str, Any]): self.mp.probe_args.__dict__.update(copy.deepcopy(self.base_probe_args)) self.mp.trainer_args.__dict__.update(copy.deepcopy(self.base_trainer_args)) # Ensure integer parameters are actually integers for key in self.int_keys: if key in cfg: cfg[key] = int(cfg[key]) if 'hidden_size' in cfg: val = cfg['hidden_size'] # Automatically set n_heads based on hidden_size (linear probe) n_heads = max(1, val // 64) cfg['n_heads'] = n_heads if 'transformer_hidden_size' in cfg: val = cfg['transformer_hidden_size'] # Automatically set n_heads based on transformer_hidden_size (transformer probe) n_heads = max(1, val // 64) cfg['n_heads'] = n_heads if 'dropout' in cfg: cfg['transformer_dropout'] = cfg['dropout'] if 'probe_pooling_types' in cfg: cfg['pooling_types'] = cfg['probe_pooling_types'] for k, v in cfg.items(): if k in self.probe_keys and hasattr(self.mp.probe_args, k): setattr(self.mp.probe_args, k, v) if k in self.trainer_keys and hasattr(self.mp.trainer_args, k): setattr(self.mp.trainer_args, k, v) # Handle embedding pooling types if k in self.embedding_keys: if k == 'embedding_pooling_types': if isinstance(v, str): v = [v] self.mp.embedding_args.pooling_types = v def train_model(self, sweep_mode=True): train_set, valid_set, test_set, _, _, ppi = self.dataset if self.mp.full_args.full_finetuning: model, valid_metrics, test_metrics = self.mp._run_full_finetuning( self.model_name, self.data_name, train_set, valid_set, test_set, ppi=ppi, sweep_mode=sweep_mode ) return model, valid_metrics, test_metrics elif self.mp.full_args.hybrid_probe: tokenizer = get_tokenizer(self.model_name) model, valid_metrics, test_metrics = self.mp._run_hybrid_probe( self.model_name, self.data_name, train_set, valid_set, test_set, tokenizer, emb_dict=self.emb_dict, ppi=ppi, sweep_mode=sweep_mode ) return model, valid_metrics, test_metrics else: tokenizer = get_tokenizer(self.model_name) probe, valid_metrics, test_metrics = self.mp._run_nn_probe( self.model_name, self.data_name, train_set, valid_set, test_set, tokenizer, emb_dict=self.emb_dict, ppi=ppi, sweep_mode=sweep_mode ) return probe, valid_metrics, test_metrics def select_metric(self, valid_metrics: Dict[str, Any], test_metrics: Dict[str, Any], sweep_metric: str) -> float: if valid_metrics and sweep_metric in valid_metrics: return float(valid_metrics[sweep_metric]) elif test_metrics and sweep_metric in test_metrics: return float(test_metrics[sweep_metric]) # Raise a helpful error if metric was not found available_keys = [] if valid_metrics: available_keys.extend(valid_metrics.keys()) if test_metrics: available_keys.extend(test_metrics.keys()) raise KeyError(f"Metric '{sweep_metric}' not found in validation or test metrics. Available metrics: {available_keys}") def objective(self): run = wandb.init( project=self.mp.full_args.wandb_project, entity=self.mp.full_args.wandb_entity, config=self.sweep_config, reinit=True, tags=["sweep", f"model:{self.model_name}", f"data:{self.data_name}"], ) run.name = f"sweep-{self.model_name}_{self.data_name}-{run.id[:6]}" # Store only the actual hyperparameters used for this run full_config = dict(wandb.config) self.apply_config(full_config) # Filter to only include the hyperparameters that were actually tuned applied_config = {k: v for k, v in full_config.items() if k in self.swept_param_keys} self.mp.trainer_args.make_plots = False # Reload embeddings if pooling type changed if 'embedding_pooling_types' in full_config and not self.mp.full_args.full_finetuning: _, _, _, _, _, ppi = self.dataset tokenizer = get_tokenizer(self.model_name) test_seq = self.mp.all_seqs[0] if self.mp._sql: filename = get_embedding_filename(self.model_name, self.mp._full, self.mp.embedding_args.pooling_types, 'db') save_path = os.path.join(self.mp.embedding_args.embedding_save_dir, filename) input_dim = self.mp.get_embedding_dim_sql(save_path, test_seq, tokenizer) self.emb_dict = None else: filename = get_embedding_filename(self.model_name, self.mp._full, self.mp.embedding_args.pooling_types, 'pth') save_path = os.path.join(self.mp.embedding_args.embedding_save_dir, filename) self.emb_dict = torch_load(save_path) input_dim = self.mp.get_embedding_dim_pth(self.emb_dict, test_seq, tokenizer) self.mp.probe_args.input_size = input_dim * 2 if (ppi and not self.mp._full) else input_dim _, valid_metrics, test_metrics = self.train_model(sweep_mode=True) # Choose task-specific metric to optimize label_type = self.mp.probe_args.task_type metric_cls = getattr(self.mp.full_args, 'sweep_metric_cls', None) metric_reg = getattr(self.mp.full_args, 'sweep_metric_reg', None) dataset_metric = metric_cls if label_type in ["singlelabel", "multilabel"] else metric_reg all_metrics = {} if isinstance(valid_metrics, dict): for k, v in valid_metrics.items(): all_metrics[f"{k}"] = v if isinstance(test_metrics, dict): for k, v in test_metrics.items(): all_metrics[f"{k}"] = v wandb.log(all_metrics) metric_value = self.select_metric(valid_metrics, test_metrics, dataset_metric) self.results_list.append({ "wandb_run_id": run.id, dataset_metric: metric_value, "config": applied_config, "valid_metrics": valid_metrics, "test_metrics": test_metrics, }) run.finish() return float(metric_value) @classmethod def run_wandb_hyperopt(cls, mp): mp.logger.info("Called method: run_wandb_hyperopt") sweep_config = {} sweep_config_path = mp.full_args.sweep_config_path if os.path.exists(sweep_config_path): with open(sweep_config_path, 'r') as f: sweep_config = yaml.safe_load(f) else: raise ValueError(f"Sweep config file not found: {sweep_config_path}") params_to_hyperopt = sweep_config.get("parameters", {}) # Filter parameters based on probe type and LoRA settings probe_type = getattr(mp.probe_args, 'probe_type', 'linear') use_lora = getattr(mp.probe_args, 'lora', False) # Define which parameters are relevant for each probe type linear_probe_params = {'lr', 'weight_decay', 'hidden_size', 'n_layers', 'dropout', 'pre_ln', 'use_bias', 'probe_batch_size'} transformer_probe_params = {'lr', 'weight_decay', 'transformer_hidden_size', 'n_layers', 'transformer_dropout', 'pre_ln', 'classifier_dropout', 'classifier_size', 'use_bias', 'probe_pooling_types', 'embedding_pooling_types', 'probe_batch_size'} lora_params = {'lora_r', 'lora_alpha', 'lora_dropout'} # Determine which parameters to include if probe_type == 'linear': relevant_params = linear_probe_params elif probe_type == 'transformer': relevant_params = transformer_probe_params else: # For other probe types, include all common params relevant_params = linear_probe_params | transformer_probe_params # Add LoRA parameters only if LoRA is enabled if use_lora: relevant_params = relevant_params | lora_params # Filter the parameters dictionary filtered_params = {k: v for k, v in params_to_hyperopt.items() if k in relevant_params} params_to_hyperopt = filtered_params # Log which parameters are being swept mp.logger.info(f"Probe type: {probe_type}, LoRA enabled: {use_lora}") mp.logger.info(f"Sweeping over {len(params_to_hyperopt)} parameters: {list(params_to_hyperopt.keys())}") method = mp.full_args.sweep_method early_term = sweep_config.get("early_terminate", None) total_combinations = len(mp.model_args.model_names) * len(mp.datasets) mp.logger.info(f"Hyperopt over {total_combinations} model/dataset combinations") for model_name in mp.model_args.model_names: tokenizer = get_tokenizer(model_name) test_seq = mp.all_seqs[0] if "random" in model_name.lower() or "onehot" in model_name.lower(): print_message(f"Skipping hyperparameter optimization for {model_name}.") for data_name, dataset in mp.datasets.items(): train_set, valid_set, test_set, num_labels, label_type, ppi = dataset mp.probe_args.num_labels = num_labels mp.probe_args.task_type = label_type mp.trainer_args.task_type = label_type mp.trainer_args.make_plots = True emb_dict = None if not mp.full_args.full_finetuning: if mp._sql: filename = get_embedding_filename(model_name, mp._full, mp.embedding_args.pooling_types, 'db') save_path = os.path.join(mp.embedding_args.embedding_save_dir, filename) input_dim = mp.get_embedding_dim_sql(save_path, test_seq, tokenizer) else: filename = get_embedding_filename(model_name, mp._full, mp.embedding_args.pooling_types, 'pth') save_path = os.path.join(mp.embedding_args.embedding_save_dir, filename) emb_dict = torch_load(save_path) input_dim = mp.get_embedding_dim_pth(emb_dict, test_seq, tokenizer) mp.probe_args.input_size = input_dim * 2 if (ppi and not mp._full) else input_dim if mp.full_args.full_finetuning: _ = mp._run_full_finetuning(model_name, data_name, train_set, valid_set, test_set, ppi, sweep_mode=False) elif mp.full_args.hybrid_probe: _ = mp._run_hybrid_probe(model_name, data_name, train_set, valid_set, test_set, tokenizer, emb_dict=emb_dict, ppi=ppi, sweep_mode=False) else: _ = mp._run_nn_probe(model_name, data_name, train_set, valid_set, test_set, tokenizer, emb_dict=emb_dict, ppi=ppi, sweep_mode=False) continue for data_name, dataset in mp.datasets.items(): mp.logger.info(f"Sweeping over {data_name} with {model_name}") train_set, _, _, num_labels, label_type, ppi = dataset mp.probe_args.num_labels = num_labels mp.probe_args.task_type = label_type mp.trainer_args.task_type = label_type emb_dict = None if not mp.full_args.full_finetuning: if mp._sql: filename = get_embedding_filename(model_name, mp._full, mp.embedding_args.pooling_types, 'db') save_path = os.path.join(mp.embedding_args.embedding_save_dir, filename) input_dim = mp.get_embedding_dim_sql(save_path, test_seq, tokenizer) else: filename = get_embedding_filename(model_name, mp._full, mp.embedding_args.pooling_types, 'pth') save_path = os.path.join(mp.embedding_args.embedding_save_dir, filename) emb_dict = torch_load(save_path) input_dim = mp.get_embedding_dim_pth(emb_dict, test_seq, tokenizer) mp.probe_args.input_size = input_dim * 2 if (ppi and not mp._full) else input_dim # Save base args for restoring after each trial base_probe = copy.deepcopy(mp.probe_args.__dict__) base_trainer = copy.deepcopy(mp.trainer_args.__dict__) results_list = [] # Choose task-specific metric to optimize metric_cls = getattr(mp.full_args, 'sweep_metric_cls', None) metric_reg = getattr(mp.full_args, 'sweep_metric_reg', None) dataset_metric = metric_cls if label_type in ["singlelabel", "multilabel"] else metric_reg hyperopt_module = cls( main_process=mp, model_name=model_name, data_name=data_name, dataset=dataset, emb_dict=emb_dict, sweep_config=sweep_config, results_list=results_list, swept_param_keys=list(params_to_hyperopt.keys()) ) wb_sweep = { "method": method, "metric": {"name": dataset_metric, "goal": mp.full_args.sweep_goal}, "early_terminate": early_term, "parameters": params_to_hyperopt, } sweep_id = wandb.sweep(sweep=wb_sweep, project=mp.full_args.wandb_project, entity=mp.full_args.wandb_entity) wandb.agent(sweep_id, function=hyperopt_module.objective, count=mp.full_args.sweep_count) # Sort, write, and save sweep results reverse_flag = True if mp.full_args.sweep_goal == "maximize" else False results_list.sort(key=lambda x: x[dataset_metric], reverse=reverse_flag) sweep_log_path = os.path.join(mp.full_args.log_dir, f"{mp.random_id}_sweep_{data_name}_{model_name}.csv") with open(sweep_log_path, 'w', newline='', encoding='utf-8') as f: writer = csv.writer(f, delimiter=',') # Columns columns = ["rank","wandb_run_id",dataset_metric,"config","valid_metrics","test_metrics"] writer.writerow(columns) for idx, res in enumerate(results_list, start=1): writer.writerow([ idx, res['wandb_run_id'], res[dataset_metric], json.dumps(res['config']), json.dumps(res['valid_metrics']), json.dumps(res['test_metrics']), ]) # Log best hyperparameters best = results_list[0] if results_list else None best_score = best[dataset_metric] best_config = best['config'] print_message(f"Best sweep result - {dataset_metric}: {best_score}") print_message(f"Best hyperparameters: {json.dumps(best_config, indent=2)}") # Restore base args then apply best mp.probe_args.__dict__.update(copy.deepcopy(base_probe)) mp.trainer_args.__dict__.update(copy.deepcopy(base_trainer)) hyperopt_module.apply_config(best_config) mp.trainer_args.make_plots = True final_config = { **best_config, 'probe_batch_size': mp.trainer_args.probe_batch_size, 'seed': mp.trainer_args.seed, 'patience': mp.trainer_args.patience, 'num_epochs': mp.trainer_args.num_epochs, } print_message(f"Final training config: {json.dumps(final_config, indent=2)}") # Create a fresh wandb run for the final model to track it final_run = wandb.init( project=mp.full_args.wandb_project, entity=mp.full_args.wandb_entity, config=final_config, reinit=True, tags=["final_model", f"model:{model_name}", f"data:{data_name}", f"best_sweep_score:{best_score}"], name=f"final-{model_name}_{data_name}-best", ) # Run best model with the best hyperparameters, log metrics, create plots _, valid_metrics, test_metrics = hyperopt_module.train_model(sweep_mode=False) # Log final model metrics to wandb all_final_metrics = {} if isinstance(valid_metrics, dict): for k, v in valid_metrics.items(): all_final_metrics[f"final_{k}"] = v if isinstance(test_metrics, dict): for k, v in test_metrics.items(): all_final_metrics[f"final_{k}"] = v wandb.log(all_final_metrics) final_run.finish()