| | import copy |
| | import os |
| | import yaml |
| | import json |
| | import csv |
| | from typing import Dict, Any, List, Tuple |
| | from utils import torch_load, print_message |
| | from embedder import get_embedding_filename |
| | from base_models.get_base_models import get_tokenizer |
| |
|
| |
|
| | if os.environ.get('WANDB_AVAILABLE') == 'true': |
| | import wandb |
| | else: |
| | pass |
| |
|
| |
|
| | class HyperoptModule: |
| | def __init__( |
| | self, |
| | main_process, |
| | model_name: str, |
| | data_name: str, |
| | dataset: Tuple, |
| | emb_dict: Any, |
| | sweep_config: Dict[str, Any], |
| | results_list: List[Dict[str, Any]], |
| | swept_param_keys: List[str] = None |
| | ): |
| | self.mp = main_process |
| | self.model_name = model_name |
| | self.data_name = data_name |
| | self.dataset = dataset |
| | self.emb_dict = emb_dict |
| | self.sweep_config = sweep_config |
| | self.results_list = results_list |
| | self.swept_param_keys = swept_param_keys or [] |
| | |
| | self.base_probe_args = copy.deepcopy(self.mp.probe_args.__dict__) |
| | self.base_trainer_args = copy.deepcopy(self.mp.trainer_args.__dict__) |
| | |
| | self.probe_keys = { |
| | 'hidden_size','transformer_hidden_size','dropout','n_layers','pre_ln','classifier_size', |
| | 'classifier_dropout','n_heads','rotary','use_bias','probe_pooling_types', |
| | 'lora','lora_r','lora_alpha','lora_dropout','probe_type','tokenwise', 'pooling_types' |
| | } |
| | self.trainer_keys = { |
| | 'lr','weight_decay','num_epochs','probe_batch_size', |
| | 'base_batch_size','probe_grad_accum','base_grad_accum', |
| | 'patience','seed' |
| | } |
| | self.embedding_keys = { |
| | 'embedding_pooling_types' |
| | } |
| | self.int_keys = { |
| | 'hidden_size', 'transformer_hidden_size', 'n_layers', 'classifier_size', 'n_heads', |
| | 'lora_r', 'lora_alpha', 'num_epochs', 'probe_batch_size', |
| | 'base_batch_size', 'probe_grad_accum', 'base_grad_accum', |
| | 'patience', 'seed' |
| | } |
| |
|
| | def apply_config(self, cfg: Dict[str, Any]): |
| | self.mp.probe_args.__dict__.update(copy.deepcopy(self.base_probe_args)) |
| | self.mp.trainer_args.__dict__.update(copy.deepcopy(self.base_trainer_args)) |
| | |
| | |
| | for key in self.int_keys: |
| | if key in cfg: |
| | cfg[key] = int(cfg[key]) |
| | |
| | if 'hidden_size' in cfg: |
| | val = cfg['hidden_size'] |
| | |
| | n_heads = max(1, val // 64) |
| | cfg['n_heads'] = n_heads |
| |
|
| | if 'transformer_hidden_size' in cfg: |
| | val = cfg['transformer_hidden_size'] |
| | |
| | n_heads = max(1, val // 64) |
| | cfg['n_heads'] = n_heads |
| | |
| | if 'dropout' in cfg: |
| | cfg['transformer_dropout'] = cfg['dropout'] |
| |
|
| | if 'probe_pooling_types' in cfg: |
| | cfg['pooling_types'] = cfg['probe_pooling_types'] |
| |
|
| | for k, v in cfg.items(): |
| | if k in self.probe_keys and hasattr(self.mp.probe_args, k): |
| | setattr(self.mp.probe_args, k, v) |
| | if k in self.trainer_keys and hasattr(self.mp.trainer_args, k): |
| | setattr(self.mp.trainer_args, k, v) |
| | |
| | if k in self.embedding_keys: |
| | if k == 'embedding_pooling_types': |
| | if isinstance(v, str): |
| | v = [v] |
| | self.mp.embedding_args.pooling_types = v |
| |
|
| | def train_model(self, sweep_mode=True): |
| | train_set, valid_set, test_set, _, _, ppi = self.dataset |
| | |
| | if self.mp.full_args.full_finetuning: |
| | model, valid_metrics, test_metrics = self.mp._run_full_finetuning( |
| | self.model_name, self.data_name, |
| | train_set, valid_set, test_set, |
| | ppi=ppi, sweep_mode=sweep_mode |
| | ) |
| | return model, valid_metrics, test_metrics |
| |
|
| | elif self.mp.full_args.hybrid_probe: |
| | tokenizer = get_tokenizer(self.model_name) |
| | model, valid_metrics, test_metrics = self.mp._run_hybrid_probe( |
| | self.model_name, self.data_name, |
| | train_set, valid_set, test_set, |
| | tokenizer, |
| | emb_dict=self.emb_dict, |
| | ppi=ppi, |
| | sweep_mode=sweep_mode |
| | ) |
| | return model, valid_metrics, test_metrics |
| |
|
| | else: |
| | tokenizer = get_tokenizer(self.model_name) |
| | probe, valid_metrics, test_metrics = self.mp._run_nn_probe( |
| | self.model_name, self.data_name, |
| | train_set, valid_set, test_set, |
| | tokenizer, |
| | emb_dict=self.emb_dict, |
| | ppi=ppi, |
| | sweep_mode=sweep_mode |
| | ) |
| | return probe, valid_metrics, test_metrics |
| |
|
| | def select_metric(self, valid_metrics: Dict[str, Any], test_metrics: Dict[str, Any], sweep_metric: str) -> float: |
| | if valid_metrics and sweep_metric in valid_metrics: |
| | return float(valid_metrics[sweep_metric]) |
| | elif test_metrics and sweep_metric in test_metrics: |
| | return float(test_metrics[sweep_metric]) |
| | |
| | |
| | available_keys = [] |
| | if valid_metrics: available_keys.extend(valid_metrics.keys()) |
| | if test_metrics: available_keys.extend(test_metrics.keys()) |
| | raise KeyError(f"Metric '{sweep_metric}' not found in validation or test metrics. Available metrics: {available_keys}") |
| |
|
| | def objective(self): |
| | run = wandb.init( |
| | project=self.mp.full_args.wandb_project, |
| | entity=self.mp.full_args.wandb_entity, |
| | config=self.sweep_config, |
| | reinit=True, |
| | tags=["sweep", f"model:{self.model_name}", f"data:{self.data_name}"], |
| | ) |
| | run.name = f"sweep-{self.model_name}_{self.data_name}-{run.id[:6]}" |
| | |
| | |
| | full_config = dict(wandb.config) |
| | self.apply_config(full_config) |
| | |
| | applied_config = {k: v for k, v in full_config.items() if k in self.swept_param_keys} |
| | self.mp.trainer_args.make_plots = False |
| | |
| | |
| | if 'embedding_pooling_types' in full_config and not self.mp.full_args.full_finetuning: |
| | _, _, _, _, _, ppi = self.dataset |
| | tokenizer = get_tokenizer(self.model_name) |
| | test_seq = self.mp.all_seqs[0] |
| | |
| | if self.mp._sql: |
| | filename = get_embedding_filename(self.model_name, self.mp._full, |
| | self.mp.embedding_args.pooling_types, 'db') |
| | save_path = os.path.join(self.mp.embedding_args.embedding_save_dir, filename) |
| | input_dim = self.mp.get_embedding_dim_sql(save_path, test_seq, tokenizer) |
| | self.emb_dict = None |
| | else: |
| | filename = get_embedding_filename(self.model_name, self.mp._full, |
| | self.mp.embedding_args.pooling_types, 'pth') |
| | save_path = os.path.join(self.mp.embedding_args.embedding_save_dir, filename) |
| | self.emb_dict = torch_load(save_path) |
| | input_dim = self.mp.get_embedding_dim_pth(self.emb_dict, test_seq, tokenizer) |
| | |
| | self.mp.probe_args.input_size = input_dim * 2 if (ppi and not self.mp._full) else input_dim |
| | |
| | _, valid_metrics, test_metrics = self.train_model(sweep_mode=True) |
| | |
| | |
| | label_type = self.mp.probe_args.task_type |
| | metric_cls = getattr(self.mp.full_args, 'sweep_metric_cls', None) |
| | metric_reg = getattr(self.mp.full_args, 'sweep_metric_reg', None) |
| | dataset_metric = metric_cls if label_type in ["singlelabel", "multilabel"] else metric_reg |
| |
|
| | all_metrics = {} |
| | if isinstance(valid_metrics, dict): |
| | for k, v in valid_metrics.items(): |
| | all_metrics[f"{k}"] = v |
| | if isinstance(test_metrics, dict): |
| | for k, v in test_metrics.items(): |
| | all_metrics[f"{k}"] = v |
| | wandb.log(all_metrics) |
| | |
| | metric_value = self.select_metric(valid_metrics, test_metrics, dataset_metric) |
| | |
| | self.results_list.append({ |
| | "wandb_run_id": run.id, |
| | dataset_metric: metric_value, |
| | "config": applied_config, |
| | "valid_metrics": valid_metrics, |
| | "test_metrics": test_metrics, |
| | }) |
| | |
| | run.finish() |
| | return float(metric_value) |
| |
|
| | @classmethod |
| | def run_wandb_hyperopt(cls, mp): |
| | mp.logger.info("Called method: run_wandb_hyperopt") |
| |
|
| | sweep_config = {} |
| | sweep_config_path = mp.full_args.sweep_config_path |
| | |
| | if os.path.exists(sweep_config_path): |
| | with open(sweep_config_path, 'r') as f: |
| | sweep_config = yaml.safe_load(f) |
| | else: |
| | raise ValueError(f"Sweep config file not found: {sweep_config_path}") |
| |
|
| | params_to_hyperopt = sweep_config.get("parameters", {}) |
| | |
| | |
| | probe_type = getattr(mp.probe_args, 'probe_type', 'linear') |
| | use_lora = getattr(mp.probe_args, 'lora', False) |
| | |
| | |
| | linear_probe_params = {'lr', 'weight_decay', 'hidden_size', 'n_layers', 'dropout', 'pre_ln', 'use_bias', 'probe_batch_size'} |
| | transformer_probe_params = {'lr', 'weight_decay', 'transformer_hidden_size', 'n_layers', 'transformer_dropout', 'pre_ln', |
| | 'classifier_dropout', 'classifier_size', 'use_bias', 'probe_pooling_types', 'embedding_pooling_types', 'probe_batch_size'} |
| | lora_params = {'lora_r', 'lora_alpha', 'lora_dropout'} |
| | |
| | |
| | if probe_type == 'linear': |
| | relevant_params = linear_probe_params |
| | elif probe_type == 'transformer': |
| | relevant_params = transformer_probe_params |
| | else: |
| | |
| | relevant_params = linear_probe_params | transformer_probe_params |
| | |
| | |
| | if use_lora: |
| | relevant_params = relevant_params | lora_params |
| | |
| | |
| | filtered_params = {k: v for k, v in params_to_hyperopt.items() if k in relevant_params} |
| | params_to_hyperopt = filtered_params |
| | |
| | |
| | mp.logger.info(f"Probe type: {probe_type}, LoRA enabled: {use_lora}") |
| | mp.logger.info(f"Sweeping over {len(params_to_hyperopt)} parameters: {list(params_to_hyperopt.keys())}") |
| |
|
| | method = mp.full_args.sweep_method |
| | early_term = sweep_config.get("early_terminate", None) |
| |
|
| | total_combinations = len(mp.model_args.model_names) * len(mp.datasets) |
| | mp.logger.info(f"Hyperopt over {total_combinations} model/dataset combinations") |
| | for model_name in mp.model_args.model_names: |
| | tokenizer = get_tokenizer(model_name) |
| | test_seq = mp.all_seqs[0] |
| |
|
| | if "random" in model_name.lower() or "onehot" in model_name.lower(): |
| | print_message(f"Skipping hyperparameter optimization for {model_name}.") |
| |
|
| | for data_name, dataset in mp.datasets.items(): |
| | train_set, valid_set, test_set, num_labels, label_type, ppi = dataset |
| | mp.probe_args.num_labels = num_labels |
| | mp.probe_args.task_type = label_type |
| | mp.trainer_args.task_type = label_type |
| | mp.trainer_args.make_plots = True |
| |
|
| | emb_dict = None |
| | if not mp.full_args.full_finetuning: |
| | if mp._sql: |
| | filename = get_embedding_filename(model_name, mp._full, mp.embedding_args.pooling_types, 'db') |
| | save_path = os.path.join(mp.embedding_args.embedding_save_dir, filename) |
| | input_dim = mp.get_embedding_dim_sql(save_path, test_seq, tokenizer) |
| | else: |
| | filename = get_embedding_filename(model_name, mp._full, mp.embedding_args.pooling_types, 'pth') |
| | save_path = os.path.join(mp.embedding_args.embedding_save_dir, filename) |
| | emb_dict = torch_load(save_path) |
| | input_dim = mp.get_embedding_dim_pth(emb_dict, test_seq, tokenizer) |
| | mp.probe_args.input_size = input_dim * 2 if (ppi and not mp._full) else input_dim |
| | if mp.full_args.full_finetuning: |
| | _ = mp._run_full_finetuning(model_name, data_name, train_set, valid_set, test_set, ppi, sweep_mode=False) |
| | elif mp.full_args.hybrid_probe: |
| | _ = mp._run_hybrid_probe(model_name, data_name, train_set, valid_set, test_set, tokenizer, emb_dict=emb_dict, ppi=ppi, sweep_mode=False) |
| | else: |
| | _ = mp._run_nn_probe(model_name, data_name, train_set, valid_set, test_set, tokenizer, emb_dict=emb_dict, ppi=ppi, sweep_mode=False) |
| | continue |
| |
|
| | for data_name, dataset in mp.datasets.items(): |
| | mp.logger.info(f"Sweeping over {data_name} with {model_name}") |
| | train_set, _, _, num_labels, label_type, ppi = dataset |
| | mp.probe_args.num_labels = num_labels |
| | mp.probe_args.task_type = label_type |
| | mp.trainer_args.task_type = label_type |
| |
|
| | emb_dict = None |
| | if not mp.full_args.full_finetuning: |
| | if mp._sql: |
| | filename = get_embedding_filename(model_name, mp._full, mp.embedding_args.pooling_types, 'db') |
| | save_path = os.path.join(mp.embedding_args.embedding_save_dir, filename) |
| | input_dim = mp.get_embedding_dim_sql(save_path, test_seq, tokenizer) |
| | else: |
| | filename = get_embedding_filename(model_name, mp._full, mp.embedding_args.pooling_types, 'pth') |
| | save_path = os.path.join(mp.embedding_args.embedding_save_dir, filename) |
| | emb_dict = torch_load(save_path) |
| | input_dim = mp.get_embedding_dim_pth(emb_dict, test_seq, tokenizer) |
| | mp.probe_args.input_size = input_dim * 2 if (ppi and not mp._full) else input_dim |
| |
|
| | |
| | base_probe = copy.deepcopy(mp.probe_args.__dict__) |
| | base_trainer = copy.deepcopy(mp.trainer_args.__dict__) |
| |
|
| | results_list = [] |
| | |
| | metric_cls = getattr(mp.full_args, 'sweep_metric_cls', None) |
| | metric_reg = getattr(mp.full_args, 'sweep_metric_reg', None) |
| | dataset_metric = metric_cls if label_type in ["singlelabel", "multilabel"] else metric_reg |
| | |
| | hyperopt_module = cls( |
| | main_process=mp, |
| | model_name=model_name, |
| | data_name=data_name, |
| | dataset=dataset, |
| | emb_dict=emb_dict, |
| | sweep_config=sweep_config, |
| | results_list=results_list, |
| | swept_param_keys=list(params_to_hyperopt.keys()) |
| | ) |
| |
|
| | wb_sweep = { |
| | "method": method, |
| | "metric": {"name": dataset_metric, "goal": mp.full_args.sweep_goal}, |
| | "early_terminate": early_term, |
| | "parameters": params_to_hyperopt, |
| | } |
| | sweep_id = wandb.sweep(sweep=wb_sweep, project=mp.full_args.wandb_project, entity=mp.full_args.wandb_entity) |
| | wandb.agent(sweep_id, function=hyperopt_module.objective, count=mp.full_args.sweep_count) |
| |
|
| | |
| | reverse_flag = True if mp.full_args.sweep_goal == "maximize" else False |
| | results_list.sort(key=lambda x: x[dataset_metric], reverse=reverse_flag) |
| | sweep_log_path = os.path.join(mp.full_args.log_dir, f"{mp.random_id}_sweep_{data_name}_{model_name}.csv") |
| | with open(sweep_log_path, 'w', newline='', encoding='utf-8') as f: |
| | writer = csv.writer(f, delimiter=',') |
| | |
| | columns = ["rank","wandb_run_id",dataset_metric,"config","valid_metrics","test_metrics"] |
| | writer.writerow(columns) |
| | for idx, res in enumerate(results_list, start=1): |
| | writer.writerow([ |
| | idx, |
| | res['wandb_run_id'], |
| | res[dataset_metric], |
| | json.dumps(res['config']), |
| | json.dumps(res['valid_metrics']), |
| | json.dumps(res['test_metrics']), |
| | ]) |
| |
|
| | |
| | best = results_list[0] if results_list else None |
| | best_score = best[dataset_metric] |
| | best_config = best['config'] |
| | print_message(f"Best sweep result - {dataset_metric}: {best_score}") |
| | print_message(f"Best hyperparameters: {json.dumps(best_config, indent=2)}") |
| |
|
| | |
| | mp.probe_args.__dict__.update(copy.deepcopy(base_probe)) |
| | mp.trainer_args.__dict__.update(copy.deepcopy(base_trainer)) |
| | hyperopt_module.apply_config(best_config) |
| | mp.trainer_args.make_plots = True |
| | |
| | final_config = { |
| | **best_config, |
| | 'probe_batch_size': mp.trainer_args.probe_batch_size, |
| | 'seed': mp.trainer_args.seed, |
| | 'patience': mp.trainer_args.patience, |
| | 'num_epochs': mp.trainer_args.num_epochs, |
| | } |
| | print_message(f"Final training config: {json.dumps(final_config, indent=2)}") |
| |
|
| | |
| | final_run = wandb.init( |
| | project=mp.full_args.wandb_project, |
| | entity=mp.full_args.wandb_entity, |
| | config=final_config, |
| | reinit=True, |
| | tags=["final_model", f"model:{model_name}", f"data:{data_name}", f"best_sweep_score:{best_score}"], |
| | name=f"final-{model_name}_{data_name}-best", |
| | ) |
| |
|
| | |
| | _, valid_metrics, test_metrics = hyperopt_module.train_model(sweep_mode=False) |
| | |
| | |
| | all_final_metrics = {} |
| | if isinstance(valid_metrics, dict): |
| | for k, v in valid_metrics.items(): |
| | all_final_metrics[f"final_{k}"] = v |
| | if isinstance(test_metrics, dict): |
| | for k, v in test_metrics.items(): |
| | all_final_metrics[f"final_{k}"] = v |
| | wandb.log(all_final_metrics) |
| | |
| | final_run.finish() |