| | import torch |
| | from torch import nn |
| | from nsr.triplane import Triplane_fg_bg_plane |
| | |
| | from vit.vit_triplane import Triplane, ViTTriplaneDecomposed |
| | import argparse |
| | import inspect |
| | import dnnlib |
| | from guided_diffusion import dist_util |
| |
|
| | from pdb import set_trace as st |
| |
|
| | import vit.vision_transformer as vits |
| | from guided_diffusion import logger |
| | from .confnet import ConfNet |
| |
|
| | from ldm.modules.diffusionmodules.model import Encoder, MVEncoder, MVEncoderGS |
| | from ldm.modules.diffusionmodules.mv_unet import MVUNet, LGM_MVEncoder |
| |
|
| | |
| |
|
| | |
| |
|
| |
|
| | class AE(torch.nn.Module): |
| |
|
| | def __init__(self, |
| | encoder, |
| | decoder, |
| | img_size, |
| | encoder_cls_token, |
| | decoder_cls_token, |
| | preprocess, |
| | use_clip, |
| | dino_version='v1', |
| | clip_dtype=None, |
| | no_dim_up_mlp=False, |
| | dim_up_mlp_as_func=False, |
| | uvit_skip_encoder=False, |
| | confnet=None) -> None: |
| | super().__init__() |
| | self.encoder = encoder |
| | self.decoder = decoder |
| | self.img_size = img_size |
| | self.encoder_cls_token = encoder_cls_token |
| | self.decoder_cls_token = decoder_cls_token |
| | self.use_clip = use_clip |
| | self.dino_version = dino_version |
| | self.confnet = confnet |
| |
|
| | if self.dino_version == 'v2': |
| | self.encoder.mask_token = None |
| | self.decoder.vit_decoder.mask_token = None |
| |
|
| | if 'sd' not in self.dino_version: |
| |
|
| | self.uvit_skip_encoder = uvit_skip_encoder |
| | if uvit_skip_encoder: |
| | logger.log( |
| | f'enables uvit: length of vit_encoder.blocks: {len(self.encoder.blocks)}' |
| | ) |
| | for blk in self.encoder.blocks[len(self.encoder.blocks) // 2:]: |
| | blk.skip_linear = nn.Linear(2 * self.encoder.embed_dim, |
| | self.encoder.embed_dim) |
| |
|
| | |
| | nn.init.constant_(blk.skip_linear.weight, 0) |
| | if isinstance( |
| | blk.skip_linear, |
| | nn.Linear) and blk.skip_linear.bias is not None: |
| | nn.init.constant_(blk.skip_linear.bias, 0) |
| | else: |
| | logger.log(f'disable uvit') |
| | else: |
| | if 'dit' not in self.dino_version: |
| | self.decoder.vit_decoder.cls_token = None |
| | self.decoder.vit_decoder.patch_embed.proj = nn.Identity() |
| | self.decoder.triplane_decoder.planes = None |
| | self.decoder.vit_decoder.mask_token = None |
| |
|
| | if self.use_clip: |
| | self.clip_dtype = clip_dtype |
| |
|
| | else: |
| |
|
| | if not no_dim_up_mlp and self.encoder.embed_dim != self.decoder.vit_decoder.embed_dim: |
| | self.dim_up_mlp = nn.Linear( |
| | self.encoder.embed_dim, |
| | self.decoder.vit_decoder.embed_dim) |
| | logger.log( |
| | f"dim_up_mlp: {self.encoder.embed_dim} -> {self.decoder.vit_decoder.embed_dim}, as_func: {self.dim_up_mlp_as_func}" |
| | ) |
| | else: |
| | logger.log('ignore dim_up_mlp: ', no_dim_up_mlp) |
| |
|
| | self.preprocess = preprocess |
| |
|
| | self.dim_up_mlp = None |
| | self.dim_up_mlp_as_func = dim_up_mlp_as_func |
| |
|
| | |
| | |
| | torch.cuda.empty_cache() |
| | |
| | |
| | |
| |
|
| | def encode(self, *args, **kwargs): |
| | if not self.use_clip: |
| | if self.dino_version == 'v1': |
| | latent = self.encode_dinov1(*args, **kwargs) |
| | elif self.dino_version == 'v2': |
| | if self.uvit_skip_encoder: |
| | latent = self.encode_dinov2_uvit(*args, **kwargs) |
| | else: |
| | latent = self.encode_dinov2(*args, **kwargs) |
| | else: |
| | latent = self.encoder(*args) |
| |
|
| | else: |
| | latent = self.encode_clip(*args, **kwargs) |
| |
|
| | return latent |
| |
|
| | def encode_dinov1(self, x): |
| | |
| | x = self.encoder.prepare_tokens(x) |
| | for blk in self.encoder.blocks: |
| | x = blk(x) |
| | x = self.encoder.norm(x) |
| | if not self.encoder_cls_token: |
| | return x[:, 1:] |
| |
|
| | return x |
| |
|
| | def encode_dinov2(self, x): |
| | |
| | x = self.encoder.prepare_tokens_with_masks(x, masks=None) |
| | for blk in self.encoder.blocks: |
| | x = blk(x) |
| | x_norm = self.encoder.norm(x) |
| |
|
| | if not self.encoder_cls_token: |
| | return x_norm[:, 1:] |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | return x_norm |
| |
|
| | def encode_dinov2_uvit(self, x): |
| | |
| | x = self.encoder.prepare_tokens_with_masks(x, masks=None) |
| |
|
| | |
| | |
| |
|
| | skips = [x] |
| |
|
| | |
| | for blk in self.encoder.blocks[0:len(self.encoder.blocks) // 2 - 1]: |
| | x = blk(x) |
| | skips.append(x) |
| |
|
| | |
| | for blk in self.encoder.blocks[len(self.encoder.blocks) // 2 - |
| | 1:len(self.encoder.blocks) // 2]: |
| | x = blk(x) |
| |
|
| | |
| | for blk in self.encoder.blocks[len(self.encoder.blocks) // 2:]: |
| | x = x + blk.skip_linear(torch.cat( |
| | [x, skips.pop()], dim=-1)) |
| | x = blk(x) |
| |
|
| | x_norm = self.encoder.norm(x) |
| |
|
| | if not self.decoder_cls_token: |
| | return x_norm[:, 1:] |
| |
|
| | return x_norm |
| |
|
| | def encode_clip(self, x): |
| | |
| | |
| | |
| | x = self.encoder.conv1(x) |
| | x = x.reshape(x.shape[0], x.shape[1], |
| | -1) |
| | x = x.permute(0, 2, 1) |
| | x = torch.cat([ |
| | self.encoder.class_embedding.to(x.dtype) + torch.zeros( |
| | x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x |
| | ], |
| | dim=1) |
| | x = x + self.encoder.positional_embedding.to(x.dtype) |
| | x = self.encoder.ln_pre(x) |
| |
|
| | x = x.permute(1, 0, 2) |
| | x = self.encoder.transformer(x) |
| | x = x.permute(1, 0, 2) |
| | x = self.encoder.ln_post(x[:, 1:, :]) |
| |
|
| | return x |
| |
|
| | |
| |
|
| | |
| | |
| |
|
| | |
| |
|
| | def decode_wo_triplane(self, latent, c=None, img_size=None): |
| | if img_size is None: |
| | img_size = self.img_size |
| |
|
| | if self.dim_up_mlp is not None: |
| | if not self.dim_up_mlp_as_func: |
| | latent = self.dim_up_mlp(latent) |
| | |
| | else: |
| | return self.decoder.vit_decode( |
| | latent, img_size, |
| | dim_up_mlp=self.dim_up_mlp) |
| |
|
| | return self.decoder.vit_decode(latent, img_size, c=c) |
| |
|
| | def decode(self, latent, c, img_size=None, return_raw_only=False): |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | latent = self.decode_wo_triplane(latent, img_size=img_size, c=c) |
| | |
| | return self.decoder.triplane_decode(latent, c) |
| |
|
| | def decode_after_vae_no_render( |
| | self, |
| | ret_dict, |
| | img_size=None, |
| | ): |
| |
|
| | if img_size is None: |
| | img_size = self.img_size |
| |
|
| | assert self.dim_up_mlp is None |
| | |
| | |
| | |
| |
|
| | latent = self.decoder.vit_decode_backbone(ret_dict, img_size) |
| | ret_dict = self.decoder.vit_decode_postprocess(latent, ret_dict) |
| | return ret_dict |
| |
|
| | def decode_after_vae( |
| | self, |
| | |
| | ret_dict, |
| | c, |
| | img_size=None, |
| | return_raw_only=False): |
| | ret_dict = self.decode_after_vae_no_render(ret_dict, img_size) |
| | return self.decoder.triplane_decode(ret_dict, c) |
| |
|
| | def decode_confmap(self, img): |
| | assert self.confnet is not None |
| | |
| | |
| | return self.confnet(img) |
| |
|
| | def encode_decode(self, img, c, return_raw_only=False): |
| | latent = self.encode(img) |
| | pred = self.decode(latent, c, return_raw_only=return_raw_only) |
| | if self.confnet is not None: |
| | pred.update({ |
| | 'conf_sigma': self.decode_confmap(img) |
| | }) |
| |
|
| | return pred |
| |
|
| | def forward(self, |
| | img=None, |
| | c=None, |
| | latent=None, |
| | behaviour='enc_dec', |
| | coordinates=None, |
| | directions=None, |
| | return_raw_only=False, |
| | *args, |
| | **kwargs): |
| | """wrap all operations inside forward() for DDP use. |
| | """ |
| |
|
| | if behaviour == 'enc_dec': |
| | pred = self.encode_decode(img, c, return_raw_only=return_raw_only) |
| | return pred |
| |
|
| | elif behaviour == 'enc': |
| | latent = self.encode(img) |
| | return latent |
| |
|
| | elif behaviour == 'dec': |
| | assert latent is not None |
| | pred: dict = self.decode(latent, |
| | c, |
| | self.img_size, |
| | return_raw_only=return_raw_only) |
| | return pred |
| |
|
| | elif behaviour == 'dec_wo_triplane': |
| | assert latent is not None |
| | pred: dict = self.decode_wo_triplane(latent, self.img_size) |
| | return pred |
| |
|
| | elif behaviour == 'enc_dec_wo_triplane': |
| | latent = self.encode(img) |
| | pred: dict = self.decode_wo_triplane(latent, img_size=self.img_size, c=c) |
| | return pred |
| |
|
| | elif behaviour == 'encoder_vae': |
| | latent = self.encode(img) |
| | ret_dict = self.decoder.vae_reparameterization(latent, True) |
| | return ret_dict |
| |
|
| | elif behaviour == 'decode_after_vae_no_render': |
| | pred: dict = self.decode_after_vae_no_render(latent, self.img_size) |
| | return pred |
| |
|
| | elif behaviour == 'decode_after_vae': |
| | pred: dict = self.decode_after_vae(latent, c, self.img_size) |
| | return pred |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | elif behaviour == 'triplane_dec': |
| | assert latent is not None |
| | pred: dict = self.decoder.triplane_decode( |
| | latent, c, return_raw_only=return_raw_only, **kwargs) |
| | |
| |
|
| | elif behaviour == 'triplane_decode_grid': |
| | assert latent is not None |
| | pred: dict = self.decoder.triplane_decode_grid( |
| | latent, **kwargs) |
| | |
| |
|
| | elif behaviour == 'vit_postprocess_triplane_dec': |
| | assert latent is not None |
| | latent = self.decoder.vit_decode_postprocess( |
| | latent) |
| | pred: dict = self.decoder.triplane_decode( |
| | latent, c) |
| |
|
| | elif behaviour == 'triplane_renderer': |
| | assert latent is not None |
| | pred: dict = self.decoder.triplane_renderer( |
| | latent, coordinates, directions) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | elif behaviour == 'get_rendering_kwargs': |
| | pred = self.decoder.triplane_decoder.rendering_kwargs |
| |
|
| | return pred |
| |
|
| |
|
| | class AE_CLIPEncoder(AE): |
| |
|
| | def __init__(self, encoder, decoder, img_size, cls_token) -> None: |
| | super().__init__(encoder, decoder, img_size, cls_token) |
| |
|
| |
|
| | class AE_with_Diffusion(torch.nn.Module): |
| |
|
| | def __init__(self, auto_encoder, denoise_model) -> None: |
| | super().__init__() |
| | self.auto_encoder = auto_encoder |
| | self.denoise_model = denoise_model |
| |
|
| | def forward(self, |
| | img, |
| | c, |
| | behaviour='enc_dec', |
| | latent=None, |
| | *args, |
| | **kwargs): |
| | |
| | if behaviour == 'enc_dec': |
| | pred = self.auto_encoder(img, c) |
| | return pred |
| | elif behaviour == 'enc': |
| | latent = self.auto_encoder.encode(img) |
| | if self.auto_encoder.dim_up_mlp is not None: |
| | latent = self.auto_encoder.dim_up_mlp(latent) |
| | return latent |
| | elif behaviour == 'dec': |
| | assert latent is not None |
| | pred: dict = self.auto_encoder.decode(latent, c, self.img_size) |
| | return pred |
| | elif behaviour == 'denoise': |
| | assert latent is not None |
| | pred: dict = self.denoise_model(*args, **kwargs) |
| | return pred |
| |
|
| |
|
| | def eg3d_options_default(): |
| |
|
| | opts = dnnlib.EasyDict( |
| | dict( |
| | cbase=32768, |
| | cmax=512, |
| | map_depth=2, |
| | g_class_name='nsr.triplane.TriPlaneGenerator', |
| | g_num_fp16_res=0, |
| | )) |
| |
|
| | return opts |
| |
|
| |
|
| | def rendering_options_defaults(opts): |
| |
|
| | rendering_options = { |
| | |
| | 'image_resolution': 256, |
| | 'disparity_space_sampling': False, |
| | 'clamp_mode': 'softplus', |
| | 'c_gen_conditioning_zero': |
| | True, |
| | |
| | 'c_scale': |
| | opts.c_scale, |
| | 'superresolution_noise_mode': 'none', |
| | 'density_reg': opts.density_reg, |
| | 'density_reg_p_dist': opts. |
| | density_reg_p_dist, |
| | 'reg_type': opts. |
| | reg_type, |
| | 'decoder_lr_mul': 1, |
| | |
| | 'decoder_activation': 'sigmoid', |
| | 'sr_antialias': True, |
| | 'return_triplane_features': False, |
| | 'return_sampling_details_flag': False, |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | 'superresolution_module': 'utils.torch_utils.components.NearestConvSR', |
| | } |
| |
|
| | if opts.cfg == 'ffhq': |
| | rendering_options.update({ |
| | 'superresolution_module': |
| | 'nsr.superresolution.SuperresolutionHybrid8XDC', |
| | 'focal': 2985.29 / 700, |
| | 'depth_resolution': |
| | 48 - 0, |
| | 'depth_resolution_importance': |
| | 48 - 0, |
| | 'bg_depth_resolution': |
| | 16, |
| | 'ray_start': |
| | 2.25, |
| | 'ray_end': |
| | 3.3, |
| | 'box_warp': |
| | 1, |
| | 'avg_camera_radius': |
| | 2.7, |
| | 'avg_camera_pivot': [ |
| | 0, 0, 0.2 |
| | ], |
| | 'superresolution_noise_mode': 'random', |
| | }) |
| | elif opts.cfg == 'afhq': |
| | rendering_options.update({ |
| | 'superresolution_module': |
| | 'nsr.superresolution.SuperresolutionHybrid8X', |
| | 'superresolution_noise_mode': 'random', |
| | 'focal': 4.2647, |
| | 'depth_resolution': 48, |
| | 'depth_resolution_importance': 48, |
| | 'ray_start': 2.25, |
| | 'ray_end': 3.3, |
| | 'box_warp': 1, |
| | 'avg_camera_radius': 2.7, |
| | 'avg_camera_pivot': [0, 0, -0.06], |
| | }) |
| | elif opts.cfg == 'shapenet': |
| | rendering_options.update({ |
| | 'depth_resolution': 64, |
| | 'depth_resolution_importance': 64, |
| | |
| | 'ray_start': 0.2, |
| | 'ray_end': 2.2, |
| | |
| | |
| | 'box_warp': 2, |
| | 'white_back': True, |
| | 'avg_camera_radius': 1.2, |
| | 'avg_camera_pivot': [0, 0, 0], |
| | }) |
| |
|
| | elif opts.cfg == 'eg3d_shapenet_aug_resolution': |
| | rendering_options.update({ |
| | 'depth_resolution': 80, |
| | 'depth_resolution_importance': 80, |
| | 'ray_start': 0.1, |
| | 'ray_end': 1.9, |
| | 'box_warp': 1.1, |
| | 'white_back': True, |
| | 'avg_camera_radius': 1.2, |
| | 'avg_camera_pivot': [0, 0, 0], |
| | }) |
| |
|
| | elif opts.cfg == 'eg3d_shapenet_aug_resolution_chair': |
| | rendering_options.update({ |
| | 'depth_resolution': 96, |
| | 'depth_resolution_importance': 96, |
| | 'ray_start': 0.1, |
| | 'ray_end': 1.9, |
| | 'box_warp': 1.1, |
| | 'white_back': True, |
| | 'avg_camera_radius': 1.2, |
| | 'avg_camera_pivot': [0, 0, 0], |
| | }) |
| |
|
| | elif opts.cfg == 'eg3d_shapenet_aug_resolution_chair_128': |
| | rendering_options.update({ |
| | 'depth_resolution': 128, |
| | 'depth_resolution_importance': 128, |
| | 'ray_start': 0.1, |
| | 'ray_end': 1.9, |
| | 'box_warp': 1.1, |
| | 'white_back': True, |
| | 'avg_camera_radius': 1.2, |
| | 'avg_camera_pivot': [0, 0, 0], |
| | }) |
| |
|
| | elif opts.cfg == 'eg3d_shapenet_aug_resolution_chair_64': |
| | rendering_options.update({ |
| | 'depth_resolution': 64, |
| | 'depth_resolution_importance': 64, |
| | 'ray_start': 0.1, |
| | 'ray_end': 1.9, |
| | 'box_warp': 1.1, |
| | 'white_back': True, |
| | 'avg_camera_radius': 1.2, |
| | 'avg_camera_pivot': [0, 0, 0], |
| | }) |
| |
|
| | elif opts.cfg == 'srn_shapenet_aug_resolution_chair_128': |
| | rendering_options.update({ |
| | 'depth_resolution': 128, |
| | 'depth_resolution_importance': 128, |
| | 'ray_start': 1.25, |
| | 'ray_end': 2.75, |
| | 'box_warp': 1.5, |
| | 'white_back': True, |
| | 'avg_camera_radius': 2, |
| | 'avg_camera_pivot': [0, 0, 0], |
| | }) |
| |
|
| | elif opts.cfg == 'eg3d_shapenet_aug_resolution_chair_128_residualSR': |
| | rendering_options.update({ |
| | 'depth_resolution': |
| | 128, |
| | 'depth_resolution_importance': |
| | 128, |
| | 'ray_start': |
| | 0.1, |
| | 'ray_end': |
| | 1.9, |
| | 'box_warp': |
| | 1.1, |
| | 'white_back': |
| | True, |
| | 'avg_camera_radius': |
| | 1.2, |
| | 'avg_camera_pivot': [0, 0, 0], |
| | 'superresolution_module': |
| | 'utils.torch_utils.components.NearestConvSR_Residual', |
| | }) |
| |
|
| | elif opts.cfg == 'shapenet_tuneray': |
| | rendering_options.update({ |
| | 'depth_resolution': 64, |
| | 'depth_resolution_importance': 64, |
| | |
| | 'ray_start': opts.ray_start, |
| | 'ray_end': opts.ray_end, |
| | 'box_warp': |
| | opts.ray_end - opts.ray_start, |
| | 'white_back': True, |
| | 'avg_camera_radius': 1.2, |
| | 'avg_camera_pivot': [0, 0, 0], |
| | }) |
| |
|
| | elif opts.cfg == 'shapenet_tuneray_aug_resolution': |
| | rendering_options.update({ |
| | 'depth_resolution': 80, |
| | 'depth_resolution_importance': 80, |
| | |
| | 'ray_start': opts.ray_start, |
| | 'ray_end': opts.ray_end, |
| | 'box_warp': |
| | opts.ray_end - opts.ray_start, |
| | 'white_back': True, |
| | 'avg_camera_radius': 1.2, |
| | 'avg_camera_pivot': [0, 0, 0], |
| | }) |
| |
|
| | elif opts.cfg == 'shapenet_tuneray_aug_resolution_64': |
| | rendering_options.update({ |
| | 'depth_resolution': 128, |
| | 'depth_resolution_importance': 128, |
| | |
| | 'ray_start': opts.ray_start, |
| | 'ray_end': opts.ray_end, |
| | 'box_warp': |
| | opts.ray_end - opts.ray_start, |
| | 'white_back': True, |
| | 'avg_camera_radius': 1.2, |
| | 'avg_camera_pivot': [0, 0, 0], |
| | }) |
| |
|
| | elif opts.cfg == 'shapenet_tuneray_aug_resolution_64_96': |
| | rendering_options.update({ |
| | 'depth_resolution': 96, |
| | 'depth_resolution_importance': 96, |
| | |
| | 'ray_start': opts.ray_start, |
| | 'ray_end': opts.ray_end, |
| | 'box_warp': |
| | opts.ray_end - opts.ray_start, |
| | 'white_back': True, |
| | 'avg_camera_radius': 1.2, |
| | 'avg_camera_pivot': [0, 0, 0], |
| | }) |
| | |
| | elif opts.cfg == 'shapenet_tuneray_aug_resolution_64_96_nearestSR': |
| | rendering_options.update({ |
| | 'depth_resolution': |
| | 96, |
| | 'depth_resolution_importance': |
| | 96, |
| | |
| | 'ray_start': |
| | opts.ray_start, |
| | 'ray_end': |
| | opts.ray_end, |
| | 'box_warp': |
| | opts.ray_end - opts.ray_start, |
| | 'white_back': |
| | True, |
| | 'avg_camera_radius': |
| | 1.2, |
| | 'avg_camera_pivot': [0, 0, 0], |
| | 'superresolution_module': |
| | 'utils.torch_utils.components.NearestConvSR', |
| | }) |
| |
|
| | |
| | elif opts.cfg == 'shapenet_tuneray_aug_resolution_64_64_nearestSR': |
| | rendering_options.update({ |
| | 'depth_resolution': |
| | 64, |
| | 'depth_resolution_importance': |
| | 64, |
| | |
| | 'ray_start': |
| | opts.ray_start, |
| | 'ray_end': |
| | opts.ray_end, |
| | 'box_warp': |
| | opts.ray_end - opts.ray_start, |
| | 'white_back': |
| | True, |
| | 'avg_camera_radius': |
| | 1.2, |
| | 'avg_camera_pivot': [0, 0, 0], |
| | 'superresolution_module': |
| | 'utils.torch_utils.components.NearestConvSR', |
| | }) |
| |
|
| | |
| | elif opts.cfg == 'shapenet_tuneray_aug_resolution_64_64_nearestSR_patch': |
| | rendering_options.update({ |
| | 'depth_resolution': |
| | 64, |
| | 'depth_resolution_importance': |
| | 64, |
| | |
| | 'ray_start': |
| | opts.ray_start, |
| | 'ray_end': |
| | opts.ray_end, |
| | 'box_warp': |
| | opts.ray_end - opts.ray_start, |
| | 'white_back': |
| | True, |
| | 'avg_camera_radius': |
| | 1.2, |
| | 'avg_camera_pivot': [0, 0, 0], |
| | 'superresolution_module': |
| | 'utils.torch_utils.components.NearestConvSR', |
| | |
| | 'PatchRaySampler': |
| | True, |
| | |
| | |
| | 'patch_rendering_resolution': |
| | opts.patch_rendering_resolution, |
| | }) |
| |
|
| | elif opts.cfg == 'objverse_tuneray_aug_resolution_64_64_nearestSR': |
| | rendering_options.update({ |
| | 'depth_resolution': |
| | 64, |
| | 'depth_resolution_importance': |
| | 64, |
| | |
| | 'ray_start': |
| | opts.ray_start, |
| | |
| | 'ray_end': |
| | opts.ray_end, |
| | |
| | 'box_warp': |
| | opts.ray_end - opts.ray_start, |
| | |
| | 'white_back': |
| | True, |
| | 'avg_camera_radius': |
| | 1.946, |
| | 'avg_camera_pivot': [0, 0, 0], |
| | 'superresolution_module': |
| | 'utils.torch_utils.components.NearestConvSR', |
| | |
| | |
| | |
| | |
| | |
| | }) |
| |
|
| | elif opts.cfg == 'objverse_tuneray_aug_resolution_64_64_auto': |
| | rendering_options.update({ |
| | 'depth_resolution': |
| | 64, |
| | 'depth_resolution_importance': |
| | 64, |
| | |
| | 'ray_start': |
| | 'auto', |
| | 'ray_end': |
| | 'auto', |
| | 'box_warp': |
| | 0.9, |
| | 'white_back': |
| | True, |
| | 'radius_range': [1.5,2], |
| | |
| | |
| | 'sampler_bbox_min': |
| | -0.45, |
| | 'sampler_bbox_max': |
| | 0.45, |
| | |
| | 'filter_out_of_bbox': |
| | True, |
| | |
| | |
| | |
| | 'PatchRaySampler': |
| | True, |
| | |
| | |
| | 'patch_rendering_resolution': |
| | opts.patch_rendering_resolution, |
| | }) |
| | rendering_options['z_near'] = rendering_options['radius_range'][0]+rendering_options['sampler_bbox_min'] |
| | rendering_options['z_far'] = rendering_options['radius_range'][1]+rendering_options['sampler_bbox_max'] |
| |
|
| | elif opts.cfg == 'shapenet_tuneray_aug_resolution_64_96_nearestResidualSR': |
| | rendering_options.update({ |
| | 'depth_resolution': |
| | 96, |
| | 'depth_resolution_importance': |
| | 96, |
| | |
| | 'ray_start': |
| | opts.ray_start, |
| | 'ray_end': |
| | opts.ray_end, |
| | 'box_warp': |
| | opts.ray_end - opts.ray_start, |
| | 'white_back': |
| | True, |
| | 'avg_camera_radius': |
| | 1.2, |
| | 'avg_camera_pivot': [0, 0, 0], |
| | 'superresolution_module': |
| | 'utils.torch_utils.components.NearestConvSR_Residual', |
| | }) |
| |
|
| | elif opts.cfg == 'shapenet_tuneray_aug_resolution_64_64_nearestResidualSR': |
| | rendering_options.update({ |
| | 'depth_resolution': |
| | 64, |
| | 'depth_resolution_importance': |
| | 64, |
| | |
| | 'ray_start': |
| | opts.ray_start, |
| | 'ray_end': |
| | opts.ray_end, |
| | 'box_warp': |
| | opts.ray_end - opts.ray_start, |
| | 'white_back': |
| | True, |
| | 'avg_camera_radius': |
| | 1.2, |
| | 'avg_camera_pivot': [0, 0, 0], |
| | 'superresolution_module': |
| | 'utils.torch_utils.components.NearestConvSR_Residual', |
| | }) |
| |
|
| | elif opts.cfg == 'shapenet_tuneray_aug_resolution_64_104': |
| | rendering_options.update({ |
| | 'depth_resolution': 104, |
| | 'depth_resolution_importance': 104, |
| | |
| | 'ray_start': opts.ray_start, |
| | 'ray_end': opts.ray_end, |
| | 'box_warp': |
| | opts.ray_end - opts.ray_start, |
| | 'white_back': True, |
| | 'avg_camera_radius': 1.2, |
| | 'avg_camera_pivot': [0, 0, 0], |
| | }) |
| |
|
| | rendering_options.update({'return_sampling_details_flag': True}) |
| | rendering_options.update({'return_sampling_details_flag': True}) |
| |
|
| | return rendering_options |
| |
|
| |
|
| | def model_encoder_defaults(): |
| |
|
| | return dict( |
| | use_clip=False, |
| | arch_encoder="vits", |
| | arch_decoder="vits", |
| | load_pretrain_encoder=False, |
| | encoder_lr=1e-5, |
| | encoder_weight_decay= |
| | 0.001, |
| | no_dim_up_mlp=False, |
| | dim_up_mlp_as_func=False, |
| | decoder_load_pretrained=True, |
| | uvit_skip_encoder=False, |
| | |
| | vae_p=1, |
| | ldm_z_channels=4, |
| | ldm_embed_dim=4, |
| | use_conf_map=False, |
| | |
| | sd_E_ch=64, |
| | z_channels=3*4, |
| | sd_E_num_res_blocks=1, |
| | |
| | arch_dit_decoder='DiT2-B/2', |
| | return_all_dit_layers=False, |
| | |
| | |
| | |
| | |
| | lrm_decoder=False, |
| | gs_rendering=False, |
| | ) |
| |
|
| |
|
| | def triplane_decoder_defaults(): |
| | opts = dict( |
| | triplane_fg_bg=False, |
| | cfg='shapenet', |
| | density_reg=0.25, |
| | density_reg_p_dist=0.004, |
| | reg_type='l1', |
| | triplane_decoder_lr=0.0025, |
| | super_resolution_lr=0.0025, |
| | |
| | c_scale=1, |
| | nsr_lr=0.02, |
| | triplane_size=224, |
| | decoder_in_chans=32, |
| | triplane_in_chans=-1, |
| | decoder_output_dim=3, |
| | out_chans=96, |
| | c_dim=25, |
| | |
| | |
| | ray_start=0.6, |
| | ray_end=1.8, |
| | rendering_kwargs={}, |
| | sr_training=False, |
| | bcg_synthesis=False, |
| | bcg_synthesis_kwargs={}, |
| | |
| | image_size=128, |
| | patch_rendering_resolution=45, |
| | ) |
| |
|
| | |
| | |
| |
|
| | |
| | |
| |
|
| | return opts |
| |
|
| |
|
| | def vit_decoder_defaults(): |
| | res = dict( |
| | vit_decoder_lr=1e-5, |
| | vit_decoder_wd=0.001, |
| | ) |
| | return res |
| |
|
| |
|
| | def nsr_decoder_defaults(): |
| | res = { |
| | 'decomposed': False, |
| | } |
| | res.update(triplane_decoder_defaults()) |
| | res.update(vit_decoder_defaults()) |
| | return res |
| |
|
| |
|
| | def loss_defaults(): |
| | opt = dict( |
| | color_criterion='mse', |
| | l2_lambda=1.0, |
| | lpips_lambda=0., |
| | lpips_delay_iter=0, |
| | sr_delay_iter=0, |
| | |
| | kl_anneal=False, |
| | latent_lambda=0., |
| | latent_criterion='mse', |
| | kl_lambda=0.0, |
| | |
| | ssim_lambda=0., |
| | l1_lambda=0., |
| | id_lambda=0.0, |
| | depth_lambda=0.0, |
| | alpha_lambda=0.0, |
| | fg_mse=False, |
| | bg_lamdba=0.0, |
| | density_reg=0.0, |
| | density_reg_p_dist=0.004, |
| | density_reg_every=4, |
| |
|
| | |
| | shape_uniform_lambda=0.005, |
| | shape_importance_lambda=0.01, |
| | shape_depth_lambda=0., |
| |
|
| | |
| | rec_cvD_lambda=0.01, |
| | nvs_cvD_lambda=0.025, |
| | patchgan_disc_factor=0.01, |
| | patchgan_disc_g_weight=0.2, |
| | r1_gamma=1.0, |
| | sds_lamdba=1.0, |
| | nvs_D_lr_mul=1, |
| | cano_D_lr_mul=1, |
| |
|
| | |
| | ce_balanced_kl=1., |
| | p_eps_lambda=1, |
| | |
| | symmetry_loss=False, |
| | depth_smoothness_lambda=0.0, |
| | ce_lambda=1.0, |
| | negative_entropy_lambda=1.0, |
| | grad_clip=False, |
| | online_mask=False, |
| | ) |
| | return opt |
| |
|
| |
|
| | def dataset_defaults(): |
| | res = dict( |
| | use_lmdb=False, |
| | use_wds=False, |
| | use_lmdb_compressed=True, |
| | compile=False, |
| | interval=1, |
| | objv_dataset=False, |
| | decode_encode_img_only=False, |
| | load_wds_diff=False, |
| | load_wds_latent=False, |
| | eval_load_wds_instance=True, |
| | shards_lst="", |
| | eval_shards_lst="", |
| | mv_input=False, |
| | duplicate_sample=True, |
| | orthog_duplicate=False, |
| | split_chunk_input=False, |
| | load_real=False, |
| | four_view_for_latent=False, |
| | single_view_for_i23d=False, |
| | shuffle_across_cls=False, |
| | load_extra_36_view=False, |
| | mv_latent_dir='', |
| | append_depth=False, |
| | plucker_embedding=False, |
| | gs_cam_format=False, |
| | ) |
| | return res |
| |
|
| |
|
| | def encoder_and_nsr_defaults(): |
| | """ |
| | Defaults for image training. |
| | """ |
| | |
| | res = dict( |
| | dino_version='v1', |
| | encoder_in_channels=3, |
| | img_size=[224], |
| | patch_size=16, |
| | in_chans=384, |
| | num_classes=0, |
| | embed_dim=384, |
| | depth=6, |
| | num_heads=16, |
| | mlp_ratio=4., |
| | qkv_bias=False, |
| | qk_scale=None, |
| | drop_rate=0.1, |
| | attn_drop_rate=0., |
| | drop_path_rate=0., |
| | norm_layer='nn.LayerNorm', |
| | |
| | cls_token=False, |
| | |
| | |
| | encoder_cls_token=False, |
| | decoder_cls_token=False, |
| | sr_kwargs={}, |
| | sr_ratio=2, |
| | |
| | ) |
| | |
| | res.update(model_encoder_defaults()) |
| | res.update(nsr_decoder_defaults()) |
| | res.update( |
| | ae_classname='vit.vit_triplane.ViTTriplaneDecomposed') |
| | return res |
| |
|
| |
|
| | def create_3DAE_model( |
| | arch_encoder, |
| | arch_decoder, |
| | dino_version='v1', |
| | img_size=[224], |
| | patch_size=16, |
| | in_chans=384, |
| | num_classes=0, |
| | embed_dim=1024, |
| | depth=6, |
| | num_heads=16, |
| | mlp_ratio=4., |
| | qkv_bias=False, |
| | qk_scale=None, |
| | drop_rate=0.1, |
| | attn_drop_rate=0., |
| | drop_path_rate=0., |
| | |
| | norm_layer='nn.LayerNorm', |
| | out_chans=96, |
| | decoder_in_chans=32, |
| | triplane_in_chans=-1, |
| | decoder_output_dim=32, |
| | encoder_cls_token=False, |
| | decoder_cls_token=False, |
| | c_dim=25, |
| | image_size=128, |
| | img_channels=3, |
| | rendering_kwargs={}, |
| | load_pretrain_encoder=False, |
| | decomposed=True, |
| | triplane_size=224, |
| | ae_classname='ViTTriplaneDecomposed', |
| | use_clip=False, |
| | sr_kwargs={}, |
| | sr_ratio=2, |
| | no_dim_up_mlp=False, |
| | dim_up_mlp_as_func=False, |
| | decoder_load_pretrained=True, |
| | uvit_skip_encoder=False, |
| | bcg_synthesis_kwargs={}, |
| | |
| | vae_p=1, |
| | ldm_z_channels=4, |
| | ldm_embed_dim=4, |
| | use_conf_map=False, |
| | triplane_fg_bg=False, |
| | encoder_in_channels=3, |
| | sd_E_ch=64, |
| | z_channels=3*4, |
| | sd_E_num_res_blocks=1, |
| | arch_dit_decoder='DiT2-B/2', |
| | lrm_decoder=False, |
| | gs_rendering=False, |
| | return_all_dit_layers=False, |
| | *args, |
| | **kwargs): |
| |
|
| | |
| |
|
| | preprocess = None |
| | clip_dtype = None |
| | if load_pretrain_encoder: |
| | if not use_clip: |
| | if dino_version == 'v1': |
| | encoder = torch.hub.load( |
| | 'facebookresearch/dino:main', |
| | 'dino_{}{}'.format(arch_encoder, patch_size)) |
| | logger.log( |
| | f'loaded pre-trained dino v1 ViT-S{patch_size} encoder ckpt' |
| | ) |
| | elif dino_version == 'v2': |
| | encoder = torch.hub.load( |
| | 'facebookresearch/dinov2', |
| | 'dinov2_{}{}'.format(arch_encoder, patch_size)) |
| | logger.log( |
| | f'loaded pre-trained dino v2 {arch_encoder}{patch_size} encoder ckpt' |
| | ) |
| | elif 'sd' in dino_version: |
| |
|
| | if 'mv' in dino_version: |
| | if 'lgm' in dino_version: |
| | encoder_cls = MVUNet( |
| | input_size=256, |
| | up_channels=(1024, 1024, 512, 256, |
| | 128), |
| | up_attention=(True, True, True, False, False), |
| | splat_size=128, |
| | output_size= |
| | 512, |
| | batch_size=8, |
| | num_views=8, |
| | gradient_accumulation_steps=1, |
| | |
| | ) |
| | elif 'gs' in dino_version: |
| | encoder_cls = MVEncoder |
| | else: |
| | encoder_cls = MVEncoder |
| |
|
| | else: |
| | encoder_cls = Encoder |
| |
|
| | encoder = encoder_cls( |
| | double_z=True, |
| | resolution=256, |
| | in_channels=encoder_in_channels, |
| | |
| | ch=64, |
| | |
| | |
| | ch_mult=[1, 2, 4, 4], |
| | num_res_blocks=1, |
| | dropout=0.0, |
| | attn_resolutions=[], |
| | out_ch=3, |
| | z_channels=4 * 3, |
| | ) |
| | else: |
| | raise NotImplementedError() |
| |
|
| | else: |
| | import clip |
| | model, preprocess = clip.load("ViT-B/16", device=dist_util.dev()) |
| | model.float() |
| | clip_dtype = model.dtype |
| | encoder = getattr( |
| | model, 'visual') |
| | encoder.requires_grad_(False) |
| | logger.log( |
| | f'loaded pre-trained CLIP ViT-B{patch_size} encoder, fixed.') |
| |
|
| | elif 'sd' in dino_version: |
| | attn_kwargs = {} |
| | if 'mv' in dino_version: |
| | if 'lgm' in dino_version: |
| | encoder = LGM_MVEncoder( |
| | in_channels=9, |
| | |
| | up_channels=(1024, 1024, 512, 256, |
| | 128), |
| | up_attention=(True, True, True, False, False), |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | ) |
| |
|
| | elif 'gs' in dino_version: |
| | encoder_cls = MVEncoderGS |
| | attn_kwargs = { |
| | 'n_heads': 8, |
| | 'd_head': 64, |
| | } |
| |
|
| | else: |
| | encoder_cls = MVEncoder |
| | attn_kwargs = { |
| | 'n_heads': 8, |
| | 'd_head': 64, |
| | } |
| |
|
| | else: |
| | encoder_cls = Encoder |
| |
|
| | if 'lgm' not in dino_version: |
| | |
| | encoder = encoder_cls( |
| | double_z=True, |
| | resolution=256, |
| | in_channels=encoder_in_channels, |
| | |
| | |
| | ch=sd_E_ch, |
| | |
| | |
| | ch_mult=[1, 2, 4, 4], |
| | |
| | num_res_blocks=sd_E_num_res_blocks, |
| | dropout=0.0, |
| | attn_resolutions=[], |
| | out_ch=3, |
| | z_channels=z_channels, |
| | attn_kwargs=attn_kwargs, |
| | ) |
| |
|
| | else: |
| | encoder = vits.__dict__[arch_encoder]( |
| | patch_size=patch_size, |
| | drop_path_rate=drop_path_rate, |
| | img_size=img_size) |
| |
|
| | |
| | |
| | if triplane_in_chans == -1: |
| | triplane_in_chans = decoder_in_chans |
| |
|
| | |
| | |
| | |
| | triplane_renderer_cls = Triplane |
| |
|
| | |
| | triplane_decoder = triplane_renderer_cls( |
| | c_dim, |
| | image_size, |
| | img_channels, |
| | rendering_kwargs=rendering_kwargs, |
| | out_chans=out_chans, |
| | |
| | triplane_size=triplane_size, |
| | decoder_in_chans=triplane_in_chans, |
| | decoder_output_dim=decoder_output_dim, |
| | sr_kwargs=sr_kwargs, |
| | bcg_synthesis_kwargs=bcg_synthesis_kwargs, |
| | lrm_decoder=lrm_decoder) |
| |
|
| | if load_pretrain_encoder: |
| |
|
| | if dino_version == 'v1': |
| | vit_decoder = torch.hub.load( |
| | 'facebookresearch/dino:main', |
| | 'dino_{}{}'.format(arch_decoder, patch_size)) |
| | logger.log( |
| | 'loaded pre-trained decoder', |
| | "facebookresearch/dino:main', 'dino_{}{}".format( |
| | arch_decoder, patch_size)) |
| | else: |
| |
|
| | vit_decoder = torch.hub.load( |
| | 'facebookresearch/dinov2', |
| | |
| | 'dinov2_{}{}'.format(arch_decoder, patch_size), |
| | pretrained=decoder_load_pretrained) |
| | logger.log( |
| | 'loaded pre-trained decoder', |
| | "facebookresearch/dinov2', 'dinov2_{}{}".format( |
| | arch_decoder, |
| | patch_size), 'pretrianed=', decoder_load_pretrained) |
| |
|
| | elif 'dit' in dino_version: |
| | from dit.dit_decoder import DiT2_models |
| |
|
| | vit_decoder = DiT2_models[arch_dit_decoder]( |
| | input_size=16, |
| | num_classes=0, |
| | learn_sigma=False, |
| | in_channels=embed_dim, |
| | mixed_prediction=False, |
| | context_dim=None, |
| | roll_out=True, plane_n=4 if |
| | 'gs' in dino_version else 3, |
| | return_all_layers=return_all_dit_layers, |
| | ) |
| |
|
| | else: |
| | vit_decoder = vits.__dict__[arch_decoder]( |
| | patch_size=patch_size, |
| | drop_path_rate=drop_path_rate, |
| | img_size=img_size) |
| |
|
| | |
| | |
| | decoder_kwargs = dict( |
| | class_name=ae_classname, |
| | vit_decoder=vit_decoder, |
| | triplane_decoder=triplane_decoder, |
| | |
| | cls_token=decoder_cls_token, |
| | sr_ratio=sr_ratio, |
| | vae_p=vae_p, |
| | ldm_z_channels=ldm_z_channels, |
| | ldm_embed_dim=ldm_embed_dim, |
| | ) |
| | decoder = dnnlib.util.construct_class_by_name(**decoder_kwargs) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | if use_conf_map: |
| | confnet = ConfNet(cin=3, cout=1, nf=64, zdim=128) |
| | else: |
| | confnet = None |
| |
|
| | auto_encoder = AE( |
| | encoder, |
| | decoder, |
| | img_size[0], |
| | encoder_cls_token, |
| | decoder_cls_token, |
| | preprocess, |
| | use_clip, |
| | dino_version, |
| | clip_dtype, |
| | no_dim_up_mlp=no_dim_up_mlp, |
| | dim_up_mlp_as_func=dim_up_mlp_as_func, |
| | uvit_skip_encoder=uvit_skip_encoder, |
| | confnet=confnet, |
| | ) |
| |
|
| | logger.log(auto_encoder) |
| | torch.cuda.empty_cache() |
| |
|
| | return auto_encoder |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | def create_Triplane( |
| | c_dim=25, |
| | img_resolution=128, |
| | img_channels=3, |
| | rendering_kwargs={}, |
| | decoder_output_dim=32, |
| | *args, |
| | **kwargs): |
| |
|
| | decoder = Triplane( |
| | c_dim, |
| | img_resolution, |
| | img_channels, |
| | |
| | rendering_kwargs=rendering_kwargs, |
| | create_triplane=True, |
| | decoder_output_dim=decoder_output_dim) |
| | return decoder |
| |
|
| |
|
| | def DiT_defaults(): |
| | return { |
| | 'dit_model': "DiT-B/16", |
| | 'vae': "ema" |
| | |
| | |
| | } |
| |
|