zhu-mingye commited on
Commit
8110e4e
·
1 Parent(s): 718985b

Add CodeT5 model integration

Browse files
Files changed (2) hide show
  1. app.py +46 -4
  2. requirements.txt +2 -0
app.py CHANGED
@@ -1,7 +1,49 @@
1
  import gradio as gr
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ import torch
4
 
5
+ # 加载 CodeT5 模型
6
+ model_name = "Salesforce/codet5-base"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
9
 
10
+ def generate_code(prompt: str, max_length: int = 128) -> str:
11
+ """代码生成/补全"""
12
+ if not prompt.strip():
13
+ return ""
14
+
15
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
16
+
17
+ with torch.no_grad():
18
+ outputs = model.generate(
19
+ **inputs,
20
+ max_length=max_length,
21
+ num_beams=4,
22
+ early_stopping=True
23
+ )
24
+
25
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
26
+
27
+ # 创建 Gradio 界面
28
+ demo = gr.Interface(
29
+ fn=generate_code,
30
+ inputs=[
31
+ gr.Textbox(
32
+ label="Prompt",
33
+ placeholder="输入代码描述或代码片段,例如:def fibonacci(n):",
34
+ lines=5
35
+ ),
36
+ gr.Slider(32, 512, value=128, step=32, label="Max Length")
37
+ ],
38
+ outputs=gr.Textbox(label="Generated Code", lines=10),
39
+ title="CodeT5 Code Generation",
40
+ description="基于 Salesforce CodeT5 的代码生成模型。支持代码补全、代码生成等任务。",
41
+ examples=[
42
+ ["def fibonacci(n):", 128],
43
+ ["# Python function to calculate factorial", 128],
44
+ ["def quick_sort(arr):", 128],
45
+ ]
46
+ )
47
+
48
+ if __name__ == "__main__":
49
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ transformers>=4.30.0
2
+ torch>=2.0.0