Mark8398 commited on
Commit
f634c61
·
verified ·
1 Parent(s): b51196c

Uploaded 6 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/beach.jpg filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from transformers import ViTModel, AutoModel, AutoTokenizer
6
+ from torchvision import transforms
7
+ from datasets import load_dataset
8
+ from PIL import Image
9
+
10
+ # --- 1. MODEL ARCHITECTURE ---
11
+ class MultiModalEngine(nn.Module):
12
+ def __init__(self):
13
+ super().__init__()
14
+ self.image_model = ViTModel.from_pretrained("google/vit-base-patch16-224")
15
+ self.text_model = AutoModel.from_pretrained("sentence-transformers/all-mpnet-base-v2")
16
+ self.image_projection = nn.Linear(768, 256)
17
+ self.text_projection = nn.Linear(768, 256)
18
+ self.logit_scale = nn.Parameter(torch.ones([]) * 2.659)
19
+
20
+ def encode_text(self, input_ids, attention_mask):
21
+ text_out = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
22
+ text_embeds = self.text_projection(self.mean_pooling(text_out, attention_mask))
23
+ return F.normalize(text_embeds, dim=1)
24
+
25
+ def encode_image(self, images):
26
+ vision_out = self.image_model(pixel_values=images)
27
+ image_embeds = self.image_projection(vision_out.last_hidden_state[:, 0, :])
28
+ return F.normalize(image_embeds, dim=1)
29
+
30
+ def mean_pooling(self, model_output, attention_mask):
31
+ token_embeddings = model_output.last_hidden_state
32
+ mask = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
33
+ return torch.sum(token_embeddings * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)
34
+
35
+ # --- 2. LOAD RESOURCES ---
36
+ print("⏳ Loading resources...")
37
+ device = "cpu"
38
+
39
+ # Load Model
40
+ model = MultiModalEngine()
41
+ model.load_state_dict(torch.load("flickr8k_best_model_r1_27.pth", map_location=device))
42
+ model.eval()
43
+
44
+ # Load Index
45
+ image_embeddings = torch.load("flickr8k_best_index.pt", map_location=device)
46
+
47
+ # Load Tokenizer & Transforms
48
+ tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-mpnet-base-v2")
49
+ val_transform = transforms.Compose([
50
+ transforms.Resize((224, 224)),
51
+ transforms.ToTensor(),
52
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
53
+ ])
54
+ # Load Dataset (Standard mode to fetch result images)
55
+ print("⏳ Downloading dataset (this may take a minute)...")
56
+ dataset = load_dataset("tsystems/flickr8k", split="train")
57
+
58
+ print("✅ Server Ready!")
59
+
60
+ # --- 3. SEARCH LOGIC ---
61
+ def search_text(query):
62
+ inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True)
63
+ with torch.no_grad():
64
+ text_emb = model.encode_text(inputs['input_ids'], inputs['attention_mask'])
65
+
66
+ scores = text_emb @ image_embeddings.T
67
+ scores = scores.squeeze()
68
+ values, indices = torch.topk(scores, 3)
69
+
70
+ return [dataset[int(idx)]['image'] for idx in indices]
71
+
72
+ def search_image(query_img):
73
+ if query_img is None: return []
74
+ # Ensure it's a PIL Image (Gradio handles this, but good safety)
75
+ if not isinstance(query_img, Image.Image):
76
+ query_img = Image.fromarray(query_img)
77
+
78
+ img_tensor = val_transform(query_img).unsqueeze(0)
79
+ with torch.no_grad():
80
+ img_emb = model.encode_image(img_tensor)
81
+
82
+ scores = img_emb @ image_embeddings.T
83
+ scores = scores.squeeze()
84
+ values, indices = torch.topk(scores, 3)
85
+
86
+ return [dataset[int(idx)]['image'] for idx in indices]
87
+
88
+ # --- 4. UI WITH EXAMPLES ---
89
+ with gr.Blocks(title="Flickr8k AI Search", theme=gr.themes.Soft()) as demo:
90
+ gr.Markdown("# 🔍 AI Super-Search")
91
+ gr.Markdown("Search for images using **Text** OR using another **Image**.")
92
+
93
+ with gr.Tabs():
94
+ # --- TAB 1: TEXT SEARCH ---
95
+ with gr.TabItem("Search by Text"):
96
+ with gr.Row():
97
+ txt_input = gr.Textbox(label="Type your query", placeholder="e.g. A dog running...")
98
+ txt_btn = gr.Button("Search", variant="primary")
99
+
100
+ txt_gallery = gr.Gallery(label="Top Matches", columns=3, height=300)
101
+
102
+ # CLICKABLE TEXT EXAMPLES
103
+ gr.Examples(
104
+ examples=[
105
+ ["A dog running on grass"],
106
+ ["Children playing in the water"],
107
+ ["A girl in a pink dress"],
108
+ ["A man climbing a rock"]
109
+ ],
110
+ inputs=txt_input, # Clicking populates this box
111
+ outputs=txt_gallery, # Result appears here
112
+ fn=search_text, # Function to run
113
+ run_on_click=True, # Run immediately when clicked!
114
+ label="Try these examples:"
115
+ )
116
+
117
+ txt_btn.click(search_text, inputs=txt_input, outputs=txt_gallery)
118
+
119
+ # --- TAB 2: IMAGE SEARCH ---
120
+ # --- TAB 2: IMAGE SEARCH ---
121
+ with gr.TabItem("Search by Image"):
122
+ # 1. Define components first (but don't draw them yet)
123
+ # We set render=False so we can place them visually later
124
+ img_input = gr.Image(type="pil", label="Upload Source Image", sources=['upload', 'clipboard'], render=False)
125
+ img_gallery = gr.Gallery(label="Similar Images", columns=3, height=300, render=False)
126
+
127
+ # 2. Draw Examples FIRST (So they appear at the very top)
128
+ gr.Examples(
129
+ examples=[
130
+ ["examples/dog.jpg"],
131
+ ["examples/beach.jpg"]
132
+ ],
133
+ inputs=img_input,
134
+ outputs=img_gallery,
135
+ fn=search_image,
136
+ run_on_click=True,
137
+ label="Click an image to test:"
138
+ )
139
+
140
+ # 3. Draw Input and Button (Visually below examples)
141
+ with gr.Row():
142
+ img_input.render() # <--- Now we actually draw the input box
143
+ img_btn = gr.Button("Find Similar", variant="primary")
144
+
145
+ # 4. Draw Gallery (Visually at the bottom)
146
+ img_gallery.render()
147
+
148
+ # 5. Connect the Button
149
+ img_btn.click(search_image, inputs=img_input, outputs=img_gallery)
150
+
151
+ if __name__ == "__main__":
152
+ demo.launch()
examples/beach.jpg ADDED

