Samarthrr commited on
Commit
879b56d
·
verified ·
1 Parent(s): e740563

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -164
app.py CHANGED
@@ -1,9 +1,8 @@
1
  import ast
2
  import torch
3
  import torch.nn as nn
4
- from fastapi import FastAPI, HTTPException, BackgroundTasks
5
  from pydantic import BaseModel
6
- from typing import Optional
7
  from transformers import (
8
  T5ForConditionalGeneration,
9
  RobertaTokenizer,
@@ -12,225 +11,163 @@ from transformers import (
12
  )
13
  import pandas as pd
14
  import os
15
- import threading
16
 
17
- # Import the training function
18
- from train_engine import train_on_devign
19
-
20
- app = FastAPI(title="Revcode AI ULTRA Orchestrator")
21
-
22
- # Global training status
23
- training_lock = threading.Lock()
24
- is_training = False
25
-
26
- # ---------------------------------------------------------
27
- # 1. DATA MODELS
28
- # ---------------------------------------------------------
29
- class CodeInput(BaseModel):
30
- code: str
31
- filename: Optional[str] = "snippet.js"
32
 
33
  # ---------------------------------------------------------
34
- # 2. ADVANCED SECURITY SCANNER (CodeBERT-Devign + XAI)
35
  # ---------------------------------------------------------
36
  class DeepVulnerabilityScanner:
37
  def __init__(self):
38
- # We check if a locally trained model exists, otherwise use the base
39
- local_model = "./trained_model"
40
- if os.path.exists(local_model):
41
- self.model_name = local_model
42
- self.tokenizer_name = local_model
43
- print(f"Loading Locally Trained Security Scanner ({self.model_name})...")
44
- else:
45
- self.model_name = "mahdin70/codebert-devign-code-vulnerability-detector"
46
- self.tokenizer_name = "microsoft/codebert-base"
47
- print(f"Loading SOTA Security Scanner ({self.model_name})...")
48
-
49
- self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
50
- self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name)
51
- self.model.eval()
52
-
53
  def scan(self, code: str) -> dict:
54
- inputs = self.tokenizer(code, return_tensors="pt", truncation=True, padding=True, max_length=512)
 
 
 
55
  with torch.no_grad():
56
  logits = self.model(**inputs).logits
57
 
58
  probs = torch.softmax(logits, dim=1)
59
  vuln_prob = probs[0][1].item()
60
 
61
- reasoning = "Analyzing code logic for Devign-pattern vulnerabilities."
62
- if vuln_prob > 0.9:
63
- reasoning = "CRITICAL: High-confidence fingerprint of a known vulnerability pattern (e.g., Buffer Overflow, Improper Sanitization)."
64
- elif vuln_prob > 0.5:
65
- reasoning = "WARNING: Code semantics mirror dangerous patterns found in the Devign security dataset."
66
- elif vuln_prob < 0.1:
67
- reasoning = "SAFE: Code logic is clean of any recognized vulnerability fingerprints."
68
-
69
  return {
70
  "is_vulnerable": vuln_prob > 0.5,
71
  "risk_score": round(vuln_prob * 100, 2),
72
- "verdict": "VULNERABLE" if vuln_prob > 0.5 else "SECURE",
73
- "reasoning": reasoning
74
  }
75
 
76
  # ---------------------------------------------------------
77
- # 3. STRUCTURAL SCANNER (Mini-Semgrep)
78
- # ---------------------------------------------------------
79
- class StructuralScanner:
80
- @staticmethod
81
- def scan_patterns(code: str, filename: str) -> list:
82
- findings = []
83
- if "os.system(" in code or "subprocess.Popen(..., shell=True)" in code:
84
- findings.append({
85
- "type": "Security",
86
- "title": "Command Injection Risk",
87
- "reasoning": "Detected use of shell=True or os.system which can lead to Remote Code Execution."
88
- })
89
- if "pickle.load" in code or "yaml.load(..., Loader=None)" in code:
90
- findings.append({
91
- "type": "Security",
92
- "title": "Insecure Deserialization",
93
- "reasoning": "Insecure loading of serialized data can lead to arbitrary code execution."
94
- })
95
- if "Password =" in code or "API_KEY =" in code:
96
- findings.append({
97
- "type": "Compliance",
98
- "title": "Hardcoded Secret",
99
- "reasoning": "Sensitive credentials found in source code. Use environment variables instead."
100
- })
101
- return findings
102
-
103
- # ---------------------------------------------------------
104
- # 4. AUTOMATED REPAIR ENGINE (The "Surgeon" + Context)
105
  # ---------------------------------------------------------
