sid-betalol commited on
Commit
3c2a001
·
verified ·
1 Parent(s): 6668ede

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +223 -221
app.py CHANGED
@@ -1,174 +1,179 @@
1
  from pathlib import Path
2
  import ast
3
  import json
 
4
  import pandas as pd
5
  import numpy as np
6
- import traceback
7
 
8
  import gradio as gr
9
- from datasets import load_dataset
10
  from datetime import datetime
11
  import os
12
- from huggingface_hub import list_repo_files, hf_hub_download
13
- import pandas as pd
14
  from about import (
15
  PROBLEM_TYPES, TOKEN, CACHE_PATH, API, submissions_repo, results_repo,
16
  COLUMN_DISPLAY_NAMES, COUNT_BASED_METRICS, METRIC_GROUPS,
17
  METRIC_GROUP_COLORS, COLUMN_TO_GROUP, TRAINING_DATASETS
18
  )
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def get_leaderboard():
21
- ds = load_dataset(results_repo, split='train', download_mode="force_redownload")
 
 
 
 
 
 
22
  full_df = pd.DataFrame(ds)
23
- print(full_df.columns)
24
  if len(full_df) == 0:
25
- return pd.DataFrame({'date':[], 'model':[], 'score':[], 'verified':[]})
26
 
27
- # Add computed column for MSUN+SUN
28
- if 'msun_count' in full_df.columns and 'sun_count' in full_df.columns:
29
- full_df['msun_plus_sun'] = full_df['msun_count'] + full_df['sun_count']
30
 
31
- # Sort by MSUN+SUN in descending order by default
32
- if 'msun_plus_sun' in full_df.columns:
33
- full_df = full_df.sort_values(by='msun_plus_sun', ascending=False)
34
 
35
  return full_df
36
- def debug_csv_schemas():
37
- files = [
38
- f for f in list_repo_files(results_repo, repo_type="dataset")
39
- if f.endswith(".csv")
40
- ]
41
- for file in sorted(files):
42
- path = hf_hub_download(results_repo, filename=file, repo_type="dataset")
43
- df = pd.read_csv(path)
44
- paper_dtype = str(df["paper_link"].dtype) if "paper_link" in df else "MISSING"
45
- notes_dtype = str(df["notes"].dtype) if "notes" in df else "MISSING"
46
- paper_value = df["paper_link"].iloc[0] if "paper_link" in df and len(df) else None
47
- notes_value = df["notes"].iloc[0] if "notes" in df and len(df) else None
48
- print(
49
- file,
50
- "paper_link:", paper_dtype, repr(paper_value),
51
- "notes:", notes_dtype, repr(notes_value),
52
- flush=True,
53
- )
54
  def format_dataframe(df, show_percentage=False, selected_groups=None, compact_view=True):
55
  """Format the dataframe with proper column names and optional percentages."""
56
  if len(df) == 0:
57
  return df
58
 
59
- # Build column list based on view mode
60
- selected_cols = ['model_name']
61
 
62
  if compact_view:
63
- # Use predefined compact columns
64
  from about import COMPACT_VIEW_COLUMNS
65
  selected_cols = [col for col in COMPACT_VIEW_COLUMNS if col in df.columns]
66
  else:
67
- # Build from selected groups
68
- if 'training_set' in df.columns:
69
- selected_cols.append('training_set')
70
- if 'n_structures' in df.columns:
71
- selected_cols.append('n_structures')
72
 
73
- # If no groups selected, show all
74
  if not selected_groups:
75
  selected_groups = list(METRIC_GROUPS.keys())
76
 
77
- # Add columns from selected groups
78
  for group in selected_groups:
79
  if group in METRIC_GROUPS:
80
  for col in METRIC_GROUPS[group]:
81
  if col in df.columns and col not in selected_cols:
82
  selected_cols.append(col)
83
 
84
- # Create a copy with selected columns
85
  display_df = df[selected_cols].copy()
86
 
87
- # Add symbols to model names based on various properties
88
- if 'model_name' in display_df.columns:
89
- # Model links mapping
90
  model_links = {
91
- 'CrystaLLM-pi': 'https://huggingface.co/c-bone/CrystaLLM-pi_base',
92
- 'OMatG': 'https://huggingface.co/OMatG/MP-20-DNG/tree/main/EncDec-ODE-Gamma'
93
  }
94
 
95
  def add_model_symbols(row):
96
- name = row['model_name']
97
  symbols = []
98
 
99
- # Add paper link emoji
100
- if 'paper_link' in df.columns:
101
- paper_val = row.get('paper_link', None)
102
  if paper_val and isinstance(paper_val, str) and paper_val.strip():
103
  symbols.append(f'<a href="{paper_val.strip()}" target="_blank">📄</a>')
104
 
105
- # Add relaxed symbol
106
- if 'relaxed' in df.columns and row.get('relaxed', False):
107
- symbols.append('⚡')
108
 
109
- # Add reference dataset symbols
110
- # ★ for Alexandria and OQMD (in-distribution, part of reference dataset)
111
- if name in ['Alexandria', 'OQMD']:
112
- symbols.append('★')
113
- # for AFLOW (out-of-distribution relative to reference dataset)
114
- elif name == 'AFLOW':
115
- symbols.append('◆')
116
- elif name == 'CrystaLLM-pi' or name == 'OMatG' or name == 'Zatom-1-WD':
117
- symbols.append('✅')
118
 
119
  symbol_str = f" {' '.join(symbols)}" if symbols else ""
120
 
121
- # Add link if model has one
122
  if name in model_links:
123
  return f'<a href="{model_links[name]}" target="_blank">{name}</a>{symbol_str}'
124
  return f"{name}{symbol_str}"
125
 
126
- display_df['model_name'] = df.apply(add_model_symbols, axis=1)
127
 
128
- # Format training_set column for clean display
129
- if 'training_set' in display_df.columns:
130
  def format_training_set(val):
131
  if val is None or (isinstance(val, float) and np.isnan(val)):
132
- return ''
133
  val = str(val).strip()
