Upload embedding_mixin.py with huggingface_hub
Browse files- embedding_mixin.py +26 -24
embedding_mixin.py
CHANGED
|
@@ -4,16 +4,16 @@ import networkx as nx
|
|
| 4 |
import numpy as np
|
| 5 |
import torch
|
| 6 |
from tqdm.auto import tqdm
|
| 7 |
-
from typing import Callable, List, Optional
|
| 8 |
from torch.utils.data import DataLoader
|
| 9 |
from torch.utils.data import Dataset as TorchDataset
|
| 10 |
from transformers import PreTrainedTokenizerBase
|
| 11 |
|
| 12 |
|
| 13 |
class Pooler:
|
| 14 |
-
def __init__(self, pooling_types: List[str]):
|
| 15 |
self.pooling_types = pooling_types
|
| 16 |
-
self.pooling_options = {
|
| 17 |
'mean': self.mean_pooling,
|
| 18 |
'max': self.max_pooling,
|
| 19 |
'norm': self.norm_pooling,
|
|
@@ -25,10 +25,11 @@ class Pooler:
|
|
| 25 |
}
|
| 26 |
|
| 27 |
def _create_pooled_matrices_across_layers(self, attentions: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 28 |
maxed_attentions = torch.max(attentions, dim=1)[0]
|
| 29 |
return maxed_attentions
|
| 30 |
|
| 31 |
-
def _page_rank(self, attention_matrix, personalization=None, nstart=None, prune_type="top_k_outdegree"):
|
| 32 |
# Run PageRank on the attention matrix converted to a graph.
|
| 33 |
# Raises exceptions if the graph doesn't match the token sequence or has no edges.
|
| 34 |
# Returns the PageRank scores for each token node.
|
|
@@ -41,13 +42,13 @@ class Pooler:
|
|
| 41 |
|
| 42 |
return nx.pagerank(G, alpha=0.85, tol=1e-06, weight='weight', personalization=personalization, nstart=nstart, max_iter=100)
|
| 43 |
|
| 44 |
-
def _convert_to_graph(self, matrix):
|
| 45 |
# Convert a matrix (e.g., attention scores) to a directed graph using networkx.
|
| 46 |
# Each element in the matrix represents a directed edge with a weight.
|
| 47 |
G = nx.from_numpy_array(matrix, create_using=nx.DiGraph)
|
| 48 |
return G
|
| 49 |
|
| 50 |
-
def _calculate_importance_weights(self, dict_importance, attention_mask: Optional[torch.Tensor] = None):
|
| 51 |
# Remove keys where attention_mask is 0
|
| 52 |
if attention_mask is not None:
|
| 53 |
for k in list(dict_importance.keys()):
|
|
@@ -59,7 +60,7 @@ class Pooler:
|
|
| 59 |
total = sum(dict_importance.values())
|
| 60 |
return np.array([v / total for _, v in dict_importance.items()])
|
| 61 |
|
| 62 |
-
def _pool_parti(self, emb: torch.Tensor, attentions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d)
|
| 63 |
maxed_attentions = self._create_pooled_matrices_across_layers(attentions).numpy()
|
| 64 |
# emb is (b, L, d), maxed_attentions is (b, L, L)
|
| 65 |
emb_pooled = []
|
|
@@ -71,35 +72,35 @@ class Pooler:
|
|
| 71 |
pooled = torch.tensor(np.array(emb_pooled))
|
| 72 |
return pooled
|
| 73 |
|
| 74 |
-
def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
|
| 75 |
if attention_mask is None:
|
| 76 |
return emb.mean(dim=1)
|
| 77 |
else:
|
| 78 |
attention_mask = attention_mask.unsqueeze(-1)
|
| 79 |
return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
|
| 80 |
|
| 81 |
-
def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
|
| 82 |
if attention_mask is None:
|
| 83 |
return emb.max(dim=1).values
|
| 84 |
else:
|
| 85 |
attention_mask = attention_mask.unsqueeze(-1)
|
| 86 |
return (emb * attention_mask).max(dim=1).values
|
| 87 |
|
| 88 |
-
def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
|
| 89 |
if attention_mask is None:
|
| 90 |
return emb.norm(dim=1, p=2)
|
| 91 |
else:
|
| 92 |
attention_mask = attention_mask.unsqueeze(-1)
|
| 93 |
return (emb * attention_mask).norm(dim=1, p=2)
|
| 94 |
|
| 95 |
-
def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
|
| 96 |
if attention_mask is None:
|
| 97 |
return emb.median(dim=1).values
|
| 98 |
else:
|
| 99 |
attention_mask = attention_mask.unsqueeze(-1)
|
| 100 |
return (emb * attention_mask).median(dim=1).values
|
| 101 |
|
| 102 |
-
def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
|
| 103 |
if attention_mask is None:
|
| 104 |
return emb.std(dim=1)
|
| 105 |
else:
|
|
@@ -107,7 +108,7 @@ class Pooler:
|
|
| 107 |
var = self.var_pooling(emb, attention_mask, **kwargs)
|
| 108 |
return torch.sqrt(var)
|
| 109 |
|
| 110 |
-
def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
|
| 111 |
if attention_mask is None:
|
| 112 |
return emb.var(dim=1)
|
| 113 |
else:
|
|
@@ -122,7 +123,7 @@ class Pooler:
|
|
| 122 |
var = (squared_diff * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d)
|
| 123 |
return var
|
| 124 |
|
| 125 |
-
def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
|
| 126 |
return emb[:, 0, :]
|
| 127 |
|
| 128 |
def __call__(
|
|
@@ -130,8 +131,8 @@ class Pooler:
|
|
| 130 |
emb: torch.Tensor,
|
| 131 |
attention_mask: Optional[torch.Tensor] = None,
|
| 132 |
attentions: Optional[torch.Tensor] = None
|
| 133 |
-
): # [mean, max]
|
| 134 |
-
final_emb = []
|
| 135 |
for pooling_type in self.pooling_types:
|
| 136 |
final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions)) # (b, d)
|
| 137 |
return torch.cat(final_emb, dim=-1) # (b, n_pooling_types * d)
|
|
@@ -139,7 +140,7 @@ class Pooler:
|
|
| 139 |
|
| 140 |
class ProteinDataset(TorchDataset):
|
| 141 |
"""Simple dataset for protein sequences."""
|
| 142 |
-
def __init__(self, sequences:
|
| 143 |
self.sequences = sequences
|
| 144 |
|
| 145 |
def __len__(self) -> int:
|
|
@@ -149,8 +150,8 @@ class ProteinDataset(TorchDataset):
|
|
| 149 |
return self.sequences[idx]
|
| 150 |
|
| 151 |
|
| 152 |
-
def build_collator(tokenizer: PreTrainedTokenizerBase) -> Callable[[
|
| 153 |
-
def _collate_fn(sequences:
|
| 154 |
return tokenizer(sequences, return_tensors="pt", padding='longest')
|
| 155 |
return _collate_fn
|
| 156 |
|
|
@@ -184,7 +185,7 @@ class EmbeddingMixin:
|
|
| 184 |
"""Get the device of the model."""
|
| 185 |
return next(self.parameters()).device
|
| 186 |
|
| 187 |
-
def _read_sequences_from_db(self, db_path: str) ->
|
| 188 |
"""Read sequences from SQLite database."""
|
| 189 |
sequences = []
|
| 190 |
with sqlite3.connect(db_path) as conn:
|
|
@@ -216,7 +217,7 @@ class EmbeddingMixin:
|
|
| 216 |
cursor.execute("ALTER TABLE embeddings ADD COLUMN dtype TEXT")
|
| 217 |
conn.commit()
|
| 218 |
|
| 219 |
-
def load_embeddings_from_pth(self, save_path: str) ->
|
| 220 |
assert os.path.exists(save_path), f"Embedding file does not exist: {save_path}"
|
| 221 |
payload = torch.load(save_path, map_location="cpu", weights_only=True)
|
| 222 |
assert isinstance(payload, dict), "Expected .pth embeddings file to contain a dictionary."
|
|
@@ -225,9 +226,9 @@ class EmbeddingMixin:
|
|
| 225 |
assert isinstance(tensor, torch.Tensor), "Expected embedding dictionary values to be tensors."
|
| 226 |
return payload
|
| 227 |
|
| 228 |
-
def load_embeddings_from_db(self, db_path: str, sequences: Optional[List[str]] = None) ->
|
| 229 |
assert os.path.exists(db_path), f"Embedding database does not exist: {db_path}"
|
| 230 |
-
loaded:
|
| 231 |
with sqlite3.connect(db_path) as conn:
|
| 232 |
self._ensure_embeddings_table(conn)
|
| 233 |
cursor = conn.cursor()
|
|
@@ -277,7 +278,7 @@ class EmbeddingMixin:
|
|
| 277 |
save_path: str = 'embeddings.pth',
|
| 278 |
fasta_path: Optional[str] = None,
|
| 279 |
**kwargs,
|
| 280 |
-
) -> Optional[
|
| 281 |
"""
|
| 282 |
Embed a dataset of protein sequences.
|
| 283 |
|
|
@@ -306,6 +307,7 @@ class EmbeddingMixin:
|
|
| 306 |
device = None
|
| 307 |
|
| 308 |
def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
|
|
| 309 |
if full_embeddings or residue_embeddings.ndim == 2:
|
| 310 |
return residue_embeddings
|
| 311 |
return pooler(residue_embeddings, attention_mask)
|
|
|
|
| 4 |
import numpy as np
|
| 5 |
import torch
|
| 6 |
from tqdm.auto import tqdm
|
| 7 |
+
from typing import Callable, Dict, List, Optional, Set
|
| 8 |
from torch.utils.data import DataLoader
|
| 9 |
from torch.utils.data import Dataset as TorchDataset
|
| 10 |
from transformers import PreTrainedTokenizerBase
|
| 11 |
|
| 12 |
|
| 13 |
class Pooler:
|
| 14 |
+
def __init__(self, pooling_types: List[str]) -> None:
|
| 15 |
self.pooling_types = pooling_types
|
| 16 |
+
self.pooling_options: Dict[str, Callable] = {
|
| 17 |
'mean': self.mean_pooling,
|
| 18 |
'max': self.max_pooling,
|
| 19 |
'norm': self.norm_pooling,
|
|
|
|
| 25 |
}
|
| 26 |
|
| 27 |
def _create_pooled_matrices_across_layers(self, attentions: torch.Tensor) -> torch.Tensor:
|
| 28 |
+
assert isinstance(attentions, torch.Tensor)
|
| 29 |
maxed_attentions = torch.max(attentions, dim=1)[0]
|
| 30 |
return maxed_attentions
|
| 31 |
|
| 32 |
+
def _page_rank(self, attention_matrix: np.ndarray, personalization: Optional[dict] = None, nstart: Optional[dict] = None, prune_type: str = "top_k_outdegree") -> Dict[int, float]:
|
| 33 |
# Run PageRank on the attention matrix converted to a graph.
|
| 34 |
# Raises exceptions if the graph doesn't match the token sequence or has no edges.
|
| 35 |
# Returns the PageRank scores for each token node.
|
|
|
|
| 42 |
|
| 43 |
return nx.pagerank(G, alpha=0.85, tol=1e-06, weight='weight', personalization=personalization, nstart=nstart, max_iter=100)
|
| 44 |
|
| 45 |
+
def _convert_to_graph(self, matrix: np.ndarray) -> nx.DiGraph:
|
| 46 |
# Convert a matrix (e.g., attention scores) to a directed graph using networkx.
|
| 47 |
# Each element in the matrix represents a directed edge with a weight.
|
| 48 |
G = nx.from_numpy_array(matrix, create_using=nx.DiGraph)
|
| 49 |
return G
|
| 50 |
|
| 51 |
+
def _calculate_importance_weights(self, dict_importance: Dict[int, float], attention_mask: Optional[torch.Tensor] = None) -> np.ndarray:
|
| 52 |
# Remove keys where attention_mask is 0
|
| 53 |
if attention_mask is not None:
|
| 54 |
for k in list(dict_importance.keys()):
|
|
|
|
| 60 |
total = sum(dict_importance.values())
|
| 61 |
return np.array([v / total for _, v in dict_importance.items()])
|
| 62 |
|
| 63 |
+
def _pool_parti(self, emb: torch.Tensor, attentions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: # (b, L, d) -> (b, d)
|
| 64 |
maxed_attentions = self._create_pooled_matrices_across_layers(attentions).numpy()
|
| 65 |
# emb is (b, L, d), maxed_attentions is (b, L, L)
|
| 66 |
emb_pooled = []
|
|
|
|
| 72 |
pooled = torch.tensor(np.array(emb_pooled))
|
| 73 |
return pooled
|
| 74 |
|
| 75 |
+
def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
|
| 76 |
if attention_mask is None:
|
| 77 |
return emb.mean(dim=1)
|
| 78 |
else:
|
| 79 |
attention_mask = attention_mask.unsqueeze(-1)
|
| 80 |
return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
|
| 81 |
|
| 82 |
+
def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
|
| 83 |
if attention_mask is None:
|
| 84 |
return emb.max(dim=1).values
|
| 85 |
else:
|
| 86 |
attention_mask = attention_mask.unsqueeze(-1)
|
| 87 |
return (emb * attention_mask).max(dim=1).values
|
| 88 |
|
| 89 |
+
def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
|
| 90 |
if attention_mask is None:
|
| 91 |
return emb.norm(dim=1, p=2)
|
| 92 |
else:
|
| 93 |
attention_mask = attention_mask.unsqueeze(-1)
|
| 94 |
return (emb * attention_mask).norm(dim=1, p=2)
|
| 95 |
|
| 96 |
+
def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
|
| 97 |
if attention_mask is None:
|
| 98 |
return emb.median(dim=1).values
|
| 99 |
else:
|
| 100 |
attention_mask = attention_mask.unsqueeze(-1)
|
| 101 |
return (emb * attention_mask).median(dim=1).values
|
| 102 |
|
| 103 |
+
def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
|
| 104 |
if attention_mask is None:
|
| 105 |
return emb.std(dim=1)
|
| 106 |
else:
|
|
|
|
| 108 |
var = self.var_pooling(emb, attention_mask, **kwargs)
|
| 109 |
return torch.sqrt(var)
|
| 110 |
|
| 111 |
+
def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
|
| 112 |
if attention_mask is None:
|
| 113 |
return emb.var(dim=1)
|
| 114 |
else:
|
|
|
|
| 123 |
var = (squared_diff * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d)
|
| 124 |
return var
|
| 125 |
|
| 126 |
+
def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
|
| 127 |
return emb[:, 0, :]
|
| 128 |
|
| 129 |
def __call__(
|
|
|
|
| 131 |
emb: torch.Tensor,
|
| 132 |
attention_mask: Optional[torch.Tensor] = None,
|
| 133 |
attentions: Optional[torch.Tensor] = None
|
| 134 |
+
) -> torch.Tensor: # [mean, max]
|
| 135 |
+
final_emb: List[torch.Tensor] = []
|
| 136 |
for pooling_type in self.pooling_types:
|
| 137 |
final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions)) # (b, d)
|
| 138 |
return torch.cat(final_emb, dim=-1) # (b, n_pooling_types * d)
|
|
|
|
| 140 |
|
| 141 |
class ProteinDataset(TorchDataset):
|
| 142 |
"""Simple dataset for protein sequences."""
|
| 143 |
+
def __init__(self, sequences: List[str]) -> None:
|
| 144 |
self.sequences = sequences
|
| 145 |
|
| 146 |
def __len__(self) -> int:
|
|
|
|
| 150 |
return self.sequences[idx]
|
| 151 |
|
| 152 |
|
| 153 |
+
def build_collator(tokenizer: PreTrainedTokenizerBase) -> Callable[[List[str]], Dict[str, torch.Tensor]]:
|
| 154 |
+
def _collate_fn(sequences: List[str]) -> Dict[str, torch.Tensor]:
|
| 155 |
return tokenizer(sequences, return_tensors="pt", padding='longest')
|
| 156 |
return _collate_fn
|
| 157 |
|
|
|
|
| 185 |
"""Get the device of the model."""
|
| 186 |
return next(self.parameters()).device
|
| 187 |
|
| 188 |
+
def _read_sequences_from_db(self, db_path: str) -> Set[str]:
|
| 189 |
"""Read sequences from SQLite database."""
|
| 190 |
sequences = []
|
| 191 |
with sqlite3.connect(db_path) as conn:
|
|
|
|
| 217 |
cursor.execute("ALTER TABLE embeddings ADD COLUMN dtype TEXT")
|
| 218 |
conn.commit()
|
| 219 |
|
| 220 |
+
def load_embeddings_from_pth(self, save_path: str) -> Dict[str, torch.Tensor]:
|
| 221 |
assert os.path.exists(save_path), f"Embedding file does not exist: {save_path}"
|
| 222 |
payload = torch.load(save_path, map_location="cpu", weights_only=True)
|
| 223 |
assert isinstance(payload, dict), "Expected .pth embeddings file to contain a dictionary."
|
|
|
|
| 226 |
assert isinstance(tensor, torch.Tensor), "Expected embedding dictionary values to be tensors."
|
| 227 |
return payload
|
| 228 |
|
| 229 |
+
def load_embeddings_from_db(self, db_path: str, sequences: Optional[List[str]] = None) -> Dict[str, torch.Tensor]:
|
| 230 |
assert os.path.exists(db_path), f"Embedding database does not exist: {db_path}"
|
| 231 |
+
loaded: Dict[str, torch.Tensor] = {}
|
| 232 |
with sqlite3.connect(db_path) as conn:
|
| 233 |
self._ensure_embeddings_table(conn)
|
| 234 |
cursor = conn.cursor()
|
|
|
|
| 278 |
save_path: str = 'embeddings.pth',
|
| 279 |
fasta_path: Optional[str] = None,
|
| 280 |
**kwargs,
|
| 281 |
+
) -> Optional[Dict[str, torch.Tensor]]:
|
| 282 |
"""
|
| 283 |
Embed a dataset of protein sequences.
|
| 284 |
|
|
|
|
| 307 |
device = None
|
| 308 |
|
| 309 |
def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 310 |
+
assert isinstance(residue_embeddings, torch.Tensor)
|
| 311 |
if full_embeddings or residue_embeddings.ndim == 2:
|
| 312 |
return residue_embeddings
|
| 313 |
return pooler(residue_embeddings, attention_mask)
|