| | import io |
| | import os |
| | import gzip |
| | import uuid |
| | import traceback |
| | from concurrent.futures import ThreadPoolExecutor |
| | from flask import Flask, request, jsonify, send_file |
| | import matplotlib.pyplot as plt |
| | import torch |
| |
|
| | from bit_transformer.dashboard_app import ModelManager |
| | from bit_transformer.dashboard import plot_telemetry |
| | from bit_transformer.hf_checkpoint import hf_login, save_checkpoint, download_checkpoint |
| | from bit_transformer.optimization import configure_optimizer |
| | from bit_transformer.bit_io import text_to_bits |
| | from bit_transformer.dataset_builder import BitTransformerDatasetBuilder, create_bittransformerlm_dataset |
| |
|
| | app = Flask(__name__) |
| | manager = ModelManager() |
| |
|
| | |
| | executor = ThreadPoolExecutor(max_workers=4) |
| | jobs: dict[str, dict] = {} |
| |
|
| |
|
| | def _submit_job(fn, *args, **kwargs) -> str: |
| | """Schedule a function for background execution and return a job id.""" |
| | job_id = str(uuid.uuid4()) |
| | jobs[job_id] = {"status": "queued", "result": None, "error": None, "logs": []} |
| |
|
| | def wrapper(): |
| | jobs[job_id]["status"] = "running" |
| | try: |
| | jobs[job_id]["result"] = fn(*args, **kwargs) |
| | jobs[job_id]["status"] = "completed" |
| | except Exception as err: |
| | jobs[job_id]["status"] = "error" |
| | jobs[job_id]["error"] = str(err) |
| | jobs[job_id]["trace"] = traceback.format_exc() |
| |
|
| | executor.submit(wrapper) |
| | return job_id |
| |
|
| |
|
| | @app.errorhandler(Exception) |
| | def handle_exception(err): |
| | """Return JSON error responses with stack traces.""" |
| | return ( |
| | jsonify({"error": str(err), "trace": traceback.format_exc()}), |
| | getattr(err, "code", 500), |
| | ) |
| |
|
| |
|
| | @app.route("/init", methods=["POST"]) |
| | def init_model(): |
| | data = request.json or {} |
| | int_fields = { |
| | "d_model", |
| | "nhead", |
| | "num_layers", |
| | "dim_feedforward", |
| | "max_seq_len", |
| | "chunk_size", |
| | "overlap", |
| | } |
| | float_fields = {"act_threshold"} |
| | bool_fields = {"reversible", "use_checkpoint"} |
| | params = {} |
| | for k, v in data.items(): |
| | if v is None: |
| | params[k] = None |
| | elif k in int_fields: |
| | params[k] = int(v) |
| | elif k in float_fields: |
| | params[k] = float(v) |
| | elif k in bool_fields: |
| | params[k] = bool(v) |
| | else: |
| | params[k] = v |
| | manager.init_model(params) |
| | return jsonify({"status": "initialized", "params": params}) |
| |
|
| | @app.route("/train", methods=["POST"]) |
| | def train_model(): |
| | bits = request.json["bits"] |
| |
|
| | def task(): |
| | tensor = torch.tensor(bits, dtype=torch.long) |
| | loss, ratio = manager.train_step(tensor) |
| | return {"loss": loss, "ratio": ratio} |
| |
|
| | job_id = _submit_job(task) |
| | return jsonify({"job_id": job_id}) |
| |
|
| |
|
| | @app.route("/train_epochs", methods=["POST"]) |
| | def train_epochs_route(): |
| | data = request.json |
| | bits = data["bits"] |
| | epochs = int(data.get("epochs", 1)) |
| | compress_prob = float(data.get("compress_prob", 0.5)) |
| | direct_prob = float(data.get("direct_prob", 0.0)) |
| |
|
| | def task(): |
| | tensor = torch.tensor(bits, dtype=torch.long) |
| | metrics = manager.train_epochs( |
| | tensor, |
| | epochs=epochs, |
| | compress_prob=compress_prob, |
| | direct_prob=direct_prob, |
| | ) |
| | return {"metrics": metrics} |
| |
|
| | job_id = _submit_job(task) |
| | return jsonify({"job_id": job_id}) |
| |
|
| | @app.route("/scale_up", methods=["POST"]) |
| | def scale_up(): |
| | width_mult = float(request.json.get("width_mult", 1.0)) |
| |
|
| | def task(): |
| | manager.scale_up(width_mult) |
| | return { |
| | "status": "scaled", |
| | "layers": manager.model.num_layers, |
| | "d_model": manager.model.d_model, |
| | } |
| |
|
| | job_id = _submit_job(task) |
| | return jsonify({"job_id": job_id}) |
| |
|
| | @app.route("/collapse", methods=["POST"]) |
| | def collapse_model(): |
| | cluster_bits = request.json["clusters"] |
| | params = {k: int(v) for k, v in request.json["params"].items()} |
| | width_scale = float(request.json.get("width_scale", 1.0)) |
| |
|
| | def task(): |
| | manager.collapse(cluster_bits, params, width_scale) |
| | return {"status": "collapsed"} |
| |
|
| | job_id = _submit_job(task) |
| | return jsonify({"job_id": job_id}) |
| |
|
| |
|
| | @app.route("/job/<job_id>", methods=["GET"]) |
| | def get_job(job_id: str): |
| | job = jobs.get(job_id) |
| | if job is None: |
| | return jsonify({"error": "not found"}), 404 |
| | return jsonify(job) |
| |
|
| |
|
| | @app.route("/jobs", methods=["GET"]) |
| | def list_jobs(): |
| | return jsonify(jobs) |
| |
|
| | @app.route("/lambdas", methods=["GET", "POST"]) |
| | def update_lambdas(): |
| | if request.method == "POST": |
| | data = request.json |
| | manager.set_lambdas(float(data["lambda_K"]), float(data["lambda_C"]), float(data["lambda_S"])) |
| | return jsonify({"status": "updated"}) |
| | else: |
| | return jsonify({ |
| | "lambda_K": manager.lambda_K, |
| | "lambda_C": manager.lambda_C, |
| | "lambda_S": manager.lambda_S, |
| | }) |
| |
|
| | @app.route("/diffusion", methods=["GET", "POST"]) |
| | def update_diffusion(): |
| | if request.method == "POST": |
| | manager.set_diffusion(bool(request.json.get("diffusion", False))) |
| | return jsonify({"status": "updated"}) |
| | return jsonify({"diffusion": manager.diffusion}) |
| |
|
| |
|
| | @app.route("/qat", methods=["GET", "POST"]) |
| | def update_qat(): |
| | if request.method == "POST": |
| | manager.set_qat(bool(request.json.get("qat", False))) |
| | return jsonify({"status": "updated"}) |
| | return jsonify({"qat": manager.qat}) |
| |
|
| |
|
| | @app.route("/gpu", methods=["GET", "POST"]) |
| | def update_gpu(): |
| | if request.method == "POST": |
| | manager.set_gpu(bool(request.json.get("use_gpu", False))) |
| | return jsonify({"status": "updated"}) |
| | return jsonify({"use_gpu": manager.use_gpu}) |
| |
|
| | @app.route("/infer", methods=["POST"]) |
| | def inference(): |
| | bits = torch.tensor(request.json["bits"], dtype=torch.long) |
| | result = manager.infer(bits) |
| | return jsonify(result) |
| |
|
| |
|
| | @app.route("/infer_long", methods=["POST"]) |
| | def inference_long(): |
| | bits = torch.tensor(request.json["bits"], dtype=torch.long) |
| | ctx = int(request.json.get("ctx_bits", 4096)) |
| | overlap = int(request.json.get("overlap", 256)) |
| | result = manager.infer_long(bits, ctx_bits=ctx, overlap=overlap) |
| | return jsonify(result) |
| |
|
| | @app.route("/infer_text", methods=["POST"]) |
| | def inference_text(): |
| | text = request.json.get("text", "") |
| | result = manager.infer_text(text) |
| | return jsonify(result) |
| |
|
| | @app.route("/status", methods=["GET"]) |
| | def status(): |
| | return jsonify(manager.get_status()) |
| |
|
| |
|
| | @app.route("/model_config", methods=["GET"]) |
| | def model_config(): |
| | return jsonify(manager.get_model_config()) |
| |
|
| |
|
| | @app.route("/metrics", methods=["GET"]) |
| | def metrics(): |
| | return jsonify(manager.get_metrics()) |
| |
|
| |
|
| | @app.route("/save_checkpoint", methods=["POST"]) |
| | def save_checkpoint_route(): |
| | repo_id = request.json.get("repo_id") |
| | token = request.json.get("token") or os.getenv("HF_TOKEN") |
| | if manager.model is None: |
| | return jsonify({"error": "model not initialized"}), 400 |
| | if token: |
| | hf_login(token=token) |
| | save_checkpoint(manager.model, repo_id=repo_id) |
| | return jsonify({"status": "saved"}) |
| |
|
| |
|
| | @app.route("/download_checkpoint", methods=["POST"]) |
| | def download_checkpoint_route(): |
| | repo_id = request.json.get("repo_id") |
| | token = request.json.get("token") or os.getenv("HF_TOKEN") |
| | if token: |
| | hf_login(token=token) |
| | dest = manager.weights_path + ".gz" |
| | ok = download_checkpoint(dest, repo_id=repo_id) |
| | if not ok: |
| | return jsonify({"status": "failed"}), 500 |
| | if manager.model is None: |
| | return jsonify({"status": "downloaded", "loaded": False}) |
| | with gzip.open(dest, "rb") as f: |
| | state = torch.load(f, map_location="cpu") |
| | manager.model.load_state_dict(state) |
| | manager.optimizer, manager.scheduler = configure_optimizer( |
| | manager.model, lr=1e-3, total_steps=manager.total_steps |
| | ) |
| | manager._apply_device() |
| | manager._save_state() |
| | return jsonify({"status": "downloaded", "loaded": True}) |
| |
|
| | @app.route("/plot.png") |
| | def plot_png(): |
| | fig, _ = plot_telemetry(manager.metrics) |
| | buf = io.BytesIO() |
| | fig.savefig(buf, format="png") |
| | plt.close(fig) |
| | buf.seek(0) |
| | return send_file(buf, mimetype="image/png") |
| |
|
| |
|
| | @app.route("/text_to_bits", methods=["POST"]) |
| | def text_to_bits_route(): |
| | text = request.json.get("text", "") |
| | if len(text) > 100_000: |
| | return jsonify({"error": "text too large"}), 413 |
| | return jsonify({"bits": text_to_bits(text)}) |
| |
|
| |
|
| | @app.route("/dataset", methods=["GET"]) |
| | def dataset_route(): |
| | name = request.args.get("name", "") |
| | split = request.args.get("split", "train") |
| | size = int(request.args.get("size", 1)) |
| | seq_len = int(request.args.get("seq_len", 64)) |
| | if size * seq_len > 1_000_000: |
| | return jsonify({"error": "dataset too large"}), 413 |
| | if name == "wikitext2": |
| | try: |
| | from datasets import load_dataset |
| |
|
| | ds = load_dataset("wikitext", "wikitext-2-raw-v1", split=split) |
| | lines = [t for t in ds["text"] if t.strip()][:size] |
| | except Exception: |
| | bits = torch.randint(0, 2, (size, seq_len), dtype=torch.long) |
| | return jsonify({"bits": bits.tolist()}) |
| | bits_list = [] |
| | for text in lines: |
| | b = text_to_bits(text)[:seq_len] |
| | if len(b) < seq_len: |
| | b.extend([0] * (seq_len - len(b))) |
| | bits_list.append(b) |
| | if len(bits_list) < size: |
| | pad = size - len(bits_list) |
| | bits_list.extend(torch.randint(0, 2, (pad, seq_len), dtype=torch.long).tolist()) |
| | return jsonify({"bits": bits_list}) |
| | return jsonify({"error": "unknown dataset"}), 400 |
| |
|
| |
|
| | |
| |
|
| | @app.route("/dataset/create", methods=["POST"]) |
| | def create_dataset(): |
| | """Create and upload a new BitTransformerLM dataset.""" |
| | data = request.json or {} |
| | |
| | hf_token = data.get("hf_token") or os.getenv("HF_TOKEN") |
| | repo_id = data.get("repo_id", "BitTransformerLM") |
| | source_texts = data.get("source_texts", None) |
| | |
| | if not hf_token: |
| | return jsonify({"error": "HF token required"}), 400 |
| | |
| | def task(): |
| | try: |
| | dataset_url = create_bittransformerlm_dataset( |
| | hf_token=hf_token, |
| | repo_id=repo_id, |
| | source_texts=source_texts |
| | ) |
| | return { |
| | "status": "success", |
| | "dataset_url": dataset_url, |
| | "repo_id": repo_id |
| | } |
| | except Exception as e: |
| | return { |
| | "status": "error", |
| | "error": str(e) |
| | } |
| | |
| | job_id = _submit_job(task) |
| | return jsonify({"job_id": job_id, "message": "Dataset creation started"}) |
| |
|
| |
|
| | @app.route("/dataset/builder", methods=["POST"]) |
| | def create_dataset_builder(): |
| | """Initialize a dataset builder for custom dataset creation.""" |
| | data = request.json or {} |
| | |
| | hf_token = data.get("hf_token") or os.getenv("HF_TOKEN") |
| | repo_id = data.get("repo_id", "BitTransformerLM") |
| | |
| | if not hf_token: |
| | return jsonify({"error": "HF token required"}), 400 |
| | |
| | try: |
| | builder = BitTransformerDatasetBuilder(hf_token, repo_id) |
| | |
| | |
| | builder_info = { |
| | "repo_id": repo_id, |
| | "config": builder.config, |
| | "status": "ready" |
| | } |
| | |
| | return jsonify({ |
| | "status": "builder_created", |
| | "builder_info": builder_info |
| | }) |
| | |
| | except Exception as e: |
| | return jsonify({"error": str(e)}), 500 |
| |
|
| |
|
| | @app.route("/dataset/generate", methods=["POST"]) |
| | def generate_dataset_samples(): |
| | """Generate specific types of dataset samples.""" |
| | data = request.json or {} |
| | |
| | sample_type = data.get("type", "text_to_bits") |
| | count = int(data.get("count", 100)) |
| | max_len = int(data.get("max_len", 256)) |
| | texts = data.get("texts", None) |
| | |
| | if count > 5000: |
| | return jsonify({"error": "count too large, max 5000"}), 400 |
| | |
| | def task(): |
| | try: |
| | |
| | builder = BitTransformerDatasetBuilder("dummy_token", "temp") |
| | |
| | if sample_type == "text_to_bits": |
| | if not texts: |
| | texts = builder._get_default_texts()[:count] |
| | samples = builder.generate_text_to_bits_data(texts[:count], max_len) |
| | |
| | elif sample_type == "synthetic": |
| | samples = builder.generate_synthetic_patterns(count, max_len) |
| | |
| | elif sample_type == "safety": |
| | samples = builder.generate_safety_benchmarks(count) |
| | |
| | elif sample_type == "compression": |
| | |
| | base_texts = builder._get_default_texts()[:50] |
| | base_samples = builder.generate_text_to_bits_data(base_texts, max_len) |
| | samples = builder.generate_compression_variants(base_samples)[:count] |
| | |
| | else: |
| | return {"error": f"Unknown sample type: {sample_type}"} |
| | |
| | return { |
| | "status": "success", |
| | "samples": samples[:10], |
| | "total_generated": len(samples), |
| | "sample_type": sample_type |
| | } |
| | |
| | except Exception as e: |
| | return {"error": str(e)} |
| | |
| | job_id = _submit_job(task) |
| | return jsonify({"job_id": job_id, "message": f"Generating {sample_type} samples"}) |
| |
|
| |
|
| | @app.route("/dataset/info", methods=["GET"]) |
| | def dataset_info(): |
| | """Get information about available dataset generation options.""" |
| | return jsonify({ |
| | "sample_types": [ |
| | { |
| | "type": "text_to_bits", |
| | "description": "Convert text to parity-protected bit sequences", |
| | "parameters": ["texts", "max_len"] |
| | }, |
| | { |
| | "type": "synthetic", |
| | "description": "Generate synthetic bit patterns", |
| | "parameters": ["count", "max_len"], |
| | "patterns": ["alternating", "blocks", "fibonacci", "prime_based", "random_walk"] |
| | }, |
| | { |
| | "type": "safety", |
| | "description": "Generate safety benchmark sequences", |
| | "parameters": ["count"], |
| | "categories": ["low_entropy", "medium_entropy", "high_entropy", "edge_cases"] |
| | }, |
| | { |
| | "type": "compression", |
| | "description": "Generate compressed variants of base sequences", |
| | "parameters": ["count", "compression_ratios"] |
| | } |
| | ], |
| | "default_config": { |
| | "max_sequence_length": 512, |
| | "total_samples": 25000, |
| | "safety_thresholds": { |
| | "min_negentropy": 0.1, |
| | "max_lz_complexity": 0.9, |
| | "min_symbiosis": 0.3 |
| | } |
| | } |
| | }) |
| |
|
| |
|
| | @app.route("/health") |
| | def health_check(): |
| | return jsonify({"status": "ok"}) |
| |
|
| |
|
| | def run_mcp_server(host: str = "0.0.0.0", port: int = 7000) -> None: |
| | app.run(host=host, port=port, debug=True) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import torch |
| | run_mcp_server() |
| |
|