134
- if val in ('[]', '', 'nan', 'None'):
135
- return ''
136
- # Strip brackets and quotes for list-like strings
137
- val = val.strip('[]')
138
- val = val.replace("'", "").replace('"', '')
139
  return val
140
- display_df['training_set'] = display_df['training_set'].apply(format_training_set)
141
 
142
- # Convert count-based metrics to percentages if requested
143
- if show_percentage and 'n_structures' in df.columns:
144
- n_structures = df['n_structures']
 
145
  for col in COUNT_BASED_METRICS:
146
  if col in display_df.columns:
147
- # Calculate percentage and format as string with %
148
- display_df[col] = (df[col] / n_structures * 100).round(1).astype(str) + '%'
149
 
150
- # Round numeric columns for cleaner display
151
  for col in display_df.columns:
152
- if display_df[col].dtype in ['float64', 'float32']:
153
  display_df[col] = display_df[col].round(4)
154
 
155
- # Separate baseline models to the bottom
156
  baseline_indices = set()
157
- if 'notes' in df.columns:
158
- is_baseline = df['notes'].fillna('').str.contains('baseline', case=False, na=False)
159
  non_baseline_df = display_df[~is_baseline]
160
  baseline_df = display_df[is_baseline]
161
  display_df = pd.concat([non_baseline_df, baseline_df]).reset_index(drop=True)
162
- # Track baseline row indices in the new dataframe
163
  baseline_indices = set(range(len(non_baseline_df), len(display_df)))
164
 
165
- # Rename columns for display
166
  display_df = display_df.rename(columns=COLUMN_DISPLAY_NAMES)
 
167
 
168
- # Apply color coding based on metric groups
169
- styled_df = apply_color_styling(display_df, selected_cols, baseline_indices)
170
-
171
- return styled_df
172
 
173
  def apply_color_styling(display_df, original_cols, baseline_indices=None):
174
  """Apply background colors to dataframe based on metric groups using pandas Styler."""
@@ -176,79 +181,70 @@ def apply_color_styling(display_df, original_cols, baseline_indices=None):
176
  baseline_indices = set()
177
 
178
  def style_by_group(x):
179
- # Create a DataFrame with the same shape filled with empty strings
180
- styles = pd.DataFrame('', index=x.index, columns=x.columns)
181
 
182
- # Map display column names back to original column names
183
  for i, display_col in enumerate(x.columns):
184
  if i < len(original_cols):
185
  original_col = original_cols[i]
186
-
187
- # Check if this column belongs to a metric group
188
  if original_col in COLUMN_TO_GROUP:
189
  group = COLUMN_TO_GROUP[original_col]
190
- color = METRIC_GROUP_COLORS.get(group, '')
191
  if color:
192
- styles[display_col] = f'background-color: {color}'
193
 
194
- # Add thick top border to the first baseline row as a separator
195
  if baseline_indices:
196
  first_baseline_idx = min(baseline_indices)
197
  for col in x.columns:
198
  current = styles.at[first_baseline_idx, col]
199
- separator_style = 'border-top: 3px solid #555'
200
- styles.at[first_baseline_idx, col] = f'{current}; {separator_style}' if current else separator_style
 
 
201
 
202
  return styles
203
 
204
- # Apply the styling function
205
  return display_df.style.apply(style_by_group, axis=None)
206
 
 
207
  def parse_training_set(val):
208
- """Parse a training_set value (stored as a string like \"['MP-20']\") into a list."""
209
  try:
210
  return ast.literal_eval(str(val))
211
  except (ValueError, SyntaxError):
212
  return []
213
 
214
- def update_leaderboard(show_percentage, selected_groups, compact_view, cached_df, sort_by, sort_direction, training_set_filter):
215
- """Update the leaderboard based on user selections.
216
 
217
- Uses cached dataframe to avoid re-downloading data on every change.
218
- """
219
- # Use cached dataframe instead of re-downloading
220
  df_to_format = cached_df.copy()
221
 
222
- # Apply training set filter (baselines always shown regardless of filter)
223
- ALWAYS_SHOW_MODELS = {'AFLOW', 'Alexandria', 'OQMD'}
224
- if training_set_filter and training_set_filter != "All" and 'training_set' in df_to_format.columns:
225
  mask = (
226
- df_to_format['training_set'].apply(lambda x: training_set_filter in parse_training_set(x)) |
227
- df_to_format['model_name'].isin(ALWAYS_SHOW_MODELS)
228
  )
229
  df_to_format = df_to_format[mask]
230
 
231
- # Convert display name back to raw column name for sorting
232
  if sort_by and sort_by != "None":
233
- # Create reverse mapping from display names to raw column names
234
  display_to_raw = {v: k for k, v in COLUMN_DISPLAY_NAMES.items()}
235
  raw_column_name = display_to_raw.get(sort_by, sort_by)
236
 
237
  if raw_column_name in df_to_format.columns:
238
- ascending = (sort_direction == "Ascending")
239
  df_to_format = df_to_format.sort_values(by=raw_column_name, ascending=ascending)
240
 
241
- formatted_df = format_dataframe(df_to_format, show_percentage, selected_groups, compact_view)
242
- return formatted_df
243
 
244
  def show_output_box(message):
245
  return gr.update(value=message, visible=True)
246
 
 
247
  def submit_cif_files(model_name, problem_type, cif_files, relaxed, relaxation_settings, training_datasets, training_dataset_other, paper_link, hf_model_link, email, profile: gr.OAuthProfile | None):
248
  """Submit structures to the leaderboard."""
249
  from huggingface_hub import upload_file
250
 
251
- # Validate inputs
252
  if not model_name or not model_name.strip():
253
  return "Error: Please provide a model name.", None
254
 
