""" HTML visualization generator for UncheatableEval. Generates interactive HTML visualizations comparing byte-level losses between two models. """ import bisect import json import math import re from pathlib import Path from typing import List, Tuple, Optional, Set import numpy as np from core.escaping import escape_json_for_script from core.render_model import RenderModel, TokenInfo, build_display from visualization.render import render_page from core.helpers import TokenizerBytesConverter ASSETS_DIR = Path(__file__).resolve().parent / "assets" # Compression rate conversion factor COMPRESSION_RATE_FACTOR = (1.0 / math.log(2.0)) * 0.125 * 100.0 # Global tokenizers (lazy loaded) _qwen_tokenizer = None _rwkv_tokenizer = None _token_bytes_converter_cache = {} def get_qwen_tokenizer(): """Lazy load Qwen tokenizer.""" global _qwen_tokenizer if _qwen_tokenizer is None: _qwen_tokenizer = TokenizerBytesConverter("Qwen/Qwen3-0.6B-Base") return _qwen_tokenizer def get_rwkv_tokenizer(): """Lazy load RWKV tokenizer.""" global _rwkv_tokenizer if _rwkv_tokenizer is None: from rwkv.rwkv_tokenizer import TRIE_TOKENIZER import os script_dir = os.path.dirname(os.path.abspath(__file__)) vocab_path = os.path.join(os.path.dirname(script_dir), "support", "rwkv_vocab_v20230424.txt") _rwkv_tokenizer = TRIE_TOKENIZER(vocab_path) return _rwkv_tokenizer def get_tokenizer_boundaries(text: str, tokenizer, is_rwkv: bool = False) -> Set[int]: """Get token boundaries (byte positions) for a given text.""" boundaries = set() boundaries.add(0) if is_rwkv: tokenized = tokenizer.encode(text) if hasattr(tokenized, "ids"): token_ids = tokenized.ids else: token_ids = tokenized byte_pos = 0 for token_id in token_ids: token_bytes = tokenizer.decodeBytes([token_id]) byte_pos += len(token_bytes) boundaries.add(byte_pos) else: token_bytes_list = tokenizer.encode_to_bytes(text) byte_pos = 0 for token_bytes in token_bytes_list: byte_pos += len(token_bytes) boundaries.add(byte_pos) return boundaries def get_token_info_for_text(text: str) -> dict: """Get detailed token information for each byte position.""" qwen_tokenizer = get_qwen_tokenizer() rwkv_tokenizer = get_rwkv_tokenizer() # Get Qwen tokens with positions qwen_tokens = [] byte_to_qwen = {} # Keep both token id (vocab id) and decoded bytes so the tooltip can show true token ids. qwen_id_and_bytes = qwen_tokenizer.encode_to_ids_and_bytes(text) byte_pos = 0 for idx, (token_id, token_bytes) in enumerate(qwen_id_and_bytes): start = byte_pos token_bytes_blob = bytes(token_bytes) end = byte_pos + len(token_bytes_blob) qwen_tokens.append((start, end, token_id, token_bytes_blob)) byte_to_qwen[start] = idx byte_pos = end # Get RWKV tokens with positions rwkv_tokens = [] byte_to_rwkv = {} tokenized = rwkv_tokenizer.encode(text) if hasattr(tokenized, "ids"): token_ids = tokenized.ids else: token_ids = tokenized byte_pos = 0 for idx, token_id in enumerate(token_ids): token_bytes = rwkv_tokenizer.decodeBytes([token_id]) start = byte_pos end = byte_pos + len(token_bytes) rwkv_tokens.append((start, end, token_id, token_bytes)) byte_to_rwkv[start] = idx byte_pos = end # Get common boundaries, but keep only UTF-8 codepoint boundaries qwen_boundaries = set([0] + [t[1] for t in qwen_tokens]) rwkv_boundaries = set([0] + [t[1] for t in rwkv_tokens]) utf8_boundaries = set([0]) whitespace_boundaries = set() linebreak_boundaries = set() byte_pos = 0 for ch in text: ch_bytes = ch.encode("utf-8") byte_pos += len(ch_bytes) utf8_boundaries.add(byte_pos) if ch.isspace(): whitespace_boundaries.add(byte_pos) if ch in ("\n", "\r"): linebreak_boundaries.add(byte_pos) common_boundaries = sorted(qwen_boundaries & rwkv_boundaries & utf8_boundaries) # Ensure we always include the end boundary text_end = len(text.encode("utf-8")) if text_end not in common_boundaries: common_boundaries.append(text_end) common_boundaries = sorted(common_boundaries) # Refine overly large segments to avoid giant spans in the UI. max_segment_bytes = 24 utf8_sorted = sorted(utf8_boundaries) linebreak_sorted = sorted(linebreak_boundaries) def split_by_max(start: int, end: int) -> List[int]: if end - start <= max_segment_bytes: return [end] left = bisect.bisect_right(utf8_sorted, start) right = bisect.bisect_left(utf8_sorted, end) candidates = utf8_sorted[left:right] if not candidates: return [end] out = [] pos = start idx = 0 while pos < end: limit = min(end, pos + max_segment_bytes) j = bisect.bisect_right(candidates, limit) - 1 if j < idx: out.append(end) break split_at = None for k in range(j, idx - 1, -1): if candidates[k] in whitespace_boundaries: split_at = candidates[k] j = k break if split_at is None: split_at = candidates[j] if split_at <= pos: split_at = candidates[j] out.append(split_at) pos = split_at idx = j + 1 if pos >= end: break if idx >= len(candidates): out.append(end) break if not out: out = [end] elif out[-1] != end: out.append(end) return out def split_segment(start: int, end: int) -> List[int]: if start >= end: return [] lb_left = bisect.bisect_right(linebreak_sorted, start) lb_right = bisect.bisect_left(linebreak_sorted, end) linebreaks = linebreak_sorted[lb_left:lb_right] if not linebreaks: return split_by_max(start, end) out = [] seg_start = start for lb in linebreaks: out.extend(split_by_max(seg_start, lb)) seg_start = lb out.extend(split_by_max(seg_start, end)) return out refined_boundaries = [common_boundaries[0]] if common_boundaries else [0] for i in range(len(common_boundaries) - 1): start = common_boundaries[i] end = common_boundaries[i + 1] refined_boundaries.extend(split_segment(start, end)) common_boundaries = sorted(set(refined_boundaries)) return { "common_boundaries": common_boundaries, "qwen_tokens": qwen_tokens, "rwkv_tokens": rwkv_tokens, "byte_to_qwen": byte_to_qwen, "byte_to_rwkv": byte_to_rwkv, } def generate_comparison_html( text: str, byte_losses_a: List[float], byte_losses_b: List[float], model_a_name: str, model_b_name: str, topk_predictions_a: Optional[List] = None, topk_predictions_b: Optional[List] = None, tokenizer_a=None, tokenizer_b=None, model_type_a: str = "hf", model_type_b: str = "rwkv7", token_info_override: Optional[dict] = None, return_render_model: bool = False, ) -> str: """ Generate an interactive HTML visualization comparing two models. Args: text: The input text that was evaluated byte_losses_a: Per-byte losses from model A byte_losses_b: Per-byte losses from model B model_a_name: Display name for model A model_b_name: Display name for model B topk_predictions_a: Top-k predictions from model A topk_predictions_b: Top-k predictions from model B tokenizer_a: Tokenizer for model A tokenizer_b: Tokenizer for model B model_type_a: Type of model A ("hf" or "rwkv7") model_type_b: Type of model B ("hf" or "rwkv7") token_info_override: Optional precomputed token info (for offline tests). return_render_model: If True, return (html, render_model_dict) Returns: HTML string with interactive visualization, or (html, render_model_dict) if return_render_model=True """ def decode_token(token_id: int, tokenizer, model_type: str) -> Tuple[str, bool]: """Decode a single token ID to text using the appropriate tokenizer. Returns (text, is_raw_bytes). """ def bytes_to_hex_str(byte_values) -> str: if isinstance(byte_values, list): byte_values = bytes(byte_values) return "".join([f"\\x{b:02x}" for b in byte_values]) def get_bytes_converter(tokenizer): if tokenizer is None: return None key = getattr(tokenizer, "name_or_path", None) if not key: key = str(id(tokenizer)) if key not in _token_bytes_converter_cache: try: _token_bytes_converter_cache[key] = TokenizerBytesConverter( model_name_or_path=getattr(tokenizer, "name_or_path", None), tokenizer=tokenizer, trust_remote_code=True, ) except Exception: _token_bytes_converter_cache[key] = None return _token_bytes_converter_cache.get(key) if tokenizer is None: return f"[{token_id}]", False try: if model_type in ["rwkv", "rwkv7"]: # RWKV tokenizer provides raw bytes try: token_bytes = tokenizer.decodeBytes([token_id]) except Exception as e: if token_id == 0: return f"[{token_id}]", False raise e if token_bytes: try: decoded = token_bytes.decode("utf-8") return (decoded if decoded else f"[{token_id}]"), False except UnicodeDecodeError: return bytes_to_hex_str(token_bytes), True return f"[{token_id}]", False else: # HuggingFace tokenizer: prefer raw bytes when possible converter = get_bytes_converter(tokenizer) token_bytes = None if converter is not None: try: token_bytes = converter.token_to_bytes(token_id) except Exception: token_bytes = None if token_bytes: try: decoded = bytes(token_bytes).decode("utf-8") return (decoded if decoded else f"[{token_id}]"), False except UnicodeDecodeError: return bytes_to_hex_str(token_bytes), True decoded = tokenizer.decode([token_id]) if decoded and "�" not in decoded: return decoded, False return (decoded if decoded else f"[{token_id}]"), False except Exception as e: print(f"Warning: Failed to decode token {token_id} ({model_type}): {e}") return f"[{token_id}]", False def build_byte_to_token_map(text: str, tokenizer, model_type: str): """Build mapping from byte position to token index using the correct tokenizer. Returns a list of (start, end, token_idx) tuples for range-based lookup.""" if tokenizer is None: return [] token_ranges = [] try: if model_type in ["rwkv", "rwkv7"]: # RWKV tokenizer tokenized = tokenizer.encode(text) if hasattr(tokenized, "ids"): token_ids = tokenized.ids else: token_ids = tokenized byte_pos = 0 for idx, token_id in enumerate(token_ids): try: token_bytes = tokenizer.decodeBytes([token_id]) token_ranges.append((byte_pos, byte_pos + len(token_bytes), idx)) byte_pos += len(token_bytes) except Exception as e: print(f"Warning: Failed to decode RWKV token {token_id}: {e}") pass else: # HuggingFace tokenizer - use TokenizerBytesConverter tokenizer_name = getattr(tokenizer, "name_or_path", None) if tokenizer_name: converter = TokenizerBytesConverter(tokenizer_name, trust_remote_code=True) token_bytes_list = converter.encode_to_bytes(text) byte_pos = 0 for idx, token_bytes in enumerate(token_bytes_list): token_ranges.append((byte_pos, byte_pos + len(token_bytes), idx)) byte_pos += len(token_bytes) else: print(f"Warning: Could not get tokenizer name for HF model") except Exception as e: print(f"Warning: Could not build byte-to-token map ({model_type}): {e}") return [] return token_ranges def find_token_for_byte(byte_pos: int, token_ranges): for start, end, idx in token_ranges: if start <= byte_pos < end: return idx return None # Calculate deltas deltas = [a - b for a, b in zip(byte_losses_a, byte_losses_b)] avg_delta = sum(deltas) / len(deltas) if deltas else 0 # Calculate average compression rates avg_compression_a = sum(byte_losses_a) / len(byte_losses_a) * COMPRESSION_RATE_FACTOR if byte_losses_a else 0 avg_compression_b = sum(byte_losses_b) / len(byte_losses_b) * COMPRESSION_RATE_FACTOR if byte_losses_b else 0 avg_delta_compression = avg_delta * COMPRESSION_RATE_FACTOR # Get token info text_bytes = text.encode("utf-8") token_info = token_info_override if token_info_override is not None else get_token_info_for_text(text) common_boundaries = token_info["common_boundaries"] qwen_tokens = token_info["qwen_tokens"] rwkv_tokens = token_info["rwkv_tokens"] # Build byte position to token index mapping model_a_token_ranges = build_byte_to_token_map(text, tokenizer_a, model_type_a) model_b_token_ranges = build_byte_to_token_map(text, tokenizer_b, model_type_b) def get_tokens_for_range(byte_start, byte_end, token_list): result = [] for idx, (t_start, t_end, token_id, t_bytes) in enumerate(token_list): if t_start < byte_end and t_end > byte_start: result.append((idx, token_id, t_bytes)) return result # Build tokens based on common boundaries tokens = [] for i in range(len(common_boundaries) - 1): start_byte = common_boundaries[i] end_byte = common_boundaries[i + 1] token_bytes = text_bytes[start_byte:end_byte] decoded_ok = True try: token_text = token_bytes.decode("utf-8") except UnicodeDecodeError: # Show raw bytes when UTF-8 decoding fails token_text = "".join([f"\\x{b:02x}" for b in token_bytes]) decoded_ok = False qwen_toks = get_tokens_for_range(start_byte, end_byte, qwen_tokens) rwkv_toks = get_tokens_for_range(start_byte, end_byte, rwkv_tokens) if decoded_ok and re.search(r"\w", token_text, re.UNICODE): tokens.append( { "type": "word", "text": token_text, "byte_start": start_byte, "byte_end": end_byte, "word_lower": token_text.lower(), "qwen_tokens": qwen_toks, "rwkv_tokens": rwkv_toks, } ) else: tokens.append( { "type": "non-word", "text": token_text, "byte_start": start_byte, "byte_end": end_byte, "qwen_tokens": qwen_toks, "rwkv_tokens": rwkv_toks, } ) # Track word occurrences word_occurrences = {} word_id_counter = 0 for i, token in enumerate(tokens): if token["type"] == "word": word_lower = token["word_lower"] if word_lower not in word_occurrences: word_occurrences[word_lower] = [] word_occurrences[word_lower].append(i) token["word_id"] = word_id_counter word_id_counter += 1 # Build render model (HTML content built in JS) render_tokens = [] for token in tokens: token_text = token["text"] byte_start = token["byte_start"] byte_end = token["byte_end"] # Get actual model token IDs for this byte range model_a_token_idx = find_token_for_byte(byte_start, model_a_token_ranges) model_b_token_idx = find_token_for_byte(byte_start, model_b_token_ranges) # Build token info strings showing all tokens in this byte range def token_bytes_to_display_text(token_bytes: bytes) -> Tuple[str, bool]: if token_bytes is None: return "", False if isinstance(token_bytes, list): token_bytes = bytes(token_bytes) if isinstance(token_bytes, str): return token_bytes, False if len(token_bytes) == 0: return "", False try: return token_bytes.decode("utf-8"), False except UnicodeDecodeError: return "".join([f"\\x{b:02x}" for b in token_bytes]), True raw_bytes = list(text_bytes[byte_start:byte_end]) losses_a = byte_losses_a[byte_start:byte_end] losses_b = byte_losses_b[byte_start:byte_end] bytes_str = " ".join([f"{b:02x}" for b in raw_bytes]) compression_a_str = " ".join([f"{l * COMPRESSION_RATE_FACTOR:.2f}%" for l in losses_a]) compression_b_str = " ".join([f"{l * COMPRESSION_RATE_FACTOR:.2f}%" for l in losses_b]) # Calculate average compression rate for this token avg_compression_a_token = sum(losses_a) / len(losses_a) * COMPRESSION_RATE_FACTOR if losses_a else 0 avg_compression_b_token = sum(losses_b) / len(losses_b) * COMPRESSION_RATE_FACTOR if losses_b else 0 topk_a_data = None topk_b_data = None if topk_predictions_a is not None and model_a_token_ranges: model_a_token_idx = find_token_for_byte(byte_start, model_a_token_ranges) if model_a_token_idx is not None and model_a_token_idx < len(topk_predictions_a): pred = topk_predictions_a[model_a_token_idx] try: if len(pred) >= 4: actual_id, rank, actual_prob, topk_list = pred[0], pred[1], pred[2], pred[3] topk_a_data = [ actual_id, rank, actual_prob, [[tid, prob, *decode_token(tid, tokenizer_a, model_type_a)] for tid, prob in topk_list], ] else: topk_a_data = [ pred[0], pred[1], [[tid, prob, *decode_token(tid, tokenizer_a, model_type_a)] for tid, prob in pred[2]], ] except Exception as e: pass if topk_predictions_b is not None and model_b_token_ranges: model_b_token_idx = find_token_for_byte(byte_start, model_b_token_ranges) if model_b_token_idx is not None and model_b_token_idx < len(topk_predictions_b): pred = topk_predictions_b[model_b_token_idx] try: if len(pred) >= 4: actual_id, rank, actual_prob, topk_list = pred[0], pred[1], pred[2], pred[3] topk_b_data = [ actual_id, rank, actual_prob, [[tid, prob, *decode_token(tid, tokenizer_b, model_type_b)] for tid, prob in topk_list], ] else: topk_b_data = [pred[0], pred[1], [[tid, prob, *decode_token(tid, tokenizer_b, model_type_b)] for tid, prob in pred[2]]] except Exception as e: pass token_deltas = deltas[byte_start:byte_end] avg_token_delta = sum(token_deltas) / len(token_deltas) if token_deltas else 0 tuned_delta = avg_token_delta - avg_delta raw_delta = avg_token_delta # Initial rendering uses white color, JavaScript will apply colors based on slider r, g, b = 255, 255, 255 raw_display_text = token_text display_text = token_text.replace("\t", " ") def classify_kind(text_value: str, is_raw_value: bool) -> str: return build_display(text_value, is_raw=is_raw_value).kind def get_actual_prob(topk_predictions, token_idx: Optional[int]): if not topk_predictions or token_idx is None: return None if token_idx < 0 or token_idx >= len(topk_predictions): return None pred = topk_predictions[token_idx] if isinstance(pred, (list, tuple)) and len(pred) >= 3: return pred[2] return None model_tokens_render = {} if token["rwkv_tokens"]: rwkv_items = [] for tok_idx, tid, tb in token["rwkv_tokens"]: txt, is_raw = token_bytes_to_display_text(tb) rwkv_items.append([tid, txt, classify_kind(txt, is_raw), get_actual_prob(topk_predictions_a, tok_idx)]) model_tokens_render["rwkv"] = rwkv_items if token["qwen_tokens"]: qwen_items = [] for tok_idx, tid, tb in token["qwen_tokens"]: txt, is_raw = token_bytes_to_display_text(tb) qwen_items.append([tid, txt, classify_kind(txt, is_raw), get_actual_prob(topk_predictions_b, tok_idx)]) model_tokens_render["qwen"] = qwen_items display_info = build_display(raw_display_text, is_raw=not decoded_ok) if display_info.kind == "control": display_text = raw_display_text display_info.text = display_text render_tokens.append( TokenInfo( byte_start=byte_start, byte_end=byte_end, display=display_info, is_word=token["type"] == "word", word_id=token.get("word_id"), word_key=token.get("word_lower"), bytes_hex=bytes_str, compression={"rwkv": compression_a_str, "qwen": compression_b_str}, model_tokens=model_tokens_render, loss={"rwkv": avg_compression_a_token, "qwen": avg_compression_b_token}, topk={ "rwkv": topk_a_data, "qwen": topk_b_data, }, raw_delta=raw_delta, tuned_delta=tuned_delta, ) ) delta_color = "#64ff64" if avg_delta < 0 else "#ff6464" render_model = RenderModel( text=text, tokens=render_tokens, meta={ "model_a": model_a_name, "model_b": model_b_name, "avg_compression": { "rwkv": avg_compression_a, "qwen": avg_compression_b, }, "avg_delta": avg_delta, "avg_delta_compression": avg_delta_compression, }, ) render_model_json = escape_json_for_script(render_model.to_dict()) style_block = (ASSETS_DIR / "main.css").read_text(encoding="utf-8") header_html = f"""