| import torch |
| import requests |
| import torchvision.transforms as transforms |
| from math import ceil |
| from PIL import Image |
| import matplotlib.pyplot as plt |
|
|
| MAX_RESOLUTION = 1024 |
|
|
| def get_resize_output_image_size( |
| image_size, |
| fix_resolution=False, |
| max_resolution: int = MAX_RESOLUTION, |
| patch_size=32 |
| ) -> tuple: |
| if fix_resolution==True: |
| return 224,224 |
| l1, l2 = image_size |
| short, long = (l2, l1) if l2 <= l1 else (l1, l2) |
|
|
| |
| requested_new_long = min( |
| [ |
| ceil(long / patch_size) * patch_size, |
| max_resolution, |
| ] |
| ) |
|
|
| new_long, new_short = requested_new_long, int(requested_new_long * short / long) |
| |
| new_short = ceil(new_short / patch_size) * patch_size |
| return (new_long, new_short) if l2 <= l1 else (new_short, new_long) |
|
|
|
|
| def preprocess_image( |
| image_tensor: torch.Tensor, |
| patch_size=32 |
| ) -> torch.Tensor: |
| |
| |
| |
| |
| patches = image_tensor.unfold(1, patch_size, patch_size)\ |
| .unfold(2, patch_size, patch_size) |
| patches = patches.permute(1, 2, 0, 3, 4).contiguous() |
| return patches |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def get_transform(height, width): |
| preprocess_transform = transforms.Compose([ |
| transforms.Resize((height, width)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225]) |
| ]) |
| return preprocess_transform |
|
|
| def convert_image_to_patches(image, patch_size=32) -> torch.Tensor: |
| |
| width, height = image.size |
| new_width, new_height = get_resize_output_image_size((width, height), patch_size=patch_size, fix_resolution=False) |
| img_tensor = get_transform(new_height, new_width)(image) |
| |
| img_patches = preprocess_image(img_tensor, patch_size=patch_size) |
| return img_patches |
|
|