Spaces:
Sleeping
Sleeping
| # app/models/image_model.py | |
| # EfficientNet-based image classification model with ONNX optimization | |
| from pathlib import Path | |
| import numpy as np | |
| from PIL import Image | |
| from app.config import get_settings | |
| from app.observability.logging import get_logger | |
| logger = get_logger(__name__) | |
| class ImageClassificationModel: | |
| """ | |
| Image content classifier using EfficientNet. | |
| Detects violence, NSFW content, and other harmful imagery. | |
| Supports ONNX (fast) and PyTorch (fallback) inference. | |
| """ | |
| LABELS = ["safe", "violence", "nsfw", "self_harm", "hate_symbol"] | |
| def __init__(self): | |
| self.settings = get_settings() | |
| self.processor = None | |
| self.onnx_session = None | |
| self.pt_model = None | |
| self.device = None | |
| self._loaded = False | |
| self._num_labels = len(self.LABELS) | |
| def load(self) -> None: | |
| """Load the image processor and model.""" | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| model_name = self.settings.image_model_name | |
| cache_dir = self.settings.model_cache_path / "efficientnet" | |
| onnx_path = cache_dir / "image_classifier.onnx" | |
| logger.info("loading_image_model", model=model_name) | |
| # Load image processor | |
| try: | |
| self.processor = AutoImageProcessor.from_pretrained( | |
| model_name, cache_dir=cache_dir | |
| ) | |
| except Exception: | |
| # Fallback: use a generic processor | |
| from transformers import AutoImageProcessor | |
| self.processor = AutoImageProcessor.from_pretrained( | |
| "google/efficientnet-b0", cache_dir=cache_dir | |
| ) | |
| if self.settings.onnx_enabled and onnx_path.exists(): | |
| from app.models.onnx_utils import load_onnx_session | |
| self.onnx_session = load_onnx_session(onnx_path) | |
| logger.info("image_model_loaded", backend="onnx") | |
| else: | |
| self._load_pytorch(model_name, cache_dir) | |
| if self.settings.onnx_enabled: | |
| try: | |
| self._export_onnx(onnx_path) | |
| from app.models.onnx_utils import load_onnx_session | |
| self.onnx_session = load_onnx_session(onnx_path) | |
| self.pt_model = None | |
| logger.info("image_model_loaded", backend="onnx", note="exported") | |
| except Exception as e: | |
| logger.warning("onnx_export_failed", error=str(e), fallback="pytorch") | |
| else: | |
| logger.info("image_model_loaded", backend="pytorch") | |
| self._loaded = True | |
| def _load_pytorch(self, model_name: str, cache_dir: Path) -> None: | |
| """Load PyTorch model.""" | |
| import torch | |
| from transformers import AutoModelForImageClassification | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| try: | |
| self.pt_model = AutoModelForImageClassification.from_pretrained( | |
| model_name, cache_dir=cache_dir | |
| ) | |
| except Exception: | |
| # If the model doesn't exist as a pretrained classifier, load base EfficientNet | |
| self.pt_model = AutoModelForImageClassification.from_pretrained( | |
| "google/efficientnet-b0", cache_dir=cache_dir | |
| ) | |
| self.pt_model.to(self.device) | |
| self.pt_model.eval() | |
| # Update labels from model config if available | |
| if hasattr(self.pt_model.config, "id2label"): | |
| model_labels = list(self.pt_model.config.id2label.values()) | |
| if model_labels: | |
| self._num_labels = len(model_labels) | |
| def _export_onnx(self, onnx_path: Path) -> None: | |
| """Export to ONNX.""" | |
| import torch | |
| from app.models.onnx_utils import export_to_onnx | |
| dummy_input = torch.randn(1, 3, 224, 224).to(self.device) | |
| export_to_onnx( | |
| model=self.pt_model, | |
| sample_input={"pixel_values": dummy_input}, | |
| output_path=onnx_path, | |
| input_names=["pixel_values"], | |
| output_names=["logits"], | |
| ) | |
| def predict(self, image: Image.Image) -> dict: | |
| """ | |
| Classify an image for harmful content. | |
| Args: | |
| image: PIL Image (RGB). | |
| Returns: | |
| Dict with labels, scores, is_harmful, max_score, max_label. | |
| """ | |
| if not self._loaded: | |
| raise RuntimeError("Image model not loaded. Call load() first.") | |
| # Preprocess with the model's processor | |
| inputs = self.processor(images=image, return_tensors="np" if self.onnx_session else "pt") | |
| if self.onnx_session: | |
| return self._predict_onnx(inputs) | |
| else: | |
| return self._predict_pytorch(inputs) | |
| def _predict_onnx(self, inputs) -> dict: | |
| """ONNX inference.""" | |
| from app.models.onnx_utils import onnx_inference | |
| pixel_values = inputs["pixel_values"].astype(np.float32) | |
| outputs = onnx_inference(self.onnx_session, {"pixel_values": pixel_values}) | |
| logits = outputs[0][0] | |
| return self._format_output(logits) | |
| def _predict_pytorch(self, inputs) -> dict: | |
| """PyTorch inference.""" | |
| import torch | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = self.pt_model(**inputs) | |
| logits = outputs.logits[0].cpu().numpy() | |
| return self._format_output(logits) | |
| def _format_output(self, logits: np.ndarray) -> dict: | |
| """Convert logits to prediction dict.""" | |
| # Softmax for single-label classification | |
| exp_logits = np.exp(logits - np.max(logits)) | |
| scores = (exp_logits / exp_logits.sum()).tolist() | |
| # Map to our labels (or use model's own labels) | |
| if self.pt_model and hasattr(self.pt_model.config, "id2label"): | |
| labels = [self.pt_model.config.id2label.get(i, f"class_{i}") for i in range(len(scores))] | |
| else: | |
| labels = [f"class_{i}" for i in range(len(scores))] | |
| max_idx = int(np.argmax(scores)) | |
| # Determine if harmful (anything not classified as safe/non-violent) | |
| safe_keywords = {"safe", "non-violence", "non_violence", "normal", "neutral"} | |
| is_harmful = labels[max_idx].lower().replace("-", "_").replace(" ", "_") not in safe_keywords | |
| return { | |
| "labels": labels, | |
| "scores": scores, | |
| "is_harmful": is_harmful, | |
| "max_score": scores[max_idx], | |
| "max_label": labels[max_idx], | |
| } | |
| def is_loaded(self) -> bool: | |
| return self._loaded | |