File size: 7,041 Bytes
cb8606e
4a8414a
a198709
59b606f
9c07b4e
59b606f
cb8606e
a483939
 
 
 
 
 
 
59b606f
 
 
4e69efc
 
a483939
 
59b606f
 
 
a483939
59b606f
 
399aaa2
a483939
 
 
cb8606e
 
 
 
 
 
 
 
 
 
 
59b606f
 
cb8606e
 
59b606f
 
 
cb8606e
 
59b606f
 
 
 
 
 
 
 
 
cb8606e
59b606f
 
 
 
 
 
 
 
 
cb8606e
59b606f
cb8606e
59b606f
 
cb8606e
 
59b606f
 
 
 
cb8606e
 
 
 
 
 
59b606f
 
 
 
 
 
 
 
cb8606e
59b606f
cb8606e
59b606f
cb8606e
59b606f
8ced29b
59b606f
 
 
8ced29b
 
59b606f
 
 
 
 
 
 
 
 
 
 
d666310
59b606f
cb8606e
 
59b606f
 
cb8606e
 
 
59b606f
cb8606e
59b606f
cb8606e
59b606f
 
 
 
cb8606e
59b606f
cb8606e
 
59b606f
 
cb8606e
59b606f
cb8606e
 
59b606f
 
 
 
 
 
 
 
 
 
 
cb8606e
 
59b606f
cb8606e
 
59b606f
 
 
 
 
8ced29b
 
59b606f
 
 
 
 
cb8606e
59b606f
 
 
 
 
 
cb8606e
59b606f
 
 
 
 
cb8606e
 
59b606f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import os
import spaces

REPO_TYPE = "hf"

from huggingface_hub import snapshot_download

MODEL_CACHE_DIR = "./models"
FUN_ASR_NANO_LOCAL_PATH = os.path.join(MODEL_CACHE_DIR, "Fun-ASR-Nano")
SENSE_VOICE_SMALL_LOCAL_PATH = os.path.join(MODEL_CACHE_DIR, "SenseVoiceSmall")
VAD_MODEL_LOCAL_PATH = os.path.join(MODEL_CACHE_DIR, "fsmn-vad")

os.makedirs(MODEL_CACHE_DIR, exist_ok=True)

FUN_ASR_NANO_REPO_ID = "FunAudioLLM/Fun-ASR-Nano-2512"
SENSE_VOICE_SMALL_REPO_ID = "FunAudioLLM/SenseVoiceSmall"
VAD_MODEL_REPO_ID = "funasr/fsmn-vad"


def download_model_if_not_exists(repo_id, local_path, model_name):
    if not os.path.exists(local_path):
        print(f"Downloading {model_name} to {local_path} ...")
        snapshot_download(repo_id=repo_id, local_dir=local_path, ignore_patterns=["*.onnx"])
        print(f"{model_name} downloaded.")
    else:
        print(f"{model_name} found locally, skipping download.")


download_model_if_not_exists(FUN_ASR_NANO_REPO_ID, FUN_ASR_NANO_LOCAL_PATH, "Fun-ASR-Nano")
download_model_if_not_exists(SENSE_VOICE_SMALL_REPO_ID, SENSE_VOICE_SMALL_LOCAL_PATH, "SenseVoiceSmall")
download_model_if_not_exists(VAD_MODEL_REPO_ID, VAD_MODEL_LOCAL_PATH, "VAD Model")

import gradio as gr
import time
import tempfile
import logging
import torch
from funasr import AutoModel
from funasr.utils.postprocess_utils import rich_transcription_postprocess

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Lazy model loading - models are loaded on first use inside @spaces.GPU
loaded_models = {}


def get_or_load_model(pipeline_type):
    if pipeline_type in loaded_models:
        return loaded_models[pipeline_type]

    if pipeline_type == "fun-asr-nano":
        model = AutoModel(
            model=FUN_ASR_NANO_LOCAL_PATH,
            trust_remote_code=True,
            vad_model=VAD_MODEL_LOCAL_PATH,
            vad_kwargs={"max_single_segment_time": 30000},
            device="cuda",
            disable_update=True,
            hub="hf",
        )
    elif pipeline_type == "sensevoice":
        model = AutoModel(
            model=SENSE_VOICE_SMALL_LOCAL_PATH,
            trust_remote_code=False,
            vad_model=VAD_MODEL_LOCAL_PATH,
            vad_kwargs={"max_single_segment_time": 30000},
            device="cuda",
            disable_update=True,
            hub="hf",
        )
    else:
        raise ValueError(f"Unknown pipeline type: {pipeline_type}")

    loaded_models[pipeline_type] = model
    return model


