| |
|
| | from typing import Callable, List, Optional
|
| |
|
| | import numpy as np
|
| |
|
| |
|
| | def ordered_halving(val):
|
| | bin_str = f"{val:064b}"
|
| | bin_flip = bin_str[::-1]
|
| | as_int = int(bin_flip, 2)
|
| |
|
| | return as_int / (1 << 64)
|
| |
|
| |
|
| | def uniform(
|
| | step: int = ...,
|
| | num_frames: int = ...,
|
| | context_size: Optional[int] = None,
|
| | context_stride: int = 3,
|
| | context_overlap: int = 4,
|
| | closed_loop: bool = True,
|
| | ):
|
| | if num_frames <= context_size:
|
| | yield list(range(num_frames))
|
| | return
|
| |
|
| | context_stride = min(
|
| | context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1
|
| | )
|
| |
|
| | for context_step in 1 << np.arange(context_stride):
|
| | pad = int(round(num_frames * ordered_halving(step)))
|
| | for j in range(
|
| | int(ordered_halving(step) * context_step) + pad,
|
| | num_frames + pad + (0 if closed_loop else -context_overlap),
|
| | (context_size * context_step - context_overlap),
|
| | ):
|
| | next_itr = []
|
| | for e in range(j, j + context_size * context_step, context_step):
|
| | if e >= num_frames:
|
| | e = num_frames - 2 - e % num_frames
|
| | next_itr.append(e)
|
| |
|
| | yield next_itr
|
| |
|
| |
|
| | def get_context_scheduler(name: str) -> Callable:
|
| | if name == "uniform":
|
| | return uniform
|
| | else:
|
| | raise ValueError(f"Unknown context_overlap policy {name}")
|
| |
|
| |
|
| | def get_total_steps(
|
| | scheduler,
|
| | timesteps: List[int],
|
| | num_steps: Optional[int] = None,
|
| | num_frames: int = ...,
|
| | context_size: Optional[int] = None,
|
| | context_stride: int = 3,
|
| | context_overlap: int = 4,
|
| | closed_loop: bool = True,
|
| | ):
|
| | return sum(
|
| | len(
|
| | list(
|
| | scheduler(
|
| | i,
|
| | num_steps,
|
| | num_frames,
|
| | context_size,
|
| | context_stride,
|
| | context_overlap,
|
| | )
|
| | )
|
| | )
|
| | for i in range(len(timesteps))
|
| | )
|
| |
|