Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,174 +1,179 @@
|
|
| 1 |
from pathlib import Path
|
| 2 |
import ast
|
| 3 |
import json
|
|
|
|
| 4 |
import pandas as pd
|
| 5 |
import numpy as np
|
| 6 |
-
import traceback
|
| 7 |
|
| 8 |
import gradio as gr
|
| 9 |
-
from datasets import load_dataset
|
| 10 |
from datetime import datetime
|
| 11 |
import os
|
| 12 |
-
|
| 13 |
-
import pandas as pd
|
| 14 |
from about import (
|
| 15 |
PROBLEM_TYPES, TOKEN, CACHE_PATH, API, submissions_repo, results_repo,
|
| 16 |
COLUMN_DISPLAY_NAMES, COUNT_BASED_METRICS, METRIC_GROUPS,
|
| 17 |
METRIC_GROUP_COLORS, COLUMN_TO_GROUP, TRAINING_DATASETS
|
| 18 |
)
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
def get_leaderboard():
|
| 21 |
-
ds = load_dataset(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
full_df = pd.DataFrame(ds)
|
| 23 |
-
|
| 24 |
if len(full_df) == 0:
|
| 25 |
-
return pd.DataFrame(
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
full_df['msun_plus_sun'] = full_df['msun_count'] + full_df['sun_count']
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
full_df = full_df.sort_values(by='msun_plus_sun', ascending=False)
|
| 34 |
|
| 35 |
return full_df
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
f for f in list_repo_files(results_repo, repo_type="dataset")
|
| 39 |
-
if f.endswith(".csv")
|
| 40 |
-
]
|
| 41 |
-
for file in sorted(files):
|
| 42 |
-
path = hf_hub_download(results_repo, filename=file, repo_type="dataset")
|
| 43 |
-
df = pd.read_csv(path)
|
| 44 |
-
paper_dtype = str(df["paper_link"].dtype) if "paper_link" in df else "MISSING"
|
| 45 |
-
notes_dtype = str(df["notes"].dtype) if "notes" in df else "MISSING"
|
| 46 |
-
paper_value = df["paper_link"].iloc[0] if "paper_link" in df and len(df) else None
|
| 47 |
-
notes_value = df["notes"].iloc[0] if "notes" in df and len(df) else None
|
| 48 |
-
print(
|
| 49 |
-
file,
|
| 50 |
-
"paper_link:", paper_dtype, repr(paper_value),
|
| 51 |
-
"notes:", notes_dtype, repr(notes_value),
|
| 52 |
-
flush=True,
|
| 53 |
-
)
|
| 54 |
def format_dataframe(df, show_percentage=False, selected_groups=None, compact_view=True):
|
| 55 |
"""Format the dataframe with proper column names and optional percentages."""
|
| 56 |
if len(df) == 0:
|
| 57 |
return df
|
| 58 |
|
| 59 |
-
|
| 60 |
-
selected_cols = ['model_name']
|
| 61 |
|
| 62 |
if compact_view:
|
| 63 |
-
# Use predefined compact columns
|
| 64 |
from about import COMPACT_VIEW_COLUMNS
|
| 65 |
selected_cols = [col for col in COMPACT_VIEW_COLUMNS if col in df.columns]
|
| 66 |
else:
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
selected_cols.append('n_structures')
|
| 72 |
|
| 73 |
-
# If no groups selected, show all
|
| 74 |
if not selected_groups:
|
| 75 |
selected_groups = list(METRIC_GROUPS.keys())
|
| 76 |
|
| 77 |
-
# Add columns from selected groups
|
| 78 |
for group in selected_groups:
|
| 79 |
if group in METRIC_GROUPS:
|
| 80 |
for col in METRIC_GROUPS[group]:
|
| 81 |
if col in df.columns and col not in selected_cols:
|
| 82 |
selected_cols.append(col)
|
| 83 |
|
| 84 |
-
# Create a copy with selected columns
|
| 85 |
display_df = df[selected_cols].copy()
|
| 86 |
|
| 87 |
-
|
| 88 |
-
if 'model_name' in display_df.columns:
|
| 89 |
-
# Model links mapping
|
| 90 |
model_links = {
|
| 91 |
-
|
| 92 |
-
|
| 93 |
}
|
| 94 |
|
| 95 |
def add_model_symbols(row):
|
| 96 |
-
name = row[
|
| 97 |
symbols = []
|
| 98 |
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
paper_val = row.get('paper_link', None)
|
| 102 |
if paper_val and isinstance(paper_val, str) and paper_val.strip():
|
| 103 |
symbols.append(f'<a href="{paper_val.strip()}" target="_blank">📄</a>')
|
| 104 |
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
symbols.append('⚡')
|
| 108 |
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
symbols.append(
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
symbols.append('◆')
|
| 116 |
-
elif name == 'CrystaLLM-pi' or name == 'OMatG' or name == 'Zatom-1-WD':
|
| 117 |
-
symbols.append('✅')
|
| 118 |
|
| 119 |
symbol_str = f" {' '.join(symbols)}" if symbols else ""
|
| 120 |
|
| 121 |
-
# Add link if model has one
|
| 122 |
if name in model_links:
|
| 123 |
return f'<a href="{model_links[name]}" target="_blank">{name}</a>{symbol_str}'
|
| 124 |
return f"{name}{symbol_str}"
|
| 125 |
|
| 126 |
-
display_df[
|
| 127 |
|
| 128 |
-
|
| 129 |
-
if 'training_set' in display_df.columns:
|
| 130 |
def format_training_set(val):
|
| 131 |
if val is None or (isinstance(val, float) and np.isnan(val)):
|
| 132 |
-
return
|
| 133 |
val = str(val).strip()
|
| 134 |
-
if val in (
|
| 135 |
-
return
|
| 136 |
-
|
| 137 |
-
val = val.
|
| 138 |
-
val = val.replace("'", "").replace('"', '')
|
| 139 |
return val
|
| 140 |
-
display_df['training_set'] = display_df['training_set'].apply(format_training_set)
|
| 141 |
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
|
|
|
| 145 |
for col in COUNT_BASED_METRICS:
|
| 146 |
if col in display_df.columns:
|
| 147 |
-
|
| 148 |
-
display_df[col] = (df[col] / n_structures * 100).round(1).astype(str) + '%'
|
| 149 |
|
| 150 |
-
# Round numeric columns for cleaner display
|
| 151 |
for col in display_df.columns:
|
| 152 |
-
if display_df[col].dtype in [
|
| 153 |
display_df[col] = display_df[col].round(4)
|
| 154 |
|
| 155 |
-
# Separate baseline models to the bottom
|
| 156 |
baseline_indices = set()
|
| 157 |
-
if
|
| 158 |
-
is_baseline = df[
|
| 159 |
non_baseline_df = display_df[~is_baseline]
|
| 160 |
baseline_df = display_df[is_baseline]
|
| 161 |
display_df = pd.concat([non_baseline_df, baseline_df]).reset_index(drop=True)
|
| 162 |
-
# Track baseline row indices in the new dataframe
|
| 163 |
baseline_indices = set(range(len(non_baseline_df), len(display_df)))
|
| 164 |
|
| 165 |
-
# Rename columns for display
|
| 166 |
display_df = display_df.rename(columns=COLUMN_DISPLAY_NAMES)
|
|
|
|
| 167 |
|
| 168 |
-
# Apply color coding based on metric groups
|
| 169 |
-
styled_df = apply_color_styling(display_df, selected_cols, baseline_indices)
|
| 170 |
-
|
| 171 |
-
return styled_df
|
| 172 |
|
| 173 |
def apply_color_styling(display_df, original_cols, baseline_indices=None):
|
| 174 |
"""Apply background colors to dataframe based on metric groups using pandas Styler."""
|
|
@@ -176,79 +181,70 @@ def apply_color_styling(display_df, original_cols, baseline_indices=None):
|
|
| 176 |
baseline_indices = set()
|
| 177 |
|
| 178 |
def style_by_group(x):
|
| 179 |
-
|
| 180 |
-
styles = pd.DataFrame('', index=x.index, columns=x.columns)
|
| 181 |
|
| 182 |
-
# Map display column names back to original column names
|
| 183 |
for i, display_col in enumerate(x.columns):
|
| 184 |
if i < len(original_cols):
|
| 185 |
original_col = original_cols[i]
|
| 186 |
-
|
| 187 |
-
# Check if this column belongs to a metric group
|
| 188 |
if original_col in COLUMN_TO_GROUP:
|
| 189 |
group = COLUMN_TO_GROUP[original_col]
|
| 190 |
-
color = METRIC_GROUP_COLORS.get(group,
|
| 191 |
if color:
|
| 192 |
-
styles[display_col] = f
|
| 193 |
|
| 194 |
-
# Add thick top border to the first baseline row as a separator
|
| 195 |
if baseline_indices:
|
| 196 |
first_baseline_idx = min(baseline_indices)
|
| 197 |
for col in x.columns:
|
| 198 |
current = styles.at[first_baseline_idx, col]
|
| 199 |
-
separator_style =
|
| 200 |
-
styles.at[first_baseline_idx, col] =
|
|
|
|
|
|
|
| 201 |
|
| 202 |
return styles
|
| 203 |
|
| 204 |
-
# Apply the styling function
|
| 205 |
return display_df.style.apply(style_by_group, axis=None)
|
| 206 |
|
|
|
|
| 207 |
def parse_training_set(val):
|
| 208 |
-
"""Parse a training_set value
|
| 209 |
try:
|
| 210 |
return ast.literal_eval(str(val))
|
| 211 |
except (ValueError, SyntaxError):
|
| 212 |
return []
|
| 213 |
|
| 214 |
-
def update_leaderboard(show_percentage, selected_groups, compact_view, cached_df, sort_by, sort_direction, training_set_filter):
|
| 215 |
-
"""Update the leaderboard based on user selections.
|
| 216 |
|
| 217 |
-
|
| 218 |
-
"""
|
| 219 |
-
# Use cached dataframe instead of re-downloading
|
| 220 |
df_to_format = cached_df.copy()
|
| 221 |
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
if training_set_filter and training_set_filter != "All" and 'training_set' in df_to_format.columns:
|
| 225 |
mask = (
|
| 226 |
-
df_to_format[
|
| 227 |
-
df_to_format[
|
| 228 |
)
|
| 229 |
df_to_format = df_to_format[mask]
|
| 230 |
|
| 231 |
-
# Convert display name back to raw column name for sorting
|
| 232 |
if sort_by and sort_by != "None":
|
| 233 |
-
# Create reverse mapping from display names to raw column names
|
| 234 |
display_to_raw = {v: k for k, v in COLUMN_DISPLAY_NAMES.items()}
|
| 235 |
raw_column_name = display_to_raw.get(sort_by, sort_by)
|
| 236 |
|
| 237 |
if raw_column_name in df_to_format.columns:
|
| 238 |
-
ascending =
|
| 239 |
df_to_format = df_to_format.sort_values(by=raw_column_name, ascending=ascending)
|
| 240 |
|
| 241 |
-
|
| 242 |
-
|
| 243 |
|
| 244 |
def show_output_box(message):
|
| 245 |
return gr.update(value=message, visible=True)
|
| 246 |
|
|
|
|
| 247 |
def submit_cif_files(model_name, problem_type, cif_files, relaxed, relaxation_settings, training_datasets, training_dataset_other, paper_link, hf_model_link, email, profile: gr.OAuthProfile | None):
|
| 248 |
"""Submit structures to the leaderboard."""
|
| 249 |
from huggingface_hub import upload_file
|
| 250 |
|
| 251 |
-
# Validate inputs
|
| 252 |
if not model_name or not model_name.strip():
|
| 253 |
return "Error: Please provide a model name.", None
|
| 254 |
|
|
@@ -268,7 +264,6 @@ def submit_cif_files(model_name, problem_type, cif_files, relaxed, relaxation_se
|
|
| 268 |
username = profile.username
|
| 269 |
timestamp = datetime.now().isoformat()
|
| 270 |
|
| 271 |
-
# Create submission metadata
|
| 272 |
submission_data = {
|
| 273 |
"username": username,
|
| 274 |
"model_name": model_name.strip(),
|
|
@@ -281,13 +276,10 @@ def submit_cif_files(model_name, problem_type, cif_files, relaxed, relaxation_se
|
|
| 281 |
"hf_model_link": hf_model_link.strip() if hf_model_link else None,
|
| 282 |
"email": email.strip(),
|
| 283 |
"timestamp": timestamp,
|
| 284 |
-
"file_name": Path(cif_files).name
|
| 285 |
}
|
| 286 |
|
| 287 |
-
# Create a unique submission ID
|
| 288 |
submission_id = f"{username}_{model_name.strip().replace(' ', '_')}_{timestamp.replace(':', '-')}"
|
| 289 |
-
|
| 290 |
-
# Upload the submission file
|
| 291 |
file_path = Path(cif_files)
|
| 292 |
uploaded_file_path = f"submissions/{submission_id}/{file_path.name}"
|
| 293 |
|
|
@@ -296,13 +288,13 @@ def submit_cif_files(model_name, problem_type, cif_files, relaxed, relaxation_se
|
|
| 296 |
path_in_repo=uploaded_file_path,
|
| 297 |
repo_id=submissions_repo,
|
| 298 |
token=TOKEN,
|
| 299 |
-
repo_type="dataset"
|
| 300 |
)
|
| 301 |
|
| 302 |
-
# Upload metadata as JSON
|
| 303 |
metadata_path = f"submissions/{submission_id}/metadata.json"
|
| 304 |
import tempfile
|
| 305 |
-
|
|
|
|
| 306 |
json.dump(submission_data, f, indent=2)
|
| 307 |
temp_metadata_path = f.name
|
| 308 |
|
|
@@ -311,52 +303,53 @@ def submit_cif_files(model_name, problem_type, cif_files, relaxed, relaxation_se
|
|
| 311 |
path_in_repo=metadata_path,
|
| 312 |
repo_id=submissions_repo,
|
| 313 |
token=TOKEN,
|
| 314 |
-
repo_type="dataset"
|
| 315 |
)
|
| 316 |
|
| 317 |
-
# Clean up temp file
|
| 318 |
os.unlink(temp_metadata_path)
|
| 319 |
|
| 320 |
return f"Success! Submitted {model_name} for {problem_type} evaluation. Submission ID: {submission_id}", submission_id
|
| 321 |
|
| 322 |
except Exception as e:
|
| 323 |
-
return f"Error during submission: {str(e)}", None
|
|
|
|
| 324 |
|
| 325 |
def generate_metric_legend_html():
|
| 326 |
"""Generate HTML table with color-coded metric group legend."""
|
| 327 |
metric_details = {
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
}
|
| 337 |
|
| 338 |
html = '<table style="width: 100%; border-collapse: collapse;">'
|
| 339 |
-
html +=
|
| 340 |
html += '<th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Color</th>'
|
| 341 |
html += '<th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Group</th>'
|
| 342 |
html += '<th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Metrics</th>'
|
| 343 |
html += '<th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Direction</th>'
|
| 344 |
-
html +=
|
| 345 |
|
| 346 |
for group, color in METRIC_GROUP_COLORS.items():
|
| 347 |
-
metrics, direction = metric_details.get(group, (
|
| 348 |
-
group_name = group.replace(
|
| 349 |
|
| 350 |
-
html +=
|
| 351 |
html += f'<td style="border: 1px solid #ddd; padding: 8px;"><div style="width: 30px; height: 20px; background-color: {color}; border: 1px solid #999;"></div></td>'
|
| 352 |
html += f'<td style="border: 1px solid #ddd; padding: 8px;"><strong>{group_name}</strong></td>'
|
| 353 |
html += f'<td style="border: 1px solid #ddd; padding: 8px;">{metrics}</td>'
|
| 354 |
html += f'<td style="border: 1px solid #ddd; padding: 8px;">{direction}</td>'
|
| 355 |
-
html +=
|
| 356 |
|
| 357 |
-
html +=
|
| 358 |
return html
|
| 359 |
|
|
|
|
| 360 |
def gradio_interface() -> gr.Blocks:
|
| 361 |
with gr.Blocks() as demo:
|
| 362 |
gr.Markdown("""
|
|
@@ -368,110 +361,109 @@ Generative machine learning models hold great promise for accelerating materials
|
|
| 368 |
|
| 369 |
📄 **Paper**: [arXiv](https://arxiv.org/abs/2512.04562) | 💻 **Code**: [GitHub](https://github.com/LeMaterial/lemat-genbench) | 📧 **Contact**: siddharth.betala [at] entalpic.ai, alexandre.duval [at] entalpic.ai
|
| 370 |
""")
|
|
|
|
| 371 |
with gr.Tabs(elem_classes="tab-buttons"):
|
| 372 |
with gr.TabItem("🚀 Leaderboard", elem_id="boundary-benchmark-tab-table"):
|
| 373 |
gr.Markdown("# LeMat-GenBench")
|
| 374 |
|
| 375 |
-
# Display options
|
| 376 |
with gr.Row():
|
| 377 |
with gr.Column(scale=1):
|
| 378 |
compact_view = gr.Checkbox(
|
| 379 |
value=True,
|
| 380 |
label="Compact View",
|
| 381 |
-
info="Show only key metrics"
|
| 382 |
)
|
| 383 |
show_percentage = gr.Checkbox(
|
| 384 |
value=True,
|
| 385 |
label="Show as Percentages",
|
| 386 |
-
info="Display count-based metrics as percentages of total structures"
|
| 387 |
)
|
|
|
|
| 388 |
with gr.Column(scale=1):
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
|
|
|
| 392 |
sort_by = gr.Dropdown(
|
| 393 |
choices=sort_choices,
|
| 394 |
value="None",
|
| 395 |
label="Sort By",
|
| 396 |
-
info="Select column to sort by (default: sorted by MSUN+SUN descending)"
|
| 397 |
)
|
| 398 |
sort_direction = gr.Radio(
|
| 399 |
choices=["Ascending", "Descending"],
|
| 400 |
value="Descending",
|
| 401 |
-
label="Sort Direction"
|
| 402 |
)
|
|
|
|
| 403 |
with gr.Column(scale=1):
|
| 404 |
training_set_filter = gr.Dropdown(
|
| 405 |
choices=["All"] + TRAINING_DATASETS,
|
| 406 |
value="MP-20",
|
| 407 |
label="Filter by Training Set",
|
| 408 |
-
info="Show only models trained on a specific dataset"
|
| 409 |
)
|
|
|
|
| 410 |
with gr.Column(scale=2):
|
| 411 |
selected_groups = gr.CheckboxGroup(
|
| 412 |
choices=list(METRIC_GROUPS.keys()),
|
| 413 |
value=list(METRIC_GROUPS.keys()),
|
| 414 |
label="Metric Families (only active when Compact View is off)",
|
| 415 |
-
info="Select which metric groups to display"
|
| 416 |
)
|
| 417 |
|
| 418 |
-
# Metric legend with color coding
|
| 419 |
with gr.Accordion("Metric Groups Legend", open=False):
|
| 420 |
gr.HTML(generate_metric_legend_html())
|
| 421 |
|
| 422 |
try:
|
| 423 |
-
# Initial dataframe - load once and cache
|
| 424 |
-
debug_csv_schemas()
|
| 425 |
initial_df = get_leaderboard()
|
| 426 |
cached_df_state = gr.State(initial_df)
|
| 427 |
|
| 428 |
-
ALWAYS_SHOW_MODELS = {
|
| 429 |
-
|
| 430 |
-
initial_df[
|
| 431 |
-
initial_df[
|
| 432 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 433 |
|
| 434 |
leaderboard_table = gr.Dataframe(
|
| 435 |
label="GenBench Leaderboard",
|
| 436 |
value=formatted_df,
|
| 437 |
interactive=False,
|
| 438 |
wrap=True,
|
| 439 |
-
datatype=["html"] + [None] * (len(
|
| 440 |
-
column_widths=["180px"] + ["160px"] * (len(
|
| 441 |
-
show_fullscreen_button=True
|
| 442 |
)
|
| 443 |
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
)
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
)
|
| 460 |
-
sort_by.change(
|
| 461 |
-
fn=update_leaderboard,
|
| 462 |
-
inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction, training_set_filter],
|
| 463 |
-
outputs=leaderboard_table
|
| 464 |
-
)
|
| 465 |
-
sort_direction.change(
|
| 466 |
-
fn=update_leaderboard,
|
| 467 |
-
inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction, training_set_filter],
|
| 468 |
-
outputs=leaderboard_table
|
| 469 |
-
)
|
| 470 |
-
training_set_filter.change(
|
| 471 |
-
fn=update_leaderboard,
|
| 472 |
-
inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction, training_set_filter],
|
| 473 |
-
outputs=leaderboard_table
|
| 474 |
-
)
|
| 475 |
|
| 476 |
except Exception as e:
|
| 477 |
traceback.print_exc()
|
|
@@ -488,16 +480,14 @@ Generative machine learning models hold great promise for accelerating materials
|
|
| 488 |
Verified submissions mean the results came from a model submission rather than a CIF submission.
|
| 489 |
|
| 490 |
Models marked as baselines appear below the separator line at the bottom of the table.
|
| 491 |
-
""")
|
| 492 |
|
| 493 |
with gr.TabItem("✉️ Submit", elem_id="boundary-benchmark-tab-table"):
|
| 494 |
-
gr.Markdown(
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
)
|
| 500 |
-
filename = gr.State(value=None)
|
| 501 |
|
| 502 |
gr.LoginButton()
|
| 503 |
|
|
@@ -506,79 +496,91 @@ Models marked as baselines appear below the separator line at the bottom of the
|
|
| 506 |
model_name_input = gr.Textbox(
|
| 507 |
label="Model Name",
|
| 508 |
placeholder="Enter your model name",
|
| 509 |
-
info="Provide a name for your model/method"
|
| 510 |
)
|
| 511 |
email_input = gr.Textbox(
|
| 512 |
label="Email Address",
|
| 513 |
placeholder="Enter your email address",
|
| 514 |
-
info="Contact email for correspondence about this submission"
|
| 515 |
)
|
| 516 |
paper_link_input = gr.Textbox(
|
| 517 |
label="Paper Link (optional)",
|
| 518 |
placeholder="https://arxiv.org/abs/...",
|
| 519 |
-
info="Link to the paper describing your model/method"
|
| 520 |
)
|
| 521 |
hf_model_link_input = gr.Textbox(
|
| 522 |
label="HuggingFace Model Link (optional)",
|
| 523 |
placeholder="https://huggingface.co/...",
|
| 524 |
-
info="Link to your model on HuggingFace"
|
| 525 |
)
|
| 526 |
problem_type = gr.Dropdown(PROBLEM_TYPES, label="Problem Type")
|
|
|
|
| 527 |
with gr.Column():
|
| 528 |
cif_file = gr.File(label="Upload a CSV, a pkl, or a ZIP of CIF files.")
|
| 529 |
relaxed = gr.Checkbox(
|
| 530 |
value=False,
|
| 531 |
label="Structures are pre-relaxed",
|
| 532 |
-
info="Check this box if your submitted structures have already been relaxed"
|
| 533 |
)
|
| 534 |
relaxation_settings_input = gr.Textbox(
|
| 535 |
label="Relaxation Settings",
|
| 536 |
placeholder="e.g., VASP PBE, 520 eV cutoff, ...",
|
| 537 |
info="Describe the relaxation settings used",
|
| 538 |
-
visible=False
|
| 539 |
)
|
| 540 |
training_dataset_input = gr.Dropdown(
|
| 541 |
choices=TRAINING_DATASETS,
|
| 542 |
label="Training Dataset",
|
| 543 |
info="Select all datasets used for training",
|
| 544 |
-
multiselect=True
|
| 545 |
)
|
| 546 |
training_dataset_other_input = gr.Textbox(
|
| 547 |
label="Other Training Dataset",
|
| 548 |
placeholder="Specify your training dataset",
|
| 549 |
info="Provide details if you selected 'Others (must specify)'",
|
| 550 |
-
visible=False
|
| 551 |
)
|
| 552 |
|
| 553 |
-
# Show/hide relaxation settings based on pre-relaxed checkbox
|
| 554 |
relaxed.change(
|
| 555 |
fn=lambda x: gr.update(visible=x),
|
| 556 |
inputs=[relaxed],
|
| 557 |
-
outputs=[relaxation_settings_input]
|
| 558 |
)
|
| 559 |
|
| 560 |
-
# Show/hide other dataset text box based on dropdown selection
|
| 561 |
training_dataset_input.change(
|
| 562 |
fn=lambda x: gr.update(visible="Others (must specify)" in (x or [])),
|
| 563 |
inputs=[training_dataset_input],
|
| 564 |
-
outputs=[training_dataset_other_input]
|
| 565 |
)
|
| 566 |
|
| 567 |
submit_btn = gr.Button("Submission")
|
| 568 |
message = gr.Textbox(label="Status", lines=1, visible=False)
|
| 569 |
-
|
| 570 |
-
gr.Markdown(
|
| 571 |
-
|
|
|
|
|
|
|
| 572 |
submit_btn.click(
|
| 573 |
submit_cif_files,
|
| 574 |
-
inputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 575 |
outputs=[message, filename],
|
| 576 |
).then(
|
| 577 |
fn=show_output_box,
|
| 578 |
inputs=[message],
|
| 579 |
outputs=[message],
|
| 580 |
)
|
| 581 |
-
|
| 582 |
return demo
|
| 583 |
|
| 584 |
|
|
|
|
| 1 |
from pathlib import Path
|
| 2 |
import ast
|
| 3 |
import json
|
| 4 |
+
import traceback
|
| 5 |
import pandas as pd
|
| 6 |
import numpy as np
|
|
|
|
| 7 |
|
| 8 |
import gradio as gr
|
| 9 |
+
from datasets import Features, Value, load_dataset
|
| 10 |
from datetime import datetime
|
| 11 |
import os
|
| 12 |
+
|
|
|
|
| 13 |
from about import (
|
| 14 |
PROBLEM_TYPES, TOKEN, CACHE_PATH, API, submissions_repo, results_repo,
|
| 15 |
COLUMN_DISPLAY_NAMES, COUNT_BASED_METRICS, METRIC_GROUPS,
|
| 16 |
METRIC_GROUP_COLORS, COLUMN_TO_GROUP, TRAINING_DATASETS
|
| 17 |
)
|
| 18 |
|
| 19 |
+
|
| 20 |
+
RESULT_FEATURES = Features({
|
| 21 |
+
"run_name": Value("string"),
|
| 22 |
+
"timestamp": Value("string"),
|
| 23 |
+
"n_structures": Value("int64"),
|
| 24 |
+
"overall_valid_count": Value("int64"),
|
| 25 |
+
"charge_neutral_count": Value("int64"),
|
| 26 |
+
"distance_valid_count": Value("int64"),
|
| 27 |
+
"plausibility_valid_count": Value("int64"),
|
| 28 |
+
"unique_count": Value("int64"),
|
| 29 |
+
"novel_count": Value("int64"),
|
| 30 |
+
"mean_formation_energy": Value("float64"),
|
| 31 |
+
"formation_energy_std": Value("float64"),
|
| 32 |
+
"stability_mean_above_hull": Value("float64"),
|
| 33 |
+
"stability_std_e_above_hull": Value("float64"),
|
| 34 |
+
"stability_mean_ensemble_std": Value("float64"),
|
| 35 |
+
"mean_relaxation_RMSD": Value("float64"),
|
| 36 |
+
"relaxation_RMSE_std": Value("float64"),
|
| 37 |
+
"stable_count": Value("int64"),
|
| 38 |
+
"unique_in_stable_count": Value("int64"),
|
| 39 |
+
"sun_count": Value("int64"),
|
| 40 |
+
"metastable_count": Value("int64"),
|
| 41 |
+
"unique_in_metastable_count": Value("int64"),
|
| 42 |
+
"msun_count": Value("int64"),
|
| 43 |
+
"JSDistance": Value("float64"),
|
| 44 |
+
"MMD": Value("float64"),
|
| 45 |
+
"FrechetDistance": Value("float64"),
|
| 46 |
+
"element_diversity": Value("float64"),
|
| 47 |
+
"space_group_diversity": Value("float64"),
|
| 48 |
+
"site_diversity": Value("float64"),
|
| 49 |
+
"physical_size_diversity": Value("float64"),
|
| 50 |
+
"hhi_production_mean": Value("float64"),
|
| 51 |
+
"hhi_reserve_mean": Value("float64"),
|
| 52 |
+
"hhi_combined_mean": Value("float64"),
|
| 53 |
+
"model_name": Value("string"),
|
| 54 |
+
"relaxed": Value("bool"),
|
| 55 |
+
"training_set": Value("string"),
|
| 56 |
+
"paper_link": Value("string"),
|
| 57 |
+
"notes": Value("string"),
|
| 58 |
+
})
|
| 59 |
+
|
| 60 |
+
|
| 61 |
def get_leaderboard():
|
| 62 |
+
ds = load_dataset(
|
| 63 |
+
results_repo,
|
| 64 |
+
data_files="*.csv",
|
| 65 |
+
split="train",
|
| 66 |
+
download_mode="force_redownload",
|
| 67 |
+
features=RESULT_FEATURES,
|
| 68 |
+
)
|
| 69 |
full_df = pd.DataFrame(ds)
|
| 70 |
+
|
| 71 |
if len(full_df) == 0:
|
| 72 |
+
return pd.DataFrame(columns=list(RESULT_FEATURES.keys()))
|
| 73 |
|
| 74 |
+
if "msun_count" in full_df.columns and "sun_count" in full_df.columns:
|
| 75 |
+
full_df["msun_plus_sun"] = full_df["msun_count"] + full_df["sun_count"]
|
|
|
|
| 76 |
|
| 77 |
+
if "msun_plus_sun" in full_df.columns:
|
| 78 |
+
full_df = full_df.sort_values(by="msun_plus_sun", ascending=False)
|
|
|
|
| 79 |
|
| 80 |
return full_df
|
| 81 |
+
|
| 82 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
def format_dataframe(df, show_percentage=False, selected_groups=None, compact_view=True):
|
| 84 |
"""Format the dataframe with proper column names and optional percentages."""
|
| 85 |
if len(df) == 0:
|
| 86 |
return df
|
| 87 |
|
| 88 |
+
selected_cols = ["model_name"]
|
|
|
|
| 89 |
|
| 90 |
if compact_view:
|
|
|
|
| 91 |
from about import COMPACT_VIEW_COLUMNS
|
| 92 |
selected_cols = [col for col in COMPACT_VIEW_COLUMNS if col in df.columns]
|
| 93 |
else:
|
| 94 |
+
if "training_set" in df.columns:
|
| 95 |
+
selected_cols.append("training_set")
|
| 96 |
+
if "n_structures" in df.columns:
|
| 97 |
+
selected_cols.append("n_structures")
|
|
|
|
| 98 |
|
|
|
|
| 99 |
if not selected_groups:
|
| 100 |
selected_groups = list(METRIC_GROUPS.keys())
|
| 101 |
|
|
|
|
| 102 |
for group in selected_groups:
|
| 103 |
if group in METRIC_GROUPS:
|
| 104 |
for col in METRIC_GROUPS[group]:
|
| 105 |
if col in df.columns and col not in selected_cols:
|
| 106 |
selected_cols.append(col)
|
| 107 |
|
|
|
|
| 108 |
display_df = df[selected_cols].copy()
|
| 109 |
|
| 110 |
+
if "model_name" in display_df.columns:
|
|
|
|
|
|
|
| 111 |
model_links = {
|
| 112 |
+
"CrystaLLM-pi": "https://huggingface.co/c-bone/CrystaLLM-pi_base",
|
| 113 |
+
"OMatG": "https://huggingface.co/OMatG/MP-20-DNG/tree/main/EncDec-ODE-Gamma",
|
| 114 |
}
|
| 115 |
|
| 116 |
def add_model_symbols(row):
|
| 117 |
+
name = row["model_name"]
|
| 118 |
symbols = []
|
| 119 |
|
| 120 |
+
if "paper_link" in df.columns:
|
| 121 |
+
paper_val = row.get("paper_link", None)
|
|
|
|
| 122 |
if paper_val and isinstance(paper_val, str) and paper_val.strip():
|
| 123 |
symbols.append(f'<a href="{paper_val.strip()}" target="_blank">📄</a>')
|
| 124 |
|
| 125 |
+
if "relaxed" in df.columns and row.get("relaxed", False):
|
| 126 |
+
symbols.append("⚡")
|
|
|
|
| 127 |
|
| 128 |
+
if name in ["Alexandria", "OQMD"]:
|
| 129 |
+
symbols.append("★")
|
| 130 |
+
elif name == "AFLOW":
|
| 131 |
+
symbols.append("◆")
|
| 132 |
+
elif name in ["CrystaLLM-pi", "OMatG", "Zatom-1-WD"]:
|
| 133 |
+
symbols.append("✅")
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
symbol_str = f" {' '.join(symbols)}" if symbols else ""
|
| 136 |
|
|
|
|
| 137 |
if name in model_links:
|
| 138 |
return f'<a href="{model_links[name]}" target="_blank">{name}</a>{symbol_str}'
|
| 139 |
return f"{name}{symbol_str}"
|
| 140 |
|
| 141 |
+
display_df["model_name"] = df.apply(add_model_symbols, axis=1)
|
| 142 |
|
| 143 |
+
if "training_set" in display_df.columns:
|
|
|
|
| 144 |
def format_training_set(val):
|
| 145 |
if val is None or (isinstance(val, float) and np.isnan(val)):
|
| 146 |
+
return ""
|
| 147 |
val = str(val).strip()
|
| 148 |
+
if val in ("[]", "", "nan", "None"):
|
| 149 |
+
return ""
|
| 150 |
+
val = val.strip("[]")
|
| 151 |
+
val = val.replace("'", "").replace('"', "")
|
|
|
|
| 152 |
return val
|
|
|
|
| 153 |
|
| 154 |
+
display_df["training_set"] = display_df["training_set"].apply(format_training_set)
|
| 155 |
+
|
| 156 |
+
if show_percentage and "n_structures" in df.columns:
|
| 157 |
+
n_structures = df["n_structures"]
|
| 158 |
for col in COUNT_BASED_METRICS:
|
| 159 |
if col in display_df.columns:
|
| 160 |
+
display_df[col] = (df[col] / n_structures * 100).round(1).astype(str) + "%"
|
|
|
|
| 161 |
|
|
|
|
| 162 |
for col in display_df.columns:
|
| 163 |
+
if display_df[col].dtype in ["float64", "float32"]:
|
| 164 |
display_df[col] = display_df[col].round(4)
|
| 165 |
|
|
|
|
| 166 |
baseline_indices = set()
|
| 167 |
+
if "notes" in df.columns:
|
| 168 |
+
is_baseline = df["notes"].fillna("").str.contains("baseline", case=False, na=False)
|
| 169 |
non_baseline_df = display_df[~is_baseline]
|
| 170 |
baseline_df = display_df[is_baseline]
|
| 171 |
display_df = pd.concat([non_baseline_df, baseline_df]).reset_index(drop=True)
|
|
|
|
| 172 |
baseline_indices = set(range(len(non_baseline_df), len(display_df)))
|
| 173 |
|
|
|
|
| 174 |
display_df = display_df.rename(columns=COLUMN_DISPLAY_NAMES)
|
| 175 |
+
return apply_color_styling(display_df, selected_cols, baseline_indices)
|
| 176 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
|
| 178 |
def apply_color_styling(display_df, original_cols, baseline_indices=None):
|
| 179 |
"""Apply background colors to dataframe based on metric groups using pandas Styler."""
|
|
|
|
| 181 |
baseline_indices = set()
|
| 182 |
|
| 183 |
def style_by_group(x):
|
| 184 |
+
styles = pd.DataFrame("", index=x.index, columns=x.columns)
|
|
|
|
| 185 |
|
|
|
|
| 186 |
for i, display_col in enumerate(x.columns):
|
| 187 |
if i < len(original_cols):
|
| 188 |
original_col = original_cols[i]
|
|
|
|
|
|
|
| 189 |
if original_col in COLUMN_TO_GROUP:
|
| 190 |
group = COLUMN_TO_GROUP[original_col]
|
| 191 |
+
color = METRIC_GROUP_COLORS.get(group, "")
|
| 192 |
if color:
|
| 193 |
+
styles[display_col] = f"background-color: {color}"
|
| 194 |
|
|
|
|
| 195 |
if baseline_indices:
|
| 196 |
first_baseline_idx = min(baseline_indices)
|
| 197 |
for col in x.columns:
|
| 198 |
current = styles.at[first_baseline_idx, col]
|
| 199 |
+
separator_style = "border-top: 3px solid #555"
|
| 200 |
+
styles.at[first_baseline_idx, col] = (
|
| 201 |
+
f"{current}; {separator_style}" if current else separator_style
|
| 202 |
+
)
|
| 203 |
|
| 204 |
return styles
|
| 205 |
|
|
|
|
| 206 |
return display_df.style.apply(style_by_group, axis=None)
|
| 207 |
|
| 208 |
+
|
| 209 |
def parse_training_set(val):
|
| 210 |
+
"""Parse a training_set value stored as a string like "['MP-20']" into a list."""
|
| 211 |
try:
|
| 212 |
return ast.literal_eval(str(val))
|
| 213 |
except (ValueError, SyntaxError):
|
| 214 |
return []
|
| 215 |
|
|
|
|
|
|
|
| 216 |
|
| 217 |
+
def update_leaderboard(show_percentage, selected_groups, compact_view, cached_df, sort_by, sort_direction, training_set_filter):
|
| 218 |
+
"""Update the leaderboard based on user selections."""
|
|
|
|
| 219 |
df_to_format = cached_df.copy()
|
| 220 |
|
| 221 |
+
ALWAYS_SHOW_MODELS = {"AFLOW", "Alexandria", "OQMD"}
|
| 222 |
+
if training_set_filter and training_set_filter != "All" and "training_set" in df_to_format.columns:
|
|
|
|
| 223 |
mask = (
|
| 224 |
+
df_to_format["training_set"].apply(lambda x: training_set_filter in parse_training_set(x))
|
| 225 |
+
| df_to_format["model_name"].isin(ALWAYS_SHOW_MODELS)
|
| 226 |
)
|
| 227 |
df_to_format = df_to_format[mask]
|
| 228 |
|
|
|
|
| 229 |
if sort_by and sort_by != "None":
|
|
|
|
| 230 |
display_to_raw = {v: k for k, v in COLUMN_DISPLAY_NAMES.items()}
|
| 231 |
raw_column_name = display_to_raw.get(sort_by, sort_by)
|
| 232 |
|
| 233 |
if raw_column_name in df_to_format.columns:
|
| 234 |
+
ascending = sort_direction == "Ascending"
|
| 235 |
df_to_format = df_to_format.sort_values(by=raw_column_name, ascending=ascending)
|
| 236 |
|
| 237 |
+
return format_dataframe(df_to_format, show_percentage, selected_groups, compact_view)
|
| 238 |
+
|
| 239 |
|
| 240 |
def show_output_box(message):
|
| 241 |
return gr.update(value=message, visible=True)
|
| 242 |
|
| 243 |
+
|
| 244 |
def submit_cif_files(model_name, problem_type, cif_files, relaxed, relaxation_settings, training_datasets, training_dataset_other, paper_link, hf_model_link, email, profile: gr.OAuthProfile | None):
|
| 245 |
"""Submit structures to the leaderboard."""
|
| 246 |
from huggingface_hub import upload_file
|
| 247 |
|
|
|
|
| 248 |
if not model_name or not model_name.strip():
|
| 249 |
return "Error: Please provide a model name.", None
|
| 250 |
|
|
|
|
| 264 |
username = profile.username
|
| 265 |
timestamp = datetime.now().isoformat()
|
| 266 |
|
|
|
|
| 267 |
submission_data = {
|
| 268 |
"username": username,
|
| 269 |
"model_name": model_name.strip(),
|
|
|
|
| 276 |
"hf_model_link": hf_model_link.strip() if hf_model_link else None,
|
| 277 |
"email": email.strip(),
|
| 278 |
"timestamp": timestamp,
|
| 279 |
+
"file_name": Path(cif_files).name,
|
| 280 |
}
|
| 281 |
|
|
|
|
| 282 |
submission_id = f"{username}_{model_name.strip().replace(' ', '_')}_{timestamp.replace(':', '-')}"
|
|
|
|
|
|
|
| 283 |
file_path = Path(cif_files)
|
| 284 |
uploaded_file_path = f"submissions/{submission_id}/{file_path.name}"
|
| 285 |
|
|
|
|
| 288 |
path_in_repo=uploaded_file_path,
|
| 289 |
repo_id=submissions_repo,
|
| 290 |
token=TOKEN,
|
| 291 |
+
repo_type="dataset",
|
| 292 |
)
|
| 293 |
|
|
|
|
| 294 |
metadata_path = f"submissions/{submission_id}/metadata.json"
|
| 295 |
import tempfile
|
| 296 |
+
|
| 297 |
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
| 298 |
json.dump(submission_data, f, indent=2)
|
| 299 |
temp_metadata_path = f.name
|
| 300 |
|
|
|
|
| 303 |
path_in_repo=metadata_path,
|
| 304 |
repo_id=submissions_repo,
|
| 305 |
token=TOKEN,
|
| 306 |
+
repo_type="dataset",
|
| 307 |
)
|
| 308 |
|
|
|
|
| 309 |
os.unlink(temp_metadata_path)
|
| 310 |
|
| 311 |
return f"Success! Submitted {model_name} for {problem_type} evaluation. Submission ID: {submission_id}", submission_id
|
| 312 |
|
| 313 |
except Exception as e:
|
| 314 |
+
return f"Error during submission: {str(e)}", None
|
| 315 |
+
|
| 316 |
|
| 317 |
def generate_metric_legend_html():
|
| 318 |
"""Generate HTML table with color-coded metric group legend."""
|
| 319 |
metric_details = {
|
| 320 |
+
"Validity ↑": ("Valid, Charge Neutral, Distance Valid, Plausibility Valid", "↑ Higher is better"),
|
| 321 |
+
"Uniqueness & Novelty ↑": ("Unique, Novel", "↑ Higher is better"),
|
| 322 |
+
"Energy Metrics ↓": ("E Above Hull, Formation Energy, Relaxation RMSD (with std)", "↓ Lower is better"),
|
| 323 |
+
"Stability ↑": ("Stable, Unique in Stable, SUN", "↑ Higher is better"),
|
| 324 |
+
"Metastability ↑": ("Metastable, Unique in Metastable, MSUN", "↑ Higher is better"),
|
| 325 |
+
"Distribution ↓": ("JS Distance, MMD, FID", "↓ Lower is better"),
|
| 326 |
+
"Diversity ↑": ("Element, Space Group, Atomic Site, Crystal Size", "↑ Higher is better"),
|
| 327 |
+
"HHI ↓": ("HHI Production, HHI Reserve", "↓ Lower is better"),
|
| 328 |
}
|
| 329 |
|
| 330 |
html = '<table style="width: 100%; border-collapse: collapse;">'
|
| 331 |
+
html += "<thead><tr>"
|
| 332 |
html += '<th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Color</th>'
|
| 333 |
html += '<th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Group</th>'
|
| 334 |
html += '<th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Metrics</th>'
|
| 335 |
html += '<th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Direction</th>'
|
| 336 |
+
html += "</tr></thead><tbody>"
|
| 337 |
|
| 338 |
for group, color in METRIC_GROUP_COLORS.items():
|
| 339 |
+
metrics, direction = metric_details.get(group, ("", ""))
|
| 340 |
+
group_name = group.replace("↑", "").replace("↓", "").strip()
|
| 341 |
|
| 342 |
+
html += "<tr>"
|
| 343 |
html += f'<td style="border: 1px solid #ddd; padding: 8px;"><div style="width: 30px; height: 20px; background-color: {color}; border: 1px solid #999;"></div></td>'
|
| 344 |
html += f'<td style="border: 1px solid #ddd; padding: 8px;"><strong>{group_name}</strong></td>'
|
| 345 |
html += f'<td style="border: 1px solid #ddd; padding: 8px;">{metrics}</td>'
|
| 346 |
html += f'<td style="border: 1px solid #ddd; padding: 8px;">{direction}</td>'
|
| 347 |
+
html += "</tr>"
|
| 348 |
|
| 349 |
+
html += "</tbody></table>"
|
| 350 |
return html
|
| 351 |
|
| 352 |
+
|
| 353 |
def gradio_interface() -> gr.Blocks:
|
| 354 |
with gr.Blocks() as demo:
|
| 355 |
gr.Markdown("""
|
|
|
|
| 361 |
|
| 362 |
📄 **Paper**: [arXiv](https://arxiv.org/abs/2512.04562) | 💻 **Code**: [GitHub](https://github.com/LeMaterial/lemat-genbench) | 📧 **Contact**: siddharth.betala [at] entalpic.ai, alexandre.duval [at] entalpic.ai
|
| 363 |
""")
|
| 364 |
+
|
| 365 |
with gr.Tabs(elem_classes="tab-buttons"):
|
| 366 |
with gr.TabItem("🚀 Leaderboard", elem_id="boundary-benchmark-tab-table"):
|
| 367 |
gr.Markdown("# LeMat-GenBench")
|
| 368 |
|
|
|
|
| 369 |
with gr.Row():
|
| 370 |
with gr.Column(scale=1):
|
| 371 |
compact_view = gr.Checkbox(
|
| 372 |
value=True,
|
| 373 |
label="Compact View",
|
| 374 |
+
info="Show only key metrics",
|
| 375 |
)
|
| 376 |
show_percentage = gr.Checkbox(
|
| 377 |
value=True,
|
| 378 |
label="Show as Percentages",
|
| 379 |
+
info="Display count-based metrics as percentages of total structures",
|
| 380 |
)
|
| 381 |
+
|
| 382 |
with gr.Column(scale=1):
|
| 383 |
+
sort_choices = ["None"] + [
|
| 384 |
+
COLUMN_DISPLAY_NAMES.get(col, col)
|
| 385 |
+
for col in COLUMN_DISPLAY_NAMES.keys()
|
| 386 |
+
]
|
| 387 |
sort_by = gr.Dropdown(
|
| 388 |
choices=sort_choices,
|
| 389 |
value="None",
|
| 390 |
label="Sort By",
|
| 391 |
+
info="Select column to sort by (default: sorted by MSUN+SUN descending)",
|
| 392 |
)
|
| 393 |
sort_direction = gr.Radio(
|
| 394 |
choices=["Ascending", "Descending"],
|
| 395 |
value="Descending",
|
| 396 |
+
label="Sort Direction",
|
| 397 |
)
|
| 398 |
+
|
| 399 |
with gr.Column(scale=1):
|
| 400 |
training_set_filter = gr.Dropdown(
|
| 401 |
choices=["All"] + TRAINING_DATASETS,
|
| 402 |
value="MP-20",
|
| 403 |
label="Filter by Training Set",
|
| 404 |
+
info="Show only models trained on a specific dataset",
|
| 405 |
)
|
| 406 |
+
|
| 407 |
with gr.Column(scale=2):
|
| 408 |
selected_groups = gr.CheckboxGroup(
|
| 409 |
choices=list(METRIC_GROUPS.keys()),
|
| 410 |
value=list(METRIC_GROUPS.keys()),
|
| 411 |
label="Metric Families (only active when Compact View is off)",
|
| 412 |
+
info="Select which metric groups to display",
|
| 413 |
)
|
| 414 |
|
|
|
|
| 415 |
with gr.Accordion("Metric Groups Legend", open=False):
|
| 416 |
gr.HTML(generate_metric_legend_html())
|
| 417 |
|
| 418 |
try:
|
|
|
|
|
|
|
| 419 |
initial_df = get_leaderboard()
|
| 420 |
cached_df_state = gr.State(initial_df)
|
| 421 |
|
| 422 |
+
ALWAYS_SHOW_MODELS = {"AFLOW", "Alexandria", "OQMD"}
|
| 423 |
+
filtered_initial_df = initial_df[
|
| 424 |
+
initial_df["training_set"].apply(lambda x: "MP-20" in parse_training_set(x))
|
| 425 |
+
| initial_df["model_name"].isin(ALWAYS_SHOW_MODELS)
|
| 426 |
+
]
|
| 427 |
+
|
| 428 |
+
formatted_df = format_dataframe(
|
| 429 |
+
filtered_initial_df,
|
| 430 |
+
show_percentage=True,
|
| 431 |
+
selected_groups=list(METRIC_GROUPS.keys()),
|
| 432 |
+
compact_view=True,
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
formatted_columns = (
|
| 436 |
+
list(formatted_df.data.columns)
|
| 437 |
+
if hasattr(formatted_df, "data")
|
| 438 |
+
else list(formatted_df.columns)
|
| 439 |
+
)
|
| 440 |
|
| 441 |
leaderboard_table = gr.Dataframe(
|
| 442 |
label="GenBench Leaderboard",
|
| 443 |
value=formatted_df,
|
| 444 |
interactive=False,
|
| 445 |
wrap=True,
|
| 446 |
+
datatype=["html"] + [None] * (len(formatted_columns) - 1) if formatted_columns else None,
|
| 447 |
+
column_widths=["180px"] + ["160px"] * (len(formatted_columns) - 1) if formatted_columns else None,
|
| 448 |
+
show_fullscreen_button=True,
|
| 449 |
)
|
| 450 |
|
| 451 |
+
inputs = [
|
| 452 |
+
show_percentage,
|
| 453 |
+
selected_groups,
|
| 454 |
+
compact_view,
|
| 455 |
+
cached_df_state,
|
| 456 |
+
sort_by,
|
| 457 |
+
sort_direction,
|
| 458 |
+
training_set_filter,
|
| 459 |
+
]
|
| 460 |
+
|
| 461 |
+
show_percentage.change(fn=update_leaderboard, inputs=inputs, outputs=leaderboard_table)
|
| 462 |
+
selected_groups.change(fn=update_leaderboard, inputs=inputs, outputs=leaderboard_table)
|
| 463 |
+
compact_view.change(fn=update_leaderboard, inputs=inputs, outputs=leaderboard_table)
|
| 464 |
+
sort_by.change(fn=update_leaderboard, inputs=inputs, outputs=leaderboard_table)
|
| 465 |
+
sort_direction.change(fn=update_leaderboard, inputs=inputs, outputs=leaderboard_table)
|
| 466 |
+
training_set_filter.change(fn=update_leaderboard, inputs=inputs, outputs=leaderboard_table)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 467 |
|
| 468 |
except Exception as e:
|
| 469 |
traceback.print_exc()
|
|
|
|
| 480 |
Verified submissions mean the results came from a model submission rather than a CIF submission.
|
| 481 |
|
| 482 |
Models marked as baselines appear below the separator line at the bottom of the table.
|
| 483 |
+
""")
|
| 484 |
|
| 485 |
with gr.TabItem("✉️ Submit", elem_id="boundary-benchmark-tab-table"):
|
| 486 |
+
gr.Markdown("""
|
| 487 |
+
# Materials Submission
|
| 488 |
+
Upload a ZIP of CIFs with your structures. To ensure eligibility for the leaderboard, please provide exactly 2,500 representative structures.
|
| 489 |
+
""")
|
| 490 |
+
filename = gr.State(value=None)
|
|
|
|
|
|
|
| 491 |
|
| 492 |
gr.LoginButton()
|
| 493 |
|
|
|
|
| 496 |
model_name_input = gr.Textbox(
|
| 497 |
label="Model Name",
|
| 498 |
placeholder="Enter your model name",
|
| 499 |
+
info="Provide a name for your model/method",
|
| 500 |
)
|
| 501 |
email_input = gr.Textbox(
|
| 502 |
label="Email Address",
|
| 503 |
placeholder="Enter your email address",
|
| 504 |
+
info="Contact email for correspondence about this submission",
|
| 505 |
)
|
| 506 |
paper_link_input = gr.Textbox(
|
| 507 |
label="Paper Link (optional)",
|
| 508 |
placeholder="https://arxiv.org/abs/...",
|
| 509 |
+
info="Link to the paper describing your model/method",
|
| 510 |
)
|
| 511 |
hf_model_link_input = gr.Textbox(
|
| 512 |
label="HuggingFace Model Link (optional)",
|
| 513 |
placeholder="https://huggingface.co/...",
|
| 514 |
+
info="Link to your model on HuggingFace",
|
| 515 |
)
|
| 516 |
problem_type = gr.Dropdown(PROBLEM_TYPES, label="Problem Type")
|
| 517 |
+
|
| 518 |
with gr.Column():
|
| 519 |
cif_file = gr.File(label="Upload a CSV, a pkl, or a ZIP of CIF files.")
|
| 520 |
relaxed = gr.Checkbox(
|
| 521 |
value=False,
|
| 522 |
label="Structures are pre-relaxed",
|
| 523 |
+
info="Check this box if your submitted structures have already been relaxed",
|
| 524 |
)
|
| 525 |
relaxation_settings_input = gr.Textbox(
|
| 526 |
label="Relaxation Settings",
|
| 527 |
placeholder="e.g., VASP PBE, 520 eV cutoff, ...",
|
| 528 |
info="Describe the relaxation settings used",
|
| 529 |
+
visible=False,
|
| 530 |
)
|
| 531 |
training_dataset_input = gr.Dropdown(
|
| 532 |
choices=TRAINING_DATASETS,
|
| 533 |
label="Training Dataset",
|
| 534 |
info="Select all datasets used for training",
|
| 535 |
+
multiselect=True,
|
| 536 |
)
|
| 537 |
training_dataset_other_input = gr.Textbox(
|
| 538 |
label="Other Training Dataset",
|
| 539 |
placeholder="Specify your training dataset",
|
| 540 |
info="Provide details if you selected 'Others (must specify)'",
|
| 541 |
+
visible=False,
|
| 542 |
)
|
| 543 |
|
|
|
|
| 544 |
relaxed.change(
|
| 545 |
fn=lambda x: gr.update(visible=x),
|
| 546 |
inputs=[relaxed],
|
| 547 |
+
outputs=[relaxation_settings_input],
|
| 548 |
)
|
| 549 |
|
|
|
|
| 550 |
training_dataset_input.change(
|
| 551 |
fn=lambda x: gr.update(visible="Others (must specify)" in (x or [])),
|
| 552 |
inputs=[training_dataset_input],
|
| 553 |
+
outputs=[training_dataset_other_input],
|
| 554 |
)
|
| 555 |
|
| 556 |
submit_btn = gr.Button("Submission")
|
| 557 |
message = gr.Textbox(label="Status", lines=1, visible=False)
|
| 558 |
+
|
| 559 |
+
gr.Markdown(
|
| 560 |
+
"If you have issues with submission or using the leaderboard, please start a discussion in the Community tab of this Space."
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
submit_btn.click(
|
| 564 |
submit_cif_files,
|
| 565 |
+
inputs=[
|
| 566 |
+
model_name_input,
|
| 567 |
+
problem_type,
|
| 568 |
+
cif_file,
|
| 569 |
+
relaxed,
|
| 570 |
+
relaxation_settings_input,
|
| 571 |
+
training_dataset_input,
|
| 572 |
+
training_dataset_other_input,
|
| 573 |
+
paper_link_input,
|
| 574 |
+
hf_model_link_input,
|
| 575 |
+
email_input,
|
| 576 |
+
],
|
| 577 |
outputs=[message, filename],
|
| 578 |
).then(
|
| 579 |
fn=show_output_box,
|
| 580 |
inputs=[message],
|
| 581 |
outputs=[message],
|
| 582 |
)
|
| 583 |
+
|
| 584 |
return demo
|
| 585 |
|
| 586 |
|