| |
|
|
| import torch |
| import torch.nn as nn |
| from typing import List, Optional |
| from transformers import ( |
| AutoModel, |
| PreTrainedModel, |
| PretrainedConfig, |
| AutoConfig, |
| AutoModel, |
| ) |
|
|
| |
| |
| |
|
|
| class RQAModelConfig(PretrainedConfig): |
| model_type = "rqa" |
|
|
| def __init__( |
| self, |
| base_model_name: str = "FacebookAI/xlm-roberta-large", |
| num_error_types: int = 6, |
| has_issue_projection_dim: int = 256, |
| errors_projection_dim: int = 512, |
| has_issue_dropout: float = 0.25, |
| errors_dropout: float = 0.3, |
| temperature_has_issue: float = 1.0, |
| temperature_errors: Optional[List[float]] = None, |
| **kwargs |
| ): |
| super().__init__(**kwargs) |
|
|
| self.base_model_name = base_model_name |
| self.num_error_types = num_error_types |
| self.has_issue_projection_dim = has_issue_projection_dim |
| self.errors_projection_dim = errors_projection_dim |
| self.has_issue_dropout = has_issue_dropout |
| self.errors_dropout = errors_dropout |
|
|
| self.temperature_has_issue = temperature_has_issue |
| self.temperature_errors = ( |
| temperature_errors |
| if temperature_errors is not None |
| else [1.0] * num_error_types |
| ) |
|
|
| |
| |
| |
|
|
| class MeanPooling(nn.Module): |
| def forward(self, last_hidden_state, attention_mask): |
| mask = attention_mask.unsqueeze(-1).float() |
| summed = torch.sum(last_hidden_state * mask, dim=1) |
| denom = torch.clamp(mask.sum(dim=1), min=1e-9) |
| return summed / denom |
|
|
| |
| |
| |
|
|
| class RQAModelHF(PreTrainedModel): |
| config_class = RQAModelConfig |
|
|
| def __init__(self, config: RQAModelConfig): |
| super().__init__(config) |
|
|
| self.encoder = AutoModel.from_pretrained(config.base_model_name) |
| hidden_size = self.encoder.config.hidden_size |
|
|
| self.pooler = MeanPooling() |
|
|
| self.has_issue_projection = nn.Sequential( |
| nn.Linear(hidden_size, config.has_issue_projection_dim), |
| nn.LayerNorm(config.has_issue_projection_dim), |
| nn.GELU(), |
| nn.Dropout(config.has_issue_dropout), |
| ) |
|
|
| self.errors_projection = nn.Sequential( |
| nn.Linear(hidden_size, config.errors_projection_dim), |
| nn.LayerNorm(config.errors_projection_dim), |
| nn.GELU(), |
| nn.Dropout(config.errors_dropout), |
| ) |
|
|
| self.has_issue_head = nn.Linear(config.has_issue_projection_dim, 1) |
| self.errors_head = nn.Linear( |
| config.errors_projection_dim, config.num_error_types |
| ) |
|
|
| self._init_custom_weights() |
|
|
| def _init_custom_weights(self): |
| for module in [ |
| self.has_issue_projection[0], |
| self.errors_projection[0], |
| self.has_issue_head, |
| self.errors_head, |
| ]: |
| if isinstance(module, nn.Linear): |
| nn.init.xavier_uniform_(module.weight) |
| nn.init.zeros_(module.bias) |
|
|
| def forward(self, input_ids=None, attention_mask=None, **kwargs): |
| outputs = self.encoder( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| return_dict=True, |
| ) |
|
|
| pooled = self.pooler(outputs.last_hidden_state, attention_mask) |
|
|
| has_issue_logits = self.has_issue_head( |
| self.has_issue_projection(pooled) |
| ).squeeze(-1) |
|
|
| errors_logits = self.errors_head( |
| self.errors_projection(pooled) |
| ) |
|
|
| return { |
| "has_issue_logits": has_issue_logits, |
| "errors_logits": errors_logits, |
| } |
|
|
| |
| |
| |
|
|
| AutoConfig.register("rqa", RQAModelConfig) |
| AutoModel.register(RQAModelConfig, RQAModelHF) |
|
|
| print("✅ RQA зарегистрирован в Transformers") |