directionality_probe / protify /data /data_collators.py
nikraf's picture
Upload folder using huggingface_hub
714cf46 verified
import torch
from typing import List, Tuple, Dict, Union
from .utils import pad_and_concatenate_dimer
def _pad_matrix_embeds(embeds: List[torch.Tensor], max_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
# pad and concatenate, return padded embeds and mask
padded_embeds, attention_masks = [], []
for embed in embeds:
seq_len = embed.size(0)
padding_size = max_len - seq_len
# Create attention mask (1 for real tokens, 0 for padding)
attention_mask = torch.ones(max_len, dtype=torch.long)
if padding_size > 0:
attention_mask[seq_len:] = 0
# Pad along the sequence dimension (dim=0)
padding = torch.zeros((padding_size, embed.size(1)), dtype=embed.dtype)
padded_embed = torch.cat((embed, padding), dim=0)
else:
padded_embed = embed
padded_embeds.append(padded_embed)
attention_masks.append(attention_mask)
return torch.stack(padded_embeds), torch.stack(attention_masks)
class StringCollator:
def __init__(self, tokenizer, **kwargs):
self.tokenizer = tokenizer
def __call__(self, batch: Tuple[List[str], List[str]]) -> Dict[str, torch.Tensor]:
batch = self.tokenizer(batch,
padding='longest',
return_tensors='pt',
add_special_tokens=True)
return batch
class StringLabelsCollator:
def __init__(self, tokenizer, task_type='regression', tokenwise=False, **kwargs):
self.tokenizer = tokenizer
self.task_type = task_type
self.tokenwise = tokenwise
def __call__(self, batch: List[Tuple[str, Union[float, int]]]) -> Dict[str, torch.Tensor]:
seqs = [ex[0] for ex in batch]
labels = [ex[1] for ex in batch]
# Tokenize the sequences
batch_encoding = self.tokenizer(
seqs,
padding='longest',
truncation=False,
return_tensors='pt',
add_special_tokens=True
)
# Handle labels based on tokenwise flag
if self.tokenwise:
# For token-wise labels, we need to pad to match the tokenized sequence length
attention_mask = batch_encoding['attention_mask']
lengths = [torch.sum(attention_mask[i]).item() for i in range(len(batch))]
max_length = max(lengths)
padded_labels = []
for label in labels:
if not isinstance(label, torch.Tensor):
label = torch.tensor(label)
label = label.flatten()
padding_size = max_length - len(label)
# Pad or truncate labels to match tokenized sequence length
if padding_size > 0:
# Pad with -100 (ignored by loss functions)
padding = torch.full((padding_size,), -100, dtype=label.dtype)
padded_label = torch.cat((label, padding))
else:
padded_label = label[:max_length]
padded_labels.append(padded_label)
# Stack all padded labels
batch_encoding['labels'] = torch.stack(padded_labels)
else:
# For sequence-level labels, just stack them
batch_encoding['labels'] = torch.stack([torch.tensor(ex[1]) for ex in batch])
if self.task_type in ['multilabel', 'regression', 'sigmoid_regression']:
batch_encoding['labels'] = batch_encoding['labels'].float()
else:
batch_encoding['labels'] = batch_encoding['labels'].long()
return batch_encoding
class EmbedsLabelsCollator:
def __init__(self, full=False, task_type='regression', tokenwise=False, **kwargs):
self.full = full
self.task_type = task_type
self.tokenwise = tokenwise
def __call__(self, batch: List[Tuple[torch.Tensor, Union[float, int]]]) -> Dict[str, torch.Tensor]:
if self.full:
embeds = [ex[0] for ex in batch]
labels = [ex[1] for ex in batch]
# Find max sequence length for padding
max_length = max(embed.size(0) for embed in embeds)
embeds, attention_mask = _pad_matrix_embeds(embeds, max_length)
# Pad labels
if self.tokenwise:
padded_labels = []
for label in labels:
if not isinstance(label, torch.Tensor):
label = torch.tensor(label)
label = label.flatten()
padding_size = max_length - len(label)
if padding_size > 0:
# Use -100 as padding value for labels (ignored by loss functions)
padding = torch.full((padding_size,), -100, dtype=label.dtype)
padded_label = torch.cat((label, padding))
else:
padded_label = label[:max_length]
padded_labels.append(padded_label)
else:
padded_labels = labels
labels = torch.stack(padded_labels)
if self.task_type in ['multilabel', 'regression', 'sigmoid_regression']:
labels = labels.float()
else:
labels = labels.long()
return {
'embeddings': embeds,
'attention_mask': attention_mask,
'labels': labels,
}
else:
embeds = torch.stack([ex[0] for ex in batch])
labels = torch.stack([ex[1] for ex in batch])
if self.task_type in ['multilabel', 'regression', 'sigmoid_regression']:
labels = labels.float()
else:
labels = labels.long()
return {
'embeddings': embeds,
'labels': labels
}
class PairCollator_input_ids:
def __init__(self, tokenizer, **kwargs):
self.tokenizer = tokenizer
def __call__(self, batch: List[Tuple[str, str, Union[float, int]]]) -> Dict[str, torch.Tensor]:
seqs_a, seqs_b, labels = zip(*batch)
labels = torch.tensor(labels, dtype=torch.float)
tokenized = self.tokenizer(
seqs_a, seqs_b,
padding='longest',
return_tensors='pt'
)
return {
'input_ids': tokenized['input_ids'],
'attention_mask': tokenized['attention_mask'],
'labels': labels
}
class PairCollator_ab:
def __init__(self, tokenizer, **kwargs):
self.tokenizer = tokenizer
def __call__(self, batch: List[Tuple[str, str, Union[float, int]]]) -> Dict[str, torch.Tensor]:
seqs_a, seqs_b, labels = zip(*batch)
labels = torch.tensor(labels, dtype=torch.float)
tokenized_a = self.tokenizer(
seqs_a,
padding='longest',
truncation=True,
return_tensors='pt'
)
tokenized_b = self.tokenizer(
seqs_b,
padding='longest',
truncation=True,
return_tensors='pt'
)
return {
'input_ids_a': tokenized_a['input_ids'],
'input_ids_b': tokenized_b['input_ids'],
'attention_mask_a': tokenized_a['attention_mask'],
'attention_mask_b': tokenized_b['attention_mask'],
'labels': labels
}
class PairEmbedsLabelsCollator:
def __init__(self, full=False, add_token_ids=False, **kwargs):
self.full = full
self.add_token_ids = add_token_ids
def __call__(self, batch: List[Tuple[torch.Tensor, torch.Tensor, Union[float, int]]]) -> Dict[str, torch.Tensor]:
if self.full:
embeds_a = [ex[0] for ex in batch]
embeds_b = [ex[1] for ex in batch]
max_len_a = max(embed.size(0) for embed in embeds_a)
max_len_b = max(embed.size(0) for embed in embeds_b)
embeds_a, attention_mask_a = _pad_matrix_embeds(embeds_a, max_len_a)
embeds_b, attention_mask_b = _pad_matrix_embeds(embeds_b, max_len_b)
embeds, attention_mask = pad_and_concatenate_dimer(embeds_a, embeds_b, attention_mask_a, attention_mask_b)
labels = torch.stack([ex[2] for ex in batch])
# For tasks requiring token type IDs, provide them so the model knows
# which tokens belong to protein A vs protein B
if self.add_token_ids:
batch_size = embeds.size(0)
max_len = embeds.size(1)
token_type_ids = torch.zeros(batch_size, max_len, dtype=torch.long)
for i in range(batch_size):
a_len = int(attention_mask_a[i].sum().item())
b_len = int(attention_mask_b[i].sum().item())
# type 0 for protein A, type 1 for protein B
token_type_ids[i, a_len:a_len + b_len] = 1
return {
'embeddings': embeds,
'attention_mask': attention_mask,
'token_type_ids': token_type_ids,
'labels': labels
}
return {
'embeddings': embeds,
'attention_mask': attention_mask,
'labels': labels
}
else:
embeds_a = torch.stack([ex[0] for ex in batch])
embeds_b = torch.stack([ex[1] for ex in batch])
labels = torch.stack([ex[2] for ex in batch])
embeds = torch.cat([embeds_a, embeds_b], dim=-1)
return {
'embeddings': embeds,
'labels': labels
}
class OneHotCollator:
def __init__(self, alphabet="ACDEFGHIKLMNPQRSTVWY"):
# Add X for unknown amino acids, and special CLS and EOS tokens
alphabet = alphabet + "X"
alphabet = list(alphabet)
self.mapping = {token: idx for idx, token in enumerate(alphabet)}
def __call__(self, batch):
seqs = [ex[0] for ex in batch]
labels = torch.stack([torch.tensor(ex[1]) for ex in batch])
# Find the longest sequence in the batch
max_len = max(len(seq) for seq in seqs)
# One-hot encode and pad each sequence
one_hot_tensors, attention_masks = [], []
for seq in seqs:
seq = list(seq)
# Create one-hot encoding for each sequence (including CLS and EOS)
seq_len = len(seq)
one_hot = torch.zeros(seq_len, len(self.alphabet))
# Add sequence tokens in the middle
for pos, token in enumerate(seq):
if token in self.mapping:
one_hot[pos, self.mapping[token]] = 1.0
else:
# For non-canonical amino acids, use the X token
one_hot[pos, self.mapping["X"]] = 1.0
# Create attention mask (1 for actual tokens, 0 for padding)
attention_mask = torch.ones(seq_len)
# Pad to the max length in this batch
padding_size = max_len - seq_len
if padding_size > 0:
padding = torch.zeros(padding_size, len(self.alphabet))
one_hot = torch.cat([one_hot, padding], dim=0)
# Add zeros to attention mask for padding
mask_padding = torch.zeros(padding_size)
attention_mask = torch.cat([attention_mask, mask_padding], dim=0)
one_hot_tensors.append(one_hot)
attention_masks.append(attention_mask)
# Stack all tensors in the batch
embeddings = torch.stack(one_hot_tensors)
attention_masks = torch.stack(attention_masks)
return {
'embeddings': embeddings,
'attention_mask': attention_masks,
'labels': labels,
}