|
|
|
|
|
import numpy as np |
|
|
from collections import defaultdict |
|
|
import prody as pr |
|
|
import os |
|
|
|
|
|
from datasets.constants import chi, atom_order, aa_long2short, aa_short2aa_idx, aa_idx2aa_short |
|
|
|
|
|
|
|
|
def get_dihedral_indices(resname, chi_num): |
|
|
"""Return the atom indices for the specified dihedral angle. |
|
|
""" |
|
|
if resname not in chi: |
|
|
return np.array([np.nan]*4) |
|
|
if chi_num not in chi[resname]: |
|
|
return np.array([np.nan]*4) |
|
|
return np.array([atom_order[resname].index(x) for x in chi[resname][chi_num]]) |
|
|
|
|
|
|
|
|
dihedral_indices = defaultdict(list) |
|
|
for aa in atom_order.keys(): |
|
|
for i in range(1, 5): |
|
|
inds = get_dihedral_indices(aa, i) |
|
|
dihedral_indices[aa].append(inds) |
|
|
dihedral_indices[aa] = np.array(dihedral_indices[aa]) |
|
|
|
|
|
|
|
|
def vector_batch(a, b): |
|
|
return a - b |
|
|
|
|
|
|
|
|
def unit_vector_batch(v): |
|
|
return v / np.linalg.norm(v, axis=1, keepdims=True) |
|
|
|
|
|
|
|
|
def dihedral_angle_batch(p): |
|
|
b0 = vector_batch(p[:, 0], p[:, 1]) |
|
|
b1 = vector_batch(p[:, 1], p[:, 2]) |
|
|
b2 = vector_batch(p[:, 2], p[:, 3]) |
|
|
|
|
|
n1 = np.cross(b0, b1) |
|
|
n2 = np.cross(b1, b2) |
|
|
|
|
|
m1 = np.cross(n1, b1 / np.linalg.norm(b1, axis=1, keepdims=True)) |
|
|
|
|
|
x = np.sum(n1 * n2, axis=1) |
|
|
y = np.sum(m1 * n2, axis=1) |
|
|
|
|
|
deg = np.degrees(np.arctan2(y, x)) |
|
|
|
|
|
deg[deg < 0] += 360 |
|
|
|
|
|
return deg |
|
|
|
|
|
|
|
|
def batch_compute_dihedral_angles(sidechains): |
|
|
sidechains_np = np.array(sidechains) |
|
|
dihedral_angles = dihedral_angle_batch(sidechains_np) |
|
|
return dihedral_angles |
|
|
|
|
|
|
|
|
def get_coords(prody_pdb): |
|
|
resindices = sorted(set(prody_pdb.ca.getResindices())) |
|
|
coords = np.full((len(resindices), 14, 3), np.nan) |
|
|
for i, resind in enumerate(resindices): |
|
|
sel = prody_pdb.select(f'resindex {resind}') |
|
|
resname = sel.getResnames()[0] |
|
|
for j, name in enumerate(atom_order[aa_long2short[resname] if resname in aa_long2short else 'X']): |
|
|
sel_resnum_name = sel.select(f'name {name}') |
|
|
if sel_resnum_name is not None: |
|
|
coords[i, j, :] = sel_resnum_name.getCoords()[0] |
|
|
else: |
|
|
coords[i, j, :] = [np.nan, np.nan, np.nan] |
|
|
return coords |
|
|
|
|
|
|
|
|
def get_onehot_sequence(seq): |
|
|
onehot = np.zeros((len(seq), 20)) |
|
|
for i, aa in enumerate(seq): |
|
|
idx = aa_short2aa_idx[aa] if aa in aa_short2aa_idx else 7 |
|
|
onehot[i, idx] = 1 |
|
|
return onehot |
|
|
|
|
|
|
|
|
def get_dihedral_indices(onehot_sequence): |
|
|
return np.array([dihedral_indices[aa_idx2aa_short[aa_idx]] for aa_idx in np.where(onehot_sequence)[1]]) |
|
|
|
|
|
|
|
|
def _get_chi_angles(coords, indices): |
|
|
X = coords |
|
|
Y = indices.astype(int) |
|
|
N = coords.shape[0] |
|
|
mask = np.isnan(indices) |
|
|
Y[mask] = 0 |
|
|
Z = X[np.arange(N)[:, None, None], Y, :] |
|
|
Z[mask] = np.nan |
|
|
chi_angles = batch_compute_dihedral_angles(Z.reshape(-1, 4, 3)).reshape(N, 4) |
|
|
return chi_angles |
|
|
|
|
|
|
|
|
def get_chi_angles(coords, seq, return_onehot=False): |
|
|
""" |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
prody_pdb : prody.AtomGroup |
|
|
prody pdb object or selection |
|
|
return_coords : bool, optional |
|
|
return coordinates of prody_pdb in (N, 14, 3) array format, by default False |
|
|
return_onehot : bool, optional |
|
|
return one-hot sequence of prody_pdb, by default False |
|
|
|
|
|
Returns |
|
|
------- |
|
|
numpy array of shape (N, 4) |
|
|
Array contains chi angles of sidechains in row-order of residue indices in prody_pdb. |
|
|
If a chi angle is not defined for a residue, due to missing atoms or GLY / ALA, it is set to np.nan. |
|
|
""" |
|
|
onehot = get_onehot_sequence(seq) |
|
|
dihedral_indices = get_dihedral_indices(onehot) |
|
|
if return_onehot: |
|
|
return _get_chi_angles(coords, dihedral_indices), onehot |
|
|
return _get_chi_angles(coords, dihedral_indices) |
|
|
|
|
|
|
|
|
def test_get_chi_angles(print_chi_angles=False): |
|
|
|
|
|
pdb = pr.parsePDB('6w70') |
|
|
prody_pdb = pdb.select('chain A') |
|
|
chi_angles = get_chi_angles(prody_pdb) |
|
|
assert chi_angles.shape == (prody_pdb.ca.numAtoms(), 4) |
|
|
assert chi_angles[0,0] < 56.0 and chi_angles[0,0] > 55.0 |
|
|
print('test_get_chi_angles passed') |
|
|
try: |
|
|
os.remove('6w70.pdb.gz') |
|
|
except: |
|
|
pass |
|
|
if print_chi_angles: |
|
|
print(chi_angles) |
|
|
return True |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
test_get_chi_angles(print_chi_angles=True) |
|
|
|
|
|
|
|
|
|