Spaces:
Build error
Build error
File size: 4,112 Bytes
94ada0b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# Please refer to original code: https://github.com/NVlabs/instant-ngp
# and the pytorch wrapper from https://github.com/ashawkey/torch-ngp
import os
import torch
from .. import custom_ops
from torch.cuda.amp import custom_bwd, custom_fwd
_plugin = None
_null_tensor = torch.empty([0])
def _init():
global _plugin
if _plugin is None:
_plugin = custom_ops.get_plugin(
module_name='hash_sample_plugin',
sources=['hash_sample.cpp', 'hash_sample.cu'],
headers=['hash_sample.h', 'utils.h'],
source_dir=os.path.dirname(__file__),
extra_cuda_cflags=['--use_fast_math'],
)
return True
def hash_sample(x, h, offsets, beta=2, base_res=16, calc_grad=True, mode='fast_hash'):
"""Hash-table look up and d-linear interpolation
x: B x N x D coordinates
h: B x L x T x C hash-tables
offsets: L resolutions
"""
assert x.device.type == 'cuda'
assert (x.size(-1) == 3) or (x.size(-1) == 2), "currently only 2D/3D is implemented"
_init()
return _hash_sample_cuda(mode).apply(x, h, offsets, beta, base_res, calc_grad)
_hash_sample_cuda_cache = dict()
def _hash_sample_cuda(mode='fast_hash'):
"""CUDA implementation of hash-table look-up
"""
if mode in _hash_sample_cuda_cache:
return _hash_sample_cuda_cache[mode]
if mode == 'fast_hash':
h_mode = 0
elif mode == 'grid_hash':
h_mode = 1
else:
raise NotImplementedError('only two types are supported now.')
class HashSampleCuda(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.half)
def forward(ctx, inputs, embeddings, offsets, beta, base_resolution, calc_grad_inputs=False):
# inputs: [B, N, D], float in [0, 1]
# embeddings: [B, sO, C], float
# offsets: [L + 1], int
# RETURN: [B, N, F], float
inputs = inputs.contiguous()
embeddings = embeddings.contiguous()
offsets = offsets.contiguous().to(inputs.device)
B, N, D = inputs.shape # batch size, # of samples, coord dim
L = offsets.shape[0] - 1 # level
C = embeddings.shape[-1] # embedding dim for each level
H = base_resolution # base resolution
outputs = torch.zeros(B, N, L * C, device=inputs.device, dtype=inputs.dtype)
if calc_grad_inputs:
dy_dx = torch.zeros(B, N, L * D * C).to(inputs.device, dtype=inputs.dtype)
else:
dy_dx = torch.zeros(1).to(inputs.device, dtype=inputs.dtype)
_plugin.hash_encode_forward(inputs, embeddings, offsets, outputs, beta, B, N, D, C, L, H, calc_grad_inputs, dy_dx, h_mode)
ctx.save_for_backward(inputs, embeddings, offsets, dy_dx)
ctx.dims = [B, N, D, C, L, H, beta]
ctx.calc_grad_inputs = calc_grad_inputs
return outputs
@staticmethod
@custom_bwd
def backward(ctx, grad):
# grad: [B, L * C]
grad = grad.contiguous()
inputs, embeddings, offsets, dy_dx = ctx.saved_tensors
B, N, D, C, L, H, beta = ctx.dims
calc_grad_inputs = ctx.calc_grad_inputs
grad_embeddings = torch.zeros_like(embeddings)
if calc_grad_inputs:
grad_inputs = torch.zeros_like(inputs)
else:
grad_inputs = torch.zeros(1).to(inputs.device, dtype=inputs.dtype)
_plugin.hash_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, beta, B, N, D, C, L, H, calc_grad_inputs, dy_dx, grad_inputs, h_mode)
if calc_grad_inputs:
return grad_inputs, grad_embeddings, None, None, None, None
else:
return None, grad_embeddings, None, None, None, None
# Add to cache.
_hash_sample_cuda_cache[mode] = HashSampleCuda
return HashSampleCuda |