Upload 2 files
Browse files
app.py
CHANGED
|
@@ -2891,8 +2891,7 @@ def _chat_stream(history, version, ckpt_label, mode_key, use_custom,
|
|
| 2891 |
|
| 2892 |
def _compare_fn(prompt, selected_versions, mode_key, use_custom,
|
| 2893 |
temperature, top_k, top_p, min_p, rep_penalty,
|
| 2894 |
-
ngram_size, soft_cap, loop_pen, max_tokens, ctx_win, raw_mode
|
| 2895 |
-
progress=gr.Progress(track_tqdm=True)):
|
| 2896 |
if use_custom:
|
| 2897 |
cfg = {
|
| 2898 |
"sft_mode": not raw_mode,
|
|
@@ -2906,16 +2905,39 @@ def _compare_fn(prompt, selected_versions, mode_key, use_custom,
|
|
| 2906 |
cfg = dict(MODES[mode_key])
|
| 2907 |
|
| 2908 |
all_versions = _collection_versions()
|
| 2909 |
-
|
| 2910 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2911 |
labels = _ckpt_labels(version)
|
| 2912 |
ckpt_label = labels[0] if labels else None
|
| 2913 |
if not ckpt_label:
|
| 2914 |
-
|
|
|
|
| 2915 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2916 |
try:
|
| 2917 |
bundle = _load_bundle(version, ckpt_label)
|
| 2918 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2919 |
model=bundle["model"], tokenizer=bundle["tokenizer"],
|
| 2920 |
prompt=prompt, device=str(bundle["device"]),
|
| 2921 |
sft_mode=cfg["sft_mode"],
|
|
@@ -2927,14 +2949,12 @@ def _compare_fn(prompt, selected_versions, mode_key, use_custom,
|
|
| 2927 |
loop_penalty=cfg["loop_penalty"],
|
| 2928 |
max_new_tokens=cfg["max_new_tokens"],
|
| 2929 |
context_window=cfg["context_window"],
|
| 2930 |
-
|
| 2931 |
-
|
| 2932 |
-
|
| 2933 |
except Exception as e:
|
| 2934 |
-
|
| 2935 |
-
|
| 2936 |
-
# Return one value per discovered version (empty string if not selected/run)
|
| 2937 |
-
return [results.get(v, "") for v in all_versions]
|
| 2938 |
|
| 2939 |
|
| 2940 |
# ---- benchmark ----
|
|
@@ -3090,24 +3110,25 @@ with gr.Blocks(title="CompactAI Models") as demo:
|
|
| 3090 |
|
| 3091 |
# ββ Compare βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 3092 |
with gr.Tab("Compare All Models"):
|
| 3093 |
-
gr.Markdown("Run the same prompt on
|
| 3094 |
with gr.Row():
|
| 3095 |
-
cmp_prompt = gr.Textbox(label="Prompt", placeholder="Enter a promptβ¦", lines=
|
| 3096 |
with gr.Column(scale=1):
|
| 3097 |
cmp_models = gr.CheckboxGroup(
|
| 3098 |
-
choices=_initial_versions, value=_initial_versions, label="Models"
|
| 3099 |
)
|
| 3100 |
cmp_mode = gr.Dropdown(
|
| 3101 |
choices=_mode_keys, value="chat-coherent", label="Mode preset"
|
| 3102 |
)
|
| 3103 |
cmp_use_custom, cmp_temp, cmp_topk, cmp_topp, cmp_minp, cmp_rep, cmp_ng, cmp_cap, cmp_lp, cmp_maxt, cmp_ctx, cmp_raw = _advanced_block()
|
| 3104 |
-
cmp_run = gr.Button("Run comparison", variant="primary")
|
| 3105 |
-
|
| 3106 |
-
|
| 3107 |
-
|
| 3108 |
-
|
| 3109 |
-
|
| 3110 |
-
|
|
|
|
| 3111 |
|
| 3112 |
cmp_run.click(
|
| 3113 |
_compare_fn,
|
|
|
|
| 2891 |
|
| 2892 |
def _compare_fn(prompt, selected_versions, mode_key, use_custom,
|
| 2893 |
temperature, top_k, top_p, min_p, rep_penalty,
|
| 2894 |
+
ngram_size, soft_cap, loop_pen, max_tokens, ctx_win, raw_mode):
|
|
|
|
| 2895 |
if use_custom:
|
| 2896 |
cfg = {
|
| 2897 |
"sft_mode": not raw_mode,
|
|
|
|
| 2905 |
cfg = dict(MODES[mode_key])
|
| 2906 |
|
| 2907 |
all_versions = _collection_versions()
|
| 2908 |
+
selected = set(selected_versions or [])
|
| 2909 |
+
state = {v: ("β³ Queuedβ¦" if v in selected else "") for v in all_versions}
|
| 2910 |
+
|
| 2911 |
+
def _emit():
|
| 2912 |
+
return [state[v] for v in all_versions]
|
| 2913 |
+
|
| 2914 |
+
yield _emit()
|
| 2915 |
+
|
| 2916 |
+
for version in all_versions:
|
| 2917 |
+
if version not in selected:
|
| 2918 |
+
continue
|
| 2919 |
labels = _ckpt_labels(version)
|
| 2920 |
ckpt_label = labels[0] if labels else None
|
| 2921 |
if not ckpt_label:
|
| 2922 |
+
state[version] = "[No checkpoint found]"
|
| 2923 |
+
yield _emit()
|
| 2924 |
continue
|
| 2925 |
+
|
| 2926 |
+
state[version] = "β³ Loadingβ¦"
|
| 2927 |
+
yield _emit()
|
| 2928 |
+
|
| 2929 |
try:
|
| 2930 |
bundle = _load_bundle(version, ckpt_label)
|
| 2931 |
+
except Exception as e:
|
| 2932 |
+
state[version] = f"[Load error: {e}]"
|
| 2933 |
+
yield _emit()
|
| 2934 |
+
continue
|
| 2935 |
+
|
| 2936 |
+
state[version] = ""
|
| 2937 |
+
yield _emit()
|
| 2938 |
+
|
| 2939 |
+
try:
|
| 2940 |
+
for partial in generate_stream(
|
| 2941 |
model=bundle["model"], tokenizer=bundle["tokenizer"],
|
| 2942 |
prompt=prompt, device=str(bundle["device"]),
|
| 2943 |
sft_mode=cfg["sft_mode"],
|
|
|
|
| 2949 |
loop_penalty=cfg["loop_penalty"],
|
| 2950 |
max_new_tokens=cfg["max_new_tokens"],
|
| 2951 |
context_window=cfg["context_window"],
|
| 2952 |
+
):
|
| 2953 |
+
state[version] = partial
|
| 2954 |
+
yield _emit()
|
| 2955 |
except Exception as e:
|
| 2956 |
+
state[version] = f"[Generation error: {e}]"
|
| 2957 |
+
yield _emit()
|
|
|
|
|
|
|
| 2958 |
|
| 2959 |
|
| 2960 |
# ---- benchmark ----
|
|
|
|
| 3110 |
|
| 3111 |
# ββ Compare βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 3112 |
with gr.Tab("Compare All Models"):
|
| 3113 |
+
gr.Markdown("Run the same prompt on every selected model. Outputs stream live one model at a time.")
|
| 3114 |
with gr.Row():
|
| 3115 |
+
cmp_prompt = gr.Textbox(label="Prompt", placeholder="Enter a promptβ¦", lines=4, scale=3)
|
| 3116 |
with gr.Column(scale=1):
|
| 3117 |
cmp_models = gr.CheckboxGroup(
|
| 3118 |
+
choices=_initial_versions, value=_initial_versions, label="Models to run"
|
| 3119 |
)
|
| 3120 |
cmp_mode = gr.Dropdown(
|
| 3121 |
choices=_mode_keys, value="chat-coherent", label="Mode preset"
|
| 3122 |
)
|
| 3123 |
cmp_use_custom, cmp_temp, cmp_topk, cmp_topp, cmp_minp, cmp_rep, cmp_ng, cmp_cap, cmp_lp, cmp_maxt, cmp_ctx, cmp_raw = _advanced_block()
|
| 3124 |
+
cmp_run = gr.Button("βΆ Run comparison", variant="primary")
|
| 3125 |
+
|
| 3126 |
+
# 2-column grid of output boxes
|
| 3127 |
+
cmp_outputs = []
|
| 3128 |
+
for row_start in range(0, len(_initial_versions), 2):
|
| 3129 |
+
with gr.Row():
|
| 3130 |
+
for v in _initial_versions[row_start:row_start + 2]:
|
| 3131 |
+
cmp_outputs.append(gr.Textbox(label=v, lines=10, interactive=False))
|
| 3132 |
|
| 3133 |
cmp_run.click(
|
| 3134 |
_compare_fn,
|