| | import torch |
| | import torch.nn as nn |
| | import torchvision.transforms as transforms |
| | from torchvision import models |
| | from PIL import Image |
| | import logging |
| | import os |
| |
|
| | |
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| | class VisionProcessingModel(nn.Module): |
| | def __init__(self, model_name="resnet50", num_classes=1000, top_k=5): |
| | super(VisionProcessingModel, self).__init__() |
| | |
| | |
| | self.model = self._load_pretrained_model(model_name, num_classes) |
| | self.model.eval() |
| |
|
| | |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | self.model.to(self.device) |
| | logger.info(f"Model loaded on device: {self.device}") |
| |
|
| | |
| | self.top_k = top_k |
| |
|
| | |
| | self.transform = transforms.Compose([ |
| | transforms.Resize(256), |
| | transforms.CenterCrop(224), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| | ]) |
| |
|
| | def _load_pretrained_model(self, model_name, num_classes): |
| | """Helper function to load a pre-trained model.""" |
| | if model_name == "resnet50": |
| | return models.resnet50(pretrained=True) |
| | elif model_name == "efficientnet_b0": |
| | return models.efficientnet_b0(pretrained=True) |
| | elif model_name == "mobilenet_v2": |
| | return models.mobilenet_v2(pretrained=True) |
| | else: |
| | raise ValueError(f"Unsupported model: {model_name}") |
| |
|
| | def forward(self, image): |
| | """Forward pass through the model.""" |
| | image = image.to(self.device) |
| | with torch.no_grad(): |
| | outputs = self.model(image) |
| | return outputs |
| |
|
| | def process_image(self, image_path): |
| | """Process an image and get predictions.""" |
| | try: |
| | |
| | image = Image.open(image_path).convert('RGB') |
| |
|
| | |
| | image = self.transform(image).unsqueeze(0) |
| |
|
| | |
| | outputs = self.forward(image) |
| |
|
| | |
| | probabilities = torch.nn.functional.softmax(outputs[0], dim=0) |
| |
|
| | |
| | top_k_prob, top_k_catid = torch.topk(probabilities, self.top_k) |
| |
|
| | return top_k_prob, top_k_catid |
| | except Exception as e: |
| | logger.error(f"Error processing image: {e}") |
| | return None, None |
| |
|
| | def get_category_labels(self, category_ids): |
| | """Map category IDs to human-readable labels.""" |
| | |
| | labels_path = os.getenv("IMAGENET_LABELS_PATH", "path/to/imagenet_labels.txt") |
| | with open(labels_path, "r") as f: |
| | labels = [line.strip() for line in f.readlines()] |
| |
|
| | |
| | return [labels[cat_id] for cat_id in category_ids] |
| |
|
| | def enhance_vision_processing(self, image_path): |
| | """Enhance vision capabilities by extracting top-k predictions.""" |
| | top_k_prob, top_k_catid = self.process_image(image_path) |
| |
|
| | if top_k_prob is not None and top_k_catid is not None: |
| | |
| | top_k_prob = top_k_prob.tolist() |
| | top_k_catid = top_k_catid.tolist() |
| |
|
| | |
| | category_labels = self.get_category_labels(top_k_catid) |
| |
|
| | return top_k_prob, top_k_catid, category_labels |
| | else: |
| | return None, None, None |
| |
|