Justine Yuan commited on
Commit
3beba17
·
1 Parent(s): e732716

Caliby HuggingFace example

Browse files
.gitignore ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[codz]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py.cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ # Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ # poetry.lock
109
+ # poetry.toml
110
+
111
+ # pdm
112
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113
+ # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
114
+ # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
115
+ # pdm.lock
116
+ # pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # pixi
121
+ # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
122
+ # pixi.lock
123
+ # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
124
+ # in the .venv directory. It is recommended not to include this directory in version control.
125
+ .pixi
126
+
127
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
128
+ __pypackages__/
129
+
130
+ # Celery stuff
131
+ celerybeat-schedule
132
+ celerybeat.pid
133
+
134
+ # Redis
135
+ *.rdb
136
+ *.aof
137
+ *.pid
138
+
139
+ # RabbitMQ
140
+ mnesia/
141
+ rabbitmq/
142
+ rabbitmq-data/
143
+
144
+ # ActiveMQ
145
+ activemq-data/
146
+
147
+ # SageMath parsed files
148
+ *.sage.py
149
+
150
+ # Environments
151
+ .env
152
+ .envrc
153
+ .venv
154
+ env/
155
+ venv/
156
+ ENV/
157
+ env.bak/
158
+ venv.bak/
159
+
160
+ # Spyder project settings
161
+ .spyderproject
162
+ .spyproject
163
+
164
+ # Rope project settings
165
+ .ropeproject
166
+
167
+ # mkdocs documentation
168
+ /site
169
+
170
+ # mypy
171
+ .mypy_cache/
172
+ .dmypy.json
173
+ dmypy.json
174
+
175
+ # Pyre type checker
176
+ .pyre/
177
+
178
+ # pytype static type analyzer
179
+ .pytype/
180
+
181
+ # Cython debug symbols
182
+ cython_debug/
183
+
184
+ # PyCharm
185
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
186
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
187
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
188
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
189
+ # .idea/
190
+
191
+ # Abstra
192
+ # Abstra is an AI-powered process automation framework.
193
+ # Ignore directories containing user credentials, local state, and settings.
194
+ # Learn more at https://abstra.io/docs
195
+ .abstra/
196
+
197
+ # Visual Studio Code
198
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
199
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
200
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
201
+ # you could uncomment the following to ignore the entire vscode folder
202
+ # .vscode/
203
+
204
+ # Ruff stuff:
205
+ .ruff_cache/
206
+
207
+ # PyPI configuration file
208
+ .pypirc
209
+
210
+ # Marimo
211
+ marimo/_static/
212
+ marimo/_lsp/
213
+ __marimo__/
214
+
215
+ # Streamlit
216
+ .streamlit/secrets.toml
217
+
218
+ envs
219
+
220
+ CLAUDE.md
README.md CHANGED
@@ -4,7 +4,8 @@ emoji: 🐢
4
  colorFrom: yellow
5
  colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 5.49.1
 
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
4
  colorFrom: yellow
5
  colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: "6.6.0"
8
+ python_version: "3.12"
9
  app_file: app.py
10
  pinned: false
11
  license: apache-2.0
