| | import numpy as np |
| |
|
| | import torch |
| | from transformers import GPT2TokenizerFast |
| | from .models import VisionGPT2Model |
| |
|
| | import albumentations as A |
| | from albumentations.pytorch import ToTensorV2 |
| |
|
| | from PIL import Image |
| | import matplotlib.pyplot as plt |
| | from types import SimpleNamespace |
| | import pathlib |
| | from tkinter import filedialog |
| |
|
| | def download(url:str, filename:str)->pathlib.Path: |
| | import functools |
| | import shutil |
| | import requests |
| | from tqdm.auto import tqdm |
| | |
| | r = requests.get(url, stream=True, allow_redirects=True) |
| | if r.status_code != 200: |
| | r.raise_for_status() |
| | raise RuntimeError(f"Request to {url} returned status code {r.status_code}\n Please download the captioner.pt file manually from the link provided in the README.md file.") |
| | file_size = int(r.headers.get('Content-Length', 0)) |
| |
|
| | path = pathlib.Path(filename).expanduser().resolve() |
| | path.parent.mkdir(parents=True, exist_ok=True) |
| |
|
| | desc = "(Unknown total file size)" if file_size == 0 else "" |
| | r.raw.read = functools.partial(r.raw.read, decode_content=True) |
| | with tqdm.wrapattr(r.raw, "read", total=file_size, desc=desc) as r_raw: |
| | with path.open("wb") as f: |
| | shutil.copyfileobj(r_raw, f) |
| |
|
| | return path |
| |
|
| | def main(): |
| | model_config = SimpleNamespace( |
| | vocab_size = 50257, |
| | embed_dim = 768, |
| | num_heads = 12, |
| | seq_len = 1024, |
| | depth = 12, |
| | attention_dropout = 0.1, |
| | residual_dropout = 0.1, |
| | mlp_ratio = 4, |
| | mlp_dropout = 0.1, |
| | emb_dropout = 0.1, |
| | ) |
| |
|
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
|
| | model = VisionGPT2Model(model_config).to(device) |
| | try: |
| | sd = torch.load("captioner.pt", map_location=device) |
| | except: |
| | print("Model not found. Downloading Model ") |
| | url = "https://drive.usercontent.google.com/download?id=1X51wAI7Bsnrhd2Pa4WUoHIXvvhIcRH7Y&export=download&authuser=0&confirm=t&uuid=ae5c4861-4411-4f81-88cd-66ea30b6fe2b&at=APZUnTWodeDt1upcQVMej2TDcADs%3A1722666079498" |
| | path = download(url, "captioner.pt") |
| | sd = torch.load(path, map_location=device) |
| |
|
| | model.load_state_dict(sd) |
| | model.eval() |
| | tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") |
| |
|
| | tfms = A.Compose([ |
| | A.Resize(224, 224), |
| | A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5],always_apply=True), |
| | ToTensorV2() |
| | ]) |
| |
|
| | test_img:str = filedialog.askopenfilename(title = "Select an image", |
| | filetypes = (("jpeg files","*.jpg"),("png files",'*.png'),("all files","*.*"))) |
| |
|
| | im = Image.open(test_img).convert("RGB") |
| |
|
| | det = True |
| | temp = 1.0 |
| | max_tokens = 50 |
| |
|
| | image = np.array(im) |
| | image:torch.Tensor = tfms(image=image)['image'] |
| | image = image.unsqueeze(0).to(device) |
| | seq = torch.ones(1,1).to(device).long()*tokenizer.bos_token_id |
| |
|
| | caption = model.generate(image, seq, max_tokens, temp, det) |
| | caption = tokenizer.decode(caption.numpy(), skip_special_tokens=True) |
| |
|
| | plt.imshow(im) |
| | plt.title(f"Predicted : {caption}") |
| | plt.axis('off') |
| | plt.show() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |