ayshajavd's picture
v2: Deploy updated app with per-class thresholds, temperature calibration, CWE-aware fix generation
7336b37 verified
"""
Code Security Risk Analyzer v2 - Gradio UI + REST API
=====================================================
IMPROVEMENTS OVER v1:
- Per-class threshold optimization (not global 0.3)
- Temperature scaling calibration (meaningful probabilities)
- Uses label_config.json for thresholds + calibration
- Better vulnerability detection across rare CWEs
Run AFTER notebooks 1-4 to use the improved models.
Upload this to: https://huggingface.co/spaces/ayshajavd/code-security-analyzer
"""
import json
import re
import time
import torch
import gradio as gr
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
T5ForConditionalGeneration,
)
from huggingface_hub import hf_hub_download
import numpy as np
# ============================================================
# Label Mappings
# ============================================================
TARGET_CWES = [
"safe", "CWE-20", "CWE-22", "CWE-78", "CWE-79", "CWE-89", "CWE-94",
"CWE-119", "CWE-125", "CWE-190", "CWE-200", "CWE-264", "CWE-269",
"CWE-276", "CWE-284", "CWE-287", "CWE-310", "CWE-327", "CWE-330",
"CWE-352", "CWE-362", "CWE-399", "CWE-401", "CWE-416", "CWE-434",
"CWE-476", "CWE-502", "CWE-601", "CWE-787", "CWE-798", "CWE-918",
]
CWE_NAMES = {
"safe": "Safe Code",
"CWE-20": "Improper Input Validation",
"CWE-22": "Path Traversal",
"CWE-78": "OS Command Injection",
"CWE-79": "Cross-Site Scripting (XSS)",
"CWE-89": "SQL Injection",
"CWE-94": "Code Injection",
"CWE-119": "Buffer Overflow",
"CWE-125": "Out-of-bounds Read",
"CWE-190": "Integer Overflow",
"CWE-200": "Information Exposure",
"CWE-264": "Permissions/Privileges/Access Controls",
"CWE-269": "Improper Privilege Management",
"CWE-276": "Incorrect Default Permissions",
"CWE-284": "Improper Access Control",
"CWE-287": "Improper Authentication",
"CWE-310": "Cryptographic Issues",
"CWE-327": "Broken Crypto Algorithm",
"CWE-330": "Insufficient Randomness",
"CWE-352": "Cross-Site Request Forgery (CSRF)",
"CWE-362": "Race Condition",
"CWE-399": "Resource Management Errors",
"CWE-401": "Memory Leak",
"CWE-416": "Use After Free",
"CWE-434": "Unrestricted File Upload",
"CWE-476": "NULL Pointer Dereference",
"CWE-502": "Insecure Deserialization",
"CWE-601": "Open Redirect",
"CWE-787": "Out-of-bounds Write",
"CWE-798": "Hardcoded Credentials",
"CWE-918": "Server-Side Request Forgery (SSRF)",
}
CWE_TO_OWASP = {
"CWE-22": "A01:2021 - Broken Access Control",
"CWE-200": "A01:2021 - Broken Access Control",
"CWE-264": "A01:2021 - Broken Access Control",
"CWE-276": "A01:2021 - Broken Access Control",
"CWE-284": "A01:2021 - Broken Access Control",
"CWE-352": "A01:2021 - Broken Access Control",
"CWE-601": "A01:2021 - Broken Access Control",
"CWE-269": "A01:2021 - Broken Access Control",
"CWE-310": "A02:2021 - Cryptographic Failures",
"CWE-327": "A02:2021 - Cryptographic Failures",
"CWE-330": "A02:2021 - Cryptographic Failures",
"CWE-20": "A03:2021 - Injection",
"CWE-78": "A03:2021 - Injection",
"CWE-79": "A03:2021 - Injection",
"CWE-89": "A03:2021 - Injection",
"CWE-94": "A03:2021 - Injection",
"CWE-119": "A03:2021 - Injection",
"CWE-125": "A03:2021 - Injection",
"CWE-190": "A03:2021 - Injection",
"CWE-416": "A03:2021 - Injection",
"CWE-476": "A03:2021 - Injection",
"CWE-401": "A03:2021 - Injection",
"CWE-787": "A03:2021 - Injection",
"CWE-434": "A04:2021 - Insecure Design",
"CWE-362": "A04:2021 - Insecure Design",
"CWE-399": "A04:2021 - Insecure Design",
"CWE-287": "A07:2021 - Identification & Auth Failures",
"CWE-798": "A07:2021 - Identification & Auth Failures",
"CWE-502": "A08:2021 - Software & Data Integrity Failures",
"CWE-918": "A10:2021 - Server-Side Request Forgery",
}
SEVERITY_MAP = {
"CWE-89": ("Critical", 95), "CWE-78": ("Critical", 93),
"CWE-94": ("Critical", 92), "CWE-502": ("Critical", 90),
"CWE-918": ("Critical", 88), "CWE-798": ("Critical", 87),
"CWE-119": ("High", 85), "CWE-787": ("High", 84),
"CWE-416": ("High", 83), "CWE-79": ("High", 80),
"CWE-22": ("High", 78), "CWE-287": ("High", 77),
"CWE-284": ("High", 76), "CWE-434": ("High", 75),
"CWE-125": ("Medium", 70), "CWE-190": ("Medium", 68),
"CWE-352": ("Medium", 67), "CWE-476": ("Medium", 65),
"CWE-362": ("Medium", 63), "CWE-20": ("Medium", 60),
"CWE-264": ("Medium", 58), "CWE-269": ("Medium", 57),
"CWE-310": ("Medium", 65), "CWE-327": ("Medium", 62),
"CWE-330": ("Medium", 55), "CWE-399": ("Low", 45),
"CWE-401": ("Low", 42), "CWE-200": ("Low", 40),
"CWE-276": ("Low", 38), "CWE-601": ("Medium", 55),
}
EXPLANATIONS = {
"CWE-89": "**SQL Injection** means an attacker can manipulate your database queries by injecting malicious SQL code through user inputs. This could let them steal, modify, or delete ALL your data.",
"CWE-79": "**Cross-Site Scripting (XSS)** lets attackers inject malicious JavaScript into your web pages. When other users visit the page, the script runs in their browser - stealing cookies, session tokens, or redirecting them to fake sites.",
"CWE-78": "**OS Command Injection** means user input is being passed directly to system commands. An attacker could run ANY command on your server.",
"CWE-94": "**Code Injection** allows attackers to inject and execute arbitrary code. Functions like `eval()`, `exec()`, or dynamic code compilation with untrusted input are the usual culprits.",
"CWE-119": "**Buffer Overflow** happens when your code writes data beyond the allocated memory buffer. Attackers can exploit this to crash your program or execute malicious code.",
"CWE-125": "**Out-of-bounds Read** means your code reads memory outside the intended buffer. This can leak sensitive data like passwords or encryption keys.",
"CWE-190": "**Integer Overflow** occurs when an arithmetic operation produces a value too large for the data type, which can be chained with buffer overflows for code execution.",
"CWE-200": "**Information Exposure** means sensitive data (API keys, passwords, stack traces) is being leaked to unauthorized parties.",
"CWE-264": "**Improper Access Control** means users can access resources or perform actions they shouldn't be authorized for.",
"CWE-287": "**Authentication Bypass** means the login/identity verification can be circumvented.",
"CWE-310": "**Cryptographic Issues** - you're using weak, broken, or improperly configured encryption.",
"CWE-352": "**CSRF** tricks authenticated users into performing unwanted actions on your site.",
"CWE-362": "**Race Condition** means two operations compete for the same resource without proper synchronization.",
"CWE-416": "**Use After Free** - memory is being used after it's been freed. Attackers can exploit this for arbitrary code execution.",
"CWE-434": "**Unrestricted File Upload** lets attackers upload malicious files (like web shells) to your server.",
"CWE-476": "**NULL Pointer Dereference** - your code tries to use a pointer that's NULL, causing crashes.",
"CWE-502": "**Insecure Deserialization** means untrusted data is deserialized without validation, enabling code execution.",
"CWE-601": "**Open Redirect** lets attackers redirect users from your trusted site to a malicious one for phishing.",
"CWE-787": "**Out-of-bounds Write** - data is written outside the intended memory buffer, often leading to remote code execution.",
"CWE-798": "**Hardcoded Credentials** - passwords, API keys, or tokens are embedded directly in the source code.",
"CWE-918": "**SSRF** lets attackers make your server send requests to internal systems, accessing internal APIs or cloud metadata.",
"CWE-22": "**Path Traversal** means user input is used in file paths without sanitization. Attackers can use `../` to access any file on the server.",
"CWE-269": "**Privilege Escalation** - a user can gain higher privileges than intended.",
"CWE-276": "**Incorrect Permissions** - files or resources have permissions that are too permissive.",
"CWE-327": "**Broken Cryptography** - you're using algorithms like MD5 or SHA1 that are cryptographically broken.",
"CWE-330": "**Insufficient Randomness** - security-critical random values (tokens, keys) are predictable.",
"CWE-399": "**Resource Management Issues** - improper handling of system resources can lead to denial of service.",
"CWE-401": "**Memory Leak** - memory is allocated but never freed, eventually causing crashes.",
"CWE-20": "**Improper Input Validation** - user input isn't properly checked before use, enabling many other vulnerabilities.",
"CWE-284": "**Broken Access Control** - authorization checks are missing or incorrectly implemented.",
}
# ============================================================
# Model Loading
# ============================================================
CLASSIFIER_ID = "ayshajavd/graphcodebert-vuln-classifier"
FIXER_ID = "ayshajavd/codet5p-vuln-fixer"
THRESHOLDS = {cwe: 0.3 for cwe in TARGET_CWES}
TEMPERATURE = 1.0
print("Loading classifier...")
try:
cls_tokenizer = AutoTokenizer.from_pretrained(CLASSIFIER_ID)
cls_model = AutoModelForSequenceClassification.from_pretrained(CLASSIFIER_ID)
cls_model.eval()
CLASSIFIER_LOADED = True
print("Classifier loaded successfully")
try:
config_path = hf_hub_download(CLASSIFIER_ID, "label_config.json")
with open(config_path) as f:
label_config = json.load(f)
if "optimized_thresholds" in label_config:
THRESHOLDS = label_config["optimized_thresholds"]
print(f"Per-class thresholds loaded ({len(THRESHOLDS)} classes)")
if "temperature" in label_config:
TEMPERATURE = label_config["temperature"]
print(f"Temperature calibration loaded (T={TEMPERATURE:.4f})")
except Exception as e:
print(f"Could not load label_config: {e}. Using defaults.")
except Exception as e:
print(f"Classifier not available: {e}")
cls_tokenizer = AutoTokenizer.from_pretrained("huggingface/CodeBERTa-small-v1")
cls_model = AutoModelForSequenceClassification.from_pretrained(
"huggingface/CodeBERTa-small-v1", num_labels=31, problem_type="multi_label_classification",
)
cls_model.eval()
CLASSIFIER_LOADED = False
print("Loading fix generator...")
try:
fix_tokenizer = AutoTokenizer.from_pretrained(FIXER_ID)
fix_model = T5ForConditionalGeneration.from_pretrained(FIXER_ID)
fix_model.eval()
FIXER_LOADED = True
print("Fix generator loaded successfully")
except Exception as e:
print(f"Fix generator not available: {e}")
fix_tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5p-220m")
fix_model = T5ForConditionalGeneration.from_pretrained("Salesforce/codet5p-220m")
fix_model.eval()
FIXER_LOADED = False
def detect_language(code: str) -> str:
code_lower = code[:500].lower()
if "<?php" in code_lower: return "PHP"
if "package main" in code_lower and "func " in code_lower: return "Go"
if "#include" in code_lower:
if "class " in code_lower or "std::" in code_lower or "cout" in code_lower: return "C++"
return "C"
if "import java" in code_lower or "public class" in code_lower: return "Java"
if re.search(r'\b(const |let |var |function |=>|require\(|module\.exports)', code_lower): return "JavaScript"
if re.search(r'\b(def |import |from |class |self\.|print\()', code_lower): return "Python"
return "Unknown"
def classify_code(code):
inputs = cls_tokenizer(code, return_tensors="pt", max_length=512, truncation=True, padding=True)
with torch.no_grad():
logits = cls_model(**inputs).logits.squeeze()
calibrated_logits = logits / TEMPERATURE
probs = torch.sigmoid(calibrated_logits).numpy()
detected = []
for i, (cwe, p) in enumerate(zip(TARGET_CWES, probs)):
if cwe == "safe":
continue
threshold = THRESHOLDS.get(cwe, 0.3)
if p > threshold:
detected.append((cwe, float(p)))
detected.sort(key=lambda x: x[1], reverse=True)
return detected, float(probs[0]), {cwe: float(p) for cwe, p in zip(TARGET_CWES, probs)}
def generate_fix(code, language, cwe_id=None):
if cwe_id:
cwe_name = CWE_NAMES.get(cwe_id, cwe_id)
prefix = f"fix {cwe_name} vulnerability in {language.lower()}: "
else:
prefix = f"fix {language.lower()}: "
input_ids = fix_tokenizer(prefix + code, return_tensors="pt", max_length=512, truncation=True).input_ids
with torch.no_grad():
out = fix_model.generate(input_ids, max_length=512, num_beams=5, early_stopping=True, no_repeat_ngram_size=3)
return fix_tokenizer.decode(out[0], skip_special_tokens=True)
def build_json_report(code):
language = detect_language(code)
detected, safe_prob, all_probs = classify_code(code)
if not detected:
overall_risk = max(0, int(100 - 100 * safe_prob))
risk_level = "Low"
else:
max_sev = max(SEVERITY_MAP.get(c, ("Low", 30))[1] for c, _ in detected)
avg_conf = sum(p for _, p in detected) / len(detected)
overall_risk = min(100, int(max_sev * avg_conf * 1.2))
risk_level = "Critical" if overall_risk >= 80 else "High" if overall_risk >= 60 else "Medium" if overall_risk >= 40 else "Low"
vulns = []
for cwe, conf in detected:
sev, score = SEVERITY_MAP.get(cwe, ("Medium", 50))
threshold_used = THRESHOLDS.get(cwe, 0.3)
vulns.append({
"cwe_id": cwe, "name": CWE_NAMES.get(cwe, cwe),
"owasp_category": CWE_TO_OWASP.get(cwe, "N/A"),
"severity": sev, "severity_score": score,
"detection_confidence": round(conf, 4),
"threshold_used": round(threshold_used, 3),
"exploit_likelihood": min(100, int(conf * score)),
"explanation": EXPLANATIONS.get(cwe, "Security risk detected.").replace("**", ""),
})
chain = None
if len(detected) > 1:
steps = []
cats = {c for c, _ in detected}
if cats & {"CWE-20","CWE-89","CWE-79","CWE-78","CWE-94"}:
steps.append({"step": len(steps)+1, "phase": "Initial Access", "description": "Exploit input validation weakness"})
if cats & {"CWE-264","CWE-269","CWE-284","CWE-287"}:
steps.append({"step": len(steps)+1, "phase": "Privilege Escalation", "description": "Bypass access controls"})
if cats & {"CWE-200","CWE-22","CWE-125"}:
steps.append({"step": len(steps)+1, "phase": "Data Exfiltration", "description": "Read sensitive files or memory"})
if cats & {"CWE-119","CWE-416","CWE-787","CWE-502"}:
steps.append({"step": len(steps)+1, "phase": "Code Execution", "description": "Exploit memory corruption"})
if steps: chain = steps
fix = None
try:
top_cwe = detected[0][0] if detected else None
f = generate_fix(code, language, top_cwe)
if f and f.strip(): fix = f
except: pass
return {
"language": language,
"model_status": {
"classifier": "trained_v2" if CLASSIFIER_LOADED else "base_model",
"fix_generator": "trained_v2" if FIXER_LOADED else "base_model",
"calibration": f"T={TEMPERATURE:.4f}" if TEMPERATURE != 1.0 else "none",
"thresholds": "per_class_optimized" if any(v != 0.3 for v in THRESHOLDS.values()) else "global_0.3",
},
"overall_risk_score": overall_risk, "risk_level": risk_level,
"safe_probability": round(safe_prob, 4), "num_vulnerabilities": len(vulns),
"vulnerabilities": vulns, "attack_chain": chain, "suggested_fix": fix,
"all_class_probabilities": all_probs,
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
}
def analyze_code(code):
if not code or not code.strip(): return "Please paste some code to analyze."
data = build_json_report(code)
r = ["# Code Security Analysis Report\n"]
r.append(f"**Language:** {data['language']}")
cls_status = "Trained v2 (GraphCodeBERT + ASL)" if data['model_status']['classifier'] == 'trained_v2' else "Base Model"
fix_status = "Trained v2 (CodeT5+ CWE-aware)" if data['model_status']['fix_generator'] == 'trained_v2' else "Base Model"
r.append(f"**Classifier:** {cls_status}")
r.append(f"**Fix Generator:** {fix_status}")
if data['model_status']['calibration'] != 'none':
r.append(f"**Calibration:** {data['model_status']['calibration']} | **Thresholds:** {data['model_status']['thresholds']}")
r.append("")
if data['num_vulnerabilities'] == 0:
r.append("## No Vulnerabilities Detected")
r.append(f"**Risk Score:** {data['overall_risk_score']}/100 | **Safe Confidence:** {data['safe_probability']:.1%}\n")
r.append("Code appears safe. Always supplement with manual review and SAST tools.")
return "\n".join(r)
emoji = {"Critical":"πŸ”΄","High":"🟠","Medium":"🟑","Low":"🟒"}.get(data['risk_level'],"βšͺ")
r.append(f"## {emoji} {data['num_vulnerabilities']} Vulnerability(ies) Detected\n")
r.append(f"**Risk Score:** {data['overall_risk_score']}/100 ({data['risk_level']}) | **Safe Probability:** {data['safe_probability']:.1%}\n---\n")
for i, v in enumerate(data['vulnerabilities'], 1):
se = {"Critical":"πŸ”΄","High":"🟠","Medium":"🟑","Low":"🟒"}.get(v['severity'],"βšͺ")
r.append(f"### {i}. {se} {v['name']}")
r.append("| Property | Value |\n|----------|-------|")
r.append(f"| **CWE ID** | {v['cwe_id']} |")
r.append(f"| **OWASP** | {v['owasp_category']} |")
r.append(f"| **Severity** | {v['severity']} ({v['severity_score']}/100) |")
r.append(f"| **Confidence** | {v['detection_confidence']:.1%} (calibrated) |")
r.append(f"| **Threshold** | {v['threshold_used']:.3f} (per-class optimized) |")
r.append(f"| **Exploit Likelihood** | {v['exploit_likelihood']}% |")
r.append(f"\n**Why Dangerous:** {v['explanation']}\n")
if data['attack_chain']:
r.append("---\n## Attack Chain\n")
for s in data['attack_chain']:
r.append(f"{s['step']}. **{s['phase']}** β€” {s['description']}")
r.append("\n---\n## Suggested Fix\n")
if data['suggested_fix']:
r.append(f"```{data['language'].lower()}\n{data['suggested_fix']}\n```")
else:
r.append("*Fix generation unavailable. Please review manually.*")
r.append("\n---\n*AI-generated report (v2: calibrated probabilities + per-class thresholds). Verify with manual review and SAST tools.*")
return "\n".join(r)
def get_json_report(code):
if not code or not code.strip(): return {"error": "No code provided"}
return build_json_report(code)
EXAMPLES = [
["""import sqlite3\n\ndef get_user(username):\n conn = sqlite3.connect('users.db')\n query = f"SELECT * FROM users WHERE username = '{username}'"\n return conn.execute(query).fetchone()\n"""],
["""#include <stdio.h>\n#include <string.h>\n\nvoid process_input(char *user_input) {\n char buffer[64];\n strcpy(buffer, user_input);\n printf("Processed: %s\\n", buffer);\n}\n"""],
["""const express = require('express');\nconst app = express();\n\napp.get('/search', (req, res) => {\n const query = req.query.q;\n res.send(`<h1>Results for: ${query}</h1>`);\n});\n"""],
["""import requests, hashlib\n\nAPI_KEY = "sk-proj-abc123def456"\nDB_PASSWORD = "admin123"\n\ndef hash_password(password):\n return hashlib.md5(password.encode()).hexdigest()\n"""],
["""import sqlite3\nfrom hashlib import sha256\nimport hmac, secrets\n\ndef get_user(username):\n conn = sqlite3.connect('users.db')\n conn.execute("SELECT * FROM users WHERE username = ?", (username,))\n return conn.fetchone()\n"""],
]
with gr.Blocks(
title="Code Security Risk Analyzer v2",
theme=gr.themes.Soft(),
css=".gradio-container { max-width: 1200px; margin: auto; }",
) as demo:
gr.Markdown("""
# πŸ”’ AI-Powered Code Security Risk Analyzer v2
### Detect OWASP Top 10 & CWE vulnerabilities with calibrated confidence + per-class thresholds
Paste code in Python, JavaScript, Java, C, C++, PHP, or Go.
**Models:** [GraphCodeBERT](https://huggingface.co/ayshajavd/graphcodebert-vuln-classifier) (detection, Macro F1=0.476) + [CodeT5+](https://huggingface.co/ayshajavd/codet5p-vuln-fixer) (fixes, BLEU=81.0) | **Dataset:** [175K samples](https://huggingface.co/datasets/ayshajavd/code-security-vulnerability-dataset)
**v2 Improvements:** Per-class threshold optimization | Temperature-calibrated probabilities | Asymmetric Loss training | GraphCodeBERT-base (125M params) | CodeT5+ 220M CWE-aware fixer
""")
with gr.Row():
with gr.Column(scale=1):
code_input = gr.Code(label="Paste Your Code Here", language="python", lines=20)
with gr.Row():
analyze_btn = gr.Button("πŸ” Analyze Security", variant="primary", size="lg")
json_btn = gr.Button("πŸ“‹ JSON Report", variant="secondary", size="lg")
with gr.Column(scale=1):
report_output = gr.Markdown(label="Security Report")
json_output = gr.JSON(label="JSON Report", visible=False)
gr.Examples(examples=EXAMPLES, inputs=[code_input], label="Example Code Snippets")
def show_json(code):
return gr.update(visible=True, value=get_json_report(code))
analyze_btn.click(fn=analyze_code, inputs=[code_input], outputs=[report_output], api_name="analyze")
json_btn.click(fn=show_json, inputs=[code_input], outputs=[json_output])
with gr.Row(visible=False):
api_json_btn = gr.Button("get_json", visible=False)
api_json_btn.click(fn=get_json_report, inputs=[code_input], outputs=[json_output], api_name="get_json_report")
with gr.Accordion("🌐 REST API Documentation", open=False):
gr.Markdown("""
### Python Client
```python
from gradio_client import Client
client = Client("ayshajavd/code-security-analyzer")
report = client.predict(code="your code here", api_name="/analyze")
json_report = client.predict(code="your code here", api_name="/get_json_report")
```
### cURL
```bash
curl -X POST https://ayshajavd-code-security-analyzer.hf.space/call/analyze \\
-H "Content-Type: application/json" -d '{"data": ["your code here"]}'
```
""")
gr.Markdown("""
---
### 30 CWE Vulnerability Classes β†’ OWASP Top 10
| OWASP Category | CWEs |
|---|---|
| **A01: Broken Access Control** | CWE-22, 200, 264, 269, 276, 284, 352, 601 |
| **A02: Cryptographic Failures** | CWE-310, 327, 330 |
| **A03: Injection** | CWE-20, 78, 79, 89, 94, 119, 125, 190, 401, 416, 476, 787 |
| **A04: Insecure Design** | CWE-362, 399, 434 |
| **A07: Auth Failures** | CWE-287, 798 |
| **A08: Integrity Failures** | CWE-502 |
| **A10: SSRF** | CWE-918 |
""")
if __name__ == "__main__":
demo.launch()