|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import math
|
| import torch
|
| import numpy as np
|
| from transformers import AutoTokenizer, is_torch_npu_available
|
| from typing import Union, List
|
| from .modeling import CrossEncoder
|
|
|
| import os
|
| os.environ['TOKENIZERS_PARALLELISM'] = 'true'
|
|
|
|
|
| def sigmoid(x):
|
| return 1 / (1 + np.exp(-x))
|
|
|
|
|
| class ListConRanker:
|
| def __init__(
|
| self,
|
| model_name_or_path: str = None,
|
| use_fp16: bool = False,
|
| cache_dir: str = None,
|
| device: Union[str, int] = None,
|
| list_transformer_layer = None
|
| ) -> None:
|
|
|
| self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, cache_dir=cache_dir)
|
| self.model = CrossEncoder.from_pretrained_for_eval(model_name_or_path, list_transformer_layer)
|
|
|
| if device and isinstance(device, str):
|
| self.device = torch.device(device)
|
| if device == 'cpu':
|
| use_fp16 = False
|
| else:
|
| if torch.cuda.is_available():
|
| if device is not None:
|
| self.device = torch.device(f"cuda:{device}")
|
| else:
|
| self.device = torch.device("cuda")
|
| elif torch.backends.mps.is_available():
|
| self.device = torch.device("mps")
|
| elif is_torch_npu_available():
|
| self.device = torch.device("npu")
|
| else:
|
| self.device = torch.device("cpu")
|
| use_fp16 = False
|
| if use_fp16:
|
| self.model.half()
|
|
|
| self.model = self.model.to(self.device)
|
|
|
| self.model.eval()
|
|
|
| if device is None:
|
| self.num_gpus = torch.cuda.device_count()
|
| if self.num_gpus > 1:
|
| print(f"----------using {self.num_gpus}*GPUs----------")
|
| self.model = torch.nn.DataParallel(self.model)
|
| else:
|
| self.num_gpus = 1
|
|
|
| @torch.no_grad()
|
| def compute_score(self, sentence_pairs: List[List[str]], max_length: int = 512) -> List[List[float]]:
|
| pair_nums = [len(pairs) - 1 for pairs in sentence_pairs]
|
| sentences_batch = sum(sentence_pairs, [])
|
| inputs = self.tokenizer(
|
| sentences_batch,
|
| padding=True,
|
| truncation=True,
|
| return_tensors='pt',
|
| max_length=max_length,
|
| ).to(self.device)
|
| inputs['pair_num'] = torch.LongTensor(pair_nums)
|
| scores = self.model(inputs).float()
|
| all_scores = scores.cpu().numpy().tolist()
|
|
|
| if isinstance(all_scores, float):
|
| return [all_scores]
|
| result = []
|
| curr_idx = 0
|
| for i in range(len(pair_nums)):
|
| result.append(all_scores[curr_idx: curr_idx + pair_nums[i]])
|
| curr_idx += pair_nums[i]
|
|
|
| return result
|
|
|
| @torch.no_grad()
|
| def iterative_inference(self, sentence_pairs: List[str], max_length: int = 512) -> List[float]:
|
| query = sentence_pairs[0]
|
| passage = sentence_pairs[1:]
|
|
|
| filter_times = 0
|
| passage2score = {}
|
| while len(passage) > 20:
|
| batch = [[query] + passage]
|
| pred_scores = self.compute_score(batch, max_length)[0]
|
|
|
| pred_scores_argsort = np.argsort(pred_scores).tolist()
|
| passage_len = len(passage)
|
| to_filter_num = math.ceil(passage_len * 0.2)
|
| if to_filter_num < 10:
|
| to_filter_num = 10
|
|
|
| have_filter_num = 0
|
| while have_filter_num < to_filter_num:
|
| idx = pred_scores_argsort[have_filter_num]
|
| if passage[idx] in passage2score:
|
| passage2score[passage[idx]].append(pred_scores[idx] + filter_times)
|
| else:
|
| passage2score[passage[idx]] = [pred_scores[idx] + filter_times]
|
| have_filter_num += 1
|
| while pred_scores[pred_scores_argsort[have_filter_num - 1]] == pred_scores[pred_scores_argsort[have_filter_num]]:
|
| idx = pred_scores_argsort[have_filter_num]
|
| if passage[idx] in passage2score:
|
| passage2score[passage[idx]].append(pred_scores[idx] + filter_times)
|
| else:
|
| passage2score[passage[idx]] = [pred_scores[idx] + filter_times]
|
| have_filter_num += 1
|
| next_passage = []
|
| next_passage_idx = have_filter_num
|
| while next_passage_idx < len(passage):
|
| idx = pred_scores_argsort[next_passage_idx]
|
| next_passage.append(passage[idx])
|
| next_passage_idx += 1
|
| passage = next_passage
|
| filter_times += 1
|
|
|
| batch = [[query] + passage]
|
| pred_scores = self.compute_score(batch, max_length)[0]
|
| cnt = 0
|
| while cnt < len(passage):
|
| if passage[cnt] in passage2score:
|
| passage2score[passage[cnt]].append(pred_scores[cnt] + filter_times)
|
| else:
|
| passage2score[passage[cnt]] = [pred_scores[cnt] + filter_times]
|
| cnt += 1
|
|
|
| passage = sentence_pairs[1:]
|
| final_score = []
|
| for i in range(len(passage)):
|
| p = passage[i]
|
| final_score += passage2score[p]
|
| return final_score
|
|
|