| |
| from typing import Callable, List, Optional |
|
|
| import numpy as np |
|
|
|
|
| def ordered_halving(val): |
| bin_str = f"{val:064b}" |
| bin_flip = bin_str[::-1] |
| as_int = int(bin_flip, 2) |
|
|
| return as_int / (1 << 64) |
|
|
|
|
| def uniform( |
| step: int = ..., |
| num_steps: Optional[int] = None, |
| num_frames: int = ..., |
| context_size: Optional[int] = None, |
| context_stride: int = 3, |
| context_overlap: int = 4, |
| closed_loop: bool = True, |
| ): |
| if num_frames <= context_size: |
| yield list(range(num_frames)) |
| return |
|
|
| context_stride = min( |
| context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1 |
| ) |
|
|
| for context_step in 1 << np.arange(context_stride): |
| pad = int(round(num_frames * ordered_halving(step))) |
| for j in range( |
| int(ordered_halving(step) * context_step) + pad, |
| num_frames + pad + (0 if closed_loop else -context_overlap), |
| (context_size * context_step - context_overlap), |
| ): |
| yield [ |
| e % num_frames |
| for e in range(j, j + context_size * context_step, context_step) |
| ] |
|
|
|
|
| def get_context_scheduler(name: str) -> Callable: |
| if name == "uniform": |
| return uniform |
| else: |
| raise ValueError(f"Unknown context_overlap policy {name}") |
|
|
|
|
| def get_total_steps( |
| scheduler, |
| timesteps: List[int], |
| num_steps: Optional[int] = None, |
| num_frames: int = ..., |
| context_size: Optional[int] = None, |
| context_stride: int = 3, |
| context_overlap: int = 4, |
| closed_loop: bool = True, |
| ): |
| return sum( |
| len( |
| list( |
| scheduler( |
| i, |
| num_steps, |
| num_frames, |
| context_size, |
| context_stride, |
| context_overlap, |
| ) |
| ) |
| ) |
| for i in range(len(timesteps)) |
| ) |
|
|