| import os |
| import numpy as np |
| import torch |
| from contextlib import nullcontext |
| from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker |
| from einops import rearrange |
| from ldm.util import instantiate_from_config |
| from ldm.models.diffusion.ddim import DDIMSampler |
| from omegaconf import OmegaConf |
| from PIL import Image |
| from rich import print |
| from transformers import CLIPImageProcessor |
| from torch import autocast |
| from torchvision import transforms |
|
|
|
|
| def load_model_from_config(config, ckpt, device, verbose=False): |
| print(f'Loading model from {ckpt}') |
| pl_sd = torch.load(ckpt, map_location='cpu') |
| if 'global_step' in pl_sd: |
| print(f'Global Step: {pl_sd["global_step"]}') |
| sd = pl_sd['state_dict'] |
| model = instantiate_from_config(config.model) |
| m, u = model.load_state_dict(sd, strict=False) |
| if len(m) > 0 and verbose: |
| print('missing keys:') |
| print(m) |
| if len(u) > 0 and verbose: |
| print('unexpected keys:') |
| print(u) |
|
|
| model.to(device) |
| model.eval() |
| return model |
|
|
|
|
| def init_model(device, ckpt): |
| config = os.path.join(os.path.dirname(__file__), 'configs/sd-objaverse-finetune-c_concat-256.yaml') |
| config = OmegaConf.load(config) |
|
|
| |
| models = dict() |
| print('Instantiating LatentDiffusion...') |
| models['turncam'] = torch.compile(load_model_from_config(config, ckpt, device=device)) |
| print('Instantiating StableDiffusionSafetyChecker...') |
| models['nsfw'] = StableDiffusionSafetyChecker.from_pretrained( |
| 'CompVis/stable-diffusion-safety-checker').to(device) |
| models['clip_fe'] = CLIPImageProcessor.from_pretrained( |
| "openai/clip-vit-large-patch14") |
| |
| models['nsfw'].concept_embeds_weights *= 1.2 |
| models['nsfw'].special_care_embeds_weights *= 1.2 |
|
|
| return models |
|
|
| @torch.no_grad() |
| def sample_model_batch(model, sampler, input_im, xs, ys, n_samples=4, precision='autocast', ddim_eta=1.0, ddim_steps=75, scale=3.0, h=256, w=256): |
| precision_scope = autocast if precision == 'autocast' else nullcontext |
| with precision_scope("cuda"): |
| with model.ema_scope(): |
| c = model.get_learned_conditioning(input_im).tile(n_samples, 1, 1) |
| T = [] |
| for x, y in zip(xs, ys): |
| T.append([np.radians(x), np.sin(np.radians(y)), np.cos(np.radians(y)), 0]) |
| T = torch.tensor(np.array(T))[:, None, :].float().to(c.device) |
| c = torch.cat([c, T], dim=-1) |
| c = model.cc_projection(c) |
| cond = {} |
| cond['c_crossattn'] = [c] |
| cond['c_concat'] = [model.encode_first_stage(input_im).mode().detach() |
| .repeat(n_samples, 1, 1, 1)] |
| if scale != 1.0: |
| uc = {} |
| uc['c_concat'] = [torch.zeros(n_samples, 4, h // 8, w // 8).to(c.device)] |
| uc['c_crossattn'] = [torch.zeros_like(c).to(c.device)] |
| else: |
| uc = None |
|
|
| shape = [4, h // 8, w // 8] |
| samples_ddim, _ = sampler.sample(S=ddim_steps, |
| conditioning=cond, |
| batch_size=n_samples, |
| shape=shape, |
| verbose=False, |
| unconditional_guidance_scale=scale, |
| unconditional_conditioning=uc, |
| eta=ddim_eta, |
| x_T=None) |
| print(samples_ddim.shape) |
| |
| x_samples_ddim = model.decode_first_stage(samples_ddim) |
| ret_imgs = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0).cpu() |
| del cond, c, x_samples_ddim, samples_ddim, uc, input_im |
| torch.cuda.empty_cache() |
| return ret_imgs |
|
|
| @torch.no_grad() |
| def predict_stage1_gradio(model, raw_im, save_path = "", adjust_set=[], device="cuda", ddim_steps=75, scale=3.0): |
| |
| |
| input_im_init = np.asarray(raw_im, dtype=np.float32) / 255.0 |
| input_im = transforms.ToTensor()(input_im_init).unsqueeze(0).to(device) |
| input_im = input_im * 2 - 1 |
|
|
| |
| delta_x_1_8 = [0] * 4 + [30] * 4 + [-30] * 4 |
| delta_y_1_8 = [0+90*(i%4) if i < 4 else 30+90*(i%4) for i in range(8)] + [30+90*(i%4) for i in range(4)] |
|
|
| ret_imgs = [] |
| sampler = DDIMSampler(model) |
| |
| if adjust_set != []: |
| x_samples_ddims_8 = sample_model_batch(model, sampler, input_im, |
| [delta_x_1_8[i] for i in adjust_set], [delta_y_1_8[i] for i in adjust_set], |
| n_samples=len(adjust_set), ddim_steps=ddim_steps, scale=scale) |
| else: |
| x_samples_ddims_8 = sample_model_batch(model, sampler, input_im, delta_x_1_8, delta_y_1_8, n_samples=len(delta_x_1_8), ddim_steps=ddim_steps, scale=scale) |
| sample_idx = 0 |
| for stage1_idx in range(len(delta_x_1_8)): |
| if adjust_set != [] and stage1_idx not in adjust_set: |
| continue |
| x_sample = 255.0 * rearrange(x_samples_ddims_8[sample_idx].numpy(), 'c h w -> h w c') |
| out_image = Image.fromarray(x_sample.astype(np.uint8)) |
| ret_imgs.append(out_image) |
| if save_path: |
| out_image.save(os.path.join(save_path, '%d.png'%(stage1_idx))) |
| sample_idx += 1 |
| del x_samples_ddims_8 |
| del sampler |
| torch.cuda.empty_cache() |
| return ret_imgs |
|
|
| def infer_stage_2(model, save_path_stage1, save_path_stage2, delta_x_2, delta_y_2, indices, device, ddim_steps=75, scale=3.0): |
| for stage1_idx in indices: |
| |
| |
| |
| stage1_image_path = os.path.join(save_path_stage1, '%d.png'%(stage1_idx)) |
|
|
| raw_im = Image.open(stage1_image_path) |
| |
| input_im_init = np.asarray(raw_im, dtype=np.float32) |
| input_im_init[input_im_init >= 253.0] = 255.0 |
| input_im_init = input_im_init / 255.0 |
| input_im = transforms.ToTensor()(input_im_init).unsqueeze(0).to(device) |
| input_im = input_im * 2 - 1 |
| |
| sampler = DDIMSampler(model) |
| |
| |
| x_samples_ddims_stage2 = sample_model_batch(model, sampler, input_im, delta_x_2, delta_y_2, n_samples=len(delta_x_2), ddim_steps=ddim_steps, scale=scale) |
| for stage2_idx in range(len(delta_x_2)): |
| x_sample_stage2 = 255.0 * rearrange(x_samples_ddims_stage2[stage2_idx].numpy(), 'c h w -> h w c') |
| Image.fromarray(x_sample_stage2.astype(np.uint8)).save(os.path.join(save_path_stage2, '%d_%d.png'%(stage1_idx, stage2_idx))) |
| del input_im |
| del x_samples_ddims_stage2 |
| torch.cuda.empty_cache() |
|
|
| def zero123_infer(model, input_dir_path, start_idx=0, end_idx=12, indices=None, device="cuda", ddim_steps=75, scale=3.0): |
| |
| save_path_8 = os.path.join(input_dir_path, "stage1_8") |
| save_path_8_2 = os.path.join(input_dir_path, "stage2_8") |
| os.makedirs(save_path_8_2, exist_ok=True) |
|
|
| |
| |
| |
| |
| |
|
|
| |
| delta_x_2 = [-10, 10, 0, 0] |
| delta_y_2 = [0, 0, -10, 10] |
| |
| infer_stage_2(model, save_path_8, save_path_8_2, delta_x_2, delta_y_2, indices=indices if indices else list(range(start_idx,end_idx)), device=device, ddim_steps=ddim_steps, scale=scale) |
|
|