File size: 17,205 Bytes
714cf46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
import entrypoint_setup

import os
import torch
import warnings
import sqlite3
import gzip
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from dataclasses import dataclass
from typing import Optional, Callable, List
from huggingface_hub import hf_hub_download

try:
    from seed_utils import seed_worker, dataloader_generator, get_global_seed
    from data.dataset_classes import SimpleProteinDataset
    from base_models.get_base_models import get_base_model
    from pooler import Pooler
    from utils import torch_load, print_message, maybe_compile
except ImportError:
    from .seed_utils import seed_worker, dataloader_generator, get_global_seed
    from .data.dataset_classes import SimpleProteinDataset
    from .base_models.get_base_models import get_base_model
    from .pooler import Pooler
    from .utils import torch_load, print_message, maybe_compile


def build_collator(tokenizer) -> Callable[[List[str]], tuple[torch.Tensor, torch.Tensor]]:
    def _collate_fn(sequences: List[str]) -> tuple[torch.Tensor, torch.Tensor]:
        """Collate function for batching sequences."""
        return tokenizer(sequences, return_tensors="pt", padding='longest', pad_to_multiple_of=8)
    return _collate_fn


def get_embedding_filename(model_name: str, matrix_embed: bool, pooling_types: List[str], extension: str = 'pth') -> str:
    """
    Generate embedding filename with pooling types for vector embeddings.
    
    Args:
        model_name: Name of the model
        matrix_embed: Whether embeddings are matrices (True) or vectors (False)
        pooling_types: List of pooling types used (only relevant for vector embeddings)
        extension: File extension ('pth' or 'db')
    
    Returns:
        Filename string in format: {model_name}_{matrix_embed}[_{pooling_types}].{extension}
    """
    base_name = f'{model_name}_{matrix_embed}'
    if not matrix_embed and pooling_types:
        # For vector embeddings, include pooling types in filename
        pooling_str = '_'.join(sorted(pooling_types))  # Sort for consistency
        base_name = f'{base_name}_{pooling_str}'
    return f'{base_name}.{extension}'


@dataclass
class EmbeddingArguments:
    def __init__(
            self,
            embedding_batch_size: int = 4,
            embedding_num_workers: int = 0,
            download_embeddings: bool = False,
            download_dir: str = 'Synthyra/vector_embeddings',
            matrix_embed: bool = False,
            embedding_pooling_types: List[str] = ['mean'],
            save_embeddings: bool = False,
            embed_dtype: torch.dtype = torch.float32,
            model_dtype: torch.dtype = None,
            sql: bool = False,
            embedding_save_dir: str = 'embeddings',
            **kwargs
    ):
        self.batch_size = embedding_batch_size
        self.num_workers = embedding_num_workers
        self.download_embeddings = download_embeddings
        self.download_dir = download_dir
        self.matrix_embed = matrix_embed
        self.pooling_types = embedding_pooling_types
        self.save_embeddings = save_embeddings
        self.embed_dtype = embed_dtype
        self.model_dtype = model_dtype
        self.sql = sql
        self.embedding_save_dir = embedding_save_dir