106
  class AutomatedRepairEngine:
107
  def __init__(self):
108
- print("Loading Repair Engine (CodeT5+)...")
 
109
  self.model_name = "Salesforce/codet5p-220m"
110
- self.tokenizer = RobertaTokenizer.from_pretrained(self.model_name)
111
- self.model = T5ForConditionalGeneration.from_pretrained(self.model_name)
112
- self.model.eval()
 
 
 
 
 
 
 
 
113
 
114
- def repair(self, buggy_code: str, filename: str) -> str:
115
- prompt = f"Fix the security vulnerability in this {filename} file: {buggy_code}"
116
- inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
117
  with torch.no_grad():
118
  outputs = self.model.generate(
119
  **inputs,
120
- max_length=512,
121
  num_beams=5,
122
  temperature=0.7,
123
  early_stopping=True
124
  )
 
125
  return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
126
 
127
  # ---------------------------------------------------------
128
- # 5. ARCHITECTURAL GUARDRAILS
129
  # ---------------------------------------------------------
130
- class Guardrails:
131
  @staticmethod
132
- def validate(code: str):
133
  try:
134
  ast.parse(code)
135
- return True, "Valid"
136
- except Exception as e:
137
- return False, f"Syntax analysis failed: {str(e)}"
 
 
 
 
 
 
 
 
 
138
 
139
  # ---------------------------------------------------------
140
- # 6. GLOBAL HANDLERS
141
  # ---------------------------------------------------------
142
- scanner = None
143
- repairer = None
144
- struct_scanner = StructuralScanner()
145
- guardrails = Guardrails()
146
 
147
- def get_scanner(force_reload=False):
148
- global scanner
149
- if scanner is None or force_reload:
150
- scanner = DeepVulnerabilityScanner()
151
- return scanner
152
 
153
  def get_repairer():
154
- global repairer
155
- if repairer is None:
156
- repairer = AutomatedRepairEngine()
157
- return repairer
158
-
159
- # ---------------------------------------------------------
160
- # 7. TRAINING WRAPPER
161
- # ---------------------------------------------------------
162
- def run_training():
163
- global is_training
164
- with training_lock:
165
- is_training = True
166
- try:
167
- print("--- STARTING BACKGROUND TRAINING CYCLE ---")
168
- train_on_devign(output_dir="./trained_model")
169
- print("--- TRAINING CYCLE COMPLETED. RELOADING SCANNER ---")
170
- get_scanner(force_reload=True)
171
- finally:
172
- with training_lock:
173
- is_training = False
174
 
175
  # ---------------------------------------------------------
176
- # 8. API ENDPOINTS
177
  # ---------------------------------------------------------
178
- @app.get("/")
179
- async def health():
180
- return {
181
- "status": "Revcode AI ULTRA Orchestrator Operational",
182
- "is_training": is_training,
183
- "features": ["XAI", "Structural-Scan", "Context-Injection", "Auto-Train"]
184
- }
185
-
186
- @app.post("/train")
187
- async def trigger_training(background_tasks: BackgroundTasks):
188
- global is_training
189
- if is_training:
190
- return {"status": "error", "message": "Training already in progress."}
191
-
192
- background_tasks.add_task(run_training)
193
- return {"status": "success", "message": "Training started in background."}
194
 
195
  @app.post("/analyze")
196
  async def analyze_security(data: CodeInput):
197
- eng = get_scanner()
198
- res = eng.scan(data.code)
199
- structural_findings = struct_scanner.scan_patterns(data.code, data.filename)
200
- if structural_findings:
201
- res["is_vulnerable"] = True
202
- res["reasoning"] += " | Structural rules flagged: " + ", ".join([f['title'] for f in structural_findings])
203
- res["verdict"] = "CRITICAL_VULNERABILITY"
204
-
205
  return {
206
- "is_vulnerable": res["is_vulnerable"],
207
- "confidence": res["risk_score"],
208
- "verdict": res["verdict"],
209
- "reasoning": res["reasoning"],
210
- "structural_findings": structural_findings,
211
- "is_training": is_training,
212
- "provider": "DeepScanner-ULTRA"
213
  }