@@ -268,7 +264,6 @@ def submit_cif_files(model_name, problem_type, cif_files, relaxed, relaxation_se
268
  username = profile.username
269
  timestamp = datetime.now().isoformat()
270
 
271
- # Create submission metadata
272
  submission_data = {
273
  "username": username,
274
  "model_name": model_name.strip(),
@@ -281,13 +276,10 @@ def submit_cif_files(model_name, problem_type, cif_files, relaxed, relaxation_se
281
  "hf_model_link": hf_model_link.strip() if hf_model_link else None,
282
  "email": email.strip(),
283
  "timestamp": timestamp,
284
- "file_name": Path(cif_files).name
285
  }
286
 
287
- # Create a unique submission ID
288
  submission_id = f"{username}_{model_name.strip().replace(' ', '_')}_{timestamp.replace(':', '-')}"
289
-
290
- # Upload the submission file
291
  file_path = Path(cif_files)
292
  uploaded_file_path = f"submissions/{submission_id}/{file_path.name}"
293
 
@@ -296,13 +288,13 @@ def submit_cif_files(model_name, problem_type, cif_files, relaxed, relaxation_se
296
  path_in_repo=uploaded_file_path,
297
  repo_id=submissions_repo,
298
  token=TOKEN,
299
- repo_type="dataset"
300
  )
301
 
302
- # Upload metadata as JSON
303
  metadata_path = f"submissions/{submission_id}/metadata.json"
304
  import tempfile
305
- with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
 
306
  json.dump(submission_data, f, indent=2)
307
  temp_metadata_path = f.name
308
 
@@ -311,52 +303,53 @@ def submit_cif_files(model_name, problem_type, cif_files, relaxed, relaxation_se
311
  path_in_repo=metadata_path,
312
  repo_id=submissions_repo,
313
  token=TOKEN,
314
- repo_type="dataset"
315
  )
316
 
317
- # Clean up temp file
318
  os.unlink(temp_metadata_path)
319
 
320
  return f"Success! Submitted {model_name} for {problem_type} evaluation. Submission ID: {submission_id}", submission_id
321
 
322
  except Exception as e:
323
- return f"Error during submission: {str(e)}", None
 
324
 
325
  def generate_metric_legend_html():
326
  """Generate HTML table with color-coded metric group legend."""
327
  metric_details = {
328
- 'Validity ↑': ('Valid, Charge Neutral, Distance Valid, Plausibility Valid', '↑ Higher is better'),
329
- 'Uniqueness & Novelty ↑': ('Unique, Novel', '↑ Higher is better'),
330
- 'Energy Metrics ↓': ('E Above Hull, Formation Energy, Relaxation RMSD (with std)', '↓ Lower is better'),
331
- 'Stability ↑': ('Stable, Unique in Stable, SUN', '↑ Higher is better'),
332
- 'Metastability ↑': ('Metastable, Unique in Metastable, MSUN', '↑ Higher is better'),
333
- 'Distribution ↓': ('JS Distance, MMD, FID', '↓ Lower is better'),
334
- 'Diversity ↑': ('Element, Space Group, Atomic Site, Crystal Size', '↑ Higher is better'),
335
- 'HHI ↓': ('HHI Production, HHI Reserve', '↓ Lower is better'),
336
  }
337
 
338
  html = '<table style="width: 100%; border-collapse: collapse;">'
339
- html += '<thead><tr>'
340
  html += '<th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Color</th>'
341
  html += '<th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Group</th>'
342
  html += '<th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Metrics</th>'
343
  html += '<th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Direction</th>'
344
- html += '</tr></thead><tbody>'
345
 
346
  for group, color in METRIC_GROUP_COLORS.items():
347
- metrics, direction = metric_details.get(group, ('', ''))
348
- group_name = group.replace('', '').replace('', '').strip()
349
 
350
- html += '<tr>'
351
  html += f'<td style="border: 1px solid #ddd; padding: 8px;"><div style="width: 30px; height: 20px; background-color: {color}; border: 1px solid #999;"></div></td>'
352
  html += f'<td style="border: 1px solid #ddd; padding: 8px;"><strong>{group_name}</strong></td>'
353
  html += f'<td style="border: 1px solid #ddd; padding: 8px;">{metrics}</td>'
354
  html += f'<td style="border: 1px solid #ddd; padding: 8px;">{direction}</td>'
355
- html += '</tr>'
356
 
357
- html += '</tbody></table>'
358
  return html
359
 
 
360
  def gradio_interface() -> gr.Blocks:
361
  with gr.Blocks() as demo:
362
  gr.Markdown("""
@@ -368,110 +361,109 @@ Generative machine learning models hold great promise for accelerating materials
368
 
369
  📄 **Paper**: [arXiv](https://arxiv.org/abs/2512.04562) | 💻 **Code**: [GitHub](https://github.com/LeMaterial/lemat-genbench) | 📧 **Contact**: siddharth.betala [at] entalpic.ai, alexandre.duval [at] entalpic.ai
370
  """)
 
371
  with gr.Tabs(elem_classes="tab-buttons"):
372
  with gr.TabItem("🚀 Leaderboard", elem_id="boundary-benchmark-tab-table"):
373
  gr.Markdown("# LeMat-GenBench")
374
 
375
- # Display options
376
  with gr.Row():
377
  with gr.Column(scale=1):
378
  compact_view = gr.Checkbox(
379
  value=True,
380
  label="Compact View",
381
- info="Show only key metrics"
382
  )
383
  show_percentage = gr.Checkbox(
384
  value=True,
385
  label="Show as Percentages",
386
- info="Display count-based metrics as percentages of total structures"
387
  )
 
388
  with gr.Column(scale=1):
389
- # Create choices with display names, but values are the raw column names
390
- sort_choices = ["None"] + [COLUMN_DISPLAY_NAMES.get(col, col) for col in COLUMN_DISPLAY_NAMES.keys()]
391
- # Note: The initial sort is already applied in get_leaderboard() by MSUN+SUN
 
392
  sort_by = gr.Dropdown(
393
  choices=sort_choices,
394
  value="None",
395
  label="Sort By",
396
- info="Select column to sort by (default: sorted by MSUN+SUN descending)"
397
  )
398
  sort_direction = gr.Radio(
399
  choices=["Ascending", "Descending"],
400
  value="Descending",
401
- label="Sort Direction"
402
  )
 
403
  with gr.Column(scale=1):
404
  training_set_filter = gr.Dropdown(
405
  choices=["All"] + TRAINING_DATASETS,
406
  value="MP-20",
407
  label="Filter by Training Set",
408
- info="Show only models trained on a specific dataset"
409
  )
 
410
  with gr.Column(scale=2):
411
  selected_groups = gr.CheckboxGroup(
412
  choices=list(METRIC_GROUPS.keys()),
413
  value=list(METRIC_GROUPS.keys()),
414
  label="Metric Families (only active when Compact View is off)",
415
- info="Select which metric groups to display"
416
  )
417
 
418
- # Metric legend with color coding
419
  with gr.Accordion("Metric Groups Legend", open=False):
420
  gr.HTML(generate_metric_legend_html())
421
 
422
  try:
423
- # Initial dataframe - load once and cache
424
- debug_csv_schemas()
425
  initial_df = get_leaderboard()
426
  cached_df_state = gr.State(initial_df)
427
 
428
- ALWAYS_SHOW_MODELS = {'AFLOW', 'Alexandria', 'OQMD'}
429
- formatted_df = format_dataframe(initial_df[
430
- initial_df['training_set'].apply(lambda x: 'MP-20' in parse_training_set(x)) |
431
- initial_df['model_name'].isin(ALWAYS_SHOW_MODELS)
432
- ], show_percentage=True, selected_groups=list(METRIC_GROUPS.keys()), compact_view=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
433
 
434
  leaderboard_table = gr.Dataframe(
435
  label="GenBench Leaderboard",
436
  value=formatted_df,
437
  interactive=False,
438
  wrap=True,
439
- datatype=["html"] + [None] * (len(formatted_df.columns) - 1) if len(formatted_df.columns) > 0 else None,
440
- column_widths=["180px"] + ["160px"] * (len(formatted_df.columns) - 1) if len(formatted_df.columns) > 0 else None,
441
- show_fullscreen_button=True
442
  )
443
 
444
- # Update dataframe when options change (using cached data)
445
- show_percentage.change(
446
- fn=update_leaderboard,
447
- inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction, training_set_filter],
448
- outputs=leaderboard_table
449
- )
450
- selected_groups.change(
451
- fn=update_leaderboard,
452
- inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction, training_set_filter],
453
- outputs=leaderboard_table
454
- )
455
- compact_view.change(
456
- fn=update_leaderboard,
457
- inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction, training_set_filter],
458
- outputs=leaderboard_table
459
- )
460
- sort_by.change(
461
- fn=update_leaderboard,
462
- inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction, training_set_filter],
463
- outputs=leaderboard_table
464
- )
465
- sort_direction.change(
466
- fn=update_leaderboard,
467
- inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction, training_set_filter],
468
- outputs=leaderboard_table
469
- )
470
- training_set_filter.change(
471
- fn=update_leaderboard,
472
- inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction, training_set_filter],
473
- outputs=leaderboard_table
474
- )
475
 