class Embedder:
    def __init__(self, args: EmbeddingArguments, all_seqs: List[str]):
        self.args = args
        self.all_seqs = all_seqs
        self.batch_size = args.batch_size
        self.num_workers = args.num_workers
        self.matrix_embed = args.matrix_embed
        self.pooling_types = args.pooling_types
        self.download_embeddings = args.download_embeddings
        self.download_dir = args.download_dir
        self.save_embeddings = args.save_embeddings
        self.embed_dtype = args.embed_dtype
        self.model_dtype = args.model_dtype
        self.sql = args.sql
        self.embedding_save_dir = args.embedding_save_dir

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print_message(f'Device {self.device} found')

    def _download_embeddings(self, model_name: str):
        # download from download_dir
        # unzip
        # move to embedding_save_dir
        filename = get_embedding_filename(model_name, self.matrix_embed, self.pooling_types, 'pth')
        try:
            local_path = hf_hub_download(
                repo_id=self.download_dir,
                filename=f'embeddings/{filename}.gz',
                repo_type='dataset'
            )
        except:
            print(f'No embeddings found for {model_name} in {self.download_dir}')
            return

        # unzip
        print_message(f'Unzipping {local_path}')
        with gzip.open(local_path, 'rb') as f_in:
            with open(local_path.replace('.gz', ''), 'wb') as f_out:
                f_out.write(f_in.read())
        # move to embedding_save_dir
        unzipped_path = local_path.replace('.gz', '')
        final_path = os.path.join(self.embedding_save_dir, filename)
        
        if os.path.exists(final_path):
            print_message(f'Found existing embeddings in {final_path}')
            # Load downloaded embeddings
            downloaded_embeddings = torch_load(unzipped_path)
            existing_embeddings = torch_load(final_path)

            download_dtype = torch.float16
            if self.embed_dtype != download_dtype:
                print_message(f"Warning:\nDownloaded embeddings are {download_dtype} but the current setting is {self.embed_dtype}\nWhen combining with existing embeddings, this could result in unintended biases or reductions in performance")

            # Combine with existing embeddings
            print_message('Combining and casting')
            downloaded_embeddings.update(existing_embeddings)

            # Cast all embeddings to the correct dtype
            for seq in downloaded_embeddings:
                downloaded_embeddings[seq] = downloaded_embeddings[seq].to(self.embed_dtype)

            # Save the combined embeddings
            print_message(f'Saving combined embeddings to {final_path}')
            torch.save(downloaded_embeddings, final_path)
        else:
            print_message(f'Downloading embeddings from {self.download_dir}, no previous embeddings found')
            downloaded_embeddings = torch.load(unzipped_path)
            torch.save(downloaded_embeddings, final_path)
        return final_path

    def _read_sequences_from_db(self, db_path: str) -> set[str]:
        """Read sequences from SQLite database."""
        import sqlite3
        sequences = []
        with sqlite3.connect(db_path) as conn:
            c = conn.cursor()
            c.execute("SELECT sequence FROM embeddings")
            while True:
                row = c.fetchone()
                if row is None:
                    break
                sequences.append(row[0])
        return set(sequences)

    def _read_embeddings_from_disk(self, model_name: str):
        if self.sql:
            filename = get_embedding_filename(model_name, self.matrix_embed, self.pooling_types, 'db')
            save_path = os.path.join(self.embedding_save_dir, filename)
            if os.path.exists(save_path):
                conn = sqlite3.connect(save_path)
                c = conn.cursor()
                c.execute('CREATE TABLE IF NOT EXISTS embeddings (sequence text PRIMARY KEY, embedding blob)')
                already_embedded = self._read_sequences_from_db(save_path)
                to_embed = [seq for seq in self.all_seqs if seq not in already_embedded]
                print_message(f"Loaded {len(already_embedded)} already embedded sequences from {save_path}\nEmbedding {len(to_embed)} new sequences")
                return to_embed, save_path, {}
            else:
                print_message(f"No embeddings found in {save_path}")
                return self.all_seqs, save_path, {}

        else:
            embeddings_dict = {}
            filename = get_embedding_filename(model_name, self.matrix_embed, self.pooling_types, 'pth')
            save_path = os.path.join(self.embedding_save_dir, filename)
            if os.path.exists(save_path):
                print_message(f"Loading embeddings from {save_path}")
                embeddings_dict = torch_load(save_path)
                print_message(f"Loaded {len(embeddings_dict)} embeddings from {save_path}")
                # Cast existing embeddings to the specified dtype
                #for seq in embeddings_dict:
                #    embeddings_dict[seq] = embeddings_dict[seq].to(self.embed_dtype)
                to_embed = [seq for seq in self.all_seqs if seq not in embeddings_dict]
                return to_embed, save_path, embeddings_dict
            else:
                print_message(f"No embeddings found in {save_path}")
                return self.all_seqs, save_path, {}

    @torch.inference_mode()
    def _embed_sequences(
            self,
            to_embed: List[str],
            save_path: str,
            embedding_model: any,
            tokenizer: any,
            embeddings_dict: dict[str, torch.Tensor]) -> Optional[dict[str, torch.Tensor]]:
        os.makedirs(self.embedding_save_dir, exist_ok=True)
        model = embedding_model.to(self.device).eval()
        model = maybe_compile(model)
        device = self.device
        collate_fn = build_collator(tokenizer)
        print_message(f'Pooling types: {self.pooling_types}')
        if self.matrix_embed:
            pooler = None
        else:
            pooler = Pooler(self.pooling_types)

        def _get_embeddings(
                residue_embeddings: torch.Tensor,
                attention_mask: Optional[torch.Tensor] = None,
                attentions: Optional[torch.Tensor] = None
            ) -> torch.Tensor:
            if residue_embeddings.ndim == 2 or self.matrix_embed: # sometimes already vector emb
                return residue_embeddings
            else:
                return pooler(emb=residue_embeddings, attention_mask=attention_mask, attentions=attentions)

        dataset = SimpleProteinDataset(to_embed)
        dataloader = DataLoader(
            dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            prefetch_factor=2 if self.num_workers > 0 else None,
            collate_fn=collate_fn,
            shuffle=False,
            pin_memory=True,
            worker_init_fn=seed_worker,
            generator=dataloader_generator(get_global_seed())
        )

        if self.sql:
            conn = sqlite3.connect(save_path)
            c = conn.cursor()
            c.execute('CREATE TABLE IF NOT EXISTS embeddings (sequence text PRIMARY KEY, embedding blob)')

        for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
            seqs = to_embed[i * self.batch_size:(i + 1) * self.batch_size]
            batch = {k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)}
            if 'attention_mask' in batch:
                attention_mask = batch['attention_mask']
            elif 'sequence_ids' in batch:
                attention_mask = (batch['sequence_ids'] != -1).long().to(device)
            else:
                attention_mask = torch.ones_like(batch['input_ids'], device=device)

            if 'parti' in self.pooling_types:
                try:
                    residue_embeddings, attentions = model(**batch, output_attentions=True)
                    embeddings = _get_embeddings(residue_embeddings, attention_mask=attention_mask, attentions=attentions).cpu()
                except Exception as e:
                    print_message(f"Error in parti pooling: {e}\nDefaulting to mean pooling")
                    self.pooling_types = ['mean']
                    pooler = Pooler(self.pooling_types)
                    residue_embeddings = model(**batch)
                    embeddings = _get_embeddings(residue_embeddings, attention_mask=attention_mask).cpu()
            else:
                residue_embeddings = model(**batch)
                embeddings = _get_embeddings(residue_embeddings, attention_mask=attention_mask).cpu()

            for seq, emb, mask in zip(seqs, embeddings, attention_mask.cpu()):
                if self.matrix_embed:
                    emb = emb[mask.bool()]
                
                if self.sql:
                    c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)", 
                            (seq, emb.numpy().tobytes())) # only supports float32
                else:
                    embeddings_dict[seq] = emb.to(self.embed_dtype)
            
            if (i + 1) % 100 == 0 and self.sql:
                conn.commit()

        if self.sql:
            conn.commit()
            conn.close()
            return embeddings_dict
        
        if self.save_embeddings:
            print_message(f"Saving embeddings to {save_path}")
            torch.save(embeddings_dict, save_path)
            
        return embeddings_dict

    def __call__(self, model_name: str, model_type: str = None, model_path: str = None):
        if self.download_embeddings:
            self._download_embeddings(model_name)

        if self.device == 'cpu':
            warnings.warn("Downloading embeddings is recommended for CPU usage - Embedding on CPU will be extremely slow!")
        to_embed, save_path, embeddings_dict = self._read_embeddings_from_disk(model_name)
        
        if len(to_embed) > 0:
            print_message(f"Embedding {len(to_embed)} sequences with {model_name}")
            dispatch_name = model_type or model_name
            model, tokenizer = get_base_model(dispatch_name, dtype=self.model_dtype, model_path=model_path)

            return self._embed_sequences(to_embed, save_path, model, tokenizer, embeddings_dict)
        else:
            print_message(f"No sequences to embed with {model_name}")
            return embeddings_dict


