Garrulus21yyx
Add minimal Gradio Space files
c772fc0
Raw
History Blame Contribute Delete
3.09 kB
import os
import gradio as gr
import torch
from transformers import AutoTokenizer
from modeling_bert import BertForSequenceClassification
# 当前 app.py 所在目录
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# 训练完成后保存模型的目录
MODEL_DIR = os.path.join(BASE_DIR, "experiments")
# 如果 Spaces 提供 GPU 就用 GPU,否则自动回退到 CPU
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# 类别 id 到文本标签的映射
ID2LABEL = {
0: "not_disaster",
1: "disaster",
}
# 加载 tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
# 加载训练好的分类模型
model = BertForSequenceClassification.from_pretrained(MODEL_DIR)
model.to(DEVICE)
model.eval()
def inference(input_text):
# 处理空输入,避免直接送进模型报错
input_text = (input_text or "").strip()
if not input_text:
return "Please input a sentence."
# 把文本编码成模型可接收的输入格式
# 包括 input_ids 和 attention_mask
inputs = tokenizer(
input_text,
max_length=128,
truncation=True,
padding="max_length",
return_tensors="pt",
)
# 把输入张量移动到和模型相同的设备上
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
# 推理阶段不需要计算梯度
with torch.no_grad():
logits = model(**inputs).logits
# 取分数最高的类别作为最终预测
predicted_class_id = logits.argmax(dim=-1).item()
output = ID2LABEL[predicted_class_id]
return output
# 使用 Gradio Blocks 搭建一个简单网页界面
with gr.Blocks(css="""
.message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px}
#component-2 > div.wrap.svelte-w6rprc {height: 600px;}
""") as demo:
gr.Markdown("# Disaster Tweet Classifier")
gr.Markdown("Input a sentence or tweet, and the model will predict whether it describes a real disaster.")
# 一行布局,里面放一个输入列
with gr.Row():
with gr.Column():
# 用户输入文本
input_text = gr.Textbox(
placeholder="Insert your text here...",
label="Input Text",
lines=4,
)
# 显示模型预测结果
answer = gr.Textbox(label="Prediction")
# 点击按钮后触发推理
generate_bt = gr.Button("Generate")
# 把按钮、输入框、输出框和推理函数绑定起来
generate_bt.click(
fn=inference,
inputs=[input_text],
outputs=[answer],
show_progress=True,
)
# 提供几个示例,方便在线体验
gr.Examples(
examples=[
["Forest fire near La Ronge Sask. Canada"],
["I love fruits and summer weather."],
["There is an emergency evacuation happening now in the building across the street."],
],
inputs=input_text,
outputs=answer,
fn=inference,
cache_examples=False,
)
# 启动 Gradio 服务
demo.launch()