| | |
| |
|
| | import warnings |
| | import copy |
| | from dataclasses import dataclass |
| | from typing import Any, Dict, Optional, Tuple, Union |
| |
|
| | import torch |
| | import torch.distributions as dists |
| | from torch.nn import functional as F |
| | from transformers import __version__ |
| | from transformers.generation.configuration_utils import GenerationConfig |
| | from transformers.utils import ModelOutput, is_torchdynamo_compiling, logging |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | def top_p_logits(logits, top_p=None): |
| | sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
| | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| | sorted_indices_to_remove = cumulative_probs > top_p |
| | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| | sorted_indices_to_remove[..., 0] = 0 |
| | mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device) |
| | mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove) |
| | logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min) |
| | return logits |
| |
|
| | def top_k_logits(logits, top_k=None): |
| | if top_k is None or top_k == 0: |
| | return logits |
| | top_k = min(top_k, logits.size(-1)) |
| | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] |
| | logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min) |
| | return logits |
| |
|
| | def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False): |
| | if temperature > 0: |
| | logits = logits / temperature |
| | if top_p is not None and top_p < 1: |
| | logits = top_p_logits(logits, top_p) |
| | if top_k is not None: |
| | logits = top_k_logits(logits, top_k) |
| | probs = torch.softmax(logits.float(), dim=-1) |
| | if temperature > 0: |
| | x0 = dists.Categorical(probs=probs).sample() |
| | else: |
| | _, x0 = probs.max(dim=-1) |
| | |
| | confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1) |
| |
|
| | if margin_confidence: |
| | sorted_probs, _ = torch.sort(probs, dim=-1, descending=True) |
| | top1_probs = sorted_probs[..., 0] |
| | top2_probs = sorted_probs[..., 1] |
| | confidence = top1_probs - top2_probs |
| | elif neg_entropy: |
| | log_probs = torch.log(probs.clamp(min=1e-10)) |
| | confidence = (probs * log_probs).sum(dim=-1) |
| | |
| | return confidence, x0 |
| |
|
| |
|
| | @dataclass |
| | class MDMModelOutput(ModelOutput): |
| | sequences: torch.LongTensor = None |
| | history: Optional[Tuple[torch.FloatTensor]] = None |
| |
|
| | class MDMGenerationConfig(GenerationConfig): |
| | def __init__(self, **kwargs): |
| | super().__init__(**kwargs) |
| | self.temperature: float = kwargs.pop("temperature", 0.0) |
| | self.top_p: Optional[float] = kwargs.pop("top_p", None) |
| | self.top_k: Optional[int] = kwargs.pop("top_k", None) |
| | self.eps: float = kwargs.pop("eps", 1e-3) |
| | self.steps: int = kwargs.pop("steps", 512) |
| | self.alg: str = kwargs.pop("alg", 'entropy') |
| | self.alg_temp: Optional[float] = kwargs.pop("alg_temp", 0.0) |
| | self.output_history: bool = kwargs.pop("output_history", False) |
| | self.mask_token_id = kwargs.pop("mask_token_id", None) |
| |
|
| |
|
| | class MDMGenerationMixin: |
| | """ |
| | Mixin class for Masked Diffusion Model generation, adapted from the Dream model's generation utils. |
| | """ |
| | @staticmethod |
| | def _expand_inputs_for_generation( |
| | expand_size: int = 1, |
| | input_ids: Optional[torch.LongTensor] = None, |
| | attention_mask: Optional[torch.LongTensor] = None |
| | ) -> Tuple[torch.LongTensor, Dict[str, Any]]: |
| | if expand_size == 1: |
| | return input_ids, attention_mask |
| | |
| | if input_ids is not None: |
| | input_ids = input_ids.repeat_interleave(expand_size, dim=0) |
| | if attention_mask is not None: |
| | attention_mask = attention_mask.repeat_interleave(expand_size, dim=0) |
| | return input_ids, attention_mask |
| |
|
| | def _prepare_generation_config( |
| | self, generation_config: Optional[GenerationConfig], **kwargs |
| | ) -> MDMGenerationConfig: |
| | if generation_config is None: |
| | generation_config = self.generation_config |
| | |
| | |
| | if not isinstance(generation_config, MDMGenerationConfig): |
| | generation_config = MDMGenerationConfig.from_dict(generation_config.to_dict()) |
| |
|
| | |
| | generation_config.update(**kwargs) |
| | return generation_config |
| |
|
| | @torch.no_grad() |
| | def diffusion_generate( |
| | self, |
| | inputs: Optional[torch.Tensor] = None, |
| | generation_config: Optional[MDMGenerationConfig] = None, |
| | **kwargs, |
| | ) -> Union[MDMModelOutput, torch.LongTensor]: |
| | |
| | |
| | generation_config = self._prepare_generation_config(generation_config, **kwargs) |
| |
|
| | |
| | input_ids = inputs |
| | attention_mask = kwargs.get("attention_mask", None) |
| |
|
| | if input_ids is None: |
| | raise ValueError("`inputs` must be provided for diffusion generation.") |
| |
|
| | if generation_config.max_new_tokens is not None: |
| | generation_config.max_length = input_ids.shape[-1] + generation_config.max_new_tokens |
| | |
| | |
| | input_ids, attention_mask = self._expand_inputs_for_generation( |
| | expand_size=generation_config.num_return_sequences, |
| | input_ids=input_ids, |
| | attention_mask=attention_mask |
| | ) |
| | |
| | return self._sample( |
| | input_ids, |
| | attention_mask=attention_mask, |
| | generation_config=generation_config |
| | ) |
| |
|
| | def _sample( |
| | self, |
| | input_ids: torch.LongTensor, |
| | attention_mask: Optional[torch.LongTensor], |
| | generation_config: MDMGenerationConfig |
| | ) -> Union[MDMModelOutput, torch.LongTensor]: |
| | |
| | |
| | max_length = generation_config.max_length |
| | mask_token_id = generation_config.mask_token_id |
| | if mask_token_id is None: |
| | raise ValueError("`mask_token_id` must be set in the generation config.") |
| |
|
| | steps = generation_config.steps |
| | eps = generation_config.eps |
| | alg = generation_config.alg |
| | alg_temp = generation_config.alg_temp |
| | temperature = generation_config.temperature |
| | top_p = generation_config.top_p |
| | top_k = generation_config.top_k |
| |
|
| | histories = [] if generation_config.output_history else None |
| |
|
| | |
| | x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id) |
| |
|
| | |
| | |
| | gen_attention_mask = (x != self.config.pad_token_id).long() if self.config.pad_token_id is not None else None |
| |
|
| | timesteps = torch.linspace(1, eps, steps + 1, device=x.device) |
| |
|
| | for i in range(steps): |
| | mask_index = (x == mask_token_id) |
| | if not mask_index.any(): |
| | break |
| | |
| | outputs = self(input_ids=x, attention_mask=gen_attention_mask, is_causal=False) |
| | logits = outputs.logits |
| | |
| | |
| | logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1) |
| |
|
| | mask_logits = logits[mask_index] |
| | t = timesteps[i] |
| | s = timesteps[i + 1] |
| |
|
| | if alg == 'origin': |
| | p_transfer = 1 - s / t if i < steps - 1 else 1 |
| | x0 = torch.full_like(x[mask_index], fill_value=mask_token_id, device=self.device, dtype=torch.long) |
| | transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer |
| | _, sampled_tokens = sample_tokens(mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k) |
| | x0[transfer_index_t_s] = sampled_tokens |
| | x[mask_index] = x0 |
| | else: |
| | |
| | confidence_alg_map = {'maskgit_plus': False, 'topk_margin': True, 'entropy': True} |
| | is_margin_conf = confidence_alg_map.get(alg, False) |
| | is_neg_entropy = alg == 'entropy' |
| | |
| | confidence, x0 = sample_tokens(mask_logits, temperature, top_p, top_k, margin_confidence=is_margin_conf, neg_entropy=is_neg_entropy) |
| |
|
| | num_masked = mask_index.sum(dim=-1, keepdim=True) |
| | gamma = 1 - s / t |
| | num_to_unmask = (num_masked * gamma).long() |
| |
|
| | |
| | full_confidence = torch.full_like(x, -torch.inf, device=self.device, dtype=confidence.dtype) |
| | full_confidence[mask_index] = confidence |
| |
|
| | if (alg_temp is not None and alg_temp > 0): |
| | |
| | unmask_probs = F.softmax(full_confidence / alg_temp, dim=-1) |
| | unmask_indices = torch.multinomial(unmask_probs, num_samples=num_to_unmask.max(), replacement=False) |
| | else: |
| | |
| | _, unmask_indices = torch.topk(full_confidence, k=num_to_unmask.max(), dim=-1) |
| |
|
| | |
| | rows = torch.arange(x.size(0), device=x.device).unsqueeze(1) |
| | unmask_selection_mask = torch.zeros_like(x, dtype=torch.bool) |
| | unmask_selection_mask[rows, unmask_indices] = True |
| | |
| | |
| | unmask_selection_mask = unmask_selection_mask & (torch.cumsum(unmask_selection_mask.long(), dim=-1) <= num_to_unmask) |
| |
|
| | |
| | x_unmasked_proposals = torch.full_like(x, fill_value=mask_token_id) |
| | x_unmasked_proposals[mask_index] = x0 |
| |
|
| | |
| | x[unmask_selection_mask] = x_unmasked_proposals[unmask_selection_mask] |
| |
|
| | if histories is not None: |
| | histories.append(x.clone()) |
| |
|
| | if generation_config.return_dict_in_generate: |
| | return MDMModelOutput(sequences=x, history=histories) |
| | else: |
| | return x |