Spaces:
Running on Zero
Running on Zero
Justine Yuan commited on
Commit ·
3beba17
1
Parent(s): e732716
Caliby HuggingFace example
Browse files- .gitignore +220 -0
- README.md +2 -1
- app.py +434 -0
- app_config.py +12 -0
- caliby_transparent.png +0 -0
- constraints.py +46 -0
- design.py +244 -0
- ensemble.py +36 -0
- file_utils.py +73 -0
- models.py +23 -0
- pyproject.toml +30 -0
- requirements.txt +8 -0
- self_consistency.py +35 -0
- tests/conftest.py +44 -0
- tests/test_design_sequences.py +475 -0
- tests/test_helpers.py +295 -0
- viewers.py +79 -0
.gitignore
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[codz]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py.cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
# Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# UV
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
uv.lock
|
| 102 |
+
|
| 103 |
+
# poetry
|
| 104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 106 |
+
# commonly ignored for libraries.
|
| 107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 108 |
+
# poetry.lock
|
| 109 |
+
# poetry.toml
|
| 110 |
+
|
| 111 |
+
# pdm
|
| 112 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 113 |
+
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
| 114 |
+
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
| 115 |
+
# pdm.lock
|
| 116 |
+
# pdm.toml
|
| 117 |
+
.pdm-python
|
| 118 |
+
.pdm-build/
|
| 119 |
+
|
| 120 |
+
# pixi
|
| 121 |
+
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
| 122 |
+
# pixi.lock
|
| 123 |
+
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
| 124 |
+
# in the .venv directory. It is recommended not to include this directory in version control.
|
| 125 |
+
.pixi
|
| 126 |
+
|
| 127 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 128 |
+
__pypackages__/
|
| 129 |
+
|
| 130 |
+
# Celery stuff
|
| 131 |
+
celerybeat-schedule
|
| 132 |
+
celerybeat.pid
|
| 133 |
+
|
| 134 |
+
# Redis
|
| 135 |
+
*.rdb
|
| 136 |
+
*.aof
|
| 137 |
+
*.pid
|
| 138 |
+
|
| 139 |
+
# RabbitMQ
|
| 140 |
+
mnesia/
|
| 141 |
+
rabbitmq/
|
| 142 |
+
rabbitmq-data/
|
| 143 |
+
|
| 144 |
+
# ActiveMQ
|
| 145 |
+
activemq-data/
|
| 146 |
+
|
| 147 |
+
# SageMath parsed files
|
| 148 |
+
*.sage.py
|
| 149 |
+
|
| 150 |
+
# Environments
|
| 151 |
+
.env
|
| 152 |
+
.envrc
|
| 153 |
+
.venv
|
| 154 |
+
env/
|
| 155 |
+
venv/
|
| 156 |
+
ENV/
|
| 157 |
+
env.bak/
|
| 158 |
+
venv.bak/
|
| 159 |
+
|
| 160 |
+
# Spyder project settings
|
| 161 |
+
.spyderproject
|
| 162 |
+
.spyproject
|
| 163 |
+
|
| 164 |
+
# Rope project settings
|
| 165 |
+
.ropeproject
|
| 166 |
+
|
| 167 |
+
# mkdocs documentation
|
| 168 |
+
/site
|
| 169 |
+
|
| 170 |
+
# mypy
|
| 171 |
+
.mypy_cache/
|
| 172 |
+
.dmypy.json
|
| 173 |
+
dmypy.json
|
| 174 |
+
|
| 175 |
+
# Pyre type checker
|
| 176 |
+
.pyre/
|
| 177 |
+
|
| 178 |
+
# pytype static type analyzer
|
| 179 |
+
.pytype/
|
| 180 |
+
|
| 181 |
+
# Cython debug symbols
|
| 182 |
+
cython_debug/
|
| 183 |
+
|
| 184 |
+
# PyCharm
|
| 185 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 186 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 187 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 188 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 189 |
+
# .idea/
|
| 190 |
+
|
| 191 |
+
# Abstra
|
| 192 |
+
# Abstra is an AI-powered process automation framework.
|
| 193 |
+
# Ignore directories containing user credentials, local state, and settings.
|
| 194 |
+
# Learn more at https://abstra.io/docs
|
| 195 |
+
.abstra/
|
| 196 |
+
|
| 197 |
+
# Visual Studio Code
|
| 198 |
+
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
| 199 |
+
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
| 200 |
+
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
| 201 |
+
# you could uncomment the following to ignore the entire vscode folder
|
| 202 |
+
# .vscode/
|
| 203 |
+
|
| 204 |
+
# Ruff stuff:
|
| 205 |
+
.ruff_cache/
|
| 206 |
+
|
| 207 |
+
# PyPI configuration file
|
| 208 |
+
.pypirc
|
| 209 |
+
|
| 210 |
+
# Marimo
|
| 211 |
+
marimo/_static/
|
| 212 |
+
marimo/_lsp/
|
| 213 |
+
__marimo__/
|
| 214 |
+
|
| 215 |
+
# Streamlit
|
| 216 |
+
.streamlit/secrets.toml
|
| 217 |
+
|
| 218 |
+
envs
|
| 219 |
+
|
| 220 |
+
CLAUDE.md
|
README.md
CHANGED
|
@@ -4,7 +4,8 @@ emoji: 🐢
|
|
| 4 |
colorFrom: yellow
|
| 5 |
colorTo: yellow
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
|
|
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: apache-2.0
|
|
|
|
| 4 |
colorFrom: yellow
|
| 5 |
colorTo: yellow
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: "6.6.0"
|
| 8 |
+
python_version: "3.12"
|
| 9 |
app_file: app.py
|
| 10 |
pinned: false
|
| 11 |
license: apache-2.0
|
app.py
ADDED
|
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Gradio app for Caliby sequence design."""
|
| 2 |
+
|
| 3 |
+
import base64
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import gradio as gr
|
| 7 |
+
|
| 8 |
+
# Eagerly import so the wandb/pydantic init runs in the main thread
|
| 9 |
+
# (where sys.modules['__main__'] exists), not in a Gradio worker thread.
|
| 10 |
+
import caliby.data.preprocessing.atomworks.clean_pdbs # noqa: F401
|
| 11 |
+
|
| 12 |
+
from design import design_sequences
|
| 13 |
+
from file_utils import _get_file_path, _write_zip_from_paths
|
| 14 |
+
from viewers import (
|
| 15 |
+
_csv_download_output,
|
| 16 |
+
_file_output,
|
| 17 |
+
_format_results_display,
|
| 18 |
+
_get_best_sc_sample,
|
| 19 |
+
_render_af2_viewer,
|
| 20 |
+
_update_viewers,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _get_upload_instructions(mode: str) -> str:
|
| 25 |
+
if mode == "none":
|
| 26 |
+
return "Upload a single PDB or CIF file."
|
| 27 |
+
elif mode == "synthetic":
|
| 28 |
+
return "Upload a single PDB or CIF file. Conformers will be generated automatically."
|
| 29 |
+
else:
|
| 30 |
+
return "Upload all PDB files — primary conformer first, then additional conformers."
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _clean_uploaded_pdbs(pdb_files: list | None):
|
| 34 |
+
if not pdb_files:
|
| 35 |
+
return None, gr.update(visible=False), gr.update(visible=False), gr.update(interactive=False)
|
| 36 |
+
|
| 37 |
+
from caliby import clean_pdbs
|
| 38 |
+
|
| 39 |
+
pdb_paths = [str(_get_file_path(f)) for f in pdb_files]
|
| 40 |
+
cleaned_paths = clean_pdbs(pdb_paths)
|
| 41 |
+
zip_path = _write_zip_from_paths(cleaned_paths, "cleaned_pdbs", ".zip")
|
| 42 |
+
|
| 43 |
+
return (
|
| 44 |
+
cleaned_paths,
|
| 45 |
+
gr.update(
|
| 46 |
+
value="**Note:** Your files have been cleaned and standardized to mmCIF format "
|
| 47 |
+
"to avoid downstream parsing and alignment issues. "
|
| 48 |
+
"If you plan to use positional constraints, please download the cleaned files and double "
|
| 49 |
+
"check the new residue indices.",
|
| 50 |
+
visible=True,
|
| 51 |
+
),
|
| 52 |
+
gr.update(value=zip_path, visible=True),
|
| 53 |
+
gr.update(interactive=True),
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _reset_cleaned_state():
|
| 58 |
+
return None, gr.update(visible=False), gr.update(visible=False), gr.update(interactive=False)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def submit_design_sequences(
|
| 62 |
+
cleaned_files: list[str] | None,
|
| 63 |
+
ensemble_mode: str,
|
| 64 |
+
model_variant: str,
|
| 65 |
+
num_seqs: int,
|
| 66 |
+
omit_aas: list[str] | None,
|
| 67 |
+
temperature: float,
|
| 68 |
+
fixed_pos_seq: str,
|
| 69 |
+
fixed_pos_scn: str,
|
| 70 |
+
fixed_pos_override_seq: str,
|
| 71 |
+
pos_restrict_aatype: str,
|
| 72 |
+
symmetry_pos: str,
|
| 73 |
+
num_protpardelle_conformers: int,
|
| 74 |
+
run_af2_eval: bool = False,
|
| 75 |
+
):
|
| 76 |
+
df, fasta_text, out_zip_path, sc_zip_path, af2_pdb_data, input_pdb_data = design_sequences(
|
| 77 |
+
pdb_files=cleaned_files,
|
| 78 |
+
ensemble_mode=ensemble_mode,
|
| 79 |
+
model_variant=model_variant,
|
| 80 |
+
num_seqs=num_seqs,
|
| 81 |
+
omit_aas=omit_aas,
|
| 82 |
+
temperature=temperature,
|
| 83 |
+
fixed_pos_seq=fixed_pos_seq,
|
| 84 |
+
fixed_pos_scn=fixed_pos_scn,
|
| 85 |
+
fixed_pos_override_seq=fixed_pos_override_seq,
|
| 86 |
+
pos_restrict_aatype=pos_restrict_aatype,
|
| 87 |
+
symmetry_pos=symmetry_pos,
|
| 88 |
+
num_protpardelle_conformers=num_protpardelle_conformers,
|
| 89 |
+
run_af2_eval=run_af2_eval,
|
| 90 |
+
)
|
| 91 |
+
has_af2 = bool(af2_pdb_data)
|
| 92 |
+
best_sample = _get_best_sc_sample(df) if has_af2 else ""
|
| 93 |
+
af2_html = _render_af2_viewer(best_sample, af2_pdb_data) if has_af2 else ""
|
| 94 |
+
|
| 95 |
+
return (
|
| 96 |
+
gr.update(visible=True),
|
| 97 |
+
gr.update(value=_format_results_display(df), visible=True),
|
| 98 |
+
df,
|
| 99 |
+
gr.update(value=fasta_text, visible=True),
|
| 100 |
+
_file_output(out_zip_path),
|
| 101 |
+
_file_output(sc_zip_path),
|
| 102 |
+
af2_pdb_data,
|
| 103 |
+
input_pdb_data,
|
| 104 |
+
best_sample,
|
| 105 |
+
gr.update(visible=has_af2),
|
| 106 |
+
af2_html,
|
| 107 |
+
gr.update(value="", visible=False),
|
| 108 |
+
gr.update(visible=False),
|
| 109 |
+
gr.update(visible=False),
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
theme = gr.themes.Base(
|
| 114 |
+
primary_hue="amber",
|
| 115 |
+
secondary_hue="orange",
|
| 116 |
+
radius_size="lg",
|
| 117 |
+
font=[gr.themes.GoogleFont('Instrument Sans'), 'ui-sans-serif', 'system-ui', 'sans-serif'],
|
| 118 |
+
).set(
|
| 119 |
+
body_text_color='*neutral_700',
|
| 120 |
+
body_text_color_dark='*neutral_300',
|
| 121 |
+
body_text_color_subdued='*neutral_500',
|
| 122 |
+
block_title_text_color='*neutral_700',
|
| 123 |
+
block_info_text_color='*neutral_500',
|
| 124 |
+
block_border_width_dark='0px',
|
| 125 |
+
block_padding='*spacing_xl calc(*spacing_xl + 3px)',
|
| 126 |
+
block_label_border_width_dark='0px',
|
| 127 |
+
block_label_padding='*spacing_md *spacing_lg',
|
| 128 |
+
button_secondary_background_fill_dark='*neutral_600',
|
| 129 |
+
checkbox_label_text_color_dark='*neutral_100',
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
css = """
|
| 133 |
+
.loading-pulse { animation: pulse 2.5s ease-in-out infinite; }
|
| 134 |
+
@keyframes pulse { 0%, 100% { opacity: 1; } 50% { opacity: 0.3; } }
|
| 135 |
+
.omit-aa-dropdown ul { max-height: 200px !important; overflow-y: auto; }
|
| 136 |
+
.compact-file .large { min-height: 50px !important; }
|
| 137 |
+
#results-table th:nth-child(2),
|
| 138 |
+
#results-table td:nth-child(2) {
|
| 139 |
+
max-width: 28rem;
|
| 140 |
+
width: 28rem;
|
| 141 |
+
}
|
| 142 |
+
#results-table td:nth-child(2) {
|
| 143 |
+
overflow: hidden;
|
| 144 |
+
}
|
| 145 |
+
#results-table td:nth-child(2) > div {
|
| 146 |
+
display: block;
|
| 147 |
+
max-width: 100%;
|
| 148 |
+
overflow-x: auto;
|
| 149 |
+
overflow-y: hidden;
|
| 150 |
+
white-space: nowrap !important;
|
| 151 |
+
scrollbar-width: thin;
|
| 152 |
+
}
|
| 153 |
+
#af2-viewer, #ref-viewer {
|
| 154 |
+
display: flex;
|
| 155 |
+
justify-content: center;
|
| 156 |
+
}
|
| 157 |
+
#af2-viewer iframe, #ref-viewer iframe {
|
| 158 |
+
max-width: 100%;
|
| 159 |
+
}
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
_LOGO_B64 = base64.b64encode(Path(__file__).with_name("caliby_transparent.png").read_bytes()).decode()
|
| 163 |
+
|
| 164 |
+
with gr.Blocks(title="Caliby - Protein Sequence Design") as demo:
|
| 165 |
+
gr.HTML(
|
| 166 |
+
'<div style="display: flex; align-items: center; gap: 16px;">'
|
| 167 |
+
f'<img src="data:image/png;base64,{_LOGO_B64}" alt="Caliby logo" style="height: 80px;">'
|
| 168 |
+
'<h1 style="margin: 0;">Caliby - Protein Sequence Design</h1>'
|
| 169 |
+
'</div>'
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
with gr.Row():
|
| 173 |
+
with gr.Column(scale=1):
|
| 174 |
+
model_variant = gr.Radio(
|
| 175 |
+
choices=[
|
| 176 |
+
("Caliby", "caliby"),
|
| 177 |
+
("SolubleCaliby v1", "soluble_caliby_v1"),
|
| 178 |
+
],
|
| 179 |
+
value="caliby",
|
| 180 |
+
label="Model",
|
| 181 |
+
)
|
| 182 |
+
ensemble_mode = gr.Radio(
|
| 183 |
+
choices=[
|
| 184 |
+
("Fixed backbone", "none"),
|
| 185 |
+
("Synthetic ensemble", "synthetic"),
|
| 186 |
+
("Upload your own ensemble", "user"),
|
| 187 |
+
],
|
| 188 |
+
value="synthetic",
|
| 189 |
+
label="Ensemble mode",
|
| 190 |
+
)
|
| 191 |
+
run_af2_eval = gr.Checkbox(
|
| 192 |
+
label="Run AF2 self-consistency evaluation",
|
| 193 |
+
value=False,
|
| 194 |
+
info="Refold designed sequences with AlphaFold2 and compute scRMSD, pLDDT, and TM-score",
|
| 195 |
+
)
|
| 196 |
+
upload_instructions = gr.Markdown(
|
| 197 |
+
_get_upload_instructions("synthetic"),
|
| 198 |
+
)
|
| 199 |
+
pdb_input = gr.File(
|
| 200 |
+
file_count="multiple",
|
| 201 |
+
label="PDB/CIF file(s)",
|
| 202 |
+
file_types=[".pdb", ".cif"],
|
| 203 |
+
)
|
| 204 |
+
finish_upload_btn = gr.Button("Upload", variant="secondary")
|
| 205 |
+
cleaned_files_state = gr.State(None)
|
| 206 |
+
clean_notification = gr.Markdown(visible=False)
|
| 207 |
+
clean_download = gr.File(
|
| 208 |
+
label="Download cleaned files", visible=False, elem_classes=["compact-file"]
|
| 209 |
+
)
|
| 210 |
+
num_seqs = gr.Slider(
|
| 211 |
+
minimum=1,
|
| 212 |
+
maximum=4,
|
| 213 |
+
value=1,
|
| 214 |
+
step=1,
|
| 215 |
+
label="Number of sequences",
|
| 216 |
+
)
|
| 217 |
+
omit_aas = gr.Dropdown(
|
| 218 |
+
choices=[
|
| 219 |
+
"A",
|
| 220 |
+
"C",
|
| 221 |
+
"D",
|
| 222 |
+
"E",
|
| 223 |
+
"F",
|
| 224 |
+
"G",
|
| 225 |
+
"H",
|
| 226 |
+
"I",
|
| 227 |
+
"K",
|
| 228 |
+
"L",
|
| 229 |
+
"M",
|
| 230 |
+
"N",
|
| 231 |
+
"P",
|
| 232 |
+
"Q",
|
| 233 |
+
"R",
|
| 234 |
+
"S",
|
| 235 |
+
"T",
|
| 236 |
+
"V",
|
| 237 |
+
"W",
|
| 238 |
+
"Y",
|
| 239 |
+
],
|
| 240 |
+
multiselect=True,
|
| 241 |
+
label="Amino acids to omit",
|
| 242 |
+
elem_classes=["omit-aa-dropdown"],
|
| 243 |
+
)
|
| 244 |
+
temperature = gr.Slider(
|
| 245 |
+
minimum=0.01,
|
| 246 |
+
maximum=1,
|
| 247 |
+
value=0.01,
|
| 248 |
+
step=0.01,
|
| 249 |
+
label="Sampling temperature",
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
submit_btn = gr.Button("Design sequences", variant="primary", interactive=False)
|
| 253 |
+
|
| 254 |
+
with gr.Accordion("Advanced constraints", open=False):
|
| 255 |
+
fixed_pos_seq = gr.Textbox(
|
| 256 |
+
label="Fixed positions",
|
| 257 |
+
info="Format: A1-100,B1-100 \nSequence positions in the input PDB to condition on so that they"
|
| 258 |
+
" remain fixed during design. For ensemble-conditioned design, fixed_pos_seq is applied using"
|
| 259 |
+
" the primary conformer's sequence.",
|
| 260 |
+
placeholder="e.g. A1-100,B1-100",
|
| 261 |
+
)
|
| 262 |
+
fixed_pos_scn = gr.Textbox(
|
| 263 |
+
label="Fixed sidechain positions",
|
| 264 |
+
info="Format: A1-10,A12,A15-20 \nSidechain positions in the input PDB to condition on so that they"
|
| 265 |
+
" remain fixed during design. Note that fixed sidechain positions must be a subset of fixed"
|
| 266 |
+
" sequence positions, since it does not make sense to condition on a sidechain without also"
|
| 267 |
+
" conditioning on its sequence identity.",
|
| 268 |
+
placeholder="e.g. A1-10,A12,A15-20",
|
| 269 |
+
)
|
| 270 |
+
fixed_pos_override_seq = gr.Textbox(
|
| 271 |
+
label="Override sequence at positions",
|
| 272 |
+
info="Format: A26:A,A27:L \nSequence positions in the input PDB to first override the sequence at,"
|
| 273 |
+
" and then condition on. The colon separates the position and the desired amino acid.",
|
| 274 |
+
placeholder="e.g. A26:A,A27:L",
|
| 275 |
+
)
|
| 276 |
+
pos_restrict_aatype = gr.Textbox(
|
| 277 |
+
label="Position restrictions",
|
| 278 |
+
info="Format: A26:AVG,A27:VG \nAllowed amino acids for certain positions in the input PDB. The"
|
| 279 |
+
" colon separates the position and the allowed amino acids.",
|
| 280 |
+
placeholder="e.g. A26:AVG,A27:VG",
|
| 281 |
+
)
|
| 282 |
+
symmetry_pos = gr.Textbox(
|
| 283 |
+
label="Symmetry positions",
|
| 284 |
+
info="Format: A10,B10,C10|A11,B11,C11 \nSymmetry positions for tying sampling across residue"
|
| 285 |
+
" positions. The pipe separates groups of positions to sample symmetrically. In the example,"
|
| 286 |
+
" A10, B10, and C10 are tied together, and A11, B11, and C11 are tied together.",
|
| 287 |
+
placeholder="e.g. A10,B10,C10|A11,B11,C11",
|
| 288 |
+
)
|
| 289 |
+
num_protpardelle_conformers = gr.Slider(
|
| 290 |
+
minimum=1,
|
| 291 |
+
maximum=15,
|
| 292 |
+
value=15,
|
| 293 |
+
step=1,
|
| 294 |
+
label="Number of conformers to generate",
|
| 295 |
+
visible=True,
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
with gr.Column(scale=2):
|
| 299 |
+
raw_results_df = gr.State(None)
|
| 300 |
+
af2_pdb_state = gr.State({})
|
| 301 |
+
input_pdb_state = gr.State({})
|
| 302 |
+
best_sample_state = gr.State("")
|
| 303 |
+
|
| 304 |
+
results_placeholder = gr.Markdown(
|
| 305 |
+
"Results will appear here after designing sequences.",
|
| 306 |
+
)
|
| 307 |
+
results_header = gr.Markdown("### Results", visible=False)
|
| 308 |
+
results_df = gr.Dataframe(
|
| 309 |
+
show_label=False,
|
| 310 |
+
interactive=False,
|
| 311 |
+
wrap=False,
|
| 312 |
+
column_widths=[160, 448],
|
| 313 |
+
elem_id="results-table",
|
| 314 |
+
visible=False,
|
| 315 |
+
)
|
| 316 |
+
fasta_output = gr.Textbox(
|
| 317 |
+
label="Sequences (FASTA)",
|
| 318 |
+
lines=10,
|
| 319 |
+
visible=False,
|
| 320 |
+
)
|
| 321 |
+
with gr.Row():
|
| 322 |
+
csv_download = gr.File(label="Download results CSV", elem_classes=["compact-file"], visible=False)
|
| 323 |
+
output_files = gr.File(label="Download CIF files", elem_classes=["compact-file"], visible=False)
|
| 324 |
+
sc_output_files = gr.File(
|
| 325 |
+
label="Download AF2 self-consistency outputs",
|
| 326 |
+
elem_classes=["compact-file"],
|
| 327 |
+
visible=False,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
with gr.Column(visible=False) as viewer_section:
|
| 331 |
+
gr.Markdown("---")
|
| 332 |
+
with gr.Row():
|
| 333 |
+
gr.Markdown("### AF2 Prediction")
|
| 334 |
+
af2_color_mode = gr.Dropdown(
|
| 335 |
+
choices=[
|
| 336 |
+
("pLDDT", "plddt"),
|
| 337 |
+
("Chain", "chain"),
|
| 338 |
+
("Rainbow", "rainbow"),
|
| 339 |
+
("Secondary structure", "secondary"),
|
| 340 |
+
],
|
| 341 |
+
value="plddt",
|
| 342 |
+
label="Color by",
|
| 343 |
+
scale=0,
|
| 344 |
+
)
|
| 345 |
+
af2_viewer = gr.HTML(elem_id="af2-viewer")
|
| 346 |
+
show_overlay = gr.Checkbox(label="Show reference structure", value=False)
|
| 347 |
+
with gr.Column(visible=False) as ref_section:
|
| 348 |
+
with gr.Row():
|
| 349 |
+
gr.Markdown("### Reference Structure")
|
| 350 |
+
ref_color_mode = gr.Dropdown(
|
| 351 |
+
choices=[
|
| 352 |
+
("Chain", "chain"),
|
| 353 |
+
("pLDDT", "plddt"),
|
| 354 |
+
("Rainbow", "rainbow"),
|
| 355 |
+
("Secondary structure", "secondary"),
|
| 356 |
+
],
|
| 357 |
+
value="chain",
|
| 358 |
+
label="Color by",
|
| 359 |
+
scale=0,
|
| 360 |
+
)
|
| 361 |
+
reference_viewer = gr.HTML(elem_id="ref-viewer")
|
| 362 |
+
|
| 363 |
+
submit_btn.click(
|
| 364 |
+
fn=lambda: gr.update(value='<div class="loading-pulse">Running design pipeline\u2026</div>', visible=True),
|
| 365 |
+
outputs=[results_placeholder],
|
| 366 |
+
).then(
|
| 367 |
+
fn=submit_design_sequences,
|
| 368 |
+
inputs=[
|
| 369 |
+
cleaned_files_state,
|
| 370 |
+
ensemble_mode,
|
| 371 |
+
model_variant,
|
| 372 |
+
num_seqs,
|
| 373 |
+
omit_aas,
|
| 374 |
+
temperature,
|
| 375 |
+
fixed_pos_seq,
|
| 376 |
+
fixed_pos_scn,
|
| 377 |
+
fixed_pos_override_seq,
|
| 378 |
+
pos_restrict_aatype,
|
| 379 |
+
symmetry_pos,
|
| 380 |
+
num_protpardelle_conformers,
|
| 381 |
+
run_af2_eval,
|
| 382 |
+
],
|
| 383 |
+
outputs=[
|
| 384 |
+
results_header,
|
| 385 |
+
results_df,
|
| 386 |
+
raw_results_df,
|
| 387 |
+
fasta_output,
|
| 388 |
+
output_files,
|
| 389 |
+
sc_output_files,
|
| 390 |
+
af2_pdb_state,
|
| 391 |
+
input_pdb_state,
|
| 392 |
+
best_sample_state,
|
| 393 |
+
viewer_section,
|
| 394 |
+
af2_viewer,
|
| 395 |
+
reference_viewer,
|
| 396 |
+
ref_section,
|
| 397 |
+
results_placeholder,
|
| 398 |
+
],
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
raw_results_df.change(fn=_csv_download_output, inputs=[raw_results_df], outputs=[csv_download])
|
| 402 |
+
|
| 403 |
+
finish_upload_btn.click(
|
| 404 |
+
fn=lambda: gr.update(value="Processing\u2026", interactive=False),
|
| 405 |
+
outputs=[finish_upload_btn],
|
| 406 |
+
).then(
|
| 407 |
+
fn=_clean_uploaded_pdbs,
|
| 408 |
+
inputs=[pdb_input],
|
| 409 |
+
outputs=[cleaned_files_state, clean_notification, clean_download, submit_btn],
|
| 410 |
+
).then(
|
| 411 |
+
fn=lambda: gr.update(value="Upload", interactive=True),
|
| 412 |
+
outputs=[finish_upload_btn],
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
pdb_input.change(
|
| 416 |
+
fn=_reset_cleaned_state,
|
| 417 |
+
outputs=[cleaned_files_state, clean_notification, clean_download, submit_btn],
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
ensemble_mode.change(
|
| 421 |
+
fn=lambda mode: (gr.update(visible=(mode == "synthetic")), _get_upload_instructions(mode)),
|
| 422 |
+
inputs=[ensemble_mode],
|
| 423 |
+
outputs=[num_protpardelle_conformers, upload_instructions],
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
viewer_inputs = [best_sample_state, af2_pdb_state, input_pdb_state, show_overlay, af2_color_mode, ref_color_mode]
|
| 427 |
+
viewer_outputs = [af2_viewer, reference_viewer, ref_section]
|
| 428 |
+
|
| 429 |
+
show_overlay.change(fn=_update_viewers, inputs=viewer_inputs, outputs=viewer_outputs)
|
| 430 |
+
af2_color_mode.change(fn=_update_viewers, inputs=viewer_inputs, outputs=viewer_outputs)
|
| 431 |
+
ref_color_mode.change(fn=_update_viewers, inputs=viewer_inputs, outputs=viewer_outputs)
|
| 432 |
+
|
| 433 |
+
if __name__ == "__main__":
|
| 434 |
+
demo.launch(theme=theme, css=css, ssr_mode=False)
|
app_config.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Configuration constants and module-level side effects (weight downloads)."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
from huggingface_hub import snapshot_download
|
| 6 |
+
|
| 7 |
+
WEIGHTS_DIR = snapshot_download(
|
| 8 |
+
repo_id="ProteinDesignLab/caliby-weights", repo_type="model", token=os.environ.get("HF_TOKEN")
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
# Set MODEL_PARAMS_DIR so caliby's weight utilities can find/download files.
|
| 12 |
+
os.environ.setdefault("MODEL_PARAMS_DIR", WEIGHTS_DIR)
|
caliby_transparent.png
ADDED
|
constraints.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Input validation and position constraint building."""
|
| 2 |
+
|
| 3 |
+
import pandas as pd
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def _validate_design_inputs(pdb_files: list | None, ensemble_mode: str) -> str | None:
|
| 7 |
+
if not pdb_files:
|
| 8 |
+
return "Upload at least one PDB or CIF file."
|
| 9 |
+
|
| 10 |
+
if ensemble_mode == "user" and len(pdb_files) < 2:
|
| 11 |
+
return "User ensemble mode requires at least two files."
|
| 12 |
+
|
| 13 |
+
single_file_mode_messages = {
|
| 14 |
+
"none": "Single structure mode requires exactly one file.",
|
| 15 |
+
"synthetic": "Protpardelle mode requires exactly one file.",
|
| 16 |
+
}
|
| 17 |
+
message = single_file_mode_messages.get(ensemble_mode)
|
| 18 |
+
if message and len(pdb_files) != 1:
|
| 19 |
+
return message
|
| 20 |
+
|
| 21 |
+
return None
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _build_pos_constraint_df(
|
| 25 |
+
pdb_key: str,
|
| 26 |
+
fixed_pos_seq: str,
|
| 27 |
+
fixed_pos_scn: str,
|
| 28 |
+
fixed_pos_override_seq: str,
|
| 29 |
+
pos_restrict_aatype: str,
|
| 30 |
+
symmetry_pos: str,
|
| 31 |
+
) -> pd.DataFrame | None:
|
| 32 |
+
row = {}
|
| 33 |
+
if fixed_pos_seq and fixed_pos_seq.strip():
|
| 34 |
+
row["fixed_pos_seq"] = fixed_pos_seq.strip()
|
| 35 |
+
if fixed_pos_scn and fixed_pos_scn.strip():
|
| 36 |
+
row["fixed_pos_scn"] = fixed_pos_scn.strip()
|
| 37 |
+
if fixed_pos_override_seq and fixed_pos_override_seq.strip():
|
| 38 |
+
row["fixed_pos_override_seq"] = fixed_pos_override_seq.strip()
|
| 39 |
+
if pos_restrict_aatype and pos_restrict_aatype.strip():
|
| 40 |
+
row["pos_restrict_aatype"] = pos_restrict_aatype.strip()
|
| 41 |
+
if symmetry_pos and symmetry_pos.strip():
|
| 42 |
+
row["symmetry_pos"] = symmetry_pos.strip()
|
| 43 |
+
if not row:
|
| 44 |
+
return None
|
| 45 |
+
row["pdb_key"] = pdb_key
|
| 46 |
+
return pd.DataFrame([row])
|
design.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Core design pipeline: context building, execution, output formatting."""
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
import tempfile
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import spaces
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from app_config import WEIGHTS_DIR
|
| 13 |
+
from constraints import _build_pos_constraint_df, _validate_design_inputs
|
| 14 |
+
from ensemble import _generate_protpardelle_ensemble, _setup_user_ensemble_dir
|
| 15 |
+
from file_utils import _copy_uploaded_files, _get_file_path, _sanitize_download_stem, _write_zip_from_paths
|
| 16 |
+
from models import get_model
|
| 17 |
+
from self_consistency import _run_self_consistency
|
| 18 |
+
|
| 19 |
+
# ZeroGPU quota-aware retry: request the max duration first, and if the
|
| 20 |
+
# scheduler returns a quota error (which is free — no GPU time consumed),
|
| 21 |
+
# parse the remaining seconds and retry with that exact amount.
|
| 22 |
+
_MAX_GPU_DURATION = 120 # Per-call max; daily quota is 210s but per-call cap is lower
|
| 23 |
+
_gpu_duration_override: int | None = None
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _dynamic_gpu_duration(*args, **kwargs) -> int:
|
| 27 |
+
"""Return the current GPU duration for @spaces.GPU scheduling."""
|
| 28 |
+
return _gpu_duration_override if _gpu_duration_override is not None else _MAX_GPU_DURATION
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _parse_quota_left(error: Exception) -> int | None:
|
| 32 |
+
"""Extract remaining GPU seconds from a ZeroGPU quota error message.
|
| 33 |
+
|
| 34 |
+
Returns the number of seconds left, or None if not a recoverable quota error.
|
| 35 |
+
"""
|
| 36 |
+
message = getattr(error, 'message', None)
|
| 37 |
+
if not isinstance(message, str):
|
| 38 |
+
return None
|
| 39 |
+
match = re.search(r'(\d+)s left\)', message)
|
| 40 |
+
return int(match.group(1)) if match else None
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _build_design_context(
|
| 44 |
+
pdb_paths: list[str],
|
| 45 |
+
ensemble_mode: str,
|
| 46 |
+
tmpdir: Path,
|
| 47 |
+
num_protpardelle_conformers: int,
|
| 48 |
+
fixed_pos_seq: str,
|
| 49 |
+
fixed_pos_scn: str,
|
| 50 |
+
fixed_pos_override_seq: str,
|
| 51 |
+
pos_restrict_aatype: str,
|
| 52 |
+
symmetry_pos: str,
|
| 53 |
+
) -> tuple[list[str] | dict[str, list[str]], pd.DataFrame | None]:
|
| 54 |
+
pdb_key = Path(pdb_paths[0]).stem
|
| 55 |
+
pos_constraint_df = _build_pos_constraint_df(
|
| 56 |
+
pdb_key=pdb_key,
|
| 57 |
+
fixed_pos_seq=fixed_pos_seq,
|
| 58 |
+
fixed_pos_scn=fixed_pos_scn,
|
| 59 |
+
fixed_pos_override_seq=fixed_pos_override_seq,
|
| 60 |
+
pos_restrict_aatype=pos_restrict_aatype,
|
| 61 |
+
symmetry_pos=symmetry_pos,
|
| 62 |
+
)
|
| 63 |
+
if ensemble_mode == "none":
|
| 64 |
+
return pdb_paths, pos_constraint_df
|
| 65 |
+
|
| 66 |
+
if ensemble_mode == "synthetic":
|
| 67 |
+
design_inputs = _generate_protpardelle_ensemble(
|
| 68 |
+
pdb_path=pdb_paths[0],
|
| 69 |
+
num_conformers=num_protpardelle_conformers,
|
| 70 |
+
out_dir=tmpdir,
|
| 71 |
+
weights_dir=WEIGHTS_DIR,
|
| 72 |
+
)
|
| 73 |
+
else:
|
| 74 |
+
design_inputs = _setup_user_ensemble_dir(pdb_paths=pdb_paths)
|
| 75 |
+
|
| 76 |
+
if pos_constraint_df is not None:
|
| 77 |
+
from caliby import make_ensemble_constraints
|
| 78 |
+
|
| 79 |
+
row = pos_constraint_df.iloc[0]
|
| 80 |
+
cols = {col: row[col] for col in pos_constraint_df.columns if col != "pdb_key"}
|
| 81 |
+
pos_constraint_df = make_ensemble_constraints({pdb_key: cols}, design_inputs)
|
| 82 |
+
|
| 83 |
+
return design_inputs, pos_constraint_df
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _format_outputs(outputs: dict) -> tuple[pd.DataFrame, str, list[str]]:
|
| 87 |
+
out_pdb_list = outputs["out_pdb"]
|
| 88 |
+
df = pd.DataFrame(
|
| 89 |
+
{
|
| 90 |
+
"Sample": [Path(out_pdb).stem for out_pdb in out_pdb_list],
|
| 91 |
+
"Sequence": outputs["seq"],
|
| 92 |
+
"Energy (U)": outputs["U"],
|
| 93 |
+
}
|
| 94 |
+
)
|
| 95 |
+
fasta_lines = []
|
| 96 |
+
for i, (eid, seq) in enumerate(zip(outputs["example_id"], outputs["seq"])):
|
| 97 |
+
fasta_lines.append(f">{eid}_sample{i}")
|
| 98 |
+
fasta_lines.append(seq)
|
| 99 |
+
fasta_text = "\n".join(fasta_lines)
|
| 100 |
+
return df, fasta_text, out_pdb_list
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@spaces.GPU(duration=_dynamic_gpu_duration)
|
| 104 |
+
def _design_sequences_gpu(
|
| 105 |
+
pdb_files: list | None,
|
| 106 |
+
ensemble_mode: str,
|
| 107 |
+
model_variant: str,
|
| 108 |
+
num_seqs: int,
|
| 109 |
+
omit_aas: list[str] | None,
|
| 110 |
+
temperature: float,
|
| 111 |
+
fixed_pos_seq: str,
|
| 112 |
+
fixed_pos_scn: str,
|
| 113 |
+
fixed_pos_override_seq: str,
|
| 114 |
+
pos_restrict_aatype: str,
|
| 115 |
+
symmetry_pos: str,
|
| 116 |
+
num_protpardelle_conformers: int,
|
| 117 |
+
run_af2_eval: bool = False,
|
| 118 |
+
):
|
| 119 |
+
validation_error = _validate_design_inputs(pdb_files, ensemble_mode)
|
| 120 |
+
if validation_error:
|
| 121 |
+
return pd.DataFrame(), validation_error, None, None, {}, {}
|
| 122 |
+
|
| 123 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 124 |
+
torch.set_grad_enabled(False)
|
| 125 |
+
download_stem = _sanitize_download_stem(_get_file_path(pdb_files[0]).stem)
|
| 126 |
+
|
| 127 |
+
gr.Info("Loading model...")
|
| 128 |
+
model = get_model(model_variant, device)
|
| 129 |
+
|
| 130 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 131 |
+
tmpdir = Path(tmpdir)
|
| 132 |
+
pdb_paths = _copy_uploaded_files(pdb_files, tmpdir)
|
| 133 |
+
input_pdb_data = {Path(p).stem: Path(p).read_text() for p in pdb_paths}
|
| 134 |
+
|
| 135 |
+
out_dir = tmpdir / "outputs"
|
| 136 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 137 |
+
|
| 138 |
+
if ensemble_mode == "synthetic":
|
| 139 |
+
gr.Info("Generating conformer ensemble...")
|
| 140 |
+
elif ensemble_mode == "user":
|
| 141 |
+
gr.Info("Preparing uploaded ensemble...")
|
| 142 |
+
design_inputs, pos_constraint_df = _build_design_context(
|
| 143 |
+
pdb_paths=pdb_paths,
|
| 144 |
+
ensemble_mode=ensemble_mode,
|
| 145 |
+
tmpdir=tmpdir,
|
| 146 |
+
num_protpardelle_conformers=num_protpardelle_conformers,
|
| 147 |
+
fixed_pos_seq=fixed_pos_seq,
|
| 148 |
+
fixed_pos_scn=fixed_pos_scn,
|
| 149 |
+
fixed_pos_override_seq=fixed_pos_override_seq,
|
| 150 |
+
pos_restrict_aatype=pos_restrict_aatype,
|
| 151 |
+
symmetry_pos=symmetry_pos,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
gr.Info("Designing sequences...")
|
| 155 |
+
sample_kwargs = dict(
|
| 156 |
+
out_dir=str(out_dir),
|
| 157 |
+
num_seqs_per_pdb=num_seqs,
|
| 158 |
+
omit_aas=omit_aas if omit_aas else None,
|
| 159 |
+
temperature=temperature,
|
| 160 |
+
num_workers=0,
|
| 161 |
+
pos_constraint_df=pos_constraint_df,
|
| 162 |
+
)
|
| 163 |
+
if ensemble_mode == "none":
|
| 164 |
+
outputs = model.sample(design_inputs, **sample_kwargs)
|
| 165 |
+
else:
|
| 166 |
+
outputs = model.ensemble_sample(design_inputs, **sample_kwargs)
|
| 167 |
+
|
| 168 |
+
df, fasta_text, out_pdb_list = _format_outputs(outputs)
|
| 169 |
+
|
| 170 |
+
sc_zip_path = None
|
| 171 |
+
af2_pdb_data = {}
|
| 172 |
+
if run_af2_eval:
|
| 173 |
+
gr.Info("Running AF2 self-consistency evaluation...")
|
| 174 |
+
sc_zip_path, af2_pdb_data = _run_self_consistency(model, df, out_pdb_list, out_dir, download_stem)
|
| 175 |
+
|
| 176 |
+
out_zip_path = _write_zip_from_paths(out_pdb_list, download_stem, "_designs.zip")
|
| 177 |
+
return df, fasta_text, out_zip_path, sc_zip_path, af2_pdb_data, input_pdb_data
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def design_sequences(
|
| 181 |
+
pdb_files: list | None,
|
| 182 |
+
ensemble_mode: str,
|
| 183 |
+
model_variant: str,
|
| 184 |
+
num_seqs: int,
|
| 185 |
+
omit_aas: list[str] | None,
|
| 186 |
+
temperature: float,
|
| 187 |
+
fixed_pos_seq: str,
|
| 188 |
+
fixed_pos_scn: str,
|
| 189 |
+
fixed_pos_override_seq: str,
|
| 190 |
+
pos_restrict_aatype: str,
|
| 191 |
+
symmetry_pos: str,
|
| 192 |
+
num_protpardelle_conformers: int,
|
| 193 |
+
run_af2_eval: bool = False,
|
| 194 |
+
):
|
| 195 |
+
"""Run sequence design with ZeroGPU quota-aware retry.
|
| 196 |
+
|
| 197 |
+
Requests the max GPU duration first. If the scheduler returns a quota
|
| 198 |
+
error (free — no GPU time consumed), parses the remaining seconds and
|
| 199 |
+
retries with that exact amount to maximize GPU utilization.
|
| 200 |
+
"""
|
| 201 |
+
global _gpu_duration_override
|
| 202 |
+
|
| 203 |
+
_gpu_duration_override = None
|
| 204 |
+
try:
|
| 205 |
+
return _design_sequences_gpu(
|
| 206 |
+
pdb_files=pdb_files,
|
| 207 |
+
ensemble_mode=ensemble_mode,
|
| 208 |
+
model_variant=model_variant,
|
| 209 |
+
num_seqs=num_seqs,
|
| 210 |
+
omit_aas=omit_aas,
|
| 211 |
+
temperature=temperature,
|
| 212 |
+
fixed_pos_seq=fixed_pos_seq,
|
| 213 |
+
fixed_pos_scn=fixed_pos_scn,
|
| 214 |
+
fixed_pos_override_seq=fixed_pos_override_seq,
|
| 215 |
+
pos_restrict_aatype=pos_restrict_aatype,
|
| 216 |
+
symmetry_pos=symmetry_pos,
|
| 217 |
+
num_protpardelle_conformers=num_protpardelle_conformers,
|
| 218 |
+
run_af2_eval=run_af2_eval,
|
| 219 |
+
)
|
| 220 |
+
except gr.Error as e:
|
| 221 |
+
remaining = _parse_quota_left(e)
|
| 222 |
+
print(f"[ZeroGPU retry] Caught gr.Error, parsed remaining={remaining}, message={getattr(e, 'message', str(e))}")
|
| 223 |
+
if remaining is None or remaining <= 0:
|
| 224 |
+
raise
|
| 225 |
+
gr.Info(f"GPU quota: {remaining}s remaining, retrying with exact quota")
|
| 226 |
+
_gpu_duration_override = remaining - 1
|
| 227 |
+
try:
|
| 228 |
+
return _design_sequences_gpu(
|
| 229 |
+
pdb_files=pdb_files,
|
| 230 |
+
ensemble_mode=ensemble_mode,
|
| 231 |
+
model_variant=model_variant,
|
| 232 |
+
num_seqs=num_seqs,
|
| 233 |
+
omit_aas=omit_aas,
|
| 234 |
+
temperature=temperature,
|
| 235 |
+
fixed_pos_seq=fixed_pos_seq,
|
| 236 |
+
fixed_pos_scn=fixed_pos_scn,
|
| 237 |
+
fixed_pos_override_seq=fixed_pos_override_seq,
|
| 238 |
+
pos_restrict_aatype=pos_restrict_aatype,
|
| 239 |
+
symmetry_pos=symmetry_pos,
|
| 240 |
+
num_protpardelle_conformers=num_protpardelle_conformers,
|
| 241 |
+
run_af2_eval=run_af2_eval,
|
| 242 |
+
)
|
| 243 |
+
finally:
|
| 244 |
+
_gpu_duration_override = None
|
ensemble.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Protpardelle and user ensemble generation."""
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def _generate_protpardelle_ensemble(
|
| 7 |
+
pdb_path: str,
|
| 8 |
+
num_conformers: int,
|
| 9 |
+
out_dir: Path,
|
| 10 |
+
weights_dir: str,
|
| 11 |
+
) -> dict[str, list[str]]:
|
| 12 |
+
"""Generate conformers with Protpardelle-1c, return pdb_to_conformers dict."""
|
| 13 |
+
from caliby import generate_ensembles
|
| 14 |
+
|
| 15 |
+
pdb_to_conformers = generate_ensembles(
|
| 16 |
+
[pdb_path],
|
| 17 |
+
out_dir=str(out_dir / "protpardelle_ensemble"),
|
| 18 |
+
num_samples_per_pdb=num_conformers,
|
| 19 |
+
model_params_path=weights_dir,
|
| 20 |
+
)
|
| 21 |
+
# generate_ensembles returns only generated conformers — prepend the primary structure.
|
| 22 |
+
pdb_stem = Path(pdb_path).stem
|
| 23 |
+
pdb_to_conformers[pdb_stem] = [pdb_path] + pdb_to_conformers.get(pdb_stem, [])
|
| 24 |
+
return pdb_to_conformers
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _setup_user_ensemble_dir(
|
| 28 |
+
pdb_paths: list[str],
|
| 29 |
+
**_ignored,
|
| 30 |
+
) -> dict[str, list[str]]:
|
| 31 |
+
"""Build pdb_to_conformers dict from user-uploaded files.
|
| 32 |
+
|
| 33 |
+
First file is the primary conformer, rest are additional conformers.
|
| 34 |
+
"""
|
| 35 |
+
pdb_key = Path(pdb_paths[0]).stem
|
| 36 |
+
return {pdb_key: list(pdb_paths)}
|
file_utils.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""File path helpers, ZIP operations, and CSV export."""
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
import tempfile
|
| 5 |
+
import zipfile
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import pandas as pd
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _get_file_path(f):
|
| 12 |
+
if isinstance(f, str):
|
| 13 |
+
return Path(f)
|
| 14 |
+
if hasattr(f, "path"):
|
| 15 |
+
return Path(f.path)
|
| 16 |
+
if isinstance(f, dict) and "path" in f:
|
| 17 |
+
return Path(f["path"])
|
| 18 |
+
return Path(str(f))
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _sanitize_download_stem(stem: str) -> str:
|
| 22 |
+
sanitized = re.sub(r"[^A-Za-z0-9._-]+", "_", stem).strip("._-")
|
| 23 |
+
return sanitized or "caliby"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _make_named_download_path(stem: str, suffix: str) -> str:
|
| 27 |
+
download_dir = Path(tempfile.mkdtemp(prefix="caliby_download_"))
|
| 28 |
+
return str(download_dir / f"{_sanitize_download_stem(stem)}{suffix}")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _get_results_stem(df: pd.DataFrame) -> str:
|
| 32 |
+
if "Sample" not in df.columns:
|
| 33 |
+
return "caliby"
|
| 34 |
+
sample_name = str(df.iloc[0]["Sample"])
|
| 35 |
+
return _sanitize_download_stem(re.sub(r"_sample\d+$", "", sample_name))
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _copy_uploaded_files(pdb_files: list, tmpdir: Path) -> list[str]:
|
| 39 |
+
pdb_paths = []
|
| 40 |
+
for f in pdb_files:
|
| 41 |
+
src = _get_file_path(f)
|
| 42 |
+
path = tmpdir / src.name
|
| 43 |
+
path.write_bytes(src.read_bytes())
|
| 44 |
+
pdb_paths.append(str(path))
|
| 45 |
+
return pdb_paths
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _write_zip_from_paths(paths: list[str], download_stem: str, suffix: str) -> str | None:
|
| 49 |
+
if not paths:
|
| 50 |
+
return None
|
| 51 |
+
|
| 52 |
+
zip_path = _make_named_download_path(download_stem, suffix)
|
| 53 |
+
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
| 54 |
+
for path in paths:
|
| 55 |
+
zf.write(path, Path(path).name)
|
| 56 |
+
return zip_path
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _write_zip_from_dir(directory: Path, download_stem: str, suffix: str) -> str:
|
| 60 |
+
zip_path = _make_named_download_path(download_stem, suffix)
|
| 61 |
+
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
| 62 |
+
for path in directory.rglob("*"):
|
| 63 |
+
if path.is_file():
|
| 64 |
+
zf.write(path, path.relative_to(directory))
|
| 65 |
+
return zip_path
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _df_to_csv(df: pd.DataFrame | None) -> str | None:
|
| 69 |
+
if df is None or df.empty:
|
| 70 |
+
return None
|
| 71 |
+
path = _make_named_download_path(_get_results_stem(df), "_results.csv")
|
| 72 |
+
df.to_csv(path, index=False)
|
| 73 |
+
return path
|
models.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model loading and caching."""
|
| 2 |
+
|
| 3 |
+
import sys
|
| 4 |
+
import types
|
| 5 |
+
|
| 6 |
+
from caliby import CalibyModel
|
| 7 |
+
|
| 8 |
+
MODELS: dict[str, CalibyModel] = {}
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_model(variant: str, device: str) -> CalibyModel:
|
| 12 |
+
"""Load and cache a CalibyModel by variant name."""
|
| 13 |
+
if variant not in MODELS:
|
| 14 |
+
# ZeroGPU's @spaces.GPU decorator may remove sys.modules["__main__"].
|
| 15 |
+
# Lightning's load_from_checkpoint calls inspect.stack() which
|
| 16 |
+
# requires it, so ensure a placeholder exists.
|
| 17 |
+
if "__main__" not in sys.modules:
|
| 18 |
+
sys.modules["__main__"] = types.ModuleType("__main__")
|
| 19 |
+
|
| 20 |
+
from caliby import load_model
|
| 21 |
+
|
| 22 |
+
MODELS[variant] = load_model(variant, device=device)
|
| 23 |
+
return MODELS[variant]
|
pyproject.toml
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "caliby-hf"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
dependencies = [
|
| 5 |
+
"caliby[af2] @ git+https://github.com/ProteinDesignLab/caliby@20d6757aaaba1662e71234ba25dde0f64b199683",
|
| 6 |
+
"gradio",
|
| 7 |
+
"huggingface_hub",
|
| 8 |
+
"molview>=0.1.0",
|
| 9 |
+
"omegaconf",
|
| 10 |
+
"pandas",
|
| 11 |
+
"pytest",
|
| 12 |
+
"spaces",
|
| 13 |
+
"torch",
|
| 14 |
+
]
|
| 15 |
+
|
| 16 |
+
[tool.pytest.ini_options]
|
| 17 |
+
testpaths = ["tests"]
|
| 18 |
+
pythonpath = ["."]
|
| 19 |
+
|
| 20 |
+
[tool.ruff]
|
| 21 |
+
line-length = 120
|
| 22 |
+
exclude = ["chroma"]
|
| 23 |
+
|
| 24 |
+
[tool.ruff.lint]
|
| 25 |
+
select = ["E", "F", "I"]
|
| 26 |
+
ignore = ["E731"]
|
| 27 |
+
|
| 28 |
+
[tool.ruff.format]
|
| 29 |
+
quote-style = "preserve" # avoid churning quotes
|
| 30 |
+
indent-style = "space"
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
caliby[af2] @ git+https://github.com/ProteinDesignLab/caliby@20d6757aaaba1662e71234ba25dde0f64b199683
|
| 2 |
+
gradio
|
| 3 |
+
huggingface_hub
|
| 4 |
+
molview>=0.1.0
|
| 5 |
+
omegaconf
|
| 6 |
+
pandas
|
| 7 |
+
spaces
|
| 8 |
+
torch
|
self_consistency.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""AF2 self-consistency evaluation."""
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from caliby import CalibyModel
|
| 7 |
+
|
| 8 |
+
from file_utils import _write_zip_from_dir
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _run_self_consistency(
|
| 12 |
+
model: CalibyModel,
|
| 13 |
+
df: pd.DataFrame,
|
| 14 |
+
out_pdb_list: list[str],
|
| 15 |
+
out_dir: Path,
|
| 16 |
+
download_stem: str,
|
| 17 |
+
) -> tuple[str, dict[str, str]]:
|
| 18 |
+
from caliby.eval.eval_utils.folding_utils import clear_mem_torch
|
| 19 |
+
|
| 20 |
+
clear_mem_torch()
|
| 21 |
+
|
| 22 |
+
sc_out_dir = out_dir / "self_consistency"
|
| 23 |
+
id_to_metrics = model.self_consistency_eval(out_pdb_list, out_dir=str(sc_out_dir))
|
| 24 |
+
|
| 25 |
+
for metric in ["sc_ca_rmsd", "avg_ca_plddt", "tmalign_score"]:
|
| 26 |
+
df[metric] = [id_to_metrics.get(Path(path).stem, {}).get(metric, float("nan")) for path in out_pdb_list]
|
| 27 |
+
|
| 28 |
+
af2_pdb_data = {}
|
| 29 |
+
for path in out_pdb_list:
|
| 30 |
+
af2_path = sc_out_dir / "struct_preds" / f"af2_{Path(path).stem}.pdb"
|
| 31 |
+
if af2_path.exists():
|
| 32 |
+
af2_pdb_data[Path(path).stem] = af2_path.read_text()
|
| 33 |
+
|
| 34 |
+
sc_zip_path = _write_zip_from_dir(sc_out_dir, download_stem, "_self_consistency.zip")
|
| 35 |
+
return sc_zip_path, af2_pdb_data
|
tests/conftest.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Module-level mocking for app.py import-time side effects."""
|
| 2 |
+
|
| 3 |
+
import tempfile
|
| 4 |
+
from unittest.mock import patch
|
| 5 |
+
|
| 6 |
+
import pytest
|
| 7 |
+
|
| 8 |
+
# ---------------------------------------------------------------------------
|
| 9 |
+
# Module-level patches — applied BEFORE any test file can `import app`.
|
| 10 |
+
#
|
| 11 |
+
# app_config.py executes on import:
|
| 12 |
+
# 1. snapshot_download(...) → needs HF_TOKEN + network
|
| 13 |
+
# 2. os.environ.setdefault(...) → safe, no mock needed
|
| 14 |
+
# app.py executes on import:
|
| 15 |
+
# 1. base64-encode caliby_transparent.png → file exists in repo, no mock
|
| 16 |
+
# 2. @spaces.GPU decorator → no-op when SPACES_ZERO_GPU unset
|
| 17 |
+
# ---------------------------------------------------------------------------
|
| 18 |
+
|
| 19 |
+
_FAKE_WEIGHTS_DIR = tempfile.mkdtemp(prefix="caliby_test_weights_")
|
| 20 |
+
|
| 21 |
+
patch("huggingface_hub.snapshot_download", return_value=_FAKE_WEIGHTS_DIR).start()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# ---------------------------------------------------------------------------
|
| 25 |
+
# Shared fixtures
|
| 26 |
+
# ---------------------------------------------------------------------------
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@pytest.fixture
|
| 30 |
+
def sample_outputs():
|
| 31 |
+
"""Mock outputs dict matching caliby's CalibyModel.sample() return format."""
|
| 32 |
+
return {
|
| 33 |
+
"example_id": ["1YCR", "1YCR"],
|
| 34 |
+
"out_pdb": ["/tmp/out/1YCR_sample0.cif", "/tmp/out/1YCR_sample1.cif"],
|
| 35 |
+
"U": [-142.38, -139.92],
|
| 36 |
+
"input_seq": ["NATIVE_SEQ", "NATIVE_SEQ"],
|
| 37 |
+
"seq": ["MTEEQWAQ", "VSEQQWAQ"],
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@pytest.fixture
|
| 42 |
+
def sample_outputs_with_out_pdbs(sample_outputs):
|
| 43 |
+
"""Outputs dict with the 'out_pdbs' key that app.py's _format_outputs actually reads."""
|
| 44 |
+
return {**sample_outputs, "out_pdbs": sample_outputs["out_pdb"]}
|
tests/test_design_sequences.py
ADDED
|
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Integration tests for get_model, _setup_user_ensemble_dir, and design_sequences."""
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from unittest.mock import MagicMock, patch
|
| 5 |
+
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import pytest
|
| 9 |
+
|
| 10 |
+
import design
|
| 11 |
+
import ensemble
|
| 12 |
+
import models
|
| 13 |
+
|
| 14 |
+
# ---------------------------------------------------------------------------
|
| 15 |
+
# get_model
|
| 16 |
+
# ---------------------------------------------------------------------------
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class TestGetModel:
|
| 20 |
+
"""Lazy-loads and caches CalibyModel instances via caliby.load_model."""
|
| 21 |
+
|
| 22 |
+
@pytest.fixture(autouse=True)
|
| 23 |
+
def _clear_model_cache(self):
|
| 24 |
+
models.MODELS.clear()
|
| 25 |
+
yield
|
| 26 |
+
models.MODELS.clear()
|
| 27 |
+
|
| 28 |
+
def test_calls_load_model_with_variant_and_device(self):
|
| 29 |
+
mock_caliby_model = MagicMock()
|
| 30 |
+
with patch("caliby.load_model", return_value=mock_caliby_model) as mock_load:
|
| 31 |
+
result = models.get_model("caliby", "cpu")
|
| 32 |
+
|
| 33 |
+
mock_load.assert_called_once_with("caliby", device="cpu")
|
| 34 |
+
assert result is mock_caliby_model
|
| 35 |
+
|
| 36 |
+
def test_caches_model_on_repeat_call(self):
|
| 37 |
+
mock_caliby_model = MagicMock()
|
| 38 |
+
with patch("caliby.load_model", return_value=mock_caliby_model) as mock_load:
|
| 39 |
+
first = models.get_model("caliby", "cpu")
|
| 40 |
+
second = models.get_model("caliby", "cpu")
|
| 41 |
+
mock_load.assert_called_once()
|
| 42 |
+
assert first is second
|
| 43 |
+
|
| 44 |
+
def test_different_variants_cached_separately(self):
|
| 45 |
+
mock_a = MagicMock()
|
| 46 |
+
mock_b = MagicMock()
|
| 47 |
+
with patch("caliby.load_model", side_effect=[mock_a, mock_b]):
|
| 48 |
+
a = models.get_model("caliby", "cpu")
|
| 49 |
+
b = models.get_model("soluble_caliby_v1", "cpu")
|
| 50 |
+
assert a is mock_a
|
| 51 |
+
assert b is mock_b
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# ---------------------------------------------------------------------------
|
| 55 |
+
# _setup_user_ensemble_dir
|
| 56 |
+
# ---------------------------------------------------------------------------
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class TestSetupUserEnsembleDir:
|
| 60 |
+
"""Builds pdb_to_conformers dict from user-uploaded files."""
|
| 61 |
+
|
| 62 |
+
def test_returns_dict_with_primary_key(self):
|
| 63 |
+
result = ensemble._setup_user_ensemble_dir(["/tmp/primary.pdb", "/tmp/conf1.pdb", "/tmp/conf2.pdb"])
|
| 64 |
+
assert "primary" in result
|
| 65 |
+
assert result["primary"] == ["/tmp/primary.pdb", "/tmp/conf1.pdb", "/tmp/conf2.pdb"]
|
| 66 |
+
|
| 67 |
+
def test_first_file_is_primary(self):
|
| 68 |
+
result = ensemble._setup_user_ensemble_dir(["/tmp/myprotein.cif", "/tmp/alt.pdb"])
|
| 69 |
+
assert result["myprotein"][0] == "/tmp/myprotein.cif"
|
| 70 |
+
|
| 71 |
+
def test_uses_stem_as_key(self):
|
| 72 |
+
result = ensemble._setup_user_ensemble_dir(["/path/to/foo.pdb"])
|
| 73 |
+
assert "foo" in result
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# ---------------------------------------------------------------------------
|
| 77 |
+
# design_sequences — validation
|
| 78 |
+
# ---------------------------------------------------------------------------
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class TestDesignSequencesValidation:
|
| 82 |
+
"""Input validation before any model calls."""
|
| 83 |
+
|
| 84 |
+
def test_no_files(self):
|
| 85 |
+
df, msg, _, _, _, _ = design.design_sequences(None, "none", "caliby", 4, None, 0.1, "", "", "", "", "", 31)
|
| 86 |
+
assert df.empty
|
| 87 |
+
assert "Upload at least one" in msg
|
| 88 |
+
|
| 89 |
+
def test_empty_file_list(self):
|
| 90 |
+
df, msg, _, _, _, _ = design.design_sequences([], "none", "caliby", 4, None, 0.1, "", "", "", "", "", 31)
|
| 91 |
+
assert df.empty
|
| 92 |
+
assert "Upload at least one" in msg
|
| 93 |
+
|
| 94 |
+
def test_single_mode_multiple_files(self):
|
| 95 |
+
df, msg, _, _, _, _ = design.design_sequences(
|
| 96 |
+
["a.pdb", "b.pdb"], "none", "caliby", 4, None, 0.1, "", "", "", "", "", 31
|
| 97 |
+
)
|
| 98 |
+
assert "exactly one file" in msg
|
| 99 |
+
|
| 100 |
+
def test_synthetic_mode_multiple_files(self):
|
| 101 |
+
df, msg, _, _, _, _ = design.design_sequences(
|
| 102 |
+
["a.pdb", "b.pdb"], "synthetic", "caliby", 4, None, 0.1, "", "", "", "", "", 31
|
| 103 |
+
)
|
| 104 |
+
assert "exactly one file" in msg
|
| 105 |
+
|
| 106 |
+
def test_user_mode_too_few_files(self):
|
| 107 |
+
df, msg, _, _, _, _ = design.design_sequences(["a.pdb"], "user", "caliby", 4, None, 0.1, "", "", "", "", "", 31)
|
| 108 |
+
assert "at least two" in msg
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# ---------------------------------------------------------------------------
|
| 112 |
+
# design_sequences — single structure mode
|
| 113 |
+
# ---------------------------------------------------------------------------
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class TestDesignSequencesSingleMode:
|
| 117 |
+
"""Tests ensemble_mode='none' — verifies correct args to CalibyModel.sample()."""
|
| 118 |
+
|
| 119 |
+
def _make_mock_outputs(self):
|
| 120 |
+
return {
|
| 121 |
+
"example_id": ["test"],
|
| 122 |
+
"out_pdb": ["/tmp/test_sample0.cif"],
|
| 123 |
+
"U": [-100.0],
|
| 124 |
+
"input_seq": ["NATIVE"],
|
| 125 |
+
"seq": ["ACDEF"],
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
def test_sample_called_with_correct_args(self, tmp_path):
|
| 129 |
+
pdb_file = tmp_path / "test.pdb"
|
| 130 |
+
pdb_file.write_text("FAKE PDB")
|
| 131 |
+
|
| 132 |
+
mock_model = MagicMock()
|
| 133 |
+
mock_model.sample.return_value = self._make_mock_outputs()
|
| 134 |
+
|
| 135 |
+
with (
|
| 136 |
+
patch.object(design, "get_model", return_value=mock_model),
|
| 137 |
+
patch.object(design, "_write_zip_from_paths", return_value=None),
|
| 138 |
+
patch("torch.cuda.is_available", return_value=False),
|
| 139 |
+
):
|
| 140 |
+
design.design_sequences(
|
| 141 |
+
[str(pdb_file)],
|
| 142 |
+
"none",
|
| 143 |
+
"caliby",
|
| 144 |
+
4,
|
| 145 |
+
["C"],
|
| 146 |
+
0.5,
|
| 147 |
+
"A1-100",
|
| 148 |
+
"A1-10",
|
| 149 |
+
"A26:A",
|
| 150 |
+
"A26:AVG",
|
| 151 |
+
"A10,B10",
|
| 152 |
+
31,
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
mock_model.sample.assert_called_once()
|
| 156 |
+
args, kwargs = mock_model.sample.call_args
|
| 157 |
+
|
| 158 |
+
# First positional arg is pdb_paths
|
| 159 |
+
assert isinstance(args[0], list)
|
| 160 |
+
assert len(args[0]) == 1
|
| 161 |
+
assert args[0][0].endswith("test.pdb")
|
| 162 |
+
|
| 163 |
+
assert kwargs["num_seqs_per_pdb"] == 4
|
| 164 |
+
assert kwargs["omit_aas"] == ["C"]
|
| 165 |
+
assert kwargs["temperature"] == 0.5
|
| 166 |
+
assert kwargs["num_workers"] == 0
|
| 167 |
+
assert isinstance(kwargs["out_dir"], str)
|
| 168 |
+
assert isinstance(kwargs["pos_constraint_df"], pd.DataFrame)
|
| 169 |
+
assert kwargs["pos_constraint_df"].iloc[0]["pdb_key"] == "test"
|
| 170 |
+
|
| 171 |
+
def test_no_constraints_passes_none(self, tmp_path):
|
| 172 |
+
pdb_file = tmp_path / "test.pdb"
|
| 173 |
+
pdb_file.write_text("FAKE")
|
| 174 |
+
|
| 175 |
+
mock_model = MagicMock()
|
| 176 |
+
mock_model.sample.return_value = self._make_mock_outputs()
|
| 177 |
+
|
| 178 |
+
with (
|
| 179 |
+
patch.object(design, "get_model", return_value=mock_model),
|
| 180 |
+
patch.object(design, "_write_zip_from_paths", return_value=None),
|
| 181 |
+
patch("torch.cuda.is_available", return_value=False),
|
| 182 |
+
):
|
| 183 |
+
design.design_sequences(
|
| 184 |
+
[str(pdb_file)],
|
| 185 |
+
"none",
|
| 186 |
+
"caliby",
|
| 187 |
+
1,
|
| 188 |
+
None,
|
| 189 |
+
0.1,
|
| 190 |
+
"",
|
| 191 |
+
"",
|
| 192 |
+
"",
|
| 193 |
+
"",
|
| 194 |
+
"",
|
| 195 |
+
31,
|
| 196 |
+
)
|
| 197 |
+
assert mock_model.sample.call_args[1]["pos_constraint_df"] is None
|
| 198 |
+
|
| 199 |
+
def test_empty_omit_aas_becomes_none(self, tmp_path):
|
| 200 |
+
pdb_file = tmp_path / "test.pdb"
|
| 201 |
+
pdb_file.write_text("FAKE")
|
| 202 |
+
|
| 203 |
+
mock_model = MagicMock()
|
| 204 |
+
mock_model.sample.return_value = self._make_mock_outputs()
|
| 205 |
+
|
| 206 |
+
with (
|
| 207 |
+
patch.object(design, "get_model", return_value=mock_model),
|
| 208 |
+
patch.object(design, "_write_zip_from_paths", return_value=None),
|
| 209 |
+
patch("torch.cuda.is_available", return_value=False),
|
| 210 |
+
):
|
| 211 |
+
design.design_sequences(
|
| 212 |
+
[str(pdb_file)],
|
| 213 |
+
"none",
|
| 214 |
+
"caliby",
|
| 215 |
+
1,
|
| 216 |
+
[],
|
| 217 |
+
0.1,
|
| 218 |
+
"",
|
| 219 |
+
"",
|
| 220 |
+
"",
|
| 221 |
+
"",
|
| 222 |
+
"",
|
| 223 |
+
31,
|
| 224 |
+
)
|
| 225 |
+
assert mock_model.sample.call_args[1]["omit_aas"] is None
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
# ---------------------------------------------------------------------------
|
| 229 |
+
# design_sequences — user ensemble mode
|
| 230 |
+
# ---------------------------------------------------------------------------
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class TestDesignSequencesUserEnsembleMode:
|
| 234 |
+
"""Tests ensemble_mode='user' — verifies correct args to CalibyModel.ensemble_sample()."""
|
| 235 |
+
|
| 236 |
+
def _make_mock_outputs(self):
|
| 237 |
+
return {
|
| 238 |
+
"example_id": ["primary"],
|
| 239 |
+
"out_pdb": ["/tmp/primary_sample0.cif"],
|
| 240 |
+
"U": [-100.0],
|
| 241 |
+
"input_seq": ["NATIVE"],
|
| 242 |
+
"seq": ["AAA"],
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
def test_calls_ensemble_sample(self, tmp_path):
|
| 246 |
+
pdb1 = tmp_path / "primary.pdb"
|
| 247 |
+
pdb2 = tmp_path / "conf1.pdb"
|
| 248 |
+
pdb1.write_text("PDB1")
|
| 249 |
+
pdb2.write_text("PDB2")
|
| 250 |
+
|
| 251 |
+
mock_model = MagicMock()
|
| 252 |
+
mock_model.ensemble_sample.return_value = self._make_mock_outputs()
|
| 253 |
+
mock_pdb_to_conf = {"primary": ["/some/primary.pdb", "/some/conf1.pdb"]}
|
| 254 |
+
|
| 255 |
+
with (
|
| 256 |
+
patch.object(design, "get_model", return_value=mock_model),
|
| 257 |
+
patch.object(design, "_setup_user_ensemble_dir", return_value=mock_pdb_to_conf),
|
| 258 |
+
patch.object(design, "_write_zip_from_paths", return_value=None),
|
| 259 |
+
patch("torch.cuda.is_available", return_value=False),
|
| 260 |
+
):
|
| 261 |
+
design.design_sequences([str(pdb1), str(pdb2)], "user", "caliby", 4, None, 0.1, "", "", "", "", "", 31)
|
| 262 |
+
|
| 263 |
+
mock_model.ensemble_sample.assert_called_once()
|
| 264 |
+
args, kwargs = mock_model.ensemble_sample.call_args
|
| 265 |
+
assert args[0] is mock_pdb_to_conf
|
| 266 |
+
assert kwargs["pos_constraint_df"] is None
|
| 267 |
+
|
| 268 |
+
def test_constraints_expand_via_make_ensemble_constraints(self, tmp_path):
|
| 269 |
+
pdb1 = tmp_path / "primary.pdb"
|
| 270 |
+
pdb2 = tmp_path / "conf1.pdb"
|
| 271 |
+
pdb1.write_text("PDB1")
|
| 272 |
+
pdb2.write_text("PDB2")
|
| 273 |
+
|
| 274 |
+
mock_model = MagicMock()
|
| 275 |
+
mock_model.ensemble_sample.return_value = self._make_mock_outputs()
|
| 276 |
+
mock_pdb_to_conf = {"primary": ["a.pdb", "b.pdb"]}
|
| 277 |
+
expanded_df = pd.DataFrame({"pdb_key": ["a", "b"], "fixed_pos_seq": ["A1-10", "A1-10"]})
|
| 278 |
+
|
| 279 |
+
with (
|
| 280 |
+
patch.object(design, "get_model", return_value=mock_model),
|
| 281 |
+
patch.object(design, "_setup_user_ensemble_dir", return_value=mock_pdb_to_conf),
|
| 282 |
+
patch("caliby.make_ensemble_constraints", return_value=expanded_df) as mock_expand,
|
| 283 |
+
patch.object(design, "_write_zip_from_paths", return_value=None),
|
| 284 |
+
patch("torch.cuda.is_available", return_value=False),
|
| 285 |
+
):
|
| 286 |
+
design.design_sequences([str(pdb1), str(pdb2)], "user", "caliby", 1, None, 0.1, "A1-10", "", "", "", "", 31)
|
| 287 |
+
|
| 288 |
+
mock_expand.assert_called_once()
|
| 289 |
+
constraints_dict, pdb_to_conf_arg = mock_expand.call_args[0]
|
| 290 |
+
assert isinstance(constraints_dict, dict)
|
| 291 |
+
assert "primary" in constraints_dict
|
| 292 |
+
assert constraints_dict["primary"]["fixed_pos_seq"] == "A1-10"
|
| 293 |
+
assert pdb_to_conf_arg is mock_pdb_to_conf
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
# ---------------------------------------------------------------------------
|
| 297 |
+
# design_sequences — error handling
|
| 298 |
+
# ---------------------------------------------------------------------------
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
class TestDesignSequencesErrorHandling:
|
| 302 |
+
"""Verifies non-validation failures now raise naturally."""
|
| 303 |
+
|
| 304 |
+
def test_value_error(self, tmp_path):
|
| 305 |
+
pdb_file = tmp_path / "test.pdb"
|
| 306 |
+
pdb_file.write_text("PDB")
|
| 307 |
+
|
| 308 |
+
with (
|
| 309 |
+
patch.object(design, "get_model", side_effect=ValueError("bad config")),
|
| 310 |
+
patch("torch.cuda.is_available", return_value=False),
|
| 311 |
+
):
|
| 312 |
+
with pytest.raises(ValueError, match="bad config"):
|
| 313 |
+
design.design_sequences([str(pdb_file)], "none", "caliby", 1, None, 0.1, "", "", "", "", "", 31)
|
| 314 |
+
|
| 315 |
+
def test_file_not_found(self, tmp_path):
|
| 316 |
+
with (
|
| 317 |
+
patch.object(design, "get_model", side_effect=FileNotFoundError("missing.pdb")),
|
| 318 |
+
patch("torch.cuda.is_available", return_value=False),
|
| 319 |
+
):
|
| 320 |
+
with pytest.raises(FileNotFoundError, match="missing.pdb"):
|
| 321 |
+
design.design_sequences(
|
| 322 |
+
[str(tmp_path / "ghost.pdb")], "none", "caliby", 1, None, 0.1, "", "", "", "", "", 31
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
def test_unexpected_runtime_error(self, tmp_path):
|
| 326 |
+
pdb_file = tmp_path / "test.pdb"
|
| 327 |
+
pdb_file.write_text("PDB")
|
| 328 |
+
|
| 329 |
+
with (
|
| 330 |
+
patch.object(design, "get_model", side_effect=RuntimeError("GPU OOM")),
|
| 331 |
+
patch("torch.cuda.is_available", return_value=False),
|
| 332 |
+
):
|
| 333 |
+
with pytest.raises(RuntimeError, match="GPU OOM"):
|
| 334 |
+
design.design_sequences([str(pdb_file)], "none", "caliby", 1, None, 0.1, "", "", "", "", "", 31)
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
# ---------------------------------------------------------------------------
|
| 338 |
+
# design_sequences — zip output
|
| 339 |
+
# ---------------------------------------------------------------------------
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
class TestDesignSequencesZipOutput:
|
| 343 |
+
"""Tests ZIP file creation from output CIF files."""
|
| 344 |
+
|
| 345 |
+
def test_creates_zip_when_out_pdb_present(self, tmp_path):
|
| 346 |
+
pdb_file = tmp_path / "test.pdb"
|
| 347 |
+
pdb_file.write_text("PDB")
|
| 348 |
+
|
| 349 |
+
out_cif = tmp_path / "test_sample0.cif"
|
| 350 |
+
out_cif.write_text("CIF CONTENT")
|
| 351 |
+
|
| 352 |
+
mock_model = MagicMock()
|
| 353 |
+
mock_model.sample.return_value = {
|
| 354 |
+
"example_id": ["test"],
|
| 355 |
+
"out_pdb": [str(out_cif)],
|
| 356 |
+
"U": [-100.0],
|
| 357 |
+
"input_seq": ["NATIVE"],
|
| 358 |
+
"seq": ["AAA"],
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
with (
|
| 362 |
+
patch.object(design, "get_model", return_value=mock_model),
|
| 363 |
+
patch("torch.cuda.is_available", return_value=False),
|
| 364 |
+
):
|
| 365 |
+
_, _, zip_path, _, _, _ = design.design_sequences(
|
| 366 |
+
[str(pdb_file)], "none", "caliby", 1, None, 0.1, "", "", "", "", "", 31
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
assert zip_path is not None
|
| 370 |
+
assert Path(zip_path).name == "test_designs.zip"
|
| 371 |
+
assert Path(zip_path).exists()
|
| 372 |
+
|
| 373 |
+
def test_empty_out_pdb_raises_for_invalid_caliby_output(self, tmp_path):
|
| 374 |
+
pdb_file = tmp_path / "test.pdb"
|
| 375 |
+
pdb_file.write_text("PDB")
|
| 376 |
+
|
| 377 |
+
mock_model = MagicMock()
|
| 378 |
+
mock_model.sample.return_value = {
|
| 379 |
+
"example_id": ["test"],
|
| 380 |
+
"out_pdb": [],
|
| 381 |
+
"U": [-100.0],
|
| 382 |
+
"input_seq": ["NATIVE"],
|
| 383 |
+
"seq": ["AAA"],
|
| 384 |
+
}
|
| 385 |
+
|
| 386 |
+
with (
|
| 387 |
+
patch.object(design, "get_model", return_value=mock_model),
|
| 388 |
+
patch("torch.cuda.is_available", return_value=False),
|
| 389 |
+
):
|
| 390 |
+
with pytest.raises(ValueError, match="All arrays must be of the same length"):
|
| 391 |
+
design.design_sequences([str(pdb_file)], "none", "caliby", 1, None, 0.1, "", "", "", "", "", 31)
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
# ---------------------------------------------------------------------------
|
| 395 |
+
# design_sequences — ZeroGPU quota-aware retry
|
| 396 |
+
# ---------------------------------------------------------------------------
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
class TestParseQuotaLeft:
|
| 400 |
+
"""Tests _parse_quota_left regex parsing of ZeroGPU error messages."""
|
| 401 |
+
|
| 402 |
+
def test_extracts_remaining_seconds(self):
|
| 403 |
+
e = gr.Error("You have exceeded your free GPU quota (210s requested vs. 45s left). Try again in 0:02:45")
|
| 404 |
+
assert design._parse_quota_left(e) == 45
|
| 405 |
+
|
| 406 |
+
def test_extracts_zero_remaining(self):
|
| 407 |
+
e = gr.Error("(210s requested vs. 0s left). Try again in 0:03:30")
|
| 408 |
+
assert design._parse_quota_left(e) == 0
|
| 409 |
+
|
| 410 |
+
def test_returns_none_for_non_quota_error(self):
|
| 411 |
+
e = gr.Error("Some other error")
|
| 412 |
+
assert design._parse_quota_left(e) is None
|
| 413 |
+
|
| 414 |
+
def test_returns_none_for_no_message_attr(self):
|
| 415 |
+
e = RuntimeError("no message attribute")
|
| 416 |
+
assert design._parse_quota_left(e) is None
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
class TestDesignSequencesQuotaRetry:
|
| 420 |
+
"""Tests ZeroGPU quota-aware retry logic in design_sequences wrapper."""
|
| 421 |
+
|
| 422 |
+
_DESIGN_ARGS = (None, "none", "caliby", 4, None, 0.1, "", "", "", "", "", 31)
|
| 423 |
+
|
| 424 |
+
def test_retry_on_quota_exceeded(self, tmp_path):
|
| 425 |
+
pdb_file = tmp_path / "test.pdb"
|
| 426 |
+
pdb_file.write_text("PDB")
|
| 427 |
+
|
| 428 |
+
mock_model = MagicMock()
|
| 429 |
+
mock_model.sample.return_value = {
|
| 430 |
+
"example_id": ["test"],
|
| 431 |
+
"out_pdb": ["/tmp/t.cif"],
|
| 432 |
+
"U": [-100.0],
|
| 433 |
+
"input_seq": ["N"],
|
| 434 |
+
"seq": ["A"],
|
| 435 |
+
}
|
| 436 |
+
|
| 437 |
+
quota_error = gr.Error("(210s requested vs. 45s left). Try again in 0:02:45")
|
| 438 |
+
|
| 439 |
+
call_count = 0
|
| 440 |
+
original_fn = design._design_sequences_gpu
|
| 441 |
+
|
| 442 |
+
def side_effect(*args, **kwargs):
|
| 443 |
+
nonlocal call_count
|
| 444 |
+
call_count += 1
|
| 445 |
+
if call_count == 1:
|
| 446 |
+
raise quota_error
|
| 447 |
+
return original_fn(*args, **kwargs)
|
| 448 |
+
|
| 449 |
+
with (
|
| 450 |
+
patch.object(design, "_design_sequences_gpu", side_effect=side_effect),
|
| 451 |
+
patch.object(design, "get_model", return_value=mock_model),
|
| 452 |
+
patch.object(design, "_write_zip_from_paths", return_value=None),
|
| 453 |
+
patch("torch.cuda.is_available", return_value=False),
|
| 454 |
+
):
|
| 455 |
+
design.design_sequences([str(pdb_file)], "none", "caliby", 1, None, 0.1, "", "", "", "", "", 31)
|
| 456 |
+
assert call_count == 2
|
| 457 |
+
assert design._gpu_duration_override is None # Reset after retry
|
| 458 |
+
|
| 459 |
+
def test_no_retry_when_remaining_zero(self):
|
| 460 |
+
quota_error = gr.Error("(210s requested vs. 0s left). Try again in 0:03:30")
|
| 461 |
+
with patch.object(design, "_design_sequences_gpu", side_effect=quota_error):
|
| 462 |
+
with pytest.raises(gr.Error):
|
| 463 |
+
design.design_sequences(*self._DESIGN_ARGS)
|
| 464 |
+
|
| 465 |
+
def test_no_retry_for_non_quota_gr_error(self):
|
| 466 |
+
other_error = gr.Error("The requested GPU duration (210s) is larger than the maximum allowed")
|
| 467 |
+
with patch.object(design, "_design_sequences_gpu", side_effect=other_error):
|
| 468 |
+
with pytest.raises(gr.Error, match="larger than the maximum allowed"):
|
| 469 |
+
design.design_sequences(*self._DESIGN_ARGS)
|
| 470 |
+
|
| 471 |
+
def test_non_gradio_errors_propagate(self):
|
| 472 |
+
"""ValueError, RuntimeError etc. are not caught by the retry logic."""
|
| 473 |
+
with patch.object(design, "_design_sequences_gpu", side_effect=ValueError("bad")):
|
| 474 |
+
with pytest.raises(ValueError, match="bad"):
|
| 475 |
+
design.design_sequences(*self._DESIGN_ARGS)
|
tests/test_helpers.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unit tests for helper functions."""
|
| 2 |
+
|
| 3 |
+
import types
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import pandas as pd
|
| 7 |
+
|
| 8 |
+
import constraints
|
| 9 |
+
import design
|
| 10 |
+
import file_utils
|
| 11 |
+
import viewers
|
| 12 |
+
|
| 13 |
+
# ---------------------------------------------------------------------------
|
| 14 |
+
# _get_file_path
|
| 15 |
+
# ---------------------------------------------------------------------------
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class TestGetFilePath:
|
| 19 |
+
"""Normalizes Gradio's various file input formats to a Path."""
|
| 20 |
+
|
| 21 |
+
def test_string_input(self):
|
| 22 |
+
assert file_utils._get_file_path("/some/path.pdb") == Path("/some/path.pdb")
|
| 23 |
+
|
| 24 |
+
def test_object_with_path_attr(self):
|
| 25 |
+
obj = types.SimpleNamespace(path="/uploads/file.pdb")
|
| 26 |
+
assert file_utils._get_file_path(obj) == Path("/uploads/file.pdb")
|
| 27 |
+
|
| 28 |
+
def test_dict_with_path_key(self):
|
| 29 |
+
result = file_utils._get_file_path({"path": "/uploads/file.pdb", "name": "file.pdb"})
|
| 30 |
+
assert result == Path("/uploads/file.pdb")
|
| 31 |
+
|
| 32 |
+
def test_fallback_to_str(self):
|
| 33 |
+
assert file_utils._get_file_path(42) == Path("42")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# ---------------------------------------------------------------------------
|
| 37 |
+
# _build_pos_constraint_df
|
| 38 |
+
# ---------------------------------------------------------------------------
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class TestBuildPosConstraintDf:
|
| 42 |
+
"""Builds a positional constraint DataFrame for caliby."""
|
| 43 |
+
|
| 44 |
+
def test_all_empty_returns_none(self):
|
| 45 |
+
assert constraints._build_pos_constraint_df("1YCR", "", "", "", "", "") is None
|
| 46 |
+
|
| 47 |
+
def test_all_whitespace_returns_none(self):
|
| 48 |
+
assert constraints._build_pos_constraint_df("1YCR", " ", " ", " ", " ", " ") is None
|
| 49 |
+
|
| 50 |
+
def test_single_field_populated(self):
|
| 51 |
+
df = constraints._build_pos_constraint_df("1YCR", "A1-100", "", "", "", "")
|
| 52 |
+
assert df is not None
|
| 53 |
+
assert len(df) == 1
|
| 54 |
+
assert df.iloc[0]["pdb_key"] == "1YCR"
|
| 55 |
+
assert df.iloc[0]["fixed_pos_seq"] == "A1-100"
|
| 56 |
+
# Only populated columns + pdb_key should be present
|
| 57 |
+
assert "fixed_pos_scn" not in df.columns
|
| 58 |
+
|
| 59 |
+
def test_all_fields_populated(self):
|
| 60 |
+
df = constraints._build_pos_constraint_df("X", "A1", "B2", "A3:G", "A4:V", "A5,B5")
|
| 61 |
+
assert set(df.columns) == {
|
| 62 |
+
"pdb_key",
|
| 63 |
+
"fixed_pos_seq",
|
| 64 |
+
"fixed_pos_scn",
|
| 65 |
+
"fixed_pos_override_seq",
|
| 66 |
+
"pos_restrict_aatype",
|
| 67 |
+
"symmetry_pos",
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
def test_columns_match_caliby_valid_columns(self):
|
| 71 |
+
"""All columns must be in caliby's _VALID_POS_CONSTRAINT_COLUMNS."""
|
| 72 |
+
valid = {
|
| 73 |
+
"pdb_key",
|
| 74 |
+
"fixed_pos_seq",
|
| 75 |
+
"fixed_pos_scn",
|
| 76 |
+
"fixed_pos_override_seq",
|
| 77 |
+
"pos_restrict_aatype",
|
| 78 |
+
"symmetry_pos",
|
| 79 |
+
}
|
| 80 |
+
df = constraints._build_pos_constraint_df("X", "A1", "B2", "A3:G", "A4:V", "A5,B5")
|
| 81 |
+
assert set(df.columns).issubset(valid)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# ---------------------------------------------------------------------------
|
| 85 |
+
# _df_to_csv
|
| 86 |
+
# ---------------------------------------------------------------------------
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class TestDfToCsv:
|
| 90 |
+
"""Writes a DataFrame to a temp CSV file."""
|
| 91 |
+
|
| 92 |
+
def test_none_returns_none(self):
|
| 93 |
+
assert file_utils._df_to_csv(None) is None
|
| 94 |
+
|
| 95 |
+
def test_empty_dataframe_returns_none(self):
|
| 96 |
+
assert file_utils._df_to_csv(pd.DataFrame()) is None
|
| 97 |
+
|
| 98 |
+
def test_valid_dataframe_roundtrips(self):
|
| 99 |
+
df = pd.DataFrame({"pdb_key": ["1YCR"], "fixed_pos_seq": ["A1-100"]})
|
| 100 |
+
path = file_utils._df_to_csv(df)
|
| 101 |
+
assert path is not None
|
| 102 |
+
assert Path(path).exists()
|
| 103 |
+
assert path.endswith(".csv")
|
| 104 |
+
loaded = pd.read_csv(path)
|
| 105 |
+
pd.testing.assert_frame_equal(df, loaded)
|
| 106 |
+
|
| 107 |
+
def test_uses_sample_name_for_csv_basename(self):
|
| 108 |
+
df = pd.DataFrame(
|
| 109 |
+
{
|
| 110 |
+
"Sample": ["1YCR_sample0"],
|
| 111 |
+
"Sequence": ["ACDE"],
|
| 112 |
+
"Energy (U)": [-1.0],
|
| 113 |
+
}
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
path = file_utils._df_to_csv(df)
|
| 117 |
+
|
| 118 |
+
assert path is not None
|
| 119 |
+
assert Path(path).name == "1YCR_results.csv"
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class TestCsvDownloadOutput:
|
| 123 |
+
"""Formats CSV downloads for the Gradio file component."""
|
| 124 |
+
|
| 125 |
+
def test_hides_component_for_empty_dataframe(self):
|
| 126 |
+
update = viewers._csv_download_output(pd.DataFrame())
|
| 127 |
+
|
| 128 |
+
assert update["visible"] is False
|
| 129 |
+
assert update["value"] is None
|
| 130 |
+
|
| 131 |
+
def test_shows_named_csv_for_results_dataframe(self):
|
| 132 |
+
df = pd.DataFrame(
|
| 133 |
+
{
|
| 134 |
+
"Sample": ["1YCR_sample0"],
|
| 135 |
+
"Sequence": ["ACDE"],
|
| 136 |
+
"Energy (U)": [-1.0],
|
| 137 |
+
}
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
update = viewers._csv_download_output(df)
|
| 141 |
+
|
| 142 |
+
assert update["visible"] is True
|
| 143 |
+
assert Path(update["value"]).name == "1YCR_results.csv"
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class TestFormatResultsDisplay:
|
| 147 |
+
"""Formats the on-screen results table without changing the raw dataframe."""
|
| 148 |
+
|
| 149 |
+
def test_formats_last_four_numeric_columns(self):
|
| 150 |
+
df = pd.DataFrame(
|
| 151 |
+
{
|
| 152 |
+
"Sample": ["1YCR_sample0"],
|
| 153 |
+
"Sequence": ["ACDE"],
|
| 154 |
+
"Energy (U)": [-1.2345],
|
| 155 |
+
"sc_ca_rmsd": [1.0],
|
| 156 |
+
"avg_ca_plddt": [88.888],
|
| 157 |
+
"tmalign_score": [0.12345],
|
| 158 |
+
}
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
styler = viewers._format_results_display(df)
|
| 162 |
+
html = styler.to_html()
|
| 163 |
+
|
| 164 |
+
assert "-1.23" in html
|
| 165 |
+
assert ">1<" in html
|
| 166 |
+
assert "88.89" in html
|
| 167 |
+
assert "0.12" in html
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# ---------------------------------------------------------------------------
|
| 171 |
+
# _format_outputs
|
| 172 |
+
# ---------------------------------------------------------------------------
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class TestFormatOutputs:
|
| 176 |
+
"""Formats caliby output dict into (DataFrame, FASTA, out_pdb_list)."""
|
| 177 |
+
|
| 178 |
+
def test_dataframe_structure(self, sample_outputs_with_out_pdbs):
|
| 179 |
+
df, _, _ = design._format_outputs(sample_outputs_with_out_pdbs)
|
| 180 |
+
assert list(df.columns) == ["Sample", "Sequence", "Energy (U)"]
|
| 181 |
+
assert len(df) == 2
|
| 182 |
+
|
| 183 |
+
def test_sample_names_from_path_stems(self, sample_outputs_with_out_pdbs):
|
| 184 |
+
df, _, _ = design._format_outputs(sample_outputs_with_out_pdbs)
|
| 185 |
+
assert list(df["Sample"]) == ["1YCR_sample0", "1YCR_sample1"]
|
| 186 |
+
|
| 187 |
+
def test_fasta_format(self, sample_outputs_with_out_pdbs):
|
| 188 |
+
_, fasta, _ = design._format_outputs(sample_outputs_with_out_pdbs)
|
| 189 |
+
lines = fasta.strip().split("\n")
|
| 190 |
+
assert lines[0] == ">1YCR_sample0"
|
| 191 |
+
assert lines[1] == "MTEEQWAQ"
|
| 192 |
+
assert lines[2] == ">1YCR_sample1"
|
| 193 |
+
assert lines[3] == "VSEQQWAQ"
|
| 194 |
+
|
| 195 |
+
def test_uses_caliby_out_pdb_key(self, sample_outputs):
|
| 196 |
+
assert "out_pdbs" not in sample_outputs
|
| 197 |
+
df, fasta, out_pdb_list = design._format_outputs(sample_outputs)
|
| 198 |
+
|
| 199 |
+
assert list(df["Sample"]) == ["1YCR_sample0", "1YCR_sample1"]
|
| 200 |
+
assert ">1YCR_sample0" in fasta
|
| 201 |
+
assert out_pdb_list == sample_outputs["out_pdb"]
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
# ---------------------------------------------------------------------------
|
| 205 |
+
# _get_best_sc_sample
|
| 206 |
+
# ---------------------------------------------------------------------------
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class TestGetBestScSample:
|
| 210 |
+
"""Picks the sample with the highest tmalign_score."""
|
| 211 |
+
|
| 212 |
+
def test_picks_highest_tmalign_score(self):
|
| 213 |
+
df = pd.DataFrame(
|
| 214 |
+
{
|
| 215 |
+
"Sample": ["1YCR_sample0", "1YCR_sample1", "1YCR_sample2"],
|
| 216 |
+
"tmalign_score": [0.5, 0.9, 0.7],
|
| 217 |
+
}
|
| 218 |
+
)
|
| 219 |
+
assert viewers._get_best_sc_sample(df) == "1YCR_sample1"
|
| 220 |
+
|
| 221 |
+
def test_falls_back_to_first_when_no_tmalign(self):
|
| 222 |
+
df = pd.DataFrame({"Sample": ["1YCR_sample0", "1YCR_sample1"]})
|
| 223 |
+
assert viewers._get_best_sc_sample(df) == "1YCR_sample0"
|
| 224 |
+
|
| 225 |
+
def test_falls_back_to_first_when_all_nan(self):
|
| 226 |
+
df = pd.DataFrame(
|
| 227 |
+
{
|
| 228 |
+
"Sample": ["A_sample0", "A_sample1"],
|
| 229 |
+
"tmalign_score": [float("nan"), float("nan")],
|
| 230 |
+
}
|
| 231 |
+
)
|
| 232 |
+
assert viewers._get_best_sc_sample(df) == "A_sample0"
|
| 233 |
+
|
| 234 |
+
def test_returns_none_for_empty_df(self):
|
| 235 |
+
assert viewers._get_best_sc_sample(pd.DataFrame()) is None
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
# ---------------------------------------------------------------------------
|
| 239 |
+
# _render_af2_viewer / _render_reference_viewer
|
| 240 |
+
# ---------------------------------------------------------------------------
|
| 241 |
+
|
| 242 |
+
_MINIMAL_PDB = "ATOM 1 CA ALA A 1 0.000 0.000 0.000 1.00 90.00 C\nEND\n"
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
class TestRenderAf2Viewer:
|
| 246 |
+
"""Renders AF2 prediction with pLDDT coloring via molview."""
|
| 247 |
+
|
| 248 |
+
def test_returns_html_with_valid_data(self):
|
| 249 |
+
html = viewers._render_af2_viewer("test_sample0", {"test_sample0": _MINIMAL_PDB})
|
| 250 |
+
assert "iframe" in html
|
| 251 |
+
|
| 252 |
+
def test_returns_empty_for_missing_sample(self):
|
| 253 |
+
assert viewers._render_af2_viewer("missing", {"other": _MINIMAL_PDB}) == ""
|
| 254 |
+
|
| 255 |
+
def test_returns_empty_for_none_sample(self):
|
| 256 |
+
assert viewers._render_af2_viewer(None, {"test": _MINIMAL_PDB}) == ""
|
| 257 |
+
|
| 258 |
+
def test_returns_empty_for_empty_data(self):
|
| 259 |
+
assert viewers._render_af2_viewer("test", {}) == ""
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
class TestRenderReferenceViewer:
|
| 263 |
+
"""Renders original input PDB with chain coloring via molview."""
|
| 264 |
+
|
| 265 |
+
def test_maps_sample_to_input_key(self):
|
| 266 |
+
html = viewers._render_reference_viewer("1YCR_sample0", {"1YCR": _MINIMAL_PDB})
|
| 267 |
+
assert "iframe" in html
|
| 268 |
+
|
| 269 |
+
def test_returns_empty_when_input_key_missing(self):
|
| 270 |
+
assert viewers._render_reference_viewer("1YCR_sample0", {"OTHER": _MINIMAL_PDB}) == ""
|
| 271 |
+
|
| 272 |
+
def test_returns_empty_for_none_sample(self):
|
| 273 |
+
assert viewers._render_reference_viewer(None, {"1YCR": _MINIMAL_PDB}) == ""
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
# ---------------------------------------------------------------------------
|
| 277 |
+
# _update_viewers
|
| 278 |
+
# ---------------------------------------------------------------------------
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class TestUpdateViewers:
|
| 282 |
+
"""Combined handler for overlay toggle."""
|
| 283 |
+
|
| 284 |
+
def test_overlay_off_hides_reference(self):
|
| 285 |
+
af2_html, ref_update = viewers._update_viewers("s0", {"s0": _MINIMAL_PDB}, {"s": _MINIMAL_PDB}, False)
|
| 286 |
+
assert "iframe" in af2_html
|
| 287 |
+
assert ref_update["visible"] is False
|
| 288 |
+
|
| 289 |
+
def test_overlay_on_shows_reference(self):
|
| 290 |
+
af2_html, ref_update = viewers._update_viewers(
|
| 291 |
+
"s_sample0", {"s_sample0": _MINIMAL_PDB}, {"s": _MINIMAL_PDB}, True
|
| 292 |
+
)
|
| 293 |
+
assert "iframe" in af2_html
|
| 294 |
+
assert ref_update["visible"] is True
|
| 295 |
+
assert "iframe" in ref_update["value"]
|
viewers.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Structure viewers and display formatting for the Gradio UI."""
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
+
import gradio as gr
|
| 6 |
+
import pandas as pd
|
| 7 |
+
|
| 8 |
+
from file_utils import _df_to_csv
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _format_display_number(value) -> str:
|
| 12 |
+
if pd.isna(value):
|
| 13 |
+
return ""
|
| 14 |
+
return f"{float(value):.2f}".rstrip("0").rstrip(".")
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _file_output(value: str | None) -> dict:
|
| 18 |
+
return gr.update(value=value, visible=bool(value))
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _csv_download_output(df: pd.DataFrame | None) -> dict:
|
| 22 |
+
return _file_output(_df_to_csv(df))
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _format_results_display(df: pd.DataFrame):
|
| 26 |
+
numeric_columns = [col for col in df.columns[-4:] if pd.api.types.is_numeric_dtype(df[col])]
|
| 27 |
+
if not numeric_columns:
|
| 28 |
+
return df
|
| 29 |
+
return df.style.format({col: _format_display_number for col in numeric_columns})
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _get_best_sc_sample(df: pd.DataFrame) -> str | None:
|
| 33 |
+
if df.empty or "Sample" not in df.columns:
|
| 34 |
+
return None
|
| 35 |
+
if "tmalign_score" in df.columns and df["tmalign_score"].notna().any():
|
| 36 |
+
return str(df.loc[df["tmalign_score"].idxmax(), "Sample"])
|
| 37 |
+
return str(df.iloc[0]["Sample"])
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _render_af2_viewer(sample_name: str | None, af2_pdb_data: dict[str, str], color_mode: str = "plddt") -> str:
|
| 41 |
+
if not sample_name or not af2_pdb_data or sample_name not in af2_pdb_data:
|
| 42 |
+
return ""
|
| 43 |
+
import molview as mv
|
| 44 |
+
|
| 45 |
+
v = mv.view(width=840, height=500)
|
| 46 |
+
v.addModel(af2_pdb_data[sample_name], name=f"AF2: {sample_name}")
|
| 47 |
+
v.setColorMode(color_mode)
|
| 48 |
+
v.setBackgroundColor("#000000")
|
| 49 |
+
return v._repr_html_()
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _render_reference_viewer(sample_name: str | None, input_pdb_data: dict[str, str], color_mode: str = "chain") -> str:
|
| 53 |
+
if not sample_name or not input_pdb_data:
|
| 54 |
+
return ""
|
| 55 |
+
input_key = re.sub(r"_sample\d+$", "", sample_name)
|
| 56 |
+
if input_key not in input_pdb_data:
|
| 57 |
+
return ""
|
| 58 |
+
import molview as mv
|
| 59 |
+
|
| 60 |
+
v = mv.view(width=840, height=500)
|
| 61 |
+
v.addModel(input_pdb_data[input_key], name=f"Reference: {input_key}")
|
| 62 |
+
v.setColorMode(color_mode)
|
| 63 |
+
v.setBackgroundColor("#000000")
|
| 64 |
+
return v._repr_html_()
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _update_viewers(
|
| 68 |
+
best_sample: str,
|
| 69 |
+
af2_pdb_data: dict[str, str],
|
| 70 |
+
input_pdb_data: dict[str, str],
|
| 71 |
+
show_overlay: bool,
|
| 72 |
+
color_mode: str = "plddt",
|
| 73 |
+
ref_color_mode: str = "chain",
|
| 74 |
+
):
|
| 75 |
+
af2_html = _render_af2_viewer(best_sample, af2_pdb_data, color_mode)
|
| 76 |
+
if show_overlay:
|
| 77 |
+
ref_html = _render_reference_viewer(best_sample, input_pdb_data, ref_color_mode)
|
| 78 |
+
return af2_html, gr.update(value=ref_html, visible=True), gr.update(visible=True)
|
| 79 |
+
return af2_html, gr.update(value="", visible=False), gr.update(visible=False)
|