| | from typing import List, Dict, Any |
| | from openfold.utils import rigid_utils as ru |
| | from data import residue_constants |
| | import numpy as np |
| | import collections |
| | import string |
| | import pickle |
| | import os |
| | import torch |
| | from torch_scatter import scatter_add, scatter |
| | from Bio.PDB.Chain import Chain |
| | from data import protein |
| | import dataclasses |
| | from Bio import PDB |
| |
|
| | Rigid = ru.Rigid |
| | Protein = protein.Protein |
| |
|
| | |
| | ALPHANUMERIC = string.ascii_letters + string.digits + ' ' |
| | CHAIN_TO_INT = { |
| | chain_char: i for i, chain_char in enumerate(ALPHANUMERIC) |
| | } |
| | INT_TO_CHAIN = { |
| | i: chain_char for i, chain_char in enumerate(ALPHANUMERIC) |
| | } |
| |
|
| | NM_TO_ANG_SCALE = 10.0 |
| | ANG_TO_NM_SCALE = 1 / NM_TO_ANG_SCALE |
| |
|
| | CHAIN_FEATS = [ |
| | 'atom_positions', 'aatype', 'atom_mask', 'residue_index', 'b_factors' |
| | ] |
| |
|
| | to_numpy = lambda x: x.detach().cpu().numpy() |
| | aatype_to_seq = lambda aatype: ''.join([ |
| | residue_constants.restypes_with_x[x] for x in aatype]) |
| |
|
| |
|
| | class CPU_Unpickler(pickle.Unpickler): |
| | """Pytorch pickle loading workaround. |
| | |
| | https://github.com/pytorch/pytorch/issues/16797 |
| | """ |
| | def find_class(self, module, name): |
| | if module == 'torch.storage' and name == '_load_from_bytes': |
| | return lambda b: torch.load(io.BytesIO(b), map_location='cpu') |
| | else: return super().find_class(module, name) |
| |
|
| |
|
| | def create_rigid(rots, trans): |
| | rots = ru.Rotation(rot_mats=rots) |
| | return Rigid(rots=rots, trans=trans) |
| |
|
| |
|
| | def batch_align_structures(pos_1, pos_2, mask=None): |
| | if pos_1.shape != pos_2.shape: |
| | raise ValueError('pos_1 and pos_2 must have the same shape.') |
| | if pos_1.ndim != 3: |
| | raise ValueError(f'Expected inputs to have shape [B, N, 3]') |
| | num_batch = pos_1.shape[0] |
| | device = pos_1.device |
| | batch_indices = ( |
| | torch.ones(*pos_1.shape[:2], device=device, dtype=torch.int64) |
| | * torch.arange(num_batch, device=device)[:, None] |
| | ) |
| | flat_pos_1 = pos_1.reshape(-1, 3) |
| | flat_pos_2 = pos_2.reshape(-1, 3) |
| | flat_batch_indices = batch_indices.reshape(-1) |
| | if mask is None: |
| | |
| | |
| | |
| | |
| | |
| | mask = torch.ones(*pos_1.shape[:2], device=device).reshape(-1).bool() |
| |
|
| | flat_mask = mask.reshape(-1).bool() |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | aligned_pos_1, aligned_pos_2, align_rots = align_structures( |
| | flat_pos_1[flat_mask], flat_batch_indices[flat_mask], flat_pos_2[flat_mask]) |
| | aligned_pos_1 = aligned_pos_1.reshape(num_batch, -1, 3) |
| | aligned_pos_2 = aligned_pos_2.reshape(num_batch, -1, 3) |
| | return aligned_pos_1, aligned_pos_2, align_rots |
| |
|
| |
|
| |
|
| | def adjust_oxygen_pos( |
| | atom_37: torch.Tensor, pos_is_known = None |
| | ) -> torch.Tensor: |
| | """ |
| | Imputes the position of the oxygen atom on the backbone by using adjacent frame information. |
| | Specifically, we say that the oxygen atom is in the plane created by the Calpha and C from the |
| | current frame and the nitrogen of the next frame. The oxygen is then placed c_o_bond_length Angstrom |
| | away from the C in the current frame in the direction away from the Ca-C-N triangle. |
| | |
| | For cases where the next frame is not available, for example we are at the C-terminus or the |
| | next frame is not available in the data then we place the oxygen in the same plane as the |
| | N-Ca-C of the current frame and pointing in the same direction as the average of the |
| | Ca->C and Ca->N vectors. |
| | |
| | Args: |
| | atom_37 (torch.Tensor): (N, 37, 3) tensor of positions of the backbone atoms in atom_37 ordering |
| | which is ['N', 'CA', 'C', 'CB', 'O', ...] |
| | pos_is_known (torch.Tensor): (N,) mask for known residues. |
| | """ |
| |
|
| | N = atom_37.shape[0] |
| | assert atom_37.shape == (N, 37, 3) |
| |
|
| | |
| | |
| |
|
| | |
| | calpha_to_carbonyl: torch.Tensor = (atom_37[:-1, 2, :] - atom_37[:-1, 1, :]) / ( |
| | torch.norm(atom_37[:-1, 2, :] - atom_37[:-1, 1, :], keepdim=True, dim=1) + 1e-7 |
| | ) |
| | |
| | |
| |
|
| | |
| | nitrogen_to_carbonyl: torch.Tensor = (atom_37[:-1, 2, :] - atom_37[1:, 0, :]) / ( |
| | torch.norm(atom_37[:-1, 2, :] - atom_37[1:, 0, :], keepdim=True, dim=1) + 1e-7 |
| | ) |
| |
|
| | carbonyl_to_oxygen: torch.Tensor = calpha_to_carbonyl + nitrogen_to_carbonyl |
| | carbonyl_to_oxygen = carbonyl_to_oxygen / ( |
| | torch.norm(carbonyl_to_oxygen, dim=1, keepdim=True) + 1e-7 |
| | ) |
| |
|
| | atom_37[:-1, 4, :] = atom_37[:-1, 2, :] + carbonyl_to_oxygen * 1.23 |
| |
|
| | |
| |
|
| | |
| | calpha_to_carbonyl_term: torch.Tensor = (atom_37[:, 2, :] - atom_37[:, 1, :]) / ( |
| | torch.norm(atom_37[:, 2, :] - atom_37[:, 1, :], keepdim=True, dim=1) + 1e-7 |
| | ) |
| | |
| | calpha_to_nitrogen_term: torch.Tensor = (atom_37[:, 0, :] - atom_37[:, 1, :]) / ( |
| | torch.norm(atom_37[:, 0, :] - atom_37[:, 1, :], keepdim=True, dim=1) + 1e-7 |
| | ) |
| | carbonyl_to_oxygen_term: torch.Tensor = ( |
| | calpha_to_carbonyl_term + calpha_to_nitrogen_term |
| | ) |
| | carbonyl_to_oxygen_term = carbonyl_to_oxygen_term / ( |
| | torch.norm(carbonyl_to_oxygen_term, dim=1, keepdim=True) + 1e-7 |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | if pos_is_known is None: |
| | pos_is_known = torch.ones((atom_37.shape[0],), dtype=torch.int64, device=atom_37.device) |
| |
|
| | next_res_gone: torch.Tensor = ~pos_is_known.bool() |
| | next_res_gone = torch.cat( |
| | [next_res_gone, torch.ones((1,), device=pos_is_known.device).bool()], dim=0 |
| | ) |
| | next_res_gone = next_res_gone[1:] |
| |
|
| | atom_37[next_res_gone, 4, :] = ( |
| | atom_37[next_res_gone, 2, :] |
| | + carbonyl_to_oxygen_term[next_res_gone, :] * 1.23 |
| | ) |
| |
|
| | return atom_37 |
| |
|
| |
|
| | def write_pkl( |
| | save_path: str, pkl_data: Any, create_dir: bool = False, use_torch=False): |
| | """Serialize data into a pickle file.""" |
| | if create_dir: |
| | os.makedirs(os.path.dirname(save_path), exist_ok=True) |
| | if use_torch: |
| | torch.save(pkl_data, save_path, pickle_protocol=pickle.HIGHEST_PROTOCOL) |
| | else: |
| | with open(save_path, 'wb') as handle: |
| | pickle.dump(pkl_data, handle, protocol=pickle.HIGHEST_PROTOCOL) |
| |
|
| |
|
| | def read_pkl(read_path: str, verbose=True, use_torch=False, map_location=None): |
| | """Read data from a pickle file.""" |
| | try: |
| | if use_torch: |
| | return torch.load(read_path, map_location=map_location) |
| | else: |
| | with open(read_path, 'rb') as handle: |
| | return pickle.load(handle) |
| | except Exception as e: |
| | try: |
| | with open(read_path, 'rb') as handle: |
| | return CPU_Unpickler(handle).load() |
| | except Exception as e2: |
| | if verbose: |
| | print(f'Failed to read {read_path}. First error: {e}\n Second error: {e2}') |
| | raise(e) |
| |
|
| |
|
| | def chain_str_to_int(chain_str: str): |
| | chain_int = 0 |
| | if len(chain_str) == 1: |
| | return CHAIN_TO_INT[chain_str] |
| | for i, chain_char in enumerate(chain_str): |
| | chain_int += CHAIN_TO_INT[chain_char] + (i * len(ALPHANUMERIC)) |
| | return chain_int |
| |
|
| |
|
| | def parse_chain_feats(chain_feats, scale_factor=1.): |
| | ca_idx = residue_constants.atom_order['CA'] |
| | chain_feats['bb_mask'] = chain_feats['atom_mask'][:, ca_idx] |
| | bb_pos = chain_feats['atom_positions'][:, ca_idx] |
| | bb_center = np.sum(bb_pos, axis=0) / (np.sum(chain_feats['bb_mask']) + 1e-5) |
| | centered_pos = chain_feats['atom_positions'] - bb_center[None, None, :] |
| | scaled_pos = centered_pos / scale_factor |
| | chain_feats['atom_positions'] = scaled_pos * chain_feats['atom_mask'][..., None] |
| | chain_feats['bb_positions'] = chain_feats['atom_positions'][:, ca_idx] |
| | return chain_feats |
| |
|
| |
|
| | def concat_np_features( |
| | np_dicts: List[Dict[str, np.ndarray]], add_batch_dim: bool): |
| | """Performs a nested concatenation of feature dicts. |
| | |
| | Args: |
| | np_dicts: list of dicts with the same structure. |
| | Each dict must have the same keys and numpy arrays as the values. |
| | add_batch_dim: whether to add a batch dimension to each feature. |
| | |
| | Returns: |
| | A single dict with all the features concatenated. |
| | """ |
| | combined_dict = collections.defaultdict(list) |
| | for chain_dict in np_dicts: |
| | for feat_name, feat_val in chain_dict.items(): |
| | if add_batch_dim: |
| | feat_val = feat_val[None] |
| | combined_dict[feat_name].append(feat_val) |
| | |
| | for feat_name, feat_vals in combined_dict.items(): |
| | combined_dict[feat_name] = np.concatenate(feat_vals, axis=0) |
| | return combined_dict |
| |
|
| |
|
| | def center_zero(pos: torch.Tensor, batch_indexes: torch.LongTensor) -> torch.Tensor: |
| | """ |
| | Move the molecule center to zero for sparse position tensors. |
| | |
| | Args: |
| | pos: [N, 3] batch positions of atoms in the molecule in sparse batch format. |
| | batch_indexes: [N] batch index for each atom in sparse batch format. |
| | |
| | Returns: |
| | pos: [N, 3] zero-centered batch positions of atoms in the molecule in sparse batch format. |
| | """ |
| | assert len(pos.shape) == 2 and pos.shape[-1] == 3, "pos must have shape [N, 3]" |
| |
|
| | means = scatter(pos, batch_indexes, dim=0, reduce="mean") |
| | return pos - means[batch_indexes] |
| |
|
| |
|
| | @torch.no_grad() |
| | def align_structures( |
| | batch_positions: torch.Tensor, |
| | batch_indices: torch.Tensor, |
| | reference_positions: torch.Tensor, |
| | broadcast_reference: bool = False, |
| | ): |
| | """ |
| | Align structures in a ChemGraph batch to a reference, e.g. for RMSD computation. This uses the |
| | sparse formulation of pytorch geometric. If the ChemGraph is composed of a single system, then |
| | the reference can be given as a single structure and broadcasted. Returns the structure |
| | coordinates shifted to the geometric center and the batch structures rotated to match the |
| | reference structures. Uses the Kabsch algorithm (see e.g. [kabsch_align1]_). No permutation of |
| | atoms is carried out. |
| | |
| | Args: |
| | batch_positions (Tensor): Batch of structures (e.g. from ChemGraph) which should be aligned |
| | to a reference. |
| | batch_indices (Tensor): Index tensor mapping each node / atom in batch to the respective |
| | system (e.g. batch attribute of ChemGraph batch). |
| | reference_positions (Tensor): Reference structure. Can either be a batch of structures or a |
| | single structure. In the second case, broadcasting is possible if the input batch is |
| | composed exclusively of this structure. |
| | broadcast_reference (bool, optional): If reference batch contains only a single structure, |
| | broadcast this structure to match the ChemGraph batch. Defaults to False. |
| | |
| | Returns: |
| | Tuple[torch.Tensor, torch.Tensor]: Tensors containing the centered positions of batch |
| | structures rotated into the reference and the centered reference batch. |
| | |
| | References |
| | ---------- |
| | .. [kabsch_align1] Lawrence, Bernal, Witzgall: |
| | A purely algebraic justification of the Kabsch-Umeyama algorithm. |
| | Journal of research of the National Institute of Standards and Technology, 124, 1. 2019. |
| | """ |
| | |
| | |
| | |
| |
|
| | if batch_positions.shape[0] != reference_positions.shape[0]: |
| | if broadcast_reference: |
| | |
| | |
| | |
| | num_molecules = int(torch.max(batch_indices) + 1) |
| | reference_positions = reference_positions.repeat(num_molecules, 1) |
| | else: |
| | raise ValueError("Mismatch in batch dimensions.") |
| |
|
| | |
| | batch_positions = center_zero(batch_positions, batch_indices) |
| | reference_positions = center_zero(reference_positions, batch_indices) |
| |
|
| | |
| | cov = scatter_add( |
| | batch_positions[:, None, :] * reference_positions[:, :, None], batch_indices, dim=0 |
| | ) |
| |
|
| | |
| | u, _, v_t = torch.linalg.svd(cov) |
| | |
| | u_t = u.transpose(1, 2) |
| | v = v_t.transpose(1, 2) |
| |
|
| | |
| | |
| | sign_correction = torch.sign(torch.linalg.det(torch.bmm(v, u_t))) |
| | |
| | u_t[:, 2, :] = u_t[:, 2, :] * sign_correction[:, None] |
| |
|
| | |
| | rotation_matrices = torch.bmm(v, u_t) |
| |
|
| | |
| | |
| | rotation_matrices = rotation_matrices.type(batch_positions.dtype) |
| |
|
| | |
| | batch_positions_rotated = torch.bmm( |
| | batch_positions[:, None, :], |
| | rotation_matrices[batch_indices], |
| | ).squeeze(1) |
| |
|
| | return batch_positions_rotated, reference_positions, rotation_matrices |
| |
|
| |
|
| | def parse_pdb_feats( |
| | pdb_name: str, |
| | pdb_path: str, |
| | scale_factor=1., |
| | |
| | chain_id='A', |
| | ): |
| | """ |
| | Args: |
| | pdb_name: name of PDB to parse. |
| | pdb_path: path to PDB file to read. |
| | scale_factor: factor to scale atom positions. |
| | mean_center: whether to mean center atom positions. |
| | Returns: |
| | Dict with CHAIN_FEATS features extracted from PDB with specified |
| | preprocessing. |
| | """ |
| | parser = PDB.PDBParser(QUIET=True) |
| | structure = parser.get_structure(pdb_name, pdb_path) |
| | struct_chains = { |
| | chain.id: chain |
| | for chain in structure.get_chains()} |
| |
|
| | def _process_chain_id(x): |
| | chain_prot = process_chain(struct_chains[x], x) |
| | chain_dict = dataclasses.asdict(chain_prot) |
| |
|
| | |
| | feat_dict = {x: chain_dict[x] for x in CHAIN_FEATS} |
| | return parse_chain_feats( |
| | feat_dict, scale_factor=scale_factor) |
| |
|
| | if isinstance(chain_id, str): |
| | return _process_chain_id(chain_id) |
| | elif isinstance(chain_id, list): |
| | return { |
| | x: _process_chain_id(x) for x in chain_id |
| | } |
| | elif chain_id is None: |
| | return { |
| | x: _process_chain_id(x) for x in struct_chains |
| | } |
| | else: |
| | raise ValueError(f'Unrecognized chain list {chain_id}') |
| |
|
| | def rigid_transform_3D(A, B, verbose=False): |
| | |
| | |
| | assert A.shape == B.shape |
| | A = A.T |
| | B = B.T |
| |
|
| | num_rows, num_cols = A.shape |
| | if num_rows != 3: |
| | raise Exception(f"matrix A is not 3xN, it is {num_rows}x{num_cols}") |
| |
|
| | num_rows, num_cols = B.shape |
| | if num_rows != 3: |
| | raise Exception(f"matrix B is not 3xN, it is {num_rows}x{num_cols}") |
| |
|
| | |
| | centroid_A = np.mean(A, axis=1) |
| | centroid_B = np.mean(B, axis=1) |
| |
|
| | |
| | centroid_A = centroid_A.reshape(-1, 1) |
| | centroid_B = centroid_B.reshape(-1, 1) |
| |
|
| | |
| | Am = A - centroid_A |
| | Bm = B - centroid_B |
| |
|
| | H = Am @ np.transpose(Bm) |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | U, S, Vt = np.linalg.svd(H) |
| | R = Vt.T @ U.T |
| |
|
| | |
| | reflection_detected = False |
| | if np.linalg.det(R) < 0: |
| | if verbose: |
| | print("det(R) < R, reflection detected!, correcting for it ...") |
| | Vt[2,:] *= -1 |
| | R = Vt.T @ U.T |
| | reflection_detected = True |
| |
|
| | t = -R @ centroid_A + centroid_B |
| | optimal_A = R @ A + t |
| |
|
| | return optimal_A.T, R, t, reflection_detected |
| |
|
| | def process_chain(chain: Chain, chain_id: str) -> Protein: |
| | """Convert a PDB chain object into a AlphaFold Protein instance. |
| | |
| | Forked from alphafold.common.protein.from_pdb_string |
| | |
| | WARNING: All non-standard residue types will be converted into UNK. All |
| | non-standard atoms will be ignored. |
| | |
| | Took out lines 94-97 which don't allow insertions in the PDB. |
| | Sabdab uses insertions for the chothia numbering so we need to allow them. |
| | |
| | Took out lines 110-112 since that would mess up CDR numbering. |
| | |
| | Args: |
| | chain: Instance of Biopython's chain class. |
| | |
| | Returns: |
| | Protein object with protein features. |
| | """ |
| | atom_positions = [] |
| | aatype = [] |
| | atom_mask = [] |
| | residue_index = [] |
| | b_factors = [] |
| | chain_ids = [] |
| | for res in chain: |
| | res_shortname = residue_constants.restype_3to1.get(res.resname, 'X') |
| | restype_idx = residue_constants.restype_order.get( |
| | res_shortname, residue_constants.restype_num) |
| | pos = np.zeros((residue_constants.atom_type_num, 3)) |
| | mask = np.zeros((residue_constants.atom_type_num,)) |
| | res_b_factors = np.zeros((residue_constants.atom_type_num,)) |
| | for atom in res: |
| | if atom.name not in residue_constants.atom_types: |
| | continue |
| | pos[residue_constants.atom_order[atom.name]] = atom.coord |
| | mask[residue_constants.atom_order[atom.name]] = 1. |
| | res_b_factors[residue_constants.atom_order[atom.name] |
| | ] = atom.bfactor |
| | aatype.append(restype_idx) |
| | atom_positions.append(pos) |
| | atom_mask.append(mask) |
| | residue_index.append(res.id[1]) |
| | b_factors.append(res_b_factors) |
| | chain_ids.append(chain_id) |
| |
|
| | return Protein( |
| | atom_positions=np.array(atom_positions), |
| | atom_mask=np.array(atom_mask), |
| | aatype=np.array(aatype), |
| | residue_index=np.array(residue_index), |
| | chain_index=np.array(chain_ids), |
| | b_factors=np.array(b_factors)) |
| |
|