nikraf's picture
Upload folder using huggingface_hub
714cf46 verified
"""
HuggingFace-compatible vec2vec implementation for embedding translation.
Based on: "Harnessing the Universal Geometry of Embeddings" (arXiv:2505.12540)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Dict, Optional, List
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import ModelOutput
# =============================================================================
# Configuration
# =============================================================================
class Vec2VecConfig(PretrainedConfig):
"""Configuration for Vec2Vec model."""
model_type = "vec2vec"
def __init__(
self,
encoder_names: List[str] = None,
encoder_dims: List[int] = None,
d_adapter: int = 1024,
d_hidden: int = 1024,
d_transform: int = 1024,
adapter_depth: int = 3,
transform_depth: int = 4,
disc_dim: int = 1024,
disc_depth: int = 5,
weight_init: str = "kaiming",
norm_style: str = "batch",
normalize_embeddings: bool = True,
# Loss coefficients
loss_coefficient_rec: float = 1.0,
loss_coefficient_vsp: float = 1.0,
loss_coefficient_cc_trans: float = 10.0,
loss_coefficient_cc_vsp: float = 10.0,
loss_coefficient_cc_rec: float = 0.0,
loss_coefficient_gen: float = 1.0,
loss_coefficient_latent_gen: float = 1.0,
loss_coefficient_similarity_gen: float = 0.0,
loss_coefficient_disc: float = 1.0,
loss_coefficient_r1_penalty: float = 0.0,
# Training settings
noise_level: float = 0.0,
max_grad_norm: float = 1000.0,
**kwargs,
):
super().__init__(**kwargs)
self.encoder_names = encoder_names or ["model_a", "model_b"]
self.encoder_dims = encoder_dims or [768, 768]
self.d_adapter = d_adapter
self.d_hidden = d_hidden
self.d_transform = d_transform
self.adapter_depth = adapter_depth
self.transform_depth = transform_depth
self.disc_dim = disc_dim
self.disc_depth = disc_depth
self.weight_init = weight_init
self.norm_style = norm_style
self.normalize_embeddings = normalize_embeddings
# Loss coefficients
self.loss_coefficient_rec = loss_coefficient_rec
self.loss_coefficient_vsp = loss_coefficient_vsp
self.loss_coefficient_cc_trans = loss_coefficient_cc_trans
self.loss_coefficient_cc_vsp = loss_coefficient_cc_vsp
self.loss_coefficient_cc_rec = loss_coefficient_cc_rec
self.loss_coefficient_gen = loss_coefficient_gen
self.loss_coefficient_latent_gen = loss_coefficient_latent_gen
self.loss_coefficient_similarity_gen = loss_coefficient_similarity_gen
self.loss_coefficient_disc = loss_coefficient_disc
self.loss_coefficient_r1_penalty = loss_coefficient_r1_penalty
self.noise_level = noise_level
self.max_grad_norm = max_grad_norm
def get_encoder_dims_dict(self) -> Dict[str, int]:
"""Return encoder dimensions as a dictionary."""
return dict(zip(self.encoder_names, self.encoder_dims))
# =============================================================================
# Model Outputs
# =============================================================================
@dataclass
class Vec2VecOutput(ModelOutput):
"""Output type for Vec2Vec forward pass."""
loss: Optional[torch.FloatTensor] = None
reconstructions: Optional[Dict[str, torch.Tensor]] = None
translations: Optional[Dict[str, Dict[str, torch.Tensor]]] = None
latents: Optional[Dict[str, torch.Tensor]] = None
metrics: Optional[Dict[str, float]] = None
# =============================================================================
# Model Components
# =============================================================================
def add_residual(input_x: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
"""Add residual connection with dimension matching."""
if input_x.shape[1] < x.shape[1]:
padding = torch.zeros(x.shape[0], x.shape[1] - input_x.shape[1], device=x.device)
input_x = torch.cat([input_x, padding], dim=1)
elif input_x.shape[1] > x.shape[1]:
input_x = input_x[:, :x.shape[1]]
return x + input_x
class MLPWithResidual(nn.Module):
"""MLP with residual connections."""
def __init__(
self,
depth: int,
in_dim: int,
hidden_dim: int,
out_dim: int,
norm_style: str = "batch",
weight_init: str = "kaiming",
):
super().__init__()
self.layers = nn.ModuleList()
norm_layer = nn.BatchNorm1d if norm_style == "batch" else nn.LayerNorm
for layer_idx in range(depth):
if layer_idx == 0:
h_dim = out_dim if depth == 1 else hidden_dim
self.layers.append(nn.Sequential(nn.Linear(in_dim, h_dim), nn.SiLU()))
elif layer_idx < depth - 1:
self.layers.append(nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.SiLU(),
norm_layer(hidden_dim),
nn.Dropout(p=0.1),
))
else:
self.layers.append(nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.Dropout(p=0.1),
nn.SiLU(),
nn.Linear(hidden_dim, out_dim),
))
self._initialize_weights(weight_init)
def _initialize_weights(self, weight_init: str):
for module in self.modules():
if isinstance(module, nn.Linear):
if weight_init == "kaiming":
nn.init.kaiming_normal_(module.weight, a=0, mode="fan_in", nonlinearity="relu")
elif weight_init == "xavier":
nn.init.xavier_normal_(module.weight)
elif weight_init == "orthogonal":
nn.init.orthogonal_(module.weight)
module.bias.data.fill_(0)
elif isinstance(module, nn.BatchNorm1d):
nn.init.normal_(module.weight, mean=1.0, std=0.02)
nn.init.normal_(module.bias, mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
nn.init.constant_(module.bias, 0)
nn.init.constant_(module.weight, 1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for layer in self.layers:
input_x = x
x = layer(x)
x = add_residual(input_x, x)
return x
class Discriminator(nn.Module):
"""Discriminator network for adversarial training."""
def __init__(
self,
latent_dim: int,
hidden_dim: int = 1024,
depth: int = 5,
weight_init: str = "kaiming",
):
super().__init__()
self.layers = nn.ModuleList()
if depth >= 2:
layers = [nn.Linear(latent_dim, hidden_dim), nn.Dropout(0.0)]
for _ in range(depth - 2):
layers.extend([
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.Dropout(0.0),
])
layers.extend([nn.SiLU(), nn.Linear(hidden_dim, 1)])
self.layers.append(nn.Sequential(*layers))
else:
self.layers.append(nn.Linear(latent_dim, 1))
self._initialize_weights(weight_init)
def _initialize_weights(self, weight_init: str):
for module in self.modules():
if isinstance(module, nn.Linear):
if weight_init == "kaiming":
nn.init.kaiming_normal_(module.weight, a=0, mode="fan_in", nonlinearity="relu")
elif weight_init == "xavier":
nn.init.xavier_normal_(module.weight)
elif weight_init == "orthogonal":
nn.init.orthogonal_(module.weight)
module.bias.data.fill_(0)
elif isinstance(module, nn.LayerNorm):
nn.init.constant_(module.bias, 0)
nn.init.constant_(module.weight, 1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for layer in self.layers:
x = layer(x)
return x
# =============================================================================
# Main Model
# =============================================================================
class Vec2VecModel(PreTrainedModel):
"""
Vec2Vec model for embedding translation between different spaces.
Architecture:
Input -> In Adapter -> Transform -> Out Adapter -> Output
"""
config_class = Vec2VecConfig
all_tied_weights_keys = {}
def __init__(self, config: Vec2VecConfig):
super().__init__(config)
self.config = config
encoder_dims = config.get_encoder_dims_dict()
# Shared transform
self.transform = MLPWithResidual(
depth=config.transform_depth,
in_dim=config.d_adapter,
hidden_dim=config.d_transform,
out_dim=config.d_adapter,
norm_style=config.norm_style,
weight_init=config.weight_init,
)
# Adapters for each encoder
self.in_adapters = nn.ModuleDict()
self.out_adapters = nn.ModuleDict()
for name, dim in encoder_dims.items():
self.in_adapters[name] = MLPWithResidual(
config.adapter_depth, dim, config.d_hidden, config.d_adapter,
config.norm_style, config.weight_init,
)
self.out_adapters[name] = MLPWithResidual(
config.adapter_depth, config.d_adapter, config.d_hidden, dim,
config.norm_style, config.weight_init,
)
# Discriminators
self.discriminators = nn.ModuleDict()
for name, dim in encoder_dims.items():
self.discriminators[name] = Discriminator(
dim, config.disc_dim, config.disc_depth, config.weight_init
)
self.discriminators["latent"] = Discriminator(
config.d_adapter, config.disc_dim, config.disc_depth, config.weight_init
)
self.post_init()
def add_encoder(self, name: str, dim: int, overwrite: bool = False):
"""Add a new encoder to the model."""
if name in self.in_adapters and not overwrite:
print(f"Encoder {name} already exists, skipping...")
return
self.in_adapters[name] = MLPWithResidual(
self.config.adapter_depth, dim, self.config.d_hidden, self.config.d_adapter,
self.config.norm_style, self.config.weight_init,
)
self.out_adapters[name] = MLPWithResidual(
self.config.adapter_depth, self.config.d_adapter, self.config.d_hidden, dim,
self.config.norm_style, self.config.weight_init,
)
self.discriminators[name] = Discriminator(
dim, self.config.disc_dim, self.config.disc_depth, self.config.weight_init
)
# Update config
if name not in self.config.encoder_names:
self.config.encoder_names.append(name)
self.config.encoder_dims.append(dim)
def _get_latent(self, emb: torch.Tensor, encoder_name: str) -> torch.Tensor:
"""Get latent representation from embedding."""
z = self.in_adapters[encoder_name](emb)
return self.transform(z)
def _decode(self, latent: torch.Tensor, encoder_name: str) -> torch.Tensor:
"""Decode latent to target embedding space."""
out = self.out_adapters[encoder_name](latent)
if self.config.normalize_embeddings:
out = F.normalize(out, p=2, dim=1)
return out
def translate(self, embeddings: torch.Tensor, src: str, tgt: str) -> torch.Tensor:
"""Translate embeddings from source to target space."""
latent = self._get_latent(embeddings, src)
return self._decode(latent, tgt)
def forward(
self,
inputs: Dict[str, torch.Tensor],
noise_level: float = None,
return_latents: bool = False,
) -> Vec2VecOutput:
"""
Forward pass computing reconstructions and translations.
Args:
inputs: Dict mapping encoder names to embeddings
noise_level: Optional noise for training
return_latents: Whether to return latent representations
"""
noise_level = noise_level if noise_level is not None else self.config.noise_level
reconstructions = {}
translations = {}
latents = {}
for src_name, emb in inputs.items():
# Add noise during training
if self.training and noise_level > 0.0:
emb = emb + torch.randn_like(emb) * noise_level
emb = F.normalize(emb, p=2, dim=1)
latent = self._get_latent(emb, src_name)
if return_latents:
latents[src_name] = latent
for tgt_name in inputs.keys():
decoded = self._decode(latent, tgt_name)
if tgt_name == src_name:
reconstructions[src_name] = decoded
else:
if tgt_name not in translations:
translations[tgt_name] = {}
translations[tgt_name][src_name] = decoded
return Vec2VecOutput(
reconstructions=reconstructions,
translations=translations,
latents=latents if return_latents else None,
)
# =============================================================================
# Loss Functions
# =============================================================================
def reconstruction_loss(inputs: Dict[str, torch.Tensor], recons: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Reconstruction loss (1 - cosine similarity)."""
loss = sum(1 - F.cosine_similarity(inputs[k], recons[k], dim=1).mean() for k in inputs)
return loss / len(inputs)
def translation_loss(inputs: Dict[str, torch.Tensor], translations: Dict[str, Dict[str, torch.Tensor]]) -> torch.Tensor:
"""Translation loss (1 - cosine similarity)."""
loss = 0.0
count = 0
for tgt, emb in inputs.items():
for trans in translations[tgt].values():
loss += 1 - F.cosine_similarity(emb, trans, dim=1).mean()
count += 1
return loss / max(count, 1)
def vsp_loss(inputs: Dict[str, torch.Tensor], translations: Dict[str, Dict[str, torch.Tensor]]) -> torch.Tensor:
"""Vector Space Preservation (VSP) loss."""
loss = 0.0
count = 0
EPS = 1e-10
for out_name in inputs:
for in_name in translations[out_name]:
B = F.normalize(inputs[out_name].detach(), p=2, dim=1)
A = F.normalize(translations[out_name][in_name], p=2, dim=1)
in_sims = B @ B.T
out_sims = A @ A.T
out_sims_reflected = A @ B.T
loss += (in_sims - out_sims).abs().mean()
loss += (in_sims - out_sims_reflected).abs().mean()
count += 1
return loss / max(count, 1)
from typing import Optional, Union, List, Dict
from transformers import AutoModel, AutoTokenizer
from .base_tokenizer import BaseSequenceTokenizer
from .supported_models import all_presets_with_paths
from pooler import Pooler
presets = {
'vec2vec-ESM2-8-ESM2-35': 'Synthyra/ESM2-8-ESM2-35-sequence-sequence',
'vec2vec-ESM2-8-ESM2-150': 'Synthyra/ESM2-8-ESM2-150-sequence-sequence',
'vec2vec-ESM2-8-ESM2-650': 'Synthyra/ESM2-8-ESM2-650-sequence-sequence',
'vec2vec-ESM2-8-ESM2-3B': 'Synthyra/ESM2-8-ESM2-3B-sequence-sequence',
'vec2vec-ESM2-35-ESM2-150': 'Synthyra/ESM2-35-ESM2-150-sequence-sequence',
'vec2vec-ESM2-35-ESM2-650': 'Synthyra/ESM2-35-ESM2-650-sequence-sequence',
'vec2vec-ESM2-35-ESM2-3B': 'Synthyra/ESM2-35-ESM2-3B-sequence-sequence',
'vec2vec-ESM2-150-ESM2-650': 'Synthyra/ESM2-150-ESM2-650-sequence-sequence',
'vec2vec-ESM2-150-ESM2-3B': 'Synthyra/ESM2-150-ESM2-3B-sequence-sequence',
'vec2vec-ESM2-650-ESM2-3B': 'Synthyra/ESM2-650-ESM2-3B-sequence-sequence',
}
class Vec2VecTokenizerWrapper(BaseSequenceTokenizer):
def __init__(self, tokenizer: AutoTokenizer):
super().__init__(tokenizer)
def __call__(self, sequences: Union[str, List[str]], **kwargs) -> Dict[str, torch.Tensor]:
if isinstance(sequences, str):
sequences = [sequences]
kwargs.setdefault('return_tensors', 'pt')
kwargs.setdefault('padding', 'longest')
kwargs.setdefault('add_special_tokens', True)
tokenized = self.tokenizer(sequences, **kwargs)
return tokenized
class Vec2VecForEmbedding(nn.Module):
def __init__(
self,
config: Vec2VecConfig,
base_model: AutoModel,
vec2vec_model: Vec2VecModel,
model_name_a: str,
model_name_b: str,
):
super().__init__()
self.base_model = base_model
self.vec2vec_model = vec2vec_model
self.config = config
self.pooler = Pooler(['mean', 'var'])
self.model_name_a = model_name_a
self.model_name_b = model_name_b
self.normalize = config.normalize_embeddings
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = False,
**kwargs,
) -> torch.Tensor:
# only vector embeddings, don't use output_attentions, etc.
base_state = self.base_model(input_ids, attention_mask=attention_mask).last_hidden_state
base_vec = self.pooler(base_state, attention_mask=attention_mask)
if self.normalize:
base_vec = F.normalize(base_vec, p=2, dim=1)
translated_ab = self.vec2vec_model.translate(base_vec, src=self.model_name_a, tgt=self.model_name_b)
return translated_ab
def get_vec2vec_tokenizer(preset: str, model_path: str = None):
# TODO work with new Vec2Vec .tokenizer_a and .tokenizer_b
path = model_path or all_presets_with_paths[preset]
try:
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
except:
model = AutoModel.from_pretrained(path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model.config.tokenizer_name)
return Vec2VecTokenizerWrapper(tokenizer)
def build_vec2vec_model(preset: str, masked_lm: bool = False, dtype: torch.dtype = None, model_path: str = None, **kwargs):
if masked_lm:
raise ValueError("Masked LM is not supported for Vec2VecForEmbedding")
else:
model_path = model_path or presets[preset]
config = Vec2VecConfig.from_pretrained(model_path)
encoder_names = config.encoder_names
encoder_dims = config.encoder_dims
if encoder_dims[0] >= encoder_dims[1]:
model_name_a = encoder_names[0]
model_name_b = encoder_names[1]
else:
model_name_a = encoder_names[1]
model_name_b = encoder_names[0]
base_model = AutoModel.from_pretrained(all_presets_with_paths[model_name_a], dtype=dtype, trust_remote_code=True)
base_tokenizer = base_model.tokenizer
vec2vec_model = Vec2VecModel(config).from_pretrained(model_path)
model = Vec2VecForEmbedding(config, base_model, vec2vec_model, model_name_a, model_name_b)
tokenizer = Vec2VecTokenizerWrapper(base_tokenizer)
return model, tokenizer
def get_vec2vec_for_training(preset: str, tokenwise: bool = False, num_labels: int = None, hybrid: bool = False):
raise ValueError("Vec2VecForTraining is not supported yet")
if __name__ == '__main__':
# py -m src.protify.base_models.vec2vec
model, tokenizer = build_vec2vec_model('ESM2-8-ESM2-35')
print(model)
print(tokenizer)
print(tokenizer('MEKVQYLTRSAIRRASTIEMPQQARQKLQNLFINFCLILICBBOLLICIIVMLL'))