skatzR commited on
Commit
67d909f
·
verified ·
1 Parent(s): d9c4c42

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +104 -285
inference.py CHANGED
@@ -1,13 +1,11 @@
1
- import os
2
- from typing import Any, Dict, List, Optional
 
3
 
4
- import torch
5
- from transformers import AutoModel, AutoTokenizer
6
 
7
- try:
8
- from huggingface_hub import hf_hub_download
9
- except Exception:
10
- hf_hub_download = None
11
 
12
 
13
  ERROR_NAMES_RU = {
@@ -20,335 +18,156 @@ ERROR_NAMES_RU = {
20
  }
21
 
22
 
23
- def _resolve_calibration_path(model_path: str) -> Optional[str]:
24
- local_path = os.path.join(model_path, "calibration_data.pth")
25
- if os.path.exists(local_path):
26
- return local_path
27
-
28
- if hf_hub_download is None or os.path.isdir(model_path):
29
- return None
30
-
31
- try:
32
- return hf_hub_download(
33
- repo_id=model_path,
34
- filename="calibration_data.pth",
35
- )
36
- except Exception:
37
- return None
38
-
39
-
40
- class RQAInferenceHF:
41
- def __init__(
42
- self,
43
- model_path: str,
44
- device: Optional[torch.device] = None,
45
- max_length: int = 512,
46
- issue_uncertain_margin: float = 0.05,
47
- hidden_uncertain_margin: float = 0.05,
48
- error_uncertain_margin: float = 0.05,
49
- ):
50
- self.model_path = model_path
51
- self.device = device or torch.device(
52
- "cuda" if torch.cuda.is_available() else "cpu"
53
- )
54
  self.max_length = int(max_length)
55
- self.issue_uncertain_margin = float(issue_uncertain_margin)
56
- self.hidden_uncertain_margin = float(hidden_uncertain_margin)
57
- self.error_uncertain_margin = float(error_uncertain_margin)
58
 
 
 
 
 
59
  self.model = AutoModel.from_pretrained(
60
- model_path,
61
- trust_remote_code=True,
62
- ).to(self.device).eval()
63
- self.tokenizer = AutoTokenizer.from_pretrained(model_path)
 
64
 
65
  cfg = self.model.config
66
- self.schema_version = str(getattr(cfg, "schema_version", "unknown"))
67
- self.error_types = list(getattr(cfg, "error_types", []))
68
- self.t_issue = float(getattr(cfg, "temperature_has_issue", 1.0))
69
- self.t_hidden = float(getattr(cfg, "temperature_is_hidden", 1.0))
70
- self.t_errors = list(
71
- getattr(cfg, "temperature_errors", [1.0] * len(self.error_types))
72
- )
73
- self.th_issue = float(getattr(cfg, "threshold_has_issue", 0.5))
74
- self.th_hidden = float(getattr(cfg, "threshold_is_hidden", 0.5))
75
- self.th_error = float(getattr(cfg, "threshold_error", 0.5))
76
- self.th_errors = list(
77
- getattr(cfg, "threshold_errors", [self.th_error] * len(self.error_types))
78
- )
79
 
80
- calibration_path = _resolve_calibration_path(model_path)
81
- if calibration_path:
82
- calibration = torch.load(calibration_path, map_location="cpu")
83
- calibration_error_types = calibration.get("error_types", None)
84
- if calibration_error_types is not None:
85
- if list(calibration_error_types) != self.error_types:
86
- raise ValueError(
87
- "Calibration artifact error_types mismatch with model.config.error_types."
88
- )
89
-
90
- self.schema_version = str(
91
- calibration.get("schema_version", self.schema_version)
92
- )
93
- self.t_issue = float(
94
- calibration.get("temperature_has_issue", self.t_issue)
95
- )
96
- self.t_hidden = float(
97
- calibration.get("temperature_is_hidden", self.t_hidden)
98
- )
99
- self.t_errors = list(
100
- calibration.get("temperature_errors", self.t_errors)
101
- )
102
- self.th_issue = float(
103
- calibration.get("threshold_has_issue", self.th_issue)
104
- )
105
- self.th_hidden = float(
106
- calibration.get("threshold_is_hidden", self.th_hidden)
107
- )
108
- self.th_error = float(
109
- calibration.get("threshold_error", self.th_error)
110
- )
111
- self.th_errors = list(
112
- calibration.get("threshold_errors", self.th_errors)
113
- )
114
 
115
- def _apply_temperature(
116
- self,
117
- issue_logits: torch.Tensor,
118
- hidden_logits: torch.Tensor,
119
- errors_logits: torch.Tensor,
120
- ):
121
- calibrated_issue = issue_logits / float(self.t_issue)
122
- calibrated_hidden = hidden_logits / float(self.t_hidden)
123
- calibrated_errors = errors_logits.clone()
124
- for idx in range(calibrated_errors.size(1)):
125
- temperature = float(self.t_errors[idx]) if idx < len(self.t_errors) else 1.0
126
- calibrated_errors[:, idx] = calibrated_errors[:, idx] / temperature
127
- return calibrated_issue, calibrated_hidden, calibrated_errors
128
 
129
  @torch.no_grad()
130
- def predict(
131
  self,
132
  text: str,
133
- return_probs: bool = False,
134
- threshold_issue: Optional[float] = None,
135
- threshold_hidden: Optional[float] = None,
136
- threshold_error: Optional[float] = None,
137
- threshold_errors: Optional[List[float]] = None,
138
- ) -> Dict[str, Any]:
139
- issue_threshold = self.th_issue if threshold_issue is None else float(threshold_issue)
140
- hidden_threshold = self.th_hidden if threshold_hidden is None else float(threshold_hidden)
141
- error_threshold = self.th_error if threshold_error is None else float(threshold_error)
142
- error_thresholds = self.th_errors if threshold_errors is None else list(threshold_errors)
143
-
144
- encoded = self.tokenizer(
 
 
145
  text,
146
  truncation=True,
147
  max_length=self.max_length,
148
  padding="max_length",
149
- return_tensors="pt",
150
- )
151
- input_ids = encoded["input_ids"].to(self.device)
152
- attention_mask = encoded["attention_mask"].to(self.device)
153
-
154
- outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
155
- issue_logits, hidden_logits, errors_logits = self._apply_temperature(
156
- outputs["has_issue_logits"],
157
- outputs["is_hidden_logits"],
158
- outputs["errors_logits"],
159
- )
160
 
161
- issue_probability = float(torch.sigmoid(issue_logits).item())
162
- has_issue = issue_probability >= issue_threshold
 
163
 
164
- result: Dict[str, Any] = {
165
- "schema_version": self.schema_version,
 
 
166
  "text": text,
167
  "class": None,
168
  "status": "ok",
169
  "review_required": False,
170
- "has_logical_issue": bool(has_issue),
171
- "has_issue_probability": issue_probability,
172
- "threshold_has_issue": issue_threshold,
173
- "temperature_has_issue": float(self.t_issue),
174
- "is_hidden_problem": False,
175
  "hidden_probability": None,
176
- "threshold_is_hidden": hidden_threshold,
177
- "temperature_is_hidden": float(self.t_hidden),
178
  "errors": [],
179
  "num_errors": 0,
 
 
180
  "threshold_error": error_threshold,
181
  "threshold_errors": error_thresholds,
182
- "calibrated": (
183
- abs(self.t_issue - 1.0) > 1e-6
184
- or abs(self.t_hidden - 1.0) > 1e-6
185
- or any(abs(float(t) - 1.0) > 1e-6 for t in self.t_errors)
186
- ),
187
  }
188
 
189
- if abs(issue_probability - issue_threshold) <= self.issue_uncertain_margin:
190
  result["status"] = "uncertain"
191
  result["review_required"] = True
192
 
193
  if not has_issue:
194
  result["class"] = "logical"
195
- if return_probs:
196
- result["raw"] = {"p_issue": issue_probability}
197
  return result
198
 
199
- hidden_probability = float(torch.sigmoid(hidden_logits).item())
200
- is_hidden = hidden_probability >= hidden_threshold
201
- result["hidden_probability"] = hidden_probability
202
- result["is_hidden_problem"] = bool(is_hidden)
 
203
 
204
- if abs(hidden_probability - hidden_threshold) <= self.hidden_uncertain_margin:
205
  result["status"] = "uncertain"
206
  result["review_required"] = True
207
 
208
  if is_hidden:
209
  result["class"] = "hidden"
210
- if return_probs:
211
- result["raw"] = {
212
- "p_issue": issue_probability,
213
- "p_hidden": hidden_probability,
214
- }
215
  return result
216
 
217
- error_probabilities = torch.sigmoid(errors_logits).cpu().numpy()[0]
218
- detected_errors = []
219
- for idx, error_type in enumerate(self.error_types):
220
- probability = float(error_probabilities[idx])
221
- threshold_i = float(
222
- error_thresholds[idx] if idx < len(error_thresholds) else error_threshold
223
- )
224
- if abs(probability - threshold_i) <= self.error_uncertain_margin:
225
  result["status"] = "uncertain"
226
  result["review_required"] = True
227
- if probability >= threshold_i:
228
- detected_errors.append(
229
- {
230
- "type": error_type,
231
- "probability": probability,
232
- "threshold": threshold_i,
233
- "temperature": float(self.t_errors[idx]) if idx < len(self.t_errors) else 1.0,
234
- }
235
- )
236
-
237
- detected_errors.sort(key=lambda item: item["probability"], reverse=True)
238
- result["class"] = "explicit"
239
- result["errors"] = detected_errors
240
- result["num_errors"] = len(detected_errors)
241
-
242
- if return_probs:
243
- result["error_probabilities"] = {
244
- error_type: float(probability)
245
- for error_type, probability in zip(self.error_types, error_probabilities)
246
- }
247
- result["raw"] = {
248
- "p_issue": issue_probability,
249
- "p_hidden": hidden_probability,
250
- }
251
 
252
- return result
 
253
 
254
- def pretty_print(self, prediction: Dict[str, Any], use_russian_names: bool = True) -> None:
255
- print("-" * 70)
256
- print(
257
- f"Class: {prediction['class']} | status={prediction['status']} "
258
- f"| review_required={prediction['review_required']}"
259
- )
260
- print(
261
- f"Issue: {prediction['has_logical_issue']} "
262
- f"({prediction['has_issue_probability'] * 100:.2f}%) "
263
- f"th={prediction['threshold_has_issue']:.3f}"
264
- )
265
- if prediction["hidden_probability"] is not None:
266
- print(
267
- f"Hidden: {prediction['is_hidden_problem']} "
268
- f"({prediction['hidden_probability'] * 100:.2f}%) "
269
- f"th={prediction['threshold_is_hidden']:.3f}"
270
- )
271
 
272
- if prediction["errors"]:
273
- printable_errors = []
274
- for item in prediction["errors"]:
275
- label = (
276
- ERROR_NAMES_RU.get(item["type"], item["type"])
277
- if use_russian_names
278
- else item["type"]
279
- )
280
- printable_errors.append((label, round(item["probability"], 3)))
281
- print(f"Top errors: {printable_errors}")
282
 
 
 
 
 
283
 
284
- class RQAJudge:
285
- def __init__(
286
- self,
287
- model_name: str = "skatzR/RQA-R2",
288
- device: Optional[torch.device] = None,
289
- max_length: int = 512,
290
- ):
291
- self.runner = RQAInferenceHF(
292
- model_path=model_name,
293
- device=device,
294
- max_length=max_length,
295
  )
 
296
 
297
- def infer(
298
- self,
299
- text: str,
300
- issue_threshold: Optional[float] = None,
301
- hidden_threshold: Optional[float] = None,
302
- error_threshold: Optional[float] = None,
303
- error_thresholds: Optional[List[float]] = None,
304
- ) -> Dict[str, Any]:
305
- prediction = self.runner.predict(
306
- text=text,
307
- return_probs=True,
308
- threshold_issue=issue_threshold,
309
- threshold_hidden=hidden_threshold,
310
- threshold_error=error_threshold,
311
- threshold_errors=error_thresholds,
312
- )
313
- return {
314
- "text": text,
315
- "class": prediction["class"],
316
- "status": prediction["status"],
317
- "review_required": prediction["review_required"],
318
- "has_issue": prediction["has_logical_issue"],
319
- "issue_probability": prediction["has_issue_probability"],
320
- "hidden_problem": prediction["is_hidden_problem"],
321
- "hidden_probability": prediction["hidden_probability"],
322
- "errors": [
323
- (item["type"], item["probability"])
324
- for item in prediction["errors"]
325
- ],
326
- "num_errors": prediction["num_errors"],
327
- "threshold_has_issue": prediction["threshold_has_issue"],
328
- "threshold_is_hidden": prediction["threshold_is_hidden"],
329
- "threshold_error": prediction["threshold_error"],
330
- }
331
 
332
- def pretty_print(self, result: Dict[str, Any], use_russian_names: bool = True) -> None:
333
- converted = {
334
- "class": result["class"],
335
- "status": result["status"],
336
- "review_required": result["review_required"],
337
- "has_logical_issue": result["has_issue"],
338
- "has_issue_probability": result["issue_probability"],
339
- "threshold_has_issue": result["threshold_has_issue"],
340
- "is_hidden_problem": result["hidden_problem"],
341
- "hidden_probability": result["hidden_probability"],
342
- "threshold_is_hidden": result["threshold_is_hidden"],
343
- "errors": [
344
- {
345
- "type": error_type,
346
- "probability": probability,
347
- }
348
- for error_type, probability in result["errors"]
349
- ],
350
- }
351
- self.runner.pretty_print(converted, use_russian_names=use_russian_names)
352
 
 
 
 
 
 
 
353
 
354
- __all__ = ["RQAInferenceHF", "RQAJudge", "ERROR_NAMES_RU"]
 
1
+ # requirements
2
+ !pip install torch==2.8.0 torchvision==0.17.2
3
+ !pip install transformers==4.48.3 tokenizers sentencepiece accelerate
4
 
 
 
5
 
6
+ import torch
7
+ from typing import List, Optional
8
+ from transformers import AutoTokenizer, AutoModel
 
9
 
10
 
11
  ERROR_NAMES_RU = {
 
18
  }
19
 
20
 
21
+ class RQAJudge:
22
+ def __init__(self, model_name="skatzR/RQA-R2", device=None, max_length: int = 512):
23
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  self.max_length = int(max_length)
 
 
 
25
 
26
+ self.tokenizer = AutoTokenizer.from_pretrained(
27
+ model_name,
28
+ trust_remote_code=True
29
+ )
30
  self.model = AutoModel.from_pretrained(
31
+ model_name,
32
+ trust_remote_code=True
33
+ ).to(self.device)
34
+
35
+ self.model.eval()
36
 
37
  cfg = self.model.config
38
+ self.error_types = list(cfg.error_types)
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ self.temp_issue = float(cfg.temperature_has_issue)
41
+ self.temp_hidden = float(cfg.temperature_is_hidden)
42
+ self.temp_errors = list(cfg.temperature_errors)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ self.threshold_issue = float(cfg.threshold_has_issue)
45
+ self.threshold_hidden = float(cfg.threshold_is_hidden)
46
+ self.threshold_error = float(cfg.threshold_error)
47
+ self.threshold_errors = list(cfg.threshold_errors)
 
 
 
 
 
 
 
 
 
48
 
49
  @torch.no_grad()
50
+ def infer(
51
  self,
52
  text: str,
53
+ issue_threshold: Optional[float] = None,
54
+ hidden_threshold: Optional[float] = None,
55
+ error_threshold: Optional[float] = None,
56
+ error_thresholds: Optional[List[float]] = None,
57
+ issue_uncertain_margin: float = 0.05,
58
+ hidden_uncertain_margin: float = 0.05,
59
+ error_uncertain_margin: float = 0.05,
60
+ ):
61
+ issue_threshold = self.threshold_issue if issue_threshold is None else float(issue_threshold)
62
+ hidden_threshold = self.threshold_hidden if hidden_threshold is None else float(hidden_threshold)
63
+ error_threshold = self.threshold_error if error_threshold is None else float(error_threshold)
64
+ error_thresholds = self.threshold_errors if error_thresholds is None else list(error_thresholds)
65
+
66
+ inputs = self.tokenizer(
67
  text,
68
  truncation=True,
69
  max_length=self.max_length,
70
  padding="max_length",
71
+ return_tensors="pt"
72
+ ).to(self.device)
73
+
74
+ outputs = self.model(**inputs)
75
+
76
+ issue_logit = outputs["has_issue_logits"] / self.temp_issue
77
+ hidden_logit = outputs["is_hidden_logits"] / self.temp_hidden
 
 
 
 
78
 
79
+ error_logits = outputs["errors_logits"][0].clone()
80
+ for i in range(len(self.error_types)):
81
+ error_logits[i] = error_logits[i] / self.temp_errors[i]
82
 
83
+ issue_prob = torch.sigmoid(issue_logit).item()
84
+ has_issue = issue_prob >= issue_threshold
85
+
86
+ result = {
87
  "text": text,
88
  "class": None,
89
  "status": "ok",
90
  "review_required": False,
91
+ "has_issue": has_issue,
92
+ "issue_probability": issue_prob,
93
+ "hidden_problem": False,
 
 
94
  "hidden_probability": None,
 
 
95
  "errors": [],
96
  "num_errors": 0,
97
+ "threshold_issue": issue_threshold,
98
+ "threshold_hidden": hidden_threshold,
99
  "threshold_error": error_threshold,
100
  "threshold_errors": error_thresholds,
101
+ "schema_version": getattr(self.model.config, "schema_version", "unknown"),
 
 
 
 
102
  }
103
 
104
+ if abs(issue_prob - issue_threshold) <= issue_uncertain_margin:
105
  result["status"] = "uncertain"
106
  result["review_required"] = True
107
 
108
  if not has_issue:
109
  result["class"] = "logical"
 
 
110
  return result
111
 
112
+ hidden_prob = torch.sigmoid(hidden_logit).item()
113
+ is_hidden = hidden_prob >= hidden_threshold
114
+
115
+ result["hidden_problem"] = is_hidden
116
+ result["hidden_probability"] = hidden_prob
117
 
118
+ if abs(hidden_prob - hidden_threshold) <= hidden_uncertain_margin:
119
  result["status"] = "uncertain"
120
  result["review_required"] = True
121
 
122
  if is_hidden:
123
  result["class"] = "hidden"
 
 
 
 
 
124
  return result
125
 
126
+ error_probs = torch.sigmoid(error_logits).tolist()
127
+ detected = []
128
+ for i, err_name in enumerate(self.error_types):
129
+ prob = float(error_probs[i])
130
+ threshold_i = float(error_thresholds[i] if i < len(error_thresholds) else error_threshold)
131
+
132
+ if abs(prob - threshold_i) <= error_uncertain_margin:
 
133
  result["status"] = "uncertain"
134
  result["review_required"] = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
+ if prob >= threshold_i:
137
+ detected.append((err_name, prob))
138
 
139
+ detected.sort(key=lambda x: x[1], reverse=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
+ result["class"] = "explicit"
142
+ result["errors"] = detected
143
+ result["num_errors"] = len(detected)
144
+ return result
 
 
 
 
 
 
145
 
146
+ def pretty_print(self, r):
147
+ print("\n" + "=" * 72)
148
+ print("📄 Текст:")
149
+ print(r["text"])
150
 
151
+ print(
152
+ f"\n🔎 Обнаружена проблема: {'ДА' if r['has_issue'] else 'НЕТ'} "
153
+ f"({r['issue_probability'] * 100:.2f}%)"
 
 
 
 
 
 
 
 
154
  )
155
+ print(f"🧠 Класс: {r['class']}")
156
 
157
+ if r["status"] == "uncertain":
158
+ print("⚠️ Статус: uncertain")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
+ if r["hidden_probability"] is not None:
161
+ print(
162
+ f"🟡 Hidden: {'ДА' if r['hidden_problem'] else 'НЕТ'} "
163
+ f"({r['hidden_probability'] * 100:.2f}%)"
164
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
+ if r["errors"]:
167
+ print("\n❌ Явные логические ошибки:")
168
+ for name, prob in r["errors"]:
169
+ print(f" • {ERROR_NAMES_RU.get(name, name)} — {prob * 100:.2f}%")
170
+ else:
171
+ print("\n✅ Явных логических ошибок не обнаружено")
172
 
173
+ print("=" * 72)