File size: 2,721 Bytes
95cecf1
 
 
 
 
 
 
 
 
ab7ed99
80bed1b
 
 
 
f4a55ae
 
80bed1b
95cecf1
f4a55ae
 
 
 
ac1e96e
3d74fc7
 
 
 
 
 
 
95cecf1
 
 
80bed1b
95cecf1
 
 
 
 
f289a83
95cecf1
 
f289a83
 
95cecf1
 
f289a83
 
95cecf1
80bed1b
95cecf1
 
 
e23ab6f
ead7e30
95cecf1
80bed1b
 
 
 
 
95cecf1
 
 
9455029
 
95cecf1
9455029
3399730
9455029
 
4313d66
 
95cecf1
 
 
 
 
80bed1b
 
 
 
39f5f2d
95cecf1
 
cac43f3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import gradio as gr
from transformers import AutoImageProcessor, AutoModel
import torch
from PIL import Image
import json
import numpy as np
import faiss

# Init similarity search AI model and processor
device = torch.device("cpu")
processor = AutoImageProcessor.from_pretrained(
    "facebook/dinov3-vitb16-pretrain-lvd1689m"
)
model = AutoModel.from_pretrained("facebook/dinov3-vitb16-pretrain-lvd1689m")
model.config.return_dict = False  # Set return_dict to False for JIT tracing
model.to(device)
model.eval()  # Set model to evaluation mode for inference

# Prepare an example input for tracing
example_input = torch.rand(1, 3, 224, 224).to(device)  # Adjust size if needed
traced_model = torch.jit.trace(model, example_input)
traced_model = traced_model.to(device)

# Load faiss index
index = faiss.read_index("xbgp-faiss.index")

# Load faiss map
with open("xbgp-faiss-map.json", "r") as f:
    images = json.load(f)


def process_image(image):
    """
    Process the image and extract features using the DINOv3 model.
    """
    # Convert to RGB if it isn't already
    if image.mode != "RGB":
        image = image.convert("RGB")

    # Resize to 224px while maintaining aspect ratio
    width, height = image.size
    if width < height:
        w_percent = 224 / float(width)
        new_width = 224
        new_height = int(float(height) * float(w_percent))
    else:
        h_percent = 224 / float(height)
        new_height = 224
        new_width = int(float(width) * float(h_percent))
    image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)

    # Extract the features from the uploaded image
    with torch.no_grad():
        inputs = processor(images=image, return_tensors="pt")["pixel_values"].to(device)
        outputs = traced_model(inputs)

        # Normalize the features before search
        embeddings = outputs[0].mean(dim=1)
        vector = embeddings.detach().cpu().numpy()
        vector = np.float32(vector)
        faiss.normalize_L2(vector)

    # Read the index file and perform search of top 50 images
    distances, indices = index.search(vector, 50)

    matches = []
    for idx, matching_gamerpic in enumerate(indices[0]):
        gamerpic = {}
        gamerpic["id"] = images[matching_gamerpic]
        gamerpic["score"] = str(round((1 / (distances[0][idx] + 1) * 100), 2)) + "%"
        matches.append(gamerpic)

    return matches


# Create a Gradio interface
iface = gr.Interface(
    fn=process_image,
    inputs=gr.Image(type="pil"),
    outputs="json",
    title="Xbox Gamerpic Finder - DINOv3",
    description="Upload an image to find similar Xbox 360 gamerpics using Meta's DINOv3 vision model",
).queue()

# Launch the Gradio app
iface.launch()