xbgp / app.py
methodw's picture
switch to dinov3
80bed1b verified
raw
history blame
2.72 kB
import gradio as gr
from transformers import AutoImageProcessor, AutoModel
import torch
from PIL import Image
import json
import numpy as np
import faiss
# Init similarity search AI model and processor
device = torch.device("cpu")
processor = AutoImageProcessor.from_pretrained(
"facebook/dinov3-vitb16-pretrain-lvd1689m"
)
model = AutoModel.from_pretrained("facebook/dinov3-vitb16-pretrain-lvd1689m")
model.config.return_dict = False # Set return_dict to False for JIT tracing
model.to(device)
model.eval() # Set model to evaluation mode for inference
# Prepare an example input for tracing
example_input = torch.rand(1, 3, 224, 224).to(device) # Adjust size if needed
traced_model = torch.jit.trace(model, example_input)
traced_model = traced_model.to(device)
# Load faiss index
index = faiss.read_index("xbgp-faiss.index")
# Load faiss map
with open("xbgp-faiss-map.json", "r") as f:
images = json.load(f)
def process_image(image):
"""
Process the image and extract features using the DINOv3 model.
"""
# Convert to RGB if it isn't already
if image.mode != "RGB":
image = image.convert("RGB")
# Resize to 224px while maintaining aspect ratio
width, height = image.size
if width < height:
w_percent = 224 / float(width)
new_width = 224
new_height = int(float(height) * float(w_percent))
else:
h_percent = 224 / float(height)
new_height = 224
new_width = int(float(width) * float(h_percent))
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
# Extract the features from the uploaded image
with torch.no_grad():
inputs = processor(images=image, return_tensors="pt")["pixel_values"].to(device)
outputs = traced_model(inputs)
# Normalize the features before search
embeddings = outputs[0].mean(dim=1)
vector = embeddings.detach().cpu().numpy()
vector = np.float32(vector)
faiss.normalize_L2(vector)
# Read the index file and perform search of top 50 images
distances, indices = index.search(vector, 50)
matches = []
for idx, matching_gamerpic in enumerate(indices[0]):
gamerpic = {}
gamerpic["id"] = images[matching_gamerpic]
gamerpic["score"] = str(round((1 / (distances[0][idx] + 1) * 100), 2)) + "%"
matches.append(gamerpic)
return matches
# Create a Gradio interface
iface = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil"),
outputs="json",
title="Xbox Gamerpic Finder - DINOv3",
description="Upload an image to find similar Xbox 360 gamerpics using Meta's DINOv3 vision model",
).queue()
# Launch the Gradio app
iface.launch()