import torch import torch.nn as nn from typing import List, Optional from transformers import ( AutoConfig, AutoModel, PreTrainedModel, PretrainedConfig, ) class RQAModelConfig(PretrainedConfig): model_type = "rqa_v2_2" def __init__( self, base_model_name: str = "FacebookAI/xlm-roberta-large", encoder_config: Optional[dict] = None, error_types: Optional[List[str]] = None, schema_version: str = "rqa.v2.2", has_issue_projection_dim: int = 256, hidden_projection_dim: int = 256, errors_projection_dim: int = 512, has_issue_dropout: float = 0.25, hidden_dropout: float = 0.25, errors_dropout: float = 0.3, temperature_has_issue: float = 1.0, temperature_is_hidden: float = 1.0, temperature_errors: Optional[List[float]] = None, threshold_has_issue: float = 0.5, threshold_is_hidden: float = 0.5, threshold_error: float = 0.5, threshold_errors: Optional[List[float]] = None, **kwargs ): super().__init__(**kwargs) self.base_model_name = base_model_name self.encoder_config = encoder_config self.error_types = error_types or [ "false_causality", "unsupported_claim", "overgeneralization", "missing_premise", "contradiction", "circular_reasoning", ] self.num_error_types = len(self.error_types) self.schema_version = schema_version self.has_issue_projection_dim = has_issue_projection_dim self.hidden_projection_dim = hidden_projection_dim self.errors_projection_dim = errors_projection_dim self.has_issue_dropout = has_issue_dropout self.hidden_dropout = hidden_dropout self.errors_dropout = errors_dropout self.temperature_has_issue = float(temperature_has_issue) self.temperature_is_hidden = float(temperature_is_hidden) self.temperature_errors = ( temperature_errors if temperature_errors is not None else [1.0] * self.num_error_types ) self.threshold_has_issue = float(threshold_has_issue) self.threshold_is_hidden = float(threshold_is_hidden) self.threshold_error = float(threshold_error) self.threshold_errors = ( threshold_errors if threshold_errors is not None else [float(threshold_error)] * self.num_error_types ) try: self._experts_implementation = "eager" self._experts_implementation_internal = "eager" except Exception: pass class MeanPooling(nn.Module): def forward(self, last_hidden_state, attention_mask): mask = attention_mask.unsqueeze(-1).float() summed = torch.sum(last_hidden_state * mask, dim=1) denom = torch.clamp(mask.sum(dim=1), min=1e-9) return summed / denom class RQAModelHF(PreTrainedModel): config_class = RQAModelConfig _supports_grouped_mm = False def __init__(self, config: RQAModelConfig): super().__init__(config) try: config._experts_implementation = "eager" config._experts_implementation_internal = "eager" except Exception: pass self.encoder = AutoModel.from_pretrained(config.base_model_name) hidden_size = self.encoder.config.hidden_size self.pooler = MeanPooling() self.has_issue_projection = nn.Sequential( nn.Linear(hidden_size, config.has_issue_projection_dim), nn.LayerNorm(config.has_issue_projection_dim), nn.GELU(), nn.Dropout(config.has_issue_dropout), ) self.hidden_projection = nn.Sequential( nn.Linear(hidden_size, config.hidden_projection_dim), nn.LayerNorm(config.hidden_projection_dim), nn.GELU(), nn.Dropout(config.hidden_dropout), ) self.errors_projection = nn.Sequential( nn.Linear(hidden_size, config.errors_projection_dim), nn.LayerNorm(config.errors_projection_dim), nn.GELU(), nn.Dropout(config.errors_dropout), ) self.has_issue_head = nn.Linear(config.has_issue_projection_dim, 1) self.is_hidden_head = nn.Linear(config.hidden_projection_dim, 1) self.errors_head = nn.Linear( config.errors_projection_dim, config.num_error_types, ) self.log_var_has_issue = nn.Parameter(torch.zeros(1)) self.log_var_is_hidden = nn.Parameter(torch.zeros(1)) self.log_var_errors = nn.Parameter(torch.zeros(1)) self._init_custom_weights() def _init_custom_weights(self): for module in [ self.has_issue_projection[0], self.hidden_projection[0], self.errors_projection[0], self.has_issue_head, self.is_hidden_head, self.errors_head, ]: if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) def forward(self, input_ids=None, attention_mask=None, **kwargs): outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, return_dict=True, ) pooled = self.pooler(outputs.last_hidden_state, attention_mask) has_issue_logits = self.has_issue_head( self.has_issue_projection(pooled) ).squeeze(-1) is_hidden_logits = self.is_hidden_head( self.hidden_projection(pooled) ).squeeze(-1) errors_logits = self.errors_head( self.errors_projection(pooled) ) return { "has_issue_logits": has_issue_logits, "is_hidden_logits": is_hidden_logits, "errors_logits": errors_logits, } AutoConfig.register("rqa_v2_2", RQAModelConfig) AutoModel.register(RQAModelConfig, RQAModelHF) print("✅ RQA-R2 зарегистрирован в Transformers")