directionality_probe / protify /FastPLMs /fine_tuning_example.py
nikraf's picture
Upload folder using huggingface_hub
714cf46 verified
#! /usr/bin/env python3
"""
This script is a simple example of how to fine-tune a Synthyra FastPLM model for a protein sequence regression or classification task.
For regression we look at the binding affinity of two proteins (pkd)
For classification we look at the solubility of a protein (membrane bound or not)
"""
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from datasets import load_dataset
from torch.utils.data import Dataset as TorchDataset
from typing import List, Tuple, Dict, Union, Any
from transformers import (
AutoModelForSequenceClassification,
Trainer,
TrainingArguments,
EarlyStoppingCallback,
EvalPrediction
)
from peft import LoraConfig, get_peft_model
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from scipy.stats import spearmanr
# Shared arguments for the trainer
BASE_TRAINER_KWARGS = {
"warmup_steps": 500,
"weight_decay": 0.01,
"logging_steps": 100,
"eval_strategy": "steps",
"eval_steps": 500,
"save_strategy": "steps",
"save_steps": 500,
"load_best_model_at_end": True,
"metric_for_best_model": "eval_loss",
"greater_is_better": False,
"report_to": "none",
"label_names": ["labels"]
}
# Dataset classes
class PairDatasetHF(TorchDataset):
"""
Dataset class for protein pair data (e.g., protein-protein interactions).
Args:
data: The dataset containing protein sequences and labels
col_a: Column name for the first protein sequence
col_b: Column name for the second protein sequence
label_col: Column name for the labels
max_length: Maximum sequence length to consider
"""
def __init__(self, dataset: Any, col_a: str, col_b: str, label_col: str, max_length: int = 2048):
self.seqs_a = dataset[col_a]
self.seqs_b = dataset[col_b]
self.labels = dataset[label_col]
self.max_length = max_length
def __len__(self) -> int:
return len(self.seqs_a)
def __getitem__(self, idx: int) -> Tuple[str, str, Union[float, int]]:
seq_a = self.seqs_a[idx][:self.max_length]
seq_b = self.seqs_b[idx][:self.max_length]
label = self.labels[idx]
return seq_a, seq_b, label
class SequenceDatasetHF(TorchDataset):
"""
Dataset class for single protein sequence data.
Args:
dataset: The dataset containing protein sequences and labels
col_name: Column name for the protein sequences
label_col: Column name for the labels
max_length: Maximum sequence length to consider
"""
def __init__(self, dataset: Any, col_name: str = 'seqs', label_col: str = 'labels', max_length: int = 2048):
self.seqs = dataset[col_name]
self.labels = dataset[label_col]
self.max_length = max_length
def __len__(self) -> int:
return len(self.seqs)
def __getitem__(self, idx: int) -> Tuple[str, Union[float, int]]:
seq = self.seqs[idx][:self.max_length]
label = self.labels[idx]
return seq, label
class PairCollator:
"""
Collator for protein pair data that handles tokenization and tensor conversion.
Args:
tokenizer: The tokenizer to use for encoding sequences
regression: Whether this is a regression task (True) or classification (False)
"""
def __init__(self, tokenizer: Any, regression: bool = False):
self.tokenizer = tokenizer
self.regression = regression
def __call__(self, batch: List[Tuple[str, str, Union[float, int]]]) -> Dict[str, torch.Tensor]:
seqs_a, seqs_b, labels = zip(*batch)
labels = torch.tensor(labels)
if self.regression:
labels = labels.float()
else:
labels = labels.long()
tokenized = self.tokenizer(
seqs_a, seqs_b,
padding='longest',
pad_to_multiple_of=8,
return_tensors='pt'
)
return {
'input_ids': tokenized['input_ids'],
'attention_mask': tokenized['attention_mask'],
'labels': labels
}
class SequenceCollator:
"""
Collator for single protein sequence data that handles tokenization and tensor conversion.
Args:
tokenizer: The tokenizer to use for encoding sequences
regression: Whether this is a regression task (True) or classification (False)
"""
def __init__(self, tokenizer: Any, regression: bool = False):
self.tokenizer = tokenizer
self.regression = regression
def __call__(self, batch: List[Tuple[str, Union[float, int]]]) -> Dict[str, torch.Tensor]:
seqs, labels = zip(*batch)
labels = torch.tensor(labels)
if self.regression:
labels = labels.float()
else:
labels = labels.long()
tokenized = self.tokenizer(
seqs,
padding='longest',
pad_to_multiple_of=8,
return_tensors='pt'
)
return {
'input_ids': tokenized['input_ids'],
'attention_mask': tokenized['attention_mask'],
'labels': labels
}
# Get the model ready, with or without LoRA
def initialize_model(model_name: str, num_labels: int, use_lora: bool = True, lora_config: Any = None):
"""
Initialize a model with optional LoRA support
Args:
model_name: Name or path of the pretrained model
num_labels: Number of labels for the task (1 for regression)
use_lora: Whether to use LoRA for fine-tuning
lora_config: Custom LoRA configuration (optional)
Returns:
model: The initialized model
tokenizer: The model's tokenizer
"""
print(f"Loading model {model_name} with {num_labels} labels...")
# Load base model
model = AutoModelForSequenceClassification.from_pretrained(
model_name,
trust_remote_code=True,
num_labels=num_labels
)
tokenizer = model.tokenizer
# Apply LoRA if requested
if use_lora:
# Default LoRA configuration if none provided
if lora_config is None:
# Target modules for ESM++ or ESM2 models
target_modules = ["layernorm_qkv.1", "out_proj", "query", "key", "value", "dense"]
lora_config = LoraConfig(
r=8,
lora_alpha=16,
lora_dropout=0.01,
bias="none",
target_modules=target_modules,
)
# Apply LoRA to the model
model = get_peft_model(model, lora_config)
# Unfreeze the classifier head
for param in model.classifier.parameters():
param.requires_grad = True
# Print parameter statistics
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
non_trainable_params = total_params - trainable_params
print(f"Total parameters: {total_params}")
print(f"Trainable parameters: {trainable_params}")
print(f"Non-trainable parameters: {non_trainable_params}")
print(f"Percentage of parameters being trained: {100 * trainable_params / total_params:.2f}%")
return model, tokenizer
# For computing performance metrics, it's fairly straightforward to add more metrics here
def compute_metrics_regression(p: EvalPrediction) -> Dict[str, float]:
"""Compute Spearman correlation for regression tasks"""
predictions, labels = p.predictions, p.label_ids
predictions = predictions[0] if isinstance(predictions, tuple) else predictions
# Calculate Spearman correlation
correlation, p_value = spearmanr(predictions.flatten(), labels.flatten())
return {
"spearman_correlation": correlation,
"p_value": p_value
}
def compute_metrics_classification(p: EvalPrediction) -> Dict[str, float]:
"""Compute accuracy for classification tasks"""
predictions, labels = p.predictions, p.label_ids
predictions = predictions[0] if isinstance(predictions, tuple) else predictions
predictions = np.argmax(predictions, axis=-1)
accuracy = (predictions.flatten() == labels.flatten()).mean()
return {
"accuracy": accuracy
}
# For plotting the results, it's fairly straightforward to add more plots here
def plot_regression_results(preds: np.ndarray, labels: np.ndarray, task_name: str = "Regression") -> float:
"""
Plot regression results with Spearman correlation
Args:
preds: Predicted values
labels: True values
task_name: Name of the task for plot title and filename
Returns:
correlation: Spearman correlation coefficient
"""
# Calculate Spearman correlation
correlation, p_value = spearmanr(preds, labels)
# Create scatter plot
plt.figure(figsize=(10, 8))
sns.scatterplot(x=labels, y=preds, alpha=0.6)
# Add regression line
sns.regplot(x=labels, y=preds, scatter=False, color='red')
plt.title(f'{task_name} - Spearman Correlation: {correlation:.3f} (p={p_value:.3e})')
plt.xlabel('True Values')
plt.ylabel('Predicted Values')
# Add correlation text
plt.annotate(f'ρ = {correlation:.3f}', xy=(0.05, 0.95), xycoords='axes fraction',
fontsize=12, bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="gray", alpha=0.8))
plt.tight_layout()
plt.savefig(f'{task_name.lower().replace(" ", "_")}_results.png')
plt.show()
return correlation
def plot_classification_results(trainer: Trainer, test_dataset: Any, task_name: str = "Classification") -> float:
"""
Plot classification results with confusion matrix
Args:
trainer: The trained model trainer
test_dataset: Dataset to evaluate on
task_name: Name of the task for plot title and filename
Returns:
accuracy: Classification accuracy
"""
# Get predictions
predictions, labels, _ = trainer.predict(test_dataset)
preds = predictions[0] if isinstance(predictions, tuple) else predictions
pred_values = np.argmax(preds, axis=1)
# Calculate accuracy
accuracy = (pred_values == labels).mean()
# Create confusion matrix
cm = confusion_matrix(labels, pred_values)
# Plot confusion matrix
plt.figure(figsize=(10, 8))
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(cmap=plt.cm.Blues)
plt.title(f'{task_name} - Accuracy: {accuracy:.3f}')
plt.tight_layout()
plt.savefig(f'{task_name.lower().replace(" ", "_")}_results.png')
plt.show()
return accuracy
# Training functions
def train_regression_model(
model_name: str = 'Synthyra/ESMplusplus_small',
use_lora: bool = True,
custom_lora_config: Any = None,
batch_size: int = 8,
learning_rate: float = 5e-5,
num_epochs: int = 10,
max_length: int = 1024,
gradient_accumulation_steps: int = 1,
patience: int = 3
) -> Tuple[Trainer, Any]:
"""
Train a regression model for protein-protein affinity prediction
Args:
model_name: Name or path of the pretrained model
use_lora: Whether to use LoRA for fine-tuning
custom_lora_config: Custom LoRA configuration (optional)
batch_size: Batch size for training
learning_rate: Learning rate for training
num_epochs: Number of epochs for training
max_length: Maximum sequence length to consider
gradient_accumulation_steps: Number of gradient accumulation steps
patience: Number of evaluation calls with no improvement after which training will be stopped
Returns:
trainer: The trained model trainer
test_dataset: The test dataset used for evaluation
"""
print("Loading datasets for regression task...")
# Filter sequences that exceed max_length
def _filter_pair_by_length(example: Any) -> bool:
return len(example['SeqA']) + len(example['SeqB']) <= max_length
# Load datasets
train_data = load_dataset('Synthyra/ProteinProteinAffinity', split='train').filter(_filter_pair_by_length)
valid_data = load_dataset('Synthyra/AffinityBenchmarkv5.5', split='train').filter(_filter_pair_by_length)
test_data = load_dataset('Synthyra/haddock_benchmark', split='train').filter(_filter_pair_by_length)
# Create datasets
train_dataset = PairDatasetHF(train_data, 'SeqA', 'SeqB', 'labels', max_length=max_length)
valid_dataset = PairDatasetHF(valid_data, 'SeqA', 'SeqB', 'labels', max_length=max_length)
test_dataset = PairDatasetHF(test_data, 'SeqA', 'SeqB', 'labels', max_length=max_length)
# Initialize model with modular function
model, tokenizer = initialize_model(
model_name=model_name,
num_labels=1, # Regression task
use_lora=use_lora,
lora_config=custom_lora_config
)
# Create data collator
data_collator = PairCollator(tokenizer, regression=True)
# Define training arguments
output_dir = "./results_regression_lora" if use_lora else "./results_regression"
logging_dir = "./logs_regression_lora" if use_lora else "./logs_regression"
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=num_epochs,
gradient_accumulation_steps=gradient_accumulation_steps,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
logging_dir=logging_dir,
learning_rate=learning_rate,
**BASE_TRAINER_KWARGS
)
# Create trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=valid_dataset,
data_collator=data_collator,
compute_metrics=compute_metrics_regression,
callbacks=[EarlyStoppingCallback(early_stopping_patience=patience)]
)
metrics = trainer.evaluate(test_dataset)
print(f"Initial metrics: {metrics}")
print("Training regression model...")
trainer.train()
# Evaluate and visualize results
print("Evaluating and visualizing results...")
predictions, labels, metrics = trainer.predict(test_dataset)
preds = predictions[0] if isinstance(predictions, tuple) else predictions
correlation = plot_regression_results(preds.flatten(), labels.flatten(), "Protein-Protein Affinity")
print(f"Final Spearman correlation on test set: {correlation:.3f}")
return trainer, test_dataset
def train_classification_model(
model_name: str = 'Synthyra/ESMplusplus_small',
use_lora: bool = True,
custom_lora_config: Any = None,
batch_size: int = 8,
learning_rate: float = 5e-5,
num_epochs: int = 10,
max_length: int = 512,
gradient_accumulation_steps: int = 1,
patience: int = 3
) -> Tuple[Trainer, Any]:
"""
Train a classification model for protein solubility prediction
Args:
model_name: Name or path of the pretrained model
use_lora: Whether to use LoRA for fine-tuning
custom_lora_config: Custom LoRA configuration (optional)
batch_size: Batch size for training
learning_rate: Learning rate for training
num_epochs: Number of epochs for training
max_length: Maximum sequence length to consider
gradient_accumulation_steps: Number of gradient accumulation steps
patience: Number of evaluation calls with no improvement after which training will be stopped
Returns:
trainer: The trained model trainer
"""
print("Loading datasets for classification task...")
# Filter sequences that exceed max_length
def _filter_by_length(example: Any) -> bool:
return len(example['seqs']) <= max_length
# Load datasets
data = load_dataset('GleghornLab/DL2_reg')
train_data = data['train'].filter(_filter_by_length)
valid_data = data['valid'].filter(_filter_by_length)
test_data = data['test'].filter(_filter_by_length)
# Create datasets
train_dataset = SequenceDatasetHF(train_data, 'seqs', 'labels', max_length=max_length)
valid_dataset = SequenceDatasetHF(valid_data, 'seqs', 'labels', max_length=max_length)
test_dataset = SequenceDatasetHF(test_data, 'seqs', 'labels', max_length=max_length)
# Get number of labels
num_labels = len(set(train_data['labels']))
# Initialize model with modular function
model, tokenizer = initialize_model(
model_name=model_name,
num_labels=num_labels,
use_lora=use_lora,
lora_config=custom_lora_config
)
# Create data collator
data_collator = SequenceCollator(tokenizer, regression=False)
# Define training arguments
output_dir = "./results_classification_lora" if use_lora else "./results_classification"
logging_dir = "./logs_classification_lora" if use_lora else "./logs_classification"
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=num_epochs,
gradient_accumulation_steps=gradient_accumulation_steps,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
logging_dir=logging_dir,
learning_rate=learning_rate,
**BASE_TRAINER_KWARGS
)
# Create trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=valid_dataset,
data_collator=data_collator,
compute_metrics=compute_metrics_classification,
callbacks=[EarlyStoppingCallback(early_stopping_patience=patience)]
)
metrics = trainer.evaluate(test_dataset)
print(f"Initial metrics: {metrics}")
print("Training classification model...")
trainer.train()
# Evaluate and visualize results
print("Evaluating and visualizing results...")
accuracy = plot_classification_results(trainer, test_dataset, "Protein Solubility")
print(f"Final accuracy on test set: {accuracy:.3f}")
return trainer
# Main function
if __name__ == "__main__":
"""
With default arguments on 4070 laptop GPU
py -m fine_tuning_example --task classification --batch_size 8 --epochs 2
Runs in 80 seconds with test accuracy of ~89%
py -m fine_tuning_example --task regression --batch_size 2 --max_length 1024 --grad_accum 4 --epochs 2
Runs in 7 minutes with test Spearman correlation of ~0.72
"""
import argparse
# Examples of PLMs with efficient implementations offered by Synthyra
MODEL_LIST = [
'Synthyra/ESMplusplus_small',
'Synthyra/ESMplusplus_large',
'Synthyra/ESM2-8M',
'Synthyra/ESM2-35M',
'Synthyra/ESM2-150M',
'Synthyra/ESM2-650M',
]
parser = argparse.ArgumentParser(description="Train models for protein tasks")
parser.add_argument("--task", type=str, choices=["regression", "classification", "both"],
default="both", help="Task to train model for")
parser.add_argument("--model_path", type=str, default="Synthyra/ESM2-8M",
help="Path to the model to train")
parser.add_argument("--use_lora", action="store_true", default=True,
help="Whether to use LoRA for fine-tuning")
parser.add_argument("--batch_size", type=int, default=2,
help="Batch size for training")
parser.add_argument("--lr", type=float, default=5e-5,
help="Learning rate for training")
parser.add_argument("--epochs", type=float, default=1.0,
help="Number of epochs for training")
parser.add_argument("--max_length", type=int, default=512,
help="Maximum length of input sequences")
parser.add_argument("--grad_accum", type=int, default=1,
help="Number of gradient accumulation steps")
parser.add_argument("--patience", type=int, default=3,
help="Early stopping patience - number of evaluation calls with no improvement after which training will be stopped")
args = parser.parse_args()
# Print training configuration
print("\n" + "="*50)
print("TRAINING CONFIGURATION")
print("="*50)
print(f"Task: {args.task}")
print(f"Using LoRA: {args.use_lora}")
print(f"Batch size: {args.batch_size}")
print(f"Learning rate: {args.lr}")
print(f"Number of epochs: {args.epochs}")
print(f"Max sequence length: {args.max_length}")
print(f"Gradient Accumulation Steps: {args.grad_accum}")
print(f"Early stopping patience: {args.patience}")
print("="*50 + "\n")
# Train regression model if required
if args.task in ["regression", "both"]:
print("\n" + "="*50)
print("TRAINING REGRESSION MODEL")
print("="*50)
regression_trainer, test_dataset = train_regression_model(
model_name=args.model_path,
use_lora=args.use_lora,
batch_size=args.batch_size,
learning_rate=args.lr,
num_epochs=args.epochs,
max_length=args.max_length,
gradient_accumulation_steps=args.grad_accum,
patience=args.patience
)
# Train classification model if required
if args.task in ["classification", "both"]:
print("\n" + "="*50)
print("TRAINING CLASSIFICATION MODEL")
print("="*50)
classification_trainer = train_classification_model(
model_name=args.model_path,
use_lora=args.use_lora,
batch_size=args.batch_size,
learning_rate=args.lr,
num_epochs=args.epochs,
max_length=args.max_length,
gradient_accumulation_steps=args.grad_accum,
patience=args.patience
)
print("\nTraining completed!")