import os import gradio as gr import torch from transformers import AutoTokenizer from modeling_bert import BertForSequenceClassification # 当前 app.py 所在目录 BASE_DIR = os.path.dirname(os.path.abspath(__file__)) # 训练完成后保存模型的目录 MODEL_DIR = os.path.join(BASE_DIR, "experiments") # 如果 Spaces 提供 GPU 就用 GPU,否则自动回退到 CPU DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # 类别 id 到文本标签的映射 ID2LABEL = { 0: "not_disaster", 1: "disaster", } # 加载 tokenizer tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR) # 加载训练好的分类模型 model = BertForSequenceClassification.from_pretrained(MODEL_DIR) model.to(DEVICE) model.eval() def inference(input_text): # 处理空输入,避免直接送进模型报错 input_text = (input_text or "").strip() if not input_text: return "Please input a sentence." # 把文本编码成模型可接收的输入格式 # 包括 input_ids 和 attention_mask inputs = tokenizer( input_text, max_length=128, truncation=True, padding="max_length", return_tensors="pt", ) # 把输入张量移动到和模型相同的设备上 inputs = {k: v.to(DEVICE) for k, v in inputs.items()} # 推理阶段不需要计算梯度 with torch.no_grad(): logits = model(**inputs).logits # 取分数最高的类别作为最终预测 predicted_class_id = logits.argmax(dim=-1).item() output = ID2LABEL[predicted_class_id] return output # 使用 Gradio Blocks 搭建一个简单网页界面 with gr.Blocks(css=""" .message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px} #component-2 > div.wrap.svelte-w6rprc {height: 600px;} """) as demo: gr.Markdown("# Disaster Tweet Classifier") gr.Markdown("Input a sentence or tweet, and the model will predict whether it describes a real disaster.") # 一行布局,里面放一个输入列 with gr.Row(): with gr.Column(): # 用户输入文本 input_text = gr.Textbox( placeholder="Insert your text here...", label="Input Text", lines=4, ) # 显示模型预测结果 answer = gr.Textbox(label="Prediction") # 点击按钮后触发推理 generate_bt = gr.Button("Generate") # 把按钮、输入框、输出框和推理函数绑定起来 generate_bt.click( fn=inference, inputs=[input_text], outputs=[answer], show_progress=True, ) # 提供几个示例,方便在线体验 gr.Examples( examples=[ ["Forest fire near La Ronge Sask. Canada"], ["I love fruits and summer weather."], ["There is an emergency evacuation happening now in the building across the street."], ], inputs=input_text, outputs=answer, fn=inference, cache_examples=False, ) # 启动 Gradio 服务 demo.launch()