Spaces:
Running
Running
| """ | |
| HTML visualization generator for UncheatableEval. | |
| Generates interactive HTML visualizations comparing byte-level losses between two models. | |
| """ | |
| import bisect | |
| import json | |
| import math | |
| import re | |
| from pathlib import Path | |
| from typing import List, Tuple, Optional, Set | |
| import numpy as np | |
| from core.escaping import escape_json_for_script | |
| from core.render_model import RenderModel, TokenInfo, build_display | |
| from visualization.render import render_page | |
| from core.helpers import TokenizerBytesConverter | |
| ASSETS_DIR = Path(__file__).resolve().parent / "assets" | |
| # Compression rate conversion factor | |
| COMPRESSION_RATE_FACTOR = (1.0 / math.log(2.0)) * 0.125 * 100.0 | |
| # Global tokenizers (lazy loaded) | |
| _qwen_tokenizer = None | |
| _rwkv_tokenizer = None | |
| _token_bytes_converter_cache = {} | |
| def get_qwen_tokenizer(): | |
| """Lazy load Qwen tokenizer.""" | |
| global _qwen_tokenizer | |
| if _qwen_tokenizer is None: | |
| _qwen_tokenizer = TokenizerBytesConverter("Qwen/Qwen3-0.6B-Base") | |
| return _qwen_tokenizer | |
| def get_rwkv_tokenizer(): | |
| """Lazy load RWKV tokenizer.""" | |
| global _rwkv_tokenizer | |
| if _rwkv_tokenizer is None: | |
| from rwkv.rwkv_tokenizer import TRIE_TOKENIZER | |
| import os | |
| script_dir = os.path.dirname(os.path.abspath(__file__)) | |
| vocab_path = os.path.join(os.path.dirname(script_dir), "support", "rwkv_vocab_v20230424.txt") | |
| _rwkv_tokenizer = TRIE_TOKENIZER(vocab_path) | |
| return _rwkv_tokenizer | |
| def get_tokenizer_boundaries(text: str, tokenizer, is_rwkv: bool = False) -> Set[int]: | |
| """Get token boundaries (byte positions) for a given text.""" | |
| boundaries = set() | |
| boundaries.add(0) | |
| if is_rwkv: | |
| tokenized = tokenizer.encode(text) | |
| if hasattr(tokenized, "ids"): | |
| token_ids = tokenized.ids | |
| else: | |
| token_ids = tokenized | |
| byte_pos = 0 | |
| for token_id in token_ids: | |
| token_bytes = tokenizer.decodeBytes([token_id]) | |
| byte_pos += len(token_bytes) | |
| boundaries.add(byte_pos) | |
| else: | |
| token_bytes_list = tokenizer.encode_to_bytes(text) | |
| byte_pos = 0 | |
| for token_bytes in token_bytes_list: | |
| byte_pos += len(token_bytes) | |
| boundaries.add(byte_pos) | |
| return boundaries | |
| def get_token_info_for_text(text: str) -> dict: | |
| """Get detailed token information for each byte position.""" | |
| qwen_tokenizer = get_qwen_tokenizer() | |
| rwkv_tokenizer = get_rwkv_tokenizer() | |
| # Get Qwen tokens with positions | |
| qwen_tokens = [] | |
| byte_to_qwen = {} | |
| # Keep both token id (vocab id) and decoded bytes so the tooltip can show true token ids. | |
| qwen_id_and_bytes = qwen_tokenizer.encode_to_ids_and_bytes(text) | |
| byte_pos = 0 | |
| for idx, (token_id, token_bytes) in enumerate(qwen_id_and_bytes): | |
| start = byte_pos | |
| token_bytes_blob = bytes(token_bytes) | |
| end = byte_pos + len(token_bytes_blob) | |
| qwen_tokens.append((start, end, token_id, token_bytes_blob)) | |
| byte_to_qwen[start] = idx | |
| byte_pos = end | |
| # Get RWKV tokens with positions | |
| rwkv_tokens = [] | |
| byte_to_rwkv = {} | |
| tokenized = rwkv_tokenizer.encode(text) | |
| if hasattr(tokenized, "ids"): | |
| token_ids = tokenized.ids | |
| else: | |
| token_ids = tokenized | |
| byte_pos = 0 | |
| for idx, token_id in enumerate(token_ids): | |
| token_bytes = rwkv_tokenizer.decodeBytes([token_id]) | |
| start = byte_pos | |
| end = byte_pos + len(token_bytes) | |
| rwkv_tokens.append((start, end, token_id, token_bytes)) | |
| byte_to_rwkv[start] = idx | |
| byte_pos = end | |
| # Get common boundaries, but keep only UTF-8 codepoint boundaries | |
| qwen_boundaries = set([0] + [t[1] for t in qwen_tokens]) | |
| rwkv_boundaries = set([0] + [t[1] for t in rwkv_tokens]) | |
| utf8_boundaries = set([0]) | |
| whitespace_boundaries = set() | |
| linebreak_boundaries = set() | |
| byte_pos = 0 | |
| for ch in text: | |
| ch_bytes = ch.encode("utf-8") | |
| byte_pos += len(ch_bytes) | |
| utf8_boundaries.add(byte_pos) | |
| if ch.isspace(): | |
| whitespace_boundaries.add(byte_pos) | |
| if ch in ("\n", "\r"): | |
| linebreak_boundaries.add(byte_pos) | |
| common_boundaries = sorted(qwen_boundaries & rwkv_boundaries & utf8_boundaries) | |
| # Ensure we always include the end boundary | |
| text_end = len(text.encode("utf-8")) | |
| if text_end not in common_boundaries: | |
| common_boundaries.append(text_end) | |
| common_boundaries = sorted(common_boundaries) | |
| # Refine overly large segments to avoid giant spans in the UI. | |
| max_segment_bytes = 24 | |
| utf8_sorted = sorted(utf8_boundaries) | |
| linebreak_sorted = sorted(linebreak_boundaries) | |
| def split_by_max(start: int, end: int) -> List[int]: | |
| if end - start <= max_segment_bytes: | |
| return [end] | |
| left = bisect.bisect_right(utf8_sorted, start) | |
| right = bisect.bisect_left(utf8_sorted, end) | |
| candidates = utf8_sorted[left:right] | |
| if not candidates: | |
| return [end] | |
| out = [] | |
| pos = start | |
| idx = 0 | |
| while pos < end: | |
| limit = min(end, pos + max_segment_bytes) | |
| j = bisect.bisect_right(candidates, limit) - 1 | |
| if j < idx: | |
| out.append(end) | |
| break | |
| split_at = None | |
| for k in range(j, idx - 1, -1): | |
| if candidates[k] in whitespace_boundaries: | |
| split_at = candidates[k] | |
| j = k | |
| break | |
| if split_at is None: | |
| split_at = candidates[j] | |
| if split_at <= pos: | |
| split_at = candidates[j] | |
| out.append(split_at) | |
| pos = split_at | |
| idx = j + 1 | |
| if pos >= end: | |
| break | |
| if idx >= len(candidates): | |
| out.append(end) | |
| break | |
| if not out: | |
| out = [end] | |
| elif out[-1] != end: | |
| out.append(end) | |
| return out | |
| def split_segment(start: int, end: int) -> List[int]: | |
| if start >= end: | |
| return [] | |
| lb_left = bisect.bisect_right(linebreak_sorted, start) | |
| lb_right = bisect.bisect_left(linebreak_sorted, end) | |
| linebreaks = linebreak_sorted[lb_left:lb_right] | |
| if not linebreaks: | |
| return split_by_max(start, end) | |
| out = [] | |
| seg_start = start | |
| for lb in linebreaks: | |
| out.extend(split_by_max(seg_start, lb)) | |
| seg_start = lb | |
| out.extend(split_by_max(seg_start, end)) | |
| return out | |
| refined_boundaries = [common_boundaries[0]] if common_boundaries else [0] | |
| for i in range(len(common_boundaries) - 1): | |
| start = common_boundaries[i] | |
| end = common_boundaries[i + 1] | |
| refined_boundaries.extend(split_segment(start, end)) | |
| common_boundaries = sorted(set(refined_boundaries)) | |
| return { | |
| "common_boundaries": common_boundaries, | |
| "qwen_tokens": qwen_tokens, | |
| "rwkv_tokens": rwkv_tokens, | |
| "byte_to_qwen": byte_to_qwen, | |
| "byte_to_rwkv": byte_to_rwkv, | |
| } | |
| def generate_comparison_html( | |
| text: str, | |
| byte_losses_a: List[float], | |
| byte_losses_b: List[float], | |
| model_a_name: str, | |
| model_b_name: str, | |
| topk_predictions_a: Optional[List] = None, | |
| topk_predictions_b: Optional[List] = None, | |
| tokenizer_a=None, | |
| tokenizer_b=None, | |
| model_type_a: str = "hf", | |
| model_type_b: str = "rwkv7", | |
| token_info_override: Optional[dict] = None, | |
| return_render_model: bool = False, | |
| ) -> str: | |
| """ | |
| Generate an interactive HTML visualization comparing two models. | |
| Args: | |
| text: The input text that was evaluated | |
| byte_losses_a: Per-byte losses from model A | |
| byte_losses_b: Per-byte losses from model B | |
| model_a_name: Display name for model A | |
| model_b_name: Display name for model B | |
| topk_predictions_a: Top-k predictions from model A | |
| topk_predictions_b: Top-k predictions from model B | |
| tokenizer_a: Tokenizer for model A | |
| tokenizer_b: Tokenizer for model B | |
| model_type_a: Type of model A ("hf" or "rwkv7") | |
| model_type_b: Type of model B ("hf" or "rwkv7") | |
| token_info_override: Optional precomputed token info (for offline tests). | |
| return_render_model: If True, return (html, render_model_dict) | |
| Returns: | |
| HTML string with interactive visualization, or (html, render_model_dict) if return_render_model=True | |
| """ | |
| def decode_token(token_id: int, tokenizer, model_type: str) -> Tuple[str, bool]: | |
| """Decode a single token ID to text using the appropriate tokenizer. | |
| Returns (text, is_raw_bytes). | |
| """ | |
| def bytes_to_hex_str(byte_values) -> str: | |
| if isinstance(byte_values, list): | |
| byte_values = bytes(byte_values) | |
| return "".join([f"\\x{b:02x}" for b in byte_values]) | |
| def get_bytes_converter(tokenizer): | |
| if tokenizer is None: | |
| return None | |
| key = getattr(tokenizer, "name_or_path", None) | |
| if not key: | |
| key = str(id(tokenizer)) | |
| if key not in _token_bytes_converter_cache: | |
| try: | |
| _token_bytes_converter_cache[key] = TokenizerBytesConverter( | |
| model_name_or_path=getattr(tokenizer, "name_or_path", None), | |
| tokenizer=tokenizer, | |
| trust_remote_code=True, | |
| ) | |
| except Exception: | |
| _token_bytes_converter_cache[key] = None | |
| return _token_bytes_converter_cache.get(key) | |
| if tokenizer is None: | |
| return f"[{token_id}]", False | |
| try: | |
| if model_type in ["rwkv", "rwkv7"]: | |
| # RWKV tokenizer provides raw bytes | |
| try: | |
| token_bytes = tokenizer.decodeBytes([token_id]) | |
| except Exception as e: | |
| if token_id == 0: | |
| return f"[{token_id}]", False | |
| raise e | |
| if token_bytes: | |
| try: | |
| decoded = token_bytes.decode("utf-8") | |
| return (decoded if decoded else f"[{token_id}]"), False | |
| except UnicodeDecodeError: | |
| return bytes_to_hex_str(token_bytes), True | |
| return f"[{token_id}]", False | |
| else: | |
| # HuggingFace tokenizer: prefer raw bytes when possible | |
| converter = get_bytes_converter(tokenizer) | |
| token_bytes = None | |
| if converter is not None: | |
| try: | |
| token_bytes = converter.token_to_bytes(token_id) | |
| except Exception: | |
| token_bytes = None | |
| if token_bytes: | |
| try: | |
| decoded = bytes(token_bytes).decode("utf-8") | |
| return (decoded if decoded else f"[{token_id}]"), False | |
| except UnicodeDecodeError: | |
| return bytes_to_hex_str(token_bytes), True | |
| decoded = tokenizer.decode([token_id]) | |
| if decoded and "�" not in decoded: | |
| return decoded, False | |
| return (decoded if decoded else f"[{token_id}]"), False | |
| except Exception as e: | |
| print(f"Warning: Failed to decode token {token_id} ({model_type}): {e}") | |
| return f"[{token_id}]", False | |
| def build_byte_to_token_map(text: str, tokenizer, model_type: str): | |
| """Build mapping from byte position to token index using the correct tokenizer. | |
| Returns a list of (start, end, token_idx) tuples for range-based lookup.""" | |
| if tokenizer is None: | |
| return [] | |
| token_ranges = [] | |
| try: | |
| if model_type in ["rwkv", "rwkv7"]: | |
| # RWKV tokenizer | |
| tokenized = tokenizer.encode(text) | |
| if hasattr(tokenized, "ids"): | |
| token_ids = tokenized.ids | |
| else: | |
| token_ids = tokenized | |
| byte_pos = 0 | |
| for idx, token_id in enumerate(token_ids): | |
| try: | |
| token_bytes = tokenizer.decodeBytes([token_id]) | |
| token_ranges.append((byte_pos, byte_pos + len(token_bytes), idx)) | |
| byte_pos += len(token_bytes) | |
| except Exception as e: | |
| print(f"Warning: Failed to decode RWKV token {token_id}: {e}") | |
| pass | |
| else: | |
| # HuggingFace tokenizer - use TokenizerBytesConverter | |
| tokenizer_name = getattr(tokenizer, "name_or_path", None) | |
| if tokenizer_name: | |
| converter = TokenizerBytesConverter(tokenizer_name, trust_remote_code=True) | |
| token_bytes_list = converter.encode_to_bytes(text) | |
| byte_pos = 0 | |
| for idx, token_bytes in enumerate(token_bytes_list): | |
| token_ranges.append((byte_pos, byte_pos + len(token_bytes), idx)) | |
| byte_pos += len(token_bytes) | |
| else: | |
| print(f"Warning: Could not get tokenizer name for HF model") | |
| except Exception as e: | |
| print(f"Warning: Could not build byte-to-token map ({model_type}): {e}") | |
| return [] | |
| return token_ranges | |
| def find_token_for_byte(byte_pos: int, token_ranges): | |
| for start, end, idx in token_ranges: | |
| if start <= byte_pos < end: | |
| return idx | |
| return None | |
| # Calculate deltas | |
| deltas = [a - b for a, b in zip(byte_losses_a, byte_losses_b)] | |
| avg_delta = sum(deltas) / len(deltas) if deltas else 0 | |
| # Calculate average compression rates | |
| avg_compression_a = sum(byte_losses_a) / len(byte_losses_a) * COMPRESSION_RATE_FACTOR if byte_losses_a else 0 | |
| avg_compression_b = sum(byte_losses_b) / len(byte_losses_b) * COMPRESSION_RATE_FACTOR if byte_losses_b else 0 | |
| avg_delta_compression = avg_delta * COMPRESSION_RATE_FACTOR | |
| # Get token info | |
| text_bytes = text.encode("utf-8") | |
| token_info = token_info_override if token_info_override is not None else get_token_info_for_text(text) | |
| common_boundaries = token_info["common_boundaries"] | |
| qwen_tokens = token_info["qwen_tokens"] | |
| rwkv_tokens = token_info["rwkv_tokens"] | |
| # Build byte position to token index mapping | |
| model_a_token_ranges = build_byte_to_token_map(text, tokenizer_a, model_type_a) | |
| model_b_token_ranges = build_byte_to_token_map(text, tokenizer_b, model_type_b) | |
| def get_tokens_for_range(byte_start, byte_end, token_list): | |
| result = [] | |
| for idx, (t_start, t_end, token_id, t_bytes) in enumerate(token_list): | |
| if t_start < byte_end and t_end > byte_start: | |
| result.append((idx, token_id, t_bytes)) | |
| return result | |
| # Build tokens based on common boundaries | |
| tokens = [] | |
| for i in range(len(common_boundaries) - 1): | |
| start_byte = common_boundaries[i] | |
| end_byte = common_boundaries[i + 1] | |
| token_bytes = text_bytes[start_byte:end_byte] | |
| decoded_ok = True | |
| try: | |
| token_text = token_bytes.decode("utf-8") | |
| except UnicodeDecodeError: | |
| # Show raw bytes when UTF-8 decoding fails | |
| token_text = "".join([f"\\x{b:02x}" for b in token_bytes]) | |
| decoded_ok = False | |
| qwen_toks = get_tokens_for_range(start_byte, end_byte, qwen_tokens) | |
| rwkv_toks = get_tokens_for_range(start_byte, end_byte, rwkv_tokens) | |
| if decoded_ok and re.search(r"\w", token_text, re.UNICODE): | |
| tokens.append( | |
| { | |
| "type": "word", | |
| "text": token_text, | |
| "byte_start": start_byte, | |
| "byte_end": end_byte, | |
| "word_lower": token_text.lower(), | |
| "qwen_tokens": qwen_toks, | |
| "rwkv_tokens": rwkv_toks, | |
| } | |
| ) | |
| else: | |
| tokens.append( | |
| { | |
| "type": "non-word", | |
| "text": token_text, | |
| "byte_start": start_byte, | |
| "byte_end": end_byte, | |
| "qwen_tokens": qwen_toks, | |
| "rwkv_tokens": rwkv_toks, | |
| } | |
| ) | |
| # Track word occurrences | |
| word_occurrences = {} | |
| word_id_counter = 0 | |
| for i, token in enumerate(tokens): | |
| if token["type"] == "word": | |
| word_lower = token["word_lower"] | |
| if word_lower not in word_occurrences: | |
| word_occurrences[word_lower] = [] | |
| word_occurrences[word_lower].append(i) | |
| token["word_id"] = word_id_counter | |
| word_id_counter += 1 | |
| # Build render model (HTML content built in JS) | |
| render_tokens = [] | |
| for token in tokens: | |
| token_text = token["text"] | |
| byte_start = token["byte_start"] | |
| byte_end = token["byte_end"] | |
| # Get actual model token IDs for this byte range | |
| model_a_token_idx = find_token_for_byte(byte_start, model_a_token_ranges) | |
| model_b_token_idx = find_token_for_byte(byte_start, model_b_token_ranges) | |
| # Build token info strings showing all tokens in this byte range | |
| def token_bytes_to_display_text(token_bytes: bytes) -> Tuple[str, bool]: | |
| if token_bytes is None: | |
| return "", False | |
| if isinstance(token_bytes, list): | |
| token_bytes = bytes(token_bytes) | |
| if isinstance(token_bytes, str): | |
| return token_bytes, False | |
| if len(token_bytes) == 0: | |
| return "", False | |
| try: | |
| return token_bytes.decode("utf-8"), False | |
| except UnicodeDecodeError: | |
| return "".join([f"\\x{b:02x}" for b in token_bytes]), True | |
| raw_bytes = list(text_bytes[byte_start:byte_end]) | |
| losses_a = byte_losses_a[byte_start:byte_end] | |
| losses_b = byte_losses_b[byte_start:byte_end] | |
| bytes_str = " ".join([f"{b:02x}" for b in raw_bytes]) | |
| compression_a_str = " ".join([f"{l * COMPRESSION_RATE_FACTOR:.2f}%" for l in losses_a]) | |
| compression_b_str = " ".join([f"{l * COMPRESSION_RATE_FACTOR:.2f}%" for l in losses_b]) | |
| # Calculate average compression rate for this token | |
| avg_compression_a_token = sum(losses_a) / len(losses_a) * COMPRESSION_RATE_FACTOR if losses_a else 0 | |
| avg_compression_b_token = sum(losses_b) / len(losses_b) * COMPRESSION_RATE_FACTOR if losses_b else 0 | |
| topk_a_data = None | |
| topk_b_data = None | |
| if topk_predictions_a is not None and model_a_token_ranges: | |
| model_a_token_idx = find_token_for_byte(byte_start, model_a_token_ranges) | |
| if model_a_token_idx is not None and model_a_token_idx < len(topk_predictions_a): | |
| pred = topk_predictions_a[model_a_token_idx] | |
| try: | |
| if len(pred) >= 4: | |
| actual_id, rank, actual_prob, topk_list = pred[0], pred[1], pred[2], pred[3] | |
| topk_a_data = [ | |
| actual_id, | |
| rank, | |
| actual_prob, | |
| [[tid, prob, *decode_token(tid, tokenizer_a, model_type_a)] for tid, prob in topk_list], | |
| ] | |
| else: | |
| topk_a_data = [ | |
| pred[0], | |
| pred[1], | |
| [[tid, prob, *decode_token(tid, tokenizer_a, model_type_a)] for tid, prob in pred[2]], | |
| ] | |
| except Exception as e: | |
| pass | |
| if topk_predictions_b is not None and model_b_token_ranges: | |
| model_b_token_idx = find_token_for_byte(byte_start, model_b_token_ranges) | |
| if model_b_token_idx is not None and model_b_token_idx < len(topk_predictions_b): | |
| pred = topk_predictions_b[model_b_token_idx] | |
| try: | |
| if len(pred) >= 4: | |
| actual_id, rank, actual_prob, topk_list = pred[0], pred[1], pred[2], pred[3] | |
| topk_b_data = [ | |
| actual_id, | |
| rank, | |
| actual_prob, | |
| [[tid, prob, *decode_token(tid, tokenizer_b, model_type_b)] for tid, prob in topk_list], | |
| ] | |
| else: | |
| topk_b_data = [pred[0], pred[1], [[tid, prob, *decode_token(tid, tokenizer_b, model_type_b)] for tid, prob in pred[2]]] | |
| except Exception as e: | |
| pass | |
| token_deltas = deltas[byte_start:byte_end] | |
| avg_token_delta = sum(token_deltas) / len(token_deltas) if token_deltas else 0 | |
| tuned_delta = avg_token_delta - avg_delta | |
| raw_delta = avg_token_delta | |
| # Initial rendering uses white color, JavaScript will apply colors based on slider | |
| r, g, b = 255, 255, 255 | |
| raw_display_text = token_text | |
| display_text = token_text.replace("\t", " ") | |
| def classify_kind(text_value: str, is_raw_value: bool) -> str: | |
| return build_display(text_value, is_raw=is_raw_value).kind | |
| def get_actual_prob(topk_predictions, token_idx: Optional[int]): | |
| if not topk_predictions or token_idx is None: | |
| return None | |
| if token_idx < 0 or token_idx >= len(topk_predictions): | |
| return None | |
| pred = topk_predictions[token_idx] | |
| if isinstance(pred, (list, tuple)) and len(pred) >= 3: | |
| return pred[2] | |
| return None | |
| model_tokens_render = {} | |
| if token["rwkv_tokens"]: | |
| rwkv_items = [] | |
| for tok_idx, tid, tb in token["rwkv_tokens"]: | |
| txt, is_raw = token_bytes_to_display_text(tb) | |
| rwkv_items.append([tid, txt, classify_kind(txt, is_raw), get_actual_prob(topk_predictions_a, tok_idx)]) | |
| model_tokens_render["rwkv"] = rwkv_items | |
| if token["qwen_tokens"]: | |
| qwen_items = [] | |
| for tok_idx, tid, tb in token["qwen_tokens"]: | |
| txt, is_raw = token_bytes_to_display_text(tb) | |
| qwen_items.append([tid, txt, classify_kind(txt, is_raw), get_actual_prob(topk_predictions_b, tok_idx)]) | |
| model_tokens_render["qwen"] = qwen_items | |
| display_info = build_display(raw_display_text, is_raw=not decoded_ok) | |
| if display_info.kind == "control": | |
| display_text = raw_display_text | |
| display_info.text = display_text | |
| render_tokens.append( | |
| TokenInfo( | |
| byte_start=byte_start, | |
| byte_end=byte_end, | |
| display=display_info, | |
| is_word=token["type"] == "word", | |
| word_id=token.get("word_id"), | |
| word_key=token.get("word_lower"), | |
| bytes_hex=bytes_str, | |
| compression={"rwkv": compression_a_str, "qwen": compression_b_str}, | |
| model_tokens=model_tokens_render, | |
| loss={"rwkv": avg_compression_a_token, "qwen": avg_compression_b_token}, | |
| topk={ | |
| "rwkv": topk_a_data, | |
| "qwen": topk_b_data, | |
| }, | |
| raw_delta=raw_delta, | |
| tuned_delta=tuned_delta, | |
| ) | |
| ) | |
| delta_color = "#64ff64" if avg_delta < 0 else "#ff6464" | |
| render_model = RenderModel( | |
| text=text, | |
| tokens=render_tokens, | |
| meta={ | |
| "model_a": model_a_name, | |
| "model_b": model_b_name, | |
| "avg_compression": { | |
| "rwkv": avg_compression_a, | |
| "qwen": avg_compression_b, | |
| }, | |
| "avg_delta": avg_delta, | |
| "avg_delta_compression": avg_delta_compression, | |
| }, | |
| ) | |
| render_model_json = escape_json_for_script(render_model.to_dict()) | |
| style_block = (ASSETS_DIR / "main.css").read_text(encoding="utf-8") | |
| header_html = f""" | |
| <div class="header"> | |
| <div class="meta"> | |
| <div>Model A: {model_a_name}</div> | |
| <div>Model B: {model_b_name}</div> | |
| <div>RWKV Compression: {avg_compression_a:.2f}%</div> | |
| <div>Qwen Compression: {avg_compression_b:.2f}%</div> | |
| <div style="color: {delta_color}">Avg Delta: {avg_delta_compression:+.2f}%</div> | |
| </div> | |
| <div class="legend"> | |
| <div class="legend-row"> | |
| <div class="legend-item legend-toggle"> | |
| <span style="color: #aaa;">Coloring Mode:</span> | |
| <label><input type="radio" name="delta-mode" value="relative" checked> vs Avg Delta</label> | |
| <label><input type="radio" name="delta-mode" value="absolute"> Absolute</label> | |
| </div> | |
| <div class="legend-item"> | |
| <span style="color: #aaa;">Color Range:</span> | |
| <input type="range" id="color-range-slider" min="0" max="100" value="10" step="0.1" style="width: 200px; vertical-align: middle;"> | |
| <span id="color-range-value" style="color: #fff; min-width: 45px; display: inline-block;">10%</span> | |
| </div> | |
| </div> | |
| <div class="legend-row"> | |
| <div class="legend-item"> | |
| <div class="legend-box" style="background-color: rgb(77, 255, 77)"></div> | |
| <span id="legend-better">RWKV better than avg delta</span> | |
| </div> | |
| <div class="legend-item"> | |
| <div class="legend-box" style="background-color: rgb(255, 255, 255)"></div> | |
| <span id="legend-equal">Equal to avg delta</span> | |
| </div> | |
| <div class="legend-item"> | |
| <div class="legend-box" style="background-color: rgb(255, 77, 77)"></div> | |
| <span id="legend-worse">RWKV worse than avg delta</span> | |
| </div> | |
| </div> | |
| </div> | |
| </div> | |
| """.strip("\n") | |
| script_body = (ASSETS_DIR / "main.js").read_text(encoding="utf-8") | |
| html_doc = render_page( | |
| { | |
| "page_title": "Model Comparison", | |
| "style_block": style_block.strip("\n"), | |
| "header_html": header_html, | |
| "content_html": "", | |
| "render_model_json": render_model_json, | |
| "script_body": script_body.strip("\n"), | |
| } | |
| ) | |
| if return_render_model: | |
| return html_doc, render_model.to_dict() | |
| return html_doc | |