YashashviAlva's picture
Fix HF Spaces streaming timeouts and error handling
43efb12
"""
Fix Agent β€” generates unified diffs, security report, and PR description
from Security + Performance findings.
"""
from __future__ import annotations
import json
import logging
import re
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from openai import AsyncOpenAI
from api.models import (
FileFix,
FixResult,
PerformanceFinding,
SecurityFinding,
)
from tools.code_parser import FileEntry
from tools.diff_generator import (
format_pr_diff_block,
generate_unified_diff,
)
logger = logging.getLogger(__name__)
FIX_SYSTEM_PROMPT = """You are CodeSentry Fix Agent β€” a senior security engineer generating precise, minimal code fixes.
Given a list of security and performance findings, produce a corrected version of each affected file.
## Rules:
1. Make the MINIMAL change required to fix each issue β€” don't refactor unrelated code.
2. Add a comment on each changed line explaining WHY the fix was applied.
3. For hardcoded secrets: replace with os.getenv("VAR_NAME") and add to .env.example.
4. For pickle.load: replace with torch.load(..., weights_only=True) or use safetensors.
5. For prompt injection: add input sanitisation or use structured prompts with variables.
6. For missing @torch.no_grad: add the decorator.
7. For N+1 embeddings: restructure to batch call.
8. For eval(llm_output): raise an error and use structured JSON parsing instead.
## Output Format (STRICT JSON):
{
"finding_fixes": [
{
"findingId": "<matching finding ID>",
"before": "<vulnerable code snippet>",
"after": "<fixed code snippet>",
"explanation": "Brief technical explanation"
}
],
"files": [
{
"file_path": "<original filename>",
"fixed_code": "<complete fixed file content>",
"explanation": "What was changed and why",
"fixes_applied": ["Fix 1 description", "Fix 2 description"]
}
],
"security_report_md": "<full markdown security report>",
"pr_description": "<GitHub PR description markdown>"
}
"""
SECURITY_REPORT_TEMPLATE = """# πŸ›‘οΈ CodeSentry Security Report
**Generated:** {timestamp}
**Session ID:** {session_id}
**Model:** Qwen/Qwen2.5-Coder-32B-Instruct (AMD MI300X)
**Zero Data Retention:** βœ… All inference ran locally
---
## Executive Summary
| Severity | Count |
|----------|-------|
| πŸ”΄ Critical | {critical} |
| 🟠 High | {high} |
| 🟑 Medium | {medium} |
| 🟒 Low | {low} |
| ⚑ Performance | {perf} |
**Files Analysed:** {files_count}
**Estimated Memory Savings:** {memory_savings} MB
---
## Security Findings
{security_findings_md}
---
## Performance Optimisations
{performance_findings_md}
---
## Remediation Diffs
{diffs_md}
---
*Report generated by CodeSentry β€” AMD MI300X powered, Zero Data Retention*
"""
class FixAgent:
def __init__(
self,
vllm_base_url: str = "http://localhost:8080/v1",
model: str = "Qwen/Qwen2.5-Coder-32B-Instruct",
api_key: str = "not-needed-local",
max_tokens: int = 8192,
temperature: float = 0.05,
) -> None:
self.model = model
self.max_tokens = max_tokens
self.temperature = temperature
self.client = AsyncOpenAI(
base_url=vllm_base_url,
api_key=api_key,
timeout=60.0,
max_retries=1,
)
# ─────────────────────────────────────────
# Main entry point
# ─────────────────────────────────────────
async def generate_fixes(
self,
files: List[FileEntry],
security_findings: List[SecurityFinding],
performance_findings: List[PerformanceFinding],
session_id: str = "",
use_llm: bool = True,
) -> FixResult:
"""
Generate diffs, security report, and PR description.
Falls back to report-only mode if LLM is unavailable.
"""
# Build report regardless
report_md = self._build_security_report(
session_id=session_id,
security_findings=security_findings,
performance_findings=performance_findings,
files=files,
diffs_md="", # filled in after diff generation
)
pr_desc = self._build_pr_description(security_findings, performance_findings)
file_fixes: List[FileFix] = []
finding_fixes: List[FindingFix] = []
if use_llm and files and (security_findings or performance_findings):
file_fixes, finding_fixes = await self._llm_generate_fixes(files, security_findings, performance_findings)
# Re-render report with actual diffs
if file_fixes:
all_diffs = [(fix.file_path, fix.diff) for fix in file_fixes]
diffs_md = format_pr_diff_block(all_diffs)
report_md = self._build_security_report(
session_id=session_id,
security_findings=security_findings,
performance_findings=performance_findings,
files=files,
diffs_md=diffs_md,
)
return FixResult(
finding_fixes=finding_fixes,
diffs=file_fixes,
files_changed=len(file_fixes),
security_report_md=report_md,
pr_description=pr_desc,
)
# ─────────────────────────────────────────
# LLM fix generation
# ─────────────────────────────────────────
async def _llm_generate_fixes(
self,
files: List[FileEntry],
security_findings: List[SecurityFinding],
performance_findings: List[PerformanceFinding],
) -> Tuple[List[FileFix], List[FindingFix]]:
"""Ask the LLM to produce fixed versions of affected files."""
# Collect only affected files
affected_paths = set()
for f in security_findings:
if f.file:
affected_paths.add(f.file)
for f in performance_findings:
if f.file:
affected_paths.add(f.file)
affected_files = [(p, c) for p, c in files if p in affected_paths] or files[:2]
findings_summary = self._findings_to_text(security_findings, performance_findings)
# Truncate each file to stay within Groq's TPM limits
MAX_CHARS_PER_FILE = 1200
MAX_TOTAL_CHARS = 3000
total_chars = 0
file_blocks = []
for p, c in affected_files:
truncated = c[:MAX_CHARS_PER_FILE]
if len(c) > MAX_CHARS_PER_FILE:
truncated += "\n# ... (truncated for brevity)"
block = f"# FILE: {p}\n```python\n{truncated}\n```"
if total_chars + len(block) > MAX_TOTAL_CHARS * 4: # rough char budget
break
file_blocks.append(block)
total_chars += len(block)
files_content = "\n\n".join(file_blocks)
user_message = (
f"Findings to fix:\n{findings_summary}\n\n"
f"Files:\n{files_content}\n\n"
"Return ONLY the JSON response as specified."
)
try:
response = await self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": FIX_SYSTEM_PROMPT},
{"role": "user", "content": user_message},
],
max_tokens=self.max_tokens,
temperature=self.temperature,
)
raw = response.choices[0].message.content or "{}"
return self._parse_fix_response(raw, dict(affected_files))
except Exception as exc:
logger.error("[FixAgent] LLM call failed: %s", exc)
return [], []
def _parse_fix_response(
self, raw: str, original_files: Dict[str, str]
) -> Tuple[List[FileFix], List[FindingFix]]:
raw = re.sub(r"```(?:json)?\s*", "", raw).strip().rstrip("`").strip()
# Find outermost JSON object
start = raw.find("{")
end = raw.rfind("}") + 1
if start == -1 or end == 0:
logger.warning("[FixAgent] No JSON object in LLM response")
return [], []
try:
data = json.loads(raw[start:end])
except json.JSONDecodeError as exc:
logger.warning("[FixAgent] JSON parse error: %s", exc)
return [], []
fixes: List[FileFix] = []
for file_info in data.get("files", []):
path = file_info.get("file_path", "unknown")
fixed_code = file_info.get("fixed_code", "")
explanation = file_info.get("explanation", "")
original = original_files.get(path, "")
diff = generate_unified_diff(original, fixed_code, filename=path)
if diff:
fixes.append(FileFix(file_path=path, diff=diff, explanation=explanation))
finding_fixes: List[FindingFix] = []
from api.models import FindingFix
for f in data.get("finding_fixes", []):
try:
finding_fixes.append(FindingFix(**f))
except Exception as e:
logger.debug("[FixAgent] Skipping malformed finding fix: %s", e)
logger.info(f"[FixAgent] Parsed {len(finding_fixes)} finding_fixes and {len(fixes)} file fixes.")
return fixes, finding_fixes
# ─────────────────────────────────────────
# Report builders
# ─────────────────────────────────────────
def _build_security_report(
self,
session_id: str,
security_findings: List[SecurityFinding],
performance_findings: List[PerformanceFinding],
files: List[FileEntry],
diffs_md: str,
) -> str:
from api.models import Severity
sev_counts = {s: 0 for s in Severity}
for f in security_findings:
sev_counts[f.severity] = sev_counts.get(f.severity, 0) + 1
total_mem = sum(
(pf.saving_mb or 0.0) for pf in performance_findings
)
# Security findings section
sec_md_lines: List[str] = []
for i, finding in enumerate(security_findings, 1):
sev_icon = {"critical": "πŸ”΄", "high": "🟠", "medium": "🟑", "low": "🟒"}.get(
finding.severity.value, "βšͺ"
)
sec_md_lines.append(
f"### {i}. {sev_icon} [{finding.severity.value.upper()}] {finding.title}\n"
f"- **CWE:** {finding.cwe or 'N/A'} \n"
f"- **OWASP:** {finding.owasp_category or 'N/A'} \n"
f"- **File:** `{finding.file or 'N/A'}` line {finding.line or 'N/A'} \n"
f"- **Description:** {finding.description} \n"
+ (f"- **Fix:** `{finding.suggestion}`\n" if finding.suggestion else "")
+ (f"\n```\n{finding.code}\n```\n" if finding.code else "")
)
# Performance findings section
perf_md_lines: List[str] = []
for i, pf in enumerate(performance_findings, 1):
perf_md_lines.append(
f"### {i}. ⚑ {pf.title}\n"
f"- **Type:** {pf.type.value} \n"
f"- **Current:** {pf.current_estimate or 'N/A'} \n"
f"- **Optimised:** {pf.optimized_estimate or 'N/A'} \n"
f"- **Saving:** {pf.saving or f'{pf.saving_mb or 0:.0f} MB'} \n"
f"- **Fix:** `{pf.suggestion}`\n"
)
return SECURITY_REPORT_TEMPLATE.format(
timestamp=datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC"),
session_id=session_id,
critical=sev_counts.get("critical", 0),
high=sev_counts.get("high", 0),
medium=sev_counts.get("medium", 0),
low=sev_counts.get("low", 0),
perf=len(performance_findings),
files_count=len(files),
memory_savings=f"{total_mem:.0f}",
security_findings_md="\n".join(sec_md_lines) or "_No security findings._",
performance_findings_md="\n".join(perf_md_lines) or "_No performance findings._",
diffs_md=diffs_md or "_No automated fixes generated._",
)
def _build_pr_description(
self,
security_findings: List[SecurityFinding],
performance_findings: List[PerformanceFinding],
) -> str:
critical = [f for f in security_findings if f.severity.value == "critical"]
high = [f for f in security_findings if f.severity.value == "high"]
lines = [
"## πŸ›‘οΈ CodeSentry Security & Performance Fix",
"",
"### What this PR fixes:",
"",
]
if critical:
lines.append("#### πŸ”΄ Critical Security Issues:")
for f in critical:
lines.append(f"- **{f.title}** ({f.cwe or f.owasp_category}) β€” {f.description[:120]}...")
lines.append("")
if high:
lines.append("#### 🟠 High Severity Issues:")
for f in high:
lines.append(f"- **{f.title}** β€” {f.description[:120]}...")
lines.append("")
if performance_findings:
total_mb = sum((pf.saving_mb or 0.0) for pf in performance_findings)
lines.append(f"#### ⚑ Performance Optimisations ({len(performance_findings)} fixes, ~{total_mb:.0f} MB VRAM saved):")
for pf in performance_findings[:5]:
lines.append(f"- {pf.title}: {pf.saving or 'improvement'}")
lines.append("")
lines += [
"### How to review:",
"1. Check diffs for each file β€” all changes are minimal and targeted",
"2. Verify `.env.example` for any new environment variables",
"3. Run `pytest tests/ -v` to confirm all tests pass",
"",
"---",
"_Generated by CodeSentry on AMD MI300X β€” Zero Data Retention βœ…_",
]
return "\n".join(lines)
@staticmethod
def _findings_to_text(
security_findings: List[SecurityFinding],
performance_findings: List[PerformanceFinding],
) -> str:
lines = ["## Security Findings:"]
for f in security_findings:
lines.append(
f"- ID: {f.id} [{f.severity.value.upper()}] {f.title} "
f"(file={f.file}, line={f.line}, cwe={f.cwe}): {f.description}"
)
lines.append("\n## Performance Findings:")
for f in performance_findings:
lines.append(f"- ID: {f.id} [{f.type.value.upper()}] {f.title}: {f.suggestion}")
return "\n".join(lines)