| | import textwrap |
| | import gradio as gr |
| | import torch |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| |
|
| | |
| | model_id = "jugalgajjar/PyJavaCPP-Vuln-Fixer" |
| |
|
| | |
| | tokenizer = AutoTokenizer.from_pretrained(model_id) |
| | model = AutoModelForCausalLM.from_pretrained( |
| | model_id, |
| | dtype=torch.float32, |
| | device_map="cpu" |
| | ) |
| |
|
| | SYSTEM_MESSAGE = ( |
| | "You are a code security expert. Given vulnerable source code, " |
| | "output ONLY the fixed version of the code with the vulnerability repaired. " |
| | "Do not include explanations, just the corrected code." |
| | ) |
| |
|
| | |
| | def fix_code(language, vulnerable_code): |
| | if not vulnerable_code.strip(): |
| | return "Please enter the code you want to fix." |
| | |
| | messages = [ |
| | {"role": "system", "content": SYSTEM_MESSAGE}, |
| | {"role": "user", "content": f"Fix the below given vulnerable {language} code:\n{vulnerable_code}"}, |
| | ] |
| |
|
| | prompt = tokenizer.apply_chat_template( |
| | messages, |
| | tokenize=False, |
| | add_generation_prompt=True, |
| | ) |
| |
|
| | |
| | inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
| | |
| | with torch.no_grad(): |
| | outputs = model.generate( |
| | **inputs, |
| | max_new_tokens=1024, |
| | temperature=0.2, |
| | top_p=0.95, |
| | do_sample=True, |
| | repetition_penalty=1.15, |
| | ) |
| |
|
| | |
| | new_tokens = outputs[0][inputs["input_ids"].shape[1]:] |
| | result = tokenizer.decode(new_tokens, skip_special_tokens=True) |
| |
|
| | |
| | start_idx = result.find("```") |
| | end_idx = result.rfind("```") |
| | if start_idx != -1 and end_idx != -1: |
| | fixed_code = result[start_idx + 3 : end_idx] |
| | return fixed_code.strip() |
| | else: |
| | return result.strip() |
| |
|
| | EXAMPLES = [ |
| | [ |
| | "python", |
| | textwrap.dedent("""\ |
| | import os |
| | from flask import Flask, request |
| | |
| | app = Flask(__name__) |
| | |
| | @app.route("/run") |
| | def run(): |
| | cmd = request.args.get("cmd") |
| | return os.popen(cmd).read() |
| | |
| | if __name__ == "__main__": |
| | app.run(debug=False)"""), |
| | ], |
| | [ |
| | "java", |
| | textwrap.dedent("""\ |
| | import java.sql.*; |
| | import javax.servlet.http.*; |
| | |
| | public class UserServlet extends HttpServlet { |
| | public void doGet(HttpServletRequest req, HttpServletResponse res) { |
| | try { |
| | String id = req.getParameter("id"); |
| | Connection conn = DriverManager.getConnection("jdbc:mysql://localhost/db", "user", "pass"); |
| | Statement stmt = conn.createStatement(); |
| | ResultSet rs = stmt.executeQuery("SELECT * FROM users WHERE id='" + id + "'"); |
| | } catch (Exception e) { |
| | e.printStackTrace(); |
| | } |
| | } |
| | }"""), |
| | ], |
| | [ |
| | "cpp", |
| | textwrap.dedent("""\ |
| | #include <iostream> |
| | #include <cstring> |
| | |
| | void login(char *input) { |
| | char password[8]; |
| | strcpy(password, input); |
| | } |
| | |
| | int main(int argc, char *argv[]) { |
| | if (argc > 1) { |
| | login(argv[1]); |
| | } |
| | return 0; |
| | }"""), |
| | ], |
| | ] |
| |
|
| | |
| | with gr.Blocks(title="PyJavaCPP Vuln-Fixer") as demo: |
| | gr.Markdown("# 🛡️ PyJavaCPP Vulnerability Fixer (CPU)") |
| | gr.Markdown( |
| | "Select your language, paste your code, and get a secured version of your code!" |
| | ) |
| |
|
| | with gr.Row(): |
| | with gr.Column(): |
| | lang_input = gr.Dropdown( |
| | choices=["python", "java", "cpp"], |
| | value="python", |
| | label="Target Language", |
| | ) |
| | code_input = gr.Textbox( |
| | label="Vulnerable Code", |
| | lines=15, |
| | max_lines=30, |
| | placeholder="Paste your vulnerable code here...", |
| | ) |
| | submit_btn = gr.Button("Secure My Code ✨", variant="primary") |
| |
|
| | with gr.Column(): |
| | code_output = gr.Textbox( |
| | label="Fixed Code", |
| | lines=15, |
| | max_lines=30, |
| | interactive=False, |
| | ) |
| |
|
| | gr.Examples( |
| | examples=EXAMPLES, |
| | inputs=[lang_input, code_input], |
| | ) |
| |
|
| | submit_btn.click(fix_code, [lang_input, code_input], code_output) |
| |
|
| | demo.launch(ssr_mode=False) |