import torch from typing import Optional, Tuple AA_SET = set('LAGVSERTIPDKQNFYMHWCXBUOZ*') CODON_SET = set('aA@bB#$%rRnNdDcCeEqQ^G&ghHiIj+MmlJLkK(fFpPoO=szZwSXTtxWyYuvUV]})') DNA_SET = set('ATCG') RNA_SET = set('AUCG') NONCANONICAL_AMINO_ACIDS = set('XBUOZ*') AMINO_ACID_TO_HUMAN_CODON = { 'A': 'GCC', 'R': 'CGC', 'N': 'AAC', 'D': 'GAC', 'C': 'TGC', 'Q': 'CAG', 'E': 'GAG', 'G': 'GGC', 'H': 'CAC', 'I': 'ATC', 'L': 'CTG', 'K': 'AAG', 'M': 'ATG', 'F': 'TTC', 'P': 'CCC', 'S': 'AGC', 'T': 'ACC', 'W': 'TGG', 'Y': 'TAC', 'V': 'GTG', } NONCANONICAL_ALANINE_CODON = 'GCT' AA_TO_CODON_TOKEN = { 'A': 'A', 'R': 'B', 'N': 'N', 'D': 'D', 'C': 'C', 'Q': 'Q', 'E': 'E', 'G': 'G', 'H': 'H', 'I': 'I', 'L': 'L', 'K': 'K', 'M': '(', 'F': 'F', 'P': 'P', 'S': 'S', 'T': 'T', 'W': 'W', 'Y': 'Y', 'V': 'V', } CODON_TO_AA = { 'a':'A', 'A':'A', '@':'A', 'b':'A', 'B':'R', '#':'R', '$':'R', '%':'R', 'r':'R', 'R':'R', 'n':'N', 'N':'N', 'd':'D', 'D':'D', 'c':'C', 'C':'C', 'e':'E', 'E':'E', 'q':'Q', 'Q':'Q', '^':'G', 'G':'G', '&':'G', 'g':'G', 'h':'H', 'H':'H', 'i':'I', 'I':'I', 'j':'I', '+':'L', 'M':'L', 'm':'L', 'l':'L', 'J':'L', 'L':'L', 'k':'K', 'K':'K', '(':'M', 'f':'F', 'F':'F', 'p':'P', 'P':'P', 'o':'P', 'O':'P', '=':'S', 's':'S', 'z':'S', 'Z':'S', 'w':'S', 'S':'S', 'X':'S', 'T':'T', 't':'T', 'x':'T', 'W':'T', 'y':'Y', 'Y':'Y', 'u':'V', 'v':'V', 'U':'V', 'V':'V', ']':'*', '}':'*', ')':'*', } DNA_CODON_TO_AA = { 'TTT': 'F', 'TTC': 'F', 'TTA': 'L', 'TTG': 'L', 'TCT': 'S', 'TCC': 'S', 'TCA': 'S', 'TCG': 'S', 'TAT': 'Y', 'TAC': 'Y', 'TAA': '*', 'TAG': '*', 'TGT': 'C', 'TGC': 'C', 'TGA': '*', 'TGG': 'W', 'CTT': 'L', 'CTC': 'L', 'CTA': 'L', 'CTG': 'L', 'CCT': 'P', 'CCC': 'P', 'CCA': 'P', 'CCG': 'P', 'CAT': 'H', 'CAC': 'H', 'CAA': 'Q', 'CAG': 'Q', 'CGT': 'R', 'CGC': 'R', 'CGA': 'R', 'CGG': 'R', 'ATT': 'I', 'ATC': 'I', 'ATA': 'I', 'ATG': 'M', 'ACT': 'T', 'ACC': 'T', 'ACA': 'T', 'ACG': 'T', 'AAT': 'N', 'AAC': 'N', 'AAA': 'K', 'AAG': 'K', 'AGT': 'S', 'AGC': 'S', 'AGA': 'R', 'AGG': 'R', 'GTT': 'V', 'GTC': 'V', 'GTA': 'V', 'GTG': 'V', 'GCT': 'A', 'GCC': 'A', 'GCA': 'A', 'GCG': 'A', 'GAT': 'D', 'GAC': 'D', 'GAA': 'E', 'GAG': 'E', 'GGT': 'G', 'GGC': 'G', 'GGA': 'G', 'GGG': 'G', } RNA_CODON_TO_AA = { codon.replace('T', 'U'): aa for codon, aa in DNA_CODON_TO_AA.items() } def pad_and_concatenate_dimer( A: torch.Tensor, B: torch.Tensor, a_mask: Optional[torch.Tensor] = None, b_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: """ Given two sequences A and B with masks, pad (if needed) and concatenate them. """ batch_size, L, d = A.size() if a_mask is None: a_mask = torch.ones(batch_size, L, device=A.device) if b_mask is None: b_mask = torch.ones(batch_size, L, device=A.device) # Compute the maximum (valid) length in the batch. max_len = max( int(a_mask[i].sum().item() + b_mask[i].sum().item()) for i in range(batch_size) ) combined = torch.zeros(batch_size, max_len, d, device=A.device) combined_mask = torch.zeros(batch_size, max_len, device=A.device) for i in range(batch_size): a_len = int(a_mask[i].sum().item()) b_len = int(b_mask[i].sum().item()) combined[i, :a_len] = A[i, :a_len] combined[i, a_len:a_len+b_len] = B[i, :b_len] combined_mask[i, :a_len+b_len] = 1 return combined, combined_mask