nikraf's picture
Upload folder using huggingface_hub
714cf46 verified
import torch
import numpy as np
import random
import os
import sqlite3
from typing import List, Tuple, Dict, Optional
from glob import glob
from pandas import read_csv, read_excel
from datasets import load_dataset, Dataset
from dataclasses import dataclass
try:
from utils import print_message
from seed_utils import get_global_seed
from embedder import get_embedding_filename
except ImportError:
from ..utils import print_message
from ..seed_utils import get_global_seed
from ..embedder import get_embedding_filename
from .supported_datasets import supported_datasets, standard_data_benchmark, vector_benchmark
from .utils import (
AA_SET,
CODON_SET,
DNA_SET,
RNA_SET,
NONCANONICAL_AMINO_ACIDS,
AMINO_ACID_TO_HUMAN_CODON,
NONCANONICAL_ALANINE_CODON,
AA_TO_CODON_TOKEN,
CODON_TO_AA,
DNA_CODON_TO_AA,
RNA_CODON_TO_AA,
)
@dataclass
class DataArguments:
"""
Args:
data_paths: List[str]
paths to the datasets
max_length: int
max length of sequences
trim: bool
whether to trim sequences to max_length
"""
def __init__(
self,
data_names: List[str],
delimiter: str = ',',
col_names: List[str] = ['seqs', 'labels'],
max_length: int = 1024,
trim: bool = False,
data_dirs: Optional[List[str]] = [],
multi_column: Optional[List[str]] = None,
aa_to_dna: bool = False,
aa_to_rna: bool = False,
dna_to_aa: bool = False,
rna_to_aa: bool = False,
codon_to_aa: bool = False,
aa_to_codon: bool = False,
**kwargs
):
self.data_names = data_names
self.data_dirs = data_dirs
self.delimiter = delimiter
self.col_names = col_names
self.max_length = max_length
self.trim = trim
self.protein_gym = False
self.multi_column = multi_column
self.aa_to_dna = aa_to_dna
self.aa_to_rna = aa_to_rna
self.dna_to_aa = dna_to_aa
self.rna_to_aa = rna_to_aa
self.codon_to_aa = codon_to_aa
self.aa_to_codon = aa_to_codon
if len(data_names) > 0:
if data_names[0] == 'standard_benchmark':
self.data_paths = [supported_datasets[data_name] for data_name in standard_data_benchmark]
elif data_names[0] == 'vector_benchmark':
self.data_paths = [supported_datasets[data_name] for data_name in vector_benchmark]
else:
self.data_paths = []
for data_name in data_names:
if data_name == 'protein_gym':
# For special handling in the main workflow
self.protein_gym = True
continue
if data_name in supported_datasets:
self.data_paths.append(supported_datasets[data_name])
else:
print(f'{data_name} not found in supported datasets')
print('We will attempt to load it from huggingface anyways, but this may not work')
self.data_paths.append(data_name)
else:
self.data_paths = []
if data_dirs is not None:
for dir in data_dirs:
if not os.path.exists(dir):
raise FileNotFoundError(f'{dir} does not exist')
class DataMixin:
def __init__(self, data_args: Optional[DataArguments] = None):
# intialize defaults
self._sql = False
self._full = False
self._max_length = 1024
self._trim = False
self._delimiter = ','
self._col_names = ['seqs', 'labels']
self._aa_to_dna = False
self._aa_to_rna = False
self._dna_to_aa = False
self._rna_to_aa = False
self._codon_to_aa = False
self._aa_to_codon = False
self.data_args = data_args
self._multi_column = None if data_args is None else getattr(data_args, 'multi_column', None)
if data_args is not None:
self._aa_to_dna = data_args.aa_to_dna
self._aa_to_rna = data_args.aa_to_rna
self._dna_to_aa = data_args.dna_to_aa
self._rna_to_aa = data_args.rna_to_aa
self._codon_to_aa = data_args.codon_to_aa
self._aa_to_codon = data_args.aa_to_codon
def _not_regression(self, labels): # not a great assumption but works most of the time
if isinstance(labels, list):
# Check if first element is itself a list (multilabel case)
if isinstance(labels[0], list):
# For multilabel: check all elements in all sublists
return all(isinstance(element, (int, float)) and element == int(element)
for label in labels for element in label)
else:
# For single label: check all elements in the list
return all(isinstance(label, (int, float)) and label == int(label)
for label in labels)
else:
# Fallback for non-list input
return all(isinstance(label, (int, float)) and label == int(label) for label in labels)
def _encode_labels(self, labels, tag2id):
return [torch.tensor([tag2id[tag] for tag in doc], dtype=torch.long) for doc in labels]
def _label_type_checker(self, labels):
ex = labels[0]
assert len(labels) > 0, f'Labels is empty: {labels}'
if self._not_regression(labels):
if isinstance(ex, list):
label_type = 'multilabel'
elif isinstance(ex, int) or isinstance(ex, float):
label_type = 'singlelabel' # binary or multiclass
elif isinstance(ex, str):
label_type = 'string'
else:
label_type = 'regression'
return label_type
def _is_sigmoid_regression(self, labels) -> bool:
"""Heuristic: labels within [0, 1] and cover the range approximately.
Uses 10-bin histogram coverage and span threshold.
"""
arr = []
for label in labels:
try:
arr.extend(label)
except:
arr.append(label)
arr = np.array(arr, dtype=float).flatten()
min_val, max_val = float(arr.min()), float(arr.max())
cond1 = min_val > 0.0 - 1e-6 and max_val < 1.0 + 1e-6
# Require substantial span across [0,1]
cond2 = (max_val - min_val) > 0.75
# Histogram coverage: at least 7 of 10 bins non-empty
hist, _ = np.histogram(arr, bins=10, range=(0.0, 1.0))
cond3 = int((hist > 0).sum()) >= 7
sigmoid_regression_status = cond1 and cond2 and cond3
return sigmoid_regression_status
def _select_from_sql(self, c, seq, cast_to_torch=True):
c.execute("SELECT embedding FROM embeddings WHERE sequence = ?", (seq,))
embedding = np.frombuffer(c.fetchone()[0], dtype=np.float32).reshape(1, -1)
if self._full:
embedding = embedding.reshape(len(seq), -1)
if cast_to_torch:
embedding = torch.tensor(embedding)
return embedding
def _select_from_pth(self, emb_dict, seq, cast_to_np=False):
embedding = emb_dict[seq].reshape(1, -1)
if self._full:
embedding = embedding.reshape(len(seq), -1)
if cast_to_np:
embedding = embedding.numpy()
return embedding
def _labels_to_numpy(self, labels):
if isinstance(labels[0], list):
return np.array(labels).flatten()
else:
return np.array([labels]).flatten()
def _random_order(self, seq_a, seq_b):
if random.random() < 0.5:
return seq_a, seq_b
else:
return seq_b, seq_a
def _truncate_pairs(self, ex):
# Truncate longest first, but if that makes it shorter than the other, truncate that one
seq_a, seq_b = ex['SeqA'], ex['SeqB']
trunc_a, trunc_b = seq_a, seq_b
while len(trunc_a) + len(trunc_b) > self._max_length:
if len(trunc_a) > len(trunc_b):
trunc_a = trunc_a[:-1]
else:
trunc_b = trunc_b[:-1]
ex['SeqA'] = trunc_a
ex['SeqB'] = trunc_b
return ex
def _active_translation_mode(self):
mode_to_flag = {
'aa_to_dna': self._aa_to_dna,
'aa_to_rna': self._aa_to_rna,
'dna_to_aa': self._dna_to_aa,
'rna_to_aa': self._rna_to_aa,
'codon_to_aa': self._codon_to_aa,
'aa_to_codon': self._aa_to_codon,
}
active_modes = [mode for mode, enabled in mode_to_flag.items() if enabled]
assert len(active_modes) <= 1, f'Only one translation mode can be enabled at a time, found: {active_modes}'
return active_modes[0] if len(active_modes) == 1 else None
def _assert_characters_in_set(self, seq, allowed_chars, mode):
bad_chars = sorted({char for char in seq if char.upper() not in allowed_chars})
assert len(bad_chars) == 0, f'Invalid characters for {mode}: {bad_chars}.'
def _validate_translated_output(self, translated_seq, allowed_chars, mode):
bad_chars = sorted({char for char in translated_seq if char not in allowed_chars})
assert len(bad_chars) == 0, f'Translation output for {mode} contains unexpected characters: {bad_chars}.'
def _normalize_aa_for_nucleotide_translation(self, seq):
canonical_aas = set(AMINO_ACID_TO_HUMAN_CODON.keys())
normalized = []
for residue in seq:
residue = residue.upper()
if residue in canonical_aas:
normalized.append(residue)
else:
normalized.append('X')
return ''.join(normalized)
def _translate_aa_to_dna(self, seq):
seq = self._normalize_aa_for_nucleotide_translation(seq)
dna_codons = []
for residue in seq:
residue = residue.upper()
if residue in AMINO_ACID_TO_HUMAN_CODON:
dna_codons.append(AMINO_ACID_TO_HUMAN_CODON[residue])
elif residue in NONCANONICAL_AMINO_ACIDS:
dna_codons.append(NONCANONICAL_ALANINE_CODON)
else:
raise AssertionError(f'Unexpected amino acid token "{residue}" while converting aa_to_dna.')
translated = ''.join(dna_codons)
self._validate_translated_output(translated, DNA_SET, 'aa_to_dna')
return translated
def _translate_aa_to_rna(self, seq):
dna_translated = self._translate_aa_to_dna(seq)
translated = dna_translated.replace('T', 'U')
self._validate_translated_output(translated, RNA_SET, 'aa_to_rna')
return translated
def _translate_dna_to_aa(self, seq):
dna_seq = seq.upper()
self._assert_characters_in_set(dna_seq, DNA_SET, 'dna_to_aa')
assert len(dna_seq) % 3 == 0, f'dna_to_aa requires sequence length multiple of 3, got {len(dna_seq)}.'
aa_seq = []
for idx in range(0, len(dna_seq), 3):
codon = dna_seq[idx:idx + 3]
assert codon in DNA_CODON_TO_AA, f'Unknown DNA codon for dna_to_aa: {codon}'
translated_char = DNA_CODON_TO_AA[codon]
if translated_char != '*':
aa_seq.append(translated_char)
translated = ''.join(aa_seq)
self._validate_translated_output(translated, AA_SET - {'*'}, 'dna_to_aa')
return translated
def _translate_rna_to_aa(self, seq):
rna_seq = seq.upper()
self._assert_characters_in_set(rna_seq, RNA_SET, 'rna_to_aa')
assert len(rna_seq) % 3 == 0, f'rna_to_aa requires sequence length multiple of 3, got {len(rna_seq)}.'
aa_seq = []
for idx in range(0, len(rna_seq), 3):
codon = rna_seq[idx:idx + 3]
assert codon in RNA_CODON_TO_AA, f'Unknown RNA codon for rna_to_aa: {codon}'
translated_char = RNA_CODON_TO_AA[codon]
if translated_char != '*':
aa_seq.append(translated_char)
translated = ''.join(aa_seq)
self._validate_translated_output(translated, AA_SET - {'*'}, 'rna_to_aa')
return translated
def _translate_codon_to_aa(self, seq):
aa_seq = []
for token in seq:
assert token in CODON_TO_AA, f'Unknown codon token for codon_to_aa: {token}'
translated_char = CODON_TO_AA[token]
if translated_char != '*':
aa_seq.append(translated_char)
translated = ''.join(aa_seq)
self._validate_translated_output(translated, AA_SET - {'*'}, 'codon_to_aa')
return translated
def _translate_aa_to_codon(self, seq):
codon_tokens = []
for residue in seq:
residue = residue.upper()
if residue in AA_TO_CODON_TOKEN:
codon_tokens.append(AA_TO_CODON_TOKEN[residue])
elif residue in NONCANONICAL_AMINO_ACIDS:
codon_tokens.append(AA_TO_CODON_TOKEN['A'])
else:
raise AssertionError(f'Unexpected amino acid token "{residue}" while converting aa_to_codon.')
translated = ''.join(codon_tokens)
self._validate_translated_output(translated, CODON_SET, 'aa_to_codon')
return translated
def _translate_sequence_for_mode(self, seq, mode):
if mode == 'aa_to_dna':
return self._translate_aa_to_dna(seq)
if mode == 'aa_to_rna':
return self._translate_aa_to_rna(seq)
if mode == 'dna_to_aa':
return self._translate_dna_to_aa(seq)
if mode == 'rna_to_aa':
return self._translate_rna_to_aa(seq)
if mode == 'codon_to_aa':
return self._translate_codon_to_aa(seq)
if mode == 'aa_to_codon':
return self._translate_aa_to_codon(seq)
raise AssertionError(f'Unsupported translation mode: {mode}')
def _find_first_present_column(self, available_columns, candidates_ordered):
"""Return the first column from candidates_ordered that exists in available_columns (case-insensitive)."""
lowercase_to_actual = {col.lower(): col for col in available_columns}
for candidate in candidates_ordered:
actual = lowercase_to_actual.get(candidate.lower())
if actual is not None:
return actual
raise KeyError(f"None of the candidate columns were found. Candidates: {candidates_ordered}. Available: {available_columns}")
def _is_ppi_from_columns(self, available_columns):
"""Detect if dataset contains paired sequence inputs (SeqA/SeqB variants)."""
lowercase_columns = set(col.lower() for col in available_columns)
base_candidates = ['seqs', 'seq', 'sequence', 'sequences']
for base in base_candidates:
if (base + 'a') in lowercase_columns and (base + 'b') in lowercase_columns:
return True
return False
def _find_ppi_sequence_columns(self, available_columns):
"""Return the actual column names for A and B sequences in PPI datasets based on priority."""
lowercase_to_actual = {col.lower(): col for col in available_columns}
# Try specific common pairs first (in order)
specific_pairs = [
('SeqA', 'SeqB'),
('seqa', 'seqb'),
('SeqsA', 'SeqsB'),
]
for cand_a, cand_b in specific_pairs:
a_actual = lowercase_to_actual.get(cand_a.lower())
b_actual = lowercase_to_actual.get(cand_b.lower())
if a_actual is not None and b_actual is not None:
return a_actual, b_actual
# Generalized search using base tokens
base_candidates = ['seqs', 'seq', 'sequence', 'sequences']
for base in base_candidates:
a_key = (base + 'a').lower()
b_key = (base + 'b').lower()
a_actual = lowercase_to_actual.get(a_key)
b_actual = lowercase_to_actual.get(b_key)
if a_actual is not None and b_actual is not None:
return a_actual, b_actual
raise KeyError(f"Could not find paired sequence columns for PPI. Available: {available_columns}")
def _is_missing_value(self, v):
if v is None:
return True
# float/np.nan handling
try:
if isinstance(v, float) and np.isnan(v):
return True
except Exception:
pass
# list/array handling (check any element)
if isinstance(v, (list, tuple, np.ndarray)):
for el in v:
if el is None:
return True
if isinstance(el, float) and np.isnan(el):
return True
return False
def process_datasets(
self,
hf_datasets: List[Tuple[Dataset, Dataset, Dataset, bool]],
data_names: List[str],
)-> Tuple[Dict[str, Tuple[Dataset, Dataset, Dataset, int, str, bool]], List[str]]:
max_length = self._max_length
datasets, all_seqs = {}, set()
translation_mode = self._active_translation_mode()
for dataset, data_name in zip(hf_datasets, data_names):
print_message(f'Processing {data_name}')
train_set, valid_set, test_set, ppi = dataset
print(train_set)
print(valid_set)
print(test_set)
### sanitize
# 1) Drop rows with None or NaN in any sequence column(s) or labels
before_train, before_valid, before_test = len(train_set), len(valid_set), len(test_set)
if ppi:
train_set = train_set.filter(lambda x: not (self._is_missing_value(x['SeqA']) or self._is_missing_value(x['SeqB']) or self._is_missing_value(x['labels'])))
valid_set = valid_set.filter(lambda x: not (self._is_missing_value(x['SeqA']) or self._is_missing_value(x['SeqB']) or self._is_missing_value(x['labels'])))
test_set = test_set.filter(lambda x: not (self._is_missing_value(x['SeqA']) or self._is_missing_value(x['SeqB']) or self._is_missing_value(x['labels'])))
elif self.data_args.multi_column:
cols = self.data_args.multi_column
# assert columns exist
for col in cols:
assert col in train_set.column_names or col in valid_set.column_names or col in test_set.column_names, f"Column {col} not found in dataset {data_name}"
def _filter_row(x):
return (not self._is_missing_value(x['labels'])) and all(not self._is_missing_value(x[col]) for col in cols)
train_set = train_set.filter(_filter_row)
valid_set = valid_set.filter(_filter_row)
test_set = test_set.filter(_filter_row)
else:
train_set = train_set.filter(lambda x: not (self._is_missing_value(x['seqs']) or self._is_missing_value(x['labels'])))
valid_set = valid_set.filter(lambda x: not (self._is_missing_value(x['seqs']) or self._is_missing_value(x['labels'])))
test_set = test_set.filter(lambda x: not (self._is_missing_value(x['seqs']) or self._is_missing_value(x['labels'])))
if any([
len(train_set) != before_train,
len(valid_set) != before_valid,
len(test_set) != before_test,
]):
print_message(
f"Removed None / NaN rows - train: {before_train - len(train_set)}, valid: {before_valid - len(valid_set)}, test: {before_test - len(test_set)}"
)
# 2) Legacy sanitization for non-translation workflows
if translation_mode is None:
if ppi:
train_set = train_set.map(lambda x: {'SeqA': ''.join(aa for aa in x['SeqA'] if aa in AA_SET),
'SeqB': ''.join(aa for aa in x['SeqB'] if aa in AA_SET)})
valid_set = valid_set.map(lambda x: {'SeqA': ''.join(aa for aa in x['SeqA'] if aa in AA_SET),
'SeqB': ''.join(aa for aa in x['SeqB'] if aa in AA_SET)})
test_set = test_set.map(lambda x: {'SeqA': ''.join(aa for aa in x['SeqA'] if aa in AA_SET),
'SeqB': ''.join(aa for aa in x['SeqB'] if aa in AA_SET)})
elif self.data_args.multi_column:
cols = self.data_args.multi_column
for col in cols:
train_set = train_set.map(lambda x, _col=col: {_col: ''.join(aa for aa in x[_col] if aa in AA_SET)})
valid_set = valid_set.map(lambda x, _col=col: {_col: ''.join(aa for aa in x[_col] if aa in AA_SET)})
test_set = test_set.map(lambda x, _col=col: {_col: ''.join(aa for aa in x[_col] if aa in AA_SET)})
else:
train_set = train_set.map(lambda x: {'seqs': ''.join(aa for aa in x['seqs'] if aa in AA_SET)})
valid_set = valid_set.map(lambda x: {'seqs': ''.join(aa for aa in x['seqs'] if aa in AA_SET)})
test_set = test_set.map(lambda x: {'seqs': ''.join(aa for aa in x['seqs'] if aa in AA_SET)})
# 3) Remove any length 0 sequences
before_train, before_valid, before_test = len(train_set), len(valid_set), len(test_set)
if ppi:
train_set = train_set.filter(lambda x: len(x['SeqA']) > 0 and len(x['SeqB']) > 0)
valid_set = valid_set.filter(lambda x: len(x['SeqA']) > 0 and len(x['SeqB']) > 0)
test_set = test_set.filter(lambda x: len(x['SeqA']) > 0 and len(x['SeqB']) > 0)
elif self.data_args.multi_column:
cols = self.data_args.multi_column
train_set = train_set.filter(lambda x: all(len(x[col]) > 0 for col in cols))
valid_set = valid_set.filter(lambda x: all(len(x[col]) > 0 for col in cols))
test_set = test_set.filter(lambda x: all(len(x[col]) > 0 for col in cols))
else:
train_set = train_set.filter(lambda x: len(x['seqs']) > 0)
valid_set = valid_set.filter(lambda x: len(x['seqs']) > 0)
test_set = test_set.filter(lambda x: len(x['seqs']) > 0)
if any([
len(train_set) != before_train,
len(valid_set) != before_valid,
len(test_set) != before_test,
]):
print_message(
f"Removed length 0 rows - train: {before_train - len(train_set)}, valid: {before_valid - len(valid_set)}, test: {before_test - len(test_set)}"
)
# 4) Trim or truncate by length if necessary
before_train, before_valid, before_test = len(train_set), len(valid_set), len(test_set)
if self._trim: # trim by length
if ppi:
train_set = train_set.filter(lambda x: len(x['SeqA']) + len(x['SeqB']) <= max_length)
valid_set = valid_set.filter(lambda x: len(x['SeqA']) + len(x['SeqB']) <= max_length)
test_set = test_set.filter(lambda x: len(x['SeqA']) + len(x['SeqB']) <= max_length)
elif self.data_args.multi_column:
cols = self.data_args.multi_column
train_set = train_set.filter(lambda x: all(len(x[col]) <= max_length for col in cols))
valid_set = valid_set.filter(lambda x: all(len(x[col]) <= max_length for col in cols))
test_set = test_set.filter(lambda x: all(len(x[col]) <= max_length for col in cols))
else:
train_set = train_set.filter(lambda x: len(x['seqs']) <= max_length)
valid_set = valid_set.filter(lambda x: len(x['seqs']) <= max_length)
test_set = test_set.filter(lambda x: len(x['seqs']) <= max_length)
else: # truncate to max_length
if ppi:
train_set = train_set.map(self._truncate_pairs)
valid_set = valid_set.map(self._truncate_pairs)
test_set = test_set.map(self._truncate_pairs)
elif self.data_args.multi_column:
cols = self.data_args.multi_column
for col in cols:
train_set = train_set.map(lambda x, _col=col: { _col: x[_col][:max_length] })
valid_set = valid_set.map(lambda x, _col=col: { _col: x[_col][:max_length] })
test_set = test_set.map(lambda x, _col=col: { _col: x[_col][:max_length] })
else:
train_set = train_set.map(lambda x: {'seqs': x['seqs'][:max_length]})
valid_set = valid_set.map(lambda x: {'seqs': x['seqs'][:max_length]})
test_set = test_set.map(lambda x: {'seqs': x['seqs'][:max_length]})
if any([
len(train_set) != before_train,
len(valid_set) != before_valid,
len(test_set) != before_test,
]):
print_message(
f"Trimmed rows - train: {(before_train - len(train_set)) / before_train * 100:.2f}%, \
valid: {(before_valid - len(valid_set)) / before_valid * 100:.2f}%, \
test: {(before_test - len(test_set)) / before_test * 100:.2f}%"
)
# 5) Optional sequence translation (post-trim/truncate)
if translation_mode is not None:
if ppi:
train_set = train_set.map(lambda x: {'SeqA': self._translate_sequence_for_mode(x['SeqA'], translation_mode),
'SeqB': self._translate_sequence_for_mode(x['SeqB'], translation_mode)})
valid_set = valid_set.map(lambda x: {'SeqA': self._translate_sequence_for_mode(x['SeqA'], translation_mode),
'SeqB': self._translate_sequence_for_mode(x['SeqB'], translation_mode)})
test_set = test_set.map(lambda x: {'SeqA': self._translate_sequence_for_mode(x['SeqA'], translation_mode),
'SeqB': self._translate_sequence_for_mode(x['SeqB'], translation_mode)})
elif self.data_args.multi_column:
cols = self.data_args.multi_column
for col in cols:
train_set = train_set.map(lambda x, _col=col: {_col: self._translate_sequence_for_mode(x[_col], translation_mode)})
valid_set = valid_set.map(lambda x, _col=col: {_col: self._translate_sequence_for_mode(x[_col], translation_mode)})
test_set = test_set.map(lambda x, _col=col: {_col: self._translate_sequence_for_mode(x[_col], translation_mode)})
else:
train_set = train_set.map(lambda x: {'seqs': self._translate_sequence_for_mode(x['seqs'], translation_mode)})
valid_set = valid_set.map(lambda x: {'seqs': self._translate_sequence_for_mode(x['seqs'], translation_mode)})
test_set = test_set.map(lambda x: {'seqs': self._translate_sequence_for_mode(x['seqs'], translation_mode)})
print_message(f"Translated sequences with mode {translation_mode} (post-trim/truncate).")
# 6) Record all_seqs
if ppi:
all_seqs.update(list(train_set['SeqA']) + list(train_set['SeqB']))
all_seqs.update(list(valid_set['SeqA']) + list(valid_set['SeqB']))
all_seqs.update(list(test_set['SeqA']) + list(test_set['SeqB']))
elif self.data_args.multi_column:
cols = self.data_args.multi_column
for col in cols:
all_seqs.update(list(train_set[col]))
all_seqs.update(list(valid_set[col]))
all_seqs.update(list(test_set[col]))
else:
all_seqs.update(list(train_set['seqs']))
all_seqs.update(list(valid_set['seqs']))
all_seqs.update(list(test_set['seqs']))
# confirm the type of labels
check_labels = list(valid_set['labels'])
label_type = self._label_type_checker(check_labels)
if label_type == 'string': # might be string or multilabel
example = list(valid_set['labels'])[0]
try:
import ast
new_ex = ast.literal_eval(example)
if isinstance(new_ex, list): # if ast runs correctly and is now a list it is multilabel labels
label_type = 'multilabel'
train_set = train_set.map(lambda ex: {'labels': ast.literal_eval(ex['labels'])})
valid_set = valid_set.map(lambda ex: {'labels': ast.literal_eval(ex['labels'])})
test_set = test_set.map(lambda ex: {'labels': ast.literal_eval(ex['labels'])})
except:
label_type = 'string' # if ast throws error it is actually string
if label_type == 'string': # if still string, it's for tokenwise classification
train_labels = list(train_set['labels'])
unique_tags = set(tag for doc in train_labels for tag in doc)
tag2id = {tag: id for id, tag in enumerate(sorted(unique_tags))}
# add cls token to labels
train_set = train_set.map(lambda ex: {'labels': self._encode_labels(ex['labels'], tag2id=tag2id)})
valid_set = valid_set.map(lambda ex: {'labels': self._encode_labels(ex['labels'], tag2id=tag2id)})
test_set = test_set.map(lambda ex: {'labels': self._encode_labels(ex['labels'], tag2id=tag2id)})
label_type = 'tokenwise'
num_labels = len(unique_tags)
else:
if label_type == 'regression':
# Detect sigmoid_regression (values in [0,1] covering the range)
if self._is_sigmoid_regression(list(train_set['labels'])):
label_type = 'sigmoid_regression'
num_labels = 1
else: # if classification, get the total number of leabels
try:
train_labels_list = list(train_set['labels'])
num_labels = len(train_labels_list[0])
except:
unique = np.unique(list(train_set['labels']))
max_label = max(unique) # sometimes there are missing labels
full_list = np.arange(0, max_label+1)
num_labels = len(full_list)
datasets[data_name] = (train_set, valid_set, test_set, num_labels, label_type, ppi)
print(f'Label type: {label_type}')
print(f'Number of labels: {num_labels}')
all_seqs = list(all_seqs)
all_seqs = sorted(all_seqs, key=len, reverse=True) # longest first
return datasets, all_seqs
def get_data(self):
"""
Supports .csv, .tsv, .txt
TODO fasta, fa, fna, etc.
"""
datasets, data_names = [], []
label_candidates = ['labels', 'label', 'Labels', 'Label']
seq_candidates = ['seqs', 'Seqs', 'seq', 'Seq', 'sequence', 'Sequence', 'sequences', 'Sequences']
for data_path in self.data_args.data_paths:
data_name = data_path.split('/')[-1]
print_message(f'Loading {data_name}')
dataset = load_dataset(data_path)
if 'inverse' in data_name.lower():
dataset = dataset.rename_columns({'seqs': 'labels', 'labels': 'seqs'})
ppi = 'SeqA' in dataset['train'].column_names
# Fallback PPI detection based on available columns
if not ppi:
ppi = self._is_ppi_from_columns(dataset['train'].column_names)
print_message(f'PPI (or dual sequence input dataset): {ppi}')
### TODO, add better handling for valid, validation, test, testing, etc.
assert 'train' in dataset, f'{data_name} does not have a train set'
assert 'valid' in dataset or 'test' in dataset, f'{data_name} does not have a valid or test set, needs at least one'
if 'valid' not in dataset:
seed = get_global_seed() if get_global_seed() is not None else 42
train_set = dataset['train']
train_valid_set = train_set.train_test_split(test_size=0.1, seed=seed + 1)
train_set = train_valid_set['train']
valid_set = train_valid_set['test']
test_set = dataset['test']
print_message(f'{data_name} does not have a valid set, created a 10% validation set')
elif 'test' not in dataset:
seed = get_global_seed() if get_global_seed() is not None else 42
train_set = dataset['train']
train_test_set = train_set.train_test_split(test_size=0.1, seed=seed + 2)
test_set = train_test_set['test']
train_set = train_test_set['train']
valid_set = dataset['valid']
print_message(f'{data_name} does not have a test set, created a 10% test set')
else:
train_set, valid_set, test_set = dataset['train'], dataset['valid'], dataset['test']
print_message(f'{data_name} has a valid and test set')
print_message(f'Train set: {len(train_set)}, Valid set: {len(valid_set)}, Test set: {len(test_set)}')
if ppi:
# Standardize PPI columns to 'SeqA', 'SeqB', and 'labels'
print('Standardizing PPI column names')
try:
a_col, b_col = self._find_ppi_sequence_columns(train_set.column_names)
except KeyError:
# Retry with validation/test in case train is empty or missing columns
try:
a_col, b_col = self._find_ppi_sequence_columns(valid_set.column_names)
except KeyError:
a_col, b_col = self._find_ppi_sequence_columns(test_set.column_names)
try:
lbl_col = self._find_first_present_column(train_set.column_names, label_candidates)
except KeyError:
try:
lbl_col = self._find_first_present_column(valid_set.column_names, label_candidates)
except KeyError:
lbl_col = self._find_first_present_column(test_set.column_names, label_candidates)
train_set = train_set.rename_columns({a_col: 'SeqA', b_col: 'SeqB', lbl_col: 'labels'})
valid_set = valid_set.rename_columns({a_col: 'SeqA', b_col: 'SeqB', lbl_col: 'labels'})
test_set = test_set.rename_columns({a_col: 'SeqA', b_col: 'SeqB', lbl_col: 'labels'})
print('Removing extras')
train_set = train_set.remove_columns([col for col in train_set.column_names if col not in ['SeqA', 'SeqB', 'labels']])
valid_set = valid_set.remove_columns([col for col in valid_set.column_names if col not in ['SeqA', 'SeqB', 'labels']])
test_set = test_set.remove_columns([col for col in test_set.column_names if col not in ['SeqA', 'SeqB', 'labels']])
else:
print('Standardizing column names')
use_multi = self.data_args.multi_column is not None
if not use_multi:
try:
seq_col = self._find_first_present_column(train_set.column_names, seq_candidates)
except KeyError:
try:
seq_col = self._find_first_present_column(valid_set.column_names, seq_candidates)
except KeyError:
seq_col = self._find_first_present_column(test_set.column_names, seq_candidates)
try:
label_col = self._find_first_present_column(train_set.column_names, label_candidates)
except KeyError:
try:
label_col = self._find_first_present_column(valid_set.column_names, label_candidates)
except KeyError:
label_col = self._find_first_present_column(test_set.column_names, label_candidates)
# Always standardize label column to 'labels'
train_set = train_set.rename_columns({label_col: 'labels'})
valid_set = valid_set.rename_columns({label_col: 'labels'})
test_set = test_set.rename_columns({label_col: 'labels'})
if not use_multi:
train_set = train_set.rename_columns({seq_col: 'seqs'})
valid_set = valid_set.rename_columns({seq_col: 'seqs'})
test_set = test_set.rename_columns({seq_col: 'seqs'})
# drop everything else
print('Removing extras')
train_set = train_set.remove_columns([col for col in train_set.column_names if col not in ['seqs', 'labels']])
valid_set = valid_set.remove_columns([col for col in valid_set.column_names if col not in ['seqs', 'labels']])
test_set = test_set.remove_columns([col for col in test_set.column_names if col not in ['seqs', 'labels']])
else:
# Validate requested multi columns exist (assert exact match)
for col in self.data_args.multi_column:
assert col in train_set.column_names or col in valid_set.column_names or col in test_set.column_names, f"Column {col} not found in dataset {data_name}"
# Keep only requested columns and labels
keep_cols = set(self.data_args.multi_column + ['labels'])
train_set = train_set.remove_columns([col for col in train_set.column_names if col not in keep_cols])
valid_set = valid_set.remove_columns([col for col in valid_set.column_names if col not in keep_cols])
test_set = test_set.remove_columns([col for col in test_set.column_names if col not in keep_cols])
datasets.append((train_set, valid_set, test_set, ppi))
data_names.append(data_name)
for data_dir in self.data_args.data_dirs:
# local_data/taxon
data_name = data_dir.split ('/')[-1]
# Determine PPI by directory hint or columns
ppi = 'ppi' in data_dir.lower()
train_path = glob(os.path.join(data_dir, 'train.*'))[0]
valid_path = glob(os.path.join(data_dir, 'valid.*'))[0]
test_path = glob(os.path.join(data_dir, 'test.*'))[0]
if '.xlsx' in train_path:
train_set = read_excel(train_path)
valid_set = read_excel(valid_path)
test_set = read_excel(test_path)
else:
train_set = read_csv(train_path, delimiter=self._delimiter)
valid_set = read_csv(valid_path, delimiter=self._delimiter)
test_set = read_csv(test_path, delimiter=self._delimiter)
train_set = Dataset.from_pandas(train_set)
valid_set = Dataset.from_pandas(valid_set)
test_set = Dataset.from_pandas(test_set)
# If not indicated by directory, infer from columns
if not ppi:
ppi = self._is_ppi_from_columns(train_set.column_names)
if ppi:
print('Standardizing PPI column names')
try:
a_col, b_col = self._find_ppi_sequence_columns(train_set.column_names)
except KeyError:
try:
a_col, b_col = self._find_ppi_sequence_columns(valid_set.column_names)
except KeyError:
a_col, b_col = self._find_ppi_sequence_columns(test_set.column_names)
try:
lbl_col = self._find_first_present_column(train_set.column_names, label_candidates)
except KeyError:
try:
lbl_col = self._find_first_present_column(valid_set.column_names, label_candidates)
except KeyError:
lbl_col = self._find_first_present_column(test_set.column_names, label_candidates)
train_set = train_set.rename_columns({a_col: 'SeqA', b_col: 'SeqB', lbl_col: 'labels'})
valid_set = valid_set.rename_columns({a_col: 'SeqA', b_col: 'SeqB', lbl_col: 'labels'})
test_set = test_set.rename_columns({a_col: 'SeqA', b_col: 'SeqB', lbl_col: 'labels'})
print('Removing extras')
train_set = train_set.remove_columns([col for col in train_set.column_names if col not in ['SeqA', 'SeqB', 'labels']])
valid_set = valid_set.remove_columns([col for col in valid_set.column_names if col not in ['SeqA', 'SeqB', 'labels']])
test_set = test_set.remove_columns([col for col in test_set.column_names if col not in ['SeqA', 'SeqB', 'labels']])
else:
print('Standardizing column names')
use_multi = self.data_args.multi_column is not None
if not use_multi:
try:
seq_col = self._find_first_present_column(train_set.column_names, seq_candidates)
except KeyError:
try:
seq_col = self._find_first_present_column(valid_set.column_names, seq_candidates)
except KeyError:
seq_col = self._find_first_present_column(test_set.column_names, seq_candidates)
try:
label_col = self._find_first_present_column(train_set.column_names, label_candidates)
except KeyError:
try:
label_col = self._find_first_present_column(valid_set.column_names, label_candidates)
except KeyError:
label_col = self._find_first_present_column(test_set.column_names, label_candidates)
# Always standardize label column to 'labels'
train_set = train_set.rename_columns({label_col: 'labels'})
valid_set = valid_set.rename_columns({label_col: 'labels'})
test_set = test_set.rename_columns({label_col: 'labels'})
if not use_multi:
train_set = train_set.rename_columns({seq_col: 'seqs'})
valid_set = valid_set.rename_columns({seq_col: 'seqs'})
test_set = test_set.rename_columns({seq_col: 'seqs'})
# drop everything else
print('Removing extras')
train_set = train_set.remove_columns([col for col in train_set.column_names if col not in ['seqs', 'labels']])
valid_set = valid_set.remove_columns([col for col in valid_set.column_names if col not in ['seqs', 'labels']])
test_set = test_set.remove_columns([col for col in test_set.column_names if col not in ['seqs', 'labels']])
else:
# Validate requested multi columns exist
for col in self.data_args.multi_column:
assert col in train_set.column_names or col in valid_set.column_names or col in test_set.column_names, f"Column {col} not found in dataset {data_name}"
# Keep only requested columns and labels
keep_cols = set(self.data_args.multi_column + ['labels'])
train_set = train_set.remove_columns([col for col in train_set.column_names if col not in keep_cols])
valid_set = valid_set.remove_columns([col for col in valid_set.column_names if col not in keep_cols])
test_set = test_set.remove_columns([col for col in test_set.column_names if col not in keep_cols])
datasets.append((train_set, valid_set, test_set, ppi))
data_names.append(data_name)
return self.process_datasets(hf_datasets=datasets, data_names=data_names)
def get_embedding_dim_sql(self, save_path, test_seq, tokenizer):
import sqlite3
test_seq_len = len(tokenizer(test_seq, return_tensors='pt')['input_ids'][0])
with sqlite3.connect(save_path) as conn:
c = conn.cursor()
c.execute("SELECT embedding FROM embeddings WHERE sequence = ?", (test_seq,))
test_embedding = c.fetchone()[0]
test_embedding = torch.tensor(np.frombuffer(test_embedding, dtype=np.float32).reshape(1, -1))
if self._full:
test_embedding = test_embedding.reshape(test_seq_len, -1)
embedding_dim = test_embedding.shape[-1]
return embedding_dim
def get_embedding_dim_pth(self, emb_dict, test_seq, tokenizer):
test_seq_len = len(tokenizer(test_seq, return_tensors='pt')['input_ids'][0])
test_embedding = emb_dict[test_seq]
if self._full:
test_embedding = test_embedding.reshape(test_seq_len, -1)
embedding_dim = test_embedding.shape[-1]
return embedding_dim
def build_vector_numpy_dataset_from_embeddings(
self,
model_name,
train_seqs,
valid_seqs,
test_seqs,
):
save_dir = self.embedding_args.embedding_save_dir
train_array, valid_array, test_array = [], [], []
# Get pooling types from embedding_args, default to ['mean'] if not available
pooling_types = self.embedding_args.pooling_types
if self._sql:
import sqlite3
filename = get_embedding_filename(model_name, self._full, pooling_types, 'db')
save_path = os.path.join(save_dir, filename)
with sqlite3.connect(save_path) as conn:
c = conn.cursor()
for seq in train_seqs:
embedding = self._select_from_sql(c, seq, cast_to_torch=False)
train_array.append(embedding)
for seq in valid_seqs:
embedding = self._select_from_sql(c, seq, cast_to_torch=False)
valid_array.append(embedding)
for seq in test_seqs:
embedding = self._select_from_sql(c, seq, cast_to_torch=False)
test_array.append(embedding)
else:
filename = get_embedding_filename(model_name, self._full, pooling_types, 'pth')
save_path = os.path.join(save_dir, filename)
emb_dict = torch.load(save_path)
for seq in train_seqs:
embedding = self._select_from_pth(emb_dict, seq, cast_to_np=True)
train_array.append(embedding)
for seq in valid_seqs:
embedding = self._select_from_pth(emb_dict, seq, cast_to_np=True)
valid_array.append(embedding)
for seq in test_seqs:
embedding = self._select_from_pth(emb_dict, seq, cast_to_np=True)
test_array.append(embedding)
del emb_dict
train_array = np.concatenate(train_array, axis=0)
valid_array = np.concatenate(valid_array, axis=0)
test_array = np.concatenate(test_array, axis=0)
if self._full: # average over the length of the sequence
train_array = np.mean(train_array, axis=1)
valid_array = np.mean(valid_array, axis=1)
test_array = np.mean(test_array, axis=1)
print_message('Numpy dataset shapes')
print_message(f'Train: {train_array.shape}')
print_message(f'Valid: {valid_array.shape}')
print_message(f'Test: {test_array.shape}')
return train_array, valid_array, test_array
def build_pair_vector_numpy_dataset_from_embeddings(
self,
model_name,
train_seqs_a,
train_seqs_b,
valid_seqs_a,
valid_seqs_b,
test_seqs_a,
test_seqs_b,
):
save_dir = self.embedding_args.embedding_save_dir
train_array, valid_array, test_array = [], [], []
pooling_types = self.embedding_args.pooling_types
if self._sql:
filename = get_embedding_filename(model_name, self._full, pooling_types, 'db')
save_path = os.path.join(save_dir, filename)
with sqlite3.connect(save_path) as conn:
c = conn.cursor()
for seq_a, seq_b in zip(train_seqs_a, train_seqs_b):
seq_a, seq_b = self._random_order(seq_a, seq_b)
embedding_a = self._select_from_sql(c, seq_a, cast_to_torch=False)
embedding_b = self._select_from_sql(c, seq_b, cast_to_torch=False)
train_array.append(np.concatenate([embedding_a, embedding_b], axis=-1))
for seq_a, seq_b in zip(valid_seqs_a, valid_seqs_b):
seq_a, seq_b = self._random_order(seq_a, seq_b)
embedding_a = self._select_from_sql(c, seq_a, cast_to_torch=False)
embedding_b = self._select_from_sql(c, seq_b, cast_to_torch=False)
valid_array.append(np.concatenate([embedding_a, embedding_b], axis=-1))
for seq_a, seq_b in zip(test_seqs_a, test_seqs_b):
seq_a, seq_b = self._random_order(seq_a, seq_b)
embedding_a = self._select_from_sql(c, seq_a, cast_to_torch=False)
embedding_b = self._select_from_sql(c, seq_b, cast_to_torch=False)
test_array.append(np.concatenate([embedding_a, embedding_b], axis=-1))
else:
filename = get_embedding_filename(model_name, self._full, pooling_types, 'pth')
save_path = os.path.join(save_dir, filename)
emb_dict = torch.load(save_path)
for seq_a, seq_b in zip(train_seqs_a, train_seqs_b):
seq_a, seq_b = self._random_order(seq_a, seq_b)
embedding_a = self._select_from_pth(emb_dict, seq_a, cast_to_np=True)
embedding_b = self._select_from_pth(emb_dict, seq_b, cast_to_np=True)
train_array.append(np.concatenate([embedding_a, embedding_b], axis=-1))
for seq_a, seq_b in zip(valid_seqs_a, valid_seqs_b):
seq_a, seq_b = self._random_order(seq_a, seq_b)
embedding_a = self._select_from_pth(emb_dict, seq_a, cast_to_np=True)
embedding_b = self._select_from_pth(emb_dict, seq_b, cast_to_np=True)
valid_array.append(np.concatenate([embedding_a, embedding_b], axis=-1))
for seq_a, seq_b in zip(test_seqs_a, test_seqs_b):
seq_a, seq_b = self._random_order(seq_a, seq_b)
embedding_a = self._select_from_pth(emb_dict, seq_a, cast_to_np=True)
embedding_b = self._select_from_pth(emb_dict, seq_b, cast_to_np=True)
test_array.append(np.concatenate([embedding_a, embedding_b], axis=-1))
del emb_dict
train_array = np.concatenate(train_array, axis=0)
valid_array = np.concatenate(valid_array, axis=0)
test_array = np.concatenate(test_array, axis=0)
if self._full: # average over the length of the sequence
train_array = np.mean(train_array, axis=1)
valid_array = np.mean(valid_array, axis=1)
test_array = np.mean(test_array, axis=1)
print_message('Numpy dataset shapes')
print_message(f'Train: {train_array.shape}')
print_message(f'Valid: {valid_array.shape}')
print_message(f'Test: {test_array.shape}')
return train_array, valid_array, test_array
def prepare_scikit_dataset(self, model_name, dataset):
train_set, valid_set, test_set, _, label_type, ppi = dataset
if ppi:
X_train, X_valid, X_test = self.build_pair_vector_numpy_dataset_from_embeddings(
model_name,
list(train_set['SeqA']),
list(train_set['SeqB']),
list(valid_set['SeqA']),
list(valid_set['SeqB']),
list(test_set['SeqA']),
list(test_set['SeqB']),
)
else:
X_train, X_valid, X_test = self.build_vector_numpy_dataset_from_embeddings(
model_name,
list(train_set['seqs']),
list(valid_set['seqs']),
list(test_set['seqs']),
)
y_train = self._labels_to_numpy(list(train_set['labels']))
y_valid = self._labels_to_numpy(list(valid_set['labels']))
y_test = self._labels_to_numpy(list(test_set['labels']))
print_message('Numpy dataset shapes with labels')
print_message(f'Train: {X_train.shape}, {y_train.shape}')
print_message(f'Valid: {X_valid.shape}, {y_valid.shape}')
print_message(f'Test: {X_test.shape}, {y_test.shape}')
return X_train, y_train, X_valid, y_valid, X_test, y_test, label_type