Git LFS Details

  • SHA256: 9f957fcad6e690f37f99e5bba984de7d8958a4527288298759ee7754002dade9
  • Pointer size: 131 Bytes
  • Size of remote file: 102 kB
examples/dog.jpg ADDED
flickr8k_best_index.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c5c6aa5489f0d9ce320e37ab0ebbed3e49f25b5f769a1826c4684a363766b5b5
3
+ size 8286845
flickr8k_best_model_r1_27.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d1003c7746967816f6e1d02d58c199fb8c32cd94ae9c45366a38707a79d1c43b
3
+ size 785252523
requirements.txt ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==24.1.0
2
+ aiohappyeyeballs==2.6.1
3
+ aiohttp==3.13.2
4
+ aiosignal==1.4.0
5
+ annotated-doc==0.0.4
6
+ annotated-types==0.7.0
7
+ anyio==4.12.0
8
+ attrs==25.4.0
9
+ brotli==1.2.0
10
+ certifi==2025.11.12
11
+ charset-normalizer==3.4.4
12
+ click==8.3.1
13
+ colorama==0.4.6
14
+ datasets==4.4.2
15
+ dill==0.4.0
16
+ fastapi==0.125.0
17
+ ffmpy==1.0.0
18
+ filelock==3.20.1
19
+ frozenlist==1.8.0
20
+ fsspec==2025.10.0
21
+ gradio==6.2.0
22
+ gradio_client==2.0.2
23
+ groovy==0.1.2
24
+ h11==0.16.0
25
+ hf-xet==1.2.0
26
+ httpcore==1.0.9
27
+ httpx==0.28.1
28
+ huggingface-hub==0.36.0
29
+ idna==3.11
30
+ Jinja2==3.1.6
31
+ markdown-it-py==4.0.0
32
+ MarkupSafe==3.0.3
33
+ mdurl==0.1.2
34
+ mpmath==1.3.0
35
+ multidict==6.7.0
36
+ multiprocess==0.70.18
37
+ networkx==3.6.1
38
+ numpy==2.3.5
39
+ orjson==3.11.5
40
+ packaging==25.0
41
+ pandas==2.3.3
42
+ pillow==12.0.0
43
+ propcache==0.4.1
44
+ pyarrow==22.0.0
45
+ pydantic==2.12.5
46
+ pydantic_core==2.41.5
47
+ pydub==0.25.1
48
+ Pygments==2.19.2
49
+ python-dateutil==2.9.0.post0
50
+ python-multipart==0.0.21
51
+ pytz==2025.2
52
+ PyYAML==6.0.3
53
+ regex==2025.11.3
54
+ requests==2.32.5
55
+ rich==14.2.0
56
+ safehttpx==0.1.7
57
+ safetensors==0.7.0
58
+ semantic-version==2.10.0
59
+ shellingham==1.5.4
60
+ six==1.17.0
61
+ starlette==0.50.0
62
+ sympy==1.14.0
63
+ tokenizers==0.22.1
64
+ tomlkit==0.13.3
65
+ torch==2.9.1
66
+ torchvision==0.24.1
67
+ tqdm==4.67.1
68
+ transformers==4.57.3
69
+ typer==0.20.1
70
+ typer-slim==0.20.1
71
+ typing-inspection==0.4.2
72
+ typing_extensions==4.15.0
73
+ tzdata==2025.3
74
+ urllib3==2.6.2
75
+ uvicorn==0.38.0
76
+ xxhash==3.6.0
77
+ yarl==1.22.0