Samarthrr commited on
Commit
76e3fab
·
verified ·
1 Parent(s): 879b56d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -111
app.py CHANGED
@@ -1,8 +1,9 @@
1
  import ast
2
  import torch
3
  import torch.nn as nn
4
- from fastapi import FastAPI, HTTPException
5
  from pydantic import BaseModel
 
6
  from transformers import (
7
  T5ForConditionalGeneration,
8
  RobertaTokenizer,
@@ -11,176 +12,207 @@ from transformers import (
11
  )
12
  import pandas as pd
13
  import os
 
 
14
 
15
- app = FastAPI(title="Revcode AI Strong Orchestrator")
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  # ---------------------------------------------------------
18
- # 1. ADVANCED SECURITY SCANNER (The "Brain")
19
  # ---------------------------------------------------------
20
  class DeepVulnerabilityScanner:
21
  def __init__(self):
22
- print("Loading Deep Security Scanner (DistilRoBERTa)...")
23
- self.model_name = "distilroberta-base"
24
- try:
25
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
26
- self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name, num_labels=2)
27
- self.model.eval()
28
- except Exception as e:
29
- print(f"Failed to load Deep Scanner: {e}")
30
- self.model = None
31
-
32
- def scan(self, code: str) -> dict:
33
- if not self.model:
34
- return {"is_vulnerable": False, "risk_score": 0.0, "details": "Scanner unavailable"}
35
 
36
- inputs = self.tokenizer(code, return_tensors="pt", truncation=True, max_length=512)
 
 
 
 
 
 
37
  with torch.no_grad():
38
  logits = self.model(**inputs).logits
39
 
40
  probs = torch.softmax(logits, dim=1)
41
  vuln_prob = probs[0][1].item()
42
 
 
 
 
 
 
 
 
 
43
  return {
44
- "is_vulnerable": vuln_prob > 0.5,
45
- "risk_score": round(vuln_prob * 100, 2),
46
- "details": "Potential vulnerability detected in code logic." if vuln_prob > 0.5 else "Clean code."
 
47
  }
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  # ---------------------------------------------------------
50
- # 2. AUTOMATED REPAIR ENGINE (The "Surgeon")
51
  # ---------------------------------------------------------
52
  class AutomatedRepairEngine:
53
  def __init__(self):
54
- print("Loading Repair Engine (CodeT5)...")
55
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
  self.model_name = "Salesforce/codet5p-220m"
57
- try:
58
- self.tokenizer = RobertaTokenizer.from_pretrained(self.model_name)
59
- self.model = T5ForConditionalGeneration.from_pretrained(self.model_name).to(self.device)
60
- self.model.eval()
61
- except Exception as e:
62
- print(f"Failed to load Repair Engine: {e}")
63
- self.model = None
64
-
65
- def repair(self, buggy_code: str) -> str:
66
- if not self.model:
67
- return buggy_code
68
-
69
- prompt = f"Fix the security vulnerability: {buggy_code}"
70
- inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=256).to(self.device)
71
 
72
  with torch.no_grad():
73
  outputs = self.model.generate(
74
  **inputs,
75
- max_length=256,
76
  num_beams=5,
77
- temperature=0.7,
78
  early_stopping=True
79
  )
80
 
81
  return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
82
 
83
  # ---------------------------------------------------------
84
- # 3. SYNTAX & LOGIC VALIDATOR (The "Quality Control")
85
- # ---------------------------------------------------------
86
- class CodeValidator:
87
- @staticmethod
88
- def is_syntax_valid(code: str) -> bool:
89
- try:
90
- ast.parse(code)
91
- return True
92
- except Exception:
93
- return False
94
-
95
- @staticmethod
96
- def check_security_patterns(code: str) -> list:
97
- issues = []
98
- dangerous_calls = ["eval(", "exec(", "os.system(", "subprocess.call("]
99
- for call in dangerous_calls:
100
- if call in code:
101
- issues.append(f"Dangerous call found: {call}")
102
- return issues
103
-
104
- # ---------------------------------------------------------
105
- # 4. GLOBAL HANDLERS (Lazy Loading)
106
  # ---------------------------------------------------------
