| from typing import Dict, List, Union |
| from transformers import BertPreTrainedModel, BertModel,PreTrainedTokenizer |
| import torch.nn as nn |
| import torch |
| class BertForStorySkillClassification(BertPreTrainedModel): |
| def __init__(self,config): |
| super(BertForStorySkillClassification,self).__init__(config) |
| self.num_labels = config.num_labels |
| self.bert = BertModel(config) |
| self.classifier = nn.Linear(config.hidden_size, self.num_labels) |
| self.post_init() |
|
|
| def forward(self,input_ids,attention_mask=None,labels=None,**kwargs): |
| outputs = self.bert(input_ids,attention_mask=attention_mask) |
| cls_hidden_state = outputs.last_hidden_state[:,0,:] |
| logits = self.classifier(cls_hidden_state) |
| if labels is not None: |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct(logits.view(-1,self.num_labels),labels.view(-1)) |
| return loss |
| return logits |
| |
|
|
| def predict( |
| self, |
| texts: Union[str, List[str]], |
| tokenizer: PreTrainedTokenizer, |
| batch_size: int = 32, |
| return_probabilities: bool = False, |
| device: Union[str, torch.device] = 'cpu', |
| ) -> List[Dict]: |
| """ |
| 对输入文本进行分类预测。 |
| |
| Args: |
| texts: 单条文本或文本列表,例如 "故事中的角色是谁?" 或 ["问题1", "问题2"] |
| tokenizer: 分词器实例(需与模型兼容) |
| batch_size: 批处理大小(提升推理速度) |
| return_probabilities: 是否返回概率值(默认返回标签) |
| device: 指定设备(例如 "cuda" 或 "cpu"),默认自动检测模型当前设备 |
| |
| Returns: |
| 预测结果列表,格式为: |
| [{"text": "输入文本", "label": "预测标签", "score": 置信度}, ...] |
| """ |
| |
| if device is None: |
| device = self.device |
|
|
| |
| if isinstance(texts, str): |
| texts = [texts] |
|
|
| |
| predictions = [] |
|
|
| |
| with torch.no_grad(): |
| for i in range(0, len(texts), batch_size): |
| batch_texts = texts[i : i + batch_size] |
|
|
| |
| inputs = tokenizer( |
| batch_texts, |
| padding=True, |
| truncation=True, |
| return_tensors="pt", |
| max_length=512, |
| ).to(device) |
|
|
| |
| logits = self(**inputs) |
| probs = torch.softmax(logits, dim=-1) |
| scores, class_ids = torch.max(probs, dim=-1) |
|
|
| |
| for text, class_id, score in zip(batch_texts, class_ids, scores): |
| label = self.config.id2label[class_id.item()] |
| result = {"text": text, "label": label} |
| if return_probabilities: |
| result["score"] = score.item() |
| predictions.append(result) |
|
|
| return predictions |