| from typing import Dict, Any, List |
| import torch |
| from transformers import AutoTokenizer, AutoModel |
| import os |
| import json |
|
|
| class EndpointHandler: |
| def __init__(self, path: str = ""): |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.tokenizer = AutoTokenizer.from_pretrained(path) |
| self.tokenizer.add_special_tokens({ |
| "additional_special_tokens": ["[QUERY]", "[LABEL_NAME]", "[LABEL_DESCRIPTION]"] |
| }) |
| self.model = AutoModel.from_pretrained(path).to(self.device) |
|
|
| head_path = os.path.join(path, "classifier_head.json") |
| with open(head_path, "r") as f: |
| head = json.load(f) |
|
|
| self.classifier = torch.nn.Linear(self.model.config.hidden_size, 1).to(self.device) |
| self.classifier.weight.data = torch.tensor(head["scorer_weight"]).to(self.device) |
| self.classifier.bias.data = torch.tensor(head["scorer_bias"]).to(self.device) |
|
|
| self.model.eval() |
| |
| |
| self.max_batch_size = 128 |
| self.max_length = 64 |
|
|
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| payload = data.get("inputs", data) |
| |
| |
| if "queries" in payload: |
| return self._process_batch(payload) |
| else: |
| return self._process_single(payload) |
| |
| def _process_single(self, payload: Dict[str, Any]) -> List[Dict[str, Any]]: |
| """Original single query processing for backward compatibility""" |
| query = payload["query"] |
| candidates = payload["candidates"] |
| results = [] |
|
|
| with torch.no_grad(): |
| for entry in candidates: |
| text = f"[QUERY] {query} [LABEL_NAME] {entry['label']} [LABEL_DESCRIPTION] {entry['description']}" |
| tokens = self.tokenizer( |
| text, |
| return_tensors="pt", |
| padding="max_length", |
| truncation=True, |
| max_length=self.max_length |
| ).to(self.device) |
|
|
| out = self.model(**tokens) |
| cls = out.last_hidden_state[:, 0, :] |
| score = torch.sigmoid(self.classifier(cls)).item() |
| results.append({ |
| "label": entry["label"], |
| "description": entry["description"], |
| "score": round(score, 4) |
| }) |
|
|
| return sorted(results, key=lambda x: x["score"], reverse=True) |
| |
| def _process_batch(self, payload: Dict[str, Any]) -> List[List[Dict[str, Any]]]: |
| """True batch processing for multiple queries""" |
| queries = payload["queries"] |
| candidates = payload["candidates"] |
| |
| |
| all_texts = [] |
| query_indices = [] |
| candidate_indices = [] |
| |
| for q_idx, query in enumerate(queries): |
| for c_idx, candidate in enumerate(candidates): |
| text = f"[QUERY] {query} [LABEL_NAME] {candidate['label']} [LABEL_DESCRIPTION] {candidate['description']}" |
| all_texts.append(text) |
| query_indices.append(q_idx) |
| candidate_indices.append(c_idx) |
| |
| |
| all_scores = [] |
| total_combinations = len(all_texts) |
| |
| with torch.no_grad(): |
| for i in range(0, total_combinations, self.max_batch_size): |
| batch_texts = all_texts[i:i + self.max_batch_size] |
| |
| |
| tokens = self.tokenizer( |
| batch_texts, |
| return_tensors="pt", |
| padding="max_length", |
| truncation=True, |
| max_length=self.max_length |
| ).to(self.device) |
| |
| |
| out = self.model(**tokens) |
| cls = out.last_hidden_state[:, 0, :] |
| scores = torch.sigmoid(self.classifier(cls)).squeeze() |
| |
| |
| if scores.dim() == 0: |
| scores = scores.unsqueeze(0) |
| |
| all_scores.extend(scores.cpu().tolist()) |
| |
| |
| results = [] |
| for q_idx in range(len(queries)): |
| query_results = [] |
| for c_idx, candidate in enumerate(candidates): |
| |
| combination_idx = q_idx * len(candidates) + c_idx |
| score = all_scores[combination_idx] |
| |
| query_results.append({ |
| "label": candidate["label"], |
| "description": candidate["description"], |
| "score": round(score, 4) |
| }) |
| |
| |
| query_results.sort(key=lambda x: x["score"], reverse=True) |
| results.append(query_results) |
| |
| return results |
| |
| def get_batch_stats(self) -> Dict[str, Any]: |
| """Return batch processing statistics""" |
| return { |
| "max_batch_size": self.max_batch_size, |
| "max_length": self.max_length, |
| "device": str(self.device), |
| "model_name": self.model.config.name_or_path if hasattr(self.model.config, 'name_or_path') else "unknown" |
| } |
|
|
|
|