lhallee commited on
Commit
651ee15
·
verified ·
1 Parent(s): 8b892c5

Upload embedding_mixin.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. embedding_mixin.py +26 -24
embedding_mixin.py CHANGED
@@ -4,16 +4,16 @@ import networkx as nx
4
  import numpy as np
5
  import torch
6
  from tqdm.auto import tqdm
7
- from typing import Callable, List, Optional
8
  from torch.utils.data import DataLoader
9
  from torch.utils.data import Dataset as TorchDataset
10
  from transformers import PreTrainedTokenizerBase
11
 
12
 
13
  class Pooler:
14
- def __init__(self, pooling_types: List[str]):
15
  self.pooling_types = pooling_types
16
- self.pooling_options = {
17
  'mean': self.mean_pooling,
18
  'max': self.max_pooling,
19
  'norm': self.norm_pooling,
@@ -25,10 +25,11 @@ class Pooler:
25
  }
26
 
27
  def _create_pooled_matrices_across_layers(self, attentions: torch.Tensor) -> torch.Tensor:
 
28
  maxed_attentions = torch.max(attentions, dim=1)[0]
29
  return maxed_attentions
30
 
31
- def _page_rank(self, attention_matrix, personalization=None, nstart=None, prune_type="top_k_outdegree"):
32
  # Run PageRank on the attention matrix converted to a graph.
33
  # Raises exceptions if the graph doesn't match the token sequence or has no edges.
34
  # Returns the PageRank scores for each token node.
@@ -41,13 +42,13 @@ class Pooler:
41
 
42
  return nx.pagerank(G, alpha=0.85, tol=1e-06, weight='weight', personalization=personalization, nstart=nstart, max_iter=100)
43
 
44
- def _convert_to_graph(self, matrix):
45
  # Convert a matrix (e.g., attention scores) to a directed graph using networkx.
46
  # Each element in the matrix represents a directed edge with a weight.
47
  G = nx.from_numpy_array(matrix, create_using=nx.DiGraph)
48
  return G
49
 
50
- def _calculate_importance_weights(self, dict_importance, attention_mask: Optional[torch.Tensor] = None):
51
  # Remove keys where attention_mask is 0
52
  if attention_mask is not None:
53
  for k in list(dict_importance.keys()):
@@ -59,7 +60,7 @@ class Pooler:
59
  total = sum(dict_importance.values())
60
  return np.array([v / total for _, v in dict_importance.items()])
61
 
62
- def _pool_parti(self, emb: torch.Tensor, attentions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d)
63
  maxed_attentions = self._create_pooled_matrices_across_layers(attentions).numpy()
64
  # emb is (b, L, d), maxed_attentions is (b, L, L)
65
  emb_pooled = []
@@ -71,35 +72,35 @@ class Pooler:
71
  pooled = torch.tensor(np.array(emb_pooled))
72
  return pooled
73
 
74
- def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
75
  if attention_mask is None:
76
  return emb.mean(dim=1)
77
  else:
78
  attention_mask = attention_mask.unsqueeze(-1)
79
  return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
80
 
81
- def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
82
  if attention_mask is None:
83
  return emb.max(dim=1).values
84
  else:
85
  attention_mask = attention_mask.unsqueeze(-1)
86
  return (emb * attention_mask).max(dim=1).values
87
 
88
- def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
89
  if attention_mask is None:
90
  return emb.norm(dim=1, p=2)
91
  else:
92
  attention_mask = attention_mask.unsqueeze(-1)
93
  return (emb * attention_mask).norm(dim=1, p=2)
94
 
95
- def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
96
  if attention_mask is None:
97
  return emb.median(dim=1).values
98
  else:
99
  attention_mask = attention_mask.unsqueeze(-1)
100
  return (emb * attention_mask).median(dim=1).values
101
 
102
- def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
103
  if attention_mask is None:
104
  return emb.std(dim=1)
105
  else:
@@ -107,7 +108,7 @@ class Pooler:
107
  var = self.var_pooling(emb, attention_mask, **kwargs)
108
  return torch.sqrt(var)
109
 
110
- def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
111
  if attention_mask is None:
112
  return emb.var(dim=1)
113
  else:
@@ -122,7 +123,7 @@ class Pooler:
122
  var = (squared_diff * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d)
123
  return var
124
 
125
- def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
126
  return emb[:, 0, :]
127
 
128
  def __call__(
@@ -130,8 +131,8 @@ class Pooler:
130
  emb: torch.Tensor,
131
  attention_mask: Optional[torch.Tensor] = None,
132
  attentions: Optional[torch.Tensor] = None
133
- ): # [mean, max]
134
- final_emb = []
135
  for pooling_type in self.pooling_types:
136
  final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions)) # (b, d)
137
  return torch.cat(final_emb, dim=-1) # (b, n_pooling_types * d)
@@ -139,7 +140,7 @@ class Pooler:
139
 
140
  class ProteinDataset(TorchDataset):
141
  """Simple dataset for protein sequences."""
