Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer | |
| from modeling_bert import BertForSequenceClassification | |
| # 当前 app.py 所在目录 | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| # 训练完成后保存模型的目录 | |
| MODEL_DIR = os.path.join(BASE_DIR, "experiments") | |
| # 如果 Spaces 提供 GPU 就用 GPU,否则自动回退到 CPU | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # 类别 id 到文本标签的映射 | |
| ID2LABEL = { | |
| 0: "not_disaster", | |
| 1: "disaster", | |
| } | |
| # 加载 tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR) | |
| # 加载训练好的分类模型 | |
| model = BertForSequenceClassification.from_pretrained(MODEL_DIR) | |
| model.to(DEVICE) | |
| model.eval() | |
| def inference(input_text): | |
| # 处理空输入,避免直接送进模型报错 | |
| input_text = (input_text or "").strip() | |
| if not input_text: | |
| return "Please input a sentence." | |
| # 把文本编码成模型可接收的输入格式 | |
| # 包括 input_ids 和 attention_mask | |
| inputs = tokenizer( | |
| input_text, | |
| max_length=128, | |
| truncation=True, | |
| padding="max_length", | |
| return_tensors="pt", | |
| ) | |
| # 把输入张量移动到和模型相同的设备上 | |
| inputs = {k: v.to(DEVICE) for k, v in inputs.items()} | |
| # 推理阶段不需要计算梯度 | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| # 取分数最高的类别作为最终预测 | |
| predicted_class_id = logits.argmax(dim=-1).item() | |
| output = ID2LABEL[predicted_class_id] | |
| return output | |
| # 使用 Gradio Blocks 搭建一个简单网页界面 | |
| with gr.Blocks(css=""" | |
| .message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px} | |
| #component-2 > div.wrap.svelte-w6rprc {height: 600px;} | |
| """) as demo: | |
| gr.Markdown("# Disaster Tweet Classifier") | |
| gr.Markdown("Input a sentence or tweet, and the model will predict whether it describes a real disaster.") | |
| # 一行布局,里面放一个输入列 | |
| with gr.Row(): | |
| with gr.Column(): | |
| # 用户输入文本 | |
| input_text = gr.Textbox( | |
| placeholder="Insert your text here...", | |
| label="Input Text", | |
| lines=4, | |
| ) | |
| # 显示模型预测结果 | |
| answer = gr.Textbox(label="Prediction") | |
| # 点击按钮后触发推理 | |
| generate_bt = gr.Button("Generate") | |
| # 把按钮、输入框、输出框和推理函数绑定起来 | |
| generate_bt.click( | |
| fn=inference, | |
| inputs=[input_text], | |
| outputs=[answer], | |
| show_progress=True, | |
| ) | |
| # 提供几个示例,方便在线体验 | |
| gr.Examples( | |
| examples=[ | |
| ["Forest fire near La Ronge Sask. Canada"], | |
| ["I love fruits and summer weather."], | |
| ["There is an emergency evacuation happening now in the building across the street."], | |
| ], | |
| inputs=input_text, | |
| outputs=answer, | |
| fn=inference, | |
| cache_examples=False, | |
| ) | |
| # 启动 Gradio 服务 | |
| demo.launch() | |