File size: 8,327 Bytes
52007f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
import os
from typing import Optional, Any
import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence
from utils import register as R
from utils.const import sidechain_atoms
from data.converter.list_blocks_to_pdb import list_blocks_to_pdb
from .format import VOCAB, Block, Atom
from .mmap_dataset import MMAPDataset
from .resample import ClusterResampler
def calculate_covariance_matrix(point_cloud):
# Calculate the covariance matrix of the point cloud
covariance_matrix = np.cov(point_cloud, rowvar=False)
return covariance_matrix
@R.register('CoDesignDataset')
class CoDesignDataset(MMAPDataset):
MAX_N_ATOM = 14
def __init__(
self,
mmap_dir: str,
backbone_only: bool, # only backbone (N, CA, C, O) or full-atom
specify_data: Optional[str] = None,
specify_index: Optional[str] = None,
padding_collate: bool = False,
cluster: Optional[str] = None,
use_covariance_matrix: bool = False
) -> None:
super().__init__(mmap_dir, specify_data, specify_index)
self.mmap_dir = mmap_dir
self.backbone_only = backbone_only
self._lengths = [len(prop[-1].split(',')) + int(prop[1]) for prop in self._properties]
self.padding_collate = padding_collate
self.resampler = ClusterResampler(cluster) if cluster else None # should only be used in training!
self.use_covariance_matrix = use_covariance_matrix
self.dynamic_idxs = [i for i in range(len(self))]
self.update_epoch() # should be called every epoch
def update_epoch(self):
if self.resampler is not None:
self.dynamic_idxs = self.resampler(len(self))
def get_len(self, idx):
return self._lengths[self.dynamic_idxs[idx]]
def get_summary(self, idx: int):
props = self._properties[idx]
_id = self._indexes[idx][0].split('.')[0]
ref_pdb = os.path.join(self.mmap_dir, '..', 'pdbs', _id + '.pdb')
rec_chain, lig_chain = props[4], props[5]
return _id, ref_pdb, rec_chain, lig_chain
def __getitem__(self, idx: int):
idx = self.dynamic_idxs[idx]
rec_blocks, lig_blocks = super().__getitem__(idx)
# receptor, (lig_chain_id, lig_blocks) = super().__getitem__(idx)
# pocket = {}
# for i in self._properties[idx][-1].split(','):
# chain, i = i.split(':')
# if chain not in pocket:
# pocket[chain] = []
# pocket[chain].append(int(i))
# rec_blocks = []
# for chain_id, blocks in receptor:
# for i in pocket[chain_id]:
# rec_blocks.append(blocks[i])
pocket_idx = [int(i) for i in self._properties[idx][-1].split(',')]
rec_position_ids = [i + 1 for i, _ in enumerate(rec_blocks)]
rec_blocks = [rec_blocks[i] for i in pocket_idx]
rec_position_ids = [rec_position_ids[i] for i in pocket_idx]
rec_blocks = [Block.from_tuple(tup) for tup in rec_blocks]
lig_blocks = [Block.from_tuple(tup) for tup in lig_blocks]
# for block in lig_blocks:
# block.units = [Atom('CA', [0, 0, 0], 'C')]
# if idx == 0:
# print(self._properties[idx])
# print(''.join(VOCAB.abrv_to_symbol(block.abrv) for block in lig_blocks))
# list_blocks_to_pdb([
# rec_blocks, lig_blocks
# ], ['B', 'A'], 'pocket.pdb')
mask = [0 for _ in rec_blocks] + [1 for _ in lig_blocks]
position_ids = rec_position_ids + [i + 1 for i, _ in enumerate(lig_blocks)]
X, S, atom_mask = [], [], []
for block in rec_blocks + lig_blocks:
symbol = VOCAB.abrv_to_symbol(block.abrv)
atom2coord = { unit.name: unit.get_coord() for unit in block.units }
bb_pos = np.mean(list(atom2coord.values()), axis=0).tolist()
coords, coord_mask = [], []
for atom_name in VOCAB.backbone_atoms + sidechain_atoms.get(symbol, []):
if atom_name in atom2coord:
coords.append(atom2coord[atom_name])
coord_mask.append(1)
else:
coords.append(bb_pos)
coord_mask.append(0)
n_pad = self.MAX_N_ATOM - len(coords)
for _ in range(n_pad):
coords.append(bb_pos)
coord_mask.append(0)
X.append(coords)
S.append(VOCAB.symbol_to_idx(symbol))
atom_mask.append(coord_mask)
X, atom_mask = torch.tensor(X, dtype=torch.float), torch.tensor(atom_mask, dtype=torch.bool)
mask = torch.tensor(mask, dtype=torch.bool)
if self.backbone_only:
X, atom_mask = X[:, :4], atom_mask[:, :4]
if self.use_covariance_matrix:
cov = calculate_covariance_matrix(X[~mask][:, 1][atom_mask[~mask][:, 1]].numpy()) # only use the receptor to derive the affine transformation
eps = 1e-4
cov = cov + eps * np.identity(cov.shape[0])
L = torch.from_numpy(np.linalg.cholesky(cov)).float().unsqueeze(0)
else:
L = None
item = {
'X': X, # [N, 14] or [N, 4] if backbone_only == True
'S': torch.tensor(S, dtype=torch.long), # [N]
'position_ids': torch.tensor(position_ids, dtype=torch.long), # [N]
'mask': mask, # [N], 1 for generation
'atom_mask': atom_mask, # [N, 14] or [N, 4], 1 for having records in the PDB
'lengths': len(S),
}
if L is not None:
item['L'] = L
return item
def collate_fn(self, batch):
if self.padding_collate:
results = {}
pad_idx = VOCAB.symbol_to_idx(VOCAB.PAD)
for key in batch[0]:
values = [item[key] for item in batch]
if values[0] is None:
results[key] = None
continue
if key == 'lengths':
results[key] = torch.tensor(values, dtype=torch.long)
elif key == 'S':
results[key] = pad_sequence(values, batch_first=True, padding_value=pad_idx)
else:
results[key] = pad_sequence(values, batch_first=True, padding_value=0)
return results
else:
results = {}
for key in batch[0]:
values = [item[key] for item in batch]
if values[0] is None:
results[key] = None
continue
if key == 'lengths':
results[key] = torch.tensor(values, dtype=torch.long)
else:
results[key] = torch.cat(values, dim=0)
return results
@R.register('ShapeDataset')
class ShapeDataset(CoDesignDataset):
def __init__(
self,
mmap_dir: str,
specify_data: Optional[str] = None,
specify_index: Optional[str] = None,
padding_collate: bool = False,
cluster: Optional[str] = None
) -> None:
super().__init__(mmap_dir, False, specify_data, specify_index, padding_collate, cluster)
self.ca_idx = VOCAB.backbone_atoms.index('CA')
def __getitem__(self, idx: int):
item = super().__getitem__(idx)
# refine coordinates to CA and the atom furthest from CA
X = item['X'] # [N, 14, 3]
atom_mask = item['atom_mask']
ca_x = X[:, self.ca_idx].unsqueeze(1) # [N, 1, 3]
sc_x = X[:, 4:] # [N, 10, 3], sidechain atom indexes
dist = torch.norm(sc_x - ca_x, dim=-1) # [N, 10]
dist = dist.masked_fill(~atom_mask[:, 4:], 1e10)
furthest_atom_x = sc_x[torch.arange(sc_x.shape[0]), torch.argmax(dist, dim=-1)] # [N, 3]
X = torch.cat([ca_x, furthest_atom_x.unsqueeze(1)], dim=1)
item['X'] = X
return item
if __name__ == '__main__':
import sys
dataset = CoDesignDataset(sys.argv[1], backbone_only=True)
print(dataset[0])
|