import os import shutil # --- SEARCH AND DESTROY POISONED CACHE --- corrupted_dir = "/root/.cache/huggingface/hub/models--google--umt5-base" if os.path.exists(corrupted_dir): print("[SYSTEM] Found corrupted UMT5 cache. Deleting...") shutil.rmtree(corrupted_dir, ignore_errors=True) else: print("[SYSTEM] Cache is clean.") # --- YOUR ORIGINAL CODE STARTS HERE --- from flask import Flask, request, jsonify, send_from_directory # Added send_from_directory from flask_sock import Sock from transformers import AutoModel import torch import time import json from flask_cors import CORS app = Flask(__name__) CORS(app) sock = Sock(app) # Initialize WebSocket support print("[SYSTEM] Booting up Network Server...") print("[SYSTEM] Loading FloodDiffusionTiny model from Hugging Face...") # 1. Load the model model = AutoModel.from_pretrained( "ShandaAI/FloodDiffusionTiny", trust_remote_code=True ) # 2. Cloud Architecture Override device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = model.to(device) print(f"[SYSTEM] Model loaded successfully onto device: {device}") @app.route('/') def serve_ui(): # This tells Flask to send the index.html file to the user's browser return send_from_directory('.', 'index.html') # --- ADD THIS NEW ROUTE HERE --- @app.route('/') def serve_static_files(filename): # This allows Flask to send the smpl.glb file when the browser asks for it! return send_from_directory('.', filename) # --- THE NEW WEBSOCKET PIPELINE --- @sock.route('/api/generate_stream') def stream_motion(ws): print("\n[NETWORK] 🟢 WebSocket Connection Opened! Client connected.") # Keep the connection open forever while True: try: # 1. Wait for the live prompt from the client's text box raw_data = ws.receive() if raw_data is None: continue data = json.loads(raw_data) text_prompt = data.get('prompt', '') ticket_number = data.get('ticket', 0) print(f"[NETWORK] Live Prompt Received: '{text_prompt}'") start_time = time.time() # 2. Server Processing (Inference) motion_joints = model(text_prompt, length=150, output_joints=True) processing_time = (time.time() - start_time) * 1000 # 3. Format Network Payload payload = { "status": "success", "ticket": ticket_number, "latency_ms": round(processing_time, 2), "tensor_shape": list(motion_joints.shape), "data": motion_joints.tolist() } # 4. Push data back through the pipe instantly! ws.send(json.dumps(payload)) print(f"[NETWORK] ⚡ Streamed 30 frames to client in {processing_time:.2f}ms") except Exception as e: print(f"[NETWORK] 🔴 WebSocket Error or Disconnect: {e}") break if __name__ == '__main__': # --- PORT 7860 FOR HUGGING FACE --- app.run(host='0.0.0.0', port=7860, debug=False)