| | |
| | """ |
| | This script is a simple example of how to fine-tune a Synthyra FastPLM model for a protein sequence regression or classification task. |
| | For regression we look at the binding affinity of two proteins (pkd) |
| | For classification we look at the solubility of a protein (membrane bound or not) |
| | """ |
| | import torch |
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| | import seaborn as sns |
| | from datasets import load_dataset |
| | from torch.utils.data import Dataset as TorchDataset |
| | from typing import List, Tuple, Dict, Union, Any |
| | from transformers import ( |
| | AutoModelForSequenceClassification, |
| | Trainer, |
| | TrainingArguments, |
| | EarlyStoppingCallback, |
| | EvalPrediction |
| | ) |
| | from peft import LoraConfig, get_peft_model |
| | from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay |
| | from scipy.stats import spearmanr |
| |
|
| |
|
| | |
| | BASE_TRAINER_KWARGS = { |
| | "warmup_steps": 500, |
| | "weight_decay": 0.01, |
| | "logging_steps": 100, |
| | "eval_strategy": "steps", |
| | "eval_steps": 500, |
| | "save_strategy": "steps", |
| | "save_steps": 500, |
| | "load_best_model_at_end": True, |
| | "metric_for_best_model": "eval_loss", |
| | "greater_is_better": False, |
| | "report_to": "none", |
| | "label_names": ["labels"] |
| | } |
| |
|
| |
|
| | |
| | class PairDatasetHF(TorchDataset): |
| | """ |
| | Dataset class for protein pair data (e.g., protein-protein interactions). |
| | |
| | Args: |
| | data: The dataset containing protein sequences and labels |
| | col_a: Column name for the first protein sequence |
| | col_b: Column name for the second protein sequence |
| | label_col: Column name for the labels |
| | max_length: Maximum sequence length to consider |
| | """ |
| | def __init__(self, dataset: Any, col_a: str, col_b: str, label_col: str, max_length: int = 2048): |
| | self.seqs_a = dataset[col_a] |
| | self.seqs_b = dataset[col_b] |
| | self.labels = dataset[label_col] |
| | self.max_length = max_length |
| |
|
| | def __len__(self) -> int: |
| | return len(self.seqs_a) |
| |
|
| | def __getitem__(self, idx: int) -> Tuple[str, str, Union[float, int]]: |
| | seq_a = self.seqs_a[idx][:self.max_length] |
| | seq_b = self.seqs_b[idx][:self.max_length] |
| | label = self.labels[idx] |
| | return seq_a, seq_b, label |
| |
|
| |
|
| | class SequenceDatasetHF(TorchDataset): |
| | """ |
| | Dataset class for single protein sequence data. |
| | |
| | Args: |
| | dataset: The dataset containing protein sequences and labels |
| | col_name: Column name for the protein sequences |
| | label_col: Column name for the labels |
| | max_length: Maximum sequence length to consider |
| | """ |
| | def __init__(self, dataset: Any, col_name: str = 'seqs', label_col: str = 'labels', max_length: int = 2048): |
| | self.seqs = dataset[col_name] |
| | self.labels = dataset[label_col] |
| | self.max_length = max_length |
| |
|
| | def __len__(self) -> int: |
| | return len(self.seqs) |
| | |
| | def __getitem__(self, idx: int) -> Tuple[str, Union[float, int]]: |
| | seq = self.seqs[idx][:self.max_length] |
| | label = self.labels[idx] |
| | return seq, label |
| |
|
| |
|
| | class PairCollator: |
| | """ |
| | Collator for protein pair data that handles tokenization and tensor conversion. |
| | |
| | Args: |
| | tokenizer: The tokenizer to use for encoding sequences |
| | regression: Whether this is a regression task (True) or classification (False) |
| | """ |
| | def __init__(self, tokenizer: Any, regression: bool = False): |
| | self.tokenizer = tokenizer |
| | self.regression = regression |
| |
|
| | def __call__(self, batch: List[Tuple[str, str, Union[float, int]]]) -> Dict[str, torch.Tensor]: |
| | seqs_a, seqs_b, labels = zip(*batch) |
| | labels = torch.tensor(labels) |
| | if self.regression: |
| | labels = labels.float() |
| | else: |
| | labels = labels.long() |
| | tokenized = self.tokenizer( |
| | seqs_a, seqs_b, |
| | padding='longest', |
| | pad_to_multiple_of=8, |
| | return_tensors='pt' |
| | ) |
| | return { |
| | 'input_ids': tokenized['input_ids'], |
| | 'attention_mask': tokenized['attention_mask'], |
| | 'labels': labels |
| | } |
| |
|
| |
|
| | class SequenceCollator: |
| | """ |
| | Collator for single protein sequence data that handles tokenization and tensor conversion. |
| | |
| | Args: |
| | tokenizer: The tokenizer to use for encoding sequences |
| | regression: Whether this is a regression task (True) or classification (False) |
| | """ |
| | def __init__(self, tokenizer: Any, regression: bool = False): |
| | self.tokenizer = tokenizer |
| | self.regression = regression |
| |
|
| | def __call__(self, batch: List[Tuple[str, Union[float, int]]]) -> Dict[str, torch.Tensor]: |
| | seqs, labels = zip(*batch) |
| | labels = torch.tensor(labels) |
| | if self.regression: |
| | labels = labels.float() |
| | else: |
| | labels = labels.long() |
| | tokenized = self.tokenizer( |
| | seqs, |
| | padding='longest', |
| | pad_to_multiple_of=8, |
| | return_tensors='pt' |
| | ) |
| | return { |
| | 'input_ids': tokenized['input_ids'], |
| | 'attention_mask': tokenized['attention_mask'], |
| | 'labels': labels |
| | } |
| |
|
| |
|
| | |
| | def initialize_model(model_name: str, num_labels: int, use_lora: bool = True, lora_config: Any = None): |
| | """ |
| | Initialize a model with optional LoRA support |
| | |
| | Args: |
| | model_name: Name or path of the pretrained model |
| | num_labels: Number of labels for the task (1 for regression) |
| | use_lora: Whether to use LoRA for fine-tuning |
| | lora_config: Custom LoRA configuration (optional) |
| | |
| | Returns: |
| | model: The initialized model |
| | tokenizer: The model's tokenizer |
| | """ |
| | print(f"Loading model {model_name} with {num_labels} labels...") |
| | |
| | |
| | model = AutoModelForSequenceClassification.from_pretrained( |
| | model_name, |
| | trust_remote_code=True, |
| | num_labels=num_labels |
| | ) |
| | tokenizer = model.tokenizer |
| | |
| | |
| | if use_lora: |
| | |
| | if lora_config is None: |
| | |
| | target_modules = ["layernorm_qkv.1", "out_proj", "query", "key", "value", "dense"] |
| | |
| | lora_config = LoraConfig( |
| | r=8, |
| | lora_alpha=16, |
| | lora_dropout=0.01, |
| | bias="none", |
| | target_modules=target_modules, |
| | ) |
| | |
| | |
| | model = get_peft_model(model, lora_config) |
| | |
| | |
| | for param in model.classifier.parameters(): |
| | param.requires_grad = True |
| | |
| | |
| | total_params = sum(p.numel() for p in model.parameters()) |
| | trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| | non_trainable_params = total_params - trainable_params |
| | print(f"Total parameters: {total_params}") |
| | print(f"Trainable parameters: {trainable_params}") |
| | print(f"Non-trainable parameters: {non_trainable_params}") |
| | print(f"Percentage of parameters being trained: {100 * trainable_params / total_params:.2f}%") |
| | |
| | return model, tokenizer |
| |
|
| |
|
| | |
| | def compute_metrics_regression(p: EvalPrediction) -> Dict[str, float]: |
| | """Compute Spearman correlation for regression tasks""" |
| | predictions, labels = p.predictions, p.label_ids |
| | predictions = predictions[0] if isinstance(predictions, tuple) else predictions |
| | |
| | |
| | correlation, p_value = spearmanr(predictions.flatten(), labels.flatten()) |
| | |
| | return { |
| | "spearman_correlation": correlation, |
| | "p_value": p_value |
| | } |
| |
|
| |
|
| | def compute_metrics_classification(p: EvalPrediction) -> Dict[str, float]: |
| | """Compute accuracy for classification tasks""" |
| | predictions, labels = p.predictions, p.label_ids |
| | predictions = predictions[0] if isinstance(predictions, tuple) else predictions |
| | predictions = np.argmax(predictions, axis=-1) |
| | |
| | accuracy = (predictions.flatten() == labels.flatten()).mean() |
| | |
| | return { |
| | "accuracy": accuracy |
| | } |
| |
|
| |
|
| | |
| | def plot_regression_results(preds: np.ndarray, labels: np.ndarray, task_name: str = "Regression") -> float: |
| | """ |
| | Plot regression results with Spearman correlation |
| | |
| | Args: |
| | preds: Predicted values |
| | labels: True values |
| | task_name: Name of the task for plot title and filename |
| | |
| | Returns: |
| | correlation: Spearman correlation coefficient |
| | """ |
| | |
| | correlation, p_value = spearmanr(preds, labels) |
| | |
| | |
| | plt.figure(figsize=(10, 8)) |
| | sns.scatterplot(x=labels, y=preds, alpha=0.6) |
| | |
| | |
| | sns.regplot(x=labels, y=preds, scatter=False, color='red') |
| | |
| | plt.title(f'{task_name} - Spearman Correlation: {correlation:.3f} (p={p_value:.3e})') |
| | plt.xlabel('True Values') |
| | plt.ylabel('Predicted Values') |
| | |
| | |
| | plt.annotate(f'ρ = {correlation:.3f}', xy=(0.05, 0.95), xycoords='axes fraction', |
| | fontsize=12, bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="gray", alpha=0.8)) |
| | |
| | plt.tight_layout() |
| | plt.savefig(f'{task_name.lower().replace(" ", "_")}_results.png') |
| | plt.show() |
| | return correlation |
| |
|
| |
|
| | def plot_classification_results(trainer: Trainer, test_dataset: Any, task_name: str = "Classification") -> float: |
| | """ |
| | Plot classification results with confusion matrix |
| | |
| | Args: |
| | trainer: The trained model trainer |
| | test_dataset: Dataset to evaluate on |
| | task_name: Name of the task for plot title and filename |
| | |
| | Returns: |
| | accuracy: Classification accuracy |
| | """ |
| | |
| | predictions, labels, _ = trainer.predict(test_dataset) |
| | preds = predictions[0] if isinstance(predictions, tuple) else predictions |
| | pred_values = np.argmax(preds, axis=1) |
| | |
| | |
| | accuracy = (pred_values == labels).mean() |
| | |
| | |
| | cm = confusion_matrix(labels, pred_values) |
| | |
| | |
| | plt.figure(figsize=(10, 8)) |
| | disp = ConfusionMatrixDisplay(confusion_matrix=cm) |
| | disp.plot(cmap=plt.cm.Blues) |
| | |
| | plt.title(f'{task_name} - Accuracy: {accuracy:.3f}') |
| | plt.tight_layout() |
| | plt.savefig(f'{task_name.lower().replace(" ", "_")}_results.png') |
| | plt.show() |
| | |
| | return accuracy |
| |
|
| |
|
| | |
| | def train_regression_model( |
| | model_name: str = 'Synthyra/ESMplusplus_small', |
| | use_lora: bool = True, |
| | custom_lora_config: Any = None, |
| | batch_size: int = 8, |
| | learning_rate: float = 5e-5, |
| | num_epochs: int = 10, |
| | max_length: int = 1024, |
| | gradient_accumulation_steps: int = 1, |
| | patience: int = 3 |
| | ) -> Tuple[Trainer, Any]: |
| | """ |
| | Train a regression model for protein-protein affinity prediction |
| | |
| | Args: |
| | model_name: Name or path of the pretrained model |
| | use_lora: Whether to use LoRA for fine-tuning |
| | custom_lora_config: Custom LoRA configuration (optional) |
| | batch_size: Batch size for training |
| | learning_rate: Learning rate for training |
| | num_epochs: Number of epochs for training |
| | max_length: Maximum sequence length to consider |
| | gradient_accumulation_steps: Number of gradient accumulation steps |
| | patience: Number of evaluation calls with no improvement after which training will be stopped |
| | |
| | Returns: |
| | trainer: The trained model trainer |
| | test_dataset: The test dataset used for evaluation |
| | """ |
| | print("Loading datasets for regression task...") |
| | |
| | |
| | def _filter_pair_by_length(example: Any) -> bool: |
| | return len(example['SeqA']) + len(example['SeqB']) <= max_length |
| |
|
| | |
| | train_data = load_dataset('Synthyra/ProteinProteinAffinity', split='train').filter(_filter_pair_by_length) |
| | valid_data = load_dataset('Synthyra/AffinityBenchmarkv5.5', split='train').filter(_filter_pair_by_length) |
| | test_data = load_dataset('Synthyra/haddock_benchmark', split='train').filter(_filter_pair_by_length) |
| |
|
| | |
| | train_dataset = PairDatasetHF(train_data, 'SeqA', 'SeqB', 'labels', max_length=max_length) |
| | valid_dataset = PairDatasetHF(valid_data, 'SeqA', 'SeqB', 'labels', max_length=max_length) |
| | test_dataset = PairDatasetHF(test_data, 'SeqA', 'SeqB', 'labels', max_length=max_length) |
| | |
| | |
| | model, tokenizer = initialize_model( |
| | model_name=model_name, |
| | num_labels=1, |
| | use_lora=use_lora, |
| | lora_config=custom_lora_config |
| | ) |
| | |
| | |
| | data_collator = PairCollator(tokenizer, regression=True) |
| | |
| | |
| | output_dir = "./results_regression_lora" if use_lora else "./results_regression" |
| | logging_dir = "./logs_regression_lora" if use_lora else "./logs_regression" |
| | |
| | training_args = TrainingArguments( |
| | output_dir=output_dir, |
| | num_train_epochs=num_epochs, |
| | gradient_accumulation_steps=gradient_accumulation_steps, |
| | per_device_train_batch_size=batch_size, |
| | per_device_eval_batch_size=batch_size, |
| | logging_dir=logging_dir, |
| | learning_rate=learning_rate, |
| | **BASE_TRAINER_KWARGS |
| | ) |
| | |
| | |
| | trainer = Trainer( |
| | model=model, |
| | args=training_args, |
| | train_dataset=train_dataset, |
| | eval_dataset=valid_dataset, |
| | data_collator=data_collator, |
| | compute_metrics=compute_metrics_regression, |
| | callbacks=[EarlyStoppingCallback(early_stopping_patience=patience)] |
| | ) |
| | |
| | metrics = trainer.evaluate(test_dataset) |
| | print(f"Initial metrics: {metrics}") |
| | print("Training regression model...") |
| | trainer.train() |
| | |
| | |
| | print("Evaluating and visualizing results...") |
| | predictions, labels, metrics = trainer.predict(test_dataset) |
| | preds = predictions[0] if isinstance(predictions, tuple) else predictions |
| | correlation = plot_regression_results(preds.flatten(), labels.flatten(), "Protein-Protein Affinity") |
| | print(f"Final Spearman correlation on test set: {correlation:.3f}") |
| | return trainer, test_dataset |
| |
|
| |
|
| | def train_classification_model( |
| | model_name: str = 'Synthyra/ESMplusplus_small', |
| | use_lora: bool = True, |
| | custom_lora_config: Any = None, |
| | batch_size: int = 8, |
| | learning_rate: float = 5e-5, |
| | num_epochs: int = 10, |
| | max_length: int = 512, |
| | gradient_accumulation_steps: int = 1, |
| | patience: int = 3 |
| | ) -> Tuple[Trainer, Any]: |
| | """ |
| | Train a classification model for protein solubility prediction |
| | |
| | Args: |
| | model_name: Name or path of the pretrained model |
| | use_lora: Whether to use LoRA for fine-tuning |
| | custom_lora_config: Custom LoRA configuration (optional) |
| | batch_size: Batch size for training |
| | learning_rate: Learning rate for training |
| | num_epochs: Number of epochs for training |
| | max_length: Maximum sequence length to consider |
| | gradient_accumulation_steps: Number of gradient accumulation steps |
| | patience: Number of evaluation calls with no improvement after which training will be stopped |
| | |
| | Returns: |
| | trainer: The trained model trainer |
| | """ |
| | print("Loading datasets for classification task...") |
| | |
| | |
| | def _filter_by_length(example: Any) -> bool: |
| | return len(example['seqs']) <= max_length |
| |
|
| | |
| | data = load_dataset('GleghornLab/DL2_reg') |
| | train_data = data['train'].filter(_filter_by_length) |
| | valid_data = data['valid'].filter(_filter_by_length) |
| | test_data = data['test'].filter(_filter_by_length) |
| |
|
| | |
| | train_dataset = SequenceDatasetHF(train_data, 'seqs', 'labels', max_length=max_length) |
| | valid_dataset = SequenceDatasetHF(valid_data, 'seqs', 'labels', max_length=max_length) |
| | test_dataset = SequenceDatasetHF(test_data, 'seqs', 'labels', max_length=max_length) |
| | |
| | |
| | num_labels = len(set(train_data['labels'])) |
| | |
| | |
| | model, tokenizer = initialize_model( |
| | model_name=model_name, |
| | num_labels=num_labels, |
| | use_lora=use_lora, |
| | lora_config=custom_lora_config |
| | ) |
| | |
| | |
| | data_collator = SequenceCollator(tokenizer, regression=False) |
| | |
| | |
| | output_dir = "./results_classification_lora" if use_lora else "./results_classification" |
| | logging_dir = "./logs_classification_lora" if use_lora else "./logs_classification" |
| | |
| | training_args = TrainingArguments( |
| | output_dir=output_dir, |
| | num_train_epochs=num_epochs, |
| | gradient_accumulation_steps=gradient_accumulation_steps, |
| | per_device_train_batch_size=batch_size, |
| | per_device_eval_batch_size=batch_size, |
| | logging_dir=logging_dir, |
| | learning_rate=learning_rate, |
| | **BASE_TRAINER_KWARGS |
| | ) |
| | |
| | |
| | trainer = Trainer( |
| | model=model, |
| | args=training_args, |
| | train_dataset=train_dataset, |
| | eval_dataset=valid_dataset, |
| | data_collator=data_collator, |
| | compute_metrics=compute_metrics_classification, |
| | callbacks=[EarlyStoppingCallback(early_stopping_patience=patience)] |
| | ) |
| | |
| | metrics = trainer.evaluate(test_dataset) |
| | print(f"Initial metrics: {metrics}") |
| | print("Training classification model...") |
| | trainer.train() |
| | |
| | |
| | print("Evaluating and visualizing results...") |
| | accuracy = plot_classification_results(trainer, test_dataset, "Protein Solubility") |
| | print(f"Final accuracy on test set: {accuracy:.3f}") |
| | |
| | return trainer |
| |
|
| |
|
| | |
| | if __name__ == "__main__": |
| | """ |
| | With default arguments on 4070 laptop GPU |
| | py -m fine_tuning_example --task classification --batch_size 8 --epochs 2 |
| | Runs in 80 seconds with test accuracy of ~89% |
| | py -m fine_tuning_example --task regression --batch_size 2 --max_length 1024 --grad_accum 4 --epochs 2 |
| | Runs in 7 minutes with test Spearman correlation of ~0.72 |
| | """ |
| | import argparse |
| | |
| | |
| | MODEL_LIST = [ |
| | 'Synthyra/ESMplusplus_small', |
| | 'Synthyra/ESMplusplus_large', |
| | 'Synthyra/ESM2-8M', |
| | 'Synthyra/ESM2-35M', |
| | 'Synthyra/ESM2-150M', |
| | 'Synthyra/ESM2-650M', |
| | ] |
| |
|
| | parser = argparse.ArgumentParser(description="Train models for protein tasks") |
| | parser.add_argument("--task", type=str, choices=["regression", "classification", "both"], |
| | default="both", help="Task to train model for") |
| | parser.add_argument("--model_path", type=str, default="Synthyra/ESM2-8M", |
| | help="Path to the model to train") |
| | parser.add_argument("--use_lora", action="store_true", default=True, |
| | help="Whether to use LoRA for fine-tuning") |
| | parser.add_argument("--batch_size", type=int, default=2, |
| | help="Batch size for training") |
| | parser.add_argument("--lr", type=float, default=5e-5, |
| | help="Learning rate for training") |
| | parser.add_argument("--epochs", type=float, default=1.0, |
| | help="Number of epochs for training") |
| | parser.add_argument("--max_length", type=int, default=512, |
| | help="Maximum length of input sequences") |
| | parser.add_argument("--grad_accum", type=int, default=1, |
| | help="Number of gradient accumulation steps") |
| | parser.add_argument("--patience", type=int, default=3, |
| | help="Early stopping patience - number of evaluation calls with no improvement after which training will be stopped") |
| | args = parser.parse_args() |
| |
|
| | |
| | print("\n" + "="*50) |
| | print("TRAINING CONFIGURATION") |
| | print("="*50) |
| | print(f"Task: {args.task}") |
| | print(f"Using LoRA: {args.use_lora}") |
| | print(f"Batch size: {args.batch_size}") |
| | print(f"Learning rate: {args.lr}") |
| | print(f"Number of epochs: {args.epochs}") |
| | print(f"Max sequence length: {args.max_length}") |
| | print(f"Gradient Accumulation Steps: {args.grad_accum}") |
| | print(f"Early stopping patience: {args.patience}") |
| | print("="*50 + "\n") |
| | |
| | |
| | if args.task in ["regression", "both"]: |
| | print("\n" + "="*50) |
| | print("TRAINING REGRESSION MODEL") |
| | print("="*50) |
| | regression_trainer, test_dataset = train_regression_model( |
| | model_name=args.model_path, |
| | use_lora=args.use_lora, |
| | batch_size=args.batch_size, |
| | learning_rate=args.lr, |
| | num_epochs=args.epochs, |
| | max_length=args.max_length, |
| | gradient_accumulation_steps=args.grad_accum, |
| | patience=args.patience |
| | ) |
| |
|
| | |
| | if args.task in ["classification", "both"]: |
| | print("\n" + "="*50) |
| | print("TRAINING CLASSIFICATION MODEL") |
| | print("="*50) |
| | classification_trainer = train_classification_model( |
| | model_name=args.model_path, |
| | use_lora=args.use_lora, |
| | batch_size=args.batch_size, |
| | learning_rate=args.lr, |
| | num_epochs=args.epochs, |
| | max_length=args.max_length, |
| | gradient_accumulation_steps=args.grad_accum, |
| | patience=args.patience |
| | ) |
| | |
| | print("\nTraining completed!") |
| |
|