107
  _scanner = None
108
  _repairer = None
 
109
 
110
- def get_scanner():
111
  global _scanner
112
- if _scanner is None:
113
- _scanner = DeepVulnerabilityScanner()
114
  return _scanner
115
 
116
  def get_repairer():
117
  global _repairer
118
- if _repairer is None:
119
- _repairer = AutomatedRepairEngine()
120
  return _repairer
121
 
122
- # ---------------------------------------------------------
123
- # 5. API ENDPOINTS
124
- # ---------------------------------------------------------
125
- class CodeInput(BaseModel):
126
- code: str
127
 
128
  @app.post("/analyze")
129
  async def analyze_security(data: CodeInput):
130
  scanner = get_scanner()
131
- result = scanner.scan(data.code)
132
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  return {
134
- "is_vulnerable": result["is_vulnerable"],
135
- "confidence": result["risk_score"],
136
- "verdict": "VULNERABLE" if result["is_vulnerable"] else "SECURE",
137
- "details": result["details"],
138
- "provider": "DistilRoBERTa-Strong"
 
139
  }
140
 
141
  @app.post("/fix")
142
  async def fix_code(data: CodeInput):
143
  repairer = get_repairer()
144
- validator = CodeValidator()
145
 
146
- # ML Repair
147
- suggestion = repairer.repair(data.code)
148
 
149
- # Validation Loop
150
- status = "PASSED"
151
- msg = "Valid syntax"
152
-
153
- if not validator.is_syntax_valid(suggestion):
154
- status = "FAILED"
155
- msg = "Repair generated invalid syntax"
156
- # Heuristic fallback (from user's logic)
157
- suggestion = data.code.replace("eval(", "safe_eval(")
158
-
159
- # Final Security Pattern Check
160
- final_issues = validator.check_security_patterns(suggestion)
161
- if final_issues:
162
- for issue in final_issues:
163
- call_name = issue.split(": ")[1]
164
- suggestion = suggestion.replace(call_name, f"# BLOCKED_{call_name.replace('(', '')}")
165
- msg += f" | Blocked {len(final_issues)} dangerous calls"
166
-
167
  return {
168
  "suggestion": suggestion,
169
- "guardrail_status": status,
170
- "guardrail_msg": msg
171
  }
172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  @app.post("/feedback")
174
  async def store_feedback(data: dict):
175
  feedback_file = "feedback_dataset.csv"
176
- df = pd.DataFrame([data])
177
- df.to_csv(feedback_file, mode='a', header=not os.path.exists(feedback_file), index=False)
178
- return {"status": "Feedback stored for retraining"}
179
-
180
- @app.get("/")
181
- async def health():
182
- return {"status": "Revcode AI Strong Engine is alive"}
183
-
184
- if __name__ == "__main__":
185
- import uvicorn
186
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
  import ast
2
  import torch
3
  import torch.nn as nn
4
+ from fastapi import FastAPI, HTTPException, BackgroundTasks
5
  from pydantic import BaseModel
6
+ from typing import Optional, List
7
  from transformers import (
8
  T5ForConditionalGeneration,
9
  RobertaTokenizer,
 
12
  )
13
  import pandas as pd
14
  import os
15
+ import threading
16
+ import re
17
 
18
+ # Import the training function
19
+ from train_engine import train_on_devign
20
+
21
+ app = FastAPI(title="Revcode AI Precision Engine")
22
+
23
+ # Global State
24
+ training_lock = threading.Lock()
25
+ is_training = False
26
+
27
+ class CodeInput(BaseModel):
28
+ code: str
29
+ filename: Optional[str] = "snippet.js"
30
 
31
  # ---------------------------------------------------------
32
+ # 1. PRECISION SCANNER (CodeBERT-Devign)
33
  # ---------------------------------------------------------
