Spaces:
Running on Zero
Running on Zero
| import os | |
| import spaces | |
| REPO_TYPE = "hf" | |
| from huggingface_hub import snapshot_download | |
| MODEL_CACHE_DIR = "./models" | |
| FUN_ASR_NANO_LOCAL_PATH = os.path.join(MODEL_CACHE_DIR, "Fun-ASR-Nano") | |
| SENSE_VOICE_SMALL_LOCAL_PATH = os.path.join(MODEL_CACHE_DIR, "SenseVoiceSmall") | |
| VAD_MODEL_LOCAL_PATH = os.path.join(MODEL_CACHE_DIR, "fsmn-vad") | |
| os.makedirs(MODEL_CACHE_DIR, exist_ok=True) | |
| FUN_ASR_NANO_REPO_ID = "FunAudioLLM/Fun-ASR-Nano-2512" | |
| SENSE_VOICE_SMALL_REPO_ID = "FunAudioLLM/SenseVoiceSmall" | |
| VAD_MODEL_REPO_ID = "funasr/fsmn-vad" | |
| def download_model_if_not_exists(repo_id, local_path, model_name): | |
| if not os.path.exists(local_path): | |
| print(f"Downloading {model_name} to {local_path} ...") | |
| snapshot_download(repo_id=repo_id, local_dir=local_path, ignore_patterns=["*.onnx"]) | |
| print(f"{model_name} downloaded.") | |
| else: | |
| print(f"{model_name} found locally, skipping download.") | |
| download_model_if_not_exists(FUN_ASR_NANO_REPO_ID, FUN_ASR_NANO_LOCAL_PATH, "Fun-ASR-Nano") | |
| download_model_if_not_exists(SENSE_VOICE_SMALL_REPO_ID, SENSE_VOICE_SMALL_LOCAL_PATH, "SenseVoiceSmall") | |
| download_model_if_not_exists(VAD_MODEL_REPO_ID, VAD_MODEL_LOCAL_PATH, "VAD Model") | |
| import gradio as gr | |
| import time | |
| import tempfile | |
| import logging | |
| import torch | |
| from funasr import AutoModel | |
| from funasr.utils.postprocess_utils import rich_transcription_postprocess | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| # Lazy model loading - models are loaded on first use inside @spaces.GPU | |
| loaded_models = {} | |
| def get_or_load_model(pipeline_type): | |
| if pipeline_type in loaded_models: | |
| return loaded_models[pipeline_type] | |
| if pipeline_type == "fun-asr-nano": | |
| model = AutoModel( | |
| model=FUN_ASR_NANO_LOCAL_PATH, | |
| trust_remote_code=True, | |
| vad_model=VAD_MODEL_LOCAL_PATH, | |
| vad_kwargs={"max_single_segment_time": 30000}, | |
| device="cuda", | |
| disable_update=True, | |
| hub="hf", | |
| ) | |
| elif pipeline_type == "sensevoice": | |
| model = AutoModel( | |
| model=SENSE_VOICE_SMALL_LOCAL_PATH, | |
| trust_remote_code=False, | |
| vad_model=VAD_MODEL_LOCAL_PATH, | |
| vad_kwargs={"max_single_segment_time": 30000}, | |
| device="cuda", | |
| disable_update=True, | |
| hub="hf", | |
| ) | |
| else: | |
| raise ValueError(f"Unknown pipeline type: {pipeline_type}") | |
| loaded_models[pipeline_type] = model | |
| return model | |
| def transcribe_audio(audio_input, audio_url, pipeline_type, start_time=None, end_time=None): | |
| try: | |
| # Determine audio source | |
| audio_path = None | |
| is_temp_file = False | |
| if audio_input is not None and len(audio_input) > 0: | |
| audio_path = audio_input | |
| elif audio_url is not None and len(audio_url.strip()) > 0: | |
| import requests as req | |
| response = req.get(audio_url, stream=True, timeout=30) | |
| if response.status_code == 200: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| audio_path = f.name | |
| is_temp_file = True | |
| else: | |
| return f"Failed to download audio: HTTP {response.status_code}", "", None | |
| else: | |
| return "No audio provided. Upload a file, record, or enter a URL.", "", None | |
| # Trim if needed | |
| if start_time > 0 or end_time > 0: | |
| from pydub import AudioSegment | |
| audio = AudioSegment.from_file(audio_path) | |
| duration = len(audio) / 1000 | |
| s = float(start_time) if start_time > 0 else 0 | |
| e = float(end_time) if end_time > 0 else duration | |
| trimmed = audio[int(s * 1000):int(e * 1000)] | |
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") | |
| trimmed.export(tmp.name, format="wav") | |
| audio_path = tmp.name | |
| is_temp_file = True | |
| # Load model (lazy, inside GPU context) | |
| model = get_or_load_model(pipeline_type) | |
| # Transcribe | |
| t0 = time.time() | |
| if pipeline_type == "fun-asr-nano": | |
| res = model.generate(input=[audio_path], use_itn=True, batch_size=1) | |
| else: | |
| res = model.generate( | |
| input=audio_path, cache={}, language="auto", | |
| use_itn=True, batch_size_s=60, merge_vad=True, merge_length_s=15, | |
| ) | |
| transcription = rich_transcription_postprocess(res[0]["text"]) | |
| elapsed = time.time() - t0 | |
| metrics = f"Transcription time: {elapsed:.2f}s\nPipeline: {pipeline_type}\nDevice: cuda" | |
| # Save transcription file | |
| txt_file = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w", encoding="utf-8") | |
| txt_file.write(transcription) | |
| txt_file.close() | |
| return metrics, transcription, txt_file.name | |
| except Exception as e: | |
| logging.error(f"Transcription error: {e}") | |
| return f"Error: {str(e)}", "", None | |
| finally: | |
| if is_temp_file and audio_path and os.path.exists(audio_path): | |
| os.remove(audio_path) | |
| with gr.Blocks(title="Fun-ASR-Nano | GPU Demo") as demo: | |
| gr.Markdown(""" | |
| # Fun-ASR-Nano: LLM-Powered Speech Recognition (GPU) | |
| End-to-end ASR model trained on tens of millions of hours, supporting **31 languages** including Chinese dialects. | |
| - **GitHub**: [Fun-ASR](https://github.com/FunAudioLLM/Fun-ASR) | [FunASR Toolkit](https://github.com/modelscope/FunASR) | |
| - **Model**: [Fun-ASR-Nano-2512](https://huggingface.co/FunAudioLLM/Fun-ASR-Nano-2512) | |
| """) | |
| with gr.Row(): | |
| audio_input = gr.Audio(label="Upload or Record Audio", sources=["upload", "microphone"], type="filepath") | |
| audio_url = gr.Textbox(label="Or Enter Audio URL", placeholder="https://example.com/audio.wav") | |
| with gr.Row(): | |
| pipeline_type = gr.Dropdown( | |
| choices=["fun-asr-nano", "sensevoice"], | |
| label="Model", | |
| value="fun-asr-nano" | |
| ) | |
| start_time = gr.Number(label="Start Time (s)", value=0, minimum=0) | |
| end_time = gr.Number(label="End Time (s)", value=0, minimum=0) | |
| transcribe_btn = gr.Button("Transcribe", variant="primary") | |
| with gr.Row(): | |
| metrics_output = gr.Textbox(label="Metrics", lines=4) | |
| transcription_output = gr.Textbox(label="Transcription", lines=10) | |
| transcription_file = gr.File(label="Download") | |
| transcribe_btn.click( | |
| transcribe_audio, | |
| inputs=[audio_input, audio_url, pipeline_type, start_time, end_time], | |
| outputs=[metrics_output, transcription_output, transcription_file], | |
| ) | |
| gr.Markdown(""" | |
| ### Supported Languages | |
| - **Fun-ASR-Nano**: 31 languages + Chinese dialects (Cantonese, Sichuan, Shanghai, Minnan, etc.) | |
| - **SenseVoice**: Chinese, English, Cantonese, Japanese, Korean | |
| """) | |
| demo.queue().launch() | |