476
  except Exception as e:
477
  traceback.print_exc()
@@ -488,16 +480,14 @@ Generative machine learning models hold great promise for accelerating materials
488
  Verified submissions mean the results came from a model submission rather than a CIF submission.
489
 
490
  Models marked as baselines appear below the separator line at the bottom of the table.
491
- """)
492
 
493
  with gr.TabItem("✉️ Submit", elem_id="boundary-benchmark-tab-table"):
494
- gr.Markdown(
495
- """
496
- # Materials Submission
497
- Upload a ZIP of CIFs with your structures. To ensure eligibility for the leaderboard, please provide exactly 2,500 representative structures.
498
- """
499
- )
500
- filename = gr.State(value=None)
501
 
502
  gr.LoginButton()
503
 
@@ -506,79 +496,91 @@ Models marked as baselines appear below the separator line at the bottom of the
506
  model_name_input = gr.Textbox(
507
  label="Model Name",
508
  placeholder="Enter your model name",
509
- info="Provide a name for your model/method"
510
  )
511
  email_input = gr.Textbox(
512
  label="Email Address",
513
  placeholder="Enter your email address",
514
- info="Contact email for correspondence about this submission"
515
  )
516
  paper_link_input = gr.Textbox(
517
  label="Paper Link (optional)",
518
  placeholder="https://arxiv.org/abs/...",
519
- info="Link to the paper describing your model/method"
520
  )
521
  hf_model_link_input = gr.Textbox(
522
  label="HuggingFace Model Link (optional)",
523
  placeholder="https://huggingface.co/...",
524
- info="Link to your model on HuggingFace"
525
  )
526
  problem_type = gr.Dropdown(PROBLEM_TYPES, label="Problem Type")
 
527
  with gr.Column():
528
  cif_file = gr.File(label="Upload a CSV, a pkl, or a ZIP of CIF files.")
529
  relaxed = gr.Checkbox(
530
  value=False,
531
  label="Structures are pre-relaxed",
532
- info="Check this box if your submitted structures have already been relaxed"
533
  )
534
  relaxation_settings_input = gr.Textbox(
535
  label="Relaxation Settings",
536
  placeholder="e.g., VASP PBE, 520 eV cutoff, ...",
537
  info="Describe the relaxation settings used",
538
- visible=False
539
  )
540
  training_dataset_input = gr.Dropdown(
541
  choices=TRAINING_DATASETS,
542
  label="Training Dataset",
543
  info="Select all datasets used for training",
544
- multiselect=True
545
  )
546
  training_dataset_other_input = gr.Textbox(
547
  label="Other Training Dataset",
548
  placeholder="Specify your training dataset",
549
  info="Provide details if you selected 'Others (must specify)'",
550
- visible=False
551
  )
552
 
553
- # Show/hide relaxation settings based on pre-relaxed checkbox
554
  relaxed.change(
555
  fn=lambda x: gr.update(visible=x),
556
  inputs=[relaxed],
557
- outputs=[relaxation_settings_input]
558
  )
559
 
560
- # Show/hide other dataset text box based on dropdown selection
561
  training_dataset_input.change(
562
  fn=lambda x: gr.update(visible="Others (must specify)" in (x or [])),
563
  inputs=[training_dataset_input],
564
- outputs=[training_dataset_other_input]
565
  )
566
 
567
  submit_btn = gr.Button("Submission")
568
  message = gr.Textbox(label="Status", lines=1, visible=False)
569
- # help message
570
- gr.Markdown("If you have issues with submission or using the leaderboard, please start a discussion in the Community tab of this Space.")
571
-
 
 
572
  submit_btn.click(
573
  submit_cif_files,
574
- inputs=[model_name_input, problem_type, cif_file, relaxed, relaxation_settings_input, training_dataset_input, training_dataset_other_input, paper_link_input, hf_model_link_input, email_input],
 
 
 
 
 
 
 
 
 
 
 
575
  outputs=[message, filename],
576
  ).then(
577
  fn=show_output_box,
578
  inputs=[message],
579
  outputs=[message],
580
  )
581
-
582
  return demo
583
 
584
 
 
1
  from pathlib import Path
2
  import ast
3
  import json
4
+ import traceback
5
  import pandas as pd
6
  import numpy as np
 
7
 
8
  import gradio as gr
9
+ from datasets import Features, Value, load_dataset
10
  from datetime import datetime
11
  import os
12
+
 
13
  from about import (
14
  PROBLEM_TYPES, TOKEN, CACHE_PATH, API, submissions_repo, results_repo,
15
  COLUMN_DISPLAY_NAMES, COUNT_BASED_METRICS, METRIC_GROUPS,
16
  METRIC_GROUP_COLORS, COLUMN_TO_GROUP, TRAINING_DATASETS
17
  )
18
 
19
+
20
+ RESULT_FEATURES = Features({
21
+ "run_name": Value("string"),
22
+ "timestamp": Value("string"),
23
+ "n_structures": Value("int64"),
24
+ "overall_valid_count": Value("int64"),
25
+ "charge_neutral_count": Value("int64"),
26
+ "distance_valid_count": Value("int64"),
27
+ "plausibility_valid_count": Value("int64"),
28
+ "unique_count": Value("int64"),
29
+ "novel_count": Value("int64"),
30
+ "mean_formation_energy": Value("float64"),
31
+ "formation_energy_std": Value("float64"),
32
+ "stability_mean_above_hull": Value("float64"),
33
+ "stability_std_e_above_hull": Value("float64"),
34
+ "stability_mean_ensemble_std": Value("float64"),
35
+ "mean_relaxation_RMSD": Value("float64"),
36
+ "relaxation_RMSE_std": Value("float64"),
37
+ "stable_count": Value("int64"),
38
+ "unique_in_stable_count": Value("int64"),
39
+ "sun_count": Value("int64"),
40
+ "metastable_count": Value("int64"),
41
+ "unique_in_metastable_count": Value("int64"),
42
+ "msun_count": Value("int64"),
43
+ "JSDistance": Value("float64"),
44
+ "MMD": Value("float64"),
45
+ "FrechetDistance": Value("float64"),
46
+ "element_diversity": Value("float64"),
47
+ "space_group_diversity": Value("float64"),
48
+ "site_diversity": Value("float64"),
49
+ "physical_size_diversity": Value("float64"),
50
+ "hhi_production_mean": Value("float64"),
51
+ "hhi_reserve_mean": Value("float64"),
52
+ "hhi_combined_mean": Value("float64"),
53
+ "model_name": Value("string"),
54
+ "relaxed": Value("bool"),
55
+ "training_set": Value("string"),
56
+ "paper_link": Value("string"),
57
+ "notes": Value("string"),
58
+ })
59
+
60
+
61
  def get_leaderboard():
62
+ ds = load_dataset(
63
+ results_repo,
64
+ data_files="*.csv",
65
+ split="train",
66
+ download_mode="force_redownload",
67
+ features=RESULT_FEATURES,
68
+ )
69
  full_df = pd.DataFrame(ds)
70
+
71
  if len(full_df) == 0:
72
+ return pd.DataFrame(columns=list(RESULT_FEATURES.keys()))
73
 
74
+ if "msun_count" in full_df.columns and "sun_count" in full_df.columns:
75
+ full_df["msun_plus_sun"] = full_df["msun_count"] + full_df["sun_count"]
 
76
 
77
+ if "msun_plus_sun" in full_df.columns:
78
+ full_df = full_df.sort_values(by="msun_plus_sun", ascending=False)
 
79
 
80
  return full_df
81
+
82
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  def format_dataframe(df, show_percentage=False, selected_groups=None, compact_view=True):
84
  """Format the dataframe with proper column names and optional percentages."""
85
  if len(df) == 0:
86
  return df
87
 
88
+ selected_cols = ["model_name"]
 
89
 
90
  if compact_view:
 
91
  from about import COMPACT_VIEW_COLUMNS
92
  selected_cols = [col for col in COMPACT_VIEW_COLUMNS if col in df.columns]
93
  else:
94
+ if "training_set" in df.columns:
95
+ selected_cols.append("training_set")
96
+ if "n_structures" in df.columns:
97
+ selected_cols.append("n_structures")
 
98
 
 
99
  if not selected_groups:
100
  selected_groups = list(METRIC_GROUPS.keys())
101
 
 
102
  for group in selected_groups:
103
  if group in METRIC_GROUPS:
104
  for col in METRIC_GROUPS[group]:
105
  if col in df.columns and col not in selected_cols:
106
  selected_cols.append(col)
107
 
 
108
  display_df = df[selected_cols].copy()
109
 
110
+ if "model_name" in display_df.columns:
 
 
111
  model_links = {
112
+ "CrystaLLM-pi": "https://huggingface.co/c-bone/CrystaLLM-pi_base",
113
+ "OMatG": "https://huggingface.co/OMatG/MP-20-DNG/tree/main/EncDec-ODE-Gamma",
114
  }
115
 
116
  def add_model_symbols(row):
117
+ name = row["model_name"]
118
  symbols = []
119
 
120
+ if "paper_link" in df.columns:
121
+ paper_val = row.get("paper_link", None)
 
122
  if paper_val and isinstance(paper_val, str) and paper_val.strip():
123
  symbols.append(f'<a href="{paper_val.strip()}" target="_blank">📄</a>')
124
 
125
+ if "relaxed" in df.columns and row.get("relaxed", False):
126
+ symbols.append("⚡")
 
127
 
128
+ if name in ["Alexandria", "OQMD"]:
129
+ symbols.append("★")
130
+ elif name == "AFLOW":
131
+ symbols.append("◆")
132
+ elif name in ["CrystaLLM-pi", "OMatG", "Zatom-1-WD"]:
133
+ symbols.append("✅")
 
 
 
134
 
135
  symbol_str = f" {' '.join(symbols)}" if symbols else ""
136
 
 
137
  if name in model_links:
138
  return f'<a href="{model_links[name]}" target="_blank">{name}</a>{symbol_str}'
139
  return f"{name}{symbol_str}"
140
 
141
+ display_df["model_name"] = df.apply(add_model_symbols, axis=1)
142
 
143
+ if "training_set" in display_df.columns:
 
144
  def format_training_set(val):
145
  if val is None or (isinstance(val, float) and np.isnan(val)):
146
+ return ""
147
  val = str(val).strip()
148
+ if val in ("[]", "", "nan", "None"):
149
+ return ""
150
+ val = val.strip("[]")
151
+ val = val.replace("'", "").replace('"', "")
 
152
  return val
 
153
 
154
+ display_df["training_set"] = display_df["training_set"].apply(format_training_set)
155
+
156
+ if show_percentage and "n_structures" in df.columns:
157
+ n_structures = df["n_structures"]
158
  for col in COUNT_BASED_METRICS:
159
  if col in display_df.columns:
160
+ display_df[col] = (df[col] / n_structures * 100).round(1).astype(str) + "%"
 
161
 
 
162
  for col in display_df.columns:
163
+ if display_df[col].dtype in ["float64", "float32"]:
164
  display_df[col] = display_df[col].round(4)
165
 
 
166
  baseline_indices = set()
167
+ if "notes" in df.columns:
168
+ is_baseline = df["notes"].fillna("").str.contains("baseline", case=False, na=False)
169
  non_baseline_df = display_df[~is_baseline]
170
  baseline_df = display_df[is_baseline]
171
  display_df = pd.concat([non_baseline_df, baseline_df]).reset_index(drop=True)
 
172
  baseline_indices = set(range(len(non_baseline_df), len(display_df)))
173
 
 
174
  display_df = display_df.rename(columns=COLUMN_DISPLAY_NAMES)
175
+ return apply_color_styling(display_df, selected_cols, baseline_indices)
176
 
 
 
 
 
177
 
178
  def apply_color_styling(display_df, original_cols, baseline_indices=None):
179
  """Apply background colors to dataframe based on metric groups using pandas Styler."""
 
181
  baseline_indices = set()
182
 
183
  def style_by_group(x):
184
+ styles = pd.DataFrame("", index=x.index, columns=x.columns)
 
185
 
 
186
  for i, display_col in enumerate(x.columns):
187
  if i < len(original_cols):
188
  original_col = original_cols[i]
 
 
189
  if original_col in COLUMN_TO_GROUP:
190
  group = COLUMN_TO_GROUP[original_col]
191
+ color = METRIC_GROUP_COLORS.get(group, "")
192
  if color:
193
+ styles[display_col] = f"background-color: {color}"
194
 
 
195
  if baseline_indices:
196
  first_baseline_idx = min(baseline_indices)
197
  for col in x.columns:
198
  current = styles.at[first_baseline_idx, col]
199
+ separator_style = "border-top: 3px solid #555"
200
+ styles.at[first_baseline_idx, col] = (
201
+ f"{current}; {separator_style}" if current else separator_style
202
+ )
203
 
204
  return styles
205
 
 
206
  return display_df.style.apply(style_by_group, axis=None)
207
 
208
+
209
  def parse_training_set(val):
210
+ """Parse a training_set value stored as a string like "['MP-20']" into a list."""
211
  try:
212
  return ast.literal_eval(str(val))
213
  except (ValueError, SyntaxError):
214
  return []
215
 
 
 
216
 
217
+ def update_leaderboard(show_percentage, selected_groups, compact_view, cached_df, sort_by, sort_direction, training_set_filter):
218
+ """Update the leaderboard based on user selections."""
 
219
  df_to_format = cached_df.copy()
220
 
221
+ ALWAYS_SHOW_MODELS = {"AFLOW", "Alexandria", "OQMD"}
222
+ if training_set_filter and training_set_filter != "All" and "training_set" in df_to_format.columns:
 
223
  mask = (
224
+ df_to_format["training_set"].apply(lambda x: training_set_filter in parse_training_set(x))
225
+ | df_to_format["model_name"].isin(ALWAYS_SHOW_MODELS)
226
  )
227
  df_to_format = df_to_format[mask]
228
 
 
229
  if sort_by and sort_by != "None":
 
230
  display_to_raw = {v: k for k, v in COLUMN_DISPLAY_NAMES.items()}
231
  raw_column_name = display_to_raw.get(sort_by, sort_by)
232
 
233
  if raw_column_name in df_to_format.columns:
234
+ ascending = sort_direction == "Ascending"
235
  df_to_format = df_to_format.sort_values(by=raw_column_name, ascending=ascending)
236
 
237
+ return format_dataframe(df_to_format, show_percentage, selected_groups, compact_view)
238
+
239
 
240
  def show_output_box(message):
241
  return gr.update(value=message, visible=True)
242
 
243
+
244
  def submit_cif_files(model_name, problem_type, cif_files, relaxed, relaxation_settings, training_datasets, training_dataset_other, paper_link, hf_model_link, email, profile: gr.OAuthProfile | None):
245
  """Submit structures to the leaderboard."""
246
  from huggingface_hub import upload_file
247
 
 
248
  if not model_name or not model_name.strip():
249
  return "Error: Please provide a model name.", None
250
 
 
264
  username = profile.username
265
  timestamp = datetime.now().isoformat()
266
 
 
267
  submission_data = {
268
  "username": username,
269
  "model_name": model_name.strip(),
 
276
  "hf_model_link": hf_model_link.strip() if hf_model_link else None,
277
  "email": email.strip(),
278
  "timestamp": timestamp,
279
+ "file_name": Path(cif_files).name,
280
  }
281
 
 
282
  submission_id = f"{username}_{model_name.strip().replace(' ', '_')}_{timestamp.replace(':', '-')}"
 
 
283
  file_path = Path(cif_files)
284
  uploaded_file_path = f"submissions/{submission_id}/{file_path.name}"
285
 
 
288
  path_in_repo=uploaded_file_path,
289
  repo_id=submissions_repo,
290
  token=TOKEN,
291
+ repo_type="dataset",
292
  )
293
 
 
294
  metadata_path = f"submissions/{submission_id}/metadata.json"
295
  import tempfile
296
+
297
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
298
  json.dump(submission_data, f, indent=2)
299
  temp_metadata_path = f.name
300
 
 
303
  path_in_repo=metadata_path,
304
  repo_id=submissions_repo,
305
  token=TOKEN,
306
+ repo_type="dataset",
307
  )
308
 
 
309
  os.unlink(temp_metadata_path)
310
 
311
  return f"Success! Submitted {model_name} for {problem_type} evaluation. Submission ID: {submission_id}", submission_id
312
 
313
  except Exception as e:
314
+ return f"Error during submission: {str(e)}", None
315
+
316
 
317
  def generate_metric_legend_html():
318
  """Generate HTML table with color-coded metric group legend."""
319
  metric_details = {
320
+ "Validity ↑": ("Valid, Charge Neutral, Distance Valid, Plausibility Valid", "↑ Higher is better"),
321
+ "Uniqueness & Novelty ↑": ("Unique, Novel", "↑ Higher is better"),
322
+ "Energy Metrics ↓": ("E Above Hull, Formation Energy, Relaxation RMSD (with std)", "↓ Lower is better"),
323
+ "Stability ↑": ("Stable, Unique in Stable, SUN", "↑ Higher is better"),
324
+ "Metastability ↑": ("Metastable, Unique in Metastable, MSUN", "↑ Higher is better"),
325
+ "Distribution ↓": ("JS Distance, MMD, FID", "↓ Lower is better"),
326
+ "Diversity ↑": ("Element, Space Group, Atomic Site, Crystal Size", "↑ Higher is better"),
327
+ "HHI ↓": ("HHI Production, HHI Reserve", "↓ Lower is better"),
328
  }
329
 
330
  html = '<table style="width: 100%; border-collapse: collapse;">'
331
+ html += "<thead><tr>"
332
  html += '<th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Color</th>'
333
  html += '<th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Group</th>'
334
  html += '<th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Metrics</th>'
335
  html += '<th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Direction</th>'
336
+ html += "</tr></thead><tbody>"
337
 
338
  for group, color in METRIC_GROUP_COLORS.items():
339
+ metrics, direction = metric_details.get(group, ("", ""))
340
+ group_name = group.replace("", "").replace("", "").strip()
341
 
342
+ html += "<tr>"
343
  html += f'<td style="border: 1px solid #ddd; padding: 8px;"><div style="width: 30px; height: 20px; background-color: {color}; border: 1px solid #999;"></div></td>'
344
  html += f'<td style="border: 1px solid #ddd; padding: 8px;"><strong>{group_name}</strong></td>'
345
  html += f'<td style="border: 1px solid #ddd; padding: 8px;">{metrics}</td>'
346
  html += f'<td style="border: 1px solid #ddd; padding: 8px;">{direction}</td>'
347
+ html += "</tr>"
348
 
349
+ html += "</tbody></table>"
350
  return html
351
 
352
+
353
  def gradio_interface() -> gr.Blocks:
354
  with gr.Blocks() as demo:
355
  gr.Markdown("""
 
