| from typing import Any, Dict |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from .sdf import CrossAttentionPointCloudSDFModel |
| from .transformer import ( |
| CLIPImageGridPointDiffusionTransformer, |
| CLIPImageGridUpsamplePointDiffusionTransformer, |
| CLIPImagePointDiffusionTransformer, |
| PointDiffusionTransformer, |
| UpsamplePointDiffusionTransformer, |
| ) |
|
|
| MODEL_CONFIGS = { |
| "base40M-imagevec": { |
| "cond_drop_prob": 0.1, |
| "heads": 8, |
| "init_scale": 0.25, |
| "input_channels": 6, |
| "layers": 12, |
| "n_ctx": 1024, |
| "name": "CLIPImagePointDiffusionTransformer", |
| "output_channels": 12, |
| "time_token_cond": True, |
| "token_cond": True, |
| "width": 512, |
| }, |
| "base40M-textvec": { |
| "cond_drop_prob": 0.1, |
| "heads": 8, |
| "init_scale": 0.25, |
| "input_channels": 6, |
| "layers": 12, |
| "n_ctx": 1024, |
| "name": "CLIPImagePointDiffusionTransformer", |
| "output_channels": 12, |
| "time_token_cond": True, |
| "token_cond": True, |
| "width": 512, |
| }, |
| "base40M-uncond": { |
| "heads": 8, |
| "init_scale": 0.25, |
| "input_channels": 6, |
| "layers": 12, |
| "n_ctx": 1024, |
| "name": "PointDiffusionTransformer", |
| "output_channels": 12, |
| "time_token_cond": True, |
| "width": 512, |
| }, |
| "base40M": { |
| "cond_drop_prob": 0.1, |
| "heads": 8, |
| "init_scale": 0.25, |
| "input_channels": 6, |
| "layers": 12, |
| "n_ctx": 1024, |
| "name": "CLIPImageGridPointDiffusionTransformer", |
| "output_channels": 12, |
| "time_token_cond": True, |
| "width": 512, |
| }, |
| "base300M": { |
| "cond_drop_prob": 0.1, |
| "heads": 16, |
| "init_scale": 0.25, |
| "input_channels": 6, |
| "layers": 24, |
| "n_ctx": 1024, |
| "name": "CLIPImageGridPointDiffusionTransformer", |
| "output_channels": 12, |
| "time_token_cond": True, |
| "width": 1024, |
| }, |
| "base1B": { |
| "cond_drop_prob": 0.1, |
| "heads": 32, |
| "init_scale": 0.25, |
| "input_channels": 6, |
| "layers": 24, |
| "n_ctx": 1024, |
| "name": "CLIPImageGridPointDiffusionTransformer", |
| "output_channels": 12, |
| "time_token_cond": True, |
| "width": 2048, |
| }, |
| "upsample": { |
| "channel_biases": [0.0, 0.0, 0.0, -1.0, -1.0, -1.0], |
| "channel_scales": [2.0, 2.0, 2.0, 0.007843137255, 0.007843137255, 0.007843137255], |
| "cond_ctx": 1024, |
| "cond_drop_prob": 0.1, |
| "heads": 8, |
| "init_scale": 0.25, |
| "input_channels": 6, |
| "layers": 12, |
| "n_ctx": 3072, |
| "name": "CLIPImageGridUpsamplePointDiffusionTransformer", |
| "output_channels": 12, |
| "time_token_cond": True, |
| "width": 512, |
| }, |
| "sdf": { |
| "decoder_heads": 4, |
| "decoder_layers": 4, |
| "encoder_heads": 4, |
| "encoder_layers": 8, |
| "init_scale": 0.25, |
| "n_ctx": 4096, |
| "name": "CrossAttentionPointCloudSDFModel", |
| "width": 256, |
| }, |
| } |
|
|
|
|
| def model_from_config(config: Dict[str, Any], device: torch.device) -> nn.Module: |
| config = config.copy() |
| name = config.pop("name") |
| if name == "PointDiffusionTransformer": |
| return PointDiffusionTransformer(device=device, dtype=torch.float32, **config) |
| elif name == "CLIPImagePointDiffusionTransformer": |
| return CLIPImagePointDiffusionTransformer(device=device, dtype=torch.float32, **config) |
| elif name == "CLIPImageGridPointDiffusionTransformer": |
| return CLIPImageGridPointDiffusionTransformer(device=device, dtype=torch.float32, **config) |
| elif name == "UpsamplePointDiffusionTransformer": |
| return UpsamplePointDiffusionTransformer(device=device, dtype=torch.float32, **config) |
| elif name == "CLIPImageGridUpsamplePointDiffusionTransformer": |
| return CLIPImageGridUpsamplePointDiffusionTransformer( |
| device=device, dtype=torch.float32, **config |
| ) |
| elif name == "CrossAttentionPointCloudSDFModel": |
| return CrossAttentionPointCloudSDFModel(device=device, dtype=torch.float32, **config) |
| raise ValueError(f"unknown model name: {name}") |
|
|