if __name__ == '__main__':
    ### Embed all supported datasets with all supported models
    # py -m embedder
    import argparse
    from huggingface_hub import upload_file, login
    from data.supported_datasets import vector_benchmark
    from data.data_mixin import DataArguments, DataMixin
    from base_models.get_base_models import BaseModelArguments, get_base_model
    from seed_utils import set_global_seed

    parser = argparse.ArgumentParser()
    parser.add_argument('--token', default=None, help='Huggingface token')
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--num_workers', type=int, default=4)
    parser.add_argument('--embed_dtype', type=str, default='float16')
    parser.add_argument('--model_names', nargs='+', default=['standard'])
    parser.add_argument('--models_to_skip', nargs='+', default=[], help='When checking for existing embeddings, skip these models.')
    parser.add_argument('--embedding_save_dir', type=str, default='embeddings')
    parser.add_argument('--download_dir', type=str, default='Synthyra/vector_embeddings')
    parser.add_argument('--embedding_pooling_types', nargs='+', default=['mean', 'var'], help='Pooling types for embeddings.')
    args = parser.parse_args()

    chosen_seed = set_global_seed()

    if args.token is not None:
        login(args.token)

    if args.embed_dtype == 'float16':
        dtype = torch.float16
    elif args.embed_dtype == 'bfloat16':
        dtype = torch.bfloat16
    elif args.embed_dtype == 'float32':
        dtype = torch.float32
    else:
        raise ValueError(f"Invalid embedding dtype: {args.embed_dtype}")

    # Get data    
    data_args = DataArguments(
        data_names=vector_benchmark,
        max_length=1024,
        trim=False
    )
    all_seqs = DataMixin(data_args).get_data()[1]

    # Embed for each model
    model_args = BaseModelArguments(model_names=args.model_names)
    for model_name in model_args.model_names:

        embedder_args = EmbeddingArguments(
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            download_embeddings=model_name not in args.models_to_skip,
            matrix_embed=False,
            embedding_pooling_types=args.embedding_pooling_types,
            save_embeddings=True,
            embed_dtype=dtype,
            sql=False,
            embedding_save_dir='embeddings'
        )
        embedder = Embedder(embedder_args, all_seqs)

        _ = embedder(model_name)
        filename = get_embedding_filename(model_name, False, embedder_args.pooling_types, 'pth')
        save_path = os.path.join(args.embedding_save_dir, filename)
        
        compressed_path = f"{save_path}.gz"
        print(f"Compressing {save_path} to {compressed_path}")
        with open(save_path, 'rb') as f_in:
            with gzip.open(compressed_path, 'wb') as f_out:
                f_out.write(f_in.read())
        upload_path = compressed_path
        path_in_repo = f'embeddings/{filename}.gz'
            
        upload_file(
             path_or_fileobj=upload_path,
            path_in_repo=path_in_repo,
            repo_id=args.download_dir,
            repo_type='dataset'
        )

    print('Done')