| | import torch
|
| | import torch.nn as nn
|
| | import numpy as np
|
| |
|
| |
|
| | class PretrainingPDeepPP:
|
| | def __init__(self, embedding_dim=1280, target_length=33, esm_ratio=None, device=None):
|
| | """
|
| | 初始化 PretrainingPDeepPP 类。
|
| |
|
| | Args:
|
| | embedding_dim: 嵌入维度大小。
|
| | target_length: 目标序列长度。
|
| | esm_ratio: ESM 表征与嵌入表示的权重比例(由外部赋值)。
|
| | device: 设备信息。
|
| | """
|
| | self.embedding_dim = embedding_dim
|
| | self.target_length = target_length
|
| | self.esm_ratio = esm_ratio
|
| | self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| |
|
| | def extract_esm_representations(self, sequences, esm_model, batch_converter, batch_size=32):
|
| | """
|
| | 提取 ESM 表征,并直接返回形状为 (batch_size, target_length, embedding_dim) 的结果。
|
| | """
|
| | sequence_representations = []
|
| | print("Sequences to process:", sequences)
|
| | print("Batch size:", batch_size)
|
| |
|
| |
|
| | labeled_sequences = [(None, seq) for seq in sequences]
|
| |
|
| | for i in range(0, len(labeled_sequences), batch_size):
|
| | batch = labeled_sequences[i:i + batch_size]
|
| | if len(batch) == 0:
|
| | continue
|
| |
|
| | _, batch_strs, batch_tokens = batch_converter(batch)
|
| | batch_tokens = batch_tokens.to(self.device)
|
| |
|
| |
|
| | with torch.no_grad():
|
| | results = esm_model(batch_tokens, repr_layers=[33], return_contacts=False)
|
| |
|
| |
|
| | for token_repr in results["representations"][33]:
|
| | sequence_representations.append(token_repr[:self.target_length])
|
| |
|
| | if len(sequence_representations) == 0:
|
| | raise ValueError("No ESM representations were generated. Check your input sequences and batch processing logic.")
|
| |
|
| |
|
| | return torch.stack(sequence_representations)
|
| |
|
| | def pad_sequences(self, sequences, max_len=None, pad_value=0):
|
| | if max_len is None:
|
| | max_len = max(len(seq) for seq in sequences)
|
| | padded_sequences = torch.zeros((len(sequences), max_len), dtype=torch.long)
|
| | for i, seq in enumerate(sequences):
|
| | padded_sequences[i, :len(seq)] = torch.tensor(seq)
|
| | return padded_sequences
|
| |
|
| | def seq_to_indices(self, seq, vocab_dict):
|
| | return [vocab_dict.get(char, 0) for char in seq]
|
| |
|
| | def create_embeddings(self, sequences, vocab, esm_model, esm_alphabet, batch_size=16):
|
| | """
|
| | 创建嵌入向量,使用类的 esm_ratio 属性动态控制权重分配。
|
| |
|
| | Args:
|
| | sequences: 输入序列列表。
|
| | vocab: 字符词汇表。
|
| | esm_model: 预训练的 ESM 模型。
|
| | esm_alphabet: ESM 模型的字母表。
|
| | batch_size: 批量大小。
|
| |
|
| | Returns:
|
| | 结合 ESM 表征与嵌入表示的嵌入结果。
|
| | """
|
| | if self.esm_ratio is None:
|
| | raise ValueError("esm_ratio is not set. Please assign a value before creating embeddings.")
|
| |
|
| |
|
| | vocab_dict = {char: i for i, char in enumerate(vocab)}
|
| |
|
| |
|
| | indices = [self.seq_to_indices(seq, vocab_dict) for seq in sequences]
|
| | indices_padded = self.pad_sequences(indices, max_len=self.target_length)
|
| |
|
| |
|
| | class EmbeddingPretrainedModel(nn.Module):
|
| | def __init__(self, vocab_size, embedding_dim, max_len):
|
| | super(EmbeddingPretrainedModel, self).__init__()
|
| | self.embedding = nn.Embedding(vocab_size, embedding_dim)
|
| | self.fc = nn.Linear(embedding_dim, embedding_dim)
|
| |
|
| | def forward(self, x):
|
| | x = self.embedding(x)
|
| | x = self.fc(x)
|
| | return x
|
| |
|
| | embedding_model = EmbeddingPretrainedModel(len(vocab), self.embedding_dim, self.target_length).to(self.device)
|
| |
|
| |
|
| | esm_representations = self.extract_esm_representations(
|
| | sequences,
|
| | esm_model,
|
| | esm_alphabet.get_batch_converter(),
|
| | batch_size=batch_size
|
| | )
|
| |
|
| |
|
| | with torch.no_grad():
|
| | embedding_output = embedding_model(indices_padded.to(self.device))
|
| |
|
| |
|
| | combined_representations = self.esm_ratio * esm_representations + (1 - self.esm_ratio) * embedding_output
|
| |
|
| | return combined_representations |