| | |
| | import spaces |
| | import gradio as gr |
| | from gradio import update |
| | from functools import lru_cache |
| | from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
| | from opencc import OpenCC |
| | from math import gcd |
| | from termcolor import cprint |
| |
|
| | |
| | cc = OpenCC('s2t') |
| |
|
| | |
| | MODEL_LIST = [ |
| | "liswei/Taiwan-ELM-270M", |
| | "Mxode/SmolLM-Chinese-180M", |
| | "openbmb/BitCPM4-0.5B", |
| | "flyingfishinwater/chinese-baby-llama2", |
| | "unsloth/gemma-3-1b-pt", |
| | "taide/TAIDE-LX-7B", |
| | "ckiplab/gpt2-tiny-chinese", |
| | "ckiplab/gpt2-base-chinese", |
| | "liswei/Taiwan-ELM-1_1B", |
| | "benchang1110/Qwen2.5-Taiwan-1.5B-Instruct", |
| | "benchang1110/Taiwan-tinyllama-v1.0-base", |
| | "lianghsun/Llama-3.2-Taiwan-3B", |
| | "twinkle-ai/Llama-3.2-3B-F1-Instruct", |
| | "Epiculous/Violet_Twilight-v0.2", |
| | ] |
| |
|
| |
|
| | @lru_cache(maxsize=8) |
| | def get_pipeline(model_name): |
| | tok = AutoTokenizer.from_pretrained(model_name) |
| | mdl = AutoModelForCausalLM.from_pretrained( |
| | model_name, weights_only=False, trust_remote_code=True |
| | ) |
| | try: |
| | mdl.to("cuda") |
| | except Exception as e: |
| | print(f'Error: {e}') |
| | return pipeline("text-generation", model=mdl, tokenizer=tok, device=0) |
| |
|
| | @spaces.GPU |
| | def suggest_next(text, model_name, k, m, num_beam_groups, diversity_penalty): |
| | """ |
| | 使用 Diverse Beam Search 產生 m 條候選: |
| | - num_beams = m |
| | - num_beam_groups, diversity_penalty 可調整多樣性 |
| | 之後轉繁體、去重、合併共同前綴後回傳。 |
| | """ |
| | gen_pipe = get_pipeline(model_name) |
| | |
| | gen_kwargs = { |
| | "max_new_tokens": k, |
| | "num_beams": m, |
| | "num_return_sequences": m, |
| | "do_sample": False, |
| | "early_stopping": True, |
| | } |
| | if diversity_penalty and diversity_penalty > 0: |
| | valid_group = gcd(m, num_beam_groups) |
| | gen_kwargs["num_beam_groups"] = valid_group |
| | gen_kwargs["diversity_penalty"] = float(diversity_penalty) |
| |
|
| | outs = gen_pipe(text, **gen_kwargs) |
| |
|
| | |
| | suggestions = set() |
| | for out in outs: |
| | snippet = out["generated_text"][len(text):].rstrip() |
| | if not snippet: |
| | continue |
| | converted = cc.convert(snippet) |
| | suggestions.add(converted) |
| | suggestions = list(suggestions) |
| |
|
| | return update(choices=suggestions, value=None) |
| |
|
| |
|
| | def append_suggestion(current, choice): |
| | if choice is None: |
| | return current |
| | |
| | return current + choice |
| |
|
| | |
| | custom_css = """ |
| | #suggestions-bar { |
| | width: 100%; |
| | margin-bottom: 8px; |
| | } |
| | #suggestions-bar .candidate-list { |
| | display: flex; |
| | gap: 8px; |
| | background: #fff; |
| | border: 1px solid #999; |
| | border-radius: 4px; |
| | padding: 6px; |
| | overflow-x: auto; |
| | white-space: nowrap; |
| | } |
| | #suggestions-bar .candidate-list label { |
| | cursor: pointer; |
| | padding: 6px 10px; |
| | font-size: 16px; |
| | } |
| | #suggestions-bar .candidate-list label:hover { |
| | background: #f5f5f5; |
| | } |
| | #suggestions-bar .candidate-list input[type=radio]:checked + label { |
| | background: #e6f7ff; |
| | border: 1px solid #1890ff; |
| | } |
| | #input-box textarea { |
| | width: 100%; |
| | font-size: 16px; |
| | padding: 6px; |
| | box-sizing: border-box; |
| | overflow: hidden; |
| | resize: none; |
| | } |
| | #predict-button { |
| | margin-top: 8px; |
| | width: 100%; |
| | } |
| | /* 手機響應式 */ |
| | @media only screen and (max-width: 600px) { |
| | #suggestions-bar .candidate-list label { |
| | padding: 8px; |
| | font-size: 18px; |
| | } |
| | #predict-button { |
| | font-size: 18px; |
| | } |
| | } |
| | """ |
| |
|
| | |
| | auto_height_js = """ |
| | <script> |
| | window.addEventListener('load', () => { |
| | const textarea = document.querySelector('#input-box textarea'); |
| | if (!textarea) return; |
| | textarea.style.height = 'auto'; |
| | textarea.addEventListener('input', function() { |
| | this.style.height = 'auto'; |
| | this.style.height = this.scrollHeight + 'px'; |
| | }); |
| | }); |
| | </script> |
| | """ |
| |
|
| | with gr.Blocks(css=custom_css) as demo: |
| | gr.HTML(auto_height_js) |
| | gr.Markdown( |
| | "## 🇹🇼 繁體中文 IME 加速器 \ |
| | " |
| | "結合小型語言模型與 ZeroGPU,提供即時輸入法風格候選欄。" |
| | ) |
| |
|
| | with gr.Column(): |
| | suggestions = gr.Radio( |
| | [], label="", interactive=True, type="value", |
| | elem_id="suggestions-bar", elem_classes="candidate-list" |
| | ) |
| | input_text = gr.Textbox( |
| | label="", placeholder="請輸入拼音或文字…", |
| | lines=1, max_lines=20, elem_id="input-box" |
| | ) |
| |
|
| | |
| | with gr.Row(): |
| | auto_predict = gr.Checkbox( |
| | value=True, label="自動預測(內容變更時觸發)", elem_id="auto-predict" |
| | ) |
| | predict_button = gr.Button( |
| | "預測", elem_id="predict-button" |
| | ) |
| |
|
| | with gr.Accordion("進階設定", open=False): |
| | model_selector = gr.Dropdown( |
| | MODEL_LIST, value=MODEL_LIST[0], label="模型" |
| | ) |
| | k_slider = gr.Slider( |
| | minimum=1, maximum=50, step=1, value=1, label="K(最大新詞元數)" |
| | ) |
| | m_slider = gr.Slider( |
| | minimum=1, maximum=30, step=1, value=10, label="M(建議數/Beam 數)" |
| | ) |
| | group_slider = gr.Slider( |
| | minimum=2, maximum=30, step=2, value=6, |
| | label="Beam 群組數 (num_beam_groups)" |
| | ) |
| | diversity_penalty_slider = gr.Slider( |
| | minimum=0.0, maximum=2.0, step=0.1, value=0.0, |
| | label="多樣性懲罰 (diversity_penalty)" |
| | ) |
| |
|
| | |
| | predict_button.click( |
| | fn=suggest_next, |
| | inputs=[ |
| | input_text, |
| | model_selector, |
| | k_slider, |
| | m_slider, |
| | group_slider, |
| | diversity_penalty_slider |
| | ], |
| | outputs=suggestions, |
| | ) |
| | input_text.change( |
| | fn=lambda txt, mdl, k, m, g, d, auto: ( |
| | suggest_next(txt, mdl, k, m, g, d) |
| | if auto else update(choices=[], value=None) |
| | ), |
| | inputs=[ |
| | input_text, |
| | model_selector, |
| | k_slider, |
| | m_slider, |
| | group_slider, |
| | diversity_penalty_slider, |
| | auto_predict |
| | ], |
| | outputs=suggestions, |
| | ) |
| | suggestions.change( |
| | fn=append_suggestion, |
| | inputs=[input_text, suggestions], |
| | outputs=input_text, |
| | ) |
| |
|
| | demo.launch() |