jugalgajjar's picture
update app.py
1fff7d6 verified
import textwrap
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# Model Configuration
model_id = "jugalgajjar/PyJavaCPP-Vuln-Fixer"
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
dtype=torch.float32,
device_map="cpu"
)
SYSTEM_MESSAGE = (
"You are a code security expert. Given vulnerable source code, "
"output ONLY the fixed version of the code with the vulnerability repaired. "
"Do not include explanations, just the corrected code."
)
# Prediction
def fix_code(language, vulnerable_code):
if not vulnerable_code.strip():
return "Please enter the code you want to fix."
messages = [
{"role": "system", "content": SYSTEM_MESSAGE},
{"role": "user", "content": f"Fix the below given vulnerable {language} code:\n{vulnerable_code}"},
]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
# Encode inputs to CPU
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=1024,
temperature=0.2,
top_p=0.95,
do_sample=True,
repetition_penalty=1.15,
)
# Decode only the new tokens
new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
result = tokenizer.decode(new_tokens, skip_special_tokens=True)
# Extract the fixed code inside ```
start_idx = result.find("```")
end_idx = result.rfind("```")
if start_idx != -1 and end_idx != -1:
fixed_code = result[start_idx + 3 : end_idx]
return fixed_code.strip()
else:
return result.strip()
EXAMPLES = [
[
"python",
textwrap.dedent("""\
import os
from flask import Flask, request
app = Flask(__name__)
@app.route("/run")
def run():
cmd = request.args.get("cmd")
return os.popen(cmd).read()
if __name__ == "__main__":
app.run(debug=False)"""),
],
[
"java",
textwrap.dedent("""\
import java.sql.*;
import javax.servlet.http.*;
public class UserServlet extends HttpServlet {
public void doGet(HttpServletRequest req, HttpServletResponse res) {
try {
String id = req.getParameter("id");
Connection conn = DriverManager.getConnection("jdbc:mysql://localhost/db", "user", "pass");
Statement stmt = conn.createStatement();
ResultSet rs = stmt.executeQuery("SELECT * FROM users WHERE id='" + id + "'");
} catch (Exception e) {
e.printStackTrace();
}
}
}"""),
],
[
"cpp",
textwrap.dedent("""\
#include <iostream>
#include <cstring>
void login(char *input) {
char password[8];
strcpy(password, input);
}
int main(int argc, char *argv[]) {
if (argc > 1) {
login(argv[1]);
}
return 0;
}"""),
],
]
# UI Layout
with gr.Blocks(title="PyJavaCPP Vuln-Fixer") as demo:
gr.Markdown("# 🛡️ PyJavaCPP Vulnerability Fixer (CPU)")
gr.Markdown(
"Select your language, paste your code, and get a secured version of your code!"
)
with gr.Row():
with gr.Column():
lang_input = gr.Dropdown(
choices=["python", "java", "cpp"],
value="python",
label="Target Language",
)
code_input = gr.Textbox(
label="Vulnerable Code",
lines=15,
max_lines=30,
placeholder="Paste your vulnerable code here...",
)
submit_btn = gr.Button("Secure My Code ✨", variant="primary")
with gr.Column():
code_output = gr.Textbox(
label="Fixed Code",
lines=15,
max_lines=30,
interactive=False,
)
gr.Examples(
examples=EXAMPLES,
inputs=[lang_input, code_input],
)
submit_btn.click(fix_code, [lang_input, code_input], code_output)
demo.launch(ssr_mode=False)