142
- def __init__(self, sequences: list[str]):
143
  self.sequences = sequences
144
 
145
  def __len__(self) -> int:
@@ -149,8 +150,8 @@ class ProteinDataset(TorchDataset):
149
  return self.sequences[idx]
150
 
151
 
152
- def build_collator(tokenizer: PreTrainedTokenizerBase) -> Callable[[list[str]], dict[str, torch.Tensor]]:
153
- def _collate_fn(sequences: list[str]) -> dict[str, torch.Tensor]:
154
  return tokenizer(sequences, return_tensors="pt", padding='longest')
155
  return _collate_fn
156
 
@@ -184,7 +185,7 @@ class EmbeddingMixin:
184
  """Get the device of the model."""
185
  return next(self.parameters()).device
186
 
187
- def _read_sequences_from_db(self, db_path: str) -> set[str]:
188
  """Read sequences from SQLite database."""
189
  sequences = []
190
  with sqlite3.connect(db_path) as conn:
@@ -216,7 +217,7 @@ class EmbeddingMixin:
216
  cursor.execute("ALTER TABLE embeddings ADD COLUMN dtype TEXT")
217
  conn.commit()
218
 
219
- def load_embeddings_from_pth(self, save_path: str) -> dict[str, torch.Tensor]:
220
  assert os.path.exists(save_path), f"Embedding file does not exist: {save_path}"
221
  payload = torch.load(save_path, map_location="cpu", weights_only=True)
222
  assert isinstance(payload, dict), "Expected .pth embeddings file to contain a dictionary."
@@ -225,9 +226,9 @@ class EmbeddingMixin:
225
  assert isinstance(tensor, torch.Tensor), "Expected embedding dictionary values to be tensors."
226
  return payload
227
 
228
- def load_embeddings_from_db(self, db_path: str, sequences: Optional[List[str]] = None) -> dict[str, torch.Tensor]:
229
  assert os.path.exists(db_path), f"Embedding database does not exist: {db_path}"
230
- loaded: dict[str, torch.Tensor] = {}
231
  with sqlite3.connect(db_path) as conn:
232
  self._ensure_embeddings_table(conn)
233
  cursor = conn.cursor()
@@ -277,7 +278,7 @@ class EmbeddingMixin:
277
  save_path: str = 'embeddings.pth',
278
  fasta_path: Optional[str] = None,
279
  **kwargs,
