| from typing import Optional, Union |
|
|
| import torch |
|
|
| from transformers.modeling_outputs import CausalLMOutputWithPast |
| from transformers.models.llama.modeling_llama import LlamaModel |
|
|
| from ...cache_utils import Cache |
|
|
|
|
| |
| class SuperModel(LlamaModel): |
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| ) -> Union[tuple, CausalLMOutputWithPast]: |
| out = super().forward( |
| input_ids, |
| attention_mask, |
| position_ids, |
| past_key_values, |
| inputs_embeds, |
| use_cache, |
| output_attentions, |
| output_hidden_states, |
| return_dict, |
| cache_position, |
| ) |
| out.logits *= 2**4 |
| return out |
|
|