File size: 6,678 Bytes
52007f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
#!/usr/bin/python
# -*- coding:utf-8 -*-
from copy import copy
from typing import List, Tuple, Iterator, Optional

from utils import const


class MoleculeVocab:

    MAX_ATOM_NUMBER = 14

    def __init__(self):
        self.backbone_atoms = ['N', 'CA', 'C', 'O']
        self.PAD, self.MASK, self.UNK, self.LAT = '#', '*', '?', '&' # pad / mask / unk / latent node
        specials = [# special added
                (self.PAD, 'PAD'), (self.MASK, 'MASK'), (self.UNK, 'UNK'), # pad / mask / unk
                (self.LAT, '<L>')  # latent node in latent space
            ]
        
        aas = const.aas

        # sms = [(e.lower(), e) for e in const.periodic_table]
        sms = [] # disable small molecule vocabulary

        self.atom_pad, self.atom_mask, self.atom_latent = 'pad', 'msk', 'lat' # Avoid conflict with atom P
        self.atom_pos_pad, self.atom_pos_mask, self.atom_pos_latent = 'pad', 'msk', 'lat'
        self.atom_pos_sm = 'sml'  # small molecule

        # block level vocab
        self.idx2block = specials + aas + sms 
        self.symbol2idx, self.abrv2idx = {}, {}
        for i, (symbol, abrv) in enumerate(self.idx2block):
            self.symbol2idx[symbol] = i
            self.abrv2idx[abrv] = i
        self.special_mask = [1 for _ in specials] + [0 for _ in aas] + [0 for _ in sms]

        # atom level vocab
        self.idx2atom = [self.atom_pad, self.atom_mask, self.atom_latent] + const.periodic_table
        self.idx2atom_pos = [self.atom_pos_pad, self.atom_pos_mask, self.atom_pos_latent, '', 'A', 'B', 'G', 'D', 'E', 'Z', 'H', 'XT', 'P', self.atom_pos_sm] # SM is for atoms in small molecule, 'P' for O1P, O2P, O3P
        self.atom2idx, self.atom_pos2idx = {}, {}
        self.atom2idx = {}
        for i, atom in enumerate(self.idx2atom):
            self.atom2idx[atom] = i
        for i, atom_pos in enumerate(self.idx2atom_pos):
            self.atom_pos2idx[atom_pos] = i
    
    # block level APIs

    def abrv_to_symbol(self, abrv):
        idx = self.abrv_to_idx(abrv)
        return None if idx is None else self.idx2block[idx][0]

    def symbol_to_abrv(self, symbol):
        idx = self.symbol_to_idx(symbol)
        return None if idx is None else self.idx2block[idx][1]

    def abrv_to_idx(self, abrv):
        abrv = abrv.upper()
        return self.abrv2idx.get(abrv, self.abrv2idx['UNK'])

    def symbol_to_idx(self, symbol):
        # symbol = symbol.upper()
        return self.symbol2idx.get(symbol, self.abrv2idx['UNK'])
    
    def idx_to_symbol(self, idx):
        return self.idx2block[idx][0]

    def idx_to_abrv(self, idx):
        return self.idx2block[idx][1]

    def get_pad_idx(self):
        return self.symbol_to_idx(self.PAD)

    def get_mask_idx(self):
        return self.symbol_to_idx(self.MASK)
    
    def get_special_mask(self):
        return copy(self.special_mask)
    
    # atom level APIs 

    def get_atom_pad_idx(self):
        return self.atom2idx[self.atom_pad]
    
    def get_atom_mask_idx(self):
        return self.atom2idx[self.atom_mask]
    
    def get_atom_latent_idx(self):
        return self.atom2idx[self.atom_latent]
    
    def get_atom_pos_pad_idx(self):
        return self.atom_pos2idx[self.atom_pos_pad]

    def get_atom_pos_mask_idx(self):
        return self.atom_pos2idx[self.atom_pos_mask]
    
    def get_atom_pos_latent_idx(self):
        return self.atom_pos2idx[self.atom_pos_latent]
    
    def idx_to_atom(self, idx):
        return self.idx2atom[idx]

    def atom_to_idx(self, atom):
        atom = atom.upper()
        return self.atom2idx.get(atom, self.atom2idx[self.atom_mask])

    def idx_to_atom_pos(self, idx):
        return self.idx2atom_pos[idx]
    
    def atom_pos_to_idx(self, atom_pos):
        return self.atom_pos2idx.get(atom_pos, self.atom_pos2idx[self.atom_pos_mask])

    # sizes

    def get_num_atom_type(self):
        return len(self.idx2atom)
    
    def get_num_atom_pos(self):
        return len(self.idx2atom_pos)

    def get_num_block_type(self):
        return len(self.special_mask) - sum(self.special_mask)

    def __len__(self):
        return len(self.symbol2idx)

    # others
    @property
    def ca_channel_idx(self):
        return self.backbone_atoms.index('CA')


VOCAB = MoleculeVocab()


class Atom:
    def __init__(self, atom_name: str, coordinate: List[float], element: str, pos_code: str=None):
        self.name = atom_name
        self.coordinate = coordinate
        self.element = element
        if pos_code is None:
            pos_code = atom_name.lstrip(element)
            self.pos_code = pos_code
        else:
            self.pos_code = pos_code

    def get_element(self):
        return self.element
    
    def get_coord(self):
        return copy(self.coordinate)
    
    def get_pos_code(self):
        return self.pos_code
    
    def __str__(self) -> str:
        return self.name

    def __repr__(self) -> str:
        return f"Atom ({self.name}): {self.element}({self.pos_code}) [{','.join(['{:.4f}'.format(num) for num in self.coordinate])}]"
    
    def to_tuple(self):
        return (
            self.name,
            self.coordinate,
            self.element,
            self.pos_code
        )
    
    @classmethod
    def from_tuple(self, data):
        return Atom(
            atom_name=data[0],
            coordinate=data[1],
            element=data[2],
            pos_code=data[3]
        )


class Block:
    def __init__(self, abrv: str, units: List[Atom], id: Optional[any]=None) -> None:
        self.abrv: str = abrv
        self.units: List[Atom] = units
        self._uname2idx = { unit.name: i for i, unit in enumerate(self.units) }
        self.id = id

    def __len__(self) -> int:
        return len(self.units)
    
    def __iter__(self) -> Iterator[Atom]:
        return iter(self.units)
    
    def get_unit_by_name(self, name: str) -> Atom:
        idx = self._uname2idx[name]
        return self.units[idx]
    
    def has_unit(self, name: str) -> bool:
        return name in self._uname2idx

    def to_tuple(self):
        return (
            self.abrv,
            [unit.to_tuple() for unit in self.units],
            self.id
        )
    
    def is_residue(self):
        return self.has_unit('CA') and self.has_unit('N') and self.has_unit('C') and self.has_unit('O')
   
    @classmethod
    def from_tuple(self, data):
        return Block(
            abrv=data[0],
            units=[Atom.from_tuple(unit_data) for unit_data in data[1]],
            id=data[2]
        )
    
    def __repr__(self) -> str:
        return f"Block ({self.abrv}):\n\t" + '\n\t'.join([repr(at) for at in self.units]) + '\n'