import torch import numpy as np import random import os import sqlite3 from typing import List, Tuple, Dict, Optional from glob import glob from pandas import read_csv, read_excel from datasets import load_dataset, Dataset from dataclasses import dataclass try: from utils import print_message from seed_utils import get_global_seed from embedder import get_embedding_filename except ImportError: from ..utils import print_message from ..seed_utils import get_global_seed from ..embedder import get_embedding_filename from .supported_datasets import supported_datasets, standard_data_benchmark, vector_benchmark from .utils import ( AA_SET, CODON_SET, DNA_SET, RNA_SET, NONCANONICAL_AMINO_ACIDS, AMINO_ACID_TO_HUMAN_CODON, NONCANONICAL_ALANINE_CODON, AA_TO_CODON_TOKEN, CODON_TO_AA, DNA_CODON_TO_AA, RNA_CODON_TO_AA, ) @dataclass class DataArguments: """ Args: data_paths: List[str] paths to the datasets max_length: int max length of sequences trim: bool whether to trim sequences to max_length """ def __init__( self, data_names: List[str], delimiter: str = ',', col_names: List[str] = ['seqs', 'labels'], max_length: int = 1024, trim: bool = False, data_dirs: Optional[List[str]] = [], multi_column: Optional[List[str]] = None, aa_to_dna: bool = False, aa_to_rna: bool = False, dna_to_aa: bool = False, rna_to_aa: bool = False, codon_to_aa: bool = False, aa_to_codon: bool = False, **kwargs ): self.data_names = data_names self.data_dirs = data_dirs self.delimiter = delimiter self.col_names = col_names self.max_length = max_length self.trim = trim self.protein_gym = False self.multi_column = multi_column self.aa_to_dna = aa_to_dna self.aa_to_rna = aa_to_rna self.dna_to_aa = dna_to_aa self.rna_to_aa = rna_to_aa self.codon_to_aa = codon_to_aa self.aa_to_codon = aa_to_codon if len(data_names) > 0: if data_names[0] == 'standard_benchmark': self.data_paths = [supported_datasets[data_name] for data_name in standard_data_benchmark] elif data_names[0] == 'vector_benchmark': self.data_paths = [supported_datasets[data_name] for data_name in vector_benchmark] else: self.data_paths = [] for data_name in data_names: if data_name == 'protein_gym': # For special handling in the main workflow self.protein_gym = True continue if data_name in supported_datasets: self.data_paths.append(supported_datasets[data_name]) else: print(f'{data_name} not found in supported datasets') print('We will attempt to load it from huggingface anyways, but this may not work') self.data_paths.append(data_name) else: self.data_paths = [] if data_dirs is not None: for dir in data_dirs: if not os.path.exists(dir): raise FileNotFoundError(f'{dir} does not exist') class DataMixin: def __init__(self, data_args: Optional[DataArguments] = None): # intialize defaults self._sql = False self._full = False self._max_length = 1024 self._trim = False self._delimiter = ',' self._col_names = ['seqs', 'labels'] self._aa_to_dna = False self._aa_to_rna = False self._dna_to_aa = False self._rna_to_aa = False self._codon_to_aa = False self._aa_to_codon = False self.data_args = data_args self._multi_column = None if data_args is None else getattr(data_args, 'multi_column', None) if data_args is not None: self._aa_to_dna = data_args.aa_to_dna self._aa_to_rna = data_args.aa_to_rna self._dna_to_aa = data_args.dna_to_aa self._rna_to_aa = data_args.rna_to_aa self._codon_to_aa = data_args.codon_to_aa self._aa_to_codon = data_args.aa_to_codon def _not_regression(self, labels): # not a great assumption but works most of the time if isinstance(labels, list): # Check if first element is itself a list (multilabel case) if isinstance(labels[0], list): # For multilabel: check all elements in all sublists return all(isinstance(element, (int, float)) and element == int(element) for label in labels for element in label) else: # For single label: check all elements in the list return all(isinstance(label, (int, float)) and label == int(label) for label in labels) else: # Fallback for non-list input return all(isinstance(label, (int, float)) and label == int(label) for label in labels) def _encode_labels(self, labels, tag2id): return [torch.tensor([tag2id[tag] for tag in doc], dtype=torch.long) for doc in labels] def _label_type_checker(self, labels): ex = labels[0] assert len(labels) > 0, f'Labels is empty: {labels}' if self._not_regression(labels): if isinstance(ex, list): label_type = 'multilabel' elif isinstance(ex, int) or isinstance(ex, float): label_type = 'singlelabel' # binary or multiclass elif isinstance(ex, str): label_type = 'string' else: label_type = 'regression' return label_type def _is_sigmoid_regression(self, labels) -> bool: """Heuristic: labels within [0, 1] and cover the range approximately. Uses 10-bin histogram coverage and span threshold. """ arr = [] for label in labels: try: arr.extend(label) except: arr.append(label) arr = np.array(arr, dtype=float).flatten() min_val, max_val = float(arr.min()), float(arr.max()) cond1 = min_val > 0.0 - 1e-6 and max_val < 1.0 + 1e-6 # Require substantial span across [0,1] cond2 = (max_val - min_val) > 0.75 # Histogram coverage: at least 7 of 10 bins non-empty hist, _ = np.histogram(arr, bins=10, range=(0.0, 1.0)) cond3 = int((hist > 0).sum()) >= 7 sigmoid_regression_status = cond1 and cond2 and cond3 return sigmoid_regression_status def _select_from_sql(self, c, seq, cast_to_torch=True): c.execute("SELECT embedding FROM embeddings WHERE sequence = ?", (seq,)) embedding = np.frombuffer(c.fetchone()[0], dtype=np.float32).reshape(1, -1) if self._full: embedding = embedding.reshape(len(seq), -1) if cast_to_torch: embedding = torch.tensor(embedding) return embedding def _select_from_pth(self, emb_dict, seq, cast_to_np=False): embedding = emb_dict[seq].reshape(1, -1) if self._full: embedding = embedding.reshape(len(seq), -1) if cast_to_np: embedding = embedding.numpy() return embedding def _labels_to_numpy(self, labels): if isinstance(labels[0], list): return np.array(labels).flatten() else: return np.array([labels]).flatten() def _random_order(self, seq_a, seq_b): if random.random() < 0.5: return seq_a, seq_b else: return seq_b, seq_a def _truncate_pairs(self, ex): # Truncate longest first, but if that makes it shorter than the other, truncate that one seq_a, seq_b = ex['SeqA'], ex['SeqB'] trunc_a, trunc_b = seq_a, seq_b while len(trunc_a) + len(trunc_b) > self._max_length: if len(trunc_a) > len(trunc_b): trunc_a = trunc_a[:-1] else: trunc_b = trunc_b[:-1] ex['SeqA'] = trunc_a ex['SeqB'] = trunc_b return ex def _active_translation_mode(self): mode_to_flag = { 'aa_to_dna': self._aa_to_dna, 'aa_to_rna': self._aa_to_rna, 'dna_to_aa': self._dna_to_aa, 'rna_to_aa': self._rna_to_aa, 'codon_to_aa': self._codon_to_aa, 'aa_to_codon': self._aa_to_codon, } active_modes = [mode for mode, enabled in mode_to_flag.items() if enabled] assert len(active_modes) <= 1, f'Only one translation mode can be enabled at a time, found: {active_modes}' return active_modes[0] if len(active_modes) == 1 else None def _assert_characters_in_set(self, seq, allowed_chars, mode): bad_chars = sorted({char for char in seq if char.upper() not in allowed_chars}) assert len(bad_chars) == 0, f'Invalid characters for {mode}: {bad_chars}.' def _validate_translated_output(self, translated_seq, allowed_chars, mode): bad_chars = sorted({char for char in translated_seq if char not in allowed_chars}) assert len(bad_chars) == 0, f'Translation output for {mode} contains unexpected characters: {bad_chars}.' def _normalize_aa_for_nucleotide_translation(self, seq): canonical_aas = set(AMINO_ACID_TO_HUMAN_CODON.keys()) normalized = [] for residue in seq: residue = residue.upper() if residue in canonical_aas: normalized.append(residue) else: normalized.append('X') return ''.join(normalized) def _translate_aa_to_dna(self, seq): seq = self._normalize_aa_for_nucleotide_translation(seq) dna_codons = [] for residue in seq: residue = residue.upper() if residue in AMINO_ACID_TO_HUMAN_CODON: dna_codons.append(AMINO_ACID_TO_HUMAN_CODON[residue]) elif residue in NONCANONICAL_AMINO_ACIDS: dna_codons.append(NONCANONICAL_ALANINE_CODON) else: raise AssertionError(f'Unexpected amino acid token "{residue}" while converting aa_to_dna.') translated = ''.join(dna_codons) self._validate_translated_output(translated, DNA_SET, 'aa_to_dna') return translated def _translate_aa_to_rna(self, seq): dna_translated = self._translate_aa_to_dna(seq) translated = dna_translated.replace('T', 'U') self._validate_translated_output(translated, RNA_SET, 'aa_to_rna') return translated def _translate_dna_to_aa(self, seq): dna_seq = seq.upper() self._assert_characters_in_set(dna_seq, DNA_SET, 'dna_to_aa') assert len(dna_seq) % 3 == 0, f'dna_to_aa requires sequence length multiple of 3, got {len(dna_seq)}.' aa_seq = [] for idx in range(0, len(dna_seq), 3): codon = dna_seq[idx:idx + 3] assert codon in DNA_CODON_TO_AA, f'Unknown DNA codon for dna_to_aa: {codon}' translated_char = DNA_CODON_TO_AA[codon] if translated_char != '*': aa_seq.append(translated_char) translated = ''.join(aa_seq) self._validate_translated_output(translated, AA_SET - {'*'}, 'dna_to_aa') return translated def _translate_rna_to_aa(self, seq): rna_seq = seq.upper() self._assert_characters_in_set(rna_seq, RNA_SET, 'rna_to_aa') assert len(rna_seq) % 3 == 0, f'rna_to_aa requires sequence length multiple of 3, got {len(rna_seq)}.' aa_seq = [] for idx in range(0, len(rna_seq), 3): codon = rna_seq[idx:idx + 3] assert codon in RNA_CODON_TO_AA, f'Unknown RNA codon for rna_to_aa: {codon}' translated_char = RNA_CODON_TO_AA[codon] if translated_char != '*': aa_seq.append(translated_char) translated = ''.join(aa_seq) self._validate_translated_output(translated, AA_SET - {'*'}, 'rna_to_aa') return translated def _translate_codon_to_aa(self, seq): aa_seq = [] for token in seq: assert token in CODON_TO_AA, f'Unknown codon token for codon_to_aa: {token}' translated_char = CODON_TO_AA[token] if translated_char != '*': aa_seq.append(translated_char) translated = ''.join(aa_seq) self._validate_translated_output(translated, AA_SET - {'*'}, 'codon_to_aa') return translated def _translate_aa_to_codon(self, seq): codon_tokens = [] for residue in seq: residue = residue.upper() if residue in AA_TO_CODON_TOKEN: codon_tokens.append(AA_TO_CODON_TOKEN[residue]) elif residue in NONCANONICAL_AMINO_ACIDS: codon_tokens.append(AA_TO_CODON_TOKEN['A']) else: raise AssertionError(f'Unexpected amino acid token "{residue}" while converting aa_to_codon.') translated = ''.join(codon_tokens) self._validate_translated_output(translated, CODON_SET, 'aa_to_codon') return translated def _translate_sequence_for_mode(self, seq, mode): if mode == 'aa_to_dna': return self._translate_aa_to_dna(seq) if mode == 'aa_to_rna': return self._translate_aa_to_rna(seq) if mode == 'dna_to_aa': return self._translate_dna_to_aa(seq) if mode == 'rna_to_aa': return self._translate_rna_to_aa(seq) if mode == 'codon_to_aa': return self._translate_codon_to_aa(seq) if mode == 'aa_to_codon': return self._translate_aa_to_codon(seq) raise AssertionError(f'Unsupported translation mode: {mode}') def _find_first_present_column(self, available_columns, candidates_ordered): """Return the first column from candidates_ordered that exists in available_columns (case-insensitive).""" lowercase_to_actual = {col.lower(): col for col in available_columns} for candidate in candidates_ordered: actual = lowercase_to_actual.get(candidate.lower()) if actual is not None: return actual raise KeyError(f"None of the candidate columns were found. Candidates: {candidates_ordered}. Available: {available_columns}") def _is_ppi_from_columns(self, available_columns): """Detect if dataset contains paired sequence inputs (SeqA/SeqB variants).""" lowercase_columns = set(col.lower() for col in available_columns) base_candidates = ['seqs', 'seq', 'sequence', 'sequences'] for base in base_candidates: if (base + 'a') in lowercase_columns and (base + 'b') in lowercase_columns: return True return False def _find_ppi_sequence_columns(self, available_columns): """Return the actual column names for A and B sequences in PPI datasets based on priority.""" lowercase_to_actual = {col.lower(): col for col in available_columns} # Try specific common pairs first (in order) specific_pairs = [ ('SeqA', 'SeqB'), ('seqa', 'seqb'), ('SeqsA', 'SeqsB'), ] for cand_a, cand_b in specific_pairs: a_actual = lowercase_to_actual.get(cand_a.lower()) b_actual = lowercase_to_actual.get(cand_b.lower()) if a_actual is not None and b_actual is not None: return a_actual, b_actual # Generalized search using base tokens base_candidates = ['seqs', 'seq', 'sequence', 'sequences'] for base in base_candidates: a_key = (base + 'a').lower() b_key = (base + 'b').lower() a_actual = lowercase_to_actual.get(a_key) b_actual = lowercase_to_actual.get(b_key) if a_actual is not None and b_actual is not None: return a_actual, b_actual raise KeyError(f"Could not find paired sequence columns for PPI. Available: {available_columns}") def _is_missing_value(self, v): if v is None: return True # float/np.nan handling try: if isinstance(v, float) and np.isnan(v): return True except Exception: pass # list/array handling (check any element) if isinstance(v, (list, tuple, np.ndarray)): for el in v: if el is None: return True if isinstance(el, float) and np.isnan(el): return True return False def process_datasets( self, hf_datasets: List[Tuple[Dataset, Dataset, Dataset, bool]], data_names: List[str], )-> Tuple[Dict[str, Tuple[Dataset, Dataset, Dataset, int, str, bool]], List[str]]: max_length = self._max_length datasets, all_seqs = {}, set() translation_mode = self._active_translation_mode() for dataset, data_name in zip(hf_datasets, data_names): print_message(f'Processing {data_name}') train_set, valid_set, test_set, ppi = dataset print(train_set) print(valid_set) print(test_set) ### sanitize # 1) Drop rows with None or NaN in any sequence column(s) or labels before_train, before_valid, before_test = len(train_set), len(valid_set), len(test_set) if ppi: train_set = train_set.filter(lambda x: not (self._is_missing_value(x['SeqA']) or self._is_missing_value(x['SeqB']) or self._is_missing_value(x['labels']))) valid_set = valid_set.filter(lambda x: not (self._is_missing_value(x['SeqA']) or self._is_missing_value(x['SeqB']) or self._is_missing_value(x['labels']))) test_set = test_set.filter(lambda x: not (self._is_missing_value(x['SeqA']) or self._is_missing_value(x['SeqB']) or self._is_missing_value(x['labels']))) elif self.data_args.multi_column: cols = self.data_args.multi_column # assert columns exist for col in cols: assert col in train_set.column_names or col in valid_set.column_names or col in test_set.column_names, f"Column {col} not found in dataset {data_name}" def _filter_row(x): return (not self._is_missing_value(x['labels'])) and all(not self._is_missing_value(x[col]) for col in cols) train_set = train_set.filter(_filter_row) valid_set = valid_set.filter(_filter_row) test_set = test_set.filter(_filter_row) else: train_set = train_set.filter(lambda x: not (self._is_missing_value(x['seqs']) or self._is_missing_value(x['labels']))) valid_set = valid_set.filter(lambda x: not (self._is_missing_value(x['seqs']) or self._is_missing_value(x['labels']))) test_set = test_set.filter(lambda x: not (self._is_missing_value(x['seqs']) or self._is_missing_value(x['labels']))) if any([ len(train_set) != before_train, len(valid_set) != before_valid, len(test_set) != before_test, ]): print_message( f"Removed None / NaN rows - train: {before_train - len(train_set)}, valid: {before_valid - len(valid_set)}, test: {before_test - len(test_set)}" ) # 2) Legacy sanitization for non-translation workflows if translation_mode is None: if ppi: train_set = train_set.map(lambda x: {'SeqA': ''.join(aa for aa in x['SeqA'] if aa in AA_SET), 'SeqB': ''.join(aa for aa in x['SeqB'] if aa in AA_SET)}) valid_set = valid_set.map(lambda x: {'SeqA': ''.join(aa for aa in x['SeqA'] if aa in AA_SET), 'SeqB': ''.join(aa for aa in x['SeqB'] if aa in AA_SET)}) test_set = test_set.map(lambda x: {'SeqA': ''.join(aa for aa in x['SeqA'] if aa in AA_SET), 'SeqB': ''.join(aa for aa in x['SeqB'] if aa in AA_SET)}) elif self.data_args.multi_column: cols = self.data_args.multi_column for col in cols: train_set = train_set.map(lambda x, _col=col: {_col: ''.join(aa for aa in x[_col] if aa in AA_SET)}) valid_set = valid_set.map(lambda x, _col=col: {_col: ''.join(aa for aa in x[_col] if aa in AA_SET)}) test_set = test_set.map(lambda x, _col=col: {_col: ''.join(aa for aa in x[_col] if aa in AA_SET)}) else: train_set = train_set.map(lambda x: {'seqs': ''.join(aa for aa in x['seqs'] if aa in AA_SET)}) valid_set = valid_set.map(lambda x: {'seqs': ''.join(aa for aa in x['seqs'] if aa in AA_SET)}) test_set = test_set.map(lambda x: {'seqs': ''.join(aa for aa in x['seqs'] if aa in AA_SET)}) # 3) Remove any length 0 sequences before_train, before_valid, before_test = len(train_set), len(valid_set), len(test_set) if ppi: train_set = train_set.filter(lambda x: len(x['SeqA']) > 0 and len(x['SeqB']) > 0) valid_set = valid_set.filter(lambda x: len(x['SeqA']) > 0 and len(x['SeqB']) > 0) test_set = test_set.filter(lambda x: len(x['SeqA']) > 0 and len(x['SeqB']) > 0) elif self.data_args.multi_column: cols = self.data_args.multi_column train_set = train_set.filter(lambda x: all(len(x[col]) > 0 for col in cols)) valid_set = valid_set.filter(lambda x: all(len(x[col]) > 0 for col in cols)) test_set = test_set.filter(lambda x: all(len(x[col]) > 0 for col in cols)) else: train_set = train_set.filter(lambda x: len(x['seqs']) > 0) valid_set = valid_set.filter(lambda x: len(x['seqs']) > 0) test_set = test_set.filter(lambda x: len(x['seqs']) > 0) if any([ len(train_set) != before_train, len(valid_set) != before_valid, len(test_set) != before_test, ]): print_message( f"Removed length 0 rows - train: {before_train - len(train_set)}, valid: {before_valid - len(valid_set)}, test: {before_test - len(test_set)}" ) # 4) Trim or truncate by length if necessary before_train, before_valid, before_test = len(train_set), len(valid_set), len(test_set) if self._trim: # trim by length if ppi: train_set = train_set.filter(lambda x: len(x['SeqA']) + len(x['SeqB']) <= max_length) valid_set = valid_set.filter(lambda x: len(x['SeqA']) + len(x['SeqB']) <= max_length) test_set = test_set.filter(lambda x: len(x['SeqA']) + len(x['SeqB']) <= max_length) elif self.data_args.multi_column: cols = self.data_args.multi_column train_set = train_set.filter(lambda x: all(len(x[col]) <= max_length for col in cols)) valid_set = valid_set.filter(lambda x: all(len(x[col]) <= max_length for col in cols)) test_set = test_set.filter(lambda x: all(len(x[col]) <= max_length for col in cols)) else: train_set = train_set.filter(lambda x: len(x['seqs']) <= max_length) valid_set = valid_set.filter(lambda x: len(x['seqs']) <= max_length) test_set = test_set.filter(lambda x: len(x['seqs']) <= max_length) else: # truncate to max_length if ppi: train_set = train_set.map(self._truncate_pairs) valid_set = valid_set.map(self._truncate_pairs) test_set = test_set.map(self._truncate_pairs) elif self.data_args.multi_column: cols = self.data_args.multi_column for col in cols: train_set = train_set.map(lambda x, _col=col: { _col: x[_col][:max_length] }) valid_set = valid_set.map(lambda x, _col=col: { _col: x[_col][:max_length] }) test_set = test_set.map(lambda x, _col=col: { _col: x[_col][:max_length] }) else: train_set = train_set.map(lambda x: {'seqs': x['seqs'][:max_length]}) valid_set = valid_set.map(lambda x: {'seqs': x['seqs'][:max_length]}) test_set = test_set.map(lambda x: {'seqs': x['seqs'][:max_length]}) if any([ len(train_set) != before_train, len(valid_set) != before_valid, len(test_set) != before_test, ]): print_message( f"Trimmed rows - train: {(before_train - len(train_set)) / before_train * 100:.2f}%, \ valid: {(before_valid - len(valid_set)) / before_valid * 100:.2f}%, \ test: {(before_test - len(test_set)) / before_test * 100:.2f}%" ) # 5) Optional sequence translation (post-trim/truncate) if translation_mode is not None: if ppi: train_set = train_set.map(lambda x: {'SeqA': self._translate_sequence_for_mode(x['SeqA'], translation_mode), 'SeqB': self._translate_sequence_for_mode(x['SeqB'], translation_mode)}) valid_set = valid_set.map(lambda x: {'SeqA': self._translate_sequence_for_mode(x['SeqA'], translation_mode), 'SeqB': self._translate_sequence_for_mode(x['SeqB'], translation_mode)}) test_set = test_set.map(lambda x: {'SeqA': self._translate_sequence_for_mode(x['SeqA'], translation_mode), 'SeqB': self._translate_sequence_for_mode(x['SeqB'], translation_mode)}) elif self.data_args.multi_column: cols = self.data_args.multi_column for col in cols: train_set = train_set.map(lambda x, _col=col: {_col: self._translate_sequence_for_mode(x[_col], translation_mode)}) valid_set = valid_set.map(lambda x, _col=col: {_col: self._translate_sequence_for_mode(x[_col], translation_mode)}) test_set = test_set.map(lambda x, _col=col: {_col: self._translate_sequence_for_mode(x[_col], translation_mode)}) else: train_set = train_set.map(lambda x: {'seqs': self._translate_sequence_for_mode(x['seqs'], translation_mode)}) valid_set = valid_set.map(lambda x: {'seqs': self._translate_sequence_for_mode(x['seqs'], translation_mode)}) test_set = test_set.map(lambda x: {'seqs': self._translate_sequence_for_mode(x['seqs'], translation_mode)}) print_message(f"Translated sequences with mode {translation_mode} (post-trim/truncate).") # 6) Record all_seqs if ppi: all_seqs.update(list(train_set['SeqA']) + list(train_set['SeqB'])) all_seqs.update(list(valid_set['SeqA']) + list(valid_set['SeqB'])) all_seqs.update(list(test_set['SeqA']) + list(test_set['SeqB'])) elif self.data_args.multi_column: cols = self.data_args.multi_column for col in cols: all_seqs.update(list(train_set[col])) all_seqs.update(list(valid_set[col])) all_seqs.update(list(test_set[col])) else: all_seqs.update(list(train_set['seqs'])) all_seqs.update(list(valid_set['seqs'])) all_seqs.update(list(test_set['seqs'])) # confirm the type of labels check_labels = list(valid_set['labels']) label_type = self._label_type_checker(check_labels) if label_type == 'string': # might be string or multilabel example = list(valid_set['labels'])[0] try: import ast new_ex = ast.literal_eval(example) if isinstance(new_ex, list): # if ast runs correctly and is now a list it is multilabel labels label_type = 'multilabel' train_set = train_set.map(lambda ex: {'labels': ast.literal_eval(ex['labels'])}) valid_set = valid_set.map(lambda ex: {'labels': ast.literal_eval(ex['labels'])}) test_set = test_set.map(lambda ex: {'labels': ast.literal_eval(ex['labels'])}) except: label_type = 'string' # if ast throws error it is actually string if label_type == 'string': # if still string, it's for tokenwise classification train_labels = list(train_set['labels']) unique_tags = set(tag for doc in train_labels for tag in doc) tag2id = {tag: id for id, tag in enumerate(sorted(unique_tags))} # add cls token to labels train_set = train_set.map(lambda ex: {'labels': self._encode_labels(ex['labels'], tag2id=tag2id)}) valid_set = valid_set.map(lambda ex: {'labels': self._encode_labels(ex['labels'], tag2id=tag2id)}) test_set = test_set.map(lambda ex: {'labels': self._encode_labels(ex['labels'], tag2id=tag2id)}) label_type = 'tokenwise' num_labels = len(unique_tags) else: if label_type == 'regression': # Detect sigmoid_regression (values in [0,1] covering the range) if self._is_sigmoid_regression(list(train_set['labels'])): label_type = 'sigmoid_regression' num_labels = 1 else: # if classification, get the total number of leabels try: train_labels_list = list(train_set['labels']) num_labels = len(train_labels_list[0]) except: unique = np.unique(list(train_set['labels'])) max_label = max(unique) # sometimes there are missing labels full_list = np.arange(0, max_label+1) num_labels = len(full_list) datasets[data_name] = (train_set, valid_set, test_set, num_labels, label_type, ppi) print(f'Label type: {label_type}') print(f'Number of labels: {num_labels}') all_seqs = list(all_seqs) all_seqs = sorted(all_seqs, key=len, reverse=True) # longest first return datasets, all_seqs def get_data(self): """ Supports .csv, .tsv, .txt TODO fasta, fa, fna, etc. """ datasets, data_names = [], [] label_candidates = ['labels', 'label', 'Labels', 'Label'] seq_candidates = ['seqs', 'Seqs', 'seq', 'Seq', 'sequence', 'Sequence', 'sequences', 'Sequences'] for data_path in self.data_args.data_paths: data_name = data_path.split('/')[-1] print_message(f'Loading {data_name}') dataset = load_dataset(data_path) if 'inverse' in data_name.lower(): dataset = dataset.rename_columns({'seqs': 'labels', 'labels': 'seqs'}) ppi = 'SeqA' in dataset['train'].column_names # Fallback PPI detection based on available columns if not ppi: ppi = self._is_ppi_from_columns(dataset['train'].column_names) print_message(f'PPI (or dual sequence input dataset): {ppi}') ### TODO, add better handling for valid, validation, test, testing, etc. assert 'train' in dataset, f'{data_name} does not have a train set' assert 'valid' in dataset or 'test' in dataset, f'{data_name} does not have a valid or test set, needs at least one' if 'valid' not in dataset: seed = get_global_seed() if get_global_seed() is not None else 42 train_set = dataset['train'] train_valid_set = train_set.train_test_split(test_size=0.1, seed=seed + 1) train_set = train_valid_set['train'] valid_set = train_valid_set['test'] test_set = dataset['test'] print_message(f'{data_name} does not have a valid set, created a 10% validation set') elif 'test' not in dataset: seed = get_global_seed() if get_global_seed() is not None else 42 train_set = dataset['train'] train_test_set = train_set.train_test_split(test_size=0.1, seed=seed + 2) test_set = train_test_set['test'] train_set = train_test_set['train'] valid_set = dataset['valid'] print_message(f'{data_name} does not have a test set, created a 10% test set') else: train_set, valid_set, test_set = dataset['train'], dataset['valid'], dataset['test'] print_message(f'{data_name} has a valid and test set') print_message(f'Train set: {len(train_set)}, Valid set: {len(valid_set)}, Test set: {len(test_set)}') if ppi: # Standardize PPI columns to 'SeqA', 'SeqB', and 'labels' print('Standardizing PPI column names') try: a_col, b_col = self._find_ppi_sequence_columns(train_set.column_names) except KeyError: # Retry with validation/test in case train is empty or missing columns try: a_col, b_col = self._find_ppi_sequence_columns(valid_set.column_names) except KeyError: a_col, b_col = self._find_ppi_sequence_columns(test_set.column_names) try: lbl_col = self._find_first_present_column(train_set.column_names, label_candidates) except KeyError: try: lbl_col = self._find_first_present_column(valid_set.column_names, label_candidates) except KeyError: lbl_col = self._find_first_present_column(test_set.column_names, label_candidates) train_set = train_set.rename_columns({a_col: 'SeqA', b_col: 'SeqB', lbl_col: 'labels'}) valid_set = valid_set.rename_columns({a_col: 'SeqA', b_col: 'SeqB', lbl_col: 'labels'}) test_set = test_set.rename_columns({a_col: 'SeqA', b_col: 'SeqB', lbl_col: 'labels'}) print('Removing extras') train_set = train_set.remove_columns([col for col in train_set.column_names if col not in ['SeqA', 'SeqB', 'labels']]) valid_set = valid_set.remove_columns([col for col in valid_set.column_names if col not in ['SeqA', 'SeqB', 'labels']]) test_set = test_set.remove_columns([col for col in test_set.column_names if col not in ['SeqA', 'SeqB', 'labels']]) else: print('Standardizing column names') use_multi = self.data_args.multi_column is not None if not use_multi: try: seq_col = self._find_first_present_column(train_set.column_names, seq_candidates) except KeyError: try: seq_col = self._find_first_present_column(valid_set.column_names, seq_candidates) except KeyError: seq_col = self._find_first_present_column(test_set.column_names, seq_candidates) try: label_col = self._find_first_present_column(train_set.column_names, label_candidates) except KeyError: try: label_col = self._find_first_present_column(valid_set.column_names, label_candidates) except KeyError: label_col = self._find_first_present_column(test_set.column_names, label_candidates) # Always standardize label column to 'labels' train_set = train_set.rename_columns({label_col: 'labels'}) valid_set = valid_set.rename_columns({label_col: 'labels'}) test_set = test_set.rename_columns({label_col: 'labels'}) if not use_multi: train_set = train_set.rename_columns({seq_col: 'seqs'}) valid_set = valid_set.rename_columns({seq_col: 'seqs'}) test_set = test_set.rename_columns({seq_col: 'seqs'}) # drop everything else print('Removing extras') train_set = train_set.remove_columns([col for col in train_set.column_names if col not in ['seqs', 'labels']]) valid_set = valid_set.remove_columns([col for col in valid_set.column_names if col not in ['seqs', 'labels']]) test_set = test_set.remove_columns([col for col in test_set.column_names if col not in ['seqs', 'labels']]) else: # Validate requested multi columns exist (assert exact match) for col in self.data_args.multi_column: assert col in train_set.column_names or col in valid_set.column_names or col in test_set.column_names, f"Column {col} not found in dataset {data_name}" # Keep only requested columns and labels keep_cols = set(self.data_args.multi_column + ['labels']) train_set = train_set.remove_columns([col for col in train_set.column_names if col not in keep_cols]) valid_set = valid_set.remove_columns([col for col in valid_set.column_names if col not in keep_cols]) test_set = test_set.remove_columns([col for col in test_set.column_names if col not in keep_cols]) datasets.append((train_set, valid_set, test_set, ppi)) data_names.append(data_name) for data_dir in self.data_args.data_dirs: # local_data/taxon data_name = data_dir.split ('/')[-1] # Determine PPI by directory hint or columns ppi = 'ppi' in data_dir.lower() train_path = glob(os.path.join(data_dir, 'train.*'))[0] valid_path = glob(os.path.join(data_dir, 'valid.*'))[0] test_path = glob(os.path.join(data_dir, 'test.*'))[0] if '.xlsx' in train_path: train_set = read_excel(train_path) valid_set = read_excel(valid_path) test_set = read_excel(test_path) else: train_set = read_csv(train_path, delimiter=self._delimiter) valid_set = read_csv(valid_path, delimiter=self._delimiter) test_set = read_csv(test_path, delimiter=self._delimiter) train_set = Dataset.from_pandas(train_set) valid_set = Dataset.from_pandas(valid_set) test_set = Dataset.from_pandas(test_set) # If not indicated by directory, infer from columns if not ppi: ppi = self._is_ppi_from_columns(train_set.column_names) if ppi: print('Standardizing PPI column names') try: a_col, b_col = self._find_ppi_sequence_columns(train_set.column_names) except KeyError: try: a_col, b_col = self._find_ppi_sequence_columns(valid_set.column_names) except KeyError: a_col, b_col = self._find_ppi_sequence_columns(test_set.column_names) try: lbl_col = self._find_first_present_column(train_set.column_names, label_candidates) except KeyError: try: lbl_col = self._find_first_present_column(valid_set.column_names, label_candidates) except KeyError: lbl_col = self._find_first_present_column(test_set.column_names, label_candidates) train_set = train_set.rename_columns({a_col: 'SeqA', b_col: 'SeqB', lbl_col: 'labels'}) valid_set = valid_set.rename_columns({a_col: 'SeqA', b_col: 'SeqB', lbl_col: 'labels'}) test_set = test_set.rename_columns({a_col: 'SeqA', b_col: 'SeqB', lbl_col: 'labels'}) print('Removing extras') train_set = train_set.remove_columns([col for col in train_set.column_names if col not in ['SeqA', 'SeqB', 'labels']]) valid_set = valid_set.remove_columns([col for col in valid_set.column_names if col not in ['SeqA', 'SeqB', 'labels']]) test_set = test_set.remove_columns([col for col in test_set.column_names if col not in ['SeqA', 'SeqB', 'labels']]) else: print('Standardizing column names') use_multi = self.data_args.multi_column is not None if not use_multi: try: seq_col = self._find_first_present_column(train_set.column_names, seq_candidates) except KeyError: try: seq_col = self._find_first_present_column(valid_set.column_names, seq_candidates) except KeyError: seq_col = self._find_first_present_column(test_set.column_names, seq_candidates) try: label_col = self._find_first_present_column(train_set.column_names, label_candidates) except KeyError: try: label_col = self._find_first_present_column(valid_set.column_names, label_candidates) except KeyError: label_col = self._find_first_present_column(test_set.column_names, label_candidates) # Always standardize label column to 'labels' train_set = train_set.rename_columns({label_col: 'labels'}) valid_set = valid_set.rename_columns({label_col: 'labels'}) test_set = test_set.rename_columns({label_col: 'labels'}) if not use_multi: train_set = train_set.rename_columns({seq_col: 'seqs'}) valid_set = valid_set.rename_columns({seq_col: 'seqs'}) test_set = test_set.rename_columns({seq_col: 'seqs'}) # drop everything else print('Removing extras') train_set = train_set.remove_columns([col for col in train_set.column_names if col not in ['seqs', 'labels']]) valid_set = valid_set.remove_columns([col for col in valid_set.column_names if col not in ['seqs', 'labels']]) test_set = test_set.remove_columns([col for col in test_set.column_names if col not in ['seqs', 'labels']]) else: # Validate requested multi columns exist for col in self.data_args.multi_column: assert col in train_set.column_names or col in valid_set.column_names or col in test_set.column_names, f"Column {col} not found in dataset {data_name}" # Keep only requested columns and labels keep_cols = set(self.data_args.multi_column + ['labels']) train_set = train_set.remove_columns([col for col in train_set.column_names if col not in keep_cols]) valid_set = valid_set.remove_columns([col for col in valid_set.column_names if col not in keep_cols]) test_set = test_set.remove_columns([col for col in test_set.column_names if col not in keep_cols]) datasets.append((train_set, valid_set, test_set, ppi)) data_names.append(data_name) return self.process_datasets(hf_datasets=datasets, data_names=data_names) def get_embedding_dim_sql(self, save_path, test_seq, tokenizer): import sqlite3 test_seq_len = len(tokenizer(test_seq, return_tensors='pt')['input_ids'][0]) with sqlite3.connect(save_path) as conn: c = conn.cursor() c.execute("SELECT embedding FROM embeddings WHERE sequence = ?", (test_seq,)) test_embedding = c.fetchone()[0] test_embedding = torch.tensor(np.frombuffer(test_embedding, dtype=np.float32).reshape(1, -1)) if self._full: test_embedding = test_embedding.reshape(test_seq_len, -1) embedding_dim = test_embedding.shape[-1] return embedding_dim def get_embedding_dim_pth(self, emb_dict, test_seq, tokenizer): test_seq_len = len(tokenizer(test_seq, return_tensors='pt')['input_ids'][0]) test_embedding = emb_dict[test_seq] if self._full: test_embedding = test_embedding.reshape(test_seq_len, -1) embedding_dim = test_embedding.shape[-1] return embedding_dim def build_vector_numpy_dataset_from_embeddings( self, model_name, train_seqs, valid_seqs, test_seqs, ): save_dir = self.embedding_args.embedding_save_dir train_array, valid_array, test_array = [], [], [] # Get pooling types from embedding_args, default to ['mean'] if not available pooling_types = self.embedding_args.pooling_types if self._sql: import sqlite3 filename = get_embedding_filename(model_name, self._full, pooling_types, 'db') save_path = os.path.join(save_dir, filename) with sqlite3.connect(save_path) as conn: c = conn.cursor() for seq in train_seqs: embedding = self._select_from_sql(c, seq, cast_to_torch=False) train_array.append(embedding) for seq in valid_seqs: embedding = self._select_from_sql(c, seq, cast_to_torch=False) valid_array.append(embedding) for seq in test_seqs: embedding = self._select_from_sql(c, seq, cast_to_torch=False) test_array.append(embedding) else: filename = get_embedding_filename(model_name, self._full, pooling_types, 'pth') save_path = os.path.join(save_dir, filename) emb_dict = torch.load(save_path) for seq in train_seqs: embedding = self._select_from_pth(emb_dict, seq, cast_to_np=True) train_array.append(embedding) for seq in valid_seqs: embedding = self._select_from_pth(emb_dict, seq, cast_to_np=True) valid_array.append(embedding) for seq in test_seqs: embedding = self._select_from_pth(emb_dict, seq, cast_to_np=True) test_array.append(embedding) del emb_dict train_array = np.concatenate(train_array, axis=0) valid_array = np.concatenate(valid_array, axis=0) test_array = np.concatenate(test_array, axis=0) if self._full: # average over the length of the sequence train_array = np.mean(train_array, axis=1) valid_array = np.mean(valid_array, axis=1) test_array = np.mean(test_array, axis=1) print_message('Numpy dataset shapes') print_message(f'Train: {train_array.shape}') print_message(f'Valid: {valid_array.shape}') print_message(f'Test: {test_array.shape}') return train_array, valid_array, test_array def build_pair_vector_numpy_dataset_from_embeddings( self, model_name, train_seqs_a, train_seqs_b, valid_seqs_a, valid_seqs_b, test_seqs_a, test_seqs_b, ): save_dir = self.embedding_args.embedding_save_dir train_array, valid_array, test_array = [], [], [] pooling_types = self.embedding_args.pooling_types if self._sql: filename = get_embedding_filename(model_name, self._full, pooling_types, 'db') save_path = os.path.join(save_dir, filename) with sqlite3.connect(save_path) as conn: c = conn.cursor() for seq_a, seq_b in zip(train_seqs_a, train_seqs_b): seq_a, seq_b = self._random_order(seq_a, seq_b) embedding_a = self._select_from_sql(c, seq_a, cast_to_torch=False) embedding_b = self._select_from_sql(c, seq_b, cast_to_torch=False) train_array.append(np.concatenate([embedding_a, embedding_b], axis=-1)) for seq_a, seq_b in zip(valid_seqs_a, valid_seqs_b): seq_a, seq_b = self._random_order(seq_a, seq_b) embedding_a = self._select_from_sql(c, seq_a, cast_to_torch=False) embedding_b = self._select_from_sql(c, seq_b, cast_to_torch=False) valid_array.append(np.concatenate([embedding_a, embedding_b], axis=-1)) for seq_a, seq_b in zip(test_seqs_a, test_seqs_b): seq_a, seq_b = self._random_order(seq_a, seq_b) embedding_a = self._select_from_sql(c, seq_a, cast_to_torch=False) embedding_b = self._select_from_sql(c, seq_b, cast_to_torch=False) test_array.append(np.concatenate([embedding_a, embedding_b], axis=-1)) else: filename = get_embedding_filename(model_name, self._full, pooling_types, 'pth') save_path = os.path.join(save_dir, filename) emb_dict = torch.load(save_path) for seq_a, seq_b in zip(train_seqs_a, train_seqs_b): seq_a, seq_b = self._random_order(seq_a, seq_b) embedding_a = self._select_from_pth(emb_dict, seq_a, cast_to_np=True) embedding_b = self._select_from_pth(emb_dict, seq_b, cast_to_np=True) train_array.append(np.concatenate([embedding_a, embedding_b], axis=-1)) for seq_a, seq_b in zip(valid_seqs_a, valid_seqs_b): seq_a, seq_b = self._random_order(seq_a, seq_b) embedding_a = self._select_from_pth(emb_dict, seq_a, cast_to_np=True) embedding_b = self._select_from_pth(emb_dict, seq_b, cast_to_np=True) valid_array.append(np.concatenate([embedding_a, embedding_b], axis=-1)) for seq_a, seq_b in zip(test_seqs_a, test_seqs_b): seq_a, seq_b = self._random_order(seq_a, seq_b) embedding_a = self._select_from_pth(emb_dict, seq_a, cast_to_np=True) embedding_b = self._select_from_pth(emb_dict, seq_b, cast_to_np=True) test_array.append(np.concatenate([embedding_a, embedding_b], axis=-1)) del emb_dict train_array = np.concatenate(train_array, axis=0) valid_array = np.concatenate(valid_array, axis=0) test_array = np.concatenate(test_array, axis=0) if self._full: # average over the length of the sequence train_array = np.mean(train_array, axis=1) valid_array = np.mean(valid_array, axis=1) test_array = np.mean(test_array, axis=1) print_message('Numpy dataset shapes') print_message(f'Train: {train_array.shape}') print_message(f'Valid: {valid_array.shape}') print_message(f'Test: {test_array.shape}') return train_array, valid_array, test_array def prepare_scikit_dataset(self, model_name, dataset): train_set, valid_set, test_set, _, label_type, ppi = dataset if ppi: X_train, X_valid, X_test = self.build_pair_vector_numpy_dataset_from_embeddings( model_name, list(train_set['SeqA']), list(train_set['SeqB']), list(valid_set['SeqA']), list(valid_set['SeqB']), list(test_set['SeqA']), list(test_set['SeqB']), ) else: X_train, X_valid, X_test = self.build_vector_numpy_dataset_from_embeddings( model_name, list(train_set['seqs']), list(valid_set['seqs']), list(test_set['seqs']), ) y_train = self._labels_to_numpy(list(train_set['labels'])) y_valid = self._labels_to_numpy(list(valid_set['labels'])) y_test = self._labels_to_numpy(list(test_set['labels'])) print_message('Numpy dataset shapes with labels') print_message(f'Train: {X_train.shape}, {y_train.shape}') print_message(f'Valid: {X_valid.shape}, {y_valid.shape}') print_message(f'Test: {X_test.shape}, {y_test.shape}') return X_train, y_train, X_valid, y_valid, X_test, y_test, label_type