nikraf's picture
Upload folder using huggingface_hub
714cf46 verified
import entrypoint_setup
import os
import torch
import warnings
import sqlite3
import gzip
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from dataclasses import dataclass
from typing import Optional, Callable, List
from huggingface_hub import hf_hub_download
try:
from seed_utils import seed_worker, dataloader_generator, get_global_seed
from data.dataset_classes import SimpleProteinDataset
from base_models.get_base_models import get_base_model
from pooler import Pooler
from utils import torch_load, print_message, maybe_compile
except ImportError:
from .seed_utils import seed_worker, dataloader_generator, get_global_seed
from .data.dataset_classes import SimpleProteinDataset
from .base_models.get_base_models import get_base_model
from .pooler import Pooler
from .utils import torch_load, print_message, maybe_compile
def build_collator(tokenizer) -> Callable[[List[str]], tuple[torch.Tensor, torch.Tensor]]:
def _collate_fn(sequences: List[str]) -> tuple[torch.Tensor, torch.Tensor]:
"""Collate function for batching sequences."""
return tokenizer(sequences, return_tensors="pt", padding='longest', pad_to_multiple_of=8)
return _collate_fn
def get_embedding_filename(model_name: str, matrix_embed: bool, pooling_types: List[str], extension: str = 'pth') -> str:
"""
Generate embedding filename with pooling types for vector embeddings.
Args:
model_name: Name of the model
matrix_embed: Whether embeddings are matrices (True) or vectors (False)
pooling_types: List of pooling types used (only relevant for vector embeddings)
extension: File extension ('pth' or 'db')
Returns:
Filename string in format: {model_name}_{matrix_embed}[_{pooling_types}].{extension}
"""
base_name = f'{model_name}_{matrix_embed}'
if not matrix_embed and pooling_types:
# For vector embeddings, include pooling types in filename
pooling_str = '_'.join(sorted(pooling_types)) # Sort for consistency
base_name = f'{base_name}_{pooling_str}'
return f'{base_name}.{extension}'
@dataclass
class EmbeddingArguments:
def __init__(
self,
embedding_batch_size: int = 4,
embedding_num_workers: int = 0,
download_embeddings: bool = False,
download_dir: str = 'Synthyra/vector_embeddings',
matrix_embed: bool = False,
embedding_pooling_types: List[str] = ['mean'],
save_embeddings: bool = False,
embed_dtype: torch.dtype = torch.float32,
model_dtype: torch.dtype = None,
sql: bool = False,
embedding_save_dir: str = 'embeddings',
**kwargs
):
self.batch_size = embedding_batch_size
self.num_workers = embedding_num_workers
self.download_embeddings = download_embeddings
self.download_dir = download_dir
self.matrix_embed = matrix_embed
self.pooling_types = embedding_pooling_types
self.save_embeddings = save_embeddings
self.embed_dtype = embed_dtype
self.model_dtype = model_dtype
self.sql = sql
self.embedding_save_dir = embedding_save_dir
class Embedder:
def __init__(self, args: EmbeddingArguments, all_seqs: List[str]):
self.args = args
self.all_seqs = all_seqs
self.batch_size = args.batch_size
self.num_workers = args.num_workers
self.matrix_embed = args.matrix_embed
self.pooling_types = args.pooling_types
self.download_embeddings = args.download_embeddings
self.download_dir = args.download_dir
self.save_embeddings = args.save_embeddings
self.embed_dtype = args.embed_dtype
self.model_dtype = args.model_dtype
self.sql = args.sql
self.embedding_save_dir = args.embedding_save_dir
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print_message(f'Device {self.device} found')
def _download_embeddings(self, model_name: str):
# download from download_dir
# unzip
# move to embedding_save_dir
filename = get_embedding_filename(model_name, self.matrix_embed, self.pooling_types, 'pth')
try:
local_path = hf_hub_download(
repo_id=self.download_dir,
filename=f'embeddings/{filename}.gz',
repo_type='dataset'
)
except:
print(f'No embeddings found for {model_name} in {self.download_dir}')
return
# unzip
print_message(f'Unzipping {local_path}')
with gzip.open(local_path, 'rb') as f_in:
with open(local_path.replace('.gz', ''), 'wb') as f_out:
f_out.write(f_in.read())
# move to embedding_save_dir
unzipped_path = local_path.replace('.gz', '')
final_path = os.path.join(self.embedding_save_dir, filename)
if os.path.exists(final_path):
print_message(f'Found existing embeddings in {final_path}')
# Load downloaded embeddings
downloaded_embeddings = torch_load(unzipped_path)
existing_embeddings = torch_load(final_path)
download_dtype = torch.float16
if self.embed_dtype != download_dtype:
print_message(f"Warning:\nDownloaded embeddings are {download_dtype} but the current setting is {self.embed_dtype}\nWhen combining with existing embeddings, this could result in unintended biases or reductions in performance")
# Combine with existing embeddings
print_message('Combining and casting')
downloaded_embeddings.update(existing_embeddings)
# Cast all embeddings to the correct dtype
for seq in downloaded_embeddings:
downloaded_embeddings[seq] = downloaded_embeddings[seq].to(self.embed_dtype)
# Save the combined embeddings
print_message(f'Saving combined embeddings to {final_path}')
torch.save(downloaded_embeddings, final_path)
else:
print_message(f'Downloading embeddings from {self.download_dir}, no previous embeddings found')
downloaded_embeddings = torch.load(unzipped_path)
torch.save(downloaded_embeddings, final_path)
return final_path
def _read_sequences_from_db(self, db_path: str) -> set[str]:
"""Read sequences from SQLite database."""
import sqlite3
sequences = []
with sqlite3.connect(db_path) as conn:
c = conn.cursor()
c.execute("SELECT sequence FROM embeddings")
while True:
row = c.fetchone()
if row is None:
break
sequences.append(row[0])
return set(sequences)
def _read_embeddings_from_disk(self, model_name: str):
if self.sql:
filename = get_embedding_filename(model_name, self.matrix_embed, self.pooling_types, 'db')
save_path = os.path.join(self.embedding_save_dir, filename)
if os.path.exists(save_path):
conn = sqlite3.connect(save_path)
c = conn.cursor()
c.execute('CREATE TABLE IF NOT EXISTS embeddings (sequence text PRIMARY KEY, embedding blob)')
already_embedded = self._read_sequences_from_db(save_path)
to_embed = [seq for seq in self.all_seqs if seq not in already_embedded]
print_message(f"Loaded {len(already_embedded)} already embedded sequences from {save_path}\nEmbedding {len(to_embed)} new sequences")
return to_embed, save_path, {}
else:
print_message(f"No embeddings found in {save_path}")
return self.all_seqs, save_path, {}
else:
embeddings_dict = {}
filename = get_embedding_filename(model_name, self.matrix_embed, self.pooling_types, 'pth')
save_path = os.path.join(self.embedding_save_dir, filename)
if os.path.exists(save_path):
print_message(f"Loading embeddings from {save_path}")
embeddings_dict = torch_load(save_path)
print_message(f"Loaded {len(embeddings_dict)} embeddings from {save_path}")
# Cast existing embeddings to the specified dtype
#for seq in embeddings_dict:
# embeddings_dict[seq] = embeddings_dict[seq].to(self.embed_dtype)
to_embed = [seq for seq in self.all_seqs if seq not in embeddings_dict]
return to_embed, save_path, embeddings_dict
else:
print_message(f"No embeddings found in {save_path}")
return self.all_seqs, save_path, {}
@torch.inference_mode()
def _embed_sequences(
self,
to_embed: List[str],
save_path: str,
embedding_model: any,
tokenizer: any,
embeddings_dict: dict[str, torch.Tensor]) -> Optional[dict[str, torch.Tensor]]:
os.makedirs(self.embedding_save_dir, exist_ok=True)
model = embedding_model.to(self.device).eval()
model = maybe_compile(model)
device = self.device
collate_fn = build_collator(tokenizer)
print_message(f'Pooling types: {self.pooling_types}')
if self.matrix_embed:
pooler = None
else:
pooler = Pooler(self.pooling_types)
def _get_embeddings(
residue_embeddings: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
attentions: Optional[torch.Tensor] = None
) -> torch.Tensor:
if residue_embeddings.ndim == 2 or self.matrix_embed: # sometimes already vector emb
return residue_embeddings
else:
return pooler(emb=residue_embeddings, attention_mask=attention_mask, attentions=attentions)
dataset = SimpleProteinDataset(to_embed)
dataloader = DataLoader(
dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
prefetch_factor=2 if self.num_workers > 0 else None,
collate_fn=collate_fn,
shuffle=False,
pin_memory=True,
worker_init_fn=seed_worker,
generator=dataloader_generator(get_global_seed())
)
if self.sql:
conn = sqlite3.connect(save_path)
c = conn.cursor()
c.execute('CREATE TABLE IF NOT EXISTS embeddings (sequence text PRIMARY KEY, embedding blob)')
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
seqs = to_embed[i * self.batch_size:(i + 1) * self.batch_size]
batch = {k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)}
if 'attention_mask' in batch:
attention_mask = batch['attention_mask']
elif 'sequence_ids' in batch:
attention_mask = (batch['sequence_ids'] != -1).long().to(device)
else:
attention_mask = torch.ones_like(batch['input_ids'], device=device)
if 'parti' in self.pooling_types:
try:
residue_embeddings, attentions = model(**batch, output_attentions=True)
embeddings = _get_embeddings(residue_embeddings, attention_mask=attention_mask, attentions=attentions).cpu()
except Exception as e:
print_message(f"Error in parti pooling: {e}\nDefaulting to mean pooling")
self.pooling_types = ['mean']
pooler = Pooler(self.pooling_types)
residue_embeddings = model(**batch)
embeddings = _get_embeddings(residue_embeddings, attention_mask=attention_mask).cpu()
else:
residue_embeddings = model(**batch)
embeddings = _get_embeddings(residue_embeddings, attention_mask=attention_mask).cpu()
for seq, emb, mask in zip(seqs, embeddings, attention_mask.cpu()):
if self.matrix_embed:
emb = emb[mask.bool()]
if self.sql:
c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
(seq, emb.numpy().tobytes())) # only supports float32
else:
embeddings_dict[seq] = emb.to(self.embed_dtype)
if (i + 1) % 100 == 0 and self.sql:
conn.commit()
if self.sql:
conn.commit()
conn.close()
return embeddings_dict
if self.save_embeddings:
print_message(f"Saving embeddings to {save_path}")
torch.save(embeddings_dict, save_path)
return embeddings_dict
def __call__(self, model_name: str, model_type: str = None, model_path: str = None):
if self.download_embeddings:
self._download_embeddings(model_name)
if self.device == 'cpu':
warnings.warn("Downloading embeddings is recommended for CPU usage - Embedding on CPU will be extremely slow!")
to_embed, save_path, embeddings_dict = self._read_embeddings_from_disk(model_name)
if len(to_embed) > 0:
print_message(f"Embedding {len(to_embed)} sequences with {model_name}")
dispatch_name = model_type or model_name
model, tokenizer = get_base_model(dispatch_name, dtype=self.model_dtype, model_path=model_path)
return self._embed_sequences(to_embed, save_path, model, tokenizer, embeddings_dict)
else:
print_message(f"No sequences to embed with {model_name}")
return embeddings_dict
if __name__ == '__main__':
### Embed all supported datasets with all supported models
# py -m embedder
import argparse
from huggingface_hub import upload_file, login
from data.supported_datasets import vector_benchmark
from data.data_mixin import DataArguments, DataMixin
from base_models.get_base_models import BaseModelArguments, get_base_model
from seed_utils import set_global_seed
parser = argparse.ArgumentParser()
parser.add_argument('--token', default=None, help='Huggingface token')
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--embed_dtype', type=str, default='float16')
parser.add_argument('--model_names', nargs='+', default=['standard'])
parser.add_argument('--models_to_skip', nargs='+', default=[], help='When checking for existing embeddings, skip these models.')
parser.add_argument('--embedding_save_dir', type=str, default='embeddings')
parser.add_argument('--download_dir', type=str, default='Synthyra/vector_embeddings')
parser.add_argument('--embedding_pooling_types', nargs='+', default=['mean', 'var'], help='Pooling types for embeddings.')
args = parser.parse_args()
chosen_seed = set_global_seed()
if args.token is not None:
login(args.token)
if args.embed_dtype == 'float16':
dtype = torch.float16
elif args.embed_dtype == 'bfloat16':
dtype = torch.bfloat16
elif args.embed_dtype == 'float32':
dtype = torch.float32
else:
raise ValueError(f"Invalid embedding dtype: {args.embed_dtype}")
# Get data
data_args = DataArguments(
data_names=vector_benchmark,
max_length=1024,
trim=False
)
all_seqs = DataMixin(data_args).get_data()[1]
# Embed for each model
model_args = BaseModelArguments(model_names=args.model_names)
for model_name in model_args.model_names:
embedder_args = EmbeddingArguments(
batch_size=args.batch_size,
num_workers=args.num_workers,
download_embeddings=model_name not in args.models_to_skip,
matrix_embed=False,
embedding_pooling_types=args.embedding_pooling_types,
save_embeddings=True,
embed_dtype=dtype,
sql=False,
embedding_save_dir='embeddings'
)
embedder = Embedder(embedder_args, all_seqs)
_ = embedder(model_name)
filename = get_embedding_filename(model_name, False, embedder_args.pooling_types, 'pth')
save_path = os.path.join(args.embedding_save_dir, filename)
compressed_path = f"{save_path}.gz"
print(f"Compressing {save_path} to {compressed_path}")
with open(save_path, 'rb') as f_in:
with gzip.open(compressed_path, 'wb') as f_out:
f_out.write(f_in.read())
upload_path = compressed_path
path_in_repo = f'embeddings/{filename}.gz'
upload_file(
path_or_fileobj=upload_path,
path_in_repo=path_in_repo,
repo_id=args.download_dir,
repo_type='dataset'
)
print('Done')