| | import torch |
| | from typing import Optional, Tuple |
| |
|
| |
|
| | AA_SET = set('LAGVSERTIPDKQNFYMHWCXBUOZ*') |
| | CODON_SET = set('aA@bB#$%rRnNdDcCeEqQ^G&ghHiIj+MmlJLkK(fFpPoO=szZwSXTtxWyYuvUV]})') |
| | DNA_SET = set('ATCG') |
| | RNA_SET = set('AUCG') |
| | NONCANONICAL_AMINO_ACIDS = set('XBUOZ*') |
| | AMINO_ACID_TO_HUMAN_CODON = { |
| | 'A': 'GCC', |
| | 'R': 'CGC', |
| | 'N': 'AAC', |
| | 'D': 'GAC', |
| | 'C': 'TGC', |
| | 'Q': 'CAG', |
| | 'E': 'GAG', |
| | 'G': 'GGC', |
| | 'H': 'CAC', |
| | 'I': 'ATC', |
| | 'L': 'CTG', |
| | 'K': 'AAG', |
| | 'M': 'ATG', |
| | 'F': 'TTC', |
| | 'P': 'CCC', |
| | 'S': 'AGC', |
| | 'T': 'ACC', |
| | 'W': 'TGG', |
| | 'Y': 'TAC', |
| | 'V': 'GTG', |
| | } |
| | NONCANONICAL_ALANINE_CODON = 'GCT' |
| |
|
| | AA_TO_CODON_TOKEN = { |
| | 'A': 'A', |
| | 'R': 'B', |
| | 'N': 'N', |
| | 'D': 'D', |
| | 'C': 'C', |
| | 'Q': 'Q', |
| | 'E': 'E', |
| | 'G': 'G', |
| | 'H': 'H', |
| | 'I': 'I', |
| | 'L': 'L', |
| | 'K': 'K', |
| | 'M': '(', |
| | 'F': 'F', |
| | 'P': 'P', |
| | 'S': 'S', |
| | 'T': 'T', |
| | 'W': 'W', |
| | 'Y': 'Y', |
| | 'V': 'V', |
| | } |
| | CODON_TO_AA = { |
| | 'a':'A', |
| | 'A':'A', |
| | '@':'A', |
| | 'b':'A', |
| | 'B':'R', |
| | '#':'R', |
| | '$':'R', |
| | '%':'R', |
| | 'r':'R', |
| | 'R':'R', |
| | 'n':'N', |
| | 'N':'N', |
| | 'd':'D', |
| | 'D':'D', |
| | 'c':'C', |
| | 'C':'C', |
| | 'e':'E', |
| | 'E':'E', |
| | 'q':'Q', |
| | 'Q':'Q', |
| | '^':'G', |
| | 'G':'G', |
| | '&':'G', |
| | 'g':'G', |
| | 'h':'H', |
| | 'H':'H', |
| | 'i':'I', |
| | 'I':'I', |
| | 'j':'I', |
| | '+':'L', |
| | 'M':'L', |
| | 'm':'L', |
| | 'l':'L', |
| | 'J':'L', |
| | 'L':'L', |
| | 'k':'K', |
| | 'K':'K', |
| | '(':'M', |
| | 'f':'F', |
| | 'F':'F', |
| | 'p':'P', |
| | 'P':'P', |
| | 'o':'P', |
| | 'O':'P', |
| | '=':'S', |
| | 's':'S', |
| | 'z':'S', |
| | 'Z':'S', |
| | 'w':'S', |
| | 'S':'S', |
| | 'X':'S', |
| | 'T':'T', |
| | 't':'T', |
| | 'x':'T', |
| | 'W':'T', |
| | 'y':'Y', |
| | 'Y':'Y', |
| | 'u':'V', |
| | 'v':'V', |
| | 'U':'V', |
| | 'V':'V', |
| | ']':'*', |
| | '}':'*', |
| | ')':'*', |
| | } |
| | DNA_CODON_TO_AA = { |
| | 'TTT': 'F', 'TTC': 'F', 'TTA': 'L', 'TTG': 'L', |
| | 'TCT': 'S', 'TCC': 'S', 'TCA': 'S', 'TCG': 'S', |
| | 'TAT': 'Y', 'TAC': 'Y', 'TAA': '*', 'TAG': '*', |
| | 'TGT': 'C', 'TGC': 'C', 'TGA': '*', 'TGG': 'W', |
| | 'CTT': 'L', 'CTC': 'L', 'CTA': 'L', 'CTG': 'L', |
| | 'CCT': 'P', 'CCC': 'P', 'CCA': 'P', 'CCG': 'P', |
| | 'CAT': 'H', 'CAC': 'H', 'CAA': 'Q', 'CAG': 'Q', |
| | 'CGT': 'R', 'CGC': 'R', 'CGA': 'R', 'CGG': 'R', |
| | 'ATT': 'I', 'ATC': 'I', 'ATA': 'I', 'ATG': 'M', |
| | 'ACT': 'T', 'ACC': 'T', 'ACA': 'T', 'ACG': 'T', |
| | 'AAT': 'N', 'AAC': 'N', 'AAA': 'K', 'AAG': 'K', |
| | 'AGT': 'S', 'AGC': 'S', 'AGA': 'R', 'AGG': 'R', |
| | 'GTT': 'V', 'GTC': 'V', 'GTA': 'V', 'GTG': 'V', |
| | 'GCT': 'A', 'GCC': 'A', 'GCA': 'A', 'GCG': 'A', |
| | 'GAT': 'D', 'GAC': 'D', 'GAA': 'E', 'GAG': 'E', |
| | 'GGT': 'G', 'GGC': 'G', 'GGA': 'G', 'GGG': 'G', |
| | } |
| |
|
| | RNA_CODON_TO_AA = { |
| | codon.replace('T', 'U'): aa for codon, aa in DNA_CODON_TO_AA.items() |
| | } |
| |
|
| |
|
| |
|
| | def pad_and_concatenate_dimer( |
| | A: torch.Tensor, |
| | B: torch.Tensor, |
| | a_mask: Optional[torch.Tensor] = None, |
| | b_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """ |
| | Given two sequences A and B with masks, pad (if needed) and concatenate them. |
| | """ |
| | batch_size, L, d = A.size() |
| | if a_mask is None: |
| | a_mask = torch.ones(batch_size, L, device=A.device) |
| | if b_mask is None: |
| | b_mask = torch.ones(batch_size, L, device=A.device) |
| | |
| | max_len = max( |
| | int(a_mask[i].sum().item() + b_mask[i].sum().item()) |
| | for i in range(batch_size) |
| | ) |
| | combined = torch.zeros(batch_size, max_len, d, device=A.device) |
| | combined_mask = torch.zeros(batch_size, max_len, device=A.device) |
| | for i in range(batch_size): |
| | a_len = int(a_mask[i].sum().item()) |
| | b_len = int(b_mask[i].sum().item()) |
| | combined[i, :a_len] = A[i, :a_len] |
| | combined[i, a_len:a_len+b_len] = B[i, :b_len] |
| | combined_mask[i, :a_len+b_len] = 1 |
| | return combined, combined_mask |
| |
|