214
 
215
  @app.post("/fix")
216
  async def fix_code(data: CodeInput):
217
- rep = get_repairer()
218
- suggestion = rep.repair(data.code, data.filename)
219
- is_valid, msg = guardrails.validate(suggestion)
220
- return {
221
- "suggestion": suggestion,
222
- "guardrail_status": "PASSED" if is_valid else "FAILED",
223
- "guardrail_msg": msg,
224
- "context_applied": data.filename
225
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
- @app.post("/verify")
228
- async def verify_fix(data: CodeInput):
229
- is_valid, msg = guardrails.validate(data.code)
230
  return {
231
- "is_valid": is_valid,
232
- "message": msg,
233
- "status": "PASSED" if is_valid else "WARNING"
234
  }
235
 
236
  @app.post("/feedback")
@@ -239,3 +176,11 @@ async def store_feedback(data: dict):
239
  df = pd.DataFrame([data])
240
  df.to_csv(feedback_file, mode='a', header=not os.path.exists(feedback_file), index=False)
241
  return {"status": "Feedback stored for retraining"}
 
 
 
 
 
 
 
 
 
1
  import ast
2
  import torch
3
  import torch.nn as nn
4
+ from fastapi import FastAPI, HTTPException
5
  from pydantic import BaseModel
 
6
  from transformers import (
7
  T5ForConditionalGeneration,
8
  RobertaTokenizer,
 
11
  )
12
  import pandas as pd
13
  import os
 
14
 
15
+ app = FastAPI(title="Revcode AI Strong Orchestrator")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  # ---------------------------------------------------------
18
+ # 1. ADVANCED SECURITY SCANNER (The "Brain")
19
  # ---------------------------------------------------------
20
  class DeepVulnerabilityScanner:
21
  def __init__(self):
22
+ print("Loading Deep Security Scanner (DistilRoBERTa)...")
23
+ self.model_name = "distilroberta-base"
24
+ try:
25
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
26
+ self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name, num_labels=2)
27
+ self.model.eval()
28
+ except Exception as e:
29
+ print(f"Failed to load Deep Scanner: {e}")
30
+ self.model = None
31
+
 
 
 
 
 
32
  def scan(self, code: str) -> dict:
33
+ if not self.model:
34
+ return {"is_vulnerable": False, "risk_score": 0.0, "details": "Scanner unavailable"}
35
+
36
+ inputs = self.tokenizer(code, return_tensors="pt", truncation=True, max_length=512)
37
  with torch.no_grad():
38
  logits = self.model(**inputs).logits
39
 
40
  probs = torch.softmax(logits, dim=1)
41
  vuln_prob = probs[0][1].item()
42
 
 
 
 
 
 
 
 
 
43
  return {
44
  "is_vulnerable": vuln_prob > 0.5,
45
  "risk_score": round(vuln_prob * 100, 2),
46
+ "details": "Potential vulnerability detected in code logic." if vuln_prob > 0.5 else "Clean code."
 
47
  }
48
 
49
  # ---------------------------------------------------------
50
+ # 2. AUTOMATED REPAIR ENGINE (The "Surgeon")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  # ---------------------------------------------------------
52
  class AutomatedRepairEngine:
53
  def __init__(self):
54
+ print("Loading Repair Engine (CodeT5)...")
55
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
  self.model_name = "Salesforce/codet5p-220m"
57
+ try:
58
+ self.tokenizer = RobertaTokenizer.from_pretrained(self.model_name)
59
+ self.model = T5ForConditionalGeneration.from_pretrained(self.model_name).to(self.device)
60
+ self.model.eval()
61
+ except Exception as e:
62
+ print(f"Failed to load Repair Engine: {e}")
63
+ self.model = None
64
+
65
+ def repair(self, buggy_code: str) -> str:
66
+ if not self.model:
67
+ return buggy_code
68
 