34
  class DeepVulnerabilityScanner:
35
  def __init__(self):
36
+ # Prefer locally trained model if it exists
37
+ local_model = "./trained_model"
38
+ if os.path.exists(local_model):
39
+ self.model_name = local_model
40
+ self.tokenizer_name = local_model
41
+ else:
42
+ self.model_name = "mahdin70/codebert-devign-code-vulnerability-detector"
43
+ self.tokenizer_name = "microsoft/codebert-base"
 
 
 
 
 
44
 
45
+ print(f"Loading Precision Scanner ({self.model_name})...")
46
+ self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
47
+ self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name)
48
+ self.model.eval()
49
+
50
+ def scan(self, code: str) -> dict:
51
+ inputs = self.tokenizer(code, return_tensors="pt", truncation=True, padding=True, max_length=512)
52
  with torch.no_grad():
53
  logits = self.model(**inputs).logits
54
 
55
  probs = torch.softmax(logits, dim=1)
56
  vuln_prob = probs[0][1].item()
57
 
58
+ # RAISED THRESHOLD: Only flag as 'is_vulnerable' if we are > 85% certain
59
+ is_vuln = vuln_prob > 0.85
60
+
61
+ verdict = "SECURE"
62
+ if vuln_prob > 0.9: verdict = "CRITICAL"
63
+ elif vuln_prob > 0.7: verdict = "WARNING"
64
+ elif vuln_prob > 0.4: verdict = "POTENTIAL"
65
+
66
  return {
67
+ "is_vulnerable": is_vuln,
68
+ "confidence": round(vuln_prob * 100, 2),
69
+ "threat_level": verdict,
70
+ "reasoning": self._generate_reasoning(vuln_prob, code)
71
  }
72
 
73
+ def _generate_reasoning(self, prob, code):
74
+ if prob > 0.85:
75
+ return "CRITICAL: Detected high-confidence signature of an exploited pattern (likely injection or stack/heap overflow)."
76
+ if prob > 0.5:
77
+ return "MEDIUM: Code structure resembles vulnerable patterns in the security training set. Recommended audit."
78
+ return "SAFE: No significant security anomalies detected by the neural engine."
79
+
80
+ # ---------------------------------------------------------
81
+ # 2. RULE-BASED PATTERN FILTER (Hardened)
82
+ # ---------------------------------------------------------
83
+ class StructuralScanner:
84
+ @staticmethod
85
+ def scan(code: str, filename: str) -> List[dict]:
86
+ findings = []
87
+
88
+ # Rule 1: Code Injection (Detecting RAW eval, excluding json/safe wraps)
89
+ if "eval(" in code:
90
+ if not any(x in code for x in ["JSON.parse(", "safe_eval", "ast.literal_eval"]):
91
+ findings.append({
92
+ "title": "Unsafe Eval Usage",
93
+ "severity": "CRITICAL",
94
+ "reasoning": "Standard eval() executes string data as code. Use JSON.parse() or ast.literal_eval() for data."
95
+ })
96
+
97
+ # Rule 2: RAW Command Injection
98
+ if any(x in code for x in ["os.system(", "subprocess.Popen(..., shell=True)"]):
99
+ findings.append({
100
+ "title": "Direct Shell Execution",
101
+ "severity": "HIGH",
102
+ "reasoning": "Detected shell invocation with shell=True. This is highly susceptible to command injection."
103
+ })
104
+
105
+ return findings
106
+
107
  # ---------------------------------------------------------
108
+ # 3. CONSERVATIVE REPAIR ENGINE (Minimal Changes)
109
  # ---------------------------------------------------------
110
  class AutomatedRepairEngine:
111
  def __init__(self):
112
+ print("Loading Conservative Repair Engine (CodeT5+)...")
 
113
  self.model_name = "Salesforce/codet5p-220m"
