| | import torch |
| | import numpy as np |
| | import random |
| | import os |
| | import sqlite3 |
| | from typing import List, Tuple, Dict, Optional |
| | from glob import glob |
| | from pandas import read_csv, read_excel |
| | from datasets import load_dataset, Dataset |
| | from dataclasses import dataclass |
| |
|
| | try: |
| | from utils import print_message |
| | from seed_utils import get_global_seed |
| | from embedder import get_embedding_filename |
| | except ImportError: |
| | from ..utils import print_message |
| | from ..seed_utils import get_global_seed |
| | from ..embedder import get_embedding_filename |
| | from .supported_datasets import supported_datasets, standard_data_benchmark, vector_benchmark |
| | from .utils import ( |
| | AA_SET, |
| | CODON_SET, |
| | DNA_SET, |
| | RNA_SET, |
| | NONCANONICAL_AMINO_ACIDS, |
| | AMINO_ACID_TO_HUMAN_CODON, |
| | NONCANONICAL_ALANINE_CODON, |
| | AA_TO_CODON_TOKEN, |
| | CODON_TO_AA, |
| | DNA_CODON_TO_AA, |
| | RNA_CODON_TO_AA, |
| | ) |
| |
|
| |
|
| |
|
| |
|
| | @dataclass |
| | class DataArguments: |
| | """ |
| | Args: |
| | data_paths: List[str] |
| | paths to the datasets |
| | max_length: int |
| | max length of sequences |
| | trim: bool |
| | whether to trim sequences to max_length |
| | """ |
| | def __init__( |
| | self, |
| | data_names: List[str], |
| | delimiter: str = ',', |
| | col_names: List[str] = ['seqs', 'labels'], |
| | max_length: int = 1024, |
| | trim: bool = False, |
| | data_dirs: Optional[List[str]] = [], |
| | multi_column: Optional[List[str]] = None, |
| | aa_to_dna: bool = False, |
| | aa_to_rna: bool = False, |
| | dna_to_aa: bool = False, |
| | rna_to_aa: bool = False, |
| | codon_to_aa: bool = False, |
| | aa_to_codon: bool = False, |
| | **kwargs |
| | ): |
| | self.data_names = data_names |
| | self.data_dirs = data_dirs |
| | self.delimiter = delimiter |
| | self.col_names = col_names |
| | self.max_length = max_length |
| | self.trim = trim |
| | self.protein_gym = False |
| | self.multi_column = multi_column |
| | self.aa_to_dna = aa_to_dna |
| | self.aa_to_rna = aa_to_rna |
| | self.dna_to_aa = dna_to_aa |
| | self.rna_to_aa = rna_to_aa |
| | self.codon_to_aa = codon_to_aa |
| | self.aa_to_codon = aa_to_codon |
| |
|
| | if len(data_names) > 0: |
| | if data_names[0] == 'standard_benchmark': |
| | self.data_paths = [supported_datasets[data_name] for data_name in standard_data_benchmark] |
| | elif data_names[0] == 'vector_benchmark': |
| | self.data_paths = [supported_datasets[data_name] for data_name in vector_benchmark] |
| | else: |
| | self.data_paths = [] |
| | for data_name in data_names: |
| | if data_name == 'protein_gym': |
| | |
| | self.protein_gym = True |
| | continue |
| | if data_name in supported_datasets: |
| | self.data_paths.append(supported_datasets[data_name]) |
| | else: |
| | print(f'{data_name} not found in supported datasets') |
| | print('We will attempt to load it from huggingface anyways, but this may not work') |
| | self.data_paths.append(data_name) |
| | else: |
| | self.data_paths = [] |
| | |
| | if data_dirs is not None: |
| | for dir in data_dirs: |
| | if not os.path.exists(dir): |
| | raise FileNotFoundError(f'{dir} does not exist') |
| |
|
| |
|
| | class DataMixin: |
| | def __init__(self, data_args: Optional[DataArguments] = None): |
| | |
| | self._sql = False |
| | self._full = False |
| | self._max_length = 1024 |
| | self._trim = False |
| | self._delimiter = ',' |
| | self._col_names = ['seqs', 'labels'] |
| | self._aa_to_dna = False |
| | self._aa_to_rna = False |
| | self._dna_to_aa = False |
| | self._rna_to_aa = False |
| | self._codon_to_aa = False |
| | self._aa_to_codon = False |
| | self.data_args = data_args |
| | self._multi_column = None if data_args is None else getattr(data_args, 'multi_column', None) |
| | if data_args is not None: |
| | self._aa_to_dna = data_args.aa_to_dna |
| | self._aa_to_rna = data_args.aa_to_rna |
| | self._dna_to_aa = data_args.dna_to_aa |
| | self._rna_to_aa = data_args.rna_to_aa |
| | self._codon_to_aa = data_args.codon_to_aa |
| | self._aa_to_codon = data_args.aa_to_codon |
| |
|
| | def _not_regression(self, labels): |
| | if isinstance(labels, list): |
| | |
| | if isinstance(labels[0], list): |
| | |
| | return all(isinstance(element, (int, float)) and element == int(element) |
| | for label in labels for element in label) |
| | else: |
| | |
| | return all(isinstance(label, (int, float)) and label == int(label) |
| | for label in labels) |
| | else: |
| | |
| | return all(isinstance(label, (int, float)) and label == int(label) for label in labels) |
| |
|
| | def _encode_labels(self, labels, tag2id): |
| | return [torch.tensor([tag2id[tag] for tag in doc], dtype=torch.long) for doc in labels] |
| |
|
| | def _label_type_checker(self, labels): |
| | ex = labels[0] |
| | assert len(labels) > 0, f'Labels is empty: {labels}' |
| | if self._not_regression(labels): |
| | if isinstance(ex, list): |
| | label_type = 'multilabel' |
| | elif isinstance(ex, int) or isinstance(ex, float): |
| | label_type = 'singlelabel' |
| | elif isinstance(ex, str): |
| | label_type = 'string' |
| | else: |
| | label_type = 'regression' |
| | return label_type |
| |
|
| | def _is_sigmoid_regression(self, labels) -> bool: |
| | """Heuristic: labels within [0, 1] and cover the range approximately. |
| | Uses 10-bin histogram coverage and span threshold. |
| | """ |
| | arr = [] |
| | for label in labels: |
| | try: |
| | arr.extend(label) |
| | except: |
| | arr.append(label) |
| | arr = np.array(arr, dtype=float).flatten() |
| |
|
| | min_val, max_val = float(arr.min()), float(arr.max()) |
| | cond1 = min_val > 0.0 - 1e-6 and max_val < 1.0 + 1e-6 |
| |
|
| | |
| | cond2 = (max_val - min_val) > 0.75 |
| |
|
| | |
| | hist, _ = np.histogram(arr, bins=10, range=(0.0, 1.0)) |
| | cond3 = int((hist > 0).sum()) >= 7 |
| |
|
| | sigmoid_regression_status = cond1 and cond2 and cond3 |
| | return sigmoid_regression_status |
| |
|
| | def _select_from_sql(self, c, seq, cast_to_torch=True): |
| | c.execute("SELECT embedding FROM embeddings WHERE sequence = ?", (seq,)) |
| | embedding = np.frombuffer(c.fetchone()[0], dtype=np.float32).reshape(1, -1) |
| | if self._full: |
| | embedding = embedding.reshape(len(seq), -1) |
| | if cast_to_torch: |
| | embedding = torch.tensor(embedding) |
| | return embedding |
| |
|
| | def _select_from_pth(self, emb_dict, seq, cast_to_np=False): |
| | embedding = emb_dict[seq].reshape(1, -1) |
| | if self._full: |
| | embedding = embedding.reshape(len(seq), -1) |
| | if cast_to_np: |
| | embedding = embedding.numpy() |
| | return embedding |
| |
|
| | def _labels_to_numpy(self, labels): |
| | if isinstance(labels[0], list): |
| | return np.array(labels).flatten() |
| | else: |
| | return np.array([labels]).flatten() |
| |
|
| | def _random_order(self, seq_a, seq_b): |
| | if random.random() < 0.5: |
| | return seq_a, seq_b |
| | else: |
| | return seq_b, seq_a |
| |
|
| | def _truncate_pairs(self, ex): |
| | |
| | seq_a, seq_b = ex['SeqA'], ex['SeqB'] |
| | trunc_a, trunc_b = seq_a, seq_b |
| | while len(trunc_a) + len(trunc_b) > self._max_length: |
| | if len(trunc_a) > len(trunc_b): |
| | trunc_a = trunc_a[:-1] |
| | else: |
| | trunc_b = trunc_b[:-1] |
| | ex['SeqA'] = trunc_a |
| | ex['SeqB'] = trunc_b |
| | return ex |
| |
|
| | def _active_translation_mode(self): |
| | mode_to_flag = { |
| | 'aa_to_dna': self._aa_to_dna, |
| | 'aa_to_rna': self._aa_to_rna, |
| | 'dna_to_aa': self._dna_to_aa, |
| | 'rna_to_aa': self._rna_to_aa, |
| | 'codon_to_aa': self._codon_to_aa, |
| | 'aa_to_codon': self._aa_to_codon, |
| | } |
| | active_modes = [mode for mode, enabled in mode_to_flag.items() if enabled] |
| | assert len(active_modes) <= 1, f'Only one translation mode can be enabled at a time, found: {active_modes}' |
| | return active_modes[0] if len(active_modes) == 1 else None |
| |
|
| | def _assert_characters_in_set(self, seq, allowed_chars, mode): |
| | bad_chars = sorted({char for char in seq if char.upper() not in allowed_chars}) |
| | assert len(bad_chars) == 0, f'Invalid characters for {mode}: {bad_chars}.' |
| |
|
| | def _validate_translated_output(self, translated_seq, allowed_chars, mode): |
| | bad_chars = sorted({char for char in translated_seq if char not in allowed_chars}) |
| | assert len(bad_chars) == 0, f'Translation output for {mode} contains unexpected characters: {bad_chars}.' |
| |
|
| | def _normalize_aa_for_nucleotide_translation(self, seq): |
| | canonical_aas = set(AMINO_ACID_TO_HUMAN_CODON.keys()) |
| | normalized = [] |
| | for residue in seq: |
| | residue = residue.upper() |
| | if residue in canonical_aas: |
| | normalized.append(residue) |
| | else: |
| | normalized.append('X') |
| | return ''.join(normalized) |
| |
|
| | def _translate_aa_to_dna(self, seq): |
| | seq = self._normalize_aa_for_nucleotide_translation(seq) |
| | dna_codons = [] |
| | for residue in seq: |
| | residue = residue.upper() |
| | if residue in AMINO_ACID_TO_HUMAN_CODON: |
| | dna_codons.append(AMINO_ACID_TO_HUMAN_CODON[residue]) |
| | elif residue in NONCANONICAL_AMINO_ACIDS: |
| | dna_codons.append(NONCANONICAL_ALANINE_CODON) |
| | else: |
| | raise AssertionError(f'Unexpected amino acid token "{residue}" while converting aa_to_dna.') |
| | translated = ''.join(dna_codons) |
| | self._validate_translated_output(translated, DNA_SET, 'aa_to_dna') |
| | return translated |
| |
|
| | def _translate_aa_to_rna(self, seq): |
| | dna_translated = self._translate_aa_to_dna(seq) |
| | translated = dna_translated.replace('T', 'U') |
| | self._validate_translated_output(translated, RNA_SET, 'aa_to_rna') |
| | return translated |
| |
|
| | def _translate_dna_to_aa(self, seq): |
| | dna_seq = seq.upper() |
| | self._assert_characters_in_set(dna_seq, DNA_SET, 'dna_to_aa') |
| | assert len(dna_seq) % 3 == 0, f'dna_to_aa requires sequence length multiple of 3, got {len(dna_seq)}.' |
| | aa_seq = [] |
| | for idx in range(0, len(dna_seq), 3): |
| | codon = dna_seq[idx:idx + 3] |
| | assert codon in DNA_CODON_TO_AA, f'Unknown DNA codon for dna_to_aa: {codon}' |
| | translated_char = DNA_CODON_TO_AA[codon] |
| | if translated_char != '*': |
| | aa_seq.append(translated_char) |
| | translated = ''.join(aa_seq) |
| | self._validate_translated_output(translated, AA_SET - {'*'}, 'dna_to_aa') |
| | return translated |
| |
|
| | def _translate_rna_to_aa(self, seq): |
| | rna_seq = seq.upper() |
| | self._assert_characters_in_set(rna_seq, RNA_SET, 'rna_to_aa') |
| | assert len(rna_seq) % 3 == 0, f'rna_to_aa requires sequence length multiple of 3, got {len(rna_seq)}.' |
| | aa_seq = [] |
| | for idx in range(0, len(rna_seq), 3): |
| | codon = rna_seq[idx:idx + 3] |
| | assert codon in RNA_CODON_TO_AA, f'Unknown RNA codon for rna_to_aa: {codon}' |
| | translated_char = RNA_CODON_TO_AA[codon] |
| | if translated_char != '*': |
| | aa_seq.append(translated_char) |
| | translated = ''.join(aa_seq) |
| | self._validate_translated_output(translated, AA_SET - {'*'}, 'rna_to_aa') |
| | return translated |
| |
|
| | def _translate_codon_to_aa(self, seq): |
| | aa_seq = [] |
| | for token in seq: |
| | assert token in CODON_TO_AA, f'Unknown codon token for codon_to_aa: {token}' |
| | translated_char = CODON_TO_AA[token] |
| | if translated_char != '*': |
| | aa_seq.append(translated_char) |
| | translated = ''.join(aa_seq) |
| | self._validate_translated_output(translated, AA_SET - {'*'}, 'codon_to_aa') |
| | return translated |
| |
|
| | def _translate_aa_to_codon(self, seq): |
| | codon_tokens = [] |
| | for residue in seq: |
| | residue = residue.upper() |
| | if residue in AA_TO_CODON_TOKEN: |
| | codon_tokens.append(AA_TO_CODON_TOKEN[residue]) |
| | elif residue in NONCANONICAL_AMINO_ACIDS: |
| | codon_tokens.append(AA_TO_CODON_TOKEN['A']) |
| | else: |
| | raise AssertionError(f'Unexpected amino acid token "{residue}" while converting aa_to_codon.') |
| | translated = ''.join(codon_tokens) |
| | self._validate_translated_output(translated, CODON_SET, 'aa_to_codon') |
| | return translated |
| |
|
| | def _translate_sequence_for_mode(self, seq, mode): |
| | if mode == 'aa_to_dna': |
| | return self._translate_aa_to_dna(seq) |
| | if mode == 'aa_to_rna': |
| | return self._translate_aa_to_rna(seq) |
| | if mode == 'dna_to_aa': |
| | return self._translate_dna_to_aa(seq) |
| | if mode == 'rna_to_aa': |
| | return self._translate_rna_to_aa(seq) |
| | if mode == 'codon_to_aa': |
| | return self._translate_codon_to_aa(seq) |
| | if mode == 'aa_to_codon': |
| | return self._translate_aa_to_codon(seq) |
| | raise AssertionError(f'Unsupported translation mode: {mode}') |
| |
|
| | def _find_first_present_column(self, available_columns, candidates_ordered): |
| | """Return the first column from candidates_ordered that exists in available_columns (case-insensitive).""" |
| | lowercase_to_actual = {col.lower(): col for col in available_columns} |
| | for candidate in candidates_ordered: |
| | actual = lowercase_to_actual.get(candidate.lower()) |
| | if actual is not None: |
| | return actual |
| | raise KeyError(f"None of the candidate columns were found. Candidates: {candidates_ordered}. Available: {available_columns}") |
| |
|
| | def _is_ppi_from_columns(self, available_columns): |
| | """Detect if dataset contains paired sequence inputs (SeqA/SeqB variants).""" |
| | lowercase_columns = set(col.lower() for col in available_columns) |
| | base_candidates = ['seqs', 'seq', 'sequence', 'sequences'] |
| | for base in base_candidates: |
| | if (base + 'a') in lowercase_columns and (base + 'b') in lowercase_columns: |
| | return True |
| | return False |
| |
|
| | def _find_ppi_sequence_columns(self, available_columns): |
| | """Return the actual column names for A and B sequences in PPI datasets based on priority.""" |
| | lowercase_to_actual = {col.lower(): col for col in available_columns} |
| | |
| | specific_pairs = [ |
| | ('SeqA', 'SeqB'), |
| | ('seqa', 'seqb'), |
| | ('SeqsA', 'SeqsB'), |
| | ] |
| | for cand_a, cand_b in specific_pairs: |
| | a_actual = lowercase_to_actual.get(cand_a.lower()) |
| | b_actual = lowercase_to_actual.get(cand_b.lower()) |
| | if a_actual is not None and b_actual is not None: |
| | return a_actual, b_actual |
| |
|
| | |
| | base_candidates = ['seqs', 'seq', 'sequence', 'sequences'] |
| | for base in base_candidates: |
| | a_key = (base + 'a').lower() |
| | b_key = (base + 'b').lower() |
| | a_actual = lowercase_to_actual.get(a_key) |
| | b_actual = lowercase_to_actual.get(b_key) |
| | if a_actual is not None and b_actual is not None: |
| | return a_actual, b_actual |
| |
|
| | raise KeyError(f"Could not find paired sequence columns for PPI. Available: {available_columns}") |
| |
|
| | def _is_missing_value(self, v): |
| | if v is None: |
| | return True |
| | |
| | try: |
| | if isinstance(v, float) and np.isnan(v): |
| | return True |
| | except Exception: |
| | pass |
| | |
| | if isinstance(v, (list, tuple, np.ndarray)): |
| | for el in v: |
| | if el is None: |
| | return True |
| | if isinstance(el, float) and np.isnan(el): |
| | return True |
| | return False |
| |
|
| | def process_datasets( |
| | self, |
| | hf_datasets: List[Tuple[Dataset, Dataset, Dataset, bool]], |
| | data_names: List[str], |
| | )-> Tuple[Dict[str, Tuple[Dataset, Dataset, Dataset, int, str, bool]], List[str]]: |
| | max_length = self._max_length |
| | datasets, all_seqs = {}, set() |
| | translation_mode = self._active_translation_mode() |
| | for dataset, data_name in zip(hf_datasets, data_names): |
| | print_message(f'Processing {data_name}') |
| | train_set, valid_set, test_set, ppi = dataset |
| | print(train_set) |
| | print(valid_set) |
| | print(test_set) |
| | |
| | |
| | before_train, before_valid, before_test = len(train_set), len(valid_set), len(test_set) |
| | if ppi: |
| | train_set = train_set.filter(lambda x: not (self._is_missing_value(x['SeqA']) or self._is_missing_value(x['SeqB']) or self._is_missing_value(x['labels']))) |
| | valid_set = valid_set.filter(lambda x: not (self._is_missing_value(x['SeqA']) or self._is_missing_value(x['SeqB']) or self._is_missing_value(x['labels']))) |
| | test_set = test_set.filter(lambda x: not (self._is_missing_value(x['SeqA']) or self._is_missing_value(x['SeqB']) or self._is_missing_value(x['labels']))) |
| | elif self.data_args.multi_column: |
| | cols = self.data_args.multi_column |
| | |
| | for col in cols: |
| | assert col in train_set.column_names or col in valid_set.column_names or col in test_set.column_names, f"Column {col} not found in dataset {data_name}" |
| |
|
| | def _filter_row(x): |
| | return (not self._is_missing_value(x['labels'])) and all(not self._is_missing_value(x[col]) for col in cols) |
| |
|
| | train_set = train_set.filter(_filter_row) |
| | valid_set = valid_set.filter(_filter_row) |
| | test_set = test_set.filter(_filter_row) |
| | else: |
| | train_set = train_set.filter(lambda x: not (self._is_missing_value(x['seqs']) or self._is_missing_value(x['labels']))) |
| | valid_set = valid_set.filter(lambda x: not (self._is_missing_value(x['seqs']) or self._is_missing_value(x['labels']))) |
| | test_set = test_set.filter(lambda x: not (self._is_missing_value(x['seqs']) or self._is_missing_value(x['labels']))) |
| | if any([ |
| | len(train_set) != before_train, |
| | len(valid_set) != before_valid, |
| | len(test_set) != before_test, |
| | ]): |
| | print_message( |
| | f"Removed None / NaN rows - train: {before_train - len(train_set)}, valid: {before_valid - len(valid_set)}, test: {before_test - len(test_set)}" |
| | ) |
| |
|
| | |
| | if translation_mode is None: |
| | if ppi: |
| | train_set = train_set.map(lambda x: {'SeqA': ''.join(aa for aa in x['SeqA'] if aa in AA_SET), |
| | 'SeqB': ''.join(aa for aa in x['SeqB'] if aa in AA_SET)}) |
| | valid_set = valid_set.map(lambda x: {'SeqA': ''.join(aa for aa in x['SeqA'] if aa in AA_SET), |
| | 'SeqB': ''.join(aa for aa in x['SeqB'] if aa in AA_SET)}) |
| | test_set = test_set.map(lambda x: {'SeqA': ''.join(aa for aa in x['SeqA'] if aa in AA_SET), |
| | 'SeqB': ''.join(aa for aa in x['SeqB'] if aa in AA_SET)}) |
| | elif self.data_args.multi_column: |
| | cols = self.data_args.multi_column |
| | for col in cols: |
| | train_set = train_set.map(lambda x, _col=col: {_col: ''.join(aa for aa in x[_col] if aa in AA_SET)}) |
| | valid_set = valid_set.map(lambda x, _col=col: {_col: ''.join(aa for aa in x[_col] if aa in AA_SET)}) |
| | test_set = test_set.map(lambda x, _col=col: {_col: ''.join(aa for aa in x[_col] if aa in AA_SET)}) |
| | else: |
| | train_set = train_set.map(lambda x: {'seqs': ''.join(aa for aa in x['seqs'] if aa in AA_SET)}) |
| | valid_set = valid_set.map(lambda x: {'seqs': ''.join(aa for aa in x['seqs'] if aa in AA_SET)}) |
| | test_set = test_set.map(lambda x: {'seqs': ''.join(aa for aa in x['seqs'] if aa in AA_SET)}) |
| |
|
| | |
| | before_train, before_valid, before_test = len(train_set), len(valid_set), len(test_set) |
| | if ppi: |
| | train_set = train_set.filter(lambda x: len(x['SeqA']) > 0 and len(x['SeqB']) > 0) |
| | valid_set = valid_set.filter(lambda x: len(x['SeqA']) > 0 and len(x['SeqB']) > 0) |
| | test_set = test_set.filter(lambda x: len(x['SeqA']) > 0 and len(x['SeqB']) > 0) |
| | elif self.data_args.multi_column: |
| | cols = self.data_args.multi_column |
| | train_set = train_set.filter(lambda x: all(len(x[col]) > 0 for col in cols)) |
| | valid_set = valid_set.filter(lambda x: all(len(x[col]) > 0 for col in cols)) |
| | test_set = test_set.filter(lambda x: all(len(x[col]) > 0 for col in cols)) |
| | else: |
| | train_set = train_set.filter(lambda x: len(x['seqs']) > 0) |
| | valid_set = valid_set.filter(lambda x: len(x['seqs']) > 0) |
| | test_set = test_set.filter(lambda x: len(x['seqs']) > 0) |
| |
|
| | if any([ |
| | len(train_set) != before_train, |
| | len(valid_set) != before_valid, |
| | len(test_set) != before_test, |
| | ]): |
| | print_message( |
| | f"Removed length 0 rows - train: {before_train - len(train_set)}, valid: {before_valid - len(valid_set)}, test: {before_test - len(test_set)}" |
| | ) |
| |
|
| | |
| | before_train, before_valid, before_test = len(train_set), len(valid_set), len(test_set) |
| | if self._trim: |
| | if ppi: |
| | train_set = train_set.filter(lambda x: len(x['SeqA']) + len(x['SeqB']) <= max_length) |
| | valid_set = valid_set.filter(lambda x: len(x['SeqA']) + len(x['SeqB']) <= max_length) |
| | test_set = test_set.filter(lambda x: len(x['SeqA']) + len(x['SeqB']) <= max_length) |
| | elif self.data_args.multi_column: |
| | cols = self.data_args.multi_column |
| | train_set = train_set.filter(lambda x: all(len(x[col]) <= max_length for col in cols)) |
| | valid_set = valid_set.filter(lambda x: all(len(x[col]) <= max_length for col in cols)) |
| | test_set = test_set.filter(lambda x: all(len(x[col]) <= max_length for col in cols)) |
| | else: |
| | train_set = train_set.filter(lambda x: len(x['seqs']) <= max_length) |
| | valid_set = valid_set.filter(lambda x: len(x['seqs']) <= max_length) |
| | test_set = test_set.filter(lambda x: len(x['seqs']) <= max_length) |
| |
|
| | else: |
| | if ppi: |
| | train_set = train_set.map(self._truncate_pairs) |
| | valid_set = valid_set.map(self._truncate_pairs) |
| | test_set = test_set.map(self._truncate_pairs) |
| | elif self.data_args.multi_column: |
| | cols = self.data_args.multi_column |
| | for col in cols: |
| | train_set = train_set.map(lambda x, _col=col: { _col: x[_col][:max_length] }) |
| | valid_set = valid_set.map(lambda x, _col=col: { _col: x[_col][:max_length] }) |
| | test_set = test_set.map(lambda x, _col=col: { _col: x[_col][:max_length] }) |
| | else: |
| | train_set = train_set.map(lambda x: {'seqs': x['seqs'][:max_length]}) |
| | valid_set = valid_set.map(lambda x: {'seqs': x['seqs'][:max_length]}) |
| | test_set = test_set.map(lambda x: {'seqs': x['seqs'][:max_length]}) |
| |
|
| | if any([ |
| | len(train_set) != before_train, |
| | len(valid_set) != before_valid, |
| | len(test_set) != before_test, |
| | ]): |
| | print_message( |
| | f"Trimmed rows - train: {(before_train - len(train_set)) / before_train * 100:.2f}%, \ |
| | valid: {(before_valid - len(valid_set)) / before_valid * 100:.2f}%, \ |
| | test: {(before_test - len(test_set)) / before_test * 100:.2f}%" |
| | ) |
| |
|
| | |
| | if translation_mode is not None: |
| | if ppi: |
| | train_set = train_set.map(lambda x: {'SeqA': self._translate_sequence_for_mode(x['SeqA'], translation_mode), |
| | 'SeqB': self._translate_sequence_for_mode(x['SeqB'], translation_mode)}) |
| | valid_set = valid_set.map(lambda x: {'SeqA': self._translate_sequence_for_mode(x['SeqA'], translation_mode), |
| | 'SeqB': self._translate_sequence_for_mode(x['SeqB'], translation_mode)}) |
| | test_set = test_set.map(lambda x: {'SeqA': self._translate_sequence_for_mode(x['SeqA'], translation_mode), |
| | 'SeqB': self._translate_sequence_for_mode(x['SeqB'], translation_mode)}) |
| | elif self.data_args.multi_column: |
| | cols = self.data_args.multi_column |
| | for col in cols: |
| | train_set = train_set.map(lambda x, _col=col: {_col: self._translate_sequence_for_mode(x[_col], translation_mode)}) |
| | valid_set = valid_set.map(lambda x, _col=col: {_col: self._translate_sequence_for_mode(x[_col], translation_mode)}) |
| | test_set = test_set.map(lambda x, _col=col: {_col: self._translate_sequence_for_mode(x[_col], translation_mode)}) |
| | else: |
| | train_set = train_set.map(lambda x: {'seqs': self._translate_sequence_for_mode(x['seqs'], translation_mode)}) |
| | valid_set = valid_set.map(lambda x: {'seqs': self._translate_sequence_for_mode(x['seqs'], translation_mode)}) |
| | test_set = test_set.map(lambda x: {'seqs': self._translate_sequence_for_mode(x['seqs'], translation_mode)}) |
| | print_message(f"Translated sequences with mode {translation_mode} (post-trim/truncate).") |
| |
|
| | |
| | if ppi: |
| | all_seqs.update(list(train_set['SeqA']) + list(train_set['SeqB'])) |
| | all_seqs.update(list(valid_set['SeqA']) + list(valid_set['SeqB'])) |
| | all_seqs.update(list(test_set['SeqA']) + list(test_set['SeqB'])) |
| | elif self.data_args.multi_column: |
| | cols = self.data_args.multi_column |
| | for col in cols: |
| | all_seqs.update(list(train_set[col])) |
| | all_seqs.update(list(valid_set[col])) |
| | all_seqs.update(list(test_set[col])) |
| | else: |
| | all_seqs.update(list(train_set['seqs'])) |
| | all_seqs.update(list(valid_set['seqs'])) |
| | all_seqs.update(list(test_set['seqs'])) |
| |
|
| | |
| | check_labels = list(valid_set['labels']) |
| | label_type = self._label_type_checker(check_labels) |
| |
|
| | if label_type == 'string': |
| | example = list(valid_set['labels'])[0] |
| | try: |
| | import ast |
| | new_ex = ast.literal_eval(example) |
| | if isinstance(new_ex, list): |
| | label_type = 'multilabel' |
| | train_set = train_set.map(lambda ex: {'labels': ast.literal_eval(ex['labels'])}) |
| | valid_set = valid_set.map(lambda ex: {'labels': ast.literal_eval(ex['labels'])}) |
| | test_set = test_set.map(lambda ex: {'labels': ast.literal_eval(ex['labels'])}) |
| | except: |
| | label_type = 'string' |
| |
|
| | if label_type == 'string': |
| | train_labels = list(train_set['labels']) |
| | unique_tags = set(tag for doc in train_labels for tag in doc) |
| | tag2id = {tag: id for id, tag in enumerate(sorted(unique_tags))} |
| | |
| | train_set = train_set.map(lambda ex: {'labels': self._encode_labels(ex['labels'], tag2id=tag2id)}) |
| | valid_set = valid_set.map(lambda ex: {'labels': self._encode_labels(ex['labels'], tag2id=tag2id)}) |
| | test_set = test_set.map(lambda ex: {'labels': self._encode_labels(ex['labels'], tag2id=tag2id)}) |
| | label_type = 'tokenwise' |
| | num_labels = len(unique_tags) |
| | else: |
| | if label_type == 'regression': |
| | |
| | if self._is_sigmoid_regression(list(train_set['labels'])): |
| | label_type = 'sigmoid_regression' |
| | num_labels = 1 |
| | else: |
| | try: |
| | train_labels_list = list(train_set['labels']) |
| | num_labels = len(train_labels_list[0]) |
| | except: |
| | unique = np.unique(list(train_set['labels'])) |
| | max_label = max(unique) |
| | full_list = np.arange(0, max_label+1) |
| | num_labels = len(full_list) |
| | datasets[data_name] = (train_set, valid_set, test_set, num_labels, label_type, ppi) |
| |
|
| | print(f'Label type: {label_type}') |
| | print(f'Number of labels: {num_labels}') |
| |
|
| | all_seqs = list(all_seqs) |
| | all_seqs = sorted(all_seqs, key=len, reverse=True) |
| | return datasets, all_seqs |
| |
|
| | def get_data(self): |
| | """ |
| | Supports .csv, .tsv, .txt |
| | TODO fasta, fa, fna, etc. |
| | """ |
| | datasets, data_names = [], [] |
| | label_candidates = ['labels', 'label', 'Labels', 'Label'] |
| | seq_candidates = ['seqs', 'Seqs', 'seq', 'Seq', 'sequence', 'Sequence', 'sequences', 'Sequences'] |
| |
|
| | for data_path in self.data_args.data_paths: |
| | data_name = data_path.split('/')[-1] |
| | print_message(f'Loading {data_name}') |
| | dataset = load_dataset(data_path) |
| | if 'inverse' in data_name.lower(): |
| | dataset = dataset.rename_columns({'seqs': 'labels', 'labels': 'seqs'}) |
| | ppi = 'SeqA' in dataset['train'].column_names |
| | |
| | if not ppi: |
| | ppi = self._is_ppi_from_columns(dataset['train'].column_names) |
| | print_message(f'PPI (or dual sequence input dataset): {ppi}') |
| |
|
| | |
| | assert 'train' in dataset, f'{data_name} does not have a train set' |
| | assert 'valid' in dataset or 'test' in dataset, f'{data_name} does not have a valid or test set, needs at least one' |
| | |
| | if 'valid' not in dataset: |
| | seed = get_global_seed() if get_global_seed() is not None else 42 |
| | train_set = dataset['train'] |
| | train_valid_set = train_set.train_test_split(test_size=0.1, seed=seed + 1) |
| | train_set = train_valid_set['train'] |
| | valid_set = train_valid_set['test'] |
| | test_set = dataset['test'] |
| | print_message(f'{data_name} does not have a valid set, created a 10% validation set') |
| | elif 'test' not in dataset: |
| | seed = get_global_seed() if get_global_seed() is not None else 42 |
| | train_set = dataset['train'] |
| | train_test_set = train_set.train_test_split(test_size=0.1, seed=seed + 2) |
| | test_set = train_test_set['test'] |
| | train_set = train_test_set['train'] |
| | valid_set = dataset['valid'] |
| | print_message(f'{data_name} does not have a test set, created a 10% test set') |
| | else: |
| | train_set, valid_set, test_set = dataset['train'], dataset['valid'], dataset['test'] |
| | print_message(f'{data_name} has a valid and test set') |
| |
|
| | print_message(f'Train set: {len(train_set)}, Valid set: {len(valid_set)}, Test set: {len(test_set)}') |
| | if ppi: |
| | |
| | print('Standardizing PPI column names') |
| | try: |
| | a_col, b_col = self._find_ppi_sequence_columns(train_set.column_names) |
| | except KeyError: |
| | |
| | try: |
| | a_col, b_col = self._find_ppi_sequence_columns(valid_set.column_names) |
| | except KeyError: |
| | a_col, b_col = self._find_ppi_sequence_columns(test_set.column_names) |
| | |
| | try: |
| | lbl_col = self._find_first_present_column(train_set.column_names, label_candidates) |
| | except KeyError: |
| | try: |
| | lbl_col = self._find_first_present_column(valid_set.column_names, label_candidates) |
| | except KeyError: |
| | lbl_col = self._find_first_present_column(test_set.column_names, label_candidates) |
| |
|
| | train_set = train_set.rename_columns({a_col: 'SeqA', b_col: 'SeqB', lbl_col: 'labels'}) |
| | valid_set = valid_set.rename_columns({a_col: 'SeqA', b_col: 'SeqB', lbl_col: 'labels'}) |
| | test_set = test_set.rename_columns({a_col: 'SeqA', b_col: 'SeqB', lbl_col: 'labels'}) |
| |
|
| | print('Removing extras') |
| | train_set = train_set.remove_columns([col for col in train_set.column_names if col not in ['SeqA', 'SeqB', 'labels']]) |
| | valid_set = valid_set.remove_columns([col for col in valid_set.column_names if col not in ['SeqA', 'SeqB', 'labels']]) |
| | test_set = test_set.remove_columns([col for col in test_set.column_names if col not in ['SeqA', 'SeqB', 'labels']]) |
| | else: |
| | print('Standardizing column names') |
| | use_multi = self.data_args.multi_column is not None |
| | if not use_multi: |
| | try: |
| | seq_col = self._find_first_present_column(train_set.column_names, seq_candidates) |
| | except KeyError: |
| | try: |
| | seq_col = self._find_first_present_column(valid_set.column_names, seq_candidates) |
| | except KeyError: |
| | seq_col = self._find_first_present_column(test_set.column_names, seq_candidates) |
| |
|
| | try: |
| | label_col = self._find_first_present_column(train_set.column_names, label_candidates) |
| | except KeyError: |
| | try: |
| | label_col = self._find_first_present_column(valid_set.column_names, label_candidates) |
| | except KeyError: |
| | label_col = self._find_first_present_column(test_set.column_names, label_candidates) |
| |
|
| | |
| | train_set = train_set.rename_columns({label_col: 'labels'}) |
| | valid_set = valid_set.rename_columns({label_col: 'labels'}) |
| | test_set = test_set.rename_columns({label_col: 'labels'}) |
| |
|
| | if not use_multi: |
| | train_set = train_set.rename_columns({seq_col: 'seqs'}) |
| | valid_set = valid_set.rename_columns({seq_col: 'seqs'}) |
| | test_set = test_set.rename_columns({seq_col: 'seqs'}) |
| | |
| | print('Removing extras') |
| | train_set = train_set.remove_columns([col for col in train_set.column_names if col not in ['seqs', 'labels']]) |
| | valid_set = valid_set.remove_columns([col for col in valid_set.column_names if col not in ['seqs', 'labels']]) |
| | test_set = test_set.remove_columns([col for col in test_set.column_names if col not in ['seqs', 'labels']]) |
| | else: |
| | |
| | for col in self.data_args.multi_column: |
| | assert col in train_set.column_names or col in valid_set.column_names or col in test_set.column_names, f"Column {col} not found in dataset {data_name}" |
| | |
| | keep_cols = set(self.data_args.multi_column + ['labels']) |
| | train_set = train_set.remove_columns([col for col in train_set.column_names if col not in keep_cols]) |
| | valid_set = valid_set.remove_columns([col for col in valid_set.column_names if col not in keep_cols]) |
| | test_set = test_set.remove_columns([col for col in test_set.column_names if col not in keep_cols]) |
| |
|
| | datasets.append((train_set, valid_set, test_set, ppi)) |
| | data_names.append(data_name) |
| |
|
| | for data_dir in self.data_args.data_dirs: |
| | |
| | data_name = data_dir.split ('/')[-1] |
| | |
| | ppi = 'ppi' in data_dir.lower() |
| | train_path = glob(os.path.join(data_dir, 'train.*'))[0] |
| | valid_path = glob(os.path.join(data_dir, 'valid.*'))[0] |
| | test_path = glob(os.path.join(data_dir, 'test.*'))[0] |
| | if '.xlsx' in train_path: |
| | train_set = read_excel(train_path) |
| | valid_set = read_excel(valid_path) |
| | test_set = read_excel(test_path) |
| | else: |
| | train_set = read_csv(train_path, delimiter=self._delimiter) |
| | valid_set = read_csv(valid_path, delimiter=self._delimiter) |
| | test_set = read_csv(test_path, delimiter=self._delimiter) |
| |
|
| | train_set = Dataset.from_pandas(train_set) |
| | valid_set = Dataset.from_pandas(valid_set) |
| | test_set = Dataset.from_pandas(test_set) |
| |
|
| | |
| | if not ppi: |
| | ppi = self._is_ppi_from_columns(train_set.column_names) |
| |
|
| | if ppi: |
| | print('Standardizing PPI column names') |
| | try: |
| | a_col, b_col = self._find_ppi_sequence_columns(train_set.column_names) |
| | except KeyError: |
| | try: |
| | a_col, b_col = self._find_ppi_sequence_columns(valid_set.column_names) |
| | except KeyError: |
| | a_col, b_col = self._find_ppi_sequence_columns(test_set.column_names) |
| |
|
| | try: |
| | lbl_col = self._find_first_present_column(train_set.column_names, label_candidates) |
| | except KeyError: |
| | try: |
| | lbl_col = self._find_first_present_column(valid_set.column_names, label_candidates) |
| | except KeyError: |
| | lbl_col = self._find_first_present_column(test_set.column_names, label_candidates) |
| |
|
| | train_set = train_set.rename_columns({a_col: 'SeqA', b_col: 'SeqB', lbl_col: 'labels'}) |
| | valid_set = valid_set.rename_columns({a_col: 'SeqA', b_col: 'SeqB', lbl_col: 'labels'}) |
| | test_set = test_set.rename_columns({a_col: 'SeqA', b_col: 'SeqB', lbl_col: 'labels'}) |
| |
|
| | print('Removing extras') |
| | train_set = train_set.remove_columns([col for col in train_set.column_names if col not in ['SeqA', 'SeqB', 'labels']]) |
| | valid_set = valid_set.remove_columns([col for col in valid_set.column_names if col not in ['SeqA', 'SeqB', 'labels']]) |
| | test_set = test_set.remove_columns([col for col in test_set.column_names if col not in ['SeqA', 'SeqB', 'labels']]) |
| | else: |
| | print('Standardizing column names') |
| | use_multi = self.data_args.multi_column is not None |
| | if not use_multi: |
| | try: |
| | seq_col = self._find_first_present_column(train_set.column_names, seq_candidates) |
| | except KeyError: |
| | try: |
| | seq_col = self._find_first_present_column(valid_set.column_names, seq_candidates) |
| | except KeyError: |
| | seq_col = self._find_first_present_column(test_set.column_names, seq_candidates) |
| |
|
| | try: |
| | label_col = self._find_first_present_column(train_set.column_names, label_candidates) |
| | except KeyError: |
| | try: |
| | label_col = self._find_first_present_column(valid_set.column_names, label_candidates) |
| | except KeyError: |
| | label_col = self._find_first_present_column(test_set.column_names, label_candidates) |
| |
|
| | |
| | train_set = train_set.rename_columns({label_col: 'labels'}) |
| | valid_set = valid_set.rename_columns({label_col: 'labels'}) |
| | test_set = test_set.rename_columns({label_col: 'labels'}) |
| |
|
| | if not use_multi: |
| | train_set = train_set.rename_columns({seq_col: 'seqs'}) |
| | valid_set = valid_set.rename_columns({seq_col: 'seqs'}) |
| | test_set = test_set.rename_columns({seq_col: 'seqs'}) |
| | |
| | print('Removing extras') |
| | train_set = train_set.remove_columns([col for col in train_set.column_names if col not in ['seqs', 'labels']]) |
| | valid_set = valid_set.remove_columns([col for col in valid_set.column_names if col not in ['seqs', 'labels']]) |
| | test_set = test_set.remove_columns([col for col in test_set.column_names if col not in ['seqs', 'labels']]) |
| | else: |
| | |
| | for col in self.data_args.multi_column: |
| | assert col in train_set.column_names or col in valid_set.column_names or col in test_set.column_names, f"Column {col} not found in dataset {data_name}" |
| | |
| | keep_cols = set(self.data_args.multi_column + ['labels']) |
| | train_set = train_set.remove_columns([col for col in train_set.column_names if col not in keep_cols]) |
| | valid_set = valid_set.remove_columns([col for col in valid_set.column_names if col not in keep_cols]) |
| | test_set = test_set.remove_columns([col for col in test_set.column_names if col not in keep_cols]) |
| |
|
| | datasets.append((train_set, valid_set, test_set, ppi)) |
| | data_names.append(data_name) |
| |
|
| | return self.process_datasets(hf_datasets=datasets, data_names=data_names) |
| |
|
| | def get_embedding_dim_sql(self, save_path, test_seq, tokenizer): |
| | import sqlite3 |
| | test_seq_len = len(tokenizer(test_seq, return_tensors='pt')['input_ids'][0]) |
| | |
| | with sqlite3.connect(save_path) as conn: |
| | c = conn.cursor() |
| | c.execute("SELECT embedding FROM embeddings WHERE sequence = ?", (test_seq,)) |
| | test_embedding = c.fetchone()[0] |
| | test_embedding = torch.tensor(np.frombuffer(test_embedding, dtype=np.float32).reshape(1, -1)) |
| | if self._full: |
| | test_embedding = test_embedding.reshape(test_seq_len, -1) |
| | embedding_dim = test_embedding.shape[-1] |
| | return embedding_dim |
| |
|
| | def get_embedding_dim_pth(self, emb_dict, test_seq, tokenizer): |
| | test_seq_len = len(tokenizer(test_seq, return_tensors='pt')['input_ids'][0]) |
| | test_embedding = emb_dict[test_seq] |
| | if self._full: |
| | test_embedding = test_embedding.reshape(test_seq_len, -1) |
| | embedding_dim = test_embedding.shape[-1] |
| | return embedding_dim |
| |
|
| | def build_vector_numpy_dataset_from_embeddings( |
| | self, |
| | model_name, |
| | train_seqs, |
| | valid_seqs, |
| | test_seqs, |
| | ): |
| | save_dir = self.embedding_args.embedding_save_dir |
| | train_array, valid_array, test_array = [], [], [] |
| | |
| | pooling_types = self.embedding_args.pooling_types |
| | if self._sql: |
| | import sqlite3 |
| | filename = get_embedding_filename(model_name, self._full, pooling_types, 'db') |
| | save_path = os.path.join(save_dir, filename) |
| | with sqlite3.connect(save_path) as conn: |
| | c = conn.cursor() |
| | for seq in train_seqs: |
| | embedding = self._select_from_sql(c, seq, cast_to_torch=False) |
| | train_array.append(embedding) |
| |
|
| | for seq in valid_seqs: |
| | embedding = self._select_from_sql(c, seq, cast_to_torch=False) |
| | valid_array.append(embedding) |
| |
|
| | for seq in test_seqs: |
| | embedding = self._select_from_sql(c, seq, cast_to_torch=False) |
| | test_array.append(embedding) |
| | else: |
| | filename = get_embedding_filename(model_name, self._full, pooling_types, 'pth') |
| | save_path = os.path.join(save_dir, filename) |
| | emb_dict = torch.load(save_path) |
| | for seq in train_seqs: |
| | embedding = self._select_from_pth(emb_dict, seq, cast_to_np=True) |
| | train_array.append(embedding) |
| | |
| | for seq in valid_seqs: |
| | embedding = self._select_from_pth(emb_dict, seq, cast_to_np=True) |
| | valid_array.append(embedding) |
| |
|
| | for seq in test_seqs: |
| | embedding = self._select_from_pth(emb_dict, seq, cast_to_np=True) |
| | test_array.append(embedding) |
| | del emb_dict |
| |
|
| | train_array = np.concatenate(train_array, axis=0) |
| | valid_array = np.concatenate(valid_array, axis=0) |
| | test_array = np.concatenate(test_array, axis=0) |
| | |
| | if self._full: |
| | train_array = np.mean(train_array, axis=1) |
| | valid_array = np.mean(valid_array, axis=1) |
| | test_array = np.mean(test_array, axis=1) |
| |
|
| | print_message('Numpy dataset shapes') |
| | print_message(f'Train: {train_array.shape}') |
| | print_message(f'Valid: {valid_array.shape}') |
| | print_message(f'Test: {test_array.shape}') |
| | return train_array, valid_array, test_array |
| |
|
| | def build_pair_vector_numpy_dataset_from_embeddings( |
| | self, |
| | model_name, |
| | train_seqs_a, |
| | train_seqs_b, |
| | valid_seqs_a, |
| | valid_seqs_b, |
| | test_seqs_a, |
| | test_seqs_b, |
| | ): |
| | save_dir = self.embedding_args.embedding_save_dir |
| | train_array, valid_array, test_array = [], [], [] |
| | pooling_types = self.embedding_args.pooling_types |
| | if self._sql: |
| | filename = get_embedding_filename(model_name, self._full, pooling_types, 'db') |
| | save_path = os.path.join(save_dir, filename) |
| | with sqlite3.connect(save_path) as conn: |
| | c = conn.cursor() |
| | for seq_a, seq_b in zip(train_seqs_a, train_seqs_b): |
| | seq_a, seq_b = self._random_order(seq_a, seq_b) |
| | embedding_a = self._select_from_sql(c, seq_a, cast_to_torch=False) |
| | embedding_b = self._select_from_sql(c, seq_b, cast_to_torch=False) |
| | train_array.append(np.concatenate([embedding_a, embedding_b], axis=-1)) |
| |
|
| | for seq_a, seq_b in zip(valid_seqs_a, valid_seqs_b): |
| | seq_a, seq_b = self._random_order(seq_a, seq_b) |
| | embedding_a = self._select_from_sql(c, seq_a, cast_to_torch=False) |
| | embedding_b = self._select_from_sql(c, seq_b, cast_to_torch=False) |
| | valid_array.append(np.concatenate([embedding_a, embedding_b], axis=-1)) |
| |
|
| | for seq_a, seq_b in zip(test_seqs_a, test_seqs_b): |
| | seq_a, seq_b = self._random_order(seq_a, seq_b) |
| | embedding_a = self._select_from_sql(c, seq_a, cast_to_torch=False) |
| | embedding_b = self._select_from_sql(c, seq_b, cast_to_torch=False) |
| | test_array.append(np.concatenate([embedding_a, embedding_b], axis=-1)) |
| | else: |
| | filename = get_embedding_filename(model_name, self._full, pooling_types, 'pth') |
| | save_path = os.path.join(save_dir, filename) |
| | emb_dict = torch.load(save_path) |
| | for seq_a, seq_b in zip(train_seqs_a, train_seqs_b): |
| | seq_a, seq_b = self._random_order(seq_a, seq_b) |
| | embedding_a = self._select_from_pth(emb_dict, seq_a, cast_to_np=True) |
| | embedding_b = self._select_from_pth(emb_dict, seq_b, cast_to_np=True) |
| | train_array.append(np.concatenate([embedding_a, embedding_b], axis=-1)) |
| |
|
| | for seq_a, seq_b in zip(valid_seqs_a, valid_seqs_b): |
| | seq_a, seq_b = self._random_order(seq_a, seq_b) |
| | embedding_a = self._select_from_pth(emb_dict, seq_a, cast_to_np=True) |
| | embedding_b = self._select_from_pth(emb_dict, seq_b, cast_to_np=True) |
| | valid_array.append(np.concatenate([embedding_a, embedding_b], axis=-1)) |
| |
|
| | for seq_a, seq_b in zip(test_seqs_a, test_seqs_b): |
| | seq_a, seq_b = self._random_order(seq_a, seq_b) |
| | embedding_a = self._select_from_pth(emb_dict, seq_a, cast_to_np=True) |
| | embedding_b = self._select_from_pth(emb_dict, seq_b, cast_to_np=True) |
| | test_array.append(np.concatenate([embedding_a, embedding_b], axis=-1)) |
| | del emb_dict |
| |
|
| | train_array = np.concatenate(train_array, axis=0) |
| | valid_array = np.concatenate(valid_array, axis=0) |
| | test_array = np.concatenate(test_array, axis=0) |
| | |
| | if self._full: |
| | train_array = np.mean(train_array, axis=1) |
| | valid_array = np.mean(valid_array, axis=1) |
| | test_array = np.mean(test_array, axis=1) |
| |
|
| | print_message('Numpy dataset shapes') |
| | print_message(f'Train: {train_array.shape}') |
| | print_message(f'Valid: {valid_array.shape}') |
| | print_message(f'Test: {test_array.shape}') |
| | return train_array, valid_array, test_array |
| |
|
| | def prepare_scikit_dataset(self, model_name, dataset): |
| | train_set, valid_set, test_set, _, label_type, ppi = dataset |
| |
|
| | if ppi: |
| | X_train, X_valid, X_test = self.build_pair_vector_numpy_dataset_from_embeddings( |
| | model_name, |
| | list(train_set['SeqA']), |
| | list(train_set['SeqB']), |
| | list(valid_set['SeqA']), |
| | list(valid_set['SeqB']), |
| | list(test_set['SeqA']), |
| | list(test_set['SeqB']), |
| | ) |
| | else: |
| | X_train, X_valid, X_test = self.build_vector_numpy_dataset_from_embeddings( |
| | model_name, |
| | list(train_set['seqs']), |
| | list(valid_set['seqs']), |
| | list(test_set['seqs']), |
| | ) |
| |
|
| | y_train = self._labels_to_numpy(list(train_set['labels'])) |
| | y_valid = self._labels_to_numpy(list(valid_set['labels'])) |
| | y_test = self._labels_to_numpy(list(test_set['labels'])) |
| |
|
| | print_message('Numpy dataset shapes with labels') |
| | print_message(f'Train: {X_train.shape}, {y_train.shape}') |
| | print_message(f'Valid: {X_valid.shape}, {y_valid.shape}') |
| | print_message(f'Test: {X_test.shape}, {y_test.shape}') |
| | return X_train, y_train, X_valid, y_valid, X_test, y_test, label_type |
| |
|