| |
| |
| !pip install transformers==4.48.3 tokenizers sentencepiece accelerate |
|
|
|
|
| |
| |
| |
| |
|
|
| import os |
| import json |
| import csv |
| import torch |
| from typing import List, Optional |
| from transformers import AutoTokenizer, AutoModel |
|
|
|
|
| |
| |
| |
|
|
| ERROR_TYPES = [ |
| "false_causality", |
| "unsupported_claim", |
| "overgeneralization", |
| "missing_premise", |
| "contradiction", |
| "circular_reasoning", |
| ] |
|
|
| ERROR_NAMES_RU = { |
| "false_causality": "Ложная причинно-следственная связь", |
| "unsupported_claim": "Неподкрепленное утверждение", |
| "overgeneralization": "Чрезмерное обобщение", |
| "missing_premise": "Отсутствующая предпосылка", |
| "contradiction": "Противоречие", |
| "circular_reasoning": "Круговое рассуждение", |
| } |
|
|
|
|
| |
| |
| |
|
|
| class RQAJudge: |
| def __init__(self, model_name="skatzR/RQA-R2", device=None, max_length: int = 512): |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
| self.max_length = int(max_length) |
|
|
| self.tokenizer = AutoTokenizer.from_pretrained( |
| model_name, |
| trust_remote_code=True |
| ) |
| self.model = AutoModel.from_pretrained( |
| model_name, |
| trust_remote_code=True |
| ).to(self.device) |
|
|
| self.model.eval() |
|
|
| cfg = self.model.config |
| self.error_types = list(getattr(cfg, "error_types", ERROR_TYPES)) |
|
|
| self.temp_issue = float(getattr(cfg, "temperature_has_issue", 1.0)) |
| self.temp_hidden = float(getattr(cfg, "temperature_is_hidden", 1.0)) |
| self.temp_errors = list( |
| getattr(cfg, "temperature_errors", [1.0] * len(self.error_types)) |
| ) |
|
|
| self.threshold_issue = float(getattr(cfg, "threshold_has_issue", 0.5)) |
| self.threshold_hidden = float(getattr(cfg, "threshold_is_hidden", 0.5)) |
| self.threshold_error = float(getattr(cfg, "threshold_error", 0.5)) |
| self.threshold_errors = list( |
| getattr(cfg, "threshold_errors", [self.threshold_error] * len(self.error_types)) |
| ) |
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def infer( |
| self, |
| text: str, |
| issue_threshold: Optional[float] = None, |
| hidden_threshold: Optional[float] = None, |
| error_threshold: Optional[float] = None, |
| error_thresholds: Optional[List[float]] = None, |
| issue_uncertain_margin: float = 0.05, |
| hidden_uncertain_margin: float = 0.05, |
| error_uncertain_margin: float = 0.05, |
| ): |
| issue_threshold = self.threshold_issue if issue_threshold is None else float(issue_threshold) |
| hidden_threshold = self.threshold_hidden if hidden_threshold is None else float(hidden_threshold) |
| error_threshold = self.threshold_error if error_threshold is None else float(error_threshold) |
| error_thresholds = self.threshold_errors if error_thresholds is None else list(error_thresholds) |
|
|
| inputs = self.tokenizer( |
| text, |
| truncation=True, |
| max_length=self.max_length, |
| padding="max_length", |
| return_tensors="pt" |
| ).to(self.device) |
|
|
| outputs = self.model(**inputs) |
|
|
| |
| issue_logit = outputs["has_issue_logits"] / self.temp_issue |
| issue_prob = torch.sigmoid(issue_logit).item() |
| has_issue = issue_prob >= issue_threshold |
|
|
| result = { |
| "text": text, |
| "class": None, |
| "status": "ok", |
| "review_required": False, |
| "has_issue": has_issue, |
| "issue_probability": issue_prob, |
| "hidden_problem": False, |
| "hidden_probability": None, |
| "errors": [], |
| "num_errors": 0, |
| "schema_version": getattr(self.model.config, "schema_version", "unknown"), |
| "threshold_issue": issue_threshold, |
| "threshold_hidden": hidden_threshold, |
| "threshold_error": error_threshold, |
| "threshold_errors": error_thresholds, |
| } |
|
|
| if abs(issue_prob - issue_threshold) <= issue_uncertain_margin: |
| result["status"] = "uncertain" |
| result["review_required"] = True |
|
|
| |
| if not has_issue: |
| result["class"] = "logical" |
| return result |
|
|
| |
| hidden_logit = outputs["is_hidden_logits"] / self.temp_hidden |
| hidden_prob = torch.sigmoid(hidden_logit).item() |
| is_hidden = hidden_prob >= hidden_threshold |
|
|
| result["hidden_problem"] = is_hidden |
| result["hidden_probability"] = hidden_prob |
|
|
| if abs(hidden_prob - hidden_threshold) <= hidden_uncertain_margin: |
| result["status"] = "uncertain" |
| result["review_required"] = True |
|
|
| |
| if is_hidden: |
| result["class"] = "hidden" |
| return result |
|
|
| |
| raw_error_logits = outputs["errors_logits"][0].clone() |
| error_probs = {} |
|
|
| for i, logit in enumerate(raw_error_logits): |
| calibrated = logit / self.temp_errors[i] |
| prob = torch.sigmoid(calibrated).item() |
| error_probs[self.error_types[i]] = prob |
|
|
| explicit_errors = [] |
| for i, err_name in enumerate(self.error_types): |
| prob = float(error_probs[err_name]) |
| threshold_i = float(error_thresholds[i] if i < len(error_thresholds) else error_threshold) |
|
|
| if abs(prob - threshold_i) <= error_uncertain_margin: |
| result["status"] = "uncertain" |
| result["review_required"] = True |
|
|
| if prob >= threshold_i: |
| explicit_errors.append((err_name, prob)) |
|
|
| explicit_errors.sort(key=lambda x: x[1], reverse=True) |
|
|
| result["class"] = "explicit" |
| result["errors"] = explicit_errors |
| result["num_errors"] = len(explicit_errors) |
| return result |
|
|
| |
| |
| |
|
|
| def pretty_print(self, r): |
| print("\n" + "=" * 72) |
| print("📄 Текст:") |
| print(r["text"]) |
|
|
| print( |
| f"\n🔎 Обнаружена проблема: {'ДА' if r['has_issue'] else 'НЕТ'} " |
| f"({r['issue_probability'] * 100:.2f}%)" |
| ) |
| print(f"🧠 Класс: {r['class']}") |
|
|
| if r["status"] == "uncertain": |
| print("⚠️ Пограничный случай: review recommended") |
|
|
| if r["hidden_probability"] is not None: |
| print( |
| f"🟡 Hidden-проблема: {'ДА' if r['hidden_problem'] else 'НЕТ'} " |
| f"({r['hidden_probability'] * 100:.2f}%)" |
| ) |
|
|
| if r["errors"]: |
| print("\n❌ Явные логические ошибки:") |
| for name, prob in r["errors"]: |
| print(f" • {ERROR_NAMES_RU.get(name, name)} — {prob * 100:.2f}%") |
| else: |
| print("\n✅ Явных логических ошибок не обнаружено") |
|
|
| print("=" * 72) |
|
|
|
|
| |
| |
| |
|
|
| def load_texts_from_file(path: str) -> List[str]: |
| ext = os.path.splitext(path)[1].lower() |
|
|
| if ext == ".txt": |
| with open(path, encoding="utf-8") as f: |
| return [line.strip() for line in f if line.strip()] |
|
|
| if ext == ".csv": |
| with open(path, encoding="utf-8") as f: |
| reader = csv.DictReader(f) |
| return [row["text"] for row in reader if row.get("text")] |
|
|
| if ext == ".json": |
| with open(path, encoding="utf-8") as f: |
| data = json.load(f) |
| if isinstance(data, list): |
| if all(isinstance(item, str) for item in data): |
| return data |
| texts = [] |
| for item in data: |
| if isinstance(item, dict) and "text" in item: |
| texts.append(str(item["text"])) |
| return texts |
|
|
| raise ValueError("Неподдерживаемый формат файла") |
|
|
|
|
| |
| |
| |
|
|
| class InteractiveCLI: |
| def __init__(self, model_name="skatzR/RQA-R2"): |
| self.judge = RQAJudge(model_name=model_name) |
|
|
| def clear_screen(self): |
| print("\n" * 2) |
|
|
| def show_mode_menu(self): |
| self.clear_screen() |
| print("=" * 60) |
| print("🤖 RQA-R2 — АНАЛИЗ ЛОГИЧЕСКИХ ОШИБОК") |
| print("=" * 60) |
| print("\nВыберите режим работы:") |
| print("1. 📝 Одиночный ввод (одна фраза для анализа)") |
| print("2. 📄 Множественный ввод (несколько фраз, каждая с новой строки)") |
| print("3. 📂 Загрузка из файла (.txt, .csv, .json)") |
| print("\nНажмите Enter без ввода для выхода.") |
| print("-" * 60) |
|
|
| def process_single_mode(self): |
| self.clear_screen() |
| print("[📝 РЕЖИМ: ОДИНОЧНЫЙ ВВОД]") |
| print("Введите текст для анализа:") |
| print("(Нажмите Enter без ввода для возврата в меню)") |
| print("-" * 40) |
|
|
| text = input("> ").strip() |
| if not text: |
| return True |
|
|
| result = self.judge.infer(text) |
| self.judge.pretty_print(result) |
|
|
| print("\n" + "-" * 40) |
| input("Нажмите Enter для продолжения...") |
| return False |
|
|
| def process_multiline_mode(self): |
| self.clear_screen() |
| print("[📄 РЕЖИМ: МНОЖЕСТВЕННЫЙ ВВОД]") |
| print("Введите тексты для анализа (каждый с новой строки).") |
| print("Оставьте строку пустой для завершения ввода.") |
| print("(Нажмите Enter без ввода для возврата в меню)") |
| print("-" * 40) |
|
|
| texts = [] |
| print("Ввод текстов:") |
| while True: |
| line = input("> ").strip() |
| if not line: |
| if not texts: |
| return True |
| break |
| texts.append(line) |
|
|
| self.clear_screen() |
| print(f"[📄 РЕЖИМ: МНОЖЕСТВЕННЫЙ ВВОД] — найдено {len(texts)} текстов") |
| print("-" * 40) |
|
|
| for i, text in enumerate(texts, 1): |
| print(f"\n🔍 Текст #{i}:") |
| result = self.judge.infer(text) |
| self.judge.pretty_print(result) |
|
|
| print("\n" + "=" * 60) |
| input("Нажмите Enter для продолжения...") |
| return False |
|
|
| def process_file_mode(self): |
| self.clear_screen() |
| print("[📂 РЕЖИМ: ЗАГРУЗКА ИЗ ФАЙЛА]") |
| print("Поддерживаемые форматы: .txt, .csv, .json") |
| print("Укажите путь к файлу:") |
| print("(Нажмите Enter без ввода для возврата в меню)") |
| print("-" * 40) |
|
|
| file_path = input("Путь к файлу> ").strip() |
| if not file_path: |
| return True |
|
|
| try: |
| if not os.path.exists(file_path): |
| print(f"\n❌ Ошибка: Файл '{file_path}' не найден!") |
| input("\nНажмите Enter для продолжения...") |
| return False |
|
|
| texts = load_texts_from_file(file_path) |
| if not texts: |
| print(f"\n⚠️ Файл '{file_path}' пуст или не содержит текстов!") |
| input("\nНажмите Enter для продолжения...") |
| return False |
|
|
| self.clear_screen() |
| print(f"[📂 РЕЖИМ: ЗАГРУЗКА ИЗ ФАЙЛА] — загружено {len(texts)} текстов") |
| print(f"Файл: {file_path}") |
| print("-" * 40) |
|
|
| for i, text in enumerate(texts, 1): |
| print(f"\n🔍 Текст #{i}:") |
| result = self.judge.infer(text) |
| self.judge.pretty_print(result) |
|
|
| print("\n" + "=" * 60) |
| input("Нажмите Enter для продолжения...") |
|
|
| except Exception as e: |
| print(f"\n❌ Ошибка при обработке файла: {str(e)}") |
| input("\nНажмите Enter для продолжения...") |
|
|
| return False |
|
|
| def run_interactive(self): |
| current_mode = None |
|
|
| while True: |
| if not current_mode: |
| self.show_mode_menu() |
| choice = input("Ваш выбор (1-3)> ").strip() |
|
|
| if not choice: |
| print("\n👋 Выход из программы...") |
| break |
|
|
| if choice == "1": |
| current_mode = "single" |
| elif choice == "2": |
| current_mode = "multiline" |
| elif choice == "3": |
| current_mode = "file" |
| else: |
| print("\n❌ Неверный выбор! Попробуйте снова.") |
| input("Нажмите Enter для продолжения...") |
| continue |
|
|
| should_return_to_menu = False |
|
|
| if current_mode == "single": |
| should_return_to_menu = self.process_single_mode() |
| elif current_mode == "multiline": |
| should_return_to_menu = self.process_multiline_mode() |
| elif current_mode == "file": |
| should_return_to_menu = self.process_file_mode() |
|
|
| if should_return_to_menu: |
| current_mode = None |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| cli = InteractiveCLI() |
| cli.run_interactive() |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| main() |
|
|