361
 
362
  📄 **Paper**: [arXiv](https://arxiv.org/abs/2512.04562) | 💻 **Code**: [GitHub](https://github.com/LeMaterial/lemat-genbench) | 📧 **Contact**: siddharth.betala [at] entalpic.ai, alexandre.duval [at] entalpic.ai
363
  """)
364
+
365
  with gr.Tabs(elem_classes="tab-buttons"):
366
  with gr.TabItem("🚀 Leaderboard", elem_id="boundary-benchmark-tab-table"):
367
  gr.Markdown("# LeMat-GenBench")
368
 
 
369
  with gr.Row():
370
  with gr.Column(scale=1):
371
  compact_view = gr.Checkbox(
372
  value=True,
373
  label="Compact View",
374
+ info="Show only key metrics",
375
  )
376
  show_percentage = gr.Checkbox(
377
  value=True,
378
  label="Show as Percentages",
379
+ info="Display count-based metrics as percentages of total structures",
380
  )
381
+
382
  with gr.Column(scale=1):
383
+ sort_choices = ["None"] + [
384
+ COLUMN_DISPLAY_NAMES.get(col, col)
385
+ for col in COLUMN_DISPLAY_NAMES.keys()
386
+ ]
387
  sort_by = gr.Dropdown(
388
  choices=sort_choices,
389
  value="None",
390
  label="Sort By",
391
+ info="Select column to sort by (default: sorted by MSUN+SUN descending)",
392
  )
393
  sort_direction = gr.Radio(
394
  choices=["Ascending", "Descending"],
395
  value="Descending",
396
+ label="Sort Direction",
397
  )
398
+
399
  with gr.Column(scale=1):
400
  training_set_filter = gr.Dropdown(
401
  choices=["All"] + TRAINING_DATASETS,
402
  value="MP-20",
403
  label="Filter by Training Set",
404
+ info="Show only models trained on a specific dataset",
405
  )
406
+
407
  with gr.Column(scale=2):
408
  selected_groups = gr.CheckboxGroup(
409
  choices=list(METRIC_GROUPS.keys()),
410
  value=list(METRIC_GROUPS.keys()),
411
  label="Metric Families (only active when Compact View is off)",
412
+ info="Select which metric groups to display",
413
  )
414
 
 
415
  with gr.Accordion("Metric Groups Legend", open=False):
416
  gr.HTML(generate_metric_legend_html())
417
 
418
  try:
 
 
419
  initial_df = get_leaderboard()
420
  cached_df_state = gr.State(initial_df)
421
 
422
+ ALWAYS_SHOW_MODELS = {"AFLOW", "Alexandria", "OQMD"}
423
+ filtered_initial_df = initial_df[
424
+ initial_df["training_set"].apply(lambda x: "MP-20" in parse_training_set(x))
425
+ | initial_df["model_name"].isin(ALWAYS_SHOW_MODELS)
426
+ ]
427
+
428
+ formatted_df = format_dataframe(
429
+ filtered_initial_df,
430
+ show_percentage=True,
431
+ selected_groups=list(METRIC_GROUPS.keys()),
432
+ compact_view=True,
433
+ )
434
+
435
+ formatted_columns = (
436
+ list(formatted_df.data.columns)
437
+ if hasattr(formatted_df, "data")
438
+ else list(formatted_df.columns)
439
+ )
440
 
441
  leaderboard_table = gr.Dataframe(
442
  label="GenBench Leaderboard",
443
  value=formatted_df,
444
  interactive=False,
445
  wrap=True,
446
+ datatype=["html"] + [None] * (len(formatted_columns) - 1) if formatted_columns else None,
447
+ column_widths=["180px"] + ["160px"] * (len(formatted_columns) - 1) if formatted_columns else None,
448
+ show_fullscreen_button=True,
449
  )
450
 
451
+ inputs = [
452
+ show_percentage,
453
+ selected_groups,
454
+ compact_view,
455
+ cached_df_state,
456
+ sort_by,
457
+ sort_direction,
458
+ training_set_filter,
459
+ ]
460
+
461
+ show_percentage.change(fn=update_leaderboard, inputs=inputs, outputs=leaderboard_table)
462
+ selected_groups.change(fn=update_leaderboard, inputs=inputs, outputs=leaderboard_table)
463
+ compact_view.change(fn=update_leaderboard, inputs=inputs, outputs=leaderboard_table)
464
+ sort_by.change(fn=update_leaderboard, inputs=inputs, outputs=leaderboard_table)
465
+ sort_direction.change(fn=update_leaderboard, inputs=inputs, outputs=leaderboard_table)
466
+ training_set_filter.change(fn=update_leaderboard, inputs=inputs, outputs=leaderboard_table)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
467
 
468
  except Exception as e:
469
  traceback.print_exc()
 
480
  Verified submissions mean the results came from a model submission rather than a CIF submission.
481
 
482
  Models marked as baselines appear below the separator line at the bottom of the table.
483
+ """)
484
 
