| | import torch |
| | from safetensors.torch import load_file |
| |
|
| | from torch.nn.utils import remove_weight_norm |
| | import warnings |
| | warnings.simplefilter(action='ignore', category=FutureWarning) |
| |
|
| |
|
| | def load_ckpt_state_dict(ckpt_path): |
| | if ckpt_path.endswith(".safetensors"): |
| | state_dict = load_file(ckpt_path) |
| | else: |
| | state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] |
| | |
| | return state_dict |
| |
|
| | def remove_weight_norm_from_model(model): |
| | for module in model.modules(): |
| | if hasattr(module, "weight"): |
| | print(f"Removing weight norm from {module}") |
| | remove_weight_norm(module) |
| |
|
| | return model |
| |
|
| | |
| | |
| |
|
| | def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None): |
| | """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension. |
| | |
| | Args: |
| | input (torch.Tensor): The input tensor containing probabilities. |
| | num_samples (int): Number of samples to draw. |
| | replacement (bool): Whether to draw with replacement or not. |
| | Keywords args: |
| | generator (torch.Generator): A pseudorandom number generator for sampling. |
| | Returns: |
| | torch.Tensor: Last dimension contains num_samples indices |
| | sampled from the multinomial probability distribution |
| | located in the last dimension of tensor input. |
| | """ |
| |
|
| | if num_samples == 1: |
| | q = torch.empty_like(input).exponential_(1, generator=generator) |
| | return torch.argmax(input / q, dim=-1, keepdim=True).to(torch.int64) |
| |
|
| | input_ = input.reshape(-1, input.shape[-1]) |
| | output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator) |
| | output = output_.reshape(*list(input.shape[:-1]), -1) |
| | return output |
| |
|
| |
|
| | def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor: |
| | """Sample next token from top K values along the last dimension of the input probs tensor. |
| | |
| | Args: |
| | probs (torch.Tensor): Input probabilities with token candidates on the last dimension. |
| | k (int): The k in “top-k”. |
| | Returns: |
| | torch.Tensor: Sampled tokens. |
| | """ |
| | top_k_value, _ = torch.topk(probs, k, dim=-1) |
| | min_value_top_k = top_k_value[..., [-1]] |
| | probs *= (probs >= min_value_top_k).float() |
| | probs.div_(probs.sum(dim=-1, keepdim=True)) |
| | next_token = multinomial(probs, num_samples=1) |
| | return next_token |
| |
|
| |
|
| | def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: |
| | """Sample next token from top P probabilities along the last dimension of the input probs tensor. |
| | |
| | Args: |
| | probs (torch.Tensor): Input probabilities with token candidates on the last dimension. |
| | p (int): The p in “top-p”. |
| | Returns: |
| | torch.Tensor: Sampled tokens. |
| | """ |
| | probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) |
| | probs_sum = torch.cumsum(probs_sort, dim=-1) |
| | mask = probs_sum - probs_sort > p |
| | probs_sort *= (~mask).float() |
| | probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) |
| | next_token = multinomial(probs_sort, num_samples=1) |
| | next_token = torch.gather(probs_idx, -1, next_token) |
| | return next_token |
| |
|
| | def next_power_of_two(n): |
| | return 2 ** (n - 1).bit_length() |
| |
|
| | def next_multiple_of_64(n): |
| | return ((n + 63) // 64) * 64 |