| import torch |
| import numpy as np |
| from torch_geometric.data import Data, Batch |
| from rdkit import Chem |
| from rdkit.Chem import AllChem |
| from rdkit.Chem import Descriptors |
| import py3Dmol |
| from jinja2 import Environment, FileSystemLoader |
| from google import genai |
| from decouple import config |
| import time |
|
|
| GEMINI_API_KEY = config("GEMINI_API_KEY") |
|
|
|
|
| from dataset import get_atom_features, get_protein_features |
| from model_attention import BindingAffinityModel |
|
|
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| |
| GAT_HEADS = 2 |
| HIDDEN_CHANNELS = 256 |
|
|
| MODEL_PATH = "models/model_ep041_attention_mse1.9153.pth" |
| |
| |
| |
|
|
|
|
| def get_inference_data(ligand_smiles, protein_sequence, model_path=MODEL_PATH): |
| |
| mol = Chem.MolFromSmiles(ligand_smiles) |
| mol = Chem.AddHs(mol) |
| AllChem.EmbedMolecule(mol, randomSeed=42) |
|
|
| |
| atom_features = [get_atom_features(atom) for atom in mol.GetAtoms()] |
| x = torch.tensor(np.array(atom_features), dtype=torch.float) |
| edge_index = [] |
| for bond in mol.GetBonds(): |
| i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() |
| edge_index.extend([(i, j), (j, i)]) |
| edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous() |
|
|
| |
| tokens = [get_protein_features(c) for c in protein_sequence] |
| if len(tokens) > 1200: |
| tokens = tokens[:1200] |
| else: |
| tokens.extend([0] * (1200 - len(tokens))) |
| protein_sequence_tensor = ( |
| torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(DEVICE) |
| ) |
|
|
| data = Data(x=x, edge_index=edge_index) |
| batch = Batch.from_data_list([data]).to(DEVICE) |
| num_features = x.shape[1] |
|
|
| |
| model = BindingAffinityModel( |
| num_features, hidden_channels=HIDDEN_CHANNELS, gat_heads=GAT_HEADS |
| ).to(DEVICE) |
| model.load_state_dict( |
| torch.load(model_path, map_location=DEVICE, weights_only=False) |
| ) |
| model.eval() |
|
|
| with torch.no_grad(): |
| pred = model(batch.x, batch.edge_index, batch.batch, protein_sequence_tensor) |
| attention_weights = model.cross_attention.last_attention_weights[0] |
|
|
| |
| real_prot_len = len([t for t in tokens if t != 0]) |
| importance = attention_weights[:, :real_prot_len].max(dim=1).values.cpu().numpy() |
|
|
| if importance.max() > 0: |
| importance = (importance - importance.min()) / ( |
| importance.max() - importance.min() |
| ) |
|
|
| importance[importance < 0.01] = 0 |
| return mol, importance, pred.item() |
|
|
|
|
| def get_lipinski_properties(mol): |
| mw = Descriptors.MolWt(mol) |
| hba = Descriptors.NOCount(mol) |
| hbd = Descriptors.NHOHCount(mol) |
| logp = Descriptors.MolLogP(mol) |
| tpsa = Descriptors.TPSA(mol) |
|
|
| violations = 0 |
| bad_params = [] |
| if mw > 500: |
| violations += 1 |
| bad_params.append("Mass > 500") |
| if logp > 5: |
| violations += 1 |
| bad_params.append("LogP > 5") |
| if hbd > 5: |
| violations += 1 |
| bad_params.append("H-Donors > 5") |
| if hba > 10: |
| violations += 1 |
| bad_params.append("H-Acceptors > 10") |
|
|
| if violations == 0: |
| status = "Excellent (Drug-like) 🟢" |
| css_class = "success" |
| elif violations == 1: |
| status = "Acceptable (1 violation) 🟡" |
| css_class = "warning" |
| else: |
| status = f"Poor ({violations} violations) 🔴" |
| css_class = "danger" |
|
|
| return { |
| "MW": round(mw, 2), |
| "LogP": round(logp, 2), |
| "HBD": hbd, |
| "HBA": hba, |
| "TPSA": round(tpsa, 2), |
| "violations": violations, |
| "status_text": status, |
| "css_class": css_class, |
| "bad_params": ", ".join(bad_params) if bad_params else "None", |
| } |
|
|
|
|
| def get_py3dmol_view(mol, importance): |
| view = py3Dmol.view(width="100%", height="600px") |
| view.addModel(Chem.MolToMolBlock(mol), "sdf") |
| view.setBackgroundColor("white") |
|
|
| view.setStyle({}, {"stick": {"radius": 0.15}, "sphere": {"scale": 0.25}}) |
|
|
| indices_sorted = np.argsort(importance)[::-1] |
| top_indices = set(indices_sorted[:15]) |
|
|
| conf = mol.GetConformer() |
|
|
| for i, val in enumerate(importance): |
| if i in top_indices: |
| pos = conf.GetAtomPosition(i) |
| symbol = mol.GetAtomWithIdx(i).GetSymbol() |
| label_text = f"{i}:{symbol}:{val:.2f}" |
|
|
| view.addLabel( |
| label_text, |
| { |
| "position": {"x": pos.x, "y": pos.y, "z": pos.z}, |
| "fontSize": 14, |
| "fontColor": "white", |
| "backgroundColor": "black", |
| "backgroundOpacity": 0.7, |
| "borderThickness": 0, |
| "inFront": True, |
| "showBackground": True, |
| }, |
| ) |
| view.zoomTo() |
| return view |
|
|
|
|
| def save_standalone_ngl_html(mol, importance, filepath): |
| pdb_block = Chem.MolToPDBBlock(mol) |
| mol_pdb = Chem.MolFromPDBBlock(pdb_block, removeHs=False) |
|
|
| for i, atom in enumerate(mol_pdb.GetAtoms()): |
| info = atom.GetPDBResidueInfo() |
| if info: |
| info.SetTempFactor(float(importance[i])) |
|
|
| final_pdb_block = Chem.MolToPDBBlock(mol_pdb) |
| final_pdb_block = final_pdb_block.replace("`", "\\`") |
|
|
| indices_sorted = np.argsort(importance)[::-1] |
| top_indices = indices_sorted[:15] |
|
|
| selection_list = [str(i) for i in top_indices] |
| selection_str = "@" + ",".join(selection_list) |
|
|
| if not selection_list: |
| selection_str = "@-1" |
|
|
| env = Environment(loader=FileSystemLoader("templates")) |
| template = env.get_template("ngl_view.html") |
|
|
| rendered_html = template.render( |
| pdb_block=final_pdb_block, selection_str=selection_str |
| ) |
|
|
| with open(filepath, "w", encoding="utf-8") as f: |
| f.write(rendered_html) |
|
|
|
|
| def get_gemini_explanation( |
| ligand_smiles, protein_sequence, affinity, top_atoms, lipinski |
| ): |
| if not GEMINI_API_KEY: |
| return "<p class='text-warning'>API Key for Gemini not found. Please set GOOGLE_API_KEY environment variable.</p>" |
|
|
| |
| atoms_desc = ", ".join( |
| [f"{a['symbol']}(idx {a['id']}, score {a['score']})" for a in top_atoms[:10]] |
| ) |
|
|
| |
| prot_short = ( |
| protein_sequence[:100] + "..." |
| if len(protein_sequence) > 100 |
| else protein_sequence |
| ) |
|
|
| prompt = f""" |
| You are an expert Computational Chemist and Drug Discovery Scientist. |
| Analyze the following interaction results between a Ligand and a Protein. |
| |
| **Data:** |
| 1. **Ligand (SMILES):** `{ligand_smiles}` |
| 2. **Target Protein (Start):** `{prot_short}` |
| 3. **Predicted Binding Affinity (pKd):** {affinity} (Note: >7 is usually good, <5 is weak). |
| 4. **Top Active Atoms (Attention Weights):** {atoms_desc}. These atoms had the highest attention scores in the Graph Neural Network with attention. |
| 5. **Lipinski Properties:** {lipinski['status_text']} (Violations: {lipinski['violations']}). |
| |
| **Task:** |
| Write a concise, professional scientific summary (in HTML format, use <p>, <ul>, <li>, <b>). |
| Cover these points: |
| 1. **Affinity Analysis:** Is the binding strong? What does a pKd of {affinity} imply for a drug candidate? |
| 2. **Structural Basis:** Why might the model have focused on the atoms listed above (e.g., Nitrogen/Oxygen often act as H-bond donors/acceptors, Rings for stacking)? |
| 3. **Drug-Likeness:** Comment on the Lipinski status. Is it suitable for oral administration? |
| 4. **Conclusion:** Verdict on whether to proceed with this molecule. |
| Keep it relatively short (max 150 words). Do not include markdown code blocks (```html), just return the raw HTML tags. |
| """ |
| max_retries = 3 |
| client = genai.Client(api_key=GEMINI_API_KEY) |
| for attempt in range(max_retries): |
| try: |
| response = client.models.generate_content( |
| model="gemini-2.5-flash", contents=prompt |
| ) |
| return response.text |
|
|
| except Exception as e: |
| error_msg = str(e).lower() |
|
|
| if "503" in error_msg or "overloaded" in error_msg or "429" in error_msg: |
| if attempt < max_retries - 1: |
| wait_time = 2 * (attempt + 1) |
| print(f"Gemini overloaded, retrying in {wait_time}s...") |
| time.sleep(wait_time) |
| continue |
|
|
| return f"<p class='text-danger'>Error generating explanation: {str(e)}</p>" |
|
|
| return "<p class='text-danger'>Error: Gemini unavailable after retries.</p>" |
|
|