File size: 3,913 Bytes
52007f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
#!/usr/bin/python
# -*- coding:utf-8 -*-
import os
import gzip
import shutil
import argparse
import numpy as np
from utils.logger import print_log
from utils.file_utils import get_filename, cnt_num_files
from data.format import VOCAB
from data.converter.pdb_to_list_blocks import pdb_to_list_blocks
from data.converter.blocks_to_data import blocks_to_data
from data.mmap_dataset import create_mmap
def parse():
parser = argparse.ArgumentParser(description='Process PDB to monomers')
parser.add_argument('--pdb_dir', type=str, required=True,
help='Directory of pdb database')
parser.add_argument('--out_dir', type=str, required=True,
help='Output directory')
return parser.parse_args()
def process_iterator(data_dir):
tmp_dir = './tmp'
if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir)
file_cnt = 0
for category in os.listdir(data_dir):
category_dir = os.path.join(data_dir, category)
for pdb_file in os.listdir(category_dir):
file_cnt += 1
path = os.path.join(category_dir, pdb_file)
tmp_file = os.path.join(tmp_dir, f'{pdb_file}.decompressed')
try:
# uncompress the file to the tmp file
with gzip.open(path, 'rb') as fin:
with open(tmp_file, 'wb') as fout:
shutil.copyfileobj(fin, fout)
list_blocks, chains = pdb_to_list_blocks(tmp_file, return_chain_ids=True)
except Exception as e:
print_log(f'Parsing {pdb_file} failed: {e}', level='WARN')
continue
for blocks, chain in zip(list_blocks, chains):
# find broken chains: sequence starts from N end
filter_blocks, NC_coords = [], []
for block in blocks:
N_coord, C_coord, CA_coord = None, None, None
for atom in block:
if atom.name == 'N':
N_coord = atom.coordinate
elif atom.name == 'C':
C_coord = atom.coordinate
elif atom.name == 'CA':
CA_coord = atom.coordinate
if N_coord and C_coord and CA_coord:
filter_blocks.append(block)
NC_coords.append(N_coord)
NC_coords.append(C_coord)
if len(filter_blocks) == 0: # no valid residues
continue
NC_coords = np.array(NC_coords)
pep_bond_len = np.linalg.norm(NC_coords[1::2][:-1] - NC_coords[2::2], axis=-1)
# broken = np.nonzero(pep_bond_len > 1.5)[0]
if np.any(pep_bond_len > 1.5):
continue
blocks = filter_blocks
item_id = chain + '_' + pdb_file
# data = blocks_to_data(blocks)
num_blocks = len(blocks)
num_units = sum([len(block.units) for block in blocks])
data = [block.to_tuple() for block in blocks]
seq = ''.join([VOCAB.abrv_to_symbol(block.abrv) for block in blocks])
# id, data, properties, whether this entry is finished for producing data
yield item_id, data, [num_blocks, num_units, chain, seq], file_cnt
if os.path.exists(tmp_file):
os.remove(tmp_file)
shutil.rmtree(tmp_dir)
def main(args):
cnt = cnt_num_files(args.pdb_dir, recursive=True)
print_log(f'Processing data from directory: {args.pdb_dir}.')
print_log(f'Number of entries: {cnt}')
create_mmap(
process_iterator(args.pdb_dir),
args.out_dir, cnt)
print_log('Finished!')
if __name__ == '__main__':
main(parse()) |