CompactAI commited on
Commit
fc89df5
Β·
verified Β·
1 Parent(s): c65f6ae

Upload 2 files

Browse files
Files changed (1) hide show
  1. app.py +44 -23
app.py CHANGED
@@ -2891,8 +2891,7 @@ def _chat_stream(history, version, ckpt_label, mode_key, use_custom,
2891
 
2892
  def _compare_fn(prompt, selected_versions, mode_key, use_custom,
2893
  temperature, top_k, top_p, min_p, rep_penalty,
2894
- ngram_size, soft_cap, loop_pen, max_tokens, ctx_win, raw_mode,
2895
- progress=gr.Progress(track_tqdm=True)):
2896
  if use_custom:
2897
  cfg = {
2898
  "sft_mode": not raw_mode,
@@ -2906,16 +2905,39 @@ def _compare_fn(prompt, selected_versions, mode_key, use_custom,
2906
  cfg = dict(MODES[mode_key])
2907
 
2908
  all_versions = _collection_versions()
2909
- results = {}
2910
- for version in progress.tqdm(selected_versions or [], desc="Running models"):
 
 
 
 
 
 
 
 
 
2911
  labels = _ckpt_labels(version)
2912
  ckpt_label = labels[0] if labels else None
2913
  if not ckpt_label:
2914
- results[version] = "[No checkpoint found]"
 
2915
  continue
 
 
 
 
2916
  try:
2917
  bundle = _load_bundle(version, ckpt_label)
2918
- out = generate(
 
 
 
 
 
 
 
 
 
2919
  model=bundle["model"], tokenizer=bundle["tokenizer"],
2920
  prompt=prompt, device=str(bundle["device"]),
2921
  sft_mode=cfg["sft_mode"],
@@ -2927,14 +2949,12 @@ def _compare_fn(prompt, selected_versions, mode_key, use_custom,
2927
  loop_penalty=cfg["loop_penalty"],
2928
  max_new_tokens=cfg["max_new_tokens"],
2929
  context_window=cfg["context_window"],
2930
- stream=False,
2931
- )
2932
- results[version] = out
2933
  except Exception as e:
2934
- results[version] = f"[Error: {e}]"
2935
-
2936
- # Return one value per discovered version (empty string if not selected/run)
2937
- return [results.get(v, "") for v in all_versions]
2938
 
2939
 
2940
  # ---- benchmark ----
@@ -3090,24 +3110,25 @@ with gr.Blocks(title="CompactAI Models") as demo:
3090
 
3091
  # ── Compare ───────────────────────────────────────────────────────────────
3092
  with gr.Tab("Compare All Models"):
3093
- gr.Markdown("Run the same prompt on multiple models and compare side-by-side.")
3094
  with gr.Row():
3095
- cmp_prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt…", lines=3, scale=3)
3096
  with gr.Column(scale=1):
3097
  cmp_models = gr.CheckboxGroup(
3098
- choices=_initial_versions, value=_initial_versions, label="Models"
3099
  )
3100
  cmp_mode = gr.Dropdown(
3101
  choices=_mode_keys, value="chat-coherent", label="Mode preset"
3102
  )
3103
  cmp_use_custom, cmp_temp, cmp_topk, cmp_topp, cmp_minp, cmp_rep, cmp_ng, cmp_cap, cmp_lp, cmp_maxt, cmp_ctx, cmp_raw = _advanced_block()
3104
- cmp_run = gr.Button("Run comparison", variant="primary")
3105
-
3106
- with gr.Row():
3107
- cmp_outputs = [
3108
- gr.Textbox(label=v, lines=8)
3109
- for v in _initial_versions
3110
- ]
 
3111
 
3112
  cmp_run.click(
3113
  _compare_fn,
 
2891
 
2892
  def _compare_fn(prompt, selected_versions, mode_key, use_custom,
2893
  temperature, top_k, top_p, min_p, rep_penalty,
2894
+ ngram_size, soft_cap, loop_pen, max_tokens, ctx_win, raw_mode):
 
2895
  if use_custom:
2896
  cfg = {
2897
  "sft_mode": not raw_mode,
 
2905
  cfg = dict(MODES[mode_key])
2906
 
2907
  all_versions = _collection_versions()
2908
+ selected = set(selected_versions or [])
2909
+ state = {v: ("⏳ Queued…" if v in selected else "") for v in all_versions}
2910
+
2911
+ def _emit():
2912
+ return [state[v] for v in all_versions]
2913
+
2914
+ yield _emit()
2915
+
2916
+ for version in all_versions:
2917
+ if version not in selected:
2918
+ continue
2919
  labels = _ckpt_labels(version)
2920
  ckpt_label = labels[0] if labels else None
2921
  if not ckpt_label:
2922
+ state[version] = "[No checkpoint found]"
2923
+ yield _emit()
2924
  continue
2925
+
2926
+ state[version] = "⏳ Loading…"
2927
+ yield _emit()
2928
+
2929
  try:
2930
  bundle = _load_bundle(version, ckpt_label)
2931
+ except Exception as e:
2932
+ state[version] = f"[Load error: {e}]"
2933
+ yield _emit()
2934
+ continue
2935
+
2936
+ state[version] = ""
2937
+ yield _emit()
2938
+
2939
+ try:
2940
+ for partial in generate_stream(
2941
  model=bundle["model"], tokenizer=bundle["tokenizer"],
2942
  prompt=prompt, device=str(bundle["device"]),
2943
  sft_mode=cfg["sft_mode"],
 
2949
  loop_penalty=cfg["loop_penalty"],
2950
  max_new_tokens=cfg["max_new_tokens"],
2951
  context_window=cfg["context_window"],
2952
+ ):
2953
+ state[version] = partial
2954
+ yield _emit()
2955
  except Exception as e:
2956
+ state[version] = f"[Generation error: {e}]"
2957
+ yield _emit()
 
 
2958
 
2959
 
2960
  # ---- benchmark ----
 
3110
 
3111
  # ── Compare ───────────────────────────────────────────────────────────────
3112
  with gr.Tab("Compare All Models"):
3113
+ gr.Markdown("Run the same prompt on every selected model. Outputs stream live one model at a time.")
3114
  with gr.Row():
3115
+ cmp_prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt…", lines=4, scale=3)
3116
  with gr.Column(scale=1):
3117
  cmp_models = gr.CheckboxGroup(
3118
+ choices=_initial_versions, value=_initial_versions, label="Models to run"
3119
  )
3120
  cmp_mode = gr.Dropdown(
3121
  choices=_mode_keys, value="chat-coherent", label="Mode preset"
3122
  )
3123
  cmp_use_custom, cmp_temp, cmp_topk, cmp_topp, cmp_minp, cmp_rep, cmp_ng, cmp_cap, cmp_lp, cmp_maxt, cmp_ctx, cmp_raw = _advanced_block()
3124
+ cmp_run = gr.Button("β–Ά Run comparison", variant="primary")
3125
+
3126
+ # 2-column grid of output boxes
3127
+ cmp_outputs = []
3128
+ for row_start in range(0, len(_initial_versions), 2):
3129
+ with gr.Row():
3130
+ for v in _initial_versions[row_start:row_start + 2]:
3131
+ cmp_outputs.append(gr.Textbox(label=v, lines=10, interactive=False))
3132
 
3133
  cmp_run.click(
3134
  _compare_fn,