280
- ) -> Optional[dict[str, torch.Tensor]]:
281
  """
282
  Embed a dataset of protein sequences.
283
 
@@ -306,6 +307,7 @@ class EmbeddingMixin:
306
  device = None
307
 
308
  def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
 
309
  if full_embeddings or residue_embeddings.ndim == 2:
310
  return residue_embeddings
311
  return pooler(residue_embeddings, attention_mask)
 
4
  import numpy as np
5
  import torch
6
  from tqdm.auto import tqdm
7
+ from typing import Callable, Dict, List, Optional, Set
8
  from torch.utils.data import DataLoader
9
  from torch.utils.data import Dataset as TorchDataset
10
  from transformers import PreTrainedTokenizerBase
11
 
12
 
13
  class Pooler:
14
+ def __init__(self, pooling_types: List[str]) -> None:
15
  self.pooling_types = pooling_types
16
+ self.pooling_options: Dict[str, Callable] = {
17
  'mean': self.mean_pooling,
18
  'max': self.max_pooling,
19
  'norm': self.norm_pooling,
 
25
  }
26
 
27
  def _create_pooled_matrices_across_layers(self, attentions: torch.Tensor) -> torch.Tensor:
28
+ assert isinstance(attentions, torch.Tensor)
29
  maxed_attentions = torch.max(attentions, dim=1)[0]
30
  return maxed_attentions
31
 
32
+ def _page_rank(self, attention_matrix: np.ndarray, personalization: Optional[dict] = None, nstart: Optional[dict] = None, prune_type: str = "top_k_outdegree") -> Dict[int, float]:
33
  # Run PageRank on the attention matrix converted to a graph.
34
  # Raises exceptions if the graph doesn't match the token sequence or has no edges.
35
  # Returns the PageRank scores for each token node.
 
42
 
43
  return nx.pagerank(G, alpha=0.85, tol=1e-06, weight='weight', personalization=personalization, nstart=nstart, max_iter=100)
44
 
45
+ def _convert_to_graph(self, matrix: np.ndarray) -> nx.DiGraph:
46
  # Convert a matrix (e.g., attention scores) to a directed graph using networkx.
47
  # Each element in the matrix represents a directed edge with a weight.
48
  G = nx.from_numpy_array(matrix, create_using=nx.DiGraph)
49
  return G
50
 
51
+ def _calculate_importance_weights(self, dict_importance: Dict[int, float], attention_mask: Optional[torch.Tensor] = None) -> np.ndarray:
52
  # Remove keys where attention_mask is 0
53
  if attention_mask is not None:
54
  for k in list(dict_importance.keys()):
 
60
  total = sum(dict_importance.values())
61
  return np.array([v / total for _, v in dict_importance.items()])
62
 
63
+ def _pool_parti(self, emb: torch.Tensor, attentions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: # (b, L, d) -> (b, d)
64
  maxed_attentions = self._create_pooled_matrices_across_layers(attentions).numpy()
65
  # emb is (b, L, d), maxed_attentions is (b, L, L)
66
  emb_pooled = []
 
72
  pooled = torch.tensor(np.array(emb_pooled))
73
  return pooled
74
 
75
+ def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
76
  if attention_mask is None:
77
  return emb.mean(dim=1)
78
  else:
79
  attention_mask = attention_mask.unsqueeze(-1)
80
  return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
81
 
82
+ def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
83
  if attention_mask is None:
84
  return emb.max(dim=1).values
85
  else:
86
  attention_mask = attention_mask.unsqueeze(-1)
87
  return (emb * attention_mask).max(dim=1).values
88
 
89
+ def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
90
  if attention_mask is None:
91
  return emb.norm(dim=1, p=2)
92
  else:
93
  attention_mask = attention_mask.unsqueeze(-1)
94
  return (emb * attention_mask).norm(dim=1, p=2)
95
 
96
+ def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
97
  if attention_mask is None:
98
  return emb.median(dim=1).values
99
  else:
100
  attention_mask = attention_mask.unsqueeze(-1)
101
  return (emb * attention_mask).median(dim=1).values
102
 
103
+ def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
104
  if attention_mask is None:
105
  return emb.std(dim=1)
106
  else:
 
108
  var = self.var_pooling(emb, attention_mask, **kwargs)
109
  return torch.sqrt(var)
110
 
111
+ def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
112
  if attention_mask is None:
113
  return emb.var(dim=1)
114
  else:
 
123
  var = (squared_diff * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d)
124
  return var
125
 
126
+ def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
127
  return emb[:, 0, :]
128
 
129
  def __call__(
 
131
  emb: torch.Tensor,
132
  attention_mask: Optional[torch.Tensor] = None,
133
  attentions: Optional[torch.Tensor] = None
134
+ ) -> torch.Tensor: # [mean, max]
135
+ final_emb: List[torch.Tensor] = []
136
  for pooling_type in self.pooling_types:
137
  final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions)) # (b, d)
138
  return torch.cat(final_emb, dim=-1) # (b, n_pooling_types * d)
 
140
 
141
  class ProteinDataset(TorchDataset):
142
  """Simple dataset for protein sequences."""
143
+ def __init__(self, sequences: List[str]) -> None:
144
  self.sequences = sequences
145
 
146
  def __len__(self) -> int:
 
150
  return self.sequences[idx]
151
 
152
 
153
+ def build_collator(tokenizer: PreTrainedTokenizerBase) -> Callable[[List[str]], Dict[str, torch.Tensor]]:
154
+ def _collate_fn(sequences: List[str]) -> Dict[str, torch.Tensor]:
155
  return tokenizer(sequences, return_tensors="pt", padding='longest')
156
  return _collate_fn
157
 
 
185
  """Get the device of the model."""
186
  return next(self.parameters()).device
187
 
188
+ def _read_sequences_from_db(self, db_path: str) -> Set[str]:
189
  """Read sequences from SQLite database."""
190
  sequences = []
191
  with sqlite3.connect(db_path) as conn:
 
217
  cursor.execute("ALTER TABLE embeddings ADD COLUMN dtype TEXT")
218
  conn.commit()
219
 
220
+ def load_embeddings_from_pth(self, save_path: str) -> Dict[str, torch.Tensor]:
221
  assert os.path.exists(save_path), f"Embedding file does not exist: {save_path}"
222
  payload = torch.load(save_path, map_location="cpu", weights_only=True)
223
  assert isinstance(payload, dict), "Expected .pth embeddings file to contain a dictionary."
 
226
  assert isinstance(tensor, torch.Tensor), "Expected embedding dictionary values to be tensors."
227
  return payload
228
 
229
+ def load_embeddings_from_db(self, db_path: str, sequences: Optional[List[str]] = None) -> Dict[str, torch.Tensor]:
230
  assert os.path.exists(db_path), f"Embedding database does not exist: {db_path}"
231
+ loaded: Dict[str, torch.Tensor] = {}
232
  with sqlite3.connect(db_path) as conn:
233
  self._ensure_embeddings_table(conn)
234
  cursor = conn.cursor()
 
278
  save_path: str = 'embeddings.pth',
279
  fasta_path: Optional[str] = None,
280
  **kwargs,
281
+ ) -> Optional[Dict[str, torch.Tensor]]:
282
  """
283
  Embed a dataset of protein sequences.
284
 
 
307
  device = None
308
 
309
  def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
310
+ assert isinstance(residue_embeddings, torch.Tensor)
311
  if full_embeddings or residue_embeddings.ndim == 2:
312
  return residue_embeddings
313
  return pooler(residue_embeddings, attention_mask)