| import gradio as gr |
| import gc |
| import cv2 |
| import torch |
| import torch.nn.functional as F |
| from tqdm import tqdm |
| from transformers import DistilBertTokenizer |
| import matplotlib.pyplot as plt |
| from implement import * |
| import config as CFG |
| from main import build_loaders |
| from CLIP import CLIPModel |
| import os |
| with gr.Blocks(css="style.css") as demo: |
| def get_image_embeddings(valid_df, model_path): |
| tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer) |
| valid_loader = build_loaders(valid_df, tokenizer, mode="valid") |
| |
| model = CLIPModel().to(CFG.device) |
| model.load_state_dict(torch.load(model_path, map_location=CFG.device)) |
| model.eval() |
| |
| valid_image_embeddings = [] |
| with torch.no_grad(): |
| for batch in tqdm(valid_loader): |
| image_features = model.image_encoder(batch["image"].to(CFG.device)) |
| image_embeddings = model.image_projection(image_features) |
| valid_image_embeddings.append(image_embeddings) |
| return model, torch.cat(valid_image_embeddings) |
|
|
| _, valid_df = make_train_valid_dfs() |
| model, image_embeddings = get_image_embeddings(valid_df, "best.pt") |
|
|
| def find_matches(query, n=9): |
| tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer) |
| encoded_query = tokenizer([query]) |
| batch = { |
| key: torch.tensor(values).to(CFG.device) |
| for key, values in encoded_query.items() |
| } |
| with torch.no_grad(): |
| text_features = model.text_encoder( |
| input_ids=batch["input_ids"], attention_mask=batch["attention_mask"] |
| ) |
| text_embeddings = model.text_projection(text_features) |
| |
| image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1) |
| text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1) |
| dot_similarity = text_embeddings_n @ image_embeddings_n.T |
| |
| _, indices = torch.topk(dot_similarity.squeeze(0), n * 5) |
| matches = [valid_df['image'].values[idx] for idx in indices[::5]] |
| |
| images = [] |
| for match in matches: |
| image = cv2.imread(f"{CFG.image_path}/{match}") |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
| |
| |
| return image |
| with gr.Row(): |
| textbox = gr.Textbox(label = "Enter a query to find matching images using a CLIP model.") |
| image = gr.Image(type="numpy") |
| |
| button = gr.Button("Press") |
| button.click( |
| fn = find_matches, |
| inputs=textbox, |
| outputs=image |
| ) |
| |
| |
| demo.launch(share=True) |
|
|