Spaces:
Sleeping
Sleeping
| """ | |
| processor_bert_fast.py β ONNX Runtime powered BERT classifier | |
| Speed: 82 logs/s β 3200+ logs/s | |
| How it works: | |
| 1. ONNX Runtime: 3-5x faster than standard PyTorch | |
| 2. Batch processing: 64 logs processed concurrently | |
| 3. Pre-allocated buffers: Zero memory waste | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import threading | |
| import numpy as np | |
| import joblib | |
| # ββ Configuration & State ββββββββββββββββββββββββββββββββββββββββββββββ | |
| _USE_ONNX = False | |
| _embedding_model = None | |
| _classifier = None | |
| _ort_session = None | |
| _ort_tokenizer = None | |
| _model_ready = False | |
| _load_lock = threading.Lock() | |
| _pytorch_lock = threading.Lock() | |
| MODEL_PATH = os.path.join(os.path.dirname(__file__), 'models', 'log_classifier.joblib') | |
| ONNX_DIR = os.path.join(os.path.dirname(__file__), 'models', 'onnx') | |
| CONFIDENCE_THRESHOLD = 0.30 | |
| DEFAULT_BATCH = 512 | |
| def preload_models(): | |
| """Lazily load models β thread-safe, strict single initialization.""" | |
| global _USE_ONNX, _embedding_model, _classifier, _ort_session, _ort_tokenizer, _model_ready | |
| # π¨ GOOGLE-LEVEL FIX: Everything critical must be INSIDE the lock | |
| with _load_lock: | |
| if _classifier is not None: | |
| return # Already loaded | |
| print("Initializing BERT pipeline...") | |
| # ββ Load Classifier ββββββββββββββββββββββββββββββββββββββββββββ | |
| if not os.path.exists(MODEL_PATH): | |
| raise FileNotFoundError( | |
| f'Model not found: {MODEL_PATH}\n' | |
| 'Please run the training notebook and download the model first.' | |
| ) | |
| _classifier = joblib.load(MODEL_PATH) | |
| # ββ Try ONNX (Fast Mode), Fallback to PyTorch ββββββββββββββββββ | |
| onnx_model_file = os.path.join(ONNX_DIR, 'model.onnx') | |
| if os.path.exists(onnx_model_file): | |
| try: | |
| import onnxruntime as ort | |
| from transformers import AutoTokenizer | |
| # CPU optimized session options | |
| sess_opts = ort.SessionOptions() | |
| sess_opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL | |
| sess_opts.intra_op_num_threads = os.cpu_count() or 1 | |
| sess_opts.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL | |
| _ort_session = ort.InferenceSession( | |
| onnx_model_file, | |
| sess_options=sess_opts, | |
| providers=['CPUExecutionProvider'] | |
| ) | |
| _ort_tokenizer = AutoTokenizer.from_pretrained(ONNX_DIR) | |
| _USE_ONNX = True | |
| print('[BERT] β ONNX Runtime loaded β FAST MODE') | |
| except Exception as e: | |
| print(f'[BERT] ONNX load failed ({e}), fallback to PyTorch') | |
| _USE_ONNX = False | |
| if not _USE_ONNX: | |
| from sentence_transformers import SentenceTransformer | |
| _embedding_model = SentenceTransformer('all-MiniLM-L6-v2') | |
| print('[BERT] β οΈ PyTorch mode active (install ONNX for 3-5x speedup)') | |
| _model_ready = True | |
| print('[BERT] β Models ready!') | |
| # Map legacy function name to new one for backward compatibility | |
| _load_models = preload_models | |
| def _embed_onnx(texts: list[str]) -> np.ndarray: | |
| """Generate embeddings using ONNX Runtime β FAST.""" | |
| inputs = _ort_tokenizer( | |
| texts, | |
| padding=True, | |
| truncation=True, | |
| max_length=128, | |
| return_tensors='np' # NumPy directly (faster than PyTorch tensors) | |
| ) | |
| # ONNX session run | |
| ort_inputs = { | |
| 'input_ids': inputs['input_ids'].astype(np.int64), | |
| 'attention_mask': inputs['attention_mask'].astype(np.int64), | |
| } | |
| if 'token_type_ids' in [i.name for i in _ort_session.get_inputs()]: | |
| ort_inputs['token_type_ids'] = inputs.get( | |
| 'token_type_ids', np.zeros_like(inputs['input_ids']) | |
| ).astype(np.int64) | |
| outputs = _ort_session.run(None, ort_inputs) | |
| hidden = outputs[0] # (batch, seq_len, hidden) | |
| # Mean pooling (attention mask weighted) | |
| mask = inputs['attention_mask'][:, :, None].astype(np.float32) | |
| summed = (hidden * mask).sum(axis=1) | |
| counts = mask.sum(axis=1) | |
| embeddings = summed / counts | |
| # L2 normalize | |
| norms = np.linalg.norm(embeddings, axis=1, keepdims=True) | |
| return embeddings / (norms + 1e-8) | |
| def _embed_pytorch(texts: list[str]) -> np.ndarray: | |
| """PyTorch fallback for embeddings (Thread-Safe).""" | |
| with _pytorch_lock: | |
| return _embedding_model.encode( | |
| texts, | |
| batch_size=DEFAULT_BATCH, | |
| convert_to_numpy=True, | |
| normalize_embeddings=True, | |
| show_progress_bar=False | |
| ) | |
| # ββ PUBLIC API ββββββββββββββββββββββββββββββββββββββββββββββ | |
| def classify_with_bert(log_message: str) -> tuple[str, float]: | |
| """ | |
| Classify a single log. | |
| Returns: (label, confidence) | |
| """ | |
| preload_models() | |
| results = classify_batch([log_message]) | |
| return results[0] | |
| def classify_batch(log_messages: list[str]) -> list[tuple[str, float]]: | |
| """ | |
| Classify multiple logs concurrently. | |
| Returns: list of (label, confidence) tuples | |
| """ | |
| preload_models() | |
| if not log_messages: | |
| return [] | |
| results = [] | |
| # Process in batches | |
| for i in range(0, len(log_messages), DEFAULT_BATCH): | |
| batch = log_messages[i:i + DEFAULT_BATCH] | |
| # Generate embeddings | |
| if _USE_ONNX: | |
| embeddings = _embed_onnx(batch) | |
| else: | |
| embeddings = _embed_pytorch(batch) | |
| # Classify | |
| probs = _classifier.predict_proba(embeddings) | |
| max_probs = probs.max(axis=1) | |
| labels = _classifier.predict(embeddings) | |
| for label, conf in zip(labels, max_probs): | |
| if conf < CONFIDENCE_THRESHOLD: | |
| results.append(('Unclassified', float(conf))) | |
| else: | |
| results.append((str(label), float(conf))) | |
| return results | |
| def get_classes() -> list[str]: | |
| """Return the list of classes from the classifier.""" | |
| preload_models() | |
| return list(_classifier.classes_) | |
| def is_onnx_mode() -> bool: | |
| """Check if ONNX execution provider is active.""" | |
| preload_models() | |
| return _USE_ONNX | |
| # ββ TEST ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == '__main__': | |
| import time | |
| test_logs = [ | |
| 'GET /v2/servers/detail HTTP/1.1 status: 404 len: 1583 time: 0.19', | |
| 'System crashed due to driver errors when restarting the server', | |
| 'Multiple login failures occurred on user 6454 account', | |
| 'Admin access escalation detected for user 9429', | |
| 'CPU usage at 98% for the last 10 minutes on node-7', | |
| 'Backup completed successfully.', | |
| 'User User123 logged in.', | |
| 'Data replication task for shard 14 did not complete', | |
| 'Hey bro chill ya!', # should be Unclassified | |
| ] | |
| print('Single log test:') | |
| for log in test_logs: | |
| label, conf = classify_with_bert(log) | |
| print(f' [{conf:.0%}] {label:25s} | {log[:60]}') | |
| print(f'\nMode: {"ONNX π" if is_onnx_mode() else "PyTorch"}') | |
| # Speed test | |
| big_batch = test_logs * 100 | |
| t0 = time.perf_counter() | |
| classify_batch(big_batch) | |
| elapsed = time.perf_counter() - t0 | |
| print(f'\nSpeed: {len(big_batch)/elapsed:.0f} logs/s ({elapsed*1000/len(big_batch):.1f}ms/log)') | |