| import argparse |
| import torch |
| import gradio as gr |
| from torchvision import transforms |
| from runner import MaskGIT |
| import numpy as np |
| import random |
| import torchvision.utils as vutils |
|
|
|
|
| class Args(argparse.Namespace): |
| data_folder = "" |
| vqgan_folder = "pretrained_maskgit/VQGAN" |
| writer_log = "" |
| data = "" |
| mask_value = 1024 |
| seed = 1 |
| channel = 3 |
| num_workers = 0 |
| iter = 0 |
| global_epoch = 0 |
| lr = 1e-4 |
| drop_label = 0.1 |
| resume = True |
| device = "cpu" |
| print(device) |
| debug = True |
| test_only = False |
| is_master = True |
| is_multi_gpus = False |
| vit_size = "base" |
| vit_folder = "pretrained_maskgit/MaskGIT/MaskGIT_ImageNet_256.pth" |
| img_size = 256 |
| patch_size = 256 // 16 |
|
|
|
|
| def set_seed(seed): |
| if seed > 0: |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed(seed) |
| np.random.seed(seed) |
| random.seed(seed) |
| torch.backends.cudnn.enable = False |
| torch.backends.cudnn.deterministic = True |
|
|
| args = Args() |
| maskgit = MaskGIT(args) |
|
|
|
|
| |
| def synthesize_image(cls, sm_temp=1, w=3, r_temp=4.5, step=8, seed=1, nb_img=1): |
| |
| set_seed(seed) |
| with torch.no_grad(): |
| labels = [cls] * nb_img |
| labels = torch.LongTensor(labels).to(args.device) |
| gen_sample = maskgit.sample(nb_sample=labels.size(0), labels=labels, sm_temp=sm_temp, w=w, |
| randomize="linear", r_temp=r_temp, sched_mode="arccos", |
| step=step)[0] |
|
|
| |
| output_image = transforms.ToPILImage()(vutils.make_grid(gen_sample, nrow=2, padding=0, normalize=True)) |
|
|
| return output_image |
|
|
|
|
| |
| app = gr.Interface( |
| fn=synthesize_image, |
| inputs=[gr.Number(31), gr.Number(1.3), gr.Number(25), gr.Number(4.5), gr.Number(16), |
| gr.Slider(0, 1000, 60), gr.Number(1, maximum=4)], |
| outputs=gr.Image(), |
| title="Image Synthesis using MaskGIT", |
| ) |
|
|
| |
| app.launch(share=True) |
|
|
|
|