| import sys |
| import os |
| import json |
| import logging |
| from typing import List, Dict, Tuple, Optional |
|
|
| import time |
| import numpy as np |
| from tqdm import tqdm |
| import onnxruntime as ort |
| from transformers import AutoTokenizer |
|
|
| class StopJudgmentONNXInference: |
| def __init__(self, onnx_model_path: str, tokenizer_path: str, device: str = 'auto'): |
| """ |
| 判停模型ONNX推理类 |
| |
| Args: |
| onnx_model_path: ONNX模型路径 |
| tokenizer_path: tokenizer路径 |
| device: 设备类型 ('auto', 'cuda', 'cpu') |
| """ |
| self.onnx_model_path = onnx_model_path |
| self.tokenizer_path = tokenizer_path |
| self.setup_logging() |
| self.load_model_and_tokenizer() |
| |
| def setup_logging(self): |
| """设置日志""" |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
| ) |
| self.logger = logging.getLogger(__name__) |
| |
| def load_model_and_tokenizer(self): |
| """加载ONNX模型和tokenizer""" |
| |
| try: |
| self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path, local_files_only=True) |
| self.logger.info("Tokenizer loaded successfully") |
| except Exception as e: |
| self.logger.error(f"Failed to load tokenizer: {e}") |
| raise |
| |
| |
| providers = [] |
| |
| |
| available_providers = ort.get_available_providers() |
| if 'CUDAExecutionProvider' in available_providers: |
| providers.append('CUDAExecutionProvider') |
| self.logger.info("CUDA provider is available and will be used") |
| |
| providers.append('CPUExecutionProvider') |
| |
| try: |
| self.ort_session = ort.InferenceSession(self.onnx_model_path, providers=providers) |
| self.logger.info(f"ONNX model loaded successfully with providers: {self.ort_session.get_providers()}") |
| except Exception as e: |
| self.logger.error(f"Failed to load ONNX model: {e}") |
| raise |
| |
| |
| self.input_names = [input.name for input in self.ort_session.get_inputs()] |
| self.output_names = [output.name for output in self.ort_session.get_outputs()] |
| |
| self.logger.info(f"Input names: {self.input_names}") |
| self.logger.info(f"Output names: {self.output_names}") |
| |
| def preprocess_text(self, texts: List[str], max_length: int = 128) -> Dict[str, np.ndarray]: |
| """ |
| 预处理文本数据 |
| |
| Args: |
| texts: 文本列表 |
| max_length: 最大长度 |
| |
| Returns: |
| 包含input_ids和attention_mask的字典 |
| """ |
| encoding = self.tokenizer( |
| texts, |
| truncation=True, |
| padding='max_length', |
| max_length=max_length, |
| return_tensors='np' |
| ) |
| |
| return { |
| 'input_ids': encoding['input_ids'].astype(np.int64), |
| 'attention_mask': encoding['attention_mask'].astype(np.int64) |
| } |
| |
| def predict_single(self, text: str, max_length: int = 128) -> Tuple[int, float]: |
| """单个文本预测""" |
| inputs = self.preprocess_text([text], max_length) |
| |
| |
| ort_inputs = { |
| self.input_names[0]: inputs['input_ids'], |
| self.input_names[1]: inputs['attention_mask'] |
| } |
| |
| ort_outputs = self.ort_session.run(self.output_names, ort_inputs) |
| logits = ort_outputs[0] |
| |
| |
| probabilities = self.softmax(logits) |
| prediction = np.argmax(probabilities[0]) |
| confidence = probabilities[0][prediction] |
| |
| return int(prediction), float(confidence) |
| |
| def predict_batch(self, texts: List[str], max_length: int = 128, |
| batch_size: int = 32) -> Tuple[List[int], List[float]]: |
| """批量预测""" |
| all_predictions = [] |
| all_confidences = [] |
| |
| for i in tqdm(range(0, len(texts), batch_size), desc="ONNX Predicting"): |
| batch_texts = texts[i:i + batch_size] |
| inputs = self.preprocess_text(batch_texts, max_length) |
| |
| |
| ort_inputs = { |
| self.input_names[0]: inputs['input_ids'], |
| self.input_names[1]: inputs['attention_mask'] |
| } |
| |
| ort_outputs = self.ort_session.run(self.output_names, ort_inputs) |
| logits = ort_outputs[0] |
| |
| |
| probabilities = self.softmax(logits) |
| predictions = np.argmax(probabilities, axis=1) |
| confidences = [probabilities[j][pred] for j, pred in enumerate(predictions)] |
| |
| all_predictions.extend(predictions.tolist()) |
| all_confidences.extend(confidences) |
| |
| return all_predictions, all_confidences |
| |
| @staticmethod |
| def softmax(x): |
| """Softmax函数""" |
| exp_x = np.exp(x - np.max(x, axis=1, keepdims=True)) |
| return exp_x / np.sum(exp_x, axis=1, keepdims=True) |
|
|
| def main(): |
| """主函数""" |
| if len(sys.argv) < 3: |
| print("Usage: python validate_onnx.py <tokenizer_path> <onnx_model_path> [test_sentence]") |
| sys.exit(1) |
| |
| tokenizer_path = sys.argv[1] |
| onnx_model_path = sys.argv[2] |
| test_sentence = sys.argv[3] if len(sys.argv) > 3 else "欢迎测试本判停模型有修正建议请随时提出" |
|
|
| print("\n ONNX Model Inference...") |
| onnx_inferencer = StopJudgmentONNXInference(onnx_model_path, tokenizer_path) |
| prediction, confidence = onnx_inferencer.predict_single( |
| test_sentence, max_length=128 |
| ) |
| print(prediction, confidence) |
| |
| if __name__ == "__main__": |
| main() |
|
|