485
  with gr.TabItem("✉️ Submit", elem_id="boundary-benchmark-tab-table"):
486
+ gr.Markdown("""
487
+ # Materials Submission
488
+ Upload a ZIP of CIFs with your structures. To ensure eligibility for the leaderboard, please provide exactly 2,500 representative structures.
489
+ """)
490
+ filename = gr.State(value=None)
 
 
491
 
492
  gr.LoginButton()
493
 
 
496
  model_name_input = gr.Textbox(
497
  label="Model Name",
498
  placeholder="Enter your model name",
499
+ info="Provide a name for your model/method",
500
  )
501
  email_input = gr.Textbox(
502
  label="Email Address",
503
  placeholder="Enter your email address",
504
+ info="Contact email for correspondence about this submission",
505
  )
506
  paper_link_input = gr.Textbox(
507
  label="Paper Link (optional)",
508
  placeholder="https://arxiv.org/abs/...",
509
+ info="Link to the paper describing your model/method",
510
  )
511
  hf_model_link_input = gr.Textbox(
512
  label="HuggingFace Model Link (optional)",
513
  placeholder="https://huggingface.co/...",
514
+ info="Link to your model on HuggingFace",
515
  )
516
  problem_type = gr.Dropdown(PROBLEM_TYPES, label="Problem Type")