@spaces.GPU(duration=120)
def transcribe_audio(audio_input, audio_url, pipeline_type, start_time=None, end_time=None):
    try:
        # Determine audio source
        audio_path = None
        is_temp_file = False

        if audio_input is not None and len(audio_input) > 0:
            audio_path = audio_input
        elif audio_url is not None and len(audio_url.strip()) > 0:
            import requests as req
            response = req.get(audio_url, stream=True, timeout=30)
            if response.status_code == 200:
                with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
                    for chunk in response.iter_content(chunk_size=8192):
                        f.write(chunk)
                    audio_path = f.name
                    is_temp_file = True
            else:
                return f"Failed to download audio: HTTP {response.status_code}", "", None
        else:
            return "No audio provided. Upload a file, record, or enter a URL.", "", None

        # Trim if needed
        if start_time > 0 or end_time > 0:
            from pydub import AudioSegment
            audio = AudioSegment.from_file(audio_path)
            duration = len(audio) / 1000
            s = float(start_time) if start_time > 0 else 0
            e = float(end_time) if end_time > 0 else duration
            trimmed = audio[int(s * 1000):int(e * 1000)]
            tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
            trimmed.export(tmp.name, format="wav")
            audio_path = tmp.name
            is_temp_file = True

        # Load model (lazy, inside GPU context)
        model = get_or_load_model(pipeline_type)

        # Transcribe
        t0 = time.time()
        if pipeline_type == "fun-asr-nano":
            res = model.generate(input=[audio_path], use_itn=True, batch_size=1)
        else:
            res = model.generate(
                input=audio_path, cache={}, language="auto",
                use_itn=True, batch_size_s=60, merge_vad=True, merge_length_s=15,
            )

        transcription = rich_transcription_postprocess(res[0]["text"])
        elapsed = time.time() - t0

        metrics = f"Transcription time: {elapsed:.2f}s\nPipeline: {pipeline_type}\nDevice: cuda"

        # Save transcription file
        txt_file = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w", encoding="utf-8")
        txt_file.write(transcription)
        txt_file.close()

        return metrics, transcription, txt_file.name

    except Exception as e:
        logging.error(f"Transcription error: {e}")
        return f"Error: {str(e)}", "", None
    finally:
        if is_temp_file and audio_path and os.path.exists(audio_path):
            os.remove(audio_path)


with gr.Blocks(title="Fun-ASR-Nano | GPU Demo") as demo:
    gr.Markdown("""
# Fun-ASR-Nano: LLM-Powered Speech Recognition (GPU)

End-to-end ASR model trained on tens of millions of hours, supporting **31 languages** including Chinese dialects.

- **GitHub**: [Fun-ASR](https://github.com/FunAudioLLM/Fun-ASR) | [FunASR Toolkit](https://github.com/modelscope/FunASR)
- **Model**: [Fun-ASR-Nano-2512](https://huggingface.co/FunAudioLLM/Fun-ASR-Nano-2512)
    """)

    with gr.Row():
        audio_input = gr.Audio(label="Upload or Record Audio", sources=["upload", "microphone"], type="filepath")
        audio_url = gr.Textbox(label="Or Enter Audio URL", placeholder="https://example.com/audio.wav")

    with gr.Row():
        pipeline_type = gr.Dropdown(
            choices=["fun-asr-nano", "sensevoice"],
            label="Model",
            value="fun-asr-nano"
        )
        start_time = gr.Number(label="Start Time (s)", value=0, minimum=0)
        end_time = gr.Number(label="End Time (s)", value=0, minimum=0)

    transcribe_btn = gr.Button("Transcribe", variant="primary")

    with gr.Row():
        metrics_output = gr.Textbox(label="Metrics", lines=4)
        transcription_output = gr.Textbox(label="Transcription", lines=10)
        transcription_file = gr.File(label="Download")

    transcribe_btn.click(
        transcribe_audio,
        inputs=[audio_input, audio_url, pipeline_type, start_time, end_time],
        outputs=[metrics_output, transcription_output, transcription_file],
    )

    gr.Markdown("""
### Supported Languages
- **Fun-ASR-Nano**: 31 languages + Chinese dialects (Cantonese, Sichuan, Shanghai, Minnan, etc.)
- **SenseVoice**: Chinese, English, Cantonese, Japanese, Korean
    """)

demo.queue().launch()