| import argparse |
| import sys |
| import os |
|
|
| from typing import Dict, Optional, Tuple, List |
| from omegaconf import OmegaConf |
| from PIL import Image |
| from dataclasses import dataclass |
| from collections import defaultdict |
| import torch |
| import torch.utils.checkpoint |
| from torchvision.utils import make_grid, save_image |
| from accelerate.utils import set_seed |
| from tqdm.auto import tqdm |
| import torch.nn.functional as F |
| from einops import rearrange |
| from rembg import remove, new_session |
| import pdb |
| from mvdiffusion.pipelines.pipeline_mvdiffusion_unclip import StableUnCLIPImg2ImgPipeline |
| from econdataset import SMPLDataset |
| from reconstruct import ReMesh |
| providers = [ |
| ('CUDAExecutionProvider', { |
| 'device_id': 0, |
| 'arena_extend_strategy': 'kSameAsRequested', |
| 'gpu_mem_limit': 8 * 1024 * 1024 * 1024, |
| 'cudnn_conv_algo_search': 'HEURISTIC', |
| }) |
| ] |
| session = new_session(providers=providers) |
|
|
| weight_dtype = torch.float16 |
| def tensor_to_numpy(tensor): |
| return tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() |
|
|
|
|
| @dataclass |
| class TestConfig: |
| pretrained_model_name_or_path: str |
| revision: Optional[str] |
| validation_dataset: Dict |
| save_dir: str |
| seed: Optional[int] |
| validation_batch_size: int |
| dataloader_num_workers: int |
| |
| save_mode: str |
| local_rank: int |
|
|
| pipe_kwargs: Dict |
| pipe_validation_kwargs: Dict |
| unet_from_pretrained_kwargs: Dict |
| validation_guidance_scales: float |
| validation_grid_nrow: int |
|
|
| num_views: int |
| enable_xformers_memory_efficient_attention: bool |
| with_smpl: Optional[bool] |
| |
| recon_opt: Dict |
|
|
|
|
| def convert_to_numpy(tensor): |
| return tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() |
|
|
| def convert_to_pil(tensor): |
| return Image.fromarray(convert_to_numpy(tensor)) |
|
|
| def save_image(tensor, fp): |
| ndarr = convert_to_numpy(tensor) |
| |
| save_image_numpy(ndarr, fp) |
| return ndarr |
|
|
| def save_image_numpy(ndarr, fp): |
| im = Image.fromarray(ndarr) |
| im.save(fp) |
|
|
| def run_inference(dataloader, econdata, pipeline, carving, cfg: TestConfig, save_dir): |
| pipeline.set_progress_bar_config(disable=True) |
|
|
| if cfg.seed is None: |
| generator = None |
| else: |
| generator = torch.Generator(device=pipeline.unet.device).manual_seed(cfg.seed) |
| |
| images_cond, pred_cat = [], defaultdict(list) |
| for case_id, batch in tqdm(enumerate(dataloader)): |
| images_cond.append(batch['imgs_in'][:, 0]) |
| |
| imgs_in = torch.cat([batch['imgs_in']]*2, dim=0) |
| num_views = imgs_in.shape[1] |
| imgs_in = rearrange(imgs_in, "B Nv C H W -> (B Nv) C H W") |
| if cfg.with_smpl: |
| smpl_in = torch.cat([batch['smpl_imgs_in']]*2, dim=0) |
| smpl_in = rearrange(smpl_in, "B Nv C H W -> (B Nv) C H W") |
| else: |
| smpl_in = None |
|
|
| normal_prompt_embeddings, clr_prompt_embeddings = batch['normal_prompt_embeddings'], batch['color_prompt_embeddings'] |
| prompt_embeddings = torch.cat([normal_prompt_embeddings, clr_prompt_embeddings], dim=0) |
| prompt_embeddings = rearrange(prompt_embeddings, "B Nv N C -> (B Nv) N C") |
|
|
| with torch.autocast("cuda"): |
| |
| guidance_scale = cfg.validation_guidance_scales |
| unet_out = pipeline( |
| imgs_in, None, prompt_embeds=prompt_embeddings, |
| dino_feature=None, smpl_in=smpl_in, |
| generator=generator, guidance_scale=guidance_scale, output_type='pt', num_images_per_prompt=1, |
| **cfg.pipe_validation_kwargs |
| ) |
| |
| out = unet_out.images |
| bsz = out.shape[0] // 2 |
|
|
| normals_pred = out[:bsz] |
| images_pred = out[bsz:] |
| if cfg.save_mode == 'concat': |
| pred_cat[f"cfg{guidance_scale:.1f}"].append(torch.cat([normals_pred, images_pred], dim=-1)) |
| cur_dir = os.path.join(save_dir, f"cropsize-{cfg.validation_dataset.crop_size}-cfg{guidance_scale:.1f}-seed{cfg.seed}-smpl-{cfg.with_smpl}") |
| os.makedirs(cur_dir, exist_ok=True) |
| for i in range(bsz//num_views): |
| scene = batch['filename'][i].split('.')[0] |
|
|
| img_in_ = images_cond[-1][i].to(out.device) |
| vis_ = [img_in_] |
| for j in range(num_views): |
| idx = i*num_views + j |
| normal = normals_pred[idx] |
| color = images_pred[idx] |
| |
| vis_.append(color) |
| vis_.append(normal) |
|
|
| out_filename = f"{cur_dir}/{scene}.png" |
| vis_ = torch.stack(vis_, dim=0) |
| vis_ = make_grid(vis_, nrow=len(vis_), padding=0, value_range=(0, 1)) |
| save_image(vis_, out_filename) |
| elif cfg.save_mode == 'rgb': |
| for i in range(bsz//num_views): |
| scene = batch['filename'][i].split('.')[0] |
|
|
| img_in_ = images_cond[-1][i].to(out.device) |
| normals, colors = [], [] |
| for j in range(num_views): |
| idx = i*num_views + j |
| normal = normals_pred[idx] |
| if j == 0: |
| color = imgs_in[0].to(out.device) |
| else: |
| color = images_pred[idx] |
| if j in [3, 4]: |
| normal = torch.flip(normal, dims=[2]) |
| color = torch.flip(color, dims=[2]) |
| |
| colors.append(color) |
| if j == 6: |
| normal = F.interpolate(normal.unsqueeze(0), size=(256, 256), mode='bilinear', align_corners=False).squeeze(0) |
| normals.append(normal) |
| |
| |
| |
| |
| |
| |
| normals[0][:, :256, 256:512] = normals[-1] |
| |
| colors = [remove(convert_to_pil(tensor), session=session) for tensor in colors[:6]] |
| normals = [remove(convert_to_pil(tensor), session=session) for tensor in normals[:6]] |
| pose = econdata.__getitem__(case_id) |
| carving.optimize_case(scene, pose, colors, normals) |
| torch.cuda.empty_cache() |
| |
| |
|
|
| def load_pshuman_pipeline(cfg): |
| pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(cfg.pretrained_model_name_or_path, torch_dtype=weight_dtype) |
| pipeline.unet.enable_xformers_memory_efficient_attention() |
| if torch.cuda.is_available(): |
| pipeline.to('cuda') |
| return pipeline |
|
|
| def main( |
| cfg: TestConfig |
| ): |
|
|
| |
| if cfg.seed is not None: |
| set_seed(cfg.seed) |
| pipeline = load_pshuman_pipeline(cfg) |
| |
|
|
| if cfg.with_smpl: |
| from mvdiffusion.data.testdata_with_smpl import SingleImageDataset |
| else: |
| from mvdiffusion.data.single_image_dataset import SingleImageDataset |
| |
| |
| validation_dataset = SingleImageDataset( |
| **cfg.validation_dataset |
| ) |
| validation_dataloader = torch.utils.data.DataLoader( |
| validation_dataset, batch_size=cfg.validation_batch_size, shuffle=False, num_workers=cfg.dataloader_num_workers |
| ) |
| dataset_param = {'image_dir': validation_dataset.root_dir, 'seg_dir': None, 'colab': False, 'has_det': True, 'hps_type': 'pixie'} |
| econdata = SMPLDataset(dataset_param, device='cuda') |
|
|
| carving = ReMesh(cfg.recon_opt, econ_dataset=econdata) |
| run_inference(validation_dataloader, econdata, pipeline, carving, cfg, cfg.save_dir) |
| |
|
|
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--config', type=str, required=True) |
| args, extras = parser.parse_known_args() |
| from utils.misc import load_config |
|
|
| |
| cfg = load_config(args.config, cli_args=extras) |
| schema = OmegaConf.structured(TestConfig) |
| cfg = OmegaConf.merge(schema, cfg) |
| main(cfg) |
|
|