import textwrap import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer # Model Configuration model_id = "jugalgajjar/PyJavaCPP-Vuln-Fixer" # Load tokenizer and model tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, dtype=torch.float32, device_map="cpu" ) SYSTEM_MESSAGE = ( "You are a code security expert. Given vulnerable source code, " "output ONLY the fixed version of the code with the vulnerability repaired. " "Do not include explanations, just the corrected code." ) # Prediction def fix_code(language, vulnerable_code): if not vulnerable_code.strip(): return "Please enter the code you want to fix." messages = [ {"role": "system", "content": SYSTEM_MESSAGE}, {"role": "user", "content": f"Fix the below given vulnerable {language} code:\n{vulnerable_code}"}, ] prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) # Encode inputs to CPU inputs = tokenizer(prompt, return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=1024, temperature=0.2, top_p=0.95, do_sample=True, repetition_penalty=1.15, ) # Decode only the new tokens new_tokens = outputs[0][inputs["input_ids"].shape[1]:] result = tokenizer.decode(new_tokens, skip_special_tokens=True) # Extract the fixed code inside ``` start_idx = result.find("```") end_idx = result.rfind("```") if start_idx != -1 and end_idx != -1: fixed_code = result[start_idx + 3 : end_idx] return fixed_code.strip() else: return result.strip() EXAMPLES = [ [ "python", textwrap.dedent("""\ import os from flask import Flask, request app = Flask(__name__) @app.route("/run") def run(): cmd = request.args.get("cmd") return os.popen(cmd).read() if __name__ == "__main__": app.run(debug=False)"""), ], [ "java", textwrap.dedent("""\ import java.sql.*; import javax.servlet.http.*; public class UserServlet extends HttpServlet { public void doGet(HttpServletRequest req, HttpServletResponse res) { try { String id = req.getParameter("id"); Connection conn = DriverManager.getConnection("jdbc:mysql://localhost/db", "user", "pass"); Statement stmt = conn.createStatement(); ResultSet rs = stmt.executeQuery("SELECT * FROM users WHERE id='" + id + "'"); } catch (Exception e) { e.printStackTrace(); } } }"""), ], [ "cpp", textwrap.dedent("""\ #include #include void login(char *input) { char password[8]; strcpy(password, input); } int main(int argc, char *argv[]) { if (argc > 1) { login(argv[1]); } return 0; }"""), ], ] # UI Layout with gr.Blocks(title="PyJavaCPP Vuln-Fixer") as demo: gr.Markdown("# 🛡️ PyJavaCPP Vulnerability Fixer (CPU)") gr.Markdown( "Select your language, paste your code, and get a secured version of your code!" ) with gr.Row(): with gr.Column(): lang_input = gr.Dropdown( choices=["python", "java", "cpp"], value="python", label="Target Language", ) code_input = gr.Textbox( label="Vulnerable Code", lines=15, max_lines=30, placeholder="Paste your vulnerable code here...", ) submit_btn = gr.Button("Secure My Code ✨", variant="primary") with gr.Column(): code_output = gr.Textbox( label="Fixed Code", lines=15, max_lines=30, interactive=False, ) gr.Examples( examples=EXAMPLES, inputs=[lang_input, code_input], ) submit_btn.click(fix_code, [lang_input, code_input], code_output) demo.launch(ssr_mode=False)