| from flask import Flask, request, jsonify |
| import torch |
| from transformers import RobertaTokenizer, RobertaForSequenceClassification |
| import os |
| import gc |
| from functools import lru_cache |
|
|
| app = Flask(__name__) |
|
|
| model = None |
| tokenizer = None |
| device = None |
|
|
| def setup_device(): |
| if torch.cuda.is_available(): |
| return torch.device('cuda') |
| elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): |
| return torch.device('mps') |
| else: |
| return torch.device('cpu') |
|
|
| def load_tokenizer(): |
| try: |
| tokenizer = RobertaTokenizer.from_pretrained('./tokenizer_vulnerability') |
| tokenizer.model_max_length = 512 |
| return tokenizer |
| except Exception as e: |
| print(f"Error loading tokenizer: {e}") |
| try: |
| return RobertaTokenizer.from_pretrained('microsoft/codebert-base') |
| except Exception as e2: |
| print(f"Fallback tokenizer failed: {e2}") |
| return None |
|
|
| def load_model(): |
| global device |
| device = setup_device() |
| print(f"Using device: {device}") |
| |
| try: |
| checkpoint = torch.load("codebert_vulnerability_scorer.pth", map_location=device) |
| |
| if 'config' in checkpoint: |
| from transformers import RobertaConfig |
| config = RobertaConfig.from_dict(checkpoint['config']) |
| model = RobertaForSequenceClassification(config) |
| else: |
| model = RobertaForSequenceClassification.from_pretrained( |
| 'microsoft/codebert-base', |
| num_labels=1 |
| ) |
| |
| if 'model_state_dict' in checkpoint: |
| model.load_state_dict(checkpoint['model_state_dict']) |
| else: |
| model.load_state_dict(checkpoint) |
| |
| model.to(device) |
| model.eval() |
| |
| if device.type == 'cuda': |
| model.half() |
| |
| return model |
| |
| except Exception as e: |
| print(f"Error loading model: {e}") |
| return None |
|
|
| def cleanup_gpu_memory(): |
| if device and device.type == 'cuda': |
| torch.cuda.empty_cache() |
| gc.collect() |
|
|
| try: |
| print("Loading tokenizer...") |
| tokenizer = load_tokenizer() |
| if tokenizer: |
| print("Tokenizer loaded successfully!") |
| else: |
| print("Failed to load tokenizer!") |
| |
| print("Loading model...") |
| model = load_model() |
| if model: |
| print("Model loaded successfully!") |
| else: |
| print("Failed to load model!") |
| |
| except Exception as e: |
| print(f"Error during initialization: {str(e)}") |
| tokenizer = None |
| model = None |
|
|
| @app.route("/", methods=['GET']) |
| def home(): |
| return jsonify({ |
| "message": "CodeBERT Vulnerability Evalutor API", |
| "status": "Model loaded" if model is not None else "Model not loaded", |
| "device": str(device) if device else "unknown", |
| "endpoints": { |
| "/predict": "POST with JSON body containing 'codes' array" |
| } |
| }) |
|
|
| @app.route("/predict", methods=['POST']) |
| def predict_batch(): |
| try: |
| if model is None or tokenizer is None: |
| return jsonify({"error": "Model not loaded properly"}), 500 |
| |
| data = request.get_json() |
| if not data or 'codes' not in data: |
| return jsonify({"error": "Missing 'codes' field in JSON body"}), 400 |
| |
| codes = data['codes'] |
| if not isinstance(codes, list) or len(codes) == 0: |
| return jsonify({"error": "'codes' must be a non-empty array"}), 400 |
| |
| if len(codes) > 100: |
| return jsonify({"error": "Too many codes. Maximum 100 allowed."}), 400 |
| |
| validated_codes = [] |
| for i, code in enumerate(codes): |
| if not isinstance(code, str): |
| return jsonify({"error": f"Code at index {i} must be a string"}), 400 |
| if len(code.strip()) == 0: |
| validated_codes.append("# empty code") |
| elif len(code) > 50000: |
| return jsonify({"error": f"Code at index {i} too long. Maximum 50000 characters."}), 400 |
| else: |
| validated_codes.append(code.strip()) |
| |
| if len(validated_codes) == 1: |
| score = predict_vulnerability_with_chunking(validated_codes[0]) |
| cleanup_gpu_memory() |
| return jsonify({"results": [{"score": 1.0 - score}]}) |
| |
| batch_size = min(len(validated_codes), 16) |
| results = [] |
| |
| try: |
| for i in range(0, len(validated_codes), batch_size): |
| batch = validated_codes[i:i+batch_size] |
| |
| long_codes = [] |
| short_codes = [] |
| long_indices = [] |
| short_indices = [] |
| |
| for idx, code in enumerate(batch): |
| try: |
| tokens = tokenizer.encode(code, add_special_tokens=False, max_length=1000, truncation=True) |
| if len(tokens) > 450: |
| long_codes.append(code) |
| long_indices.append(i + idx) |
| else: |
| short_codes.append(code) |
| short_indices.append(i + idx) |
| except Exception as e: |
| print(f"Tokenization error for code {i + idx}: {e}") |
| short_codes.append(code) |
| short_indices.append(i + idx) |
| |
| batch_scores = [0.0] * len(batch) |
| |
| if short_codes: |
| try: |
| short_scores = predict_vulnerability_batch(short_codes) |
| for j, score in enumerate(short_scores): |
| local_idx = short_indices[j] - i |
| batch_scores[local_idx] = score |
| except Exception as e: |
| print(f"Batch prediction error: {e}") |
| for j in range(len(short_codes)): |
| local_idx = short_indices[j] - i |
| batch_scores[local_idx] = 0.0 |
| |
| for j, code in enumerate(long_codes): |
| try: |
| score = predict_vulnerability_with_chunking(code) |
| local_idx = long_indices[j] - i |
| batch_scores[local_idx] = score |
| except Exception as e: |
| print(f"Chunking prediction error: {e}") |
| local_idx = long_indices[j] - i |
| batch_scores[local_idx] = 0.0 |
| |
| for score in batch_scores: |
| results.append({"score": round(1.0 - score,4)}) |
| |
| cleanup_gpu_memory() |
| |
| except Exception as e: |
| cleanup_gpu_memory() |
| raise e |
| |
| return jsonify({"results": results}) |
| |
| except Exception as e: |
| cleanup_gpu_memory() |
| return jsonify({"error": f"Batch prediction error: {str(e)}"}), 500 |
|
|
| def predict_vulnerability_with_chunking(code): |
| try: |
| if not code or len(code.strip()) == 0: |
| return 0.0 |
| |
| tokens = tokenizer.encode(code, add_special_tokens=False, max_length=2000, truncation=True) |
| |
| if len(tokens) <= 450: |
| return predict_vulnerability(code) |
| |
| chunk_size = 400 |
| overlap = 50 |
| max_score = 0.0 |
| |
| for start in range(0, len(tokens), chunk_size - overlap): |
| end = min(start + chunk_size, len(tokens)) |
| chunk_tokens = tokens[start:end] |
| |
| try: |
| chunk_code = tokenizer.decode(chunk_tokens, skip_special_tokens=True) |
| if chunk_code.strip(): |
| score = predict_vulnerability(chunk_code) |
| max_score = max(max_score, score) |
| except Exception as e: |
| print(f"Chunk processing error: {e}") |
| continue |
| |
| if end >= len(tokens): |
| break |
| |
| return max_score |
| |
| except Exception as e: |
| print(f"Chunking error: {e}") |
| return 0.0 |
|
|
| def predict_vulnerability(code): |
| try: |
| if not code or len(code.strip()) == 0: |
| return 0.0 |
| |
| dynamic_length = min(max(len(code.split()) * 2, 128), 512) |
| |
| inputs = tokenizer( |
| code, |
| truncation=True, |
| padding='max_length', |
| max_length=dynamic_length, |
| return_tensors='pt' |
| ) |
| |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| |
| with torch.no_grad(): |
| if device.type == 'cuda': |
| with torch.cuda.amp.autocast(): |
| outputs = model(**inputs) |
| else: |
| outputs = model(**inputs) |
| |
| amplified_logits = 2.0 * outputs.logits |
| score = torch.sigmoid(amplified_logits).cpu().item() |
| return round(max(0.0, min(1.0, score)), 4) |
| |
| except Exception as e: |
| print(f"Single prediction error: {e}") |
| return 0.0 |
|
|
| def predict_vulnerability_batch(codes): |
| try: |
| if not codes or len(codes) == 0: |
| return [] |
| |
| filtered_codes = [code if code and code.strip() else "# empty" for code in codes] |
| |
| max_len = max([len(code.split()) * 2 for code in filtered_codes if code]) |
| dynamic_length = min(max(max_len, 128), 512) |
| |
| inputs = tokenizer( |
| filtered_codes, |
| truncation=True, |
| padding='max_length', |
| max_length=dynamic_length, |
| return_tensors='pt' |
| ) |
| |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| |
| with torch.no_grad(): |
| if device.type == 'cuda': |
| with torch.cuda.amp.autocast(): |
| outputs = model(**inputs) |
| else: |
| outputs = model(**inputs) |
| |
| amplified_logits = 2.0 * outputs.logits |
| scores = torch.sigmoid(amplified_logits).cpu().numpy() |
|
|
| return [round(max(0.0, min(1.0, float(score))), 4) for score in scores.flatten()] |
| |
| except Exception as e: |
| print(f"Batch prediction error: {e}") |
| return [0.0] * len(codes) |
|
|
| @app.route("/health", methods=['GET']) |
| def health_check(): |
| return jsonify({ |
| "status": "healthy", |
| "model_loaded": model is not None, |
| "tokenizer_loaded": tokenizer is not None, |
| "device": str(device) if device else "unknown" |
| }) |
|
|
| if __name__ == "__main__": |
| app.run(host="0.0.0.0", port=7860, debug=False, threaded=True) |