| | import os |
| | import spaces |
| | from huggingface_hub import HfApi, hf_hub_download |
| | from apscheduler.schedulers.background import BackgroundScheduler |
| | from concurrent.futures import ThreadPoolExecutor |
| | from datetime import datetime |
| | import threading |
| | from sqlalchemy import or_ |
| |
|
| | year = datetime.now().year |
| | month = datetime.now().month |
| |
|
| | |
| | IS_SPACES = False |
| | if os.getenv("SPACE_REPO_NAME"): |
| | print("Running in a Hugging Face Space 🤗") |
| | IS_SPACES = True |
| |
|
| | |
| | if not os.path.exists("instance/tts_arena.db"): |
| | os.makedirs("instance", exist_ok=True) |
| | try: |
| | print("Database not found, downloading from HF dataset...") |
| | hf_hub_download( |
| | repo_id="kemuriririn/database-arena", |
| | filename="tts_arena.db", |
| | repo_type="dataset", |
| | local_dir="instance", |
| | token=os.getenv("HF_TOKEN"), |
| | ) |
| | print("Database downloaded successfully ✅") |
| | except Exception as e: |
| | print(f"Error downloading database from HF dataset: {str(e)} ⚠️") |
| |
|
| | from flask import ( |
| | Flask, |
| | render_template, |
| | g, |
| | request, |
| | jsonify, |
| | send_file, |
| | redirect, |
| | url_for, |
| | session, |
| | abort, |
| | ) |
| | from flask_login import LoginManager, current_user |
| | from models import * |
| | from auth import auth, init_oauth, is_admin |
| | from admin import admin |
| | import os |
| | from dotenv import load_dotenv |
| | from flask_limiter import Limiter |
| | from flask_limiter.util import get_remote_address |
| | import uuid |
| | import tempfile |
| | import shutil |
| | from tts import predict_tts |
| | import random |
| | import json |
| | from datetime import datetime, timedelta |
| | from flask_migrate import Migrate |
| | import requests |
| | import functools |
| | import time |
| |
|
| |
|
| | |
| | if not IS_SPACES: |
| | load_dotenv() |
| |
|
| | app = Flask(__name__) |
| | app.config["SECRET_KEY"] = os.getenv("SECRET_KEY", os.urandom(24)) |
| | app.config["SQLALCHEMY_DATABASE_URI"] = os.getenv( |
| | "DATABASE_URI", "sqlite:///tts_arena.db" |
| | ) |
| | app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False |
| | app.config["SESSION_COOKIE_SECURE"] = True |
| | app.config["SESSION_COOKIE_SAMESITE"] = ( |
| | "None" if IS_SPACES else "Lax" |
| | ) |
| | app.config["PERMANENT_SESSION_LIFETIME"] = timedelta(days=30) |
| |
|
| | |
| | if IS_SPACES: |
| | app.config["PREFERRED_URL_SCHEME"] = "https" |
| |
|
| | |
| | app.config["TURNSTILE_ENABLED"] = ( |
| | os.getenv("TURNSTILE_ENABLED", "False").lower() == "true" |
| | ) |
| | app.config["TURNSTILE_SITE_KEY"] = os.getenv("TURNSTILE_SITE_KEY", "") |
| | app.config["TURNSTILE_SECRET_KEY"] = os.getenv("TURNSTILE_SECRET_KEY", "") |
| | app.config["TURNSTILE_VERIFY_URL"] = ( |
| | "https://challenges.cloudflare.com/turnstile/v0/siteverify" |
| | ) |
| |
|
| | migrate = Migrate(app, db) |
| |
|
| | |
| | db.init_app(app) |
| | login_manager = LoginManager() |
| | login_manager.init_app(app) |
| | login_manager.login_view = "auth.login" |
| |
|
| | |
| | init_oauth(app) |
| |
|
| | |
| | limiter = Limiter( |
| | app=app, |
| | key_func=get_remote_address, |
| | default_limits=["2000 per day", "50 per minute"], |
| | storage_uri="memory://", |
| | ) |
| |
|
| | |
| | TTS_CACHE_SIZE = int(os.getenv("TTS_CACHE_SIZE", "10")) |
| | CACHE_AUDIO_SUBDIR = "cache" |
| | tts_cache = {} |
| | tts_cache_lock = threading.Lock() |
| | SMOOTHING_FACTOR_MODEL_SELECTION = 500 |
| | |
| | cache_executor = ThreadPoolExecutor(max_workers=8, thread_name_prefix='CacheReplacer') |
| | all_harvard_sentences = [] |
| |
|
| | |
| | TEMP_AUDIO_DIR = os.path.join(tempfile.gettempdir(), "tts_arena_audio") |
| | CACHE_AUDIO_DIR = os.path.join(TEMP_AUDIO_DIR, CACHE_AUDIO_SUBDIR) |
| | os.makedirs(TEMP_AUDIO_DIR, exist_ok=True) |
| | os.makedirs(CACHE_AUDIO_DIR, exist_ok=True) |
| |
|
| | |
| | REFERENCE_AUDIO_DIR = os.path.join(TEMP_AUDIO_DIR, "reference_audios") |
| | REFERENCE_AUDIO_DATASET = os.getenv("REFERENCE_AUDIO_DATASET", "kemuriririn/arena-files") |
| | REFERENCE_AUDIO_PATTERN = os.getenv("REFERENCE_AUDIO_PATTERN", "reference_audios/") |
| | reference_audio_files = [] |
| |
|
| | def download_reference_audios(): |
| | """从 Hugging Face dataset 下载参考音频到本地目录,并生成文件列表""" |
| | global reference_audio_files |
| | os.makedirs(REFERENCE_AUDIO_DIR, exist_ok=True) |
| | try: |
| | api = HfApi(token=os.getenv("HF_TOKEN")) |
| | files = api.list_repo_files(repo_id=REFERENCE_AUDIO_DATASET, repo_type="dataset") |
| | |
| | wav_files = [f for f in files if f.startswith(REFERENCE_AUDIO_PATTERN) and f.endswith(".wav")] |
| | for f in wav_files: |
| | local_path = hf_hub_download( |
| | repo_id=REFERENCE_AUDIO_DATASET, |
| | filename=f, |
| | repo_type="dataset", |
| | local_dir=REFERENCE_AUDIO_DIR, |
| | token=os.getenv("HF_TOKEN"), |
| | ) |
| | reference_audio_files.append(local_path) |
| | print(f"Downloaded {len(reference_audio_files)} reference audios.") |
| | except Exception as e: |
| | print(f"Error downloading reference audios: {e}") |
| | reference_audio_files = [] |
| |
|
| | |
| | app.tts_sessions = {} |
| | tts_sessions = app.tts_sessions |
| |
|
| | |
| | app.conversational_sessions = {} |
| | conversational_sessions = app.conversational_sessions |
| |
|
| | |
| | app.register_blueprint(auth, url_prefix="/auth") |
| | app.register_blueprint(admin) |
| |
|
| |
|
| | @login_manager.user_loader |
| | def load_user(user_id): |
| | return User.query.get(int(user_id)) |
| |
|
| |
|
| | @app.before_request |
| | def before_request(): |
| | g.user = current_user |
| | g.is_admin = is_admin(current_user) |
| |
|
| | |
| | if IS_SPACES and request.headers.get("X-Forwarded-Proto") == "http": |
| | url = request.url.replace("http://", "https://", 1) |
| | return redirect(url, code=301) |
| |
|
| | |
| | if app.config["TURNSTILE_ENABLED"]: |
| | |
| | excluded_routes = ["verify_turnstile", "turnstile_page", "static"] |
| | if request.endpoint not in excluded_routes: |
| | |
| | if not session.get("turnstile_verified"): |
| | |
| | redirect_url = request.url |
| | |
| | if IS_SPACES and redirect_url.startswith("http://"): |
| | redirect_url = redirect_url.replace("http://", "https://", 1) |
| |
|
| | |
| | if request.path.startswith("/api/"): |
| | return jsonify({"error": "Turnstile verification required"}), 403 |
| | |
| | return redirect(url_for("turnstile_page", redirect_url=redirect_url)) |
| | else: |
| | |
| | verification_timeout = ( |
| | int(os.getenv("TURNSTILE_TIMEOUT_HOURS", "24")) * 3600 |
| | ) |
| | verified_at = session.get("turnstile_verified_at", 0) |
| | current_time = datetime.utcnow().timestamp() |
| |
|
| | if current_time - verified_at > verification_timeout: |
| | |
| | session.pop("turnstile_verified", None) |
| | session.pop("turnstile_verified_at", None) |
| |
|
| | redirect_url = request.url |
| | |
| | if IS_SPACES and redirect_url.startswith("http://"): |
| | redirect_url = redirect_url.replace("http://", "https://", 1) |
| |
|
| | if request.path.startswith("/api/"): |
| | return jsonify({"error": "Turnstile verification expired"}), 403 |
| | return redirect( |
| | url_for("turnstile_page", redirect_url=redirect_url) |
| | ) |
| |
|
| |
|
| | @app.route("/turnstile", methods=["GET"]) |
| | def turnstile_page(): |
| | """Display Cloudflare Turnstile verification page""" |
| | redirect_url = request.args.get("redirect_url", url_for("arena", _external=True)) |
| |
|
| | |
| | if IS_SPACES and redirect_url.startswith("http://"): |
| | redirect_url = redirect_url.replace("http://", "https://", 1) |
| |
|
| | return render_template( |
| | "turnstile.html", |
| | turnstile_site_key=app.config["TURNSTILE_SITE_KEY"], |
| | redirect_url=redirect_url, |
| | ) |
| |
|
| |
|
| | @app.route("/verify-turnstile", methods=["POST"]) |
| | def verify_turnstile(): |
| | """Verify Cloudflare Turnstile token""" |
| | token = request.form.get("cf-turnstile-response") |
| | redirect_url = request.form.get("redirect_url", url_for("arena", _external=True)) |
| |
|
| | |
| | if IS_SPACES and redirect_url.startswith("http://"): |
| | redirect_url = redirect_url.replace("http://", "https://", 1) |
| |
|
| | if not token: |
| | |
| | if request.headers.get("X-Requested-With") == "XMLHttpRequest": |
| | return ( |
| | jsonify({"success": False, "error": "Missing verification token"}), |
| | 400, |
| | ) |
| | |
| | return redirect(url_for("turnstile_page", redirect_url=redirect_url)) |
| |
|
| | |
| | data = { |
| | "secret": app.config["TURNSTILE_SECRET_KEY"], |
| | "response": token, |
| | "remoteip": request.remote_addr, |
| | } |
| |
|
| | try: |
| | response = requests.post(app.config["TURNSTILE_VERIFY_URL"], data=data) |
| | result = response.json() |
| |
|
| | if result.get("success"): |
| | |
| | session["turnstile_verified"] = True |
| | session["turnstile_verified_at"] = datetime.utcnow().timestamp() |
| |
|
| | |
| | is_xhr = request.headers.get("X-Requested-With") == "XMLHttpRequest" |
| | accepts_json = "application/json" in request.headers.get("Accept", "") |
| |
|
| | |
| | if is_xhr or accepts_json: |
| | return jsonify({"success": True, "redirect": redirect_url}) |
| |
|
| | |
| | return redirect(redirect_url) |
| | else: |
| | |
| | app.logger.warning(f"Turnstile verification failed: {result}") |
| |
|
| | |
| | if request.headers.get("X-Requested-With") == "XMLHttpRequest": |
| | return jsonify({"success": False, "error": "Verification failed"}), 403 |
| |
|
| | |
| | return redirect(url_for("turnstile_page", redirect_url=redirect_url)) |
| |
|
| | except Exception as e: |
| | app.logger.error(f"Turnstile verification error: {str(e)}") |
| |
|
| | |
| | if request.headers.get("X-Requested-With") == "XMLHttpRequest": |
| | return ( |
| | jsonify( |
| | {"success": False, "error": "Server error during verification"} |
| | ), |
| | 500, |
| | ) |
| |
|
| | |
| | return redirect(url_for("turnstile_page", redirect_url=redirect_url)) |
| |
|
| | with open("sentences.txt", "r") as f, open("emotional_sentences.txt", "r") as f_emotional: |
| | |
| | all_harvard_sentences = [line.strip() for line in f.readlines() if line.strip()] + [line.strip() for line in f_emotional.readlines() if line.strip()] |
| | |
| | initial_sentences = random.sample(all_harvard_sentences, min(len(all_harvard_sentences), 500)) |
| |
|
| | @app.route("/") |
| | def arena(): |
| | |
| | return render_template("arena.html", harvard_sentences=json.dumps(initial_sentences)) |
| |
|
| |
|
| | @app.route("/leaderboard") |
| | def leaderboard(): |
| | tts_leaderboard = get_leaderboard_data(ModelType.TTS) |
| | conversational_leaderboard = get_leaderboard_data(ModelType.CONVERSATIONAL) |
| | top_voters = get_top_voters(10) |
| |
|
| | |
| | tts_personal_leaderboard = None |
| | conversational_personal_leaderboard = None |
| | user_leaderboard_visibility = None |
| |
|
| | |
| | if current_user.is_authenticated: |
| | tts_personal_leaderboard = get_user_leaderboard(current_user.id, ModelType.TTS) |
| | conversational_personal_leaderboard = get_user_leaderboard( |
| | current_user.id, ModelType.CONVERSATIONAL |
| | ) |
| | user_leaderboard_visibility = current_user.show_in_leaderboard |
| |
|
| | |
| | tts_key_dates = get_key_historical_dates(ModelType.TTS) |
| | conversational_key_dates = get_key_historical_dates(ModelType.CONVERSATIONAL) |
| |
|
| | |
| | formatted_tts_dates = [date.strftime("%B %Y") for date in tts_key_dates] |
| | formatted_conversational_dates = [ |
| | date.strftime("%B %Y") for date in conversational_key_dates |
| | ] |
| |
|
| | return render_template( |
| | "leaderboard.html", |
| | tts_leaderboard=tts_leaderboard, |
| | conversational_leaderboard=conversational_leaderboard, |
| | tts_personal_leaderboard=tts_personal_leaderboard, |
| | conversational_personal_leaderboard=conversational_personal_leaderboard, |
| | tts_key_dates=tts_key_dates, |
| | conversational_key_dates=conversational_key_dates, |
| | formatted_tts_dates=formatted_tts_dates, |
| | formatted_conversational_dates=formatted_conversational_dates, |
| | top_voters=top_voters, |
| | user_leaderboard_visibility=user_leaderboard_visibility |
| | ) |
| |
|
| |
|
| | @app.route("/api/historical-leaderboard/<model_type>") |
| | def historical_leaderboard(model_type): |
| | """Get historical leaderboard data for a specific date""" |
| | if model_type not in [ModelType.TTS, ModelType.CONVERSATIONAL]: |
| | return jsonify({"error": "Invalid model type"}), 400 |
| |
|
| | |
| | date_str = request.args.get("date") |
| | if not date_str: |
| | return jsonify({"error": "Date parameter is required"}), 400 |
| |
|
| | try: |
| | |
| | target_date = datetime.strptime(date_str, "%Y-%m-%d") |
| |
|
| | |
| | leaderboard_data = get_historical_leaderboard_data(model_type, target_date) |
| |
|
| | return jsonify( |
| | {"date": target_date.strftime("%B %d, %Y"), "leaderboard": leaderboard_data} |
| | ) |
| | except ValueError: |
| | return jsonify({"error": "Invalid date format. Use YYYY-MM-DD"}), 400 |
| |
|
| |
|
| | @app.route("/about") |
| | def about(): |
| | return render_template("about.html") |
| |
|
| |
|
| | |
| |
|
| | def generate_and_save_tts(text, model_id, output_dir): |
| | """Generates TTS and saves it to a specific directory, returning the full path.""" |
| | temp_audio_path = None |
| | try: |
| | app.logger.debug(f"[TTS Gen {model_id}] Starting generation for: '{text[:30]}...'") |
| | |
| | reference_audio_path = None |
| | if reference_audio_files: |
| | reference_audio_path = random.choice(reference_audio_files) |
| | app.logger.debug(f"[TTS Gen {model_id}] Using reference audio: {reference_audio_path}") |
| | |
| | temp_audio_path = predict_tts(text, model_id, reference_audio_path=reference_audio_path) |
| | app.logger.debug(f"[TTS Gen {model_id}] predict_tts returned: {temp_audio_path}") |
| |
|
| | if not temp_audio_path or not os.path.exists(temp_audio_path): |
| | app.logger.warning(f"[TTS Gen {model_id}] predict_tts failed or returned invalid path: {temp_audio_path}") |
| | raise ValueError("predict_tts did not return a valid path or file does not exist") |
| |
|
| | file_uuid = str(uuid.uuid4()) |
| | dest_path = os.path.join(output_dir, f"{file_uuid}.wav") |
| | app.logger.debug(f"[TTS Gen {model_id}] Moving {temp_audio_path} to {dest_path}") |
| | |
| | shutil.move(temp_audio_path, dest_path) |
| | app.logger.debug(f"[TTS Gen {model_id}] Move successful. Returning {dest_path}") |
| | return dest_path, reference_audio_path |
| |
|
| | except Exception as e: |
| | app.logger.error(f"Error generating/saving TTS for model {model_id} and text '{text[:30]}...': {str(e)}") |
| | |
| | if temp_audio_path and os.path.exists(temp_audio_path): |
| | try: |
| | app.logger.debug(f"[TTS Gen {model_id}] Cleaning up temporary file {temp_audio_path} after error.") |
| | os.remove(temp_audio_path) |
| | except OSError: |
| | pass |
| | return None, None |
| |
|
| |
|
| | def _generate_cache_entry_task(sentence): |
| | """Task function to generate audio for a sentence and add to cache.""" |
| | |
| | with app.app_context(): |
| | if not sentence: |
| | |
| | with tts_cache_lock: |
| | cached_keys = set(tts_cache.keys()) |
| | available_sentences = [s for s in all_harvard_sentences if s not in cached_keys] |
| | if not available_sentences: |
| | app.logger.warning("No more unique Harvard sentences available for caching.") |
| | return |
| | sentence = random.choice(available_sentences) |
| |
|
| | |
| | print(f"[Cache Task] Querying models for: '{sentence[:50]}...'") |
| | available_models = Model.query.filter_by( |
| | model_type=ModelType.TTS, is_active=True |
| | ).all() |
| |
|
| | if len(available_models) < 2: |
| | app.logger.error("Not enough active TTS models to generate cache entry.") |
| | return |
| |
|
| | try: |
| | models = get_weighted_random_models(available_models, 2, ModelType.TTS) |
| | model_a_id = models[0].id |
| | model_b_id = models[1].id |
| |
|
| | |
| | with ThreadPoolExecutor(max_workers=2, thread_name_prefix='AudioGen') as audio_executor: |
| | future_a = audio_executor.submit(generate_and_save_tts, sentence, model_a_id, CACHE_AUDIO_DIR) |
| | future_b = audio_executor.submit(generate_and_save_tts, sentence, model_b_id, CACHE_AUDIO_DIR) |
| |
|
| | timeout_seconds = 120 |
| | audio_a_path, ref_a = future_a.result(timeout=timeout_seconds) |
| | audio_b_path, ref_b = future_b.result(timeout=timeout_seconds) |
| |
|
| | if audio_a_path and audio_b_path: |
| | with tts_cache_lock: |
| | |
| | |
| | if sentence not in tts_cache and len(tts_cache) < TTS_CACHE_SIZE: |
| | tts_cache[sentence] = { |
| | "model_a": model_a_id, |
| | "model_b": model_b_id, |
| | "audio_a": audio_a_path, |
| | "audio_b": audio_b_path, |
| | "ref_a": ref_a, |
| | "ref_b": ref_b, |
| | "created_at": datetime.utcnow(), |
| | } |
| | app.logger.info(f"Successfully cached entry for: '{sentence[:50]}...'") |
| | elif sentence in tts_cache: |
| | app.logger.warning(f"Sentence '{sentence[:50]}...' already re-cached. Discarding new generation.") |
| | |
| | if os.path.exists(audio_a_path): os.remove(audio_a_path) |
| | if os.path.exists(audio_b_path): os.remove(audio_b_path) |
| | else: |
| | app.logger.warning(f"Cache is full ({len(tts_cache)} entries). Discarding new generation for '{sentence[:50]}...'.") |
| | |
| | if os.path.exists(audio_a_path): os.remove(audio_a_path) |
| | if os.path.exists(audio_b_path): os.remove(audio_b_path) |
| |
|
| | else: |
| | app.logger.error(f"Failed to generate one or both audio files for cache: '{sentence[:50]}...'") |
| | |
| | if audio_a_path and os.path.exists(audio_a_path): os.remove(audio_a_path) |
| | if audio_b_path and os.path.exists(audio_b_path): os.remove(audio_b_path) |
| |
|
| | except Exception as e: |
| | |
| | app.logger.error(f"Exception in _generate_cache_entry_task for '{sentence[:50]}...': {str(e)}", exc_info=True) |
| |
|
| |
|
| | def initialize_tts_cache(): |
| | print("Initializing TTS cache") |
| | """Selects initial sentences and starts generation tasks.""" |
| | with app.app_context(): |
| | if not all_harvard_sentences: |
| | app.logger.error("Harvard sentences not loaded. Cannot initialize cache.") |
| | return |
| |
|
| | initial_selection = random.sample(all_harvard_sentences, min(len(all_harvard_sentences), TTS_CACHE_SIZE)) |
| | app.logger.info(f"Initializing TTS cache with {len(initial_selection)} sentences...") |
| |
|
| | for sentence in initial_selection: |
| | |
| | cache_executor.submit(_generate_cache_entry_task, sentence) |
| | app.logger.info("Submitted initial cache generation tasks.") |
| |
|
| | |
| |
|
| |
|
| | @app.route("/api/tts/generate", methods=["POST"]) |
| | @limiter.limit("10 per minute") |
| | @spaces.GPU |
| | def generate_tts(): |
| | |
| | if app.config["TURNSTILE_ENABLED"] and not session.get("turnstile_verified"): |
| | return jsonify({"error": "Turnstile verification required"}), 403 |
| |
|
| | |
| | if request.content_type and request.content_type.startswith('multipart/form-data'): |
| | text = request.form.get("text", "").strip() |
| | voice_file = request.files.get("voice_file") |
| | reference_audio_path = None |
| | if voice_file: |
| | temp_voice_path = os.path.join(TEMP_AUDIO_DIR, f"ref_{uuid.uuid4()}.wav") |
| | voice_file.save(temp_voice_path) |
| | reference_audio_path = temp_voice_path |
| | else: |
| | data = request.json |
| | text = data.get("text", "").strip() |
| | reference_audio_path = None |
| |
|
| | if not text or len(text) > 1000: |
| | return jsonify({"error": "Invalid or too long text"}), 400 |
| |
|
| | |
| | cache_hit = False |
| | session_data_from_cache = None |
| | with tts_cache_lock: |
| | if text in tts_cache: |
| | cache_hit = True |
| | cached_entry = tts_cache.pop(text) |
| | app.logger.info(f"TTS Cache HIT for: '{text[:50]}...'") |
| |
|
| | |
| | session_id = str(uuid.uuid4()) |
| | session_data_from_cache = { |
| | "model_a": cached_entry["model_a"], |
| | "model_b": cached_entry["model_b"], |
| | "audio_a": cached_entry["audio_a"], |
| | "audio_b": cached_entry["audio_b"], |
| | "text": text, |
| | "created_at": datetime.utcnow(), |
| | "expires_at": datetime.utcnow() + timedelta(minutes=30), |
| | "voted": False, |
| | } |
| | app.tts_sessions[session_id] = session_data_from_cache |
| |
|
| | |
| | |
| | current_cache_size = len(tts_cache) |
| | needed_refills = TTS_CACHE_SIZE - current_cache_size |
| | |
| | refills_to_submit = min(needed_refills, 8) |
| |
|
| | if refills_to_submit > 0: |
| | app.logger.info(f"Cache hit: Submitting {refills_to_submit} background task(s) to refill cache (current size: {current_cache_size}, target: {TTS_CACHE_SIZE}).") |
| | for _ in range(refills_to_submit): |
| | |
| | cache_executor.submit(_generate_cache_entry_task, None) |
| | else: |
| | app.logger.info(f"Cache hit: Cache is already full or at target size ({current_cache_size}/{TTS_CACHE_SIZE}). No refill tasks submitted.") |
| | |
| |
|
| | if cache_hit and session_data_from_cache: |
| | |
| | |
| | return jsonify( |
| | { |
| | "session_id": session_id, |
| | "audio_a": f"/api/tts/audio/{session_id}/a", |
| | "audio_b": f"/api/tts/audio/{session_id}/b", |
| | "expires_in": 1800, |
| | "cache_hit": True, |
| | } |
| | ) |
| | |
| |
|
| | |
| | app.logger.info(f"TTS Cache MISS for: '{text[:50]}...'. Generating on the fly.") |
| | available_models = Model.query.filter_by( |
| | model_type=ModelType.TTS, is_active=True |
| | ).all() |
| | if len(available_models) < 2: |
| | return jsonify({"error": "Not enough TTS models available"}), 500 |
| |
|
| | selected_models = get_weighted_random_models(available_models, 2, ModelType.TTS) |
| |
|
| | try: |
| | audio_files = [] |
| | model_ids = [] |
| |
|
| | |
| |
|
| | def process_model_on_the_fly(model): |
| | |
| | temp_audio_path = predict_tts(text, model.id, reference_audio_path=reference_audio_path) |
| | if not temp_audio_path or not os.path.exists(temp_audio_path): |
| | raise ValueError(f"predict_tts failed for model {model.id}") |
| |
|
| | |
| | file_uuid = str(uuid.uuid4()) |
| | dest_path = os.path.join(TEMP_AUDIO_DIR, f"{file_uuid}.wav") |
| | shutil.move(temp_audio_path, dest_path) |
| |
|
| | return {"model_id": model.id, "audio_path": dest_path} |
| |
|
| | |
| | with ThreadPoolExecutor(max_workers=2) as executor: |
| | results = list(executor.map(process_model_on_the_fly, selected_models)) |
| |
|
| | |
| | for result in results: |
| | model_ids.append(result["model_id"]) |
| | audio_files.append(result["audio_path"]) |
| |
|
| | |
| | session_id = str(uuid.uuid4()) |
| | app.tts_sessions[session_id] = { |
| | "model_a": model_ids[0], |
| | "model_b": model_ids[1], |
| | "audio_a": audio_files[0], |
| | "audio_b": audio_files[1], |
| | "text": text, |
| | "created_at": datetime.utcnow(), |
| | "expires_at": datetime.utcnow() + timedelta(minutes=30), |
| | "voted": False, |
| | } |
| |
|
| | |
| | if reference_audio_path and os.path.exists(reference_audio_path): |
| | os.remove(reference_audio_path) |
| |
|
| | |
| | return jsonify( |
| | { |
| | "session_id": session_id, |
| | "audio_a": f"/api/tts/audio/{session_id}/a", |
| | "audio_b": f"/api/tts/audio/{session_id}/b", |
| | "expires_in": 1800, |
| | "cache_hit": False, |
| | } |
| | ) |
| |
|
| | except Exception as e: |
| | app.logger.error(f"TTS on-the-fly generation error: {str(e)}", exc_info=True) |
| | |
| | if 'results' in locals(): |
| | for res in results: |
| | if 'audio_path' in res and os.path.exists(res['audio_path']): |
| | try: |
| | os.remove(res['audio_path']) |
| | except OSError: |
| | pass |
| | |
| | if reference_audio_path and os.path.exists(reference_audio_path): |
| | os.remove(reference_audio_path) |
| | return jsonify({"error": "Failed to generate TTS"}), 500 |
| | |
| |
|
| |
|
| | @app.route("/api/tts/audio/<session_id>/<model_key>") |
| | def get_audio(session_id, model_key): |
| | |
| | if app.config["TURNSTILE_ENABLED"] and not session.get("turnstile_verified"): |
| | return jsonify({"error": "Turnstile verification required"}), 403 |
| |
|
| | if session_id not in app.tts_sessions: |
| | return jsonify({"error": "Invalid or expired session"}), 404 |
| |
|
| | session_data = app.tts_sessions[session_id] |
| |
|
| | |
| | if datetime.utcnow() > session_data["expires_at"]: |
| | cleanup_session(session_id) |
| | return jsonify({"error": "Session expired"}), 410 |
| |
|
| | if model_key == "a": |
| | audio_path = session_data["audio_a"] |
| | elif model_key == "b": |
| | audio_path = session_data["audio_b"] |
| | else: |
| | return jsonify({"error": "Invalid model key"}), 400 |
| |
|
| | |
| | if not os.path.exists(audio_path): |
| | return jsonify({"error": "Audio file not found"}), 404 |
| |
|
| | return send_file(audio_path, mimetype="audio/wav") |
| |
|
| |
|
| | @app.route("/api/tts/vote", methods=["POST"]) |
| | @limiter.limit("30 per minute") |
| | def submit_vote(): |
| | |
| | if app.config["TURNSTILE_ENABLED"] and not session.get("turnstile_verified"): |
| | return jsonify({"error": "Turnstile verification required"}), 403 |
| |
|
| | data = request.json |
| | session_id = data.get("session_id") |
| | chosen_model_key = data.get("chosen_model") |
| |
|
| | if not session_id or session_id not in app.tts_sessions: |
| | return jsonify({"error": "Invalid or expired session"}), 404 |
| |
|
| | if not chosen_model_key or chosen_model_key not in ["a", "b"]: |
| | return jsonify({"error": "Invalid chosen model"}), 400 |
| |
|
| | session_data = app.tts_sessions[session_id] |
| |
|
| | |
| | if datetime.utcnow() > session_data["expires_at"]: |
| | cleanup_session(session_id) |
| | return jsonify({"error": "Session expired"}), 410 |
| |
|
| | |
| | if session_data["voted"]: |
| | return jsonify({"error": "Vote already submitted for this session"}), 400 |
| |
|
| | |
| | chosen_id = ( |
| | session_data["model_a"] if chosen_model_key == "a" else session_data["model_b"] |
| | ) |
| | rejected_id = ( |
| | session_data["model_b"] if chosen_model_key == "a" else session_data["model_a"] |
| | ) |
| | chosen_audio_path = ( |
| | session_data["audio_a"] if chosen_model_key == "a" else session_data["audio_b"] |
| | ) |
| | rejected_audio_path = ( |
| | session_data["audio_b"] if chosen_model_key == "a" else session_data["audio_a"] |
| | ) |
| |
|
| | |
| | user_id = current_user.id if current_user.is_authenticated else None |
| | vote, error = record_vote( |
| | user_id, session_data["text"], chosen_id, rejected_id, ModelType.TTS |
| | ) |
| |
|
| | if error: |
| | return jsonify({"error": error}), 500 |
| |
|
| | |
| | try: |
| | vote_uuid = str(uuid.uuid4()) |
| | vote_dir = os.path.join("./votes", vote_uuid) |
| | os.makedirs(vote_dir, exist_ok=True) |
| |
|
| | |
| | shutil.copy(chosen_audio_path, os.path.join(vote_dir, "chosen.wav")) |
| | shutil.copy(rejected_audio_path, os.path.join(vote_dir, "rejected.wav")) |
| |
|
| | |
| | chosen_model_obj = Model.query.get(chosen_id) |
| | rejected_model_obj = Model.query.get(rejected_id) |
| | metadata = { |
| | "text": session_data["text"], |
| | "chosen_model": chosen_model_obj.name if chosen_model_obj else "Unknown", |
| | "chosen_model_id": chosen_model_obj.id if chosen_model_obj else "Unknown", |
| | "rejected_model": rejected_model_obj.name if rejected_model_obj else "Unknown", |
| | "rejected_model_id": rejected_model_obj.id if rejected_model_obj else "Unknown", |
| | "session_id": session_id, |
| | "timestamp": datetime.utcnow().isoformat(), |
| | "username": current_user.username if current_user.is_authenticated else None, |
| | "model_type": "TTS" |
| | } |
| | with open(os.path.join(vote_dir, "metadata.json"), "w") as f: |
| | json.dump(metadata, f, indent=2) |
| |
|
| | except Exception as e: |
| | app.logger.error(f"Error saving preference data for vote {session_id}: {str(e)}") |
| | |
| |
|
| | |
| | session_data["voted"] = True |
| |
|
| | |
| | return jsonify( |
| | { |
| | "success": True, |
| | "chosen_model": {"id": chosen_id, "name": chosen_model_obj.name if chosen_model_obj else "Unknown"}, |
| | "rejected_model": { |
| | "id": rejected_id, |
| | "name": rejected_model_obj.name if rejected_model_obj else "Unknown", |
| | }, |
| | "names": { |
| | "a": ( |
| | chosen_model_obj.name if chosen_model_key == "a" else rejected_model_obj.name |
| | if chosen_model_obj and rejected_model_obj else "Unknown" |
| | ), |
| | "b": ( |
| | rejected_model_obj.name if chosen_model_key == "a" else chosen_model_obj.name |
| | if chosen_model_obj and rejected_model_obj else "Unknown" |
| | ), |
| | }, |
| | } |
| | ) |
| |
|
| |
|
| | def cleanup_session(session_id): |
| | """Remove session and its audio files""" |
| | if session_id in app.tts_sessions: |
| | session = app.tts_sessions[session_id] |
| |
|
| | |
| | for audio_file in [session["audio_a"], session["audio_b"]]: |
| | if os.path.exists(audio_file): |
| | try: |
| | os.remove(audio_file) |
| | except Exception as e: |
| | app.logger.error(f"Error removing audio file: {str(e)}") |
| |
|
| | |
| | del app.tts_sessions[session_id] |
| |
|
| |
|
| | @app.route("/api/conversational/generate", methods=["POST"]) |
| | @limiter.limit("5 per minute") |
| | def generate_podcast(): |
| | |
| | if app.config["TURNSTILE_ENABLED"] and not session.get("turnstile_verified"): |
| | return jsonify({"error": "Turnstile verification required"}), 403 |
| |
|
| | data = request.json |
| | script = data.get("script") |
| |
|
| | if not script or not isinstance(script, list) or len(script) < 2: |
| | return jsonify({"error": "Invalid script format or too short"}), 400 |
| |
|
| | |
| | for line in script: |
| | if not isinstance(line, dict) or "text" not in line or "speaker_id" not in line: |
| | return ( |
| | jsonify( |
| | { |
| | "error": "Invalid script line format. Each line must have text and speaker_id" |
| | } |
| | ), |
| | 400, |
| | ) |
| | if ( |
| | not line["text"] |
| | or not isinstance(line["speaker_id"], int) |
| | or line["speaker_id"] not in [0, 1] |
| | ): |
| | return ( |
| | jsonify({"error": "Invalid script content. Speaker ID must be 0 or 1"}), |
| | 400, |
| | ) |
| |
|
| | |
| | available_models = Model.query.filter_by( |
| | model_type=ModelType.CONVERSATIONAL, is_active=True |
| | ).all() |
| |
|
| | if len(available_models) < 2: |
| | return jsonify({"error": "Not enough conversational models available"}), 500 |
| |
|
| | selected_models = get_weighted_random_models(available_models, 2, ModelType.CONVERSATIONAL) |
| |
|
| | try: |
| | |
| | audio_files = [] |
| | model_ids = [] |
| |
|
| | |
| | def process_model(model): |
| | |
| | audio_content = predict_tts(script, model.id) |
| |
|
| | |
| | file_uuid = str(uuid.uuid4()) |
| | dest_path = os.path.join(TEMP_AUDIO_DIR, f"{file_uuid}.wav") |
| |
|
| | with open(dest_path, "wb") as f: |
| | f.write(audio_content) |
| |
|
| | return {"model_id": model.id, "audio_path": dest_path} |
| |
|
| | |
| | with ThreadPoolExecutor(max_workers=2) as executor: |
| | results = list(executor.map(process_model, selected_models)) |
| |
|
| | |
| | for result in results: |
| | model_ids.append(result["model_id"]) |
| | audio_files.append(result["audio_path"]) |
| |
|
| | |
| | session_id = str(uuid.uuid4()) |
| | script_text = " ".join([line["text"] for line in script]) |
| | app.conversational_sessions[session_id] = { |
| | "model_a": model_ids[0], |
| | "model_b": model_ids[1], |
| | "audio_a": audio_files[0], |
| | "audio_b": audio_files[1], |
| | "text": script_text[:1000], |
| | "created_at": datetime.utcnow(), |
| | "expires_at": datetime.utcnow() + timedelta(minutes=30), |
| | "voted": False, |
| | "script": script, |
| | } |
| |
|
| | |
| | return jsonify( |
| | { |
| | "session_id": session_id, |
| | "audio_a": f"/api/conversational/audio/{session_id}/a", |
| | "audio_b": f"/api/conversational/audio/{session_id}/b", |
| | "expires_in": 1800, |
| | } |
| | ) |
| |
|
| | except Exception as e: |
| | app.logger.error(f"Conversational generation error: {str(e)}") |
| | return jsonify({"error": f"Failed to generate podcast: {str(e)}"}), 500 |
| |
|
| |
|
| | @app.route("/api/conversational/audio/<session_id>/<model_key>") |
| | def get_podcast_audio(session_id, model_key): |
| | |
| | if app.config["TURNSTILE_ENABLED"] and not session.get("turnstile_verified"): |
| | return jsonify({"error": "Turnstile verification required"}), 403 |
| |
|
| | if session_id not in app.conversational_sessions: |
| | return jsonify({"error": "Invalid or expired session"}), 404 |
| |
|
| | session_data = app.conversational_sessions[session_id] |
| |
|
| | |
| | if datetime.utcnow() > session_data["expires_at"]: |
| | cleanup_conversational_session(session_id) |
| | return jsonify({"error": "Session expired"}), 410 |
| |
|
| | if model_key == "a": |
| | audio_path = session_data["audio_a"] |
| | elif model_key == "b": |
| | audio_path = session_data["audio_b"] |
| | else: |
| | return jsonify({"error": "Invalid model key"}), 400 |
| |
|
| | |
| | if not os.path.exists(audio_path): |
| | return jsonify({"error": "Audio file not found"}), 404 |
| |
|
| | return send_file(audio_path, mimetype="audio/wav") |
| |
|
| |
|
| | @app.route("/api/conversational/vote", methods=["POST"]) |
| | @limiter.limit("30 per minute") |
| | def submit_podcast_vote(): |
| | |
| | if app.config["TURNSTILE_ENABLED"] and not session.get("turnstile_verified"): |
| | return jsonify({"error": "Turnstile verification required"}), 403 |
| |
|
| | data = request.json |
| | session_id = data.get("session_id") |
| | chosen_model_key = data.get("chosen_model") |
| |
|
| | if not session_id or session_id not in app.conversational_sessions: |
| | return jsonify({"error": "Invalid or expired session"}), 404 |
| |
|
| | if not chosen_model_key or chosen_model_key not in ["a", "b"]: |
| | return jsonify({"error": "Invalid chosen model"}), 400 |
| |
|
| | session_data = app.conversational_sessions[session_id] |
| |
|
| | |
| | if datetime.utcnow() > session_data["expires_at"]: |
| | cleanup_conversational_session(session_id) |
| | return jsonify({"error": "Session expired"}), 410 |
| |
|
| | |
| | if session_data["voted"]: |
| | return jsonify({"error": "Vote already submitted for this session"}), 400 |
| |
|
| | |
| | chosen_id = ( |
| | session_data["model_a"] if chosen_model_key == "a" else session_data["model_b"] |
| | ) |
| | rejected_id = ( |
| | session_data["model_b"] if chosen_model_key == "a" else session_data["model_a"] |
| | ) |
| | chosen_audio_path = ( |
| | session_data["audio_a"] if chosen_model_key == "a" else session_data["audio_b"] |
| | ) |
| | rejected_audio_path = ( |
| | session_data["audio_b"] if chosen_model_key == "a" else session_data["audio_a"] |
| | ) |
| |
|
| | |
| | user_id = current_user.id if current_user.is_authenticated else None |
| | vote, error = record_vote( |
| | user_id, session_data["text"], chosen_id, rejected_id, ModelType.CONVERSATIONAL |
| | ) |
| |
|
| | if error: |
| | return jsonify({"error": error}), 500 |
| |
|
| | |
| | try: |
| | vote_uuid = str(uuid.uuid4()) |
| | vote_dir = os.path.join("./votes", vote_uuid) |
| | os.makedirs(vote_dir, exist_ok=True) |
| |
|
| | |
| | shutil.copy(chosen_audio_path, os.path.join(vote_dir, "chosen.wav")) |
| | shutil.copy(rejected_audio_path, os.path.join(vote_dir, "rejected.wav")) |
| |
|
| | |
| | chosen_model_obj = Model.query.get(chosen_id) |
| | rejected_model_obj = Model.query.get(rejected_id) |
| | metadata = { |
| | "script": session_data["script"], |
| | "chosen_model": chosen_model_obj.name if chosen_model_obj else "Unknown", |
| | "chosen_model_id": chosen_model_obj.id if chosen_model_obj else "Unknown", |
| | "rejected_model": rejected_model_obj.name if rejected_model_obj else "Unknown", |
| | "rejected_model_id": rejected_model_obj.id if rejected_model_obj else "Unknown", |
| | "session_id": session_id, |
| | "timestamp": datetime.utcnow().isoformat(), |
| | "username": current_user.username if current_user.is_authenticated else None, |
| | "model_type": "CONVERSATIONAL" |
| | } |
| | with open(os.path.join(vote_dir, "metadata.json"), "w") as f: |
| | json.dump(metadata, f, indent=2) |
| |
|
| | except Exception as e: |
| | app.logger.error(f"Error saving preference data for conversational vote {session_id}: {str(e)}") |
| | |
| |
|
| | |
| | session_data["voted"] = True |
| |
|
| | |
| | return jsonify( |
| | { |
| | "success": True, |
| | "chosen_model": {"id": chosen_id, "name": chosen_model_obj.name if chosen_model_obj else "Unknown"}, |
| | "rejected_model": { |
| | "id": rejected_id, |
| | "name": rejected_model_obj.name if rejected_model_obj else "Unknown", |
| | }, |
| | "names": { |
| | "a": Model.query.get(session_data["model_a"]).name, |
| | "b": Model.query.get(session_data["model_b"]).name, |
| | }, |
| | } |
| | ) |
| |
|
| |
|
| | def cleanup_conversational_session(session_id): |
| | """Remove conversational session and its audio files""" |
| | if session_id in app.conversational_sessions: |
| | session = app.conversational_sessions[session_id] |
| |
|
| | |
| | for audio_file in [session["audio_a"], session["audio_b"]]: |
| | if os.path.exists(audio_file): |
| | try: |
| | os.remove(audio_file) |
| | except Exception as e: |
| | app.logger.error( |
| | f"Error removing conversational audio file: {str(e)}" |
| | ) |
| |
|
| | |
| | del app.conversational_sessions[session_id] |
| |
|
| |
|
| | |
| | def setup_cleanup(): |
| | def cleanup_expired_sessions(): |
| | with app.app_context(): |
| | current_time = datetime.utcnow() |
| | |
| | expired_tts_sessions = [ |
| | sid |
| | for sid, session_data in app.tts_sessions.items() |
| | if current_time > session_data["expires_at"] |
| | ] |
| | for sid in expired_tts_sessions: |
| | cleanup_session(sid) |
| |
|
| | |
| | expired_conv_sessions = [ |
| | sid |
| | for sid, session_data in app.conversational_sessions.items() |
| | if current_time > session_data["expires_at"] |
| | ] |
| | for sid in expired_conv_sessions: |
| | cleanup_conversational_session(sid) |
| | app.logger.info(f"Cleaned up {len(expired_tts_sessions)} TTS and {len(expired_conv_sessions)} conversational sessions.") |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | scheduler = BackgroundScheduler(daemon=True) |
| | scheduler.add_job(cleanup_expired_sessions, "interval", minutes=15) |
| | scheduler.start() |
| | print("Cleanup scheduler started") |
| |
|
| |
|
| | |
| | def setup_periodic_tasks(): |
| | """Setup periodic database synchronization and preference data upload for Spaces""" |
| | if not IS_SPACES: |
| | return |
| |
|
| | db_path = app.config["SQLALCHEMY_DATABASE_URI"].replace("sqlite:///", "instance/") |
| | preferences_repo_id = "kemuriririn/arena-preferences" |
| | database_repo_id = "kemuriririn/database-arena" |
| | votes_dir = "./votes" |
| |
|
| | def sync_database(): |
| | """Uploads the database to HF dataset""" |
| | with app.app_context(): |
| | try: |
| | if not os.path.exists(db_path): |
| | app.logger.warning(f"Database file not found at {db_path}, skipping sync.") |
| | return |
| |
|
| | api = HfApi(token=os.getenv("HF_TOKEN")) |
| | api.upload_file( |
| | path_or_fileobj=db_path, |
| | path_in_repo="tts_arena.db", |
| | repo_id=database_repo_id, |
| | repo_type="dataset", |
| | ) |
| | app.logger.info(f"Database uploaded to {database_repo_id} at {datetime.utcnow()}") |
| | except Exception as e: |
| | app.logger.error(f"Error uploading database to {database_repo_id}: {str(e)}") |
| |
|
| | def sync_preferences_data(): |
| | """Zips and uploads preference data folders in batches to HF dataset""" |
| | with app.app_context(): |
| | if not os.path.isdir(votes_dir): |
| | return |
| |
|
| | temp_batch_dir = None |
| | temp_individual_zip_dir = None |
| | local_batch_zip_path = None |
| |
|
| | try: |
| | api = HfApi(token=os.getenv("HF_TOKEN")) |
| | vote_uuids = [d for d in os.listdir(votes_dir) if os.path.isdir(os.path.join(votes_dir, d))] |
| |
|
| | if not vote_uuids: |
| | return |
| |
|
| | app.logger.info(f"Found {len(vote_uuids)} vote directories to process.") |
| |
|
| | |
| | temp_batch_dir = tempfile.mkdtemp(prefix="hf_batch_") |
| | temp_individual_zip_dir = tempfile.mkdtemp(prefix="hf_indiv_zips_") |
| | app.logger.debug(f"Created temp directories: {temp_batch_dir}, {temp_individual_zip_dir}") |
| |
|
| | processed_vote_dirs = [] |
| | individual_zips_in_batch = [] |
| |
|
| | |
| | for vote_uuid in vote_uuids: |
| | dir_path = os.path.join(votes_dir, vote_uuid) |
| | individual_zip_base_path = os.path.join(temp_individual_zip_dir, vote_uuid) |
| | individual_zip_path = f"{individual_zip_base_path}.zip" |
| |
|
| | try: |
| | shutil.make_archive(individual_zip_base_path, 'zip', dir_path) |
| | app.logger.debug(f"Created individual zip: {individual_zip_path}") |
| |
|
| | |
| | final_individual_zip_path = os.path.join(temp_batch_dir, f"{vote_uuid}.zip") |
| | shutil.move(individual_zip_path, final_individual_zip_path) |
| | app.logger.debug(f"Moved individual zip to batch dir: {final_individual_zip_path}") |
| |
|
| | processed_vote_dirs.append(dir_path) |
| | individual_zips_in_batch.append(final_individual_zip_path) |
| |
|
| | except Exception as zip_err: |
| | app.logger.error(f"Error creating or moving zip for {vote_uuid}: {str(zip_err)}") |
| | |
| | if os.path.exists(individual_zip_path): |
| | try: |
| | os.remove(individual_zip_path) |
| | except OSError: |
| | pass |
| | |
| |
|
| | |
| | shutil.rmtree(temp_individual_zip_dir) |
| | temp_individual_zip_dir = None |
| | app.logger.debug("Cleaned up temporary individual zip directory.") |
| |
|
| | if not individual_zips_in_batch: |
| | app.logger.warning("No individual zips were successfully created for batching.") |
| | |
| | if temp_batch_dir and os.path.exists(temp_batch_dir): |
| | shutil.rmtree(temp_batch_dir) |
| | temp_batch_dir = None |
| | return |
| |
|
| | |
| | batch_timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S") |
| | batch_uuid_short = str(uuid.uuid4())[:8] |
| | batch_zip_filename = f"{batch_timestamp}_batch_{batch_uuid_short}.zip" |
| | |
| | local_batch_zip_base = os.path.join(tempfile.gettempdir(), batch_zip_filename.replace('.zip', '')) |
| | local_batch_zip_path = f"{local_batch_zip_base}.zip" |
| |
|
| | app.logger.info(f"Creating batch zip: {local_batch_zip_path} with {len(individual_zips_in_batch)} individual zips.") |
| | shutil.make_archive(local_batch_zip_base, 'zip', temp_batch_dir) |
| | app.logger.info(f"Batch zip created successfully: {local_batch_zip_path}") |
| |
|
| | |
| | hf_repo_path = f"votes/{year}/{month}/{batch_zip_filename}" |
| | app.logger.info(f"Uploading batch zip to HF Hub: {preferences_repo_id}/{hf_repo_path}") |
| |
|
| | api.upload_file( |
| | path_or_fileobj=local_batch_zip_path, |
| | path_in_repo=hf_repo_path, |
| | repo_id=preferences_repo_id, |
| | repo_type="dataset", |
| | commit_message=f"Add batch preference data {batch_zip_filename} ({len(individual_zips_in_batch)} votes)" |
| | ) |
| | app.logger.info(f"Successfully uploaded batch {batch_zip_filename} to {preferences_repo_id}") |
| |
|
| | |
| | app.logger.info("Cleaning up local files after successful upload.") |
| | |
| | for dir_path in processed_vote_dirs: |
| | try: |
| | shutil.rmtree(dir_path) |
| | app.logger.debug(f"Removed original vote directory: {dir_path}") |
| | except OSError as e: |
| | app.logger.error(f"Error removing processed vote directory {dir_path}: {str(e)}") |
| |
|
| | |
| | shutil.rmtree(temp_batch_dir) |
| | temp_batch_dir = None |
| | app.logger.debug("Removed temporary batch directory.") |
| |
|
| | |
| | os.remove(local_batch_zip_path) |
| | local_batch_zip_path = None |
| | app.logger.debug("Removed local batch zip file.") |
| |
|
| | app.logger.info(f"Finished preference data sync. Uploaded batch {batch_zip_filename}.") |
| |
|
| | except Exception as e: |
| | app.logger.error(f"Error during preference data batch sync: {str(e)}", exc_info=True) |
| | |
| | if local_batch_zip_path and os.path.exists(local_batch_zip_path): |
| | try: |
| | os.remove(local_batch_zip_path) |
| | app.logger.debug("Cleaned up local batch zip after failed upload.") |
| | except OSError as clean_err: |
| | app.logger.error(f"Error cleaning up batch zip after failed upload: {clean_err}") |
| | |
| | |
| |
|
| | finally: |
| | |
| | if temp_individual_zip_dir and os.path.exists(temp_individual_zip_dir): |
| | try: |
| | shutil.rmtree(temp_individual_zip_dir) |
| | except Exception as final_clean_err: |
| | app.logger.error(f"Error in final cleanup (indiv zips): {final_clean_err}") |
| | |
| | if temp_batch_dir and os.path.exists(temp_batch_dir): |
| | |
| | upload_failed = 'e' in locals() and isinstance(e, Exception) |
| | if not upload_failed: |
| | try: |
| | shutil.rmtree(temp_batch_dir) |
| | except Exception as final_clean_err: |
| | app.logger.error(f"Error in final cleanup (batch dir): {final_clean_err}") |
| | else: |
| | app.logger.warning("Keeping temporary batch directory due to upload failure for next attempt.") |
| |
|
| |
|
| | |
| | scheduler = BackgroundScheduler() |
| | |
| | scheduler.add_job(sync_database, "interval", minutes=15, id="sync_db_job") |
| | |
| | scheduler.add_job(sync_preferences_data, "interval", minutes=5, id="sync_pref_job") |
| | scheduler.start() |
| | print("Periodic tasks scheduler started (DB sync and Preferences upload)") |
| |
|
| |
|
| | @app.cli.command("init-db") |
| | def init_db(): |
| | """Initialize the database.""" |
| | with app.app_context(): |
| | db.create_all() |
| | print("Database initialized!") |
| |
|
| |
|
| | @app.route("/api/toggle-leaderboard-visibility", methods=["POST"]) |
| | def toggle_leaderboard_visibility(): |
| | """Toggle whether the current user appears in the top voters leaderboard""" |
| | if not current_user.is_authenticated: |
| | return jsonify({"error": "You must be logged in to change this setting"}), 401 |
| | |
| | new_status = toggle_user_leaderboard_visibility(current_user.id) |
| | if new_status is None: |
| | return jsonify({"error": "User not found"}), 404 |
| | |
| | return jsonify({ |
| | "success": True, |
| | "visible": new_status, |
| | "message": "You are now visible in the voters leaderboard" if new_status else "You are now hidden from the voters leaderboard" |
| | }) |
| |
|
| |
|
| | @app.route("/api/tts/cached-sentences") |
| | def get_cached_sentences(): |
| | """Returns a list of sentences currently available in the TTS cache, with reference audio.""" |
| | with tts_cache_lock: |
| | cached = [ |
| | { |
| | "sentence": k, |
| | "model_a": v["model_a"], |
| | "model_b": v["model_b"], |
| | "ref_a": os.path.relpath(v["ref_a"], start=REFERENCE_AUDIO_DIR) if v.get("ref_a") else None, |
| | "ref_b": os.path.relpath(v["ref_b"], start=REFERENCE_AUDIO_DIR) if v.get("ref_b") else None, |
| | } |
| | for k, v in tts_cache.items() |
| | ] |
| | return jsonify(cached) |
| |
|
| | @app.route("/api/tts/reference-audio/<filename>") |
| | def get_reference_audio(filename): |
| | """试听参考音频""" |
| | file_path = os.path.join(REFERENCE_AUDIO_DIR, filename) |
| | if not os.path.exists(file_path): |
| | return jsonify({"error": "Reference audio not found"}), 404 |
| | return send_file(file_path, mimetype="audio/wav") |
| |
|
| |
|
| | def get_weighted_random_models( |
| | applicable_models: list[Model], num_to_select: int, model_type: ModelType |
| | ) -> list[Model]: |
| | """ |
| | Selects a specified number of models randomly from a list of applicable_models, |
| | weighting models with fewer votes higher. A smoothing factor is used to ensure |
| | the preference is slight and to prevent models with zero votes from being |
| | overwhelmingly favored. Models are selected without replacement. |
| | |
| | Assumes len(applicable_models) >= num_to_select, which should be checked by the caller. |
| | """ |
| | model_votes_counts = {} |
| | for model in applicable_models: |
| | votes = ( |
| | Vote.query.filter(Vote.model_type == model_type) |
| | .filter(or_(Vote.model_chosen == model.id, Vote.model_rejected == model.id)) |
| | .count() |
| | ) |
| | model_votes_counts[model.id] = votes |
| |
|
| | weights = [ |
| | 1.0 / (model_votes_counts[model.id] + SMOOTHING_FACTOR_MODEL_SELECTION) |
| | for model in applicable_models |
| | ] |
| |
|
| | selected_models_list = [] |
| | |
| | current_candidates = list(applicable_models) |
| | current_weights = list(weights) |
| |
|
| | |
| | |
| | for _ in range(num_to_select): |
| | if not current_candidates: |
| | app.logger.warning("Not enough candidates left for weighted selection.") |
| | break |
| | |
| | chosen_model = random.choices(current_candidates, weights=current_weights, k=1)[0] |
| | selected_models_list.append(chosen_model) |
| |
|
| | try: |
| | idx_to_remove = current_candidates.index(chosen_model) |
| | current_candidates.pop(idx_to_remove) |
| | current_weights.pop(idx_to_remove) |
| | except ValueError: |
| | |
| | app.logger.error(f"Error removing model {chosen_model.id} from weighted selection candidates.") |
| | break |
| |
|
| | return selected_models_list |
| |
|
| |
|
| | if __name__ == "__main__": |
| | with app.app_context(): |
| | |
| | os.makedirs("instance", exist_ok=True) |
| | os.makedirs("./votes", exist_ok=True) |
| | os.makedirs(CACHE_AUDIO_DIR, exist_ok=True) |
| |
|
| | |
| | try: |
| | app.logger.info(f"Clearing old cache audio files from {CACHE_AUDIO_DIR}") |
| | for filename in os.listdir(CACHE_AUDIO_DIR): |
| | file_path = os.path.join(CACHE_AUDIO_DIR, filename) |
| | try: |
| | if os.path.isfile(file_path) or os.path.islink(file_path): |
| | os.unlink(file_path) |
| | elif os.path.isdir(file_path): |
| | shutil.rmtree(file_path) |
| | except Exception as e: |
| | app.logger.error(f'Failed to delete {file_path}. Reason: {e}') |
| | except Exception as e: |
| | app.logger.error(f"Error clearing cache directory {CACHE_AUDIO_DIR}: {e}") |
| |
|
| |
|
| | |
| | if IS_SPACES and not os.path.exists(app.config["SQLALCHEMY_DATABASE_URI"].replace("sqlite:///", "")): |
| | try: |
| | print("Database not found, downloading from HF dataset...") |
| | hf_hub_download( |
| | repo_id="kemuriririn/database-arena", |
| | filename="tts_arena.db", |
| | repo_type="dataset", |
| | local_dir="instance", |
| | token=os.getenv("HF_TOKEN"), |
| | ) |
| | print("Database downloaded successfully ✅") |
| | except Exception as e: |
| | print(f"Error downloading database from HF dataset: {str(e)} ⚠️") |
| |
|
| | download_reference_audios() |
| |
|
| | db.create_all() |
| | insert_initial_models() |
| | |
| | |
| | setup_cleanup() |
| | setup_periodic_tasks() |
| |
|
| | |
| | from werkzeug.middleware.proxy_fix import ProxyFix |
| |
|
| | |
| | |
| | |
| | app.wsgi_app = ProxyFix(app.wsgi_app, x_proto=1, x_host=1) |
| |
|
| | |
| | app.config["PREFERRED_URL_SCHEME"] = "https" |
| |
|
| | from waitress import serve |
| |
|
| | |
| | |
| | |
| | |
| | threads = 12 |
| |
|
| | if IS_SPACES: |
| | serve( |
| | app, |
| | host="0.0.0.0", |
| | port=int(os.environ.get("PORT", 7860)), |
| | threads=threads, |
| | connection_limit=100, |
| | channel_timeout=30, |
| | url_scheme='https' |
| | ) |
| | else: |
| | print(f"Starting Waitress server with {threads} threads") |
| | serve( |
| | app, |
| | host="0.0.0.0", |
| | port=5000, |
| | threads=threads, |
| | connection_limit=100, |
| | channel_timeout=30, |
| | url_scheme='https' |
| | ) |
| |
|