| | |
| | import random |
| | import torch |
| | import numpy as np |
| | import sqlite3 |
| | import torch.nn.functional as F |
| | from torch.utils.data import Dataset as TorchDataset |
| | from tqdm.auto import tqdm |
| |
|
| | try: |
| | from utils import print_message |
| | except ImportError: |
| | from ..utils import print_message |
| | from typing import List |
| |
|
| |
|
| | class PairEmbedsLabelsDatasetFromDisk(TorchDataset): |
| | def __init__( |
| | self, |
| | hf_dataset, |
| | col_a='SeqA', |
| | col_b='SeqB', |
| | label_col='labels', |
| | full=False, |
| | db_path='embeddings.db', |
| | batch_size=64, |
| | read_scaler=100, |
| | input_size=768, |
| | task_type='regression', |
| | train=True, |
| | random_pair_flipping=False, |
| | **kwargs |
| | ): |
| | self.seqs_a, self.seqs_b, self.labels = list(hf_dataset[col_a]), list(hf_dataset[col_b]), list(hf_dataset[label_col]) |
| | self.db_file = db_path |
| | self.batch_size = batch_size |
| | self.input_size = input_size |
| | self.full = full |
| | self.length = len(self.labels) |
| | self.read_amt = read_scaler * self.batch_size |
| | self.embeddings_a, self.embeddings_b, self.current_labels = [], [], [] |
| | self.count, self.index = 0, 0 |
| | self.task_type = task_type |
| | self.train = train |
| | self.random_pair_flipping = random_pair_flipping |
| |
|
| | def __len__(self): |
| | return self.length |
| |
|
| | def check_seqs(self, all_seqs): |
| | missing_seqs = [seq for seq in self.seqs_a + self.seqs_b if seq not in all_seqs] |
| | if missing_seqs: |
| | print_message(f'Sequences not found in embeddings: {missing_seqs}') |
| | else: |
| | print_message('All sequences in embeddings') |
| |
|
| | def reset_epoch(self): |
| | data = list(zip(self.seqs_a, self.seqs_b, self.labels)) |
| | random.shuffle(data) |
| | self.seqs_a, self.seqs_b, self.labels = zip(*data) |
| | self.seqs_a, self.seqs_b, self.labels = list(self.seqs_a), list(self.seqs_b), list(self.labels) |
| | self.embeddings_a, self.embeddings_b, self.current_labels = [], [], [] |
| | self.count, self.index = 0, 0 |
| |
|
| | def get_embedding(self, c, seq): |
| | result = c.execute("SELECT embedding FROM embeddings WHERE sequence=?", (seq,)) |
| | row = result.fetchone() |
| | if row is None: |
| | raise ValueError(f"Embedding not found for sequence: {seq}") |
| | emb_data = row[0] |
| | emb = torch.tensor(np.frombuffer(emb_data, dtype=np.float32).reshape(-1, self.input_size)) |
| | return emb |
| |
|
| | def read_embeddings(self): |
| | embeddings_a, embeddings_b, labels = [], [], [] |
| | self.count += self.read_amt |
| | if self.count >= self.length: |
| | self.reset_epoch() |
| | conn = sqlite3.connect(self.db_file) |
| | c = conn.cursor() |
| | |
| | for i in range(self.count, self.count + self.read_amt): |
| | if i >= self.length: |
| | break |
| | emb_a = self.get_embedding(c, self.seqs_a[i]) |
| | emb_b = self.get_embedding(c, self.seqs_b[i]) |
| | embeddings_a.append(emb_a) |
| | embeddings_b.append(emb_b) |
| | labels.append(self.labels[i]) |
| | conn.close() |
| | self.index = 0 |
| | self.embeddings_a = embeddings_a |
| | self.embeddings_b = embeddings_b |
| | self.current_labels = labels |
| |
|
| | def __getitem__(self, idx): |
| | if self.index >= len(self.current_labels) or len(self.current_labels) == 0: |
| | self.read_embeddings() |
| |
|
| | emb_a = self.embeddings_a[self.index] |
| | emb_b = self.embeddings_b[self.index] |
| | label = self.current_labels[self.index] |
| |
|
| | self.index += 1 |
| |
|
| | |
| | if self.train and self.random_pair_flipping and random.random() < 0.5: |
| | emb_a, emb_b = emb_b, emb_a |
| |
|
| | if self.task_type in ['multilabel', 'regression', 'sigmoid_regression']: |
| | label = torch.tensor(label, dtype=torch.float) |
| | else: |
| | label = torch.tensor(label, dtype=torch.long) |
| |
|
| | return emb_a, emb_b, label |
| |
|
| |
|
| | class PairEmbedsLabelsDataset(TorchDataset): |
| | def __init__( |
| | self, |
| | hf_dataset, |
| | emb_dict, |
| | col_a='SeqA', |
| | col_b='SeqB', |
| | full=False, |
| | label_col='labels', |
| | input_size=768, |
| | task_type='regression', |
| | train=True, |
| | random_pair_flipping=False, |
| | **kwargs |
| | ): |
| | self.seqs_a = list(hf_dataset[col_a]) |
| | self.seqs_b = list(hf_dataset[col_b]) |
| | self.labels = list(hf_dataset[label_col]) |
| | self.input_size = input_size // 2 if not full else input_size |
| | self.task_type = task_type |
| | self.full = full |
| | self.train = train |
| | self.random_pair_flipping = random_pair_flipping |
| |
|
| | |
| | needed_seqs = set(list(hf_dataset[col_a]) + list(hf_dataset[col_b])) |
| | |
| | self.emb_dict = {seq: emb_dict[seq] for seq in needed_seqs if seq in emb_dict} |
| | |
| | missing_seqs = needed_seqs - self.emb_dict.keys() |
| | if missing_seqs: |
| | raise ValueError(f"Embeddings not found for sequences: {missing_seqs}") |
| |
|
| | def __len__(self): |
| | return len(self.labels) |
| | |
| | def __getitem__(self, idx): |
| | seq_a, seq_b = self.seqs_a[idx], self.seqs_b[idx] |
| | emb_a = self.emb_dict.get(seq_a).reshape(-1, self.input_size) |
| | emb_b = self.emb_dict.get(seq_b).reshape(-1, self.input_size) |
| | |
| | |
| | if self.train and self.random_pair_flipping and random.random() < 0.5: |
| | emb_a, emb_b = emb_b, emb_a |
| |
|
| | |
| | if self.task_type in ['multilabel', 'regression', 'sigmoid_regression']: |
| | label = torch.tensor(self.labels[idx], dtype=torch.float) |
| | else: |
| | label = torch.tensor(self.labels[idx], dtype=torch.long) |
| |
|
| | return emb_a, emb_b, label |
| |
|
| |
|
| | class EmbedsLabelsDatasetFromDisk(TorchDataset): |
| | def __init__( |
| | self, |
| | hf_dataset, |
| | col_name='seqs', |
| | label_col='labels', |
| | full=False, |
| | db_path='embeddings.db', |
| | batch_size=64, |
| | read_scaler=100, |
| | input_size=768, |
| | task_type='singlelabel', |
| | **kwargs |
| | ): |
| | self.seqs, self.labels = list(hf_dataset[col_name]), list(hf_dataset[label_col]) |
| | self.length = len(self.labels) |
| | self.max_length = len(max(self.seqs, key=len)) |
| | print_message(f'Max length: {self.max_length}') |
| |
|
| | self.db_file = db_path |
| | self.batch_size = batch_size |
| | self.input_size = input_size |
| | self.full = full |
| |
|
| | self.task_type = task_type |
| | self.read_amt = read_scaler * self.batch_size |
| | self.embeddings, self.current_labels = [], [] |
| | self.count, self.index = 0, 0 |
| |
|
| | self.reset_epoch() |
| |
|
| | def __len__(self): |
| | return self.length |
| |
|
| | def check_seqs(self, all_seqs): |
| | cond = False |
| | for seq in self.seqs: |
| | if seq not in all_seqs: |
| | cond = True |
| | if cond: |
| | break |
| | if cond: |
| | print_message('Sequences not found in embeddings') |
| | else: |
| | print_message('All sequences in embeddings') |
| |
|
| | def reset_epoch(self): |
| | data = list(zip(self.seqs, self.labels)) |
| | random.shuffle(data) |
| | self.seqs, self.labels = zip(*data) |
| | self.seqs, self.labels = list(self.seqs), list(self.labels) |
| | self.embeddings, self.current_labels = [], [] |
| | self.count, self.index = 0, 0 |
| |
|
| | def read_embeddings(self): |
| | embeddings, labels = [], [] |
| | self.count += self.read_amt |
| | if self.count >= self.length: |
| | self.reset_epoch() |
| | conn = sqlite3.connect(self.db_file) |
| | c = conn.cursor() |
| |
|
| | for i in range(self.count, self.count + self.read_amt): |
| | if i >= self.length: |
| | break |
| | result = c.execute("SELECT embedding FROM embeddings WHERE sequence=?", (self.seqs[i],)) |
| | row = result.fetchone() |
| | emb_data = row[0] |
| | emb = torch.tensor(np.frombuffer(emb_data, dtype=np.float32).reshape(-1, self.input_size)) |
| | if self.full: |
| | padding_needed = self.max_length - emb.size(0) |
| | emb = F.pad(emb, (0, 0, 0, padding_needed), value=0) |
| | embeddings.append(emb) |
| | labels.append(self.labels[i]) |
| | conn.close() |
| | self.index = 0 |
| | self.embeddings = embeddings |
| | self.current_labels = labels |
| |
|
| | def __getitem__(self, idx): |
| | if self.index >= len(self.current_labels) or len(self.current_labels) == 0: |
| | self.read_embeddings() |
| |
|
| | emb = self.embeddings[self.index] |
| | label = self.current_labels[self.index] |
| |
|
| | self.index += 1 |
| |
|
| | if self.task_type in ['multilabel', 'regression', 'sigmoid_regression']: |
| | label = torch.tensor(label, dtype=torch.float) |
| | else: |
| | label = torch.tensor(label, dtype=torch.long) |
| |
|
| | return emb.squeeze(0), label |
| |
|
| |
|
| | class EmbedsLabelsDataset(TorchDataset): |
| | def __init__(self, hf_dataset, emb_dict, col_name='seqs', label_col='labels', task_type='singlelabel', full=False, **kwargs): |
| | self.embeddings = self.get_embs(emb_dict, list(hf_dataset[col_name])) |
| | self.full = full |
| | self.labels = list(hf_dataset[label_col]) |
| | self.task_type = task_type |
| | self.max_length = len(max(list(hf_dataset[col_name]), key=len)) |
| | print_message(f'Max length: {self.max_length}') |
| |
|
| | def __len__(self): |
| | return len(self.labels) |
| | |
| | def get_embs(self, emb_dict, seqs): |
| | embeddings = [] |
| | for seq in tqdm(seqs, desc='Loading Embeddings'): |
| | emb = emb_dict[seq] |
| | embeddings.append(emb) |
| | return embeddings |
| |
|
| | def __getitem__(self, idx): |
| | if self.task_type in ['multilabel', 'regression', 'sigmoid_regression']: |
| | label = torch.tensor(self.labels[idx], dtype=torch.float) |
| | else: |
| | label = torch.tensor(self.labels[idx], dtype=torch.long) |
| | emb = self.embeddings[idx].float() |
| | if self.full: |
| | padding_needed = self.max_length - emb.size(0) |
| | emb = F.pad(emb, (0, 0, 0, padding_needed), value=0) |
| | return emb.squeeze(0), label |
| | |
| |
|
| | class StringLabelDataset(TorchDataset): |
| | def __init__(self, hf_dataset, col_name='seqs', label_col='labels', **kwargs): |
| | self.seqs = list(hf_dataset[col_name]) |
| | self.labels = list(hf_dataset[label_col]) |
| | self.lengths = [len(seq) for seq in self.seqs] |
| |
|
| | def avg(self): |
| | return sum(self.lengths) / len(self.lengths) |
| |
|
| | def __len__(self): |
| | return len(self.seqs) |
| | |
| | def __getitem__(self, idx): |
| | seq = self.seqs[idx] |
| | label = self.labels[idx] |
| | return seq, label |
| | |
| |
|
| | class PairStringLabelDataset(TorchDataset): |
| | def __init__(self, hf_dataset, col_a='SeqA', col_b='SeqB', label_col='labels', train=True, random_pair_flipping=False, **kwargs): |
| | self.seqs_a, self.seqs_b = list(hf_dataset[col_a]), list(hf_dataset[col_b]) |
| | self.labels = list(hf_dataset[label_col]) |
| | self.train = train |
| | self.random_pair_flipping = random_pair_flipping |
| |
|
| | def avg(self): |
| | return sum(len(seqa) + len(seqb) for seqa, seqb in zip(self.seqs_a, self.seqs_b)) / len(self.seqs_a) |
| |
|
| | def __len__(self): |
| | return len(self.seqs_a) |
| |
|
| | def __getitem__(self, idx): |
| | seq_a, seq_b = self.seqs_a[idx], self.seqs_b[idx] |
| | if self.train and self.random_pair_flipping and random.random() < 0.5: |
| | seq_a, seq_b = seq_b, seq_a |
| | return seq_a, seq_b, self.labels[idx] |
| |
|
| |
|
| | class SimpleProteinDataset(TorchDataset): |
| | """Simple dataset for protein sequences.""" |
| | def __init__(self, sequences: List[str]): |
| | self.sequences = sequences |
| |
|
| | def __len__(self) -> int: |
| | return len(self.sequences) |
| |
|
| | def __getitem__(self, idx: int) -> str: |
| | return self.sequences[idx] |
| |
|
| |
|
| | class MultiEmbedsLabelsDatasetFromDisk(TorchDataset): |
| | def __init__( |
| | self, |
| | hf_dataset, |
| | seq_cols: List[str], |
| | label_col: str = 'labels', |
| | full: bool = False, |
| | db_path: str = 'embeddings.db', |
| | batch_size: int = 64, |
| | read_scaler: int = 100, |
| | input_size: int = 768, |
| | task_type: str = 'singlelabel', |
| | train: bool = True, |
| | **kwargs, |
| | ): |
| | self.seq_cols = seq_cols |
| | self.labels = list(hf_dataset[label_col]) |
| | self.length = len(self.labels) |
| | self.full = full |
| | self.db_file = db_path |
| | self.batch_size = batch_size |
| | self.read_amt = read_scaler * self.batch_size |
| | self.input_size = input_size // len(seq_cols) if not full else input_size |
| | self.task_type = task_type |
| | self.train = train |
| |
|
| | |
| | self.col_to_seqs = {col: list(hf_dataset[col]) for col in seq_cols} |
| |
|
| | |
| | if self.full: |
| | def combined_len_at(i: int) -> int: |
| | return sum(len(self.col_to_seqs[c][i]) for c in self.seq_cols) + (len(self.seq_cols) - 1) |
| | self.max_length = max(combined_len_at(i) for i in range(self.length)) if self.length > 0 else 0 |
| |
|
| | self.embeddings, self.current_labels = [], [] |
| | self.count, self.index = 0, 0 |
| |
|
| | def __len__(self): |
| | return self.length |
| |
|
| | def reset_epoch(self): |
| | |
| | idxs = list(range(self.length)) |
| | random.shuffle(idxs) |
| | for col in self.seq_cols: |
| | self.col_to_seqs[col] = [self.col_to_seqs[col][i] for i in idxs] |
| | self.labels = [self.labels[i] for i in idxs] |
| | self.embeddings, self.current_labels = [], [] |
| | self.count, self.index = 0, 0 |
| |
|
| | def _get_embedding(self, c, seq: str) -> torch.Tensor: |
| | result = c.execute("SELECT embedding FROM embeddings WHERE sequence=?", (seq,)) |
| | row = result.fetchone() |
| | if row is None: |
| | raise ValueError(f"Embedding not found for sequence: {seq}") |
| | emb_data = row[0] |
| | emb = torch.tensor(np.frombuffer(emb_data, dtype=np.float32).reshape(-1, self.input_size)) |
| | return emb |
| |
|
| | def _combine_matrix(self, parts: List[torch.Tensor]) -> torch.Tensor: |
| | |
| | if len(parts) == 0: |
| | return torch.zeros(0, self.input_size) |
| | sep = torch.zeros(1, self.input_size, dtype=parts[0].dtype) |
| | out = [] |
| | for i, p in enumerate(parts): |
| | out.append(p) |
| | if i < len(parts) - 1: |
| | out.append(sep) |
| | return torch.cat(out, dim=0) |
| |
|
| | def read_embeddings(self): |
| | embeddings, labels = [], [] |
| | self.count += self.read_amt |
| | if self.count >= self.length: |
| | self.reset_epoch() |
| | conn = sqlite3.connect(self.db_file) |
| | c = conn.cursor() |
| |
|
| | for i in range(self.count, self.count + self.read_amt): |
| | if i >= self.length: |
| | break |
| | parts = [self._get_embedding(c, self.col_to_seqs[col][i]) for col in self.seq_cols] |
| | if self.full: |
| | emb = self._combine_matrix(parts) |
| | |
| | if self.full and self.max_length: |
| | pad_needed = self.max_length - emb.size(0) |
| | if pad_needed > 0: |
| | emb = F.pad(emb, (0, 0, 0, pad_needed), value=0) |
| | else: |
| | |
| | emb = torch.cat([p.reshape(1, -1) for p in parts], dim=-1) |
| | embeddings.append(emb) |
| | labels.append(self.labels[i]) |
| | conn.close() |
| | self.index = 0 |
| | self.embeddings = embeddings |
| | self.current_labels = labels |
| |
|
| | def __getitem__(self, idx): |
| | if self.index >= len(self.current_labels) or len(self.current_labels) == 0: |
| | self.read_embeddings() |
| |
|
| | emb = self.embeddings[self.index] |
| | label = self.current_labels[self.index] |
| | self.index += 1 |
| |
|
| | if self.task_type in ['multilabel', 'regression', 'sigmoid_regression']: |
| | label = torch.tensor(label, dtype=torch.float) |
| | else: |
| | label = torch.tensor(label, dtype=torch.long) |
| |
|
| | return emb.squeeze(0), label |
| |
|
| |
|
| | class MultiEmbedsLabelsDataset(TorchDataset): |
| | def __init__( |
| | self, |
| | hf_dataset, |
| | seq_cols: List[str], |
| | label_col: str = 'labels', |
| | full: bool = False, |
| | emb_dict: dict = None, |
| | input_size: int = 768, |
| | task_type: str = 'singlelabel', |
| | train: bool = True, |
| | **kwargs, |
| | ): |
| | self.seq_cols = seq_cols |
| | self.labels = list(hf_dataset[label_col]) |
| | self.full = full |
| | self.input_size = input_size // len(seq_cols) if not full else input_size |
| | self.task_type = task_type |
| | self.train = train |
| |
|
| | self.col_to_seqs = {col: list(hf_dataset[col]) for col in seq_cols} |
| |
|
| | |
| | self.embeddings = [] |
| | if self.full: |
| | |
| | def combined_len_at(i: int) -> int: |
| | return sum(len(self.col_to_seqs[c][i]) for c in self.seq_cols) + (len(self.seq_cols) - 1) |
| | self.max_length = max(combined_len_at(i) for i in range(len(self.labels))) if len(self.labels) > 0 else 0 |
| |
|
| | for i in tqdm(range(len(self.labels)), desc='Loading Multi-Embeddings'): |
| | parts = [] |
| | for col in self.seq_cols: |
| | seq = self.col_to_seqs[col][i] |
| | emb = emb_dict[seq] |
| | emb = emb.reshape(-1, self.input_size) |
| | parts.append(emb) |
| | if self.full: |
| | emb = self._combine_matrix(parts) |
| | |
| | if self.max_length: |
| | pad_needed = self.max_length - emb.size(0) |
| | if pad_needed > 0: |
| | emb = F.pad(emb, (0, 0, 0, pad_needed), value=0) |
| | else: |
| | emb = torch.cat([p.reshape(1, -1) for p in parts], dim=-1) |
| | self.embeddings.append(emb) |
| |
|
| | def _combine_matrix(self, parts: List[torch.Tensor]) -> torch.Tensor: |
| | if len(parts) == 0: |
| | return torch.zeros(0, self.input_size) |
| | sep = torch.zeros(1, self.input_size, dtype=parts[0].dtype) |
| | out = [] |
| | for i, p in enumerate(parts): |
| | out.append(p) |
| | if i < len(parts) - 1: |
| | out.append(sep) |
| | return torch.cat(out, dim=0) |
| |
|
| | def __len__(self): |
| | return len(self.labels) |
| |
|
| | def __getitem__(self, idx): |
| | if self.task_type in ['multilabel', 'regression', 'sigmoid_regression']: |
| | label = torch.tensor(self.labels[idx], dtype=torch.float) |
| | else: |
| | label = torch.tensor(self.labels[idx], dtype=torch.long) |
| | emb = self.embeddings[idx].float() |
| | return emb.squeeze(0), label |
| | |