69
+ prompt = f"Fix the security vulnerability: {buggy_code}"
70
+ inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=256).to(self.device)
71
+
72
  with torch.no_grad():
73
  outputs = self.model.generate(
74
  **inputs,
75
+ max_length=256,
76
  num_beams=5,
77
  temperature=0.7,
78
  early_stopping=True
79
  )
80
+
81
  return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
82
 
83
  # ---------------------------------------------------------
84
+ # 3. SYNTAX & LOGIC VALIDATOR (The "Quality Control")
85
  # ---------------------------------------------------------
86
+ class CodeValidator:
87
  @staticmethod
88
+ def is_syntax_valid(code: str) -> bool:
89
  try:
90
  ast.parse(code)
91
+ return True
92
+ except Exception:
93
+ return False
94
+
95
+ @staticmethod
96
+ def check_security_patterns(code: str) -> list:
97
+ issues = []
98
+ dangerous_calls = ["eval(", "exec(", "os.system(", "subprocess.call("]
99
+ for call in dangerous_calls:
100
+ if call in code:
101
+ issues.append(f"Dangerous call found: {call}")
102
+ return issues
103
 
104
  # ---------------------------------------------------------
105
+ # 4. GLOBAL HANDLERS (Lazy Loading)
106
  # ---------------------------------------------------------
107
+ _scanner = None
108
+ _repairer = None
 
 
109
 
110
+ def get_scanner():
111
+ global _scanner
112
+ if _scanner is None:
113
+ _scanner = DeepVulnerabilityScanner()
114
+ return _scanner
115
 
116
  def get_repairer():
117
+ global _repairer
118
+ if _repairer is None:
119
+ _repairer = AutomatedRepairEngine()
120
+ return _repairer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  # ---------------------------------------------------------
123
+ # 5. API ENDPOINTS
124
  # ---------------------------------------------------------
125
+ class CodeInput(BaseModel):
126
+ code: str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  @app.post("/analyze")
129
  async def analyze_security(data: CodeInput):
130
+ scanner = get_scanner()
131
+ result = scanner.scan(data.code)
132
+
 
 
 
 
 
133
  return {
134
+ "is_vulnerable": result["is_vulnerable"],
135
+ "confidence": result["risk_score"],
136
+ "verdict": "VULNERABLE" if result["is_vulnerable"] else "SECURE",
137
+ "details": result["details"],
138
+ "provider": "DistilRoBERTa-Strong"
 
 
139
  }
140
 
141
  @app.post("/fix")
142
  async def fix_code(data: CodeInput):
143
+ repairer = get_repairer()
144
+ validator = CodeValidator()
145
+
146
+ # ML Repair
147
+ suggestion = repairer.repair(data.code)
148
+
149
+ # Validation Loop
150
+ status = "PASSED"
151
+ msg = "Valid syntax"
152
+
153
+ if not validator.is_syntax_valid(suggestion):
154
+ status = "FAILED"
155
+ msg = "Repair generated invalid syntax"
156
+ # Heuristic fallback (from user's logic)
157
+ suggestion = data.code.replace("eval(", "safe_eval(")
158
+
159
+ # Final Security Pattern Check
160
+ final_issues = validator.check_security_patterns(suggestion)
161
+ if final_issues:
162
+ for issue in final_issues:
163
+ call_name = issue.split(": ")[1]
164
+ suggestion = suggestion.replace(call_name, f"# BLOCKED_{call_name.replace('(', '')}")
165
+ msg += f" | Blocked {len(final_issues)} dangerous calls"
166
 
 
 
 
167
  return {
168
+ "suggestion": suggestion,
169
+ "guardrail_status": status,
170
+ "guardrail_msg": msg
171
  }
172
 
173
  @app.post("/feedback")
 
176
  df = pd.DataFrame([data])
177
  df.to_csv(feedback_file, mode='a', header=not os.path.exists(feedback_file), index=False)
178
  return {"status": "Feedback stored for retraining"}
179
+
180
+ @app.get("/")
181
+ async def health():
182
+ return {"status": "Revcode AI Strong Engine is alive"}
183
+
184
+ if __name__ == "__main__":
185
+ import uvicorn
186
+ uvicorn.run(app, host="0.0.0.0", port=8000)