517
+
518
  with gr.Column():
519
  cif_file = gr.File(label="Upload a CSV, a pkl, or a ZIP of CIF files.")
520
  relaxed = gr.Checkbox(
521
  value=False,
522
  label="Structures are pre-relaxed",
523
+ info="Check this box if your submitted structures have already been relaxed",
524
  )
525
  relaxation_settings_input = gr.Textbox(
526
  label="Relaxation Settings",
527
  placeholder="e.g., VASP PBE, 520 eV cutoff, ...",
528
  info="Describe the relaxation settings used",
529
+ visible=False,
530
  )
531
  training_dataset_input = gr.Dropdown(
532
  choices=TRAINING_DATASETS,
533
  label="Training Dataset",
534
  info="Select all datasets used for training",
535
+ multiselect=True,
536
  )
537
  training_dataset_other_input = gr.Textbox(
538
  label="Other Training Dataset",
539
  placeholder="Specify your training dataset",
540
  info="Provide details if you selected 'Others (must specify)'",
541
+ visible=False,
542
  )
543
 
 
544
  relaxed.change(
545
  fn=lambda x: gr.update(visible=x),
546
  inputs=[relaxed],
547
+ outputs=[relaxation_settings_input],
548
  )
549
 
 
550
  training_dataset_input.change(
551
  fn=lambda x: gr.update(visible="Others (must specify)" in (x or [])),
552
  inputs=[training_dataset_input],
553
+ outputs=[training_dataset_other_input],
554
  )
555
 
556
  submit_btn = gr.Button("Submission")
557
  message = gr.Textbox(label="Status", lines=1, visible=False)
558
+
559
+ gr.Markdown(
560
+ "If you have issues with submission or using the leaderboard, please start a discussion in the Community tab of this Space."
561
+ )
562
+
563
  submit_btn.click(
564
  submit_cif_files,
565
+ inputs=[
566
+ model_name_input,
567
+ problem_type,
568
+ cif_file,
569
+ relaxed,
570
+ relaxation_settings_input,
571
+ training_dataset_input,
572
+ training_dataset_other_input,
573
+ paper_link_input,
574
+ hf_model_link_input,
575
+ email_input,
576
+ ],
577
  outputs=[message, filename],
578
  ).then(
579
  fn=show_output_box,
580
  inputs=[message],
581
  outputs=[message],
582
  )
583
+
584
  return demo
585
 
586