114
+ self.tokenizer = RobertaTokenizer.from_pretrained(self.model_name)
115
+ self.model = T5ForConditionalGeneration.from_pretrained(self.model_name)
116
+ self.model.eval()
117
+
118
+ def repair(self, buggy_code: str, filename: str) -> str:
119
+ # CONSTRAINED PROMPT: Focus only on the security fix
120
+ prompt = f"Fix the security scan vulnerability in this {filename} file accurately and with minimal changes: {buggy_code}"
121
+ inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
 
 
 
 
 
 
122
 
123
  with torch.no_grad():
124
  outputs = self.model.generate(
125
  **inputs,
126
+ max_length=512,
127
  num_beams=5,
128
+ temperature=0.2, # LOWER TEMPERATURE for less creativity/more precision
129
  early_stopping=True
130
  )
131
 
132
  return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
133
 
134
  # ---------------------------------------------------------
135
+ # 4. ORCHESTRATION & API
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  # ---------------------------------------------------------
137
  _scanner = None
138
  _repairer = None
139
+ _struct = StructuralScanner()
140
 
141
+ def get_scanner(reload=False):
142
  global _scanner
143
+ if _scanner is None or reload: _scanner = DeepVulnerabilityScanner()
 
144
  return _scanner
145
 
146
  def get_repairer():
147
  global _repairer
148
+ if _repairer is None: _repairer = AutomatedRepairEngine()
 
149
  return _repairer
150
 
151
+ @app.get("/")
152
+ async def health():
153
+ return {"status": "Revcode Precision Engine Live", "is_training": is_training}
 
 
154
 
155
  @app.post("/analyze")
156
  async def analyze_security(data: CodeInput):
157
  scanner = get_scanner()
 
158
 
159
+ # 1. Neural Analysis
160
+ res = scanner.scan(data.code)
161
+
162
+ # 2. Structural Analysis
163
+ struct_findings = _struct.scan(data.code, data.filename)
164
+
165
+ # Merge Logic: If structural findings exist, it's definitely vulnerable
166
+ if struct_findings:
167
+ res["is_vulnerable"] = True
168
+ res["threat_level"] = "CRITICAL"
169
+ res["reasoning"] += " | Found hard rules violation: " + ", ".join([f['title'] for f in struct_findings])
170
+
171
  return {
172
+ "is_vulnerable": res["is_vulnerable"],
173
+ "confidence": res["confidence"],
174
+ "threat_level": res["threat_level"],
175
+ "reasoning": res["reasoning"],
176
+ "structural_findings": struct_findings,
177
+ "is_training": is_training
178
  }
179
 
180
  @app.post("/fix")
181
  async def fix_code(data: CodeInput):
182
  repairer = get_repairer()
 
183
 
184
+ # 1. Primary generative fix
185
+ suggestion = repairer.repair(data.code, data.filename)
186
 
187
+ # 2. Post-processing: If the AI failed to replace eval, force a surgical replacement
188
+ # This prevents the "vulnerability still there" issue
189
+ if "eval(" in data.code and "eval(" in suggestion:
190
+ suggestion = suggestion.replace("eval(", "JSON.parse(")
191
+
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  return {
193
  "suggestion": suggestion,
194
+ "engine": "Conservative-CodeT5",
195
+ "context": data.filename
196
  }
197
 
198
+ @app.post("/train")
199
+ async def trigger_training(background_tasks: BackgroundTasks):
200
+ global is_training
201
+ if is_training: return {"status": "error", "message": "Training in progress"}
202
+
203
+ def run():
204
+ global is_training
205
+ is_training = True
206
+ try:
207
+ train_on_devign(output_dir="./trained_model")
208
+ get_scanner(reload=True)
209
+ finally: is_training = False
210
+
211
+ background_tasks.add_task(run)
212
+ return {"status": "success", "message": "Training started"}
213
+
214
  @app.post("/feedback")
215
  async def store_feedback(data: dict):
216
  feedback_file = "feedback_dataset.csv"
217
+ pd.DataFrame([data]).to_csv(feedback_file, mode='a', header=not os.path.exists(feedback_file), index=False)
218
+ return {"status": "stored"}