| from collections import deque |
| from typing import Any, Optional, Union |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
|
|
| from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList |
| from transformers.generation.logits_process import ( |
| TemperatureLogitsWarper, |
| TopKLogitsWarper, |
| TopPLogitsWarper, |
| ) |
| from transformers.generation.utils import ( |
| GenerateDecoderOnlyOutput, |
| GenerateEncoderDecoderOutput, |
| ) |
|
|
|
|
| def _deepconf_generate( |
| model: Any, |
| input_ids: torch.LongTensor, |
| logits_processor: Optional[LogitsProcessorList], |
| stopping_criteria: Optional[StoppingCriteriaList], |
| generation_config: Optional[GenerationConfig], |
| synced_gpus: bool = False, |
| streamer: Optional[Any] = None, |
| **model_kwargs, |
| ) -> Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput, torch.LongTensor]: |
| """Custom decoding with DeepCONF (confidence-based early stopping). |
| |
| Args: |
| model: PreTrainedModel with a LM head. |
| input_ids: Prompt ids of shape (batch, seq_len). |
| logits_processor: Optional logits processors. |
| stopping_criteria: Optional stopping criteria. |
| generation_config: GenerationConfig controlling sampling/outputs. |
| synced_gpus: Keep looping to max length for distributed setups. |
| streamer: Optional streamer for incremental tokens. |
| **model_kwargs: Forward pass kwargs (e.g., attention_mask). |
| |
| Returns: |
| GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput, or LongTensor |
| depending on `return_dict_in_generate` and model type. |
| """ |
|
|
| |
| enable_conf = getattr(generation_config, "enable_conf", False) |
| enable_early_stopping = getattr(generation_config, "enable_early_stopping", True) |
| window_size = getattr(generation_config, "window_size", 2048) |
| threshold = getattr( |
| generation_config, "threshold", 17.0 |
| ) |
| conf_topk = getattr( |
| generation_config, "conf_topk", 20 |
| ) |
|
|
| |
| if not enable_conf: |
| return model._sample( |
| input_ids, |
| logits_processor=logits_processor, |
| stopping_criteria=stopping_criteria, |
| generation_config=generation_config, |
| synced_gpus=synced_gpus, |
| streamer=streamer, |
| **model_kwargs, |
| ) |
|
|
| |
| |
| pad_token_id = generation_config._pad_token_tensor |
|
|
| output_attentions = generation_config.output_attentions |
| output_hidden_states = generation_config.output_hidden_states |
| output_scores = generation_config.output_scores |
| output_logits = generation_config.output_logits |
| return_dict_in_generate = generation_config.return_dict_in_generate |
| output_confidences = getattr(generation_config, "output_confidences", False) |
| |
| deepconf_variant = getattr( |
| generation_config, "deepconf_variant", None |
| ) |
| deepconf_eta = getattr(generation_config, "deepconf_eta", None) |
| deepconf_warmup_confidences = getattr( |
| generation_config, "deepconf_warmup_confidences", None |
| ) |
| has_eos_stopping_criteria = any( |
| hasattr(criteria, "eos_token_id") for criteria in stopping_criteria |
| ) |
| do_sample = generation_config.do_sample |
|
|
| |
| if enable_conf and threshold is not None: |
| pass |
| elif ( |
| enable_conf |
| and deepconf_variant is not None |
| and deepconf_warmup_confidences is not None |
| ): |
| confs = deepconf_warmup_confidences |
| if hasattr(confs, "detach"): |
| confs = confs.detach().cpu().numpy() |
| elif isinstance(confs, torch.Tensor): |
| confs = confs.cpu().numpy() |
| confs = np.asarray(confs, dtype=np.float32).ravel() |
| eta = deepconf_eta |
| if eta is None: |
| eta = ( |
| 0.1 |
| if deepconf_variant == "low" |
| else 0.9 |
| if deepconf_variant == "high" |
| else 0.5 |
| ) |
| pct = max(0.0, min(100.0, 100.0 - (eta * 100.0))) |
| threshold = float(np.percentile(confs, pct)) |
|
|
| |
| scores = () if (return_dict_in_generate and output_scores) else None |
| raw_logits = () if (return_dict_in_generate and output_logits) else None |
| decoder_attentions = () if (return_dict_in_generate and output_attentions) else None |
| cross_attentions = () if (return_dict_in_generate and output_attentions) else None |
| decoder_hidden_states = ( |
| () if (return_dict_in_generate and output_hidden_states) else None |
| ) |
|
|
| |
| if return_dict_in_generate and model.config.is_encoder_decoder: |
| encoder_attentions = ( |
| model_kwargs["encoder_outputs"].get("attentions") |
| if output_attentions |
| else None |
| ) |
| encoder_hidden_states = ( |
| model_kwargs["encoder_outputs"].get("hidden_states") |
| if output_hidden_states |
| else None |
| ) |
|
|
| |
| batch_size, cur_len = input_ids.shape[:2] |
| unfinished_sequences = torch.ones( |
| batch_size, dtype=torch.long, device=input_ids.device |
| ) |
| |
|
|
| |
| |
| conf_group_lists = [deque(maxlen=window_size) for _ in range(batch_size)] |
| conf_grouped_sums = [ |
| 0.0 for _ in range(batch_size) |
| ] |
|
|
| |
| step_confidences = [] if (return_dict_in_generate and output_confidences) else None |
|
|
| |
| steps = 0 |
| max_new_tokens = getattr(generation_config, "max_new_tokens", None) or 512 |
| |
| |
| model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) |
| while steps < max_new_tokens and unfinished_sequences.max() != 0: |
| |
| model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs) |
|
|
| |
| model_inputs.update( |
| {"output_attentions": output_attentions} if output_attentions else {} |
| ) |
| model_inputs.update( |
| {"output_hidden_states": output_hidden_states} |
| if output_hidden_states |
| else {} |
| ) |
|
|
| |
| with torch.no_grad(): |
| outputs = model(**model_inputs, return_dict=True) |
| next_token_logits = outputs.logits[:, -1, :].detach() |
|
|
| |
| if hasattr(outputs, "past_key_values") and outputs.past_key_values is not None: |
| model_kwargs["past_key_values"] = outputs.past_key_values |
|
|
| |
| next_token_scores = logits_processor(input_ids, next_token_logits) |
|
|
| |
| warpers = LogitsProcessorList() |
| |
| temperature = getattr(generation_config, "temperature", 1.0) |
| if temperature is not None and temperature != 1.0: |
| warpers.append(TemperatureLogitsWarper(temperature)) |
| |
| top_k = getattr(generation_config, "top_k", None) |
| if top_k is not None and isinstance(top_k, int) and top_k > 0: |
| warpers.append(TopKLogitsWarper(top_k)) |
| |
| top_p = getattr(generation_config, "top_p", None) |
| if top_p is not None and top_p < 1.0: |
| warpers.append(TopPLogitsWarper(top_p)) |
| if len(warpers) > 0: |
| next_token_scores = warpers(input_ids, next_token_scores) |
|
|
| |
| if return_dict_in_generate: |
| if output_scores: |
| scores += (next_token_scores,) |
| if output_logits: |
| raw_logits += (next_token_logits,) |
| if output_attentions: |
| decoder_attentions += ( |
| (outputs.decoder_attentions,) |
| if model.config.is_encoder_decoder |
| else (outputs.attentions,) |
| ) |
| if model.config.is_encoder_decoder: |
| cross_attentions += (outputs.cross_attentions,) |
|
|
| if output_hidden_states: |
| decoder_hidden_states += ( |
| (outputs.decoder_hidden_states,) |
| if model.config.is_encoder_decoder |
| else (outputs.hidden_states,) |
| ) |
|
|
| |
| if do_sample: |
| probs = F.softmax(next_token_scores, dim=-1) |
| next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) |
| else: |
| next_tokens = torch.argmax(next_token_scores, dim=-1) |
|
|
| |
| |
| |
| probs = F.softmax(next_token_logits, dim=-1) |
|
|
| deepconf_stopping = torch.ones( |
| batch_size, dtype=torch.bool, device=input_ids.device |
| ) |
| step_conf_values = [ |
| 0.0 |
| ] * batch_size |
|
|
| for i in range(batch_size): |
| if not unfinished_sequences[i]: |
| continue |
|
|
| |
| top_probs, _ = torch.topk(probs[i], k=conf_topk, dim=-1) |
| |
| |
| eps = torch.finfo(top_probs.dtype).eps if top_probs.dtype == torch.float32 else 1e-7 |
| top_probs = torch.clamp(top_probs, min=eps) |
| log_probs = torch.log(top_probs) |
| |
| conf = -log_probs.mean().item() |
|
|
| |
| if len(conf_group_lists[i]) >= window_size: |
| conf_grouped_sums[i] -= conf_group_lists[i][0] |
| conf_group_lists[i].append(conf) |
| conf_grouped_sums[i] += conf |
|
|
| |
| if enable_early_stopping and len(conf_group_lists[i]) >= window_size: |
| avg_conf = conf_grouped_sums[i] / len(conf_group_lists[i]) |
| if avg_conf < threshold: |
| deepconf_stopping[i] = False |
|
|
| if step_confidences is not None: |
| step_conf_values[i] = conf |
|
|
| if step_confidences is not None: |
| |
| step_confidences.append( |
| torch.tensor(step_conf_values, device=input_ids.device) |
| ) |
|
|
| |
| if has_eos_stopping_criteria and pad_token_id is not None: |
| next_tokens = next_tokens * unfinished_sequences + pad_token_id * ( |
| 1 - unfinished_sequences |
| ) |
|
|
| |
| input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) |
| |
| if model_kwargs.get("attention_mask") is not None: |
| attn = model_kwargs["attention_mask"] |
| model_kwargs["attention_mask"] = torch.cat( |
| [ |
| attn, |
| torch.ones((batch_size, 1), dtype=attn.dtype, device=attn.device), |
| ], |
| dim=-1, |
| ) |
| |
| model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1 |
| if streamer is not None: |
| streamer.put(next_tokens.cpu()) |
|
|
| |
| sc = stopping_criteria(input_ids, scores) |
| if isinstance(sc, torch.Tensor): |
| unfinished_sequences = unfinished_sequences & ~sc |
| elif sc: |
| |
| unfinished_sequences = torch.zeros_like(unfinished_sequences) |
|
|
| |
| unfinished_sequences = unfinished_sequences & deepconf_stopping |
|
|
| |
| if unfinished_sequences.max() == 0 and not synced_gpus: |
| break |
| cur_len += 1 |
| steps += 1 |
|
|
| |
| del outputs |
|
|
| if streamer is not None: |
| streamer.end() |
|
|
| |
| if return_dict_in_generate: |
| |
| confidences_tensor = None |
| if step_confidences is not None and len(step_confidences) > 0: |
| |
| confidences_tensor = torch.stack(step_confidences, dim=0).transpose(0, 1) |
| if model.config.is_encoder_decoder: |
| output = GenerateEncoderDecoderOutput( |
| sequences=input_ids, |
| scores=scores, |
| logits=raw_logits, |
| encoder_attentions=encoder_attentions, |
| encoder_hidden_states=encoder_hidden_states, |
| decoder_attentions=decoder_attentions, |
| cross_attentions=cross_attentions, |
| decoder_hidden_states=decoder_hidden_states, |
| past_key_values=model_kwargs.get("past_key_values"), |
| ) |
| if confidences_tensor is not None: |
| output["confidences"] = confidences_tensor |
| try: |
| setattr(output, "confidences", confidences_tensor) |
| except Exception: |
| pass |
| return output |
| else: |
| output = GenerateDecoderOnlyOutput( |
| sequences=input_ids, |
| scores=scores, |
| logits=raw_logits, |
| attentions=decoder_attentions, |
| hidden_states=decoder_hidden_states, |
| past_key_values=model_kwargs.get("past_key_values"), |
| ) |
| if confidences_tensor is not None: |
| output["confidences"] = confidences_tensor |
| try: |
| setattr(output, "confidences", confidences_tensor) |
| except Exception: |
| pass |
| return output |
| else: |
| return input_ids |
|
|
|
|
| def generate(model, *args, **kwargs): |
| """Custom generate function for group beam search decoding. |
| Args: |
| model (`PreTrainedModel`): |
| The model to generate from. |
| num_beams (`int`): The number of beams to use for beam search. |
| num_beam_groups (`int`): The number of beam groups to use for beam search. |
| length_penalty (`float`): The length penalty to use for beam search. |
| early_stopping (`bool`): Whether to stop beam search when sufficient beams have finished. |
| num_return_sequences (`int`): The number of sequences to return. |
| max_length (`int`): The maximum length of the generated sequence. |
| """ |
| generation_outputs = GenerationMixin.generate( |
| model, *args, custom_generate=_deepconf_generate, **kwargs |
| ) |
| return generation_outputs |