| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| from typing import Optional |
|
|
| |
| |
| |
|
|
| GENETIC_CODE = { |
| 'ATA':'I', 'ATC':'I', 'ATT':'I', 'ATG':'M', |
| 'ACA':'T', 'ACC':'T', 'ACG':'T', 'ACT':'T', |
| 'AAC':'N', 'AAT':'N', 'AAA':'K', 'AAG':'K', |
| 'AGC':'S', 'AGT':'S', 'AGA':'R', 'AGG':'R', |
| 'CTA':'L', 'CTC':'L', 'CTG':'L', 'CTT':'L', |
| 'CCA':'P', 'CCC':'P', 'CCG':'P', 'CCT':'P', |
| 'CAC':'H', 'CAT':'H', 'CAA':'Q', 'CAG':'Q', |
| 'CGA':'R', 'CGC':'R', 'CGG':'R', 'CGT':'R', |
| 'GTA':'V', 'GTC':'V', 'GTG':'V', 'GTT':'V', |
| 'GCA':'A', 'GCC':'A', 'GCG':'A', 'GCT':'A', |
| 'GAC':'D', 'GAT':'D', 'GAA':'E', 'GAG':'E', |
| 'GGA':'G', 'GGC':'A', 'GGG':'G', 'GGT':'G', |
| 'TCA':'S', 'TCC':'S', 'TCG':'S', 'TCT':'S', |
| 'TTC':'F', 'TTT':'F', 'TTA':'L', 'TTG':'L', |
| 'TAC':'Y', 'TAT':'Y', 'TAA':'*', 'TAG':'*', |
| 'TGC':'C', 'TGT':'C', 'TGA':'*', 'TGG':'W', |
| } |
| |
| GENETIC_CODE['GGC'] = 'G' |
|
|
| BASES = ['A', 'C', 'G', 'T'] |
| CODON_TO_INDEX = {b1+b2+b3: i for i, (b1,b2,b3) in enumerate([(b1,b2,b3) for b1 in BASES for b2 in BASES for b3 in BASES])} |
| INDEX_TO_CODON = {v: k for k, v in CODON_TO_INDEX.items()} |
|
|
| |
| |
| |
|
|
| def exp_map_zero(x: torch.Tensor, c: float = 1.0) -> torch.Tensor: |
| sqrt_c = math.sqrt(c) |
| norm_x = torch.norm(x, p=2, dim=-1, keepdim=True) |
| norm_x = torch.clamp(norm_x, min=1e-15) |
| res = torch.tanh(sqrt_c * norm_x) * x / (sqrt_c * norm_x) |
| return res |
|
|
| def project_to_poincare(z: torch.Tensor, max_norm: float = 0.95, c: float = 1.0) -> torch.Tensor: |
| norm = torch.norm(z, p=2, dim=-1, keepdim=True) |
| mask = norm > max_norm |
| projected = (z / norm) * max_norm |
| return torch.where(mask, projected, z) |
|
|
| |
| |
| |
|
|
| class CodonEncoderMLP(nn.Module): |
| def __init__(self, latent_dim=16, hidden_dim=64, dropout=0.1): |
| super().__init__() |
| self.encoder = nn.Sequential( |
| nn.Linear(12, hidden_dim), nn.LayerNorm(hidden_dim), nn.SiLU(), nn.Dropout(dropout), |
| nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.SiLU(), nn.Dropout(dropout), |
| nn.Linear(hidden_dim, latent_dim) |
| ) |
| def forward(self, x): return self.encoder(x) |
|
|
| class TrainableCodonEncoder(nn.Module): |
| def __init__(self, latent_dim=16, hidden_dim=64, curvature=1.0, max_radius=0.9, dropout=0.1): |
| super().__init__() |
| self.latent_dim = latent_dim; self.curvature = curvature; self.max_radius = max_radius |
| self.encoder = CodonEncoderMLP(latent_dim, hidden_dim, dropout) |
| |
| |
| onehots = torch.zeros(64, 12) |
| base_to_idx = {'A':0, 'C':1, 'G':2, 'T':3, 'U':3} |
| for i in range(64): |
| codon = INDEX_TO_CODON[i] |
| for pos, base in enumerate(codon): |
| onehots[i, pos*4 + base_to_idx[base]] = 1.0 |
| self.register_buffer('codon_onehots', onehots) |
|
|
| def encode_all(self): |
| z_tangent = self.encoder(self.codon_onehots) |
| z_hyp = exp_map_zero(z_tangent, c=self.curvature) |
| return project_to_poincare(z_hyp, max_norm=self.max_radius, c=self.curvature) |
|
|
| def forward(self, codon_indices): |
| flat_indices = codon_indices.flatten() |
| onehots = self.codon_onehots[flat_indices] |
| z_tangent = self.encoder(onehots) |
| z_hyp = exp_map_zero(z_tangent, c=self.curvature) |
| z_hyp = project_to_poincare(z_hyp, max_norm=self.max_radius, c=self.curvature) |
| if len(codon_indices.shape) > 1: |
| z_hyp = z_hyp.view(*codon_indices.shape, self.latent_dim) |
| return z_hyp |
|
|