app.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradio app for Caliby sequence design."""
2
+
3
+ import base64
4
+ from pathlib import Path
5
+
6
+ import gradio as gr
7
+
8
+ # Eagerly import so the wandb/pydantic init runs in the main thread
9
+ # (where sys.modules['__main__'] exists), not in a Gradio worker thread.
10
+ import caliby.data.preprocessing.atomworks.clean_pdbs # noqa: F401
11
+
12
+ from design import design_sequences
13
+ from file_utils import _get_file_path, _write_zip_from_paths
14
+ from viewers import (
15
+ _csv_download_output,
16
+ _file_output,
17
+ _format_results_display,
18
+ _get_best_sc_sample,
19
+ _render_af2_viewer,
20
+ _update_viewers,
21
+ )
22
+
23
+
24
+ def _get_upload_instructions(mode: str) -> str:
25
+ if mode == "none":
26
+ return "Upload a single PDB or CIF file."
27
+ elif mode == "synthetic":
28
+ return "Upload a single PDB or CIF file. Conformers will be generated automatically."
29
+ else:
30
+ return "Upload all PDB files — primary conformer first, then additional conformers."
31
+
32
+
33
+ def _clean_uploaded_pdbs(pdb_files: list | None):
34
+ if not pdb_files:
35
+ return None, gr.update(visible=False), gr.update(visible=False), gr.update(interactive=False)
36
+
37
+ from caliby import clean_pdbs
38
+
39
+ pdb_paths = [str(_get_file_path(f)) for f in pdb_files]
40
+ cleaned_paths = clean_pdbs(pdb_paths)
41
+ zip_path = _write_zip_from_paths(cleaned_paths, "cleaned_pdbs", ".zip")
42
+
43
+ return (
44
+ cleaned_paths,
45
+ gr.update(
46
+ value="**Note:** Your files have been cleaned and standardized to mmCIF format "
47
+ "to avoid downstream parsing and alignment issues. "
48
+ "If you plan to use positional constraints, please download the cleaned files and double "
49
+ "check the new residue indices.",
50
+ visible=True,
51
+ ),
52
+ gr.update(value=zip_path, visible=True),
53
+ gr.update(interactive=True),
54
+ )
55
+
56
+
57
+ def _reset_cleaned_state():
58
+ return None, gr.update(visible=False), gr.update(visible=False), gr.update(interactive=False)
59
+
60
+
61
+ def submit_design_sequences(
62
+ cleaned_files: list[str] | None,
63
+ ensemble_mode: str,
64
+ model_variant: str,
65
+ num_seqs: int,
66
+ omit_aas: list[str] | None,
67
+ temperature: float,
68
+ fixed_pos_seq: str,
69
+ fixed_pos_scn: str,
70
+ fixed_pos_override_seq: str,
71
+ pos_restrict_aatype: str,
72
+ symmetry_pos: str,
73
+ num_protpardelle_conformers: int,
74
+ run_af2_eval: bool = False,
75
+ ):
76
+ df, fasta_text, out_zip_path, sc_zip_path, af2_pdb_data, input_pdb_data = design_sequences(
77
+ pdb_files=cleaned_files,
78
+ ensemble_mode=ensemble_mode,
79
+ model_variant=model_variant,
80
+ num_seqs=num_seqs,
81
+ omit_aas=omit_aas,
82
+ temperature=temperature,
83
+ fixed_pos_seq=fixed_pos_seq,
84
+ fixed_pos_scn=fixed_pos_scn,
85
+ fixed_pos_override_seq=fixed_pos_override_seq,
86
+ pos_restrict_aatype=pos_restrict_aatype,
87
+ symmetry_pos=symmetry_pos,
88
+ num_protpardelle_conformers=num_protpardelle_conformers,
89
+ run_af2_eval=run_af2_eval,
90
+ )
91
+ has_af2 = bool(af2_pdb_data)
92
+ best_sample = _get_best_sc_sample(df) if has_af2 else ""
93
+ af2_html = _render_af2_viewer(best_sample, af2_pdb_data) if has_af2 else ""
94
+
95
+ return (
96
+ gr.update(visible=True),
97
+ gr.update(value=_format_results_display(df), visible=True),
98
+ df,
99
+ gr.update(value=fasta_text, visible=True),
100
+ _file_output(out_zip_path),
101
+ _file_output(sc_zip_path),
102
+ af2_pdb_data,
103
+ input_pdb_data,
104
+ best_sample,
105
+ gr.update(visible=has_af2),
106
+ af2_html,
107
+ gr.update(value="", visible=False),
108
+ gr.update(visible=False),
109
+ gr.update(visible=False),
110
+ )
111
+
112
+
113
+ theme = gr.themes.Base(
114
+ primary_hue="amber",
115
+ secondary_hue="orange",
116
+ radius_size="lg",
117
+ font=[gr.themes.GoogleFont('Instrument Sans'), 'ui-sans-serif', 'system-ui', 'sans-serif'],
118
+ ).set(
119
+ body_text_color='*neutral_700',
120
+ body_text_color_dark='*neutral_300',
121
+ body_text_color_subdued='*neutral_500',
122
+ block_title_text_color='*neutral_700',
123
+ block_info_text_color='*neutral_500',
124
+ block_border_width_dark='0px',
125
+ block_padding='*spacing_xl calc(*spacing_xl + 3px)',
126
+ block_label_border_width_dark='0px',
127
+ block_label_padding='*spacing_md *spacing_lg',
128
+ button_secondary_background_fill_dark='*neutral_600',
129
+ checkbox_label_text_color_dark='*neutral_100',
130
+ )
131
+
132
+ css = """
133
+ .loading-pulse { animation: pulse 2.5s ease-in-out infinite; }
134
+ @keyframes pulse { 0%, 100% { opacity: 1; } 50% { opacity: 0.3; } }
135
+ .omit-aa-dropdown ul { max-height: 200px !important; overflow-y: auto; }
136
+ .compact-file .large { min-height: 50px !important; }
137
+ #results-table th:nth-child(2),
138
+ #results-table td:nth-child(2) {
139
+ max-width: 28rem;
140
+ width: 28rem;
141
+ }
142
+ #results-table td:nth-child(2) {
143
+ overflow: hidden;
144
+ }
145
+ #results-table td:nth-child(2) > div {
146
+ display: block;
147
+ max-width: 100%;
148
+ overflow-x: auto;
149
+ overflow-y: hidden;
150
+ white-space: nowrap !important;
151
+ scrollbar-width: thin;
152
+ }
153
+ #af2-viewer, #ref-viewer {
154
+ display: flex;
155
+ justify-content: center;
156
+ }
157
+ #af2-viewer iframe, #ref-viewer iframe {
158
+ max-width: 100%;
159
+ }
160
+ """
161
+
162
+ _LOGO_B64 = base64.b64encode(Path(__file__).with_name("caliby_transparent.png").read_bytes()).decode()
163
+
164
+ with gr.Blocks(title="Caliby - Protein Sequence Design") as demo:
165
+ gr.HTML(
166
+ '<div style="display: flex; align-items: center; gap: 16px;">'
167
+ f'<img src="data:image/png;base64,{_LOGO_B64}" alt="Caliby logo" style="height: 80px;">'
168
+ '<h1 style="margin: 0;">Caliby - Protein Sequence Design</h1>'
169
+ '</div>'
170
+ )
171
+
172
+ with gr.Row():
173
+ with gr.Column(scale=1):
174
+ model_variant = gr.Radio(
175
+ choices=[
176
+ ("Caliby", "caliby"),
177
+ ("SolubleCaliby v1", "soluble_caliby_v1"),
178
+ ],
179
+ value="caliby",
180
+ label="Model",
181
+ )
182
+ ensemble_mode = gr.Radio(
183
+ choices=[
184
+ ("Fixed backbone", "none"),
185
+ ("Synthetic ensemble", "synthetic"),
186
+ ("Upload your own ensemble", "user"),
187
+ ],
188
+ value="synthetic",
189
+ label="Ensemble mode",
190
+ )
191
+ run_af2_eval = gr.Checkbox(
192
+ label="Run AF2 self-consistency evaluation",
193
+ value=False,
194
+ info="Refold designed sequences with AlphaFold2 and compute scRMSD, pLDDT, and TM-score",
195
+ )
196
+ upload_instructions = gr.Markdown(
197
+ _get_upload_instructions("synthetic"),
198
+ )
199
+ pdb_input = gr.File(
200
+ file_count="multiple",
201
+ label="PDB/CIF file(s)",
202
+ file_types=[".pdb", ".cif"],
203
+ )
204
+ finish_upload_btn = gr.Button("Upload", variant="secondary")
205
+ cleaned_files_state = gr.State(None)
206
+ clean_notification = gr.Markdown(visible=False)
207
+ clean_download = gr.File(
208
+ label="Download cleaned files", visible=False, elem_classes=["compact-file"]
209
+ )
210
+ num_seqs = gr.Slider(
211
+ minimum=1,
212
+ maximum=4,
213
+ value=1,
214
+ step=1,
215
+ label="Number of sequences",
216
+ )
217
+ omit_aas = gr.Dropdown(
218
+ choices=[
219
+ "A",
220
+ "C",
221
+ "D",
222
+ "E",
223
+ "F",
224
+ "G",
225
+ "H",
226
+ "I",
227
+ "K",
228
+ "L",
229
+ "M",
230
+ "N",
231
+ "P",
232
+ "Q",
233
+ "R",
234
+ "S",
235
+ "T",
236
+ "V",
237
+ "W",
238
+ "Y",
239
+ ],
240
+ multiselect=True,
241
+ label="Amino acids to omit",
242
+ elem_classes=["omit-aa-dropdown"],
243
+ )
244
+ temperature = gr.Slider(
245
+ minimum=0.01,
246
+ maximum=1,
247
+ value=0.01,
248
+ step=0.01,
249
+ label="Sampling temperature",
250
+ )
251
+
252
+ submit_btn = gr.Button("Design sequences", variant="primary", interactive=False)
253
+
254
+ with gr.Accordion("Advanced constraints", open=False):
255
+ fixed_pos_seq = gr.Textbox(
256
+ label="Fixed positions",
257
+ info="Format: A1-100,B1-100 \nSequence positions in the input PDB to condition on so that they"
258
+ " remain fixed during design. For ensemble-conditioned design, fixed_pos_seq is applied using"
259
+ " the primary conformer's sequence.",
260
+ placeholder="e.g. A1-100,B1-100",
261
+ )
262
+ fixed_pos_scn = gr.Textbox(
263
+ label="Fixed sidechain positions",
264
+ info="Format: A1-10,A12,A15-20 \nSidechain positions in the input PDB to condition on so that they"
265
+ " remain fixed during design. Note that fixed sidechain positions must be a subset of fixed"
266
+ " sequence positions, since it does not make sense to condition on a sidechain without also"
267
+ " conditioning on its sequence identity.",
268
+ placeholder="e.g. A1-10,A12,A15-20",
269
+ )
270
+ fixed_pos_override_seq = gr.Textbox(
271
+ label="Override sequence at positions",
272
+ info="Format: A26:A,A27:L \nSequence positions in the input PDB to first override the sequence at,"
273
+ " and then condition on. The colon separates the position and the desired amino acid.",
274
+ placeholder="e.g. A26:A,A27:L",
275
+ )
276
+ pos_restrict_aatype = gr.Textbox(
277
+ label="Position restrictions",
278
+ info="Format: A26:AVG,A27:VG \nAllowed amino acids for certain positions in the input PDB. The"
279
+ " colon separates the position and the allowed amino acids.",
280
+ placeholder="e.g. A26:AVG,A27:VG",
281
+ )
282
+ symmetry_pos = gr.Textbox(
283
+ label="Symmetry positions",
284
+ info="Format: A10,B10,C10|A11,B11,C11 \nSymmetry positions for tying sampling across residue"
285
+ " positions. The pipe separates groups of positions to sample symmetrically. In the example,"
286
+ " A10, B10, and C10 are tied together, and A11, B11, and C11 are tied together.",
287
+ placeholder="e.g. A10,B10,C10|A11,B11,C11",
288
+ )
289
+ num_protpardelle_conformers = gr.Slider(
290
+ minimum=1,
291
+ maximum=15,
292
+ value=15,
293
+ step=1,
294
+ label="Number of conformers to generate",
295
+ visible=True,
296
+ )
297
+
298
+ with gr.Column(scale=2):
299
+ raw_results_df = gr.State(None)
300
+ af2_pdb_state = gr.State({})
301
+ input_pdb_state = gr.State({})
302
+ best_sample_state = gr.State("")
303
+
304
+ results_placeholder = gr.Markdown(
305
+ "Results will appear here after designing sequences.",
306
+ )
307
+ results_header = gr.Markdown("### Results", visible=False)
308
+ results_df = gr.Dataframe(
309
+ show_label=False,
310
+ interactive=False,
311
+ wrap=False,
312
+ column_widths=[160, 448],
313
+ elem_id="results-table",
314
+ visible=False,
315
+ )
316
+ fasta_output = gr.Textbox(
317
+ label="Sequences (FASTA)",
318
+ lines=10,
319
+ visible=False,
320
+ )
321
+ with gr.Row():
322
+ csv_download = gr.File(label="Download results CSV", elem_classes=["compact-file"], visible=False)
323
+ output_files = gr.File(label="Download CIF files", elem_classes=["compact-file"], visible=False)
324
+ sc_output_files = gr.File(
325
+ label="Download AF2 self-consistency outputs",
326
+ elem_classes=["compact-file"],
327
+ visible=False,
328
+ )
329
+
330
+ with gr.Column(visible=False) as viewer_section:
331
+ gr.Markdown("---")
332
+ with gr.Row():
333
+ gr.Markdown("### AF2 Prediction")
334
+ af2_color_mode = gr.Dropdown(
335
+ choices=[
336
+ ("pLDDT", "plddt"),
337
+ ("Chain", "chain"),
338
+ ("Rainbow", "rainbow"),
339
+ ("Secondary structure", "secondary"),
340
+ ],
341
+ value="plddt",
342
+ label="Color by",
343
+ scale=0,
344
+ )
345
+ af2_viewer = gr.HTML(elem_id="af2-viewer")
346
+ show_overlay = gr.Checkbox(label="Show reference structure", value=False)
347
+ with gr.Column(visible=False) as ref_section:
348
+ with gr.Row():
349
+ gr.Markdown("### Reference Structure")
350
+ ref_color_mode = gr.Dropdown(
351
+ choices=[
352
+ ("Chain", "chain"),
353
+ ("pLDDT", "plddt"),
354
+ ("Rainbow", "rainbow"),
355
+ ("Secondary structure", "secondary"),
356
+ ],
357
+ value="chain",
358
+ label="Color by",
359
+ scale=0,
360
+ )
361
+ reference_viewer = gr.HTML(elem_id="ref-viewer")
362
+
363
+ submit_btn.click(
364
+ fn=lambda: gr.update(value='<div class="loading-pulse">Running design pipeline\u2026</div>', visible=True),
365
+ outputs=[results_placeholder],
366
+ ).then(
367
+ fn=submit_design_sequences,
368
+ inputs=[
369
+ cleaned_files_state,
370
+ ensemble_mode,
371
+ model_variant,
372
+ num_seqs,
373
+ omit_aas,
374
+ temperature,
375
+ fixed_pos_seq,
376
+ fixed_pos_scn,
377
+ fixed_pos_override_seq,
378
+ pos_restrict_aatype,
379
+ symmetry_pos,
380
+ num_protpardelle_conformers,
381
+ run_af2_eval,
382
+ ],
383
+ outputs=[
384
+ results_header,
385
+ results_df,
386
+ raw_results_df,
387
+ fasta_output,
388
+ output_files,
389
+ sc_output_files,
390
+ af2_pdb_state,
391
+ input_pdb_state,
392
+ best_sample_state,
393
+ viewer_section,
394
+ af2_viewer,
395
+ reference_viewer,
396
+ ref_section,
397
+ results_placeholder,
398
+ ],
399
+ )
400
+
401
+ raw_results_df.change(fn=_csv_download_output, inputs=[raw_results_df], outputs=[csv_download])
402
+
403
+ finish_upload_btn.click(
404
+ fn=lambda: gr.update(value="Processing\u2026", interactive=False),
405
+ outputs=[finish_upload_btn],
406
+ ).then(
407
+ fn=_clean_uploaded_pdbs,
408
+ inputs=[pdb_input],
409
+ outputs=[cleaned_files_state, clean_notification, clean_download, submit_btn],
410
+ ).then(
411
+ fn=lambda: gr.update(value="Upload", interactive=True),
412
+ outputs=[finish_upload_btn],
413
+ )
414
+
415
+ pdb_input.change(
416
+ fn=_reset_cleaned_state,
417
+ outputs=[cleaned_files_state, clean_notification, clean_download, submit_btn],
418
+ )
419
+
420
+ ensemble_mode.change(
421
+ fn=lambda mode: (gr.update(visible=(mode == "synthetic")), _get_upload_instructions(mode)),
422
+ inputs=[ensemble_mode],
423
+ outputs=[num_protpardelle_conformers, upload_instructions],
424
+ )
425
+
426
+ viewer_inputs = [best_sample_state, af2_pdb_state, input_pdb_state, show_overlay, af2_color_mode, ref_color_mode]
427
+ viewer_outputs = [af2_viewer, reference_viewer, ref_section]
428
+
429
+ show_overlay.change(fn=_update_viewers, inputs=viewer_inputs, outputs=viewer_outputs)
430
+ af2_color_mode.change(fn=_update_viewers, inputs=viewer_inputs, outputs=viewer_outputs)
431
+ ref_color_mode.change(fn=_update_viewers, inputs=viewer_inputs, outputs=viewer_outputs)
432
+
433
+ if __name__ == "__main__":
434
+ demo.launch(theme=theme, css=css, ssr_mode=False)
app_config.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration constants and module-level side effects (weight downloads)."""
2
+
3
+ import os
4
+
5
+ from huggingface_hub import snapshot_download
6
+
7
+ WEIGHTS_DIR = snapshot_download(
8
+ repo_id="ProteinDesignLab/caliby-weights", repo_type="model", token=os.environ.get("HF_TOKEN")
9
+ )
10
+
11
+ # Set MODEL_PARAMS_DIR so caliby's weight utilities can find/download files.
12
+ os.environ.setdefault("MODEL_PARAMS_DIR", WEIGHTS_DIR)
caliby_transparent.png ADDED
constraints.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Input validation and position constraint building."""
2
+
3
+ import pandas as pd
4
+
5
+
6
+ def _validate_design_inputs(pdb_files: list | None, ensemble_mode: str) -> str | None:
7
+ if not pdb_files:
8
+ return "Upload at least one PDB or CIF file."
9
+
10
+ if ensemble_mode == "user" and len(pdb_files) < 2:
11
+ return "User ensemble mode requires at least two files."
12
+
13
+ single_file_mode_messages = {
14
+ "none": "Single structure mode requires exactly one file.",
15
+ "synthetic": "Protpardelle mode requires exactly one file.",
16
+ }
17
+ message = single_file_mode_messages.get(ensemble_mode)
18
+ if message and len(pdb_files) != 1:
19
+ return message
20
+
21
+ return None
22
+
23
+
24
+ def _build_pos_constraint_df(
25
+ pdb_key: str,
26
+ fixed_pos_seq: str,
27
+ fixed_pos_scn: str,
28
+ fixed_pos_override_seq: str,
29
+ pos_restrict_aatype: str,
30
+ symmetry_pos: str,
31
+ ) -> pd.DataFrame | None:
32
+ row = {}
33
+ if fixed_pos_seq and fixed_pos_seq.strip():
34
+ row["fixed_pos_seq"] = fixed_pos_seq.strip()
35
+ if fixed_pos_scn and fixed_pos_scn.strip():
36
+ row["fixed_pos_scn"] = fixed_pos_scn.strip()
37
+ if fixed_pos_override_seq and fixed_pos_override_seq.strip():
38
+ row["fixed_pos_override_seq"] = fixed_pos_override_seq.strip()
39
+ if pos_restrict_aatype and pos_restrict_aatype.strip():
40
+ row["pos_restrict_aatype"] = pos_restrict_aatype.strip()
41
+ if symmetry_pos and symmetry_pos.strip():
42
+ row["symmetry_pos"] = symmetry_pos.strip()
43
+ if not row:
44
+ return None
45
+ row["pdb_key"] = pdb_key
46
+ return pd.DataFrame([row])
design.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Core design pipeline: context building, execution, output formatting."""
2
+
3
+ import re
4
+ import tempfile
5
+ from pathlib import Path
6
+
7
+ import gradio as gr
8
+ import pandas as pd
9
+ import spaces
10
+ import torch
11
+
12
+ from app_config import WEIGHTS_DIR
13
+ from constraints import _build_pos_constraint_df, _validate_design_inputs
14
+ from ensemble import _generate_protpardelle_ensemble, _setup_user_ensemble_dir
15
+ from file_utils import _copy_uploaded_files, _get_file_path, _sanitize_download_stem, _write_zip_from_paths
16
+ from models import get_model
17
+ from self_consistency import _run_self_consistency
18
+
19
+ # ZeroGPU quota-aware retry: request the max duration first, and if the
20
+ # scheduler returns a quota error (which is free — no GPU time consumed),
21
+ # parse the remaining seconds and retry with that exact amount.
22
+ _MAX_GPU_DURATION = 120 # Per-call max; daily quota is 210s but per-call cap is lower
23
+ _gpu_duration_override: int | None = None
24
+
25
+
26
+ def _dynamic_gpu_duration(*args, **kwargs) -> int:
27
+ """Return the current GPU duration for @spaces.GPU scheduling."""
28
+ return _gpu_duration_override if _gpu_duration_override is not None else _MAX_GPU_DURATION
29
+
30
+
31
+ def _parse_quota_left(error: Exception) -> int | None:
32
+ """Extract remaining GPU seconds from a ZeroGPU quota error message.
33
+
34
+ Returns the number of seconds left, or None if not a recoverable quota error.
35
+ """
36
+ message = getattr(error, 'message', None)
37
+ if not isinstance(message, str):
38
+ return None
39
+ match = re.search(r'(\d+)s left\)', message)
40
+ return int(match.group(1)) if match else None
41
+
42
+
43
+ def _build_design_context(
44
+ pdb_paths: list[str],
45
+ ensemble_mode: str,
46
+ tmpdir: Path,
47
+ num_protpardelle_conformers: int,
48
+ fixed_pos_seq: str,
49
+ fixed_pos_scn: str,
50
+ fixed_pos_override_seq: str,
51
+ pos_restrict_aatype: str,
52
+ symmetry_pos: str,
53
+ ) -> tuple[list[str] | dict[str, list[str]], pd.DataFrame | None]:
54
+ pdb_key = Path(pdb_paths[0]).stem
55
+ pos_constraint_df = _build_pos_constraint_df(
56
+ pdb_key=pdb_key,
57
+ fixed_pos_seq=fixed_pos_seq,
58
+ fixed_pos_scn=fixed_pos_scn,
59
+ fixed_pos_override_seq=fixed_pos_override_seq,
60
+ pos_restrict_aatype=pos_restrict_aatype,
61
+ symmetry_pos=symmetry_pos,
62
+ )
63
+ if ensemble_mode == "none":
64
+ return pdb_paths, pos_constraint_df
65
+
66
+ if ensemble_mode == "synthetic":
67
+ design_inputs = _generate_protpardelle_ensemble(
68
+ pdb_path=pdb_paths[0],
69
+ num_conformers=num_protpardelle_conformers,
70
+ out_dir=tmpdir,
71
+ weights_dir=WEIGHTS_DIR,
72
+ )
73
+ else:
74
+ design_inputs = _setup_user_ensemble_dir(pdb_paths=pdb_paths)
75
+
76
+ if pos_constraint_df is not None:
77
+ from caliby import make_ensemble_constraints
78
+
79
+ row = pos_constraint_df.iloc[0]
80
+ cols = {col: row[col] for col in pos_constraint_df.columns if col != "pdb_key"}
81
+ pos_constraint_df = make_ensemble_constraints({pdb_key: cols}, design_inputs)
82
+
83
+ return design_inputs, pos_constraint_df
84
+
85
+
86
+ def _format_outputs(outputs: dict) -> tuple[pd.DataFrame, str, list[str]]:
87
+ out_pdb_list = outputs["out_pdb"]
88
+ df = pd.DataFrame(
89
+ {
90
+ "Sample": [Path(out_pdb).stem for out_pdb in out_pdb_list],
91
+ "Sequence": outputs["seq"],
92
+ "Energy (U)": outputs["U"],
93
+ }
94
+ )
95
+ fasta_lines = []
96
+ for i, (eid, seq) in enumerate(zip(outputs["example_id"], outputs["seq"])):
97
+ fasta_lines.append(f">{eid}_sample{i}")
98
+ fasta_lines.append(seq)
99
+ fasta_text = "\n".join(fasta_lines)
100
+ return df, fasta_text, out_pdb_list
101
+
102
+
103
+ @spaces.GPU(duration=_dynamic_gpu_duration)
104
+ def _design_sequences_gpu(
105
+ pdb_files: list | None,
106
+ ensemble_mode: str,
107
+ model_variant: str,
108
+ num_seqs: int,
109
+ omit_aas: list[str] | None,
110
+ temperature: float,
111
+ fixed_pos_seq: str,
112
+ fixed_pos_scn: str,
113
+ fixed_pos_override_seq: str,
114
+ pos_restrict_aatype: str,
115
+ symmetry_pos: str,
116
+ num_protpardelle_conformers: int,
117
+ run_af2_eval: bool = False,
118
+ ):
119
+ validation_error = _validate_design_inputs(pdb_files, ensemble_mode)
120
+ if validation_error:
121
+ return pd.DataFrame(), validation_error, None, None, {}, {}
122
+
123
+ device = "cuda" if torch.cuda.is_available() else "cpu"
124
+ torch.set_grad_enabled(False)
125
+ download_stem = _sanitize_download_stem(_get_file_path(pdb_files[0]).stem)
126
+
127
+ gr.Info("Loading model...")
128
+ model = get_model(model_variant, device)
129
+
130
+ with tempfile.TemporaryDirectory() as tmpdir:
131
+ tmpdir = Path(tmpdir)
132
+ pdb_paths = _copy_uploaded_files(pdb_files, tmpdir)
133
+ input_pdb_data = {Path(p).stem: Path(p).read_text() for p in pdb_paths}
134
+
135
+ out_dir = tmpdir / "outputs"
136
+ out_dir.mkdir(parents=True, exist_ok=True)
137
+
138
+ if ensemble_mode == "synthetic":
139
+ gr.Info("Generating conformer ensemble...")
140
+ elif ensemble_mode == "user":
141
+ gr.Info("Preparing uploaded ensemble...")
142
+ design_inputs, pos_constraint_df = _build_design_context(
143
+ pdb_paths=pdb_paths,
144
+ ensemble_mode=ensemble_mode,
145
+ tmpdir=tmpdir,
146
+ num_protpardelle_conformers=num_protpardelle_conformers,
147
+ fixed_pos_seq=fixed_pos_seq,
148
+ fixed_pos_scn=fixed_pos_scn,
149
+ fixed_pos_override_seq=fixed_pos_override_seq,
150
+ pos_restrict_aatype=pos_restrict_aatype,
151
+ symmetry_pos=symmetry_pos,
152
+ )
153
+
154
+ gr.Info("Designing sequences...")
155
+ sample_kwargs = dict(
156
+ out_dir=str(out_dir),
157
+ num_seqs_per_pdb=num_seqs,
158
+ omit_aas=omit_aas if omit_aas else None,
159
+ temperature=temperature,
160
+ num_workers=0,
161
+ pos_constraint_df=pos_constraint_df,
162
+ )
163
+ if ensemble_mode == "none":
164
+ outputs = model.sample(design_inputs, **sample_kwargs)
165
+ else:
166
+ outputs = model.ensemble_sample(design_inputs, **sample_kwargs)
167
+
168
+ df, fasta_text, out_pdb_list = _format_outputs(outputs)
169
+
170
+ sc_zip_path = None
171
+ af2_pdb_data = {}
172
+ if run_af2_eval:
173
+ gr.Info("Running AF2 self-consistency evaluation...")
174
+ sc_zip_path, af2_pdb_data = _run_self_consistency(model, df, out_pdb_list, out_dir, download_stem)
175
+
176
+ out_zip_path = _write_zip_from_paths(out_pdb_list, download_stem, "_designs.zip")
177
+ return df, fasta_text, out_zip_path, sc_zip_path, af2_pdb_data, input_pdb_data
178
+
179
+
180
+ def design_sequences(
181
+ pdb_files: list | None,
182
+ ensemble_mode: str,
183
+ model_variant: str,
184
+ num_seqs: int,
185
+ omit_aas: list[str] | None,
186
+ temperature: float,
187
+ fixed_pos_seq: str,
188
+ fixed_pos_scn: str,
189
+ fixed_pos_override_seq: str,
190
+ pos_restrict_aatype: str,
191
+ symmetry_pos: str,
192
+ num_protpardelle_conformers: int,
193
+ run_af2_eval: bool = False,
194
+ ):
195
+ """Run sequence design with ZeroGPU quota-aware retry.
196
+
197
+ Requests the max GPU duration first. If the scheduler returns a quota
198
+ error (free — no GPU time consumed), parses the remaining seconds and
199
+ retries with that exact amount to maximize GPU utilization.
200
+ """
201
+ global _gpu_duration_override
202
+
203
+ _gpu_duration_override = None
204
+ try:
205
+ return _design_sequences_gpu(
206
+ pdb_files=pdb_files,
207
+ ensemble_mode=ensemble_mode,
208
+ model_variant=model_variant,
209
+ num_seqs=num_seqs,
210
+ omit_aas=omit_aas,
211
+ temperature=temperature,
212
+ fixed_pos_seq=fixed_pos_seq,
213
+ fixed_pos_scn=fixed_pos_scn,
214
+ fixed_pos_override_seq=fixed_pos_override_seq,
215
+ pos_restrict_aatype=pos_restrict_aatype,
216
+ symmetry_pos=symmetry_pos,
217
+ num_protpardelle_conformers=num_protpardelle_conformers,
218
+ run_af2_eval=run_af2_eval,
219
+ )
220
+ except gr.Error as e:
221
+ remaining = _parse_quota_left(e)
222
+ print(f"[ZeroGPU retry] Caught gr.Error, parsed remaining={remaining}, message={getattr(e, 'message', str(e))}")
223
+ if remaining is None or remaining <= 0:
224
+ raise
225
+ gr.Info(f"GPU quota: {remaining}s remaining, retrying with exact quota")
226
+ _gpu_duration_override = remaining - 1
227
+ try:
228
+ return _design_sequences_gpu(
229
+ pdb_files=pdb_files,
230
+ ensemble_mode=ensemble_mode,
231
+ model_variant=model_variant,
232
+ num_seqs=num_seqs,
233
+ omit_aas=omit_aas,
234
+ temperature=temperature,
235
+ fixed_pos_seq=fixed_pos_seq,
236
+ fixed_pos_scn=fixed_pos_scn,
237
+ fixed_pos_override_seq=fixed_pos_override_seq,
238
+ pos_restrict_aatype=pos_restrict_aatype,
239
+ symmetry_pos=symmetry_pos,
240
+ num_protpardelle_conformers=num_protpardelle_conformers,
241
+ run_af2_eval=run_af2_eval,
242
+ )
243
+ finally:
244
+ _gpu_duration_override = None
ensemble.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Protpardelle and user ensemble generation."""
2
+
3
+ from pathlib import Path
4
+
5
+
6
+ def _generate_protpardelle_ensemble(
7
+ pdb_path: str,
8
+ num_conformers: int,
9
+ out_dir: Path,
10
+ weights_dir: str,
11
+ ) -> dict[str, list[str]]:
12
+ """Generate conformers with Protpardelle-1c, return pdb_to_conformers dict."""
13
+ from caliby import generate_ensembles
14
+
15
+ pdb_to_conformers = generate_ensembles(
16
+ [pdb_path],
17
+ out_dir=str(out_dir / "protpardelle_ensemble"),
18
+ num_samples_per_pdb=num_conformers,
19
+ model_params_path=weights_dir,
20
+ )
21
+ # generate_ensembles returns only generated conformers — prepend the primary structure.
22
+ pdb_stem = Path(pdb_path).stem
23
+ pdb_to_conformers[pdb_stem] = [pdb_path] + pdb_to_conformers.get(pdb_stem, [])
24
+ return pdb_to_conformers
25
+
26
+
27
+ def _setup_user_ensemble_dir(
28
+ pdb_paths: list[str],
29
+ **_ignored,
30
+ ) -> dict[str, list[str]]:
31
+ """Build pdb_to_conformers dict from user-uploaded files.
32
+
33
+ First file is the primary conformer, rest are additional conformers.
34
+ """
35
+ pdb_key = Path(pdb_paths[0]).stem
36
+ return {pdb_key: list(pdb_paths)}
file_utils.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """File path helpers, ZIP operations, and CSV export."""
2
+
3
+ import re
4
+ import tempfile
5
+ import zipfile
6
+ from pathlib import Path
7
+
8
+ import pandas as pd
9
+
10
+
11
+ def _get_file_path(f):
12
+ if isinstance(f, str):
13
+ return Path(f)
14
+ if hasattr(f, "path"):
15
+ return Path(f.path)
16
+ if isinstance(f, dict) and "path" in f:
17
+ return Path(f["path"])
18
+ return Path(str(f))
19
+
20
+
21
+ def _sanitize_download_stem(stem: str) -> str:
22
+ sanitized = re.sub(r"[^A-Za-z0-9._-]+", "_", stem).strip("._-")
23
+ return sanitized or "caliby"
24
+
25
+
26
+ def _make_named_download_path(stem: str, suffix: str) -> str:
27
+ download_dir = Path(tempfile.mkdtemp(prefix="caliby_download_"))
28
+ return str(download_dir / f"{_sanitize_download_stem(stem)}{suffix}")
29
+
30
+
31
+ def _get_results_stem(df: pd.DataFrame) -> str:
32
+ if "Sample" not in df.columns:
33
+ return "caliby"
34
+ sample_name = str(df.iloc[0]["Sample"])
35
+ return _sanitize_download_stem(re.sub(r"_sample\d+$", "", sample_name))
36
+
37
+
38
+ def _copy_uploaded_files(pdb_files: list, tmpdir: Path) -> list[str]:
39
+ pdb_paths = []
40
+ for f in pdb_files:
41
+ src = _get_file_path(f)
42
+ path = tmpdir / src.name
43
+ path.write_bytes(src.read_bytes())
44
+ pdb_paths.append(str(path))
45
+ return pdb_paths
46
+
47
+
48
+ def _write_zip_from_paths(paths: list[str], download_stem: str, suffix: str) -> str | None:
49
+ if not paths:
50
+ return None
51
+
52
+ zip_path = _make_named_download_path(download_stem, suffix)
53
+ with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
54
+ for path in paths:
55
+ zf.write(path, Path(path).name)
56
+ return zip_path
57
+
58
+
59
+ def _write_zip_from_dir(directory: Path, download_stem: str, suffix: str) -> str:
60
+ zip_path = _make_named_download_path(download_stem, suffix)
61
+ with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
62
+ for path in directory.rglob("*"):
63
+ if path.is_file():
64
+ zf.write(path, path.relative_to(directory))
65
+ return zip_path
66
+
67
+
68
+ def _df_to_csv(df: pd.DataFrame | None) -> str | None:
69
+ if df is None or df.empty:
70
+ return None
71
+ path = _make_named_download_path(_get_results_stem(df), "_results.csv")
72
+ df.to_csv(path, index=False)
73
+ return path
models.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model loading and caching."""
2
+
3
+ import sys
4
+ import types
5
+
6
+ from caliby import CalibyModel
7
+
8
+ MODELS: dict[str, CalibyModel] = {}
9
+
10
+
11
+ def get_model(variant: str, device: str) -> CalibyModel:
12
+ """Load and cache a CalibyModel by variant name."""
13
+ if variant not in MODELS:
14
+ # ZeroGPU's @spaces.GPU decorator may remove sys.modules["__main__"].
15
+ # Lightning's load_from_checkpoint calls inspect.stack() which
16
+ # requires it, so ensure a placeholder exists.
17
+ if "__main__" not in sys.modules:
18
+ sys.modules["__main__"] = types.ModuleType("__main__")
19
+
20
+ from caliby import load_model
21
+
22
+ MODELS[variant] = load_model(variant, device=device)
23
+ return MODELS[variant]
pyproject.toml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "caliby-hf"
3
+ version = "0.1.0"
4
+ dependencies = [
5
+ "caliby[af2] @ git+https://github.com/ProteinDesignLab/caliby@20d6757aaaba1662e71234ba25dde0f64b199683",
6
+ "gradio",
7
+ "huggingface_hub",
8
+ "molview>=0.1.0",
9
+ "omegaconf",
10
+ "pandas",
11
+ "pytest",
12
+ "spaces",
13
+ "torch",
14
+ ]
15
+
16
+ [tool.pytest.ini_options]
17
+ testpaths = ["tests"]
18
+ pythonpath = ["."]
19
+
20
+ [tool.ruff]
21
+ line-length = 120
22
+ exclude = ["chroma"]
23
+
24
+ [tool.ruff.lint]
25
+ select = ["E", "F", "I"]
26
+ ignore = ["E731"]
27
+
28
+ [tool.ruff.format]
29
+ quote-style = "preserve" # avoid churning quotes
30
+ indent-style = "space"
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ caliby[af2] @ git+https://github.com/ProteinDesignLab/caliby@20d6757aaaba1662e71234ba25dde0f64b199683
2
+ gradio
3
+ huggingface_hub
4
+ molview>=0.1.0
5
+ omegaconf
6
+ pandas
7
+ spaces
8
+ torch
self_consistency.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """AF2 self-consistency evaluation."""
2
+
3
+ from pathlib import Path
4
+
5
+ import pandas as pd
6
+ from caliby import CalibyModel
7
+
8
+ from file_utils import _write_zip_from_dir
9
+
10
+
11
+ def _run_self_consistency(
12
+ model: CalibyModel,
13
+ df: pd.DataFrame,
14
+ out_pdb_list: list[str],
15
+ out_dir: Path,
16
+ download_stem: str,
17
+ ) -> tuple[str, dict[str, str]]:
18
+ from caliby.eval.eval_utils.folding_utils import clear_mem_torch
19
+
20
+ clear_mem_torch()
21
+
22
+ sc_out_dir = out_dir / "self_consistency"
23
+ id_to_metrics = model.self_consistency_eval(out_pdb_list, out_dir=str(sc_out_dir))
24
+
25
+ for metric in ["sc_ca_rmsd", "avg_ca_plddt", "tmalign_score"]:
26
+ df[metric] = [id_to_metrics.get(Path(path).stem, {}).get(metric, float("nan")) for path in out_pdb_list]
27
+
28
+ af2_pdb_data = {}
29
+ for path in out_pdb_list:
30
+ af2_path = sc_out_dir / "struct_preds" / f"af2_{Path(path).stem}.pdb"
31
+ if af2_path.exists():
32
+ af2_pdb_data[Path(path).stem] = af2_path.read_text()
33
+
34
+ sc_zip_path = _write_zip_from_dir(sc_out_dir, download_stem, "_self_consistency.zip")
35
+ return sc_zip_path, af2_pdb_data
tests/conftest.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module-level mocking for app.py import-time side effects."""
2
+
3
+ import tempfile
4
+ from unittest.mock import patch
5
+
6
+ import pytest
7
+
8
+ # ---------------------------------------------------------------------------
9
+ # Module-level patches — applied BEFORE any test file can `import app`.
10
+ #
11
+ # app_config.py executes on import:
12
+ # 1. snapshot_download(...) → needs HF_TOKEN + network
13
+ # 2. os.environ.setdefault(...) → safe, no mock needed
14
+ # app.py executes on import:
15
+ # 1. base64-encode caliby_transparent.png → file exists in repo, no mock
16
+ # 2. @spaces.GPU decorator → no-op when SPACES_ZERO_GPU unset
17
+ # ---------------------------------------------------------------------------
18
+
19
+ _FAKE_WEIGHTS_DIR = tempfile.mkdtemp(prefix="caliby_test_weights_")
20
+
21
+ patch("huggingface_hub.snapshot_download", return_value=_FAKE_WEIGHTS_DIR).start()
22
+
23
+
24
+ # ---------------------------------------------------------------------------
25
+ # Shared fixtures
26
+ # ---------------------------------------------------------------------------
27
+
28
+
29
+ @pytest.fixture
30
+ def sample_outputs():
31
+ """Mock outputs dict matching caliby's CalibyModel.sample() return format."""
32
+ return {
33
+ "example_id": ["1YCR", "1YCR"],
34
+ "out_pdb": ["/tmp/out/1YCR_sample0.cif", "/tmp/out/1YCR_sample1.cif"],
35
+ "U": [-142.38, -139.92],
36
+ "input_seq": ["NATIVE_SEQ", "NATIVE_SEQ"],
37
+ "seq": ["MTEEQWAQ", "VSEQQWAQ"],
38
+ }
39
+
40
+
41
+ @pytest.fixture
42
+ def sample_outputs_with_out_pdbs(sample_outputs):
43
+ """Outputs dict with the 'out_pdbs' key that app.py's _format_outputs actually reads."""
44
+ return {**sample_outputs, "out_pdbs": sample_outputs["out_pdb"]}
tests/test_design_sequences.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Integration tests for get_model, _setup_user_ensemble_dir, and design_sequences."""
2
+
3
+ from pathlib import Path
4
+ from unittest.mock import MagicMock, patch
5
+
6
+ import gradio as gr
7
+ import pandas as pd
8
+ import pytest
9
+
10
+ import design
11
+ import ensemble
12
+ import models
13
+
14
+ # ---------------------------------------------------------------------------
15
+ # get_model
16
+ # ---------------------------------------------------------------------------
17
+
18
+
19
+ class TestGetModel:
20
+ """Lazy-loads and caches CalibyModel instances via caliby.load_model."""
21
+
22
+ @pytest.fixture(autouse=True)
23
+ def _clear_model_cache(self):
24
+ models.MODELS.clear()
25
+ yield
26
+ models.MODELS.clear()
27
+
28
+ def test_calls_load_model_with_variant_and_device(self):
29
+ mock_caliby_model = MagicMock()
30
+ with patch("caliby.load_model", return_value=mock_caliby_model) as mock_load:
31
+ result = models.get_model("caliby", "cpu")
32
+
33
+ mock_load.assert_called_once_with("caliby", device="cpu")
34
+ assert result is mock_caliby_model
35
+
36
+ def test_caches_model_on_repeat_call(self):
37
+ mock_caliby_model = MagicMock()
38
+ with patch("caliby.load_model", return_value=mock_caliby_model) as mock_load:
39
+ first = models.get_model("caliby", "cpu")
40
+ second = models.get_model("caliby", "cpu")
41
+ mock_load.assert_called_once()
42
+ assert first is second
43
+
44
+ def test_different_variants_cached_separately(self):
45
+ mock_a = MagicMock()
46
+ mock_b = MagicMock()
47
+ with patch("caliby.load_model", side_effect=[mock_a, mock_b]):
48
+ a = models.get_model("caliby", "cpu")
49
+ b = models.get_model("soluble_caliby_v1", "cpu")
50
+ assert a is mock_a
51
+ assert b is mock_b
52
+
53
+
54
+ # ---------------------------------------------------------------------------
55
+ # _setup_user_ensemble_dir
56
+ # ---------------------------------------------------------------------------
57
+
58
+
59
+ class TestSetupUserEnsembleDir:
60
+ """Builds pdb_to_conformers dict from user-uploaded files."""
61
+
62
+ def test_returns_dict_with_primary_key(self):
63
+ result = ensemble._setup_user_ensemble_dir(["/tmp/primary.pdb", "/tmp/conf1.pdb", "/tmp/conf2.pdb"])
64
+ assert "primary" in result
65
+ assert result["primary"] == ["/tmp/primary.pdb", "/tmp/conf1.pdb", "/tmp/conf2.pdb"]
66
+
67
+ def test_first_file_is_primary(self):
68
+ result = ensemble._setup_user_ensemble_dir(["/tmp/myprotein.cif", "/tmp/alt.pdb"])
69
+ assert result["myprotein"][0] == "/tmp/myprotein.cif"
70
+
71
+ def test_uses_stem_as_key(self):
72
+ result = ensemble._setup_user_ensemble_dir(["/path/to/foo.pdb"])
73
+ assert "foo" in result
74
+
75
+
76
+ # ---------------------------------------------------------------------------
77
+ # design_sequences — validation
78
+ # ---------------------------------------------------------------------------
79
+
80
+
81
+ class TestDesignSequencesValidation:
82
+ """Input validation before any model calls."""
83
+
84
+ def test_no_files(self):
85
+ df, msg, _, _, _, _ = design.design_sequences(None, "none", "caliby", 4, None, 0.1, "", "", "", "", "", 31)
86
+ assert df.empty
87
+ assert "Upload at least one" in msg
88
+
89
+ def test_empty_file_list(self):
90
+ df, msg, _, _, _, _ = design.design_sequences([], "none", "caliby", 4, None, 0.1, "", "", "", "", "", 31)
91
+ assert df.empty
92
+ assert "Upload at least one" in msg
93
+
94
+ def test_single_mode_multiple_files(self):
95
+ df, msg, _, _, _, _ = design.design_sequences(
96
+ ["a.pdb", "b.pdb"], "none", "caliby", 4, None, 0.1, "", "", "", "", "", 31
97
+ )
98
+ assert "exactly one file" in msg
99
+
100
+ def test_synthetic_mode_multiple_files(self):
101
+ df, msg, _, _, _, _ = design.design_sequences(
102
+ ["a.pdb", "b.pdb"], "synthetic", "caliby", 4, None, 0.1, "", "", "", "", "", 31
103
+ )
104
+ assert "exactly one file" in msg
105
+
106
+ def test_user_mode_too_few_files(self):
107
+ df, msg, _, _, _, _ = design.design_sequences(["a.pdb"], "user", "caliby", 4, None, 0.1, "", "", "", "", "", 31)
108
+ assert "at least two" in msg
109
+
110
+
111
+ # ---------------------------------------------------------------------------
112
+ # design_sequences — single structure mode
113
+ # ---------------------------------------------------------------------------
114
+
115
+
116
+ class TestDesignSequencesSingleMode:
117
+ """Tests ensemble_mode='none' — verifies correct args to CalibyModel.sample()."""
118
+
119
+ def _make_mock_outputs(self):
120
+ return {
121
+ "example_id": ["test"],
122
+ "out_pdb": ["/tmp/test_sample0.cif"],
123
+ "U": [-100.0],
124
+ "input_seq": ["NATIVE"],
125
+ "seq": ["ACDEF"],
126
+ }
127
+
128
+ def test_sample_called_with_correct_args(self, tmp_path):
129
+ pdb_file = tmp_path / "test.pdb"
130
+ pdb_file.write_text("FAKE PDB")
131
+
132
+ mock_model = MagicMock()
133
+ mock_model.sample.return_value = self._make_mock_outputs()
134
+
135
+ with (
136
+ patch.object(design, "get_model", return_value=mock_model),
137
+ patch.object(design, "_write_zip_from_paths", return_value=None),
138
+ patch("torch.cuda.is_available", return_value=False),
139
+ ):
140
+ design.design_sequences(
141
+ [str(pdb_file)],
142
+ "none",
143
+ "caliby",
144
+ 4,
145
+ ["C"],
146
+ 0.5,
147
+ "A1-100",
148
+ "A1-10",
149
+ "A26:A",
150
+ "A26:AVG",
151
+ "A10,B10",
152
+ 31,
153
+ )
154
+
155
+ mock_model.sample.assert_called_once()
156
+ args, kwargs = mock_model.sample.call_args
157
+
158
+ # First positional arg is pdb_paths
159
+ assert isinstance(args[0], list)
160
+ assert len(args[0]) == 1
161
+ assert args[0][0].endswith("test.pdb")
162
+
163
+ assert kwargs["num_seqs_per_pdb"] == 4
164
+ assert kwargs["omit_aas"] == ["C"]
165
+ assert kwargs["temperature"] == 0.5
166
+ assert kwargs["num_workers"] == 0
167
+ assert isinstance(kwargs["out_dir"], str)
168
+ assert isinstance(kwargs["pos_constraint_df"], pd.DataFrame)
169
+ assert kwargs["pos_constraint_df"].iloc[0]["pdb_key"] == "test"
170
+
171
+ def test_no_constraints_passes_none(self, tmp_path):
172
+ pdb_file = tmp_path / "test.pdb"
173
+ pdb_file.write_text("FAKE")
174
+
175
+ mock_model = MagicMock()
176
+ mock_model.sample.return_value = self._make_mock_outputs()
177
+
178
+ with (
179
+ patch.object(design, "get_model", return_value=mock_model),
180
+ patch.object(design, "_write_zip_from_paths", return_value=None),
181
+ patch("torch.cuda.is_available", return_value=False),
182
+ ):
183
+ design.design_sequences(
184
+ [str(pdb_file)],
185
+ "none",
186
+ "caliby",
187
+ 1,
188
+ None,
189
+ 0.1,
190
+ "",
191
+ "",
192
+ "",
193
+ "",
194
+ "",
195
+ 31,
196
+ )
197
+ assert mock_model.sample.call_args[1]["pos_constraint_df"] is None
198
+
199
+ def test_empty_omit_aas_becomes_none(self, tmp_path):
200
+ pdb_file = tmp_path / "test.pdb"
201
+ pdb_file.write_text("FAKE")
202
+
203
+ mock_model = MagicMock()
204
+ mock_model.sample.return_value = self._make_mock_outputs()
205
+
206
+ with (
207
+ patch.object(design, "get_model", return_value=mock_model),
208
+ patch.object(design, "_write_zip_from_paths", return_value=None),
209
+ patch("torch.cuda.is_available", return_value=False),
210
+ ):
211
+ design.design_sequences(
212
+ [str(pdb_file)],
213
+ "none",
214
+ "caliby",
215
+ 1,
216
+ [],
217
+ 0.1,
218
+ "",
219
+ "",
220
+ "",
221
+ "",
222
+ "",
223
+ 31,
224
+ )
225
+ assert mock_model.sample.call_args[1]["omit_aas"] is None
226
+
227
+
228
+ # ---------------------------------------------------------------------------
229
+ # design_sequences — user ensemble mode
230
+ # ---------------------------------------------------------------------------
231
+
232
+
233
+ class TestDesignSequencesUserEnsembleMode:
234
+ """Tests ensemble_mode='user' — verifies correct args to CalibyModel.ensemble_sample()."""
235
+
236
+ def _make_mock_outputs(self):
237
+ return {
238
+ "example_id": ["primary"],
239
+ "out_pdb": ["/tmp/primary_sample0.cif"],
240
+ "U": [-100.0],
241
+ "input_seq": ["NATIVE"],
242
+ "seq": ["AAA"],
243
+ }
244
+
245
+ def test_calls_ensemble_sample(self, tmp_path):
246
+ pdb1 = tmp_path / "primary.pdb"
247
+ pdb2 = tmp_path / "conf1.pdb"
248
+ pdb1.write_text("PDB1")
249
+ pdb2.write_text("PDB2")
250
+
251
+ mock_model = MagicMock()
252
+ mock_model.ensemble_sample.return_value = self._make_mock_outputs()
253
+ mock_pdb_to_conf = {"primary": ["/some/primary.pdb", "/some/conf1.pdb"]}
254
+
255
+ with (
256
+ patch.object(design, "get_model", return_value=mock_model),
257
+ patch.object(design, "_setup_user_ensemble_dir", return_value=mock_pdb_to_conf),
258
+ patch.object(design, "_write_zip_from_paths", return_value=None),
259
+ patch("torch.cuda.is_available", return_value=False),
260
+ ):
261
+ design.design_sequences([str(pdb1), str(pdb2)], "user", "caliby", 4, None, 0.1, "", "", "", "", "", 31)
262
+
263
+ mock_model.ensemble_sample.assert_called_once()
264
+ args, kwargs = mock_model.ensemble_sample.call_args
265
+ assert args[0] is mock_pdb_to_conf
266
+ assert kwargs["pos_constraint_df"] is None
267
+
268
+ def test_constraints_expand_via_make_ensemble_constraints(self, tmp_path):
269
+ pdb1 = tmp_path / "primary.pdb"
270
+ pdb2 = tmp_path / "conf1.pdb"
271
+ pdb1.write_text("PDB1")
272
+ pdb2.write_text("PDB2")
273
+
274
+ mock_model = MagicMock()
275
+ mock_model.ensemble_sample.return_value = self._make_mock_outputs()
276
+ mock_pdb_to_conf = {"primary": ["a.pdb", "b.pdb"]}
277
+ expanded_df = pd.DataFrame({"pdb_key": ["a", "b"], "fixed_pos_seq": ["A1-10", "A1-10"]})
278
+
279
+ with (
280
+ patch.object(design, "get_model", return_value=mock_model),
281
+ patch.object(design, "_setup_user_ensemble_dir", return_value=mock_pdb_to_conf),
282
+ patch("caliby.make_ensemble_constraints", return_value=expanded_df) as mock_expand,
283
+ patch.object(design, "_write_zip_from_paths", return_value=None),
284
+ patch("torch.cuda.is_available", return_value=False),
285
+ ):
286
+ design.design_sequences([str(pdb1), str(pdb2)], "user", "caliby", 1, None, 0.1, "A1-10", "", "", "", "", 31)
287
+
288
+ mock_expand.assert_called_once()
289
+ constraints_dict, pdb_to_conf_arg = mock_expand.call_args[0]
290
+ assert isinstance(constraints_dict, dict)
291
+ assert "primary" in constraints_dict
292
+ assert constraints_dict["primary"]["fixed_pos_seq"] == "A1-10"
293
+ assert pdb_to_conf_arg is mock_pdb_to_conf
294
+
295
+
296
+ # ---------------------------------------------------------------------------
297
+ # design_sequences — error handling
298
+ # ---------------------------------------------------------------------------
299
+
300
+
301
+ class TestDesignSequencesErrorHandling:
302
+ """Verifies non-validation failures now raise naturally."""
303
+
304
+ def test_value_error(self, tmp_path):
305
+ pdb_file = tmp_path / "test.pdb"
306
+ pdb_file.write_text("PDB")
307
+
308
+ with (
309
+ patch.object(design, "get_model", side_effect=ValueError("bad config")),
310
+ patch("torch.cuda.is_available", return_value=False),
311
+ ):
312
+ with pytest.raises(ValueError, match="bad config"):
313
+ design.design_sequences([str(pdb_file)], "none", "caliby", 1, None, 0.1, "", "", "", "", "", 31)
314
+
315
+ def test_file_not_found(self, tmp_path):
316
+ with (
317
+ patch.object(design, "get_model", side_effect=FileNotFoundError("missing.pdb")),
318
+ patch("torch.cuda.is_available", return_value=False),
319
+ ):
320
+ with pytest.raises(FileNotFoundError, match="missing.pdb"):
321
+ design.design_sequences(
322
+ [str(tmp_path / "ghost.pdb")], "none", "caliby", 1, None, 0.1, "", "", "", "", "", 31
323
+ )
324
+
325
+ def test_unexpected_runtime_error(self, tmp_path):
326
+ pdb_file = tmp_path / "test.pdb"
327
+ pdb_file.write_text("PDB")
328
+
329
+ with (
330
+ patch.object(design, "get_model", side_effect=RuntimeError("GPU OOM")),
331
+ patch("torch.cuda.is_available", return_value=False),
332
+ ):
333
+ with pytest.raises(RuntimeError, match="GPU OOM"):
334
+ design.design_sequences([str(pdb_file)], "none", "caliby", 1, None, 0.1, "", "", "", "", "", 31)
335
+
336
+
337
+ # ---------------------------------------------------------------------------
338
+ # design_sequences — zip output
339
+ # ---------------------------------------------------------------------------
340
+
341
+
342
+ class TestDesignSequencesZipOutput:
343
+ """Tests ZIP file creation from output CIF files."""
344
+
345
+ def test_creates_zip_when_out_pdb_present(self, tmp_path):
346
+ pdb_file = tmp_path / "test.pdb"
347
+ pdb_file.write_text("PDB")
348
+
349
+ out_cif = tmp_path / "test_sample0.cif"
350
+ out_cif.write_text("CIF CONTENT")
351
+
352
+ mock_model = MagicMock()
353
+ mock_model.sample.return_value = {
354
+ "example_id": ["test"],
355
+ "out_pdb": [str(out_cif)],
356
+ "U": [-100.0],
357
+ "input_seq": ["NATIVE"],
358
+ "seq": ["AAA"],
359
+ }
360
+
361
+ with (
362
+ patch.object(design, "get_model", return_value=mock_model),
363
+ patch("torch.cuda.is_available", return_value=False),
364
+ ):
365
+ _, _, zip_path, _, _, _ = design.design_sequences(
366
+ [str(pdb_file)], "none", "caliby", 1, None, 0.1, "", "", "", "", "", 31
367
+ )
368
+
369
+ assert zip_path is not None
370
+ assert Path(zip_path).name == "test_designs.zip"
371
+ assert Path(zip_path).exists()
372
+
373
+ def test_empty_out_pdb_raises_for_invalid_caliby_output(self, tmp_path):
374
+ pdb_file = tmp_path / "test.pdb"
375
+ pdb_file.write_text("PDB")
376
+
377
+ mock_model = MagicMock()
378
+ mock_model.sample.return_value = {
379
+ "example_id": ["test"],
380
+ "out_pdb": [],
381
+ "U": [-100.0],
382
+ "input_seq": ["NATIVE"],
383
+ "seq": ["AAA"],
384
+ }
385
+
386
+ with (
387
+ patch.object(design, "get_model", return_value=mock_model),
388
+ patch("torch.cuda.is_available", return_value=False),
389
+ ):
390
+ with pytest.raises(ValueError, match="All arrays must be of the same length"):
391
+ design.design_sequences([str(pdb_file)], "none", "caliby", 1, None, 0.1, "", "", "", "", "", 31)
392
+
393
+
394
+ # ---------------------------------------------------------------------------
395
+ # design_sequences — ZeroGPU quota-aware retry
396
+ # ---------------------------------------------------------------------------
397
+
398
+
399
+ class TestParseQuotaLeft:
400
+ """Tests _parse_quota_left regex parsing of ZeroGPU error messages."""
401
+
402
+ def test_extracts_remaining_seconds(self):
403
+ e = gr.Error("You have exceeded your free GPU quota (210s requested vs. 45s left). Try again in 0:02:45")
404
+ assert design._parse_quota_left(e) == 45
405
+
406
+ def test_extracts_zero_remaining(self):
407
+ e = gr.Error("(210s requested vs. 0s left). Try again in 0:03:30")
408
+ assert design._parse_quota_left(e) == 0
409
+
410
+ def test_returns_none_for_non_quota_error(self):
411
+ e = gr.Error("Some other error")
412
+ assert design._parse_quota_left(e) is None
413
+
414
+ def test_returns_none_for_no_message_attr(self):
415
+ e = RuntimeError("no message attribute")
416
+ assert design._parse_quota_left(e) is None
417
+
418
+
419
+ class TestDesignSequencesQuotaRetry:
420
+ """Tests ZeroGPU quota-aware retry logic in design_sequences wrapper."""
421
+
422
+ _DESIGN_ARGS = (None, "none", "caliby", 4, None, 0.1, "", "", "", "", "", 31)
423
+
424
+ def test_retry_on_quota_exceeded(self, tmp_path):
425
+ pdb_file = tmp_path / "test.pdb"
426
+ pdb_file.write_text("PDB")
427
+
428
+ mock_model = MagicMock()
429
+ mock_model.sample.return_value = {
430
+ "example_id": ["test"],
431
+ "out_pdb": ["/tmp/t.cif"],
432
+ "U": [-100.0],
433
+ "input_seq": ["N"],
434
+ "seq": ["A"],
435
+ }
436
+
437
+ quota_error = gr.Error("(210s requested vs. 45s left). Try again in 0:02:45")
438
+
439
+ call_count = 0
440
+ original_fn = design._design_sequences_gpu
441
+
442
+ def side_effect(*args, **kwargs):
443
+ nonlocal call_count
444
+ call_count += 1
445
+ if call_count == 1:
446
+ raise quota_error
447
+ return original_fn(*args, **kwargs)
448
+
449
+ with (
450
+ patch.object(design, "_design_sequences_gpu", side_effect=side_effect),
451
+ patch.object(design, "get_model", return_value=mock_model),
452
+ patch.object(design, "_write_zip_from_paths", return_value=None),
453
+ patch("torch.cuda.is_available", return_value=False),
454
+ ):
455
+ design.design_sequences([str(pdb_file)], "none", "caliby", 1, None, 0.1, "", "", "", "", "", 31)
456
+ assert call_count == 2
457
+ assert design._gpu_duration_override is None # Reset after retry
458
+
459
+ def test_no_retry_when_remaining_zero(self):
460
+ quota_error = gr.Error("(210s requested vs. 0s left). Try again in 0:03:30")
461
+ with patch.object(design, "_design_sequences_gpu", side_effect=quota_error):
462
+ with pytest.raises(gr.Error):
463
+ design.design_sequences(*self._DESIGN_ARGS)
464
+
465
+ def test_no_retry_for_non_quota_gr_error(self):
466
+ other_error = gr.Error("The requested GPU duration (210s) is larger than the maximum allowed")
467
+ with patch.object(design, "_design_sequences_gpu", side_effect=other_error):
468
+ with pytest.raises(gr.Error, match="larger than the maximum allowed"):
469
+ design.design_sequences(*self._DESIGN_ARGS)
470
+
471
+ def test_non_gradio_errors_propagate(self):
472
+ """ValueError, RuntimeError etc. are not caught by the retry logic."""
473
+ with patch.object(design, "_design_sequences_gpu", side_effect=ValueError("bad")):
474
+ with pytest.raises(ValueError, match="bad"):
475
+ design.design_sequences(*self._DESIGN_ARGS)
tests/test_helpers.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for helper functions."""
2
+
3
+ import types
4
+ from pathlib import Path
5
+
6
+ import pandas as pd
7
+
8
+ import constraints
9
+ import design
10
+ import file_utils
11
+ import viewers
12
+
13
+ # ---------------------------------------------------------------------------
14
+ # _get_file_path
15
+ # ---------------------------------------------------------------------------
16
+
17
+
18
+ class TestGetFilePath:
19
+ """Normalizes Gradio's various file input formats to a Path."""
20
+
21
+ def test_string_input(self):
22
+ assert file_utils._get_file_path("/some/path.pdb") == Path("/some/path.pdb")
23
+
24
+ def test_object_with_path_attr(self):
25
+ obj = types.SimpleNamespace(path="/uploads/file.pdb")
26
+ assert file_utils._get_file_path(obj) == Path("/uploads/file.pdb")
27
+
28
+ def test_dict_with_path_key(self):
29
+ result = file_utils._get_file_path({"path": "/uploads/file.pdb", "name": "file.pdb"})
30
+ assert result == Path("/uploads/file.pdb")
31
+
32
+ def test_fallback_to_str(self):
33
+ assert file_utils._get_file_path(42) == Path("42")
34
+
35
+
36
+ # ---------------------------------------------------------------------------
37
+ # _build_pos_constraint_df
38
+ # ---------------------------------------------------------------------------
39
+
40
+
41
+ class TestBuildPosConstraintDf:
42
+ """Builds a positional constraint DataFrame for caliby."""
43
+
44
+ def test_all_empty_returns_none(self):
45
+ assert constraints._build_pos_constraint_df("1YCR", "", "", "", "", "") is None
46
+
47
+ def test_all_whitespace_returns_none(self):
48
+ assert constraints._build_pos_constraint_df("1YCR", " ", " ", " ", " ", " ") is None
49
+
50
+ def test_single_field_populated(self):
51
+ df = constraints._build_pos_constraint_df("1YCR", "A1-100", "", "", "", "")
52
+ assert df is not None
53
+ assert len(df) == 1
54
+ assert df.iloc[0]["pdb_key"] == "1YCR"
55
+ assert df.iloc[0]["fixed_pos_seq"] == "A1-100"
56
+ # Only populated columns + pdb_key should be present
57
+ assert "fixed_pos_scn" not in df.columns
58
+
59
+ def test_all_fields_populated(self):
60
+ df = constraints._build_pos_constraint_df("X", "A1", "B2", "A3:G", "A4:V", "A5,B5")
61
+ assert set(df.columns) == {
62
+ "pdb_key",
63
+ "fixed_pos_seq",
64
+ "fixed_pos_scn",
65
+ "fixed_pos_override_seq",
66
+ "pos_restrict_aatype",
67
+ "symmetry_pos",
68
+ }
69
+
70
+ def test_columns_match_caliby_valid_columns(self):
71
+ """All columns must be in caliby's _VALID_POS_CONSTRAINT_COLUMNS."""
72
+ valid = {
73
+ "pdb_key",
74
+ "fixed_pos_seq",
75
+ "fixed_pos_scn",
76
+ "fixed_pos_override_seq",
77
+ "pos_restrict_aatype",
78
+ "symmetry_pos",
79
+ }
80
+ df = constraints._build_pos_constraint_df("X", "A1", "B2", "A3:G", "A4:V", "A5,B5")
81
+ assert set(df.columns).issubset(valid)
82
+
83
+
84
+ # ---------------------------------------------------------------------------
85
+ # _df_to_csv
86
+ # ---------------------------------------------------------------------------
87
+
88
+
89
+ class TestDfToCsv:
90
+ """Writes a DataFrame to a temp CSV file."""
91
+
92
+ def test_none_returns_none(self):
93
+ assert file_utils._df_to_csv(None) is None
94
+
95
+ def test_empty_dataframe_returns_none(self):
96
+ assert file_utils._df_to_csv(pd.DataFrame()) is None
97
+
98
+ def test_valid_dataframe_roundtrips(self):
99
+ df = pd.DataFrame({"pdb_key": ["1YCR"], "fixed_pos_seq": ["A1-100"]})
100
+ path = file_utils._df_to_csv(df)
101
+ assert path is not None
102
+ assert Path(path).exists()
103
+ assert path.endswith(".csv")
104
+ loaded = pd.read_csv(path)
105
+ pd.testing.assert_frame_equal(df, loaded)
106
+
107
+ def test_uses_sample_name_for_csv_basename(self):
108
+ df = pd.DataFrame(
109
+ {
110
+ "Sample": ["1YCR_sample0"],
111
+ "Sequence": ["ACDE"],
112
+ "Energy (U)": [-1.0],
113
+ }
114
+ )
115
+
116
+ path = file_utils._df_to_csv(df)
117
+
118
+ assert path is not None
119
+ assert Path(path).name == "1YCR_results.csv"
120
+
121
+
122
+ class TestCsvDownloadOutput:
123
+ """Formats CSV downloads for the Gradio file component."""
124
+
125
+ def test_hides_component_for_empty_dataframe(self):
126
+ update = viewers._csv_download_output(pd.DataFrame())
127
+
128
+ assert update["visible"] is False
129
+ assert update["value"] is None
130
+
131
+ def test_shows_named_csv_for_results_dataframe(self):
132
+ df = pd.DataFrame(
133
+ {
134
+ "Sample": ["1YCR_sample0"],
135
+ "Sequence": ["ACDE"],
136
+ "Energy (U)": [-1.0],
137
+ }
138
+ )
139
+
140
+ update = viewers._csv_download_output(df)
141
+
142
+ assert update["visible"] is True
143
+ assert Path(update["value"]).name == "1YCR_results.csv"
144
+
145
+
146
+ class TestFormatResultsDisplay:
147
+ """Formats the on-screen results table without changing the raw dataframe."""
148
+
149
+ def test_formats_last_four_numeric_columns(self):
150
+ df = pd.DataFrame(
151
+ {
152
+ "Sample": ["1YCR_sample0"],
153
+ "Sequence": ["ACDE"],
154
+ "Energy (U)": [-1.2345],
155
+ "sc_ca_rmsd": [1.0],
156
+ "avg_ca_plddt": [88.888],
157
+ "tmalign_score": [0.12345],
158
+ }
159
+ )
160
+
161
+ styler = viewers._format_results_display(df)
162
+ html = styler.to_html()
163
+
164
+ assert "-1.23" in html
165
+ assert ">1<" in html
166
+ assert "88.89" in html
167
+ assert "0.12" in html
168
+
169
+
170
+ # ---------------------------------------------------------------------------
171
+ # _format_outputs
172
+ # ---------------------------------------------------------------------------
173
+
174
+
175
+ class TestFormatOutputs:
176
+ """Formats caliby output dict into (DataFrame, FASTA, out_pdb_list)."""
177
+
178
+ def test_dataframe_structure(self, sample_outputs_with_out_pdbs):
179
+ df, _, _ = design._format_outputs(sample_outputs_with_out_pdbs)
180
+ assert list(df.columns) == ["Sample", "Sequence", "Energy (U)"]
181
+ assert len(df) == 2
182
+
183
+ def test_sample_names_from_path_stems(self, sample_outputs_with_out_pdbs):
184
+ df, _, _ = design._format_outputs(sample_outputs_with_out_pdbs)
185
+ assert list(df["Sample"]) == ["1YCR_sample0", "1YCR_sample1"]
186
+
187
+ def test_fasta_format(self, sample_outputs_with_out_pdbs):
188
+ _, fasta, _ = design._format_outputs(sample_outputs_with_out_pdbs)
189
+ lines = fasta.strip().split("\n")
190
+ assert lines[0] == ">1YCR_sample0"
191
+ assert lines[1] == "MTEEQWAQ"
192
+ assert lines[2] == ">1YCR_sample1"
193
+ assert lines[3] == "VSEQQWAQ"
194
+
195
+ def test_uses_caliby_out_pdb_key(self, sample_outputs):
196
+ assert "out_pdbs" not in sample_outputs
197
+ df, fasta, out_pdb_list = design._format_outputs(sample_outputs)
198
+
199
+ assert list(df["Sample"]) == ["1YCR_sample0", "1YCR_sample1"]
200
+ assert ">1YCR_sample0" in fasta
201
+ assert out_pdb_list == sample_outputs["out_pdb"]
202
+
203
+
204
+ # ---------------------------------------------------------------------------
205
+ # _get_best_sc_sample
206
+ # ---------------------------------------------------------------------------
207
+
208
+
209
+ class TestGetBestScSample:
210
+ """Picks the sample with the highest tmalign_score."""
211
+
212
+ def test_picks_highest_tmalign_score(self):
213
+ df = pd.DataFrame(
214
+ {
215
+ "Sample": ["1YCR_sample0", "1YCR_sample1", "1YCR_sample2"],
216
+ "tmalign_score": [0.5, 0.9, 0.7],
217
+ }
218
+ )
219
+ assert viewers._get_best_sc_sample(df) == "1YCR_sample1"
220
+
221
+ def test_falls_back_to_first_when_no_tmalign(self):
222
+ df = pd.DataFrame({"Sample": ["1YCR_sample0", "1YCR_sample1"]})
223
+ assert viewers._get_best_sc_sample(df) == "1YCR_sample0"
224
+
225
+ def test_falls_back_to_first_when_all_nan(self):
226
+ df = pd.DataFrame(
227
+ {
228
+ "Sample": ["A_sample0", "A_sample1"],
229
+ "tmalign_score": [float("nan"), float("nan")],
230
+ }
231
+ )
232
+ assert viewers._get_best_sc_sample(df) == "A_sample0"
233
+
234
+ def test_returns_none_for_empty_df(self):
235
+ assert viewers._get_best_sc_sample(pd.DataFrame()) is None
236
+
237
+
238
+ # ---------------------------------------------------------------------------
239
+ # _render_af2_viewer / _render_reference_viewer
240
+ # ---------------------------------------------------------------------------
241
+
242
+ _MINIMAL_PDB = "ATOM 1 CA ALA A 1 0.000 0.000 0.000 1.00 90.00 C\nEND\n"
243
+
244
+
245
+ class TestRenderAf2Viewer:
246
+ """Renders AF2 prediction with pLDDT coloring via molview."""
247
+
248
+ def test_returns_html_with_valid_data(self):
249
+ html = viewers._render_af2_viewer("test_sample0", {"test_sample0": _MINIMAL_PDB})
250
+ assert "iframe" in html
251
+
252
+ def test_returns_empty_for_missing_sample(self):
253
+ assert viewers._render_af2_viewer("missing", {"other": _MINIMAL_PDB}) == ""
254
+
255
+ def test_returns_empty_for_none_sample(self):
256
+ assert viewers._render_af2_viewer(None, {"test": _MINIMAL_PDB}) == ""
257
+
258
+ def test_returns_empty_for_empty_data(self):
259
+ assert viewers._render_af2_viewer("test", {}) == ""
260
+
261
+
262
+ class TestRenderReferenceViewer:
263
+ """Renders original input PDB with chain coloring via molview."""
264
+
265
+ def test_maps_sample_to_input_key(self):
266
+ html = viewers._render_reference_viewer("1YCR_sample0", {"1YCR": _MINIMAL_PDB})
267
+ assert "iframe" in html
268
+
269
+ def test_returns_empty_when_input_key_missing(self):
270
+ assert viewers._render_reference_viewer("1YCR_sample0", {"OTHER": _MINIMAL_PDB}) == ""
271
+
272
+ def test_returns_empty_for_none_sample(self):
273
+ assert viewers._render_reference_viewer(None, {"1YCR": _MINIMAL_PDB}) == ""
274
+
275
+
276
+ # ---------------------------------------------------------------------------
277
+ # _update_viewers
278
+ # ---------------------------------------------------------------------------
279
+
280
+
281
+ class TestUpdateViewers:
282
+ """Combined handler for overlay toggle."""
283
+
284
+ def test_overlay_off_hides_reference(self):
285
+ af2_html, ref_update = viewers._update_viewers("s0", {"s0": _MINIMAL_PDB}, {"s": _MINIMAL_PDB}, False)
286
+ assert "iframe" in af2_html
287
+ assert ref_update["visible"] is False
288
+
289
+ def test_overlay_on_shows_reference(self):
290
+ af2_html, ref_update = viewers._update_viewers(
291
+ "s_sample0", {"s_sample0": _MINIMAL_PDB}, {"s": _MINIMAL_PDB}, True
292
+ )
293
+ assert "iframe" in af2_html
294
+ assert ref_update["visible"] is True
295
+ assert "iframe" in ref_update["value"]
viewers.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Structure viewers and display formatting for the Gradio UI."""
2
+
3
+ import re
4
+
5
+ import gradio as gr
6
+ import pandas as pd
7
+
8
+ from file_utils import _df_to_csv
9
+
10
+
11
+ def _format_display_number(value) -> str:
12
+ if pd.isna(value):
13
+ return ""
14
+ return f"{float(value):.2f}".rstrip("0").rstrip(".")
15
+
16
+
17
+ def _file_output(value: str | None) -> dict:
18
+ return gr.update(value=value, visible=bool(value))
19
+
20
+
21
+ def _csv_download_output(df: pd.DataFrame | None) -> dict:
22
+ return _file_output(_df_to_csv(df))
23
+
24
+
25
+ def _format_results_display(df: pd.DataFrame):
26
+ numeric_columns = [col for col in df.columns[-4:] if pd.api.types.is_numeric_dtype(df[col])]
27
+ if not numeric_columns:
28
+ return df
29
+ return df.style.format({col: _format_display_number for col in numeric_columns})
30
+
31
+
32
+ def _get_best_sc_sample(df: pd.DataFrame) -> str | None:
33
+ if df.empty or "Sample" not in df.columns:
34
+ return None
35
+ if "tmalign_score" in df.columns and df["tmalign_score"].notna().any():
36
+ return str(df.loc[df["tmalign_score"].idxmax(), "Sample"])
37
+ return str(df.iloc[0]["Sample"])
38
+
39
+
40
+ def _render_af2_viewer(sample_name: str | None, af2_pdb_data: dict[str, str], color_mode: str = "plddt") -> str:
41
+ if not sample_name or not af2_pdb_data or sample_name not in af2_pdb_data:
42
+ return ""
43
+ import molview as mv
44
+
45
+ v = mv.view(width=840, height=500)
46
+ v.addModel(af2_pdb_data[sample_name], name=f"AF2: {sample_name}")
47
+ v.setColorMode(color_mode)
48
+ v.setBackgroundColor("#000000")
49
+ return v._repr_html_()
50
+
51
+
52
+ def _render_reference_viewer(sample_name: str | None, input_pdb_data: dict[str, str], color_mode: str = "chain") -> str:
53
+ if not sample_name or not input_pdb_data:
54
+ return ""
55
+ input_key = re.sub(r"_sample\d+$", "", sample_name)
56
+ if input_key not in input_pdb_data:
57
+ return ""
58
+ import molview as mv
59
+
60
+ v = mv.view(width=840, height=500)
61
+ v.addModel(input_pdb_data[input_key], name=f"Reference: {input_key}")
62
+ v.setColorMode(color_mode)
63
+ v.setBackgroundColor("#000000")
64
+ return v._repr_html_()
65
+
66
+
67
+ def _update_viewers(
68
+ best_sample: str,
69
+ af2_pdb_data: dict[str, str],
70
+ input_pdb_data: dict[str, str],
71
+ show_overlay: bool,
72
+ color_mode: str = "plddt",
73
+ ref_color_mode: str = "chain",
74
+ ):
75
+ af2_html = _render_af2_viewer(best_sample, af2_pdb_data, color_mode)
76
+ if show_overlay:
77
+ ref_html = _render_reference_viewer(best_sample, input_pdb_data, ref_color_mode)
78
+ return af2_html, gr.update(value=ref_html, visible=True), gr.update(visible=True)
79
+ return af2_html, gr.update(value="", visible=False), gr.update(visible=False)