File size: 6,941 Bytes
714cf46 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 | import argparse
import numpy as np
import matplotlib.pyplot as plt
import torch
import os
import random
import requests
import tempfile
import sys
from Bio.PDB import PDBParser, PPBuilder
from esm2.modeling_fastesm import FastEsmModel
def download_random_pdb():
"""
Download a random protein chain PDB file.
Returns:
str: Path to the downloaded PDB file.
"""
example_pdbs = ["1AKE"]
# Select a random PDB ID
pdb_id = random.choice(example_pdbs)
print(f"Selected random PDB ID: {pdb_id}")
# Create a temporary file to store the PDB
temp_file = tempfile.NamedTemporaryFile(suffix=".pdb", delete=False)
temp_file_path = temp_file.name
temp_file.close()
# Download the PDB file
url = f"https://files.rcsb.org/download/{pdb_id}.pdb"
response = requests.get(url)
if response.status_code == 200:
with open(temp_file_path, 'wb') as f:
f.write(response.content)
print(f"Downloaded PDB file to: {temp_file_path}")
return temp_file_path
else:
raise Exception(f"Failed to download PDB file: {response.status_code}")
def parse_pdb(pdb_file):
"""
Parse a PDB file and extract the protein sequence and CA atom coordinates.
Parameters:
pdb_file (str): Path to the PDB file.
Returns:
tuple: (sequence (str), coords (np.ndarray of shape (L, 3)))
"""
parser = PDBParser(QUIET=True)
structure = parser.get_structure("protein", pdb_file)
ppb = PPBuilder()
# Assume a single protein chain; take the first polypeptide found.
for pp in ppb.build_peptides(structure):
sequence = str(pp.get_sequence())
coords = []
for residue in pp:
# Only add the CA atom if available.
if 'CA' in residue:
coords.append(residue['CA'].get_coord())
if len(coords) == 0:
raise ValueError("No CA atoms found in the polypeptide.")
return sequence, np.array(coords)
raise ValueError("No polypeptide chains were found in the PDB file.")
def compute_distance_matrix(coords):
"""
Compute the pairwise Euclidean distance matrix from a set of coordinates.
Parameters:
coords (np.ndarray): Array of shape (L, 3) where L is the number of residues.
Returns:
np.ndarray: A matrix of shape (L, L) containing distances.
"""
diff = coords[:, None, :] - coords[None, :, :]
dist_matrix = np.sqrt(np.sum(diff**2, axis=-1))
return dist_matrix
def get_esm_contact_map(sequence):
"""
Use the ESM model to predict a contact map for the given protein sequence.
Parameters:
sequence (str): Amino acid sequence.
Returns:
np.ndarray: A 2D array (L x L) with contact probabilities.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = "Synthyra/ESM2-650M"
model = FastEsmModel.from_pretrained(model_path).eval().to(device)
tokenizer = model.tokenizer
inputs = tokenizer(sequence, return_tensors="pt")
inputs = {key: value.to(device) for key, value in inputs.items()}
with torch.no_grad():
contact_map = model.predict_contacts(inputs["input_ids"], inputs["attention_mask"])
print(contact_map.shape)
contact_map = contact_map.squeeze().cpu().numpy()
print(contact_map.shape)
return contact_map
def plot_maps(true_contact_map, predicted_contact_map, pdb_file):
"""
Generate two subplots:
1. ESM predicted contact map.
2. True contact map from the PDB (binary, thresholded).
Parameters:
true_contact_map (np.ndarray): Binary (0/1) contact map from PDB.
predicted_contact_map (np.ndarray): Predicted contact probabilities from ESM.
pdb_file (str): Path to the PDB file, used to generate output filename.
"""
fig, axs = plt.subplots(1, 2, figsize=(12, 6))
# Plot the ESM-predicted contact map.
im0 = axs[0].imshow(predicted_contact_map, cmap='RdYlBu_r', aspect='equal')
axs[0].set_title("Predicted contact probabilities")
axs[0].set_xlabel("Residue index")
axs[0].set_ylabel("Residue index")
fig.colorbar(im0, ax=axs[0], fraction=0.046, pad=0.04)
# Plot the true contact map (binary contacts).
im1 = axs[1].imshow(true_contact_map, cmap='RdYlBu_r', aspect='equal')
axs[1].set_title("True contacts (PDB, threshold = 8 Å)")
axs[1].set_xlabel("Residue index")
axs[1].set_ylabel("Residue index")
fig.colorbar(im1, ax=axs[1], fraction=0.046, pad=0.04)
plt.tight_layout()
# Generate output filename from PDB filename
pdb_name = os.path.splitext(os.path.basename(pdb_file))[0]
output_file = f"contact_maps_{pdb_name}.png"
plt.savefig(output_file, dpi=300, bbox_inches='tight')
plt.close()
def main():
# py tests/test_contact_maps.py
parser = argparse.ArgumentParser(
description="Extract protein sequence and compute contact maps from a PDB file using ESM predictions."
)
parser.add_argument("--pdb_file", type=str, help="Path to the PDB file of the protein. If not provided, a random PDB will be downloaded.", default=None)
parser.add_argument(
"--threshold",
type=float,
default=8.0,
help="Distance threshold (in Å) for defining true contacts (default: 8.0 Å)."
)
args = parser.parse_args()
# If no PDB file is provided, download a random one
if args.pdb_file is None:
pdb_file = download_random_pdb()
else:
pdb_file = args.pdb_file
try:
# Parse the PDB file.
sequence, coords = parse_pdb(pdb_file)
print("Extracted Protein Sequence:")
print(sequence)
# Compute the pairwise distance matrix.
dist_matrix = compute_distance_matrix(coords)
# Create a binary contact map from the distance matrix using the threshold.
true_contact_map = (dist_matrix < args.threshold).astype(float)
# Get the predicted contact map from the ESM model.
predicted_contact_map = get_esm_contact_map(sequence)
# Check that the dimensions agree.
if predicted_contact_map.shape[0] != true_contact_map.shape[0]:
print("Warning: The predicted contact map and true contact map have different dimensions.")
# Plot the maps.
plot_maps(true_contact_map, predicted_contact_map, pdb_file)
print(f"Contact maps saved to: contact_maps_{os.path.splitext(os.path.basename(pdb_file))[0]}.png")
finally:
# Clean up the temporary file if we downloaded a random PDB
if args.pdb_file is None and os.path.exists(pdb_file):
os.remove(pdb_file)
print(f"Removed temporary PDB file: {pdb_file}")
if __name__ == '__main__':
main()
|