CompactAI commited on
Commit
bb0efe6
·
verified ·
1 Parent(s): 9d7e5cd

Upload 18 files

Browse files
app.py CHANGED
@@ -5,6 +5,11 @@ Serves the trained sklearn ensemble via the AIFinder inference class.
5
 
6
  import os
7
  import re
 
 
 
 
 
8
 
9
  import joblib
10
  import numpy as np
@@ -13,10 +18,14 @@ from flask import Flask, jsonify, request, send_from_directory, render_template
13
  from flask_cors import CORS
14
  from flask_limiter import Limiter
15
  from flask_limiter.util import get_remote_address
 
16
 
17
  from config import MODEL_DIR
18
  from inference import AIFinder
19
 
 
 
 
20
  app = Flask(__name__)
21
  CORS(app)
22
  limiter = Limiter(get_remote_address, app=app)
@@ -28,20 +37,42 @@ using_community = False
28
  DEFAULT_TOP_N = 4
29
  COMMUNITY_DIR = os.path.join(MODEL_DIR, "community")
30
  CORRECTIONS_FILE = os.path.join(COMMUNITY_DIR, "corrections.joblib")
 
 
31
  corrections: list[dict] = []
32
 
 
33
 
34
- def load_models():
35
- global finder, community_finder, corrections
36
- finder = AIFinder(model_dir=MODEL_DIR)
37
- os.makedirs(COMMUNITY_DIR, exist_ok=True)
38
- if os.path.exists(CORRECTIONS_FILE):
39
- corrections = joblib.load(CORRECTIONS_FILE)
40
- if os.path.exists(os.path.join(COMMUNITY_DIR, "rf_4provider.joblib")):
41
- try:
42
- community_finder = AIFinder(model_dir=COMMUNITY_DIR)
43
- except Exception:
44
- community_finder = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
 
47
  def _active_finder():
@@ -112,17 +143,20 @@ def correct():
112
  text = _strip_think_tags(data["text"])
113
  corrections.append({"text": text, "provider": provider})
114
 
115
- texts = [c["text"] for c in corrections]
116
- providers = [c["provider"] for c in corrections]
117
- X = finder.pipeline.transform(texts)
118
- y = finder.le.transform(providers)
 
 
 
119
 
120
- rf = RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=-1)
121
- rf.fit(X, y)
 
 
 
122
 
123
- joblib.dump([rf], os.path.join(COMMUNITY_DIR, "rf_4provider.joblib"))
124
- joblib.dump(finder.pipeline, os.path.join(COMMUNITY_DIR, "pipeline_4provider.joblib"))
125
- joblib.dump(finder.le, os.path.join(COMMUNITY_DIR, "enc_4provider.joblib"))
126
  joblib.dump(corrections, CORRECTIONS_FILE)
127
 
128
  community_finder = AIFinder(model_dir=COMMUNITY_DIR)
@@ -143,7 +177,9 @@ def toggle_community():
143
  global using_community
144
  data = request.get_json(silent=True) or {}
145
  using_community = bool(data.get("enabled", not using_community))
146
- return jsonify({"using_community": using_community, "available": community_finder is not None})
 
 
147
 
148
 
149
  @app.route("/models/<filename>")
@@ -178,6 +214,448 @@ def providers():
178
  )
179
 
180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  if __name__ == "__main__":
182
  print("Loading models...")
183
  load_models()
 
5
 
6
  import os
7
  import re
8
+ import shutil
9
+ import uuid
10
+ import threading
11
+ from collections import defaultdict
12
+ from datetime import datetime
13
 
14
  import joblib
15
  import numpy as np
 
18
  from flask_cors import CORS
19
  from flask_limiter import Limiter
20
  from flask_limiter.util import get_remote_address
21
+ from tqdm import tqdm
22
 
23
  from config import MODEL_DIR
24
  from inference import AIFinder
25
 
26
+ STYLE_MODEL_DIR = os.path.join(MODEL_DIR, "style")
27
+ from dataset_evaluator import load_dataset_texts, get_supported_formats
28
+
29
  app = Flask(__name__)
30
  CORS(app)
31
  limiter = Limiter(get_remote_address, app=app)
 
37
  DEFAULT_TOP_N = 4
38
  COMMUNITY_DIR = os.path.join(MODEL_DIR, "community")
39
  CORRECTIONS_FILE = os.path.join(COMMUNITY_DIR, "corrections.joblib")
40
+ CORRECTION_MODEL_FILE = os.path.join(COMMUNITY_DIR, "correction_rf_4provider.joblib")
41
+ JOBS_FILE = os.path.join(MODEL_DIR, "jobs.joblib")
42
  corrections: list[dict] = []
43
 
44
+ jobs: dict[str, dict] = {}
45
 
46
+
47
+ def _copy_base_models_to_community():
48
+ """Copy base models from style model to community directory if not already present."""
49
+ base_files = [
50
+ "rf_4provider.joblib",
51
+ "pipeline_4provider.joblib",
52
+ "enc_4provider.joblib",
53
+ ]
54
+ for fname in base_files:
55
+ src = os.path.join(STYLE_MODEL_DIR, fname)
56
+ dst = os.path.join(COMMUNITY_DIR, fname)
57
+ if os.path.exists(src) and not os.path.exists(dst):
58
+ shutil.copy(src, dst)
59
+
60
+
61
+ def _update_job_progress(job_id, current, total, stage):
62
+ """Update progress for a job."""
63
+ if job_id in jobs:
64
+ jobs[job_id]["progress"] = {
65
+ "current": current,
66
+ "total": total,
67
+ "stage": stage,
68
+ "percent": round((current / total * 100), 1) if total > 0 else 0,
69
+ }
70
+ _save_jobs()
71
+
72
+
73
+ def _save_jobs():
74
+ """Persist jobs to disk."""
75
+ joblib.dump(jobs, JOBS_FILE)
76
 
77
 
78
  def _active_finder():
 
143
  text = _strip_think_tags(data["text"])
144
  corrections.append({"text": text, "provider": provider})
145
 
146
+ _copy_base_models_to_community()
147
+
148
+ if len(corrections) > 0:
149
+ texts = [c["text"] for c in corrections]
150
+ providers = [c["provider"] for c in corrections]
151
+ X = finder.pipeline.transform(texts)
152
+ y = finder.le.transform(providers)
153
 
154
+ correction_rf = RandomForestClassifier(
155
+ n_estimators=100, random_state=42, n_jobs=-1
156
+ )
157
+ correction_rf.fit(X, y)
158
+ joblib.dump([correction_rf], CORRECTION_MODEL_FILE)
159
 
 
 
 
160
  joblib.dump(corrections, CORRECTIONS_FILE)
161
 
162
  community_finder = AIFinder(model_dir=COMMUNITY_DIR)
 
177
  global using_community
178
  data = request.get_json(silent=True) or {}
179
  using_community = bool(data.get("enabled", not using_community))
180
+ return jsonify(
181
+ {"using_community": using_community, "available": community_finder is not None}
182
+ )
183
 
184
 
185
  @app.route("/models/<filename>")
 
214
  )
215
 
216
 
217
+ @app.route("/api/dataset/info", methods=["POST"])
218
+ def dataset_info():
219
+ """Get info about a dataset without evaluating."""
220
+ data = request.get_json(silent=True)
221
+ if not data or "dataset_id" not in data:
222
+ return jsonify({"error": "Request must include 'dataset_id'"}), 400
223
+
224
+ dataset_id = data["dataset_id"]
225
+ max_samples = data.get("max_samples", 1000)
226
+ evaluate = data.get("evaluate", False)
227
+ api_key = data.get("api_key")
228
+ custom_format = data.get("custom_format")
229
+
230
+ result = load_dataset_texts(
231
+ dataset_id, max_samples=max_samples, sample_size=1, custom_format=custom_format
232
+ )
233
+
234
+ response = {
235
+ "dataset_id": dataset_id,
236
+ "total_rows": result["total_rows"],
237
+ "extracted_count": len(result["texts"]),
238
+ "format": result["format"],
239
+ "format_name": result["format_info"]["name"] if result["format_info"] else None,
240
+ "format_description": result["format_info"]["description"]
241
+ if result["format_info"]
242
+ else None,
243
+ "supported": result["supported"],
244
+ "error": result["error"],
245
+ "custom_format": custom_format,
246
+ }
247
+
248
+ if evaluate and result["supported"]:
249
+ job_id = str(uuid.uuid4())
250
+ jobs[job_id] = {
251
+ "job_id": job_id,
252
+ "dataset_id": dataset_id,
253
+ "max_samples": max_samples,
254
+ "status": "pending",
255
+ "created_at": datetime.utcnow().isoformat(),
256
+ "api_key": api_key,
257
+ }
258
+ _save_jobs()
259
+
260
+ thread = threading.Thread(
261
+ target=_run_evaluation_job,
262
+ args=(job_id, dataset_id, max_samples, api_key, custom_format),
263
+ )
264
+ thread.daemon = True
265
+ thread.start()
266
+
267
+ response["job_id"] = job_id
268
+ response["status"] = "pending"
269
+ response["message"] = "Evaluation started in background."
270
+ response["custom_format"] = custom_format
271
+
272
+ return jsonify(response)
273
+
274
+
275
+ def _run_evaluation_job(
276
+ job_id: str,
277
+ dataset_id: str,
278
+ max_samples: int,
279
+ api_key: str | None,
280
+ custom_format: str | None = None,
281
+ ):
282
+ """Background task to run dataset evaluation."""
283
+ jobs[job_id]["status"] = "running"
284
+ jobs[job_id]["started_at"] = datetime.utcnow().isoformat()
285
+ jobs[job_id]["custom_format"] = custom_format
286
+ _save_jobs()
287
+
288
+ progress_cb = lambda c, t, s: _update_job_progress(job_id, c, t, s)
289
+
290
+ try:
291
+ load_result = load_dataset_texts(
292
+ dataset_id,
293
+ max_samples=max_samples,
294
+ progress_callback=progress_cb,
295
+ custom_format=custom_format,
296
+ )
297
+
298
+ if not load_result["supported"]:
299
+ jobs[job_id].update(
300
+ {
301
+ "status": "failed",
302
+ "error": load_result["error"],
303
+ "dataset_id": dataset_id,
304
+ "supported": False,
305
+ "completed_at": datetime.utcnow().isoformat(),
306
+ }
307
+ )
308
+ _save_jobs()
309
+ return
310
+
311
+ texts = load_result["texts"]
312
+ if not texts:
313
+ jobs[job_id].update(
314
+ {
315
+ "status": "failed",
316
+ "error": "No valid assistant responses found in dataset",
317
+ "dataset_id": dataset_id,
318
+ "supported": True,
319
+ "extracted_count": 0,
320
+ "completed_at": datetime.utcnow().isoformat(),
321
+ }
322
+ )
323
+ _save_jobs()
324
+ return
325
+
326
+ results = {
327
+ "dataset_id": dataset_id,
328
+ "format": load_result["format"],
329
+ "format_name": load_result["format_info"]["name"]
330
+ if load_result["format_info"]
331
+ else None,
332
+ "total_rows": load_result["total_rows"],
333
+ "extracted_count": len(texts),
334
+ "provider_counts": {},
335
+ "provider_confidences": {},
336
+ "top_providers": {},
337
+ }
338
+
339
+ provider_counts = defaultdict(int)
340
+ provider_confidences = defaultdict(list)
341
+ top_providers = defaultdict(int)
342
+
343
+ af = _active_finder()
344
+
345
+ total = len(texts)
346
+ for i, text in enumerate(tqdm(texts, desc="Evaluating")):
347
+ if progress_cb and (i % 10 == 0 or i == total - 1):
348
+ progress_cb(i + 1, total, "evaluating")
349
+ try:
350
+ proba = af.predict_proba(text)
351
+ sorted_providers = sorted(
352
+ proba.items(), key=lambda x: x[1], reverse=True
353
+ )
354
+
355
+ pred_provider = sorted_providers[0][0]
356
+ confidence = sorted_providers[0][1]
357
+
358
+ provider_counts[pred_provider] += 1
359
+ provider_confidences[pred_provider].append(confidence)
360
+
361
+ for name, conf in sorted_providers[:5]:
362
+ top_providers[name] += 1
363
+ except Exception:
364
+ continue
365
+
366
+ total = len(texts)
367
+ for provider, count in provider_counts.items():
368
+ results["provider_counts"][provider] = {
369
+ "count": count,
370
+ "percentage": round((count / total) * 100, 2),
371
+ }
372
+ confs = provider_confidences[provider]
373
+ avg_conf = sum(confs) / len(confs) if confs else 0
374
+ results["provider_confidences"][provider] = {
375
+ "average": round(avg_conf * 100, 2),
376
+ "cumulative": round(avg_conf * count, 2),
377
+ }
378
+
379
+ results["top_providers"] = dict(
380
+ sorted(top_providers.items(), key=lambda x: -x[1])[:5]
381
+ )
382
+
383
+ sorted_by_cumulative = sorted(
384
+ results["provider_confidences"].items(), key=lambda x: -x[1]["cumulative"]
385
+ )
386
+ results["likely_provider"] = (
387
+ sorted_by_cumulative[0][0] if sorted_by_cumulative else None
388
+ )
389
+ results["average_confidence"] = (
390
+ round(sum(sum(c) for c in provider_confidences.values()) / total * 100, 2)
391
+ if total > 0
392
+ else 0
393
+ )
394
+
395
+ jobs[job_id].update(
396
+ {
397
+ "status": "completed",
398
+ "results": results,
399
+ "api_key": api_key,
400
+ "completed_at": datetime.utcnow().isoformat(),
401
+ }
402
+ )
403
+ _save_jobs()
404
+ except Exception as e:
405
+ jobs[job_id].update(
406
+ {
407
+ "status": "failed",
408
+ "error": str(e),
409
+ "completed_at": datetime.utcnow().isoformat(),
410
+ }
411
+ )
412
+ _save_jobs()
413
+
414
+
415
+ @app.route("/api/dataset/evaluate", methods=["POST"])
416
+ @limiter.limit("10/minute")
417
+ def dataset_evaluate():
418
+ """Start a background job to evaluate a HuggingFace dataset."""
419
+ data = request.get_json(silent=True)
420
+ if not data or "dataset_id" not in data:
421
+ return jsonify({"error": "Request must include 'dataset_id'"}), 400
422
+
423
+ dataset_id = data["dataset_id"]
424
+ max_samples = data.get("max_samples", 1000)
425
+ api_key = data.get("api_key")
426
+ custom_format = data.get("custom_format")
427
+
428
+ load_result = load_dataset_texts(
429
+ dataset_id, max_samples=max_samples, custom_format=custom_format
430
+ )
431
+
432
+ if not load_result["supported"]:
433
+ return jsonify(
434
+ {
435
+ "error": load_result["error"],
436
+ "dataset_id": dataset_id,
437
+ "supported": False,
438
+ }
439
+ ), 400
440
+
441
+ if not load_result["texts"]:
442
+ return jsonify(
443
+ {
444
+ "error": "No valid assistant responses found in dataset",
445
+ "dataset_id": dataset_id,
446
+ "supported": True,
447
+ "extracted_count": 0,
448
+ }
449
+ ), 400
450
+
451
+ job_id = str(uuid.uuid4())
452
+ jobs[job_id] = {
453
+ "job_id": job_id,
454
+ "dataset_id": dataset_id,
455
+ "max_samples": max_samples,
456
+ "status": "pending",
457
+ "created_at": datetime.utcnow().isoformat(),
458
+ "api_key": api_key,
459
+ "custom_format": custom_format,
460
+ }
461
+ _save_jobs()
462
+
463
+ thread = threading.Thread(
464
+ target=_run_evaluation_job,
465
+ args=(job_id, dataset_id, max_samples, api_key, custom_format),
466
+ )
467
+ thread.daemon = True
468
+ thread.start()
469
+
470
+ return jsonify(
471
+ {
472
+ "job_id": job_id,
473
+ "status": "pending",
474
+ "message": "Evaluation started. Use the job_id to check status later.",
475
+ "custom_format": custom_format,
476
+ }
477
+ )
478
+
479
+
480
+ @app.route("/api/dataset/job/<job_id>", methods=["GET"])
481
+ def dataset_job_status(job_id):
482
+ """Get the status and results of a dataset evaluation job."""
483
+ if job_id not in jobs:
484
+ return jsonify({"error": "Job not found"}), 404
485
+
486
+ job = jobs[job_id]
487
+ response = {
488
+ "job_id": job_id,
489
+ "dataset_id": job.get("dataset_id"),
490
+ "status": job["status"],
491
+ "created_at": job.get("created_at"),
492
+ "started_at": job.get("started_at"),
493
+ "completed_at": job.get("completed_at"),
494
+ }
495
+
496
+ if job.get("progress"):
497
+ response["progress"] = job["progress"]
498
+
499
+ if job["status"] == "completed":
500
+ response["results"] = job.get("results")
501
+ elif job["status"] == "failed":
502
+ response["error"] = job.get("error")
503
+
504
+ return jsonify(response)
505
+
506
+
507
+ @app.route("/api/datasets", methods=["GET"])
508
+ def list_datasets():
509
+ """List all evaluated datasets, optionally filtered by API key."""
510
+ api_key = request.args.get("api_key")
511
+
512
+ filtered_jobs = []
513
+ for job_id, job in jobs.items():
514
+ if api_key and job.get("api_key") != api_key:
515
+ continue
516
+ if job["status"] in ("completed", "failed"):
517
+ filtered_jobs.append(
518
+ {
519
+ "job_id": job_id,
520
+ "dataset_id": job.get("dataset_id"),
521
+ "status": job["status"],
522
+ "created_at": job.get("created_at"),
523
+ "completed_at": job.get("completed_at"),
524
+ "error": job.get("error"),
525
+ "custom_format": job.get("custom_format"),
526
+ }
527
+ )
528
+
529
+ filtered_jobs.sort(key=lambda x: x.get("created_at", ""), reverse=True)
530
+ return jsonify({"datasets": filtered_jobs})
531
+
532
+
533
+ @app.route("/api/datasets/clear", methods=["POST"])
534
+ def clear_datasets():
535
+ """Clear all evaluated dataset history for the current API key."""
536
+ data = request.get_json(silent=True) or {}
537
+ api_key = data.get("api_key")
538
+
539
+ if not api_key:
540
+ return jsonify({"error": "API key required"}), 400
541
+
542
+ keys_to_remove = []
543
+ for job_id, job in jobs.items():
544
+ if job.get("api_key") == api_key and job["status"] in ("completed", "failed"):
545
+ keys_to_remove.append(job_id)
546
+
547
+ for key in keys_to_remove:
548
+ del jobs[key]
549
+
550
+ if keys_to_remove:
551
+ _save_jobs()
552
+
553
+ return jsonify({"status": "ok", "cleared": len(keys_to_remove)})
554
+
555
+
556
+ @app.route("/api/dataset/formats", methods=["GET"])
557
+ def dataset_formats():
558
+ """Get list of supported dataset formats."""
559
+ formats = get_supported_formats()
560
+ formats_list = [
561
+ {
562
+ "name": info["name"],
563
+ "description": info["description"],
564
+ "examples": info["examples"],
565
+ }
566
+ for info in formats.values()
567
+ ]
568
+ formats_list.append(
569
+ {
570
+ "name": "Custom Format",
571
+ "description": "Define your own format specification",
572
+ "examples": [
573
+ "column: response",
574
+ "column: prompt, column: response",
575
+ "pattern: user:, pattern: assistant:",
576
+ "user:[startuser]assistant:[startassistant]",
577
+ ],
578
+ }
579
+ )
580
+ return jsonify(
581
+ {
582
+ "formats": formats_list,
583
+ "custom_format_help": {
584
+ "description": "Specify custom format using these patterns:",
585
+ "patterns": [
586
+ "column: <field_name> - extract single field",
587
+ "column: <user_field>, column: <assistant_field> - extract from two columns",
588
+ "pattern: <regex> - use regex to extract",
589
+ "user:[startuser]assistant:[startassistant] - pattern-based extraction",
590
+ ],
591
+ "examples": [
592
+ {
593
+ "input": "column: completion",
594
+ "description": "Extract from 'completion' field",
595
+ },
596
+ {
597
+ "input": "column: input, column: output",
598
+ "description": "Extract from 'input' and 'output' columns",
599
+ },
600
+ {
601
+ "input": "user:[INST]assistant:[/INST]",
602
+ "description": "Extract text between markers",
603
+ },
604
+ ],
605
+ },
606
+ }
607
+ )
608
+
609
+
610
+ class CommunityAIFinder(AIFinder):
611
+ """Extended AIFinder that blends base model with correction model."""
612
+
613
+ def __init__(self, model_dir, correction_model_path=None):
614
+ super().__init__(model_dir)
615
+ self.correction_models = None
616
+ if correction_model_path and os.path.exists(correction_model_path):
617
+ self.correction_models = joblib.load(correction_model_path)
618
+
619
+ def predict_proba(self, text):
620
+ """Blend base model predictions with correction model if available."""
621
+ X = self.pipeline.transform([text])
622
+
623
+ base_proba = np.mean([m.predict_proba(X) for m in self.models], axis=0)
624
+
625
+ if self.correction_models is not None and len(self.correction_models) > 0:
626
+ correction_proba = np.mean(
627
+ [m.predict_proba(X) for m in self.correction_models], axis=0
628
+ )
629
+
630
+ blend_weight = 0.7
631
+ final_proba = (
632
+ 1 - blend_weight
633
+ ) * base_proba + blend_weight * correction_proba
634
+ final_proba = final_proba / final_proba.sum(axis=1, keepdims=True)
635
+ else:
636
+ final_proba = base_proba
637
+
638
+ return dict(zip(self.le.classes_, final_proba[0]))
639
+
640
+
641
+ def load_models():
642
+ global finder, community_finder, corrections, jobs
643
+ finder = AIFinder(model_dir=STYLE_MODEL_DIR)
644
+ os.makedirs(COMMUNITY_DIR, exist_ok=True)
645
+ _copy_base_models_to_community()
646
+ if os.path.exists(CORRECTIONS_FILE):
647
+ corrections = joblib.load(CORRECTIONS_FILE)
648
+ if os.path.exists(JOBS_FILE):
649
+ jobs = joblib.load(JOBS_FILE)
650
+ if os.path.exists(os.path.join(COMMUNITY_DIR, "rf_4provider.joblib")):
651
+ try:
652
+ community_finder = CommunityAIFinder(
653
+ model_dir=COMMUNITY_DIR, correction_model_path=CORRECTION_MODEL_FILE
654
+ )
655
+ except Exception:
656
+ community_finder = None
657
+
658
+
659
  if __name__ == "__main__":
660
  print("Loading models...")
661
  load_models()
config.py CHANGED
@@ -1,6 +1,6 @@
1
  """
2
  AIFinder Configuration
3
- Dataset registry, label mappings, and feature parameters.
4
  """
5
 
6
  import os
@@ -9,10 +9,11 @@ import os
9
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
10
  MODEL_DIR = os.path.join(BASE_DIR, "models")
11
 
12
- # --- Dataset Registry ---
13
- # Each entry: (hf_dataset_id, provider, model_name, optional_kwargs)
14
- # optional_kwargs: subset name, split, etc.
15
- DATASET_REGISTRY = [
 
16
  # Anthropic
17
  ("TeichAI/claude-4.5-opus-high-reasoning-250x", "Anthropic", "Claude 4.5 Opus", {}),
18
  (
@@ -73,7 +74,7 @@ DATASET_REGISTRY = [
73
  ),
74
  # Zhipu
75
  ("TeichAI/Pony-Alpha-15k", "Zhipu", "GLM-5", {"max_samples": 1500}),
76
- # DeepSeek (TeichAI)
77
  ("TeichAI/deepseek-v3.2-speciale-1000x", "DeepSeek", "DeepSeek V3.2 Speciale", {}),
78
  (
79
  "TeichAI/deepseek-v3.2-speciale-openr1-math-3k",
@@ -81,10 +82,7 @@ DATASET_REGISTRY = [
81
  "DeepSeek V3.2 Speciale",
82
  {"max_samples": 1500},
83
  ),
84
- ]
85
-
86
- # DeepSeek (a-m-team) — different format, handled separately
87
- DEEPSEEK_AM_DATASETS = [
88
  (
89
  "a-m-team/AM-DeepSeek-R1-Distilled-1.4M",
90
  "DeepSeek",
@@ -93,24 +91,22 @@ DEEPSEEK_AM_DATASETS = [
93
  ),
94
  ]
95
 
96
- # Conversational datasets disabled
97
- CONVERSATIONAL_DATASETS = []
98
-
99
- # --- All providers and models ---
100
- PROVIDERS = [
101
- "Anthropic",
102
- "OpenAI",
103
- "Google",
104
- "xAI",
105
- "MoonshotAI",
106
- "Mistral",
107
- "MiniMax",
108
- "StepFun",
109
- "Zhipu",
110
- "DeepSeek",
111
  ]
 
112
 
113
- # --- Feature parameters ---
 
 
114
  TFIDF_WORD_PARAMS = {
115
  "analyzer": "word",
116
  "ngram_range": (1, 2),
@@ -130,15 +126,15 @@ TFIDF_CHAR_PARAMS = {
130
  "smooth_idf": True,
131
  }
132
 
133
- # Equal samples per provider
 
 
134
  MAX_SAMPLES_PER_PROVIDER = 1000
135
-
136
- # --- Train/val/test split ---
137
  TEST_SIZE = 0.15
138
  VAL_SIZE = 0.10
139
  RANDOM_STATE = 42
140
 
141
- # --- Neural Network ---
142
  HIDDEN_DIM = 256
143
  EMBED_DIM = 128
144
  DROPOUT = 0.7
 
1
  """
2
  AIFinder Configuration
3
+ Easy configuration for providers and datasets.
4
  """
5
 
6
  import os
 
9
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
10
  MODEL_DIR = os.path.join(BASE_DIR, "models")
11
 
12
+ # ============================================================================
13
+ # EASY PROVIDER CONFIGURATION
14
+ # Add new providers here! Each entry: (huggingface_dataset, provider_name, model_name, kwargs)
15
+ # ============================================================================
16
+ PROVIDER_DATASETS = [
17
  # Anthropic
18
  ("TeichAI/claude-4.5-opus-high-reasoning-250x", "Anthropic", "Claude 4.5 Opus", {}),
19
  (
 
74
  ),
75
  # Zhipu
76
  ("TeichAI/Pony-Alpha-15k", "Zhipu", "GLM-5", {"max_samples": 1500}),
77
+ # DeepSeek
78
  ("TeichAI/deepseek-v3.2-speciale-1000x", "DeepSeek", "DeepSeek V3.2 Speciale", {}),
79
  (
80
  "TeichAI/deepseek-v3.2-speciale-openr1-math-3k",
 
82
  "DeepSeek V3.2 Speciale",
83
  {"max_samples": 1500},
84
  ),
85
+ # DeepSeek (a-m-team) - different format
 
 
 
86
  (
87
  "a-m-team/AM-DeepSeek-R1-Distilled-1.4M",
88
  "DeepSeek",
 
91
  ),
92
  ]
93
 
94
+ # Auto-generate DATASET_REGISTRY and PROVIDERS from PROVIDER_DATASETS
95
+ DEEPSEEK_AM_DATASETS = [
96
+ (ds_id, prov, model, kwargs)
97
+ for ds_id, prov, model, kwargs in PROVIDER_DATASETS
98
+ if "a-m-team" in ds_id
99
+ ]
100
+ DATASET_REGISTRY = [
101
+ (ds_id, prov, model, kwargs)
102
+ for ds_id, prov, model, kwargs in PROVIDER_DATASETS
103
+ if "a-m-team" not in ds_id
 
 
 
 
 
104
  ]
105
+ PROVIDERS = sorted(set(prov for _, prov, _, _ in PROVIDER_DATASETS))
106
 
107
+ # ============================================================================
108
+ # FEATURE PARAMETERS
109
+ # ============================================================================
110
  TFIDF_WORD_PARAMS = {
111
  "analyzer": "word",
112
  "ngram_range": (1, 2),
 
126
  "smooth_idf": True,
127
  }
128
 
129
+ # ============================================================================
130
+ # TRAINING PARAMETERS
131
+ # ============================================================================
132
  MAX_SAMPLES_PER_PROVIDER = 1000
 
 
133
  TEST_SIZE = 0.15
134
  VAL_SIZE = 0.10
135
  RANDOM_STATE = 42
136
 
137
+ # Neural Network (unused currently, but kept for reference)
138
  HIDDEN_DIM = 256
139
  EMBED_DIM = 128
140
  DROPOUT = 0.7
data_loader.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AIFinder Data Loader
3
+ Downloads and parses HuggingFace datasets, extracts assistant responses,
4
+ and labels them with is_ai, provider, and model.
5
+ """
6
+
7
+ import os
8
+ import re
9
+ import time
10
+
11
+ from datasets import load_dataset
12
+ from tqdm import tqdm
13
+
14
+ from config import (
15
+ DATASET_REGISTRY,
16
+ DEEPSEEK_AM_DATASETS,
17
+ MAX_SAMPLES_PER_PROVIDER,
18
+ )
19
+
20
+ HF_TOKEN = os.environ.get("HF_TOKEN")
21
+
22
+
23
+ def _parse_msg(msg):
24
+ """Parse a message that may be a dict or a JSON string."""
25
+ if isinstance(msg, dict):
26
+ return msg
27
+ if isinstance(msg, str):
28
+ try:
29
+ import json as _json
30
+
31
+ parsed = _json.loads(msg)
32
+ if isinstance(parsed, dict):
33
+ return parsed
34
+ except (ValueError, Exception):
35
+ pass
36
+ return {}
37
+
38
+
39
+ def _extract_response_only(content):
40
+ """Extract only the final response, stripping CoT blocks.
41
+ Returns only the text after </think> or </thinking> if present,
42
+ otherwise returns the full content.
43
+ """
44
+ if not content:
45
+ return ""
46
+ think_match = re.search(r"</?think(?:ing)?>(.*)$", content, re.DOTALL)
47
+ if think_match:
48
+ response = think_match.group(1).strip()
49
+ if response:
50
+ return response
51
+ return content
52
+
53
+
54
+ def _extract_assistant_texts_from_conversations(rows):
55
+ """Extract individual assistant messages from conversation datasets.
56
+ Returns one text per assistant turn (not concatenated) for cleaner samples.
57
+ Only extracts the response portion (after </think> if present).
58
+ """
59
+ texts = []
60
+ for row in rows:
61
+ convos = row.get("conversations")
62
+ if convos is None or (hasattr(convos, "__len__") and len(convos) == 0):
63
+ convos = row.get("messages")
64
+ if convos is None or (hasattr(convos, "__len__") and len(convos) == 0):
65
+ convos = []
66
+ for msg in convos:
67
+ msg = _parse_msg(msg)
68
+ role = msg.get("role", "")
69
+ content = msg.get("content", "")
70
+ if role in ("assistant", "gpt", "model") and content:
71
+ response_only = _extract_response_only(content)
72
+ if response_only:
73
+ texts.append(response_only)
74
+ return texts
75
+
76
+
77
+ def _extract_from_am_dataset(row):
78
+ """Extract individual assistant texts from a-m-team format.
79
+ Only extracts the response portion (after </think> if present).
80
+ """
81
+ messages = row.get("messages") or row.get("conversations") or []
82
+ texts = []
83
+ for msg in messages:
84
+ role = msg.get("role", "") if isinstance(msg, dict) else ""
85
+ content = msg.get("content", "") if isinstance(msg, dict) else ""
86
+ if role == "assistant" and content:
87
+ response_only = _extract_response_only(content)
88
+ if response_only:
89
+ texts.append(response_only)
90
+ return texts
91
+
92
+
93
+ def load_teichai_dataset(dataset_id, provider, model_name, kwargs):
94
+ """Load a single conversation-format dataset and return (texts, providers, models)."""
95
+ max_samples = kwargs.get("max_samples")
96
+ load_kwargs = {}
97
+ if "name" in kwargs:
98
+ load_kwargs["name"] = kwargs["name"]
99
+
100
+ try:
101
+ ds = load_dataset(dataset_id, split="train", token=HF_TOKEN, **load_kwargs)
102
+ rows = list(ds)
103
+ except Exception as e:
104
+ # Fallback: load from auto-converted parquet via HF API
105
+ try:
106
+ import pandas as pd
107
+
108
+ url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/default/train/0.parquet"
109
+ df = pd.read_parquet(url)
110
+ rows = df.to_dict(orient="records")
111
+ except Exception as e2:
112
+ print(f" [SKIP] {dataset_id}: {e} / parquet fallback: {e2}")
113
+ return [], [], []
114
+
115
+ if max_samples and len(rows) > max_samples:
116
+ import random
117
+
118
+ random.seed(42)
119
+ rows = random.sample(rows, max_samples)
120
+
121
+ texts = _extract_assistant_texts_from_conversations(rows)
122
+
123
+ # Filter out empty/too-short texts
124
+ filtered = [(t, provider, model_name) for t in texts if len(t) > 50]
125
+ if not filtered:
126
+ print(f" [SKIP] {dataset_id}: no valid texts extracted")
127
+ return [], [], []
128
+
129
+ t, p, m = zip(*filtered)
130
+ return list(t), list(p), list(m)
131
+
132
+
133
+ def load_am_deepseek_dataset(dataset_id, provider, model_name, kwargs):
134
+ """Load a-m-team DeepSeek dataset."""
135
+ max_samples = kwargs.get("max_samples")
136
+ load_kwargs = {}
137
+ if "name" in kwargs:
138
+ load_kwargs["name"] = kwargs["name"]
139
+
140
+ try:
141
+ ds = load_dataset(dataset_id, split="train", token=HF_TOKEN, **load_kwargs)
142
+ except Exception:
143
+ try:
144
+ ds = load_dataset(
145
+ dataset_id, split="train", streaming=True, token=HF_TOKEN, **load_kwargs
146
+ )
147
+ rows = []
148
+ for row in ds:
149
+ rows.append(row)
150
+ if max_samples and len(rows) >= max_samples:
151
+ break
152
+ except Exception as e2:
153
+ print(f" [SKIP] {dataset_id}: {e2}")
154
+ return [], [], []
155
+ else:
156
+ rows = list(ds)
157
+ if max_samples and len(rows) > max_samples:
158
+ rows = rows[:max_samples]
159
+
160
+ texts = []
161
+ for row in rows:
162
+ for text in _extract_from_am_dataset(row):
163
+ if len(text) > 50:
164
+ texts.append(text)
165
+
166
+ providers = [provider] * len(texts)
167
+ models = [model_name] * len(texts)
168
+ return texts, providers, models
169
+
170
+
171
+ def load_all_data():
172
+ """Load all datasets and return combined lists.
173
+
174
+ Returns:
175
+ texts: list of str
176
+ providers: list of str
177
+ models: list of str
178
+ is_ai: list of int (1=AI, 0=Human)
179
+ """
180
+ all_texts = []
181
+ all_providers = []
182
+ all_models = []
183
+
184
+ # TeichAI datasets
185
+ print("Loading TeichAI datasets...")
186
+ for dataset_id, provider, model_name, kwargs in tqdm(
187
+ DATASET_REGISTRY, desc="TeichAI"
188
+ ):
189
+ t0 = time.time()
190
+ texts, providers, models = load_teichai_dataset(
191
+ dataset_id, provider, model_name, kwargs
192
+ )
193
+ elapsed = time.time() - t0
194
+ all_texts.extend(texts)
195
+ all_providers.extend(providers)
196
+ all_models.extend(models)
197
+ print(f" {dataset_id}: {len(texts)} samples ({elapsed:.1f}s)")
198
+
199
+ # DeepSeek a-m-team datasets
200
+ print("\nLoading DeepSeek (a-m-team) datasets...")
201
+ for dataset_id, provider, model_name, kwargs in tqdm(
202
+ DEEPSEEK_AM_DATASETS, desc="DeepSeek-AM"
203
+ ):
204
+ t0 = time.time()
205
+ texts, providers, models = load_am_deepseek_dataset(
206
+ dataset_id, provider, model_name, kwargs
207
+ )
208
+ elapsed = time.time() - t0
209
+ all_texts.extend(texts)
210
+ all_providers.extend(providers)
211
+ all_models.extend(models)
212
+ print(f" {dataset_id}: {len(texts)} samples ({elapsed:.1f}s)")
213
+
214
+ # Deduplicate by text hash
215
+ import hashlib
216
+ import random as _rng
217
+
218
+ _rng.seed(42)
219
+
220
+ seen = set()
221
+ dedup_texts, dedup_providers, dedup_models = [], [], []
222
+ for t, p, m in zip(all_texts, all_providers, all_models):
223
+ h = hashlib.md5(t.strip().lower().encode()).hexdigest()
224
+ if h not in seen:
225
+ seen.add(h)
226
+ dedup_texts.append(t)
227
+ dedup_providers.append(p)
228
+ dedup_models.append(m)
229
+
230
+ n_dupes = len(all_texts) - len(dedup_texts)
231
+ if n_dupes > 0:
232
+ print(f"\n Removed {n_dupes} duplicate samples")
233
+
234
+ # Equal samples per provider
235
+ from collections import defaultdict
236
+
237
+ provider_indices = defaultdict(list)
238
+ for i, p in enumerate(dedup_providers):
239
+ provider_indices[p].append(i)
240
+
241
+ # Use min of available or max allowed
242
+ keep_indices = []
243
+ for p, idxs in provider_indices.items():
244
+ _rng.shuffle(idxs)
245
+ n_sample = min(len(idxs), MAX_SAMPLES_PER_PROVIDER)
246
+ idxs = idxs[:n_sample]
247
+ print(f" Sampled {p}: {len(idxs)} samples")
248
+ keep_indices.extend(idxs)
249
+ keep_indices.sort()
250
+
251
+ all_texts = [dedup_texts[i] for i in keep_indices]
252
+ all_providers = [dedup_providers[i] for i in keep_indices]
253
+ all_models = [dedup_models[i] for i in keep_indices]
254
+
255
+ # Build is_ai labels (all AI)
256
+ is_ai = [1] * len(all_texts)
257
+
258
+ print(f"\n=== Total: {len(all_texts)} samples ===")
259
+ # Print per-provider counts
260
+ from collections import Counter
261
+
262
+ prov_counts = Counter(all_providers)
263
+ for p, c in sorted(prov_counts.items(), key=lambda x: -x[1]):
264
+ print(f" {p}: {c}")
265
+
266
+ return all_texts, all_providers, all_models, is_ai
267
+
268
+
269
+ if __name__ == "__main__":
270
+ texts, providers, models, is_ai = load_all_data()
dataset_evaluator.py ADDED
@@ -0,0 +1,769 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AIFinder Dataset Evaluator
3
+ Supports various HuggingFace dataset formats for evaluation.
4
+ """
5
+
6
+ import os
7
+ import re
8
+ import json
9
+ import random
10
+ from collections import defaultdict
11
+ from typing import Any
12
+
13
+ from datasets import load_dataset
14
+ from tqdm import tqdm
15
+
16
+ HF_TOKEN = os.environ.get("HF_TOKEN")
17
+
18
+
19
+ SUPPORTED_FORMATS = {
20
+ "teichai_healer": {
21
+ "name": "TeichAI Healer Format",
22
+ "description": "TeichAI Healer-Alpha format with 'prompt' and 'response' fields",
23
+ "examples": ["TeichAI/Healer-Alpha-16k"],
24
+ "check": lambda row: (
25
+ "prompt" in row
26
+ and "response" in row
27
+ and isinstance(row.get("prompt"), (str, dict))
28
+ and isinstance(row.get("response"), (str, dict))
29
+ ),
30
+ },
31
+ "teichai": {
32
+ "name": "TeichAI Format",
33
+ "description": "TeichAI dataset format with 'conversations' or 'messages' containing role/content",
34
+ "examples": [
35
+ "TeichAI/claude-4.5-opus-high-reasoning-250x",
36
+ "TeichAI/Claude-3.5-Sonnet-128k",
37
+ ],
38
+ "check": lambda row: _check_conversations_format(row),
39
+ },
40
+ "combined": {
41
+ "name": "Combined Outputs",
42
+ "description": "Dataset with 'output', 'outputs', 'generated' or 'completion' field",
43
+ "examples": ["jacobmorrison/gpt-oss-20b-combined-outputs"],
44
+ "check": lambda row: (
45
+ "prompt" not in row
46
+ and "response" not in row
47
+ and not _check_conversations_format(row)
48
+ and (
49
+ any(k in row for k in ["output", "outputs", "generated", "completion"])
50
+ or (
51
+ isinstance(row.get("data"), str)
52
+ or isinstance(row.get("example"), str)
53
+ )
54
+ )
55
+ ),
56
+ },
57
+ "conversations": {
58
+ "name": "Conversations Format",
59
+ "description": "Dataset with 'conversations' or 'messages' field containing role/content pairs",
60
+ "examples": [
61
+ "TeichAI/claude-4.5-opus-high-reasoning-250x",
62
+ "ianncity/Hunter-Alpha-SFT-300000x",
63
+ ],
64
+ "check": lambda row: _check_conversations_format(row),
65
+ },
66
+ "chat": {
67
+ "name": "Chat Format",
68
+ "description": "Dataset with 'chat' or 'dialogue' field",
69
+ "examples": ["some/chat-dataset"],
70
+ "check": lambda row: ("chat" in row.keys() or "dialogue" in row.keys()),
71
+ },
72
+ "text": {
73
+ "name": "Text Field",
74
+ "description": "Dataset with a 'text' field containing the response",
75
+ "examples": ["some/text-dataset"],
76
+ "check": lambda row: "text" in row and isinstance(row.get("text"), str),
77
+ },
78
+ "response": {
79
+ "name": "Response Field",
80
+ "description": "Dataset with 'response' or 'output' field",
81
+ "examples": ["some/response-dataset"],
82
+ "check": lambda row: "response" in row or "output" in row,
83
+ },
84
+ "content": {
85
+ "name": "Content Field",
86
+ "description": "Dataset with 'content' field (single message)",
87
+ "examples": ["some/content-dataset"],
88
+ "check": lambda row: "content" in row and isinstance(row.get("content"), str),
89
+ },
90
+ "messages": {
91
+ "name": "Messages Array",
92
+ "description": "Dataset where each row is an array of message objects",
93
+ "examples": ["some/messages-dataset"],
94
+ "check": lambda row: isinstance(row, list)
95
+ and len(row) > 0
96
+ and isinstance(row[0], dict),
97
+ },
98
+ "sft": {
99
+ "name": "SFT Format",
100
+ "description": "Supervised Fine-Tuning format with 'prompt' and 'response' or 'completion'",
101
+ "examples": ["some/sft-dataset"],
102
+ "check": lambda row: "prompt" in row
103
+ and ("response" in row or "completion" in row),
104
+ },
105
+ "qa": {
106
+ "name": "Q&A Format",
107
+ "description": "Question-Answer format with 'question' and 'answer' fields",
108
+ "examples": ["some/qa-dataset"],
109
+ "check": lambda row: "question" in row and "answer" in row,
110
+ },
111
+ "combined": {
112
+ "name": "Combined Outputs",
113
+ "description": "Dataset with 'input', 'output', 'outputs' or combined text field",
114
+ "examples": ["jacobmorrison/gpt-oss-20b-combined-outputs"],
115
+ "check": lambda row: any(
116
+ k in row
117
+ for k in ["output", "outputs", "combined", "generated", "completion"]
118
+ )
119
+ or (isinstance(row.get("data"), str) or isinstance(row.get("example"), str)),
120
+ },
121
+ "completion": {
122
+ "name": "Completion Format",
123
+ "description": "Dataset with 'completion' field (like OpenAI fine-tuning)",
124
+ "examples": ["some/completion-dataset"],
125
+ "check": lambda row: "completion" in row
126
+ and isinstance(row.get("completion"), str),
127
+ },
128
+ "generations": {
129
+ "name": "Generations Format",
130
+ "description": "Dataset with 'generations' or 'generation' field (LLM outputs)",
131
+ "examples": ["some/generations-dataset"],
132
+ "check": lambda row: "generations" in row or "generation" in row,
133
+ },
134
+ }
135
+
136
+
137
+ def _check_conversations_format(row):
138
+ """Check if row has conversations/messages with proper role/content structure."""
139
+ conv_key = (
140
+ "conversations"
141
+ if "conversations" in row
142
+ else "messages"
143
+ if "messages" in row
144
+ else None
145
+ )
146
+ if not conv_key:
147
+ return False
148
+ convos = row.get(conv_key)
149
+ if not isinstance(convos, list) or not convos:
150
+ return False
151
+ first_msg = convos[0]
152
+ if isinstance(first_msg, dict):
153
+ return "role" in first_msg and "content" in first_msg
154
+ return False
155
+
156
+
157
+ def detect_format(rows, sample_size=10):
158
+ """Detect the dataset format from sample rows."""
159
+ if not rows:
160
+ return None, []
161
+
162
+ sample = rows[:sample_size]
163
+
164
+ for fmt_name, fmt_info in SUPPORTED_FORMATS.items():
165
+ check_func = fmt_info["check"]
166
+ matches = 0
167
+ for row in sample:
168
+ try:
169
+ if check_func(row):
170
+ matches += 1
171
+ except:
172
+ pass
173
+
174
+ if matches >= len(sample) * 0.6:
175
+ return fmt_name, SUPPORTED_FORMATS[fmt_name]
176
+
177
+ return None, []
178
+
179
+
180
+ def _parse_msg(msg):
181
+ """Parse a message that may be a dict or a JSON string."""
182
+ if isinstance(msg, dict):
183
+ return msg
184
+ if isinstance(msg, str):
185
+ try:
186
+ parsed = json.loads(msg)
187
+ if isinstance(parsed, dict):
188
+ return parsed
189
+ except (ValueError, Exception):
190
+ pass
191
+ return {}
192
+
193
+
194
+ def _extract_response_only(content):
195
+ """Extract only the final response, stripping CoT blocks."""
196
+ if not content:
197
+ return ""
198
+ think_match = re.search(r"</?think(?:ing)?>(.*)$", content, re.DOTALL)
199
+ if think_match:
200
+ response = think_match.group(1).strip()
201
+ if response:
202
+ return response
203
+ return content
204
+
205
+
206
+ def extract_texts_conversations(rows):
207
+ """Extract from conversations/messages format."""
208
+ texts = []
209
+ for row in rows:
210
+ convos = row.get("conversations") or row.get("messages") or []
211
+
212
+ if not convos:
213
+ continue
214
+
215
+ for msg in convos:
216
+ msg = _parse_msg(msg)
217
+ role = msg.get("role", "")
218
+ content = msg.get("content", "")
219
+
220
+ if role in ("assistant", "gpt", "model", "ai") and content:
221
+ response_only = _extract_response_only(content)
222
+ if response_only and len(response_only) > 50:
223
+ texts.append(response_only)
224
+ return texts
225
+
226
+
227
+ def extract_texts_chat(rows):
228
+ """Extract from chat/dialogue format."""
229
+ texts = []
230
+ for row in rows:
231
+ chat = row.get("chat") or row.get("dialogue") or []
232
+
233
+ if isinstance(chat, list):
234
+ for msg in chat:
235
+ msg = _parse_msg(msg)
236
+ role = msg.get("role", "")
237
+ content = msg.get("content", "")
238
+
239
+ if role in ("assistant", "ai") and content:
240
+ response_only = _extract_response_only(content)
241
+ if response_only and len(response_only) > 50:
242
+ texts.append(response_only)
243
+ return texts
244
+
245
+
246
+ def extract_texts_text_field(rows, field="text"):
247
+ """Extract from a text field."""
248
+ texts = []
249
+ for row in rows:
250
+ content = row.get(field, "")
251
+ if content and len(str(content)) > 50:
252
+ response_only = _extract_response_only(str(content))
253
+ if response_only and len(response_only) > 50:
254
+ texts.append(response_only)
255
+ return texts
256
+
257
+
258
+ def extract_texts_sft(rows):
259
+ """Extract from SFT format (prompt + response/completion)."""
260
+ texts = []
261
+ for row in rows:
262
+ response = row.get("response") or row.get("completion") or ""
263
+ if response and len(str(response)) > 50:
264
+ response_only = _extract_response_only(str(response))
265
+ if response_only and len(response_only) > 50:
266
+ texts.append(response_only)
267
+ return texts
268
+
269
+
270
+ def extract_texts_qa(rows):
271
+ """Extract from Q&A format (use answer as response)."""
272
+ texts = []
273
+ for row in rows:
274
+ answer = row.get("answer", "")
275
+ if answer and len(str(answer)) > 50:
276
+ response_only = _extract_response_only(str(answer))
277
+ if response_only and len(response_only) > 50:
278
+ texts.append(response_only)
279
+ return texts
280
+
281
+
282
+ def extract_texts_messages_array(rows):
283
+ """Extract from messages array format."""
284
+ texts = []
285
+ for row in rows:
286
+ if isinstance(row, list):
287
+ for msg in row:
288
+ msg = _parse_msg(msg)
289
+ role = msg.get("role", "")
290
+ content = msg.get("content", "")
291
+
292
+ if role in ("assistant", "ai", "model") and content:
293
+ response_only = _extract_response_only(content)
294
+ if response_only and len(response_only) > 50:
295
+ texts.append(response_only)
296
+ return texts
297
+
298
+
299
+ def extract_texts_teichai_healer(rows):
300
+ """Extract from TeichAI Healer-Alpha format (prompt + response fields)."""
301
+ texts = []
302
+ for row in rows:
303
+ response = row.get("response")
304
+ if response:
305
+ if isinstance(response, dict):
306
+ response = response.get("content") or response.get("text") or ""
307
+ if response and len(str(response)) > 50:
308
+ response_only = _extract_response_only(str(response))
309
+ if response_only and len(response_only) > 50:
310
+ texts.append(response_only)
311
+ return texts
312
+
313
+
314
+ def _get_dataset_size(dataset_id, load_kwargs):
315
+ """Get dataset size without loading all data."""
316
+ try:
317
+ ds = load_dataset(dataset_id, split="train", streaming=True, **load_kwargs)
318
+ return ds.info.num_rows
319
+ except Exception:
320
+ pass
321
+ try:
322
+ import pandas as pd
323
+
324
+ url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/default/train/0.parquet"
325
+ df = pd.read_parquet(url)
326
+ return len(df)
327
+ except Exception:
328
+ return 0
329
+
330
+
331
+ def _streaming_download_with_progress(dataset_id, load_kwargs, progress_callback=None):
332
+ """Download dataset using streaming with progress tracking."""
333
+ import pandas as pd
334
+
335
+ total_rows = _get_dataset_size(dataset_id, load_kwargs)
336
+ print(f"[PROGRESS] Dataset size: {total_rows} rows", flush=True)
337
+
338
+ if total_rows > 0 and progress_callback:
339
+ progress_callback(0, total_rows, "fetching_info")
340
+ print(f"[PROGRESS] Initial callback: 0/{total_rows}", flush=True)
341
+
342
+ try:
343
+ ds = load_dataset(dataset_id, split="train", streaming=True, **load_kwargs)
344
+ rows = []
345
+ for i, row in enumerate(tqdm(ds, desc="Downloading", unit="rows")):
346
+ rows.append(row)
347
+ if progress_callback and total_rows > 0:
348
+ progress_callback(i + 1, total_rows, "downloading")
349
+ if i % 100 == 0:
350
+ print(
351
+ f"[PROGRESS] Downloaded {i + 1}/{total_rows} ({100 * (i + 1) / total_rows:.1f}%)",
352
+ flush=True,
353
+ )
354
+ return rows, total_rows
355
+ except Exception as e:
356
+ print(f"[PROGRESS] Streaming failed: {e}", flush=True)
357
+ pass
358
+
359
+ try:
360
+ url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/default/train/0.parquet"
361
+ df = pd.read_parquet(url)
362
+ total = len(df)
363
+ if progress_callback:
364
+ progress_callback(0, total, "downloading")
365
+ rows = []
366
+ for i, row in enumerate(df.to_dict(orient="records")):
367
+ rows.append(row)
368
+ if progress_callback:
369
+ progress_callback(i + 1, total, "downloading")
370
+ return rows, total
371
+ except Exception as e:
372
+ raise e
373
+
374
+
375
+ def _load_sample_rows(dataset_id, sample_size, load_kwargs):
376
+ """Load just a few rows for format detection."""
377
+ try:
378
+ ds = load_dataset(dataset_id, split="train", streaming=True, **load_kwargs)
379
+ return [next(iter(ds)) for _ in range(sample_size)]
380
+ except Exception:
381
+ pass
382
+ try:
383
+ import pandas as pd
384
+
385
+ url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/default/train/0.parquet"
386
+ df = pd.read_parquet(url)
387
+ return df.head(sample_size).to_dict(orient="records")
388
+ except Exception:
389
+ return []
390
+
391
+
392
+ def load_dataset_texts(
393
+ dataset_id,
394
+ max_samples=None,
395
+ sample_size=None,
396
+ progress_callback=None,
397
+ custom_format=None,
398
+ ):
399
+ """
400
+ Load a HuggingFace dataset and extract assistant response texts.
401
+ Returns: {
402
+ "texts": list of extracted texts,
403
+ "format": detected format name,
404
+ "format_info": format info dict,
405
+ "total_rows": total rows in dataset,
406
+ "supported": bool,
407
+ "error": error message if failed,
408
+ }
409
+ progress_callback: optional function(current, total, stage) -> None
410
+ stage can be: "fetching_info", "downloading", "extracting"
411
+ custom_format: optional custom format specification string
412
+ Examples:
413
+ - "column: response"
414
+ - "column: prompt, column: response"
415
+ - "pattern: user:, pattern: assistant:"
416
+ - "user:[startuser]assistant:[startassistant]"
417
+ """
418
+ load_kwargs = {"token": HF_TOKEN} if HF_TOKEN else {}
419
+ rows = []
420
+ total_rows = 0
421
+
422
+ if sample_size:
423
+ total_rows = _get_dataset_size(dataset_id, load_kwargs)
424
+ if total_rows == 0:
425
+ return {
426
+ "texts": [],
427
+ "format": None,
428
+ "format_info": None,
429
+ "total_rows": 0,
430
+ "supported": False,
431
+ "error": "Dataset is empty",
432
+ }
433
+
434
+ rows = _load_sample_rows(dataset_id, sample_size, load_kwargs)
435
+ else:
436
+ if progress_callback:
437
+ try:
438
+ rows, total_rows = _streaming_download_with_progress(
439
+ dataset_id, load_kwargs, progress_callback
440
+ )
441
+ except Exception as e:
442
+ fallback_error = None
443
+ try:
444
+ import pandas as pd
445
+
446
+ url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/default/train/0.parquet"
447
+ df = pd.read_parquet(url)
448
+ total_rows = len(df)
449
+ if progress_callback:
450
+ progress_callback(0, total_rows, "downloading")
451
+ rows = []
452
+ for i, row in enumerate(df.to_dict(orient="records")):
453
+ rows.append(row)
454
+ if progress_callback:
455
+ progress_callback(i + 1, total_rows, "downloading")
456
+ except Exception as e2:
457
+ fallback_error = str(e2)
458
+ return {
459
+ "texts": [],
460
+ "format": None,
461
+ "format_info": None,
462
+ "total_rows": 0,
463
+ "supported": False,
464
+ "error": f"Failed to load: {e}. Parquet fallback also failed: {fallback_error}",
465
+ }
466
+ else:
467
+ try:
468
+ ds = load_dataset(dataset_id, split="train", **load_kwargs)
469
+ total_rows = len(ds)
470
+ rows = list(ds)
471
+ except Exception as e:
472
+ fallback_error = None
473
+ try:
474
+ import pandas as pd
475
+
476
+ url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/default/train/0.parquet"
477
+ df = pd.read_parquet(url)
478
+ total_rows = len(df)
479
+ rows = df.to_dict(orient="records")
480
+ except Exception as e2:
481
+ fallback_error = str(e2)
482
+ return {
483
+ "texts": [],
484
+ "format": None,
485
+ "format_info": None,
486
+ "total_rows": 0,
487
+ "supported": False,
488
+ "error": f"Failed to load: {e}. Parquet fallback also failed: {fallback_error}",
489
+ }
490
+
491
+ if not rows:
492
+ return {
493
+ "texts": [],
494
+ "format": None,
495
+ "format_info": None,
496
+ "total_rows": 0,
497
+ "supported": False,
498
+ "error": "Dataset is empty",
499
+ }
500
+
501
+ detect_rows = rows[:sample_size] if sample_size else rows
502
+
503
+ custom_format_spec = custom_format
504
+ if custom_format_spec and check_custom_format(detect_rows, custom_format_spec):
505
+ fmt_name = "custom"
506
+ fmt_info = {
507
+ "name": "Custom Format",
508
+ "description": f"Custom format: {custom_format_spec}",
509
+ "examples": [],
510
+ }
511
+ else:
512
+ fmt_name, fmt_info = detect_format(detect_rows, sample_size=sample_size or 10)
513
+
514
+ if fmt_name is None:
515
+ return {
516
+ "texts": [],
517
+ "format": None,
518
+ "format_info": None,
519
+ "total_rows": total_rows,
520
+ "supported": False,
521
+ "error": "Unknown dataset format. Supported formats: "
522
+ + ", ".join(f["name"] for f in SUPPORTED_FORMATS.values()),
523
+ }
524
+
525
+ extractors = {
526
+ "teichai_healer": extract_texts_teichai_healer,
527
+ "teichai": extract_texts_conversations,
528
+ "conversations": extract_texts_conversations,
529
+ "chat": extract_texts_chat,
530
+ "text": lambda r: extract_texts_text_field(r, "text"),
531
+ "response": lambda r: extract_texts_text_field(r, "response")
532
+ or extract_texts_text_field(r, "output"),
533
+ "content": lambda r: extract_texts_text_field(r, "content"),
534
+ "messages": extract_texts_messages_array,
535
+ "sft": extract_texts_sft,
536
+ "qa": extract_texts_qa,
537
+ "combined": lambda r: (
538
+ extract_texts_text_field(r, "output")
539
+ or extract_texts_text_field(r, "outputs")
540
+ or extract_texts_text_field(r, "generated")
541
+ or extract_texts_text_field(r, "completion")
542
+ or extract_texts_text_field(r, "combined")
543
+ or extract_texts_text_field(r, "data")
544
+ or extract_texts_text_field(r, "example")
545
+ ),
546
+ "completion": lambda r: extract_texts_text_field(r, "completion"),
547
+ "generations": lambda r: (
548
+ extract_texts_text_field(r, "generations")
549
+ or extract_texts_text_field(r, "generation")
550
+ ),
551
+ "custom": lambda r: extract_texts_custom(r, custom_format_spec),
552
+ }
553
+
554
+ extractor = extractors.get(fmt_name)
555
+ texts = extractor(rows) if extractor else []
556
+
557
+ if max_samples and len(texts) > max_samples:
558
+ random.seed(42)
559
+ texts = random.sample(texts, max_samples)
560
+
561
+ return {
562
+ "texts": texts,
563
+ "format": fmt_name,
564
+ "format_info": fmt_info,
565
+ "total_rows": total_rows,
566
+ "supported": True,
567
+ "error": None,
568
+ }
569
+
570
+
571
+ def parse_custom_format_spec(spec):
572
+ """
573
+ Parse custom format specification.
574
+
575
+ Supported formats:
576
+ - "column: <field_name>" - extract single field as text
577
+ - "column: <user_col>, column: <assistant_col>" - extract from two columns (user/assistant)
578
+ - "pattern: <start_marker>user<end_marker>, pattern: <start_marker>assistant<end_marker>" - use regex patterns
579
+ - "delimiter: <delim>" - use delimiter to split columns
580
+
581
+ Examples:
582
+ - "column: response"
583
+ - "column: prompt, column: response"
584
+ - "pattern: user:, pattern: assistant:"
585
+ - "user:[startuser]assistant:[startassistant]"
586
+ """
587
+ if not spec:
588
+ return None
589
+
590
+ spec = spec.strip()
591
+ result = {
592
+ "type": None,
593
+ "user_field": None,
594
+ "assistant_field": None,
595
+ "user_pattern": None,
596
+ "assistant_pattern": None,
597
+ }
598
+
599
+ if spec.startswith("column:") or spec.startswith("col:"):
600
+ cols_spec = spec.replace("column:", "").replace("col:", "").strip()
601
+ if "," in cols_spec:
602
+ parts = [p.strip() for p in cols_spec.split(",")]
603
+ if len(parts) >= 2:
604
+ result["type"] = "two_column"
605
+ result["user_field"] = parts[0]
606
+ result["assistant_field"] = parts[1]
607
+ else:
608
+ result["type"] = "single_column"
609
+ result["assistant_field"] = cols_spec
610
+ return result
611
+
612
+ if spec.startswith("pattern:") or spec.startswith("regex:"):
613
+ patterns_spec = spec.replace("pattern:", "").replace("regex:", "").strip()
614
+ if "," in patterns_spec:
615
+ parts = [p.strip() for p in patterns_spec.split(",")]
616
+ if len(parts) >= 2:
617
+ result["type"] = "two_pattern"
618
+ result["user_pattern"] = parts[0]
619
+ result["assistant_pattern"] = parts[1]
620
+ else:
621
+ result["type"] = "single_pattern"
622
+ result["assistant_pattern"] = patterns_spec
623
+ return result
624
+
625
+ if "user:" in spec.lower() and "assistant:" in spec.lower():
626
+ import re
627
+
628
+ user_match = re.search(
629
+ r"user:\s*(\[.*?\]|(?:(?!\s+assistant:).)+)",
630
+ spec,
631
+ re.IGNORECASE | re.DOTALL,
632
+ )
633
+ assistant_match = re.search(
634
+ r"assistant:\s*(\[.*?\]|(?:(?:\s+user:|$).)+)",
635
+ spec,
636
+ re.IGNORECASE | re.DOTALL,
637
+ )
638
+
639
+ if user_match and assistant_match:
640
+ result["type"] = "two_pattern"
641
+ result["user_pattern"] = user_match.group(1).strip()
642
+ result["assistant_pattern"] = assistant_match.group(1).strip()
643
+ return result
644
+
645
+ if "[startuser]" in spec and "[startassistant]" in spec:
646
+ result["type"] = "two_pattern"
647
+ result["user_pattern"] = re.escape("[startuser]")
648
+ result["assistant_pattern"] = re.escape("[startassistant]")
649
+ return result
650
+
651
+ if "," in spec:
652
+ parts = [p.strip() for p in spec.split(",")]
653
+ if len(parts) >= 2:
654
+ result["type"] = "two_column"
655
+ result["user_field"] = parts[0]
656
+ result["assistant_field"] = parts[1]
657
+ return result
658
+
659
+ result["type"] = "single_column"
660
+ result["assistant_field"] = spec
661
+ return result
662
+
663
+
664
+ def extract_texts_custom(rows, format_spec):
665
+ """Extract texts using custom format specification."""
666
+ parsed = parse_custom_format_spec(format_spec)
667
+ if not parsed or not parsed.get("type"):
668
+ return []
669
+
670
+ texts = []
671
+
672
+ if parsed["type"] == "single_column":
673
+ field = parsed["assistant_field"]
674
+ for row in rows:
675
+ content = row.get(field, "")
676
+ if content and len(str(content)) > 50:
677
+ response_only = _extract_response_only(str(content))
678
+ if response_only and len(response_only) > 50:
679
+ texts.append(response_only)
680
+
681
+ elif parsed["type"] == "two_column":
682
+ user_field = parsed.get("user_field")
683
+ assistant_field = parsed["assistant_field"]
684
+ for row in rows:
685
+ user_content = row.get(user_field, "") if user_field else ""
686
+ assistant_content = row.get(assistant_field, "")
687
+ if assistant_content and len(str(assistant_content)) > 50:
688
+ response_only = _extract_response_only(str(assistant_content))
689
+ if response_only and len(response_only) > 50:
690
+ texts.append(response_only)
691
+
692
+ elif parsed["type"] == "single_pattern":
693
+ pattern = parsed.get("assistant_pattern")
694
+ if pattern:
695
+ try:
696
+ regex = re.compile(pattern, re.DOTALL | re.IGNORECASE)
697
+ for row in rows:
698
+ row_str = str(row)
699
+ match = regex.search(row_str)
700
+ if match:
701
+ content = match.group(1) if match.groups() else match.group(0)
702
+ if content and len(content) > 50:
703
+ response_only = _extract_response_only(content)
704
+ if response_only and len(response_only) > 50:
705
+ texts.append(response_only)
706
+ except re.error:
707
+ pass
708
+
709
+ elif parsed["type"] == "two_pattern":
710
+ user_pattern = parsed.get("user_pattern")
711
+ assistant_pattern = parsed.get("assistant_pattern")
712
+ if assistant_pattern:
713
+ try:
714
+ user_regex = (
715
+ re.compile(user_pattern, re.DOTALL | re.IGNORECASE)
716
+ if user_pattern
717
+ else None
718
+ )
719
+ assistant_regex = re.compile(
720
+ assistant_pattern, re.DOTALL | re.IGNORECASE
721
+ )
722
+
723
+ for row in rows:
724
+ row_str = str(row)
725
+ match = assistant_regex.search(row_str)
726
+ if match:
727
+ content = match.group(1) if match.groups() else match.group(0)
728
+ if content and len(content) > 50:
729
+ response_only = _extract_response_only(content)
730
+ if response_only and len(response_only) > 50:
731
+ texts.append(response_only)
732
+ except re.error:
733
+ pass
734
+
735
+ return texts
736
+
737
+
738
+ def check_custom_format(rows, format_spec):
739
+ """Check if custom format applies to the dataset."""
740
+ parsed = parse_custom_format_spec(format_spec)
741
+ if not parsed or not parsed.get("type"):
742
+ return False
743
+
744
+ if not rows:
745
+ return False
746
+
747
+ sample = rows[0]
748
+
749
+ if parsed["type"] == "single_column":
750
+ return parsed.get("assistant_field") in sample
751
+
752
+ if parsed["type"] == "two_column":
753
+ return parsed.get("assistant_field") in sample
754
+
755
+ if parsed["type"] in ("single_pattern", "two_pattern"):
756
+ pattern = parsed.get("assistant_pattern")
757
+ if pattern:
758
+ try:
759
+ regex = re.compile(pattern, re.DOTALL | re.IGNORECASE)
760
+ return regex.search(str(sample)) is not None
761
+ except re.error:
762
+ pass
763
+
764
+ return False
765
+
766
+
767
+ def get_supported_formats():
768
+ """Return list of supported format info."""
769
+ return SUPPORTED_FORMATS
evaluate_dataset.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AIFinder Dataset Evaluator with Server
3
+ Runs the Flask server, then allows interactive dataset input.
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import time
9
+ import argparse
10
+ import random
11
+ import threading
12
+ import requests
13
+ from collections import defaultdict
14
+
15
+ from datasets import load_dataset
16
+ from tqdm import tqdm
17
+
18
+ from config import MODEL_DIR
19
+ from inference import AIFinder
20
+
21
+
22
+ HF_TOKEN = os.environ.get("HF_TOKEN")
23
+ SERVER_URL = "http://localhost:7860"
24
+
25
+
26
+ def start_server():
27
+ """Start Flask server in background thread."""
28
+ os.chdir(os.path.dirname(os.path.abspath(__file__)))
29
+ from app import app, load_models
30
+
31
+ load_models()
32
+ print("Server started on http://localhost:7860")
33
+ app.run(host="0.0.0.0", port=7860, debug=False, use_reloader=False)
34
+
35
+
36
+ def wait_for_server(timeout=30):
37
+ """Wait for server to be ready."""
38
+ start = time.time()
39
+ while time.time() - start < timeout:
40
+ try:
41
+ resp = requests.get(f"{SERVER_URL}/api/status", timeout=2)
42
+ if resp.status_code == 200:
43
+ return True
44
+ except requests.exceptions.RequestException:
45
+ pass
46
+ time.sleep(1)
47
+ return False
48
+
49
+
50
+ def _parse_msg(msg):
51
+ """Parse a message that may be a dict or a JSON string."""
52
+ if isinstance(msg, dict):
53
+ return msg
54
+ if isinstance(msg, str):
55
+ try:
56
+ import json
57
+
58
+ parsed = json.loads(msg)
59
+ if isinstance(parsed, dict):
60
+ return parsed
61
+ except (ValueError, Exception):
62
+ pass
63
+ return {}
64
+
65
+
66
+ def _extract_response_only(content):
67
+ """Extract only the final response, stripping CoT blocks."""
68
+ import re
69
+
70
+ if not content:
71
+ return ""
72
+ think_match = re.search(r"</?think(?:ing)?>(.*)$", content, re.DOTALL)
73
+ if think_match:
74
+ response = think_match.group(1).strip()
75
+ if response:
76
+ return response
77
+ return content
78
+
79
+
80
+ def extract_texts_from_dataset(dataset_id, max_samples=None):
81
+ """Extract assistant response texts from a HuggingFace dataset."""
82
+ print(f"\nLoading dataset: {dataset_id}")
83
+
84
+ load_kwargs = {"token": HF_TOKEN} if HF_TOKEN else {}
85
+ rows = []
86
+
87
+ try:
88
+ ds = load_dataset(dataset_id, split="train", **load_kwargs)
89
+ rows = list(ds)
90
+ except Exception as e:
91
+ print(f"Failed to load dataset: {e}")
92
+ try:
93
+ import pandas as pd
94
+
95
+ url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/default/train/0.parquet"
96
+ df = pd.read_parquet(url)
97
+ rows = df.to_dict(orient="records")
98
+ except Exception as e2:
99
+ print(f"Parquet fallback also failed: {e2}")
100
+ return []
101
+
102
+ texts = []
103
+ for row in rows:
104
+ convos = row.get("conversations") or row.get("messages") or []
105
+
106
+ if not convos:
107
+ continue
108
+
109
+ for msg in convos:
110
+ msg = _parse_msg(msg)
111
+ role = msg.get("role", "")
112
+ content = msg.get("content", "")
113
+
114
+ if role in ("assistant", "gpt", "model") and content:
115
+ response_only = _extract_response_only(content)
116
+ if response_only and len(response_only) > 50:
117
+ texts.append(response_only)
118
+
119
+ if max_samples and len(texts) > max_samples:
120
+ random.seed(42)
121
+ texts = random.sample(texts, max_samples)
122
+
123
+ return texts
124
+
125
+
126
+ def evaluate_dataset(texts):
127
+ """Evaluate all texts via API and aggregate results."""
128
+ results = {
129
+ "total": len(texts),
130
+ "provider_counts": defaultdict(int),
131
+ "confidences": defaultdict(list),
132
+ }
133
+
134
+ for text in tqdm(texts, desc="Evaluating"):
135
+ try:
136
+ resp = requests.post(
137
+ f"{SERVER_URL}/api/classify",
138
+ json={"text": text, "top_n": 5},
139
+ timeout=30,
140
+ )
141
+ if resp.status_code == 200:
142
+ result = resp.json()
143
+ pred_provider = result.get("provider")
144
+ confidence = result.get("confidence", 0) / 100.0
145
+
146
+ if pred_provider:
147
+ results["provider_counts"][pred_provider] += 1
148
+ results["confidences"][pred_provider].append(confidence)
149
+ except Exception as e:
150
+ print(f"Error: {e}")
151
+ continue
152
+
153
+ return results
154
+
155
+
156
+ def print_results(results):
157
+ """Print aggregated evaluation results."""
158
+ total = results["total"]
159
+ print("\n" + "=" * 60)
160
+ print(f"EVALUATION RESULTS ({total} samples)")
161
+ print("=" * 60)
162
+
163
+ print("\n--- Predicted Provider Distribution ---")
164
+ for provider, count in sorted(
165
+ results["provider_counts"].items(), key=lambda x: -x[1]
166
+ ):
167
+ pct = (count / total) * 100
168
+ avg_conf = sum(results["confidences"][provider]) / len(
169
+ results["confidences"][provider]
170
+ )
171
+ print(
172
+ f" {provider}: {count} ({pct:.1f}%) - Avg confidence: {avg_conf * 100:.1f}%"
173
+ )
174
+
175
+ if results["confidences"]:
176
+ print("\n--- Top Providers (by cumulative confidence) ---")
177
+ provider_scores = {}
178
+ for provider, confs in results["confidences"].items():
179
+ if confs:
180
+ avg_conf = sum(confs) / len(confs)
181
+ count = results["provider_counts"][provider]
182
+ provider_scores[provider] = avg_conf * count
183
+
184
+ for provider, score in sorted(provider_scores.items(), key=lambda x: -x[1])[:3]:
185
+ print(f" {provider}: {score:.2f}")
186
+
187
+ print("\n" + "=" * 60)
188
+
189
+
190
+ def main():
191
+ parser = argparse.ArgumentParser(
192
+ description="AIFinder Dataset Evaluator with Server"
193
+ )
194
+ parser.add_argument(
195
+ "--max-samples", type=int, default=None, help="Max samples to test"
196
+ )
197
+ args = parser.parse_args()
198
+
199
+ print("Starting AIFinder server...")
200
+ server_thread = threading.Thread(target=start_server, daemon=True)
201
+ server_thread.start()
202
+
203
+ print("Waiting for server...")
204
+ if not wait_for_server():
205
+ print("Server failed to start!")
206
+ sys.exit(1)
207
+
208
+ print("\n" + "=" * 60)
209
+ print("AIFinder Server Ready!")
210
+ print("=" * 60)
211
+ print(f"Server running at: {SERVER_URL}")
212
+ print("Enter a HuggingFace dataset ID to evaluate.")
213
+ print("Examples: ianncity/Hunter-Alpha-SFT-300000x")
214
+ print("Type 'quit' or 'exit' to stop.")
215
+ print("=" * 60 + "\n")
216
+
217
+ while True:
218
+ try:
219
+ dataset_id = input("Dataset ID: ").strip()
220
+
221
+ if dataset_id.lower() in ("quit", "exit", "q"):
222
+ print("Goodbye!")
223
+ break
224
+
225
+ if not dataset_id:
226
+ continue
227
+
228
+ texts = extract_texts_from_dataset(dataset_id, args.max_samples)
229
+
230
+ if not texts:
231
+ print("No valid texts found in dataset.")
232
+ continue
233
+
234
+ print(f"Testing {len(texts)} responses...")
235
+ results = evaluate_dataset(texts)
236
+ print_results(results)
237
+
238
+ except KeyboardInterrupt:
239
+ print("\nGoodbye!")
240
+ break
241
+ except Exception as e:
242
+ print(f"Error: {e}")
243
+
244
+
245
+ if __name__ == "__main__":
246
+ main()
features.py CHANGED
@@ -1,6 +1,6 @@
1
  """
2
  AIFinder Feature Extraction
3
- TF-IDF and stylometric features for AI model detection.
4
  """
5
 
6
  import re
@@ -12,25 +12,198 @@ from sklearn.preprocessing import MaxAbsScaler
12
 
13
  from config import TFIDF_WORD_PARAMS, TFIDF_CHAR_PARAMS
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  def strip_cot(text):
17
- text = re.sub(r"<think(?:ing)?>.*?</think(?:ing)?>", "", text, flags=re.DOTALL)
18
- return text.strip()
19
 
20
 
21
  def strip_markdown(text):
22
- text = re.sub(r"```[\s\S]*?```", "", text)
23
- text = re.sub(r"`[^`]+`", "", text)
24
- text = re.sub(r"\*\*([^*]+)\*\*", r"\1", text)
25
- text = re.sub(r"\*([^*]+)\*", r"\1", text)
26
- text = re.sub(r"__([^_]+)__", r"\1", text)
27
- text = re.sub(r"_([^_]+)_", r"\1", text)
28
- text = re.sub(r"^#{1,6}\s+", "", text, flags=re.MULTILINE)
29
- text = re.sub(r"^[\s]*[-*+]\s+", "", text, flags=re.MULTILINE)
30
- text = re.sub(r"^\s*\d+[.)]\s+", "", text, flags=re.MULTILINE)
31
- text = re.sub(r"\[([^\]]+)\]\([^)]+\)", r"\1", text)
32
- text = re.sub(r"^>.*$", "", text, flags=re.MULTILINE)
33
- text = re.sub(r"^---+$", "", text, flags=re.MULTILINE)
34
  return text.strip()
35
 
36
 
@@ -39,18 +212,14 @@ class StylometricFeatures(BaseEstimator, TransformerMixin):
39
  return self
40
 
41
  def transform(self, X):
42
- features = []
43
- for text in X:
44
- features.append(self._extract(text))
45
- return csr_matrix(np.array(features, dtype=np.float32))
46
 
47
  def _extract(self, text):
48
- words = text.split()
49
  n_chars = max(len(text), 1)
 
50
  n_words = max(len(words), 1)
51
 
52
- sentences = re.split(r"[.!?]+", text)
53
- sentences = [s.strip() for s in sentences if s.strip()]
54
  n_sentences = max(len(sentences), 1)
55
 
56
  paragraphs = text.split("\n\n")
@@ -58,17 +227,21 @@ class StylometricFeatures(BaseEstimator, TransformerMixin):
58
  n_paragraphs = len(non_empty_paras)
59
 
60
  lines = text.split("\n")
61
- non_empty_lines = [l for l in lines if l.strip()]
62
  n_lines = max(len(non_empty_lines), 1)
63
 
64
- # === Word-level stats ===
65
  word_lens = [len(w) for w in words]
66
- avg_word_len = np.mean(word_lens) if words else 0
67
- word_len_std = np.std(word_lens) if len(words) > 1 else 0
68
- median_word_len = np.median(word_lens) if words else 0
 
 
 
 
 
 
69
  avg_sent_len = n_words / n_sentences
70
 
71
- # === Punctuation density ===
72
  n_commas = text.count(",") / n_chars
73
  n_semicolons = text.count(";") / n_chars
74
  n_colons = text.count(":") / n_chars
@@ -84,16 +257,14 @@ class StylometricFeatures(BaseEstimator, TransformerMixin):
84
  comma_period_ratio = n_commas / (n_period + 0.001)
85
  excl_question_ratio = n_exclaim / (n_question + 0.001)
86
 
87
- # === Markdown/formatting features ===
88
- n_headers = len(re.findall(r"^#{1,6}\s", text, re.MULTILINE)) / n_sentences
89
- n_bold = len(re.findall(r"\*\*.*?\*\*", text)) / n_sentences
90
- n_code_blocks = len(re.findall(r"```", text)) / n_sentences
91
- n_inline_code = len(re.findall(r"`[^`]+`", text)) / n_sentences
92
- n_bullet = len(re.findall(r"^[\s]*[-*+]\s", text, re.MULTILINE)) / n_sentences
93
- n_numbered = len(re.findall(r"^\s*\d+[.)]\s", text, re.MULTILINE)) / n_sentences
94
- n_tables = len(re.findall(r"\|.*\|", text)) / n_sentences
95
 
96
- # === Whitespace & structure ===
97
  newline_density = text.count("\n") / n_chars
98
  double_newline_ratio = text.count("\n\n") / (text.count("\n") + 1)
99
  uppercase_ratio = sum(1 for c in text if c.isupper()) / n_chars
@@ -103,59 +274,40 @@ class StylometricFeatures(BaseEstimator, TransformerMixin):
103
  unique_chars = len(set(text)) / n_chars
104
  unique_chars_ratio = len(set(text.lower())) / n_chars
105
 
106
- # === Sentence-level stats ===
107
- sent_lens = [len(s.split()) for s in sentences]
108
- sent_len_std = np.std(sent_lens) if len(sent_lens) > 1 else 0
109
  sent_len_max = max(sent_lens) if sent_lens else 0
110
  sent_len_min = min(sent_lens) if sent_lens else 0
111
- sent_len_median = np.median(sent_lens) if sent_lens else 0
112
  sent_len_range = sent_len_max - sent_len_min
113
 
114
- # === Structural markers ===
115
- has_think = 1.0 if re.search(r"<think>", text) else 0.0
116
- has_xml = 1.0 if re.search(r"<[^>]+>", text) else 0.0
117
- has_hr = 1.0 if re.search(r"^---+", text, re.MULTILINE) else 0.0
118
- has_url = 1.0 if re.search(r"https?://", text) else 0.0
119
 
120
- # === Pronoun and person features ===
121
  words_lower = [w.lower().strip(".,!?;:'\"()[]{}") for w in words]
122
-
123
- first_person = {
124
- "i",
125
- "me",
126
- "my",
127
- "mine",
128
- "myself",
129
- "we",
130
- "us",
131
- "our",
132
- "ours",
133
- "ourselves",
134
- }
135
- second_person = {"you", "your", "yours", "yourself", "yourselves"}
136
- third_person = {"he", "she", "it", "they", "them", "his", "her", "its", "their"}
137
-
138
- first_person_ratio = sum(1 for w in words_lower if w in first_person) / n_words
139
  second_person_ratio = (
140
- sum(1 for w in words_lower if w in second_person) / n_words
141
  )
142
- third_person_ratio = sum(1 for w in words_lower if w in third_person) / n_words
143
 
144
- # === Vocabulary richness ===
145
  unique_words = len(set(words_lower))
146
- ttr = unique_words / n_words if n_words > 0 else 0
147
- hapax = sum(1 for w in set(words_lower) if words_lower.count(w) == 1)
148
- hapax_ratio = hapax / n_words if n_words > 0 else 0
 
 
 
149
 
150
- contraction_count = len(re.findall(r"\b\w+'\w+\b", text))
151
- contraction_ratio = contraction_count / n_words if n_words > 0 else 0
152
 
153
- # === Sentence starters ===
154
  sentences_starters = [
155
  s.split()[0].lower() if s.split() else "" for s in sentences
156
  ]
157
  starter_vocab = (
158
- len(set(sentences_starters)) / n_sentences if n_sentences > 0 else 0
159
  )
160
 
161
  and_starts = sum(1 for s in sentences_starters if s == "and") / n_sentences
@@ -170,281 +322,119 @@ class StylometricFeatures(BaseEstimator, TransformerMixin):
170
  / n_sentences
171
  )
172
 
173
- # === Word length distributions ===
174
  short_word_ratio = sum(1 for w in words_lower if len(w) <= 2) / n_words
175
  medium_word_ratio = sum(1 for w in words_lower if 3 <= len(w) <= 6) / n_words
176
  long_word_ratio = sum(1 for w in words_lower if len(w) >= 7) / n_words
177
  very_long_word_ratio = sum(1 for w in words_lower if len(w) >= 10) / n_words
178
 
179
- # === Paragraph stats ===
180
  para_lens = (
181
  [len(p.split()) for p in non_empty_paras] if non_empty_paras else [0]
182
  )
183
  avg_para_len = np.mean(para_lens)
184
- para_len_std = np.std(para_lens) if len(para_lens) > 1 else 0
185
 
186
- # === Discourse markers ===
187
- conjunctions = {
188
- "and",
189
- "but",
190
- "or",
191
- "nor",
192
- "for",
193
- "yet",
194
- "so",
195
- "because",
196
- "although",
197
- "while",
198
- "if",
199
- "when",
200
- "where",
201
- }
202
- discourse = {
203
- "however",
204
- "therefore",
205
- "moreover",
206
- "furthermore",
207
- "nevertheless",
208
- "consequently",
209
- "thus",
210
- "hence",
211
- }
212
- hedging = {
213
- "perhaps",
214
- "maybe",
215
- "might",
216
- "could",
217
- "possibly",
218
- "seemingly",
219
- "apparently",
220
- "arguably",
221
- "potentially",
222
- }
223
- certainty = {
224
- "definitely",
225
- "certainly",
226
- "absolutely",
227
- "clearly",
228
- "obviously",
229
- "undoubtedly",
230
- "indeed",
231
- "surely",
232
- }
233
- transition = {
234
- "additionally",
235
- "meanwhile",
236
- "subsequently",
237
- "alternatively",
238
- "specifically",
239
- "notably",
240
- "importantly",
241
- "essentially",
242
- }
243
-
244
- conjunction_ratio = sum(1 for w in words_lower if w in conjunctions) / n_words
245
- discourse_ratio = sum(1 for w in words_lower if w in discourse) / n_words
246
- hedging_ratio = sum(1 for w in words_lower if w in hedging) / n_words
247
- certainty_ratio = sum(1 for w in words_lower if w in certainty) / n_words
248
- transition_ratio = sum(1 for w in words_lower if w in transition) / n_words
249
 
250
- # === Question patterns ===
251
  question_starts = sum(
252
- 1
253
- for s in sentences
254
- if s
255
- and s.strip()
256
- .lower()
257
- .startswith(("who", "what", "when", "where", "why", "how"))
258
  )
259
 
260
- # === List features ===
261
  has_list = 1.0 if n_bullet > 0 or n_numbered > 0 else 0.0
262
  list_items = n_bullet + n_numbered
263
 
264
- # === Emoji and special chars ===
265
- emoji_count = len(re.findall(r"[\U00010000-\U0010ffff]", text))
266
  has_emoji = 1.0 if emoji_count > 0 else 0.0
267
 
268
- # === Specific style markers ===
269
- # ALL CAPS words (emphasis style)
270
  all_caps_words = sum(
271
  1 for w in words if len(w) > 1 and w.isupper() and w.isalpha()
272
  )
273
  all_caps_ratio = all_caps_words / n_words
274
 
275
- # Parenthetical asides
276
- paren_count = len(re.findall(r"\([^)]+\)", text))
277
  paren_ratio = paren_count / n_sentences
278
 
279
- # Rhetorical questions (sentences ending with ?)
280
  rhetorical_q = sum(1 for s in text.split("\n") if s.strip().endswith("?"))
281
  rhetorical_ratio = rhetorical_q / n_sentences
282
 
283
- # Direct address / casual markers
284
- casual_markers = {
285
- "okay",
286
- "ok",
287
- "hey",
288
- "hi",
289
- "cool",
290
- "awesome",
291
- "wow",
292
- "basically",
293
- "actually",
294
- "literally",
295
- "right",
296
- "yeah",
297
- }
298
- casual_ratio = sum(1 for w in words_lower if w in casual_markers) / n_words
299
-
300
- # Formal markers
301
- formal_markers = {
302
- "regarding",
303
- "concerning",
304
- "pertaining",
305
- "aforementioned",
306
- "respectively",
307
- "accordingly",
308
- "henceforth",
309
- "whereby",
310
- "notwithstanding",
311
- "pursuant",
312
- }
313
- formal_ratio = sum(1 for w in words_lower if w in formal_markers) / n_words
314
 
315
- # Chinese character detection
316
- chinese_chars = len(re.findall(r"[\u4e00-\u9fff]", text))
317
  has_chinese = 1.0 if chinese_chars > 0 else 0.0
318
  chinese_ratio = chinese_chars / n_chars
319
 
320
- # Self-identification patterns
321
- has_self_id_ai = (
322
- 1.0
323
- if re.search(
324
- r"\b(I'm|I am)\s+(an?\s+)?(AI|language model|assistant|chatbot)\b",
325
- text,
326
- re.IGNORECASE,
327
- )
328
- else 0.0
329
- )
330
- has_provider_mention = (
331
- 1.0
332
- if re.search(
333
- r"\b(Claude|Anthropic|GPT|OpenAI|ChatGPT|Gemini|Google|Bard|Grok|xAI"
334
- r"|DeepSeek|Kimi|Moonshot|Mistral|MiniMax|Zhipu|GLM|深度求索)\b",
335
- text,
336
- re.IGNORECASE,
337
- )
338
- else 0.0
339
- )
340
 
341
- # Response ending patterns
342
  ends_with_question = 1.0 if text.rstrip().endswith("?") else 0.0
343
- has_closing_offer = (
344
- 1.0
345
- if re.search(
346
- r"(let me know|feel free|happy to help|don't hesitate|hope this helps)",
347
- text,
348
- re.IGNORECASE,
349
- )
350
- else 0.0
351
- )
352
 
353
- # Sentence complexity (approximation via commas per sentence)
354
  commas_per_sentence = text.count(",") / n_sentences
355
 
356
- # Line-level features
357
  avg_line_len = (
358
- np.mean([len(l) for l in non_empty_lines]) if non_empty_lines else 0
359
  )
360
  short_lines_ratio = (
361
- sum(1 for l in non_empty_lines if len(l.split()) <= 5) / n_lines
362
  )
363
 
364
- # Capitalized word ratio (proper nouns, emphasis)
365
- cap_words = len(re.findall(r"\b[A-Z][a-z]+\b", text))
366
  cap_word_ratio = cap_words / n_words
367
 
368
- # Multi-word phrases per sentence
369
- four_word_phrases = len(re.findall(r"\b\w+\s+\w+\s+\w+\s+\w+\b", text))
370
  phrase_ratio = four_word_phrases / n_sentences
371
 
372
- # Sentence boundary patterns
373
- sent_boundaries = len(re.findall(r"[.!?]\s+[A-Z]", text))
374
  sent_boundary_ratio = sent_boundaries / n_sentences
375
 
376
- # Special punctuation
377
- has_checkmark = (
378
- 1.0 if "✓" in text or "✗" in text or "" in text or "✘" in text else 0.0
379
- )
380
- has_arrow = 1.0 if "→" in text or "←" in text or "➡" in text else 0.0
381
- has_star = 1.0 if "⭐" in text or "★" in text or "☆" in text else 0.0
382
- special_unicode = len(re.findall(r"[^\x00-\x7F]", text)) / n_chars
383
 
384
- # Colon-based definitions (common in some providers)
385
- colon_definitions = len(re.findall(r"\b\w+:\s+\w+", text)) / n_sentences
386
 
387
- # Quotation usage
388
- double_quote_pairs = len(re.findall(r'"[^"]*"', text)) / n_sentences
389
- single_quote_pairs = len(re.findall(r"'[^']*'", text)) / n_sentences
390
 
391
- # Greeting patterns
392
- greeting_patterns = len(
393
- re.findall(
394
- r"\b(hi|hello|hey|hiya|greetings|howdy|yo)\b", text, re.IGNORECASE
395
- )
396
- )
397
  greeting_ratio = greeting_patterns / n_sentences
398
 
399
- # Response length categories
400
  is_short = 1.0 if n_words < 100 else 0.0
401
  is_medium = 1.0 if 100 <= n_words < 500 else 0.0
402
  is_long = 1.0 if n_words >= 500 else 0.0
403
 
404
- # Exclamation usage
405
  excl_sentences = sum(1 for s in sentences if s.strip().endswith("!"))
406
  excl_sentence_ratio = excl_sentences / n_sentences
407
 
408
- # Question-only responses
409
- question_lines = [l for l in non_empty_lines if l.strip().endswith("?")]
410
  question_line_ratio = len(question_lines) / n_lines if n_lines > 0 else 0.0
411
 
412
- # Common conversational phrases
413
- conversational_phrases = len(
414
- re.findall(
415
- r"\b(great|perfect|sure|definitely|certainly|absolutely|of course"
416
- r"|no problem|sounds good|got it|understood|okay|alright)\b",
417
- text,
418
- re.IGNORECASE,
419
- )
420
- )
421
  conv_phrase_ratio = conversational_phrases / n_words
422
 
423
- # Helpful/closing phrases
424
- helpful_phrases = len(
425
- re.findall(
426
- r"\b(let me know|feel free|happy to|glad to|happy to help"
427
- r"|don't hesitate|let me know if|please let me|reach out)\b",
428
- text,
429
- re.IGNORECASE,
430
- )
431
- )
432
  helpful_ratio = helpful_phrases / n_sentences
433
 
434
  return [
435
- # Basic word stats (0-3)
436
  avg_word_len,
437
  word_len_std,
438
  median_word_len,
439
  avg_sent_len,
440
- # Sentence stats (4-9)
441
  sent_len_std,
442
  sent_len_max,
443
  sent_len_min,
444
  sent_len_median,
445
  sent_len_range,
446
  commas_per_sentence,
447
- # Punctuation density (10-22)
448
  n_commas,
449
  n_semicolons,
450
  n_colons,
@@ -458,7 +448,6 @@ class StylometricFeatures(BaseEstimator, TransformerMixin):
458
  comma_colon_ratio,
459
  comma_period_ratio,
460
  excl_question_ratio,
461
- # Markdown features (23-30)
462
  n_headers,
463
  n_bold,
464
  n_code_blocks,
@@ -467,7 +456,6 @@ class StylometricFeatures(BaseEstimator, TransformerMixin):
467
  n_numbered,
468
  n_tables,
469
  has_list,
470
- # Structure (31-40)
471
  newline_density,
472
  double_newline_ratio,
473
  uppercase_ratio,
@@ -478,47 +466,37 @@ class StylometricFeatures(BaseEstimator, TransformerMixin):
478
  list_items,
479
  n_paragraphs,
480
  n_lines / n_sentences,
481
- # Sentence level (41-44)
482
  has_think,
483
  has_xml,
484
  has_hr,
485
  has_url,
486
- # Pronoun features (45-47)
487
  first_person_ratio,
488
  second_person_ratio,
489
  third_person_ratio,
490
- # Vocabulary (48-52)
491
  ttr,
492
  hapax_ratio,
493
  contraction_ratio,
494
  short_word_ratio,
495
  medium_word_ratio,
496
- # Word length distributions (53-54)
497
  long_word_ratio,
498
  very_long_word_ratio,
499
- # Sentence starters (55-60)
500
  starter_vocab,
501
  and_starts,
502
  but_starts,
503
  so_starts,
504
  the_starts,
505
  it_starts,
506
- # Paragraph stats (61-62)
507
  avg_para_len,
508
  para_len_std,
509
- # Discourse markers (63-67)
510
  conjunction_ratio,
511
  discourse_ratio,
512
  hedging_ratio,
513
  certainty_ratio,
514
  transition_ratio,
515
- # Questions (68)
516
  question_starts / n_sentences if n_sentences > 0 else 0,
517
- # Emoji/special (69-71)
518
  emoji_count,
519
  has_emoji,
520
  special_unicode,
521
- # Style markers (72-79)
522
  all_caps_ratio,
523
  paren_ratio,
524
  rhetorical_ratio,
@@ -527,25 +505,21 @@ class StylometricFeatures(BaseEstimator, TransformerMixin):
527
  has_chinese,
528
  chinese_ratio,
529
  has_self_id_ai,
530
- # Provider mention & response patterns (80-83)
531
  has_provider_mention,
532
  ends_with_question,
533
  has_closing_offer,
534
  has_checkmark,
535
- # More structure (84-89)
536
  has_arrow,
537
  has_star,
538
  avg_line_len,
539
  short_lines_ratio,
540
  cap_word_ratio,
541
  phrase_ratio,
542
- # Final features (90-94)
543
  sent_boundary_ratio,
544
  colon_definitions,
545
  double_quote_pairs,
546
  single_quote_pairs,
547
  i_starts,
548
- # New features (95-102)
549
  greeting_ratio,
550
  is_short,
551
  is_medium,
@@ -557,6 +531,32 @@ class StylometricFeatures(BaseEstimator, TransformerMixin):
557
  ]
558
 
559
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
560
  class FeaturePipeline:
561
  def __init__(self, use_tfidf=True):
562
  word_params = dict(TFIDF_WORD_PARAMS)
@@ -577,7 +577,6 @@ class FeaturePipeline:
577
  )
578
 
579
  def _clean_for_tfidf(self, text):
580
- """Strip CoT and markdown for TF-IDF (remove formatting artifacts, keep content)."""
581
  return strip_markdown(strip_cot(text))
582
 
583
  def fit_transform(self, texts):
@@ -585,8 +584,8 @@ class FeaturePipeline:
585
 
586
  print(f" Input: {len(texts)} texts", flush=True)
587
 
588
- texts_tfidf = [self._clean_for_tfidf(t) for t in texts]
589
- texts_stylo = [strip_markdown(strip_cot(t)) for t in texts]
590
 
591
  use_word_tfidf = (
592
  self.word_tfidf.max_features is not None
@@ -613,7 +612,7 @@ class FeaturePipeline:
613
  char_features = csr_matrix((len(texts), 0), dtype=np.float32)
614
 
615
  t0 = time.time()
616
- stylo_features = self.stylo.fit_transform(texts_stylo)
617
  print(
618
  f" stylometric: {stylo_features.shape[1]} features ({time.time() - t0:.1f}s)",
619
  flush=True,
@@ -625,8 +624,8 @@ class FeaturePipeline:
625
  return combined
626
 
627
  def transform(self, texts):
628
- texts_tfidf = [self._clean_for_tfidf(t) for t in texts]
629
- texts_stylo = [strip_markdown(strip_cot(t)) for t in texts]
630
 
631
  use_word_tfidf = (
632
  self.word_tfidf.max_features is not None
@@ -642,6 +641,6 @@ class FeaturePipeline:
642
  else:
643
  char_features = csr_matrix((len(texts), 0), dtype=np.float32)
644
 
645
- stylo_features = self.stylo.transform(texts_stylo)
646
  combined = hstack([word_features, char_features, stylo_features])
647
  return self.scaler.transform(combined)
 
1
  """
2
  AIFinder Feature Extraction
3
+ Optimized TF-IDF and stylometric features for AI model detection.
4
  """
5
 
6
  import re
 
12
 
13
  from config import TFIDF_WORD_PARAMS, TFIDF_CHAR_PARAMS
14
 
15
+ _RE_COMPILED = {
16
+ "cot": re.compile(r"<think(?:ing)?>.*?</think(?:ing)?>", re.DOTALL),
17
+ "code_block": re.compile(r"```[\s\S]*?```"),
18
+ "inline_code": re.compile(r"`[^`]+`"),
19
+ "bold": re.compile(r"\*\*([^*]+)\*\*"),
20
+ "italic_ast": re.compile(r"\*([^*]+)\*"),
21
+ "italic_under": re.compile(r"__([^_]+)__"),
22
+ "under": re.compile(r"_([^_]+)_"),
23
+ "header": re.compile(r"^#{1,6}\s+", re.MULTILINE),
24
+ "bullet": re.compile(r"^[\s]*[-*+]\s+", re.MULTILINE),
25
+ "numbered": re.compile(r"^\s*\d+[.)]\s+", re.MULTILINE),
26
+ "link": re.compile(r"\[([^\]]+)\]\([^)]+\)"),
27
+ "quote": re.compile(r"^>.*$", re.MULTILINE),
28
+ "hr": re.compile(r"^---+$", re.MULTILINE),
29
+ "think_tag": re.compile(r"<think>"),
30
+ "xml_tag": re.compile(r"<[^>]+>"),
31
+ "url": re.compile(r"https?://"),
32
+ "contraction": re.compile(r"\b\w+'\w+\b"),
33
+ "markdown_header": re.compile(r"^#{1,6}\s", re.MULTILINE),
34
+ "markdown_bold": re.compile(r"\*\*.*?\*\*"),
35
+ "markdown_code_block": re.compile(r"```"),
36
+ "markdown_inline_code": re.compile(r"`[^`]+`"),
37
+ "markdown_bullet": re.compile(r"^[\s]*[-*+]\s", re.MULTILINE),
38
+ "markdown_numbered": re.compile(r"^\s*\d+[.)]\s", re.MULTILINE),
39
+ "markdown_table": re.compile(r"\|.*\|"),
40
+ "question_start": re.compile(
41
+ r"^(who|what|when|where|why|how)\b", re.IGNORECASE | re.MULTILINE
42
+ ),
43
+ "emoji": re.compile(r"[\U00010000-\U0010ffff]"),
44
+ "chinese": re.compile(r"[\u4e00-\u9fff]"),
45
+ "all_caps": re.compile(r"\b[A-Z][a-z]+\b"),
46
+ "four_word": re.compile(r"\b\w+\s+\w+\s+\w+\s+\w+\b"),
47
+ "sent_boundary": re.compile(r"[.!?]\s+[A-Z]"),
48
+ "paren": re.compile(r"\([^)]+\)"),
49
+ "colon_def": re.compile(r"\b\w+:\s+\w+"),
50
+ "double_quote": re.compile(r'"[^"]*"'),
51
+ "single_quote": re.compile(r"'[^']*'"),
52
+ "greeting": re.compile(
53
+ r"\b(hi|hello|hey|hiya|greetings|howdy|yo)\b", re.IGNORECASE
54
+ ),
55
+ "conv_phrase": re.compile(
56
+ r"\b(great|perfect|sure|definitely|certainly|absolutely|of course|no problem|sounds good|got it|understood|okay|alright)\b",
57
+ re.IGNORECASE,
58
+ ),
59
+ "helpful": re.compile(
60
+ r"\b(let me know|feel free|happy to|glad to|happy to help|don't hesitate|let me know if|please let me|reach out)\b",
61
+ re.IGNORECASE,
62
+ ),
63
+ "closing_offer": re.compile(
64
+ r"(let me know|feel free|happy to help|don't hesitate|hope this helps)",
65
+ re.IGNORECASE,
66
+ ),
67
+ "self_id_ai": re.compile(
68
+ r"\b(I'm|I am)\s+(an?\s+)?(AI|language model|assistant|chatbot)\b",
69
+ re.IGNORECASE,
70
+ ),
71
+ "provider_mention": re.compile(
72
+ r"\b(Claude|Anthropic|GPT|OpenAI|ChatGPT|Gemini|Google|Bard|Grok|xAI|DeepSeek|Kimi|Moonshot|Mistral|MiniMax|Zhipu|GLM|深度求索)\b",
73
+ re.IGNORECASE,
74
+ ),
75
+ "special_unicode": re.compile(r"[^\x00-\x7F]"),
76
+ }
77
+
78
+ _PRONOUN_SETS = {
79
+ "first": frozenset(
80
+ {"i", "me", "my", "mine", "myself", "we", "us", "our", "ours", "ourselves"}
81
+ ),
82
+ "second": frozenset({"you", "your", "yours", "yourself", "yourselves"}),
83
+ "third": frozenset(
84
+ {"he", "she", "it", "they", "them", "his", "her", "its", "their"}
85
+ ),
86
+ }
87
+
88
+ _DISCOURSE_SETS = {
89
+ "conjunctions": frozenset(
90
+ {
91
+ "and",
92
+ "but",
93
+ "or",
94
+ "nor",
95
+ "for",
96
+ "yet",
97
+ "so",
98
+ "because",
99
+ "although",
100
+ "while",
101
+ "if",
102
+ "when",
103
+ "where",
104
+ }
105
+ ),
106
+ "discourse": frozenset(
107
+ {
108
+ "however",
109
+ "therefore",
110
+ "moreover",
111
+ "furthermore",
112
+ "nevertheless",
113
+ "consequently",
114
+ "thus",
115
+ "hence",
116
+ }
117
+ ),
118
+ "hedging": frozenset(
119
+ {
120
+ "perhaps",
121
+ "maybe",
122
+ "might",
123
+ "could",
124
+ "possibly",
125
+ "seemingly",
126
+ "apparently",
127
+ "arguably",
128
+ "potentially",
129
+ }
130
+ ),
131
+ "certainty": frozenset(
132
+ {
133
+ "definitely",
134
+ "certainly",
135
+ "absolutely",
136
+ "clearly",
137
+ "obviously",
138
+ "undoubtedly",
139
+ "indeed",
140
+ "surely",
141
+ }
142
+ ),
143
+ "transition": frozenset(
144
+ {
145
+ "additionally",
146
+ "meanwhile",
147
+ "subsequently",
148
+ "alternatively",
149
+ "specifically",
150
+ "notably",
151
+ "importantly",
152
+ "essentially",
153
+ }
154
+ ),
155
+ "casual": frozenset(
156
+ {
157
+ "okay",
158
+ "ok",
159
+ "hey",
160
+ "hi",
161
+ "cool",
162
+ "awesome",
163
+ "wow",
164
+ "basically",
165
+ "actually",
166
+ "literally",
167
+ "right",
168
+ "yeah",
169
+ }
170
+ ),
171
+ "formal": frozenset(
172
+ {
173
+ "regarding",
174
+ "concerning",
175
+ "pertaining",
176
+ "aforementioned",
177
+ "respectively",
178
+ "accordingly",
179
+ "henceforth",
180
+ "whereby",
181
+ "notwithstanding",
182
+ "pursuant",
183
+ }
184
+ ),
185
+ }
186
+
187
+ _PUNC_STRIP = frozenset(".,!?;:'\"()[]{}")
188
+
189
 
190
  def strip_cot(text):
191
+ return _RE_COMPILED["cot"].sub("", text).strip()
 
192
 
193
 
194
  def strip_markdown(text):
195
+ text = _RE_COMPILED["code_block"].sub("", text)
196
+ text = _RE_COMPILED["inline_code"].sub("", text)
197
+ text = _RE_COMPILED["bold"].sub(r"\1", text)
198
+ text = _RE_COMPILED["italic_ast"].sub(r"\1", text)
199
+ text = _RE_COMPILED["italic_under"].sub(r"\1", text)
200
+ text = _RE_COMPILED["under"].sub(r"\1", text)
201
+ text = _RE_COMPILED["header"].sub("", text)
202
+ text = _RE_COMPILED["bullet"].sub("", text)
203
+ text = _RE_COMPILED["numbered"].sub("", text)
204
+ text = _RE_COMPILED["link"].sub(r"\1", text)
205
+ text = _RE_COMPILED["quote"].sub("", text)
206
+ text = _RE_COMPILED["hr"].sub("", text)
207
  return text.strip()
208
 
209
 
 
212
  return self
213
 
214
  def transform(self, X):
215
+ return csr_matrix(np.array([self._extract(t) for t in X], dtype=np.float32))
 
 
 
216
 
217
  def _extract(self, text):
 
218
  n_chars = max(len(text), 1)
219
+ words = text.split()
220
  n_words = max(len(words), 1)
221
 
222
+ sentences = [s.strip() for s in re.split(r"[.!?]+", text) if s.strip()]
 
223
  n_sentences = max(len(sentences), 1)
224
 
225
  paragraphs = text.split("\n\n")
 
227
  n_paragraphs = len(non_empty_paras)
228
 
229
  lines = text.split("\n")
230
+ non_empty_lines = [ln for ln in lines if ln.strip()]
231
  n_lines = max(len(non_empty_lines), 1)
232
 
 
233
  word_lens = [len(w) for w in words]
234
+ sent_lens = [len(s.split()) for s in sentences]
235
+
236
+ _rc = _RE_COMPILED
237
+ _ps = _PRONOUN_SETS
238
+ _ds = _DISCOURSE_SETS
239
+
240
+ avg_word_len = np.mean(word_lens) if words else 0.0
241
+ word_len_std = np.std(word_lens) if len(words) > 1 else 0.0
242
+ median_word_len = np.median(word_lens) if words else 0.0
243
  avg_sent_len = n_words / n_sentences
244
 
 
245
  n_commas = text.count(",") / n_chars
246
  n_semicolons = text.count(";") / n_chars
247
  n_colons = text.count(":") / n_chars
 
257
  comma_period_ratio = n_commas / (n_period + 0.001)
258
  excl_question_ratio = n_exclaim / (n_question + 0.001)
259
 
260
+ n_headers = len(_rc["markdown_header"].findall(text)) / n_sentences
261
+ n_bold = len(_rc["markdown_bold"].findall(text)) / n_sentences
262
+ n_code_blocks = len(_rc["markdown_code_block"].findall(text)) / n_sentences
263
+ n_inline_code = len(_rc["markdown_inline_code"].findall(text)) / n_sentences
264
+ n_bullet = len(_rc["markdown_bullet"].findall(text)) / n_sentences
265
+ n_numbered = len(_rc["markdown_numbered"].findall(text)) / n_sentences
266
+ n_tables = len(_rc["markdown_table"].findall(text)) / n_sentences
 
267
 
 
268
  newline_density = text.count("\n") / n_chars
269
  double_newline_ratio = text.count("\n\n") / (text.count("\n") + 1)
270
  uppercase_ratio = sum(1 for c in text if c.isupper()) / n_chars
 
274
  unique_chars = len(set(text)) / n_chars
275
  unique_chars_ratio = len(set(text.lower())) / n_chars
276
 
277
+ sent_len_std = np.std(sent_lens) if len(sent_lens) > 1 else 0.0
 
 
278
  sent_len_max = max(sent_lens) if sent_lens else 0
279
  sent_len_min = min(sent_lens) if sent_lens else 0
280
+ sent_len_median = np.median(sent_lens) if sent_lens else 0.0
281
  sent_len_range = sent_len_max - sent_len_min
282
 
283
+ has_think = 1.0 if _rc["think_tag"].search(text) else 0.0
284
+ has_xml = 1.0 if _rc["xml_tag"].search(text) else 0.0
285
+ has_hr = 1.0 if _rc["hr"].search(text) else 0.0
286
+ has_url = 1.0 if _rc["url"].search(text) else 0.0
 
287
 
 
288
  words_lower = [w.lower().strip(".,!?;:'\"()[]{}") for w in words]
289
+ first_person_ratio = sum(1 for w in words_lower if w in _ps["first"]) / n_words
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  second_person_ratio = (
291
+ sum(1 for w in words_lower if w in _ps["second"]) / n_words
292
  )
293
+ third_person_ratio = sum(1 for w in words_lower if w in _ps["third"]) / n_words
294
 
 
295
  unique_words = len(set(words_lower))
296
+ ttr = unique_words / n_words if n_words > 0 else 0.0
297
+ word_counts = {}
298
+ for w in words_lower:
299
+ word_counts[w] = word_counts.get(w, 0) + 1
300
+ hapax = sum(1 for c in word_counts.values() if c == 1)
301
+ hapax_ratio = hapax / n_words if n_words > 0 else 0.0
302
 
303
+ contraction_count = len(_rc["contraction"].findall(text))
304
+ contraction_ratio = contraction_count / n_words if n_words > 0 else 0.0
305
 
 
306
  sentences_starters = [
307
  s.split()[0].lower() if s.split() else "" for s in sentences
308
  ]
309
  starter_vocab = (
310
+ len(set(sentences_starters)) / n_sentences if n_sentences > 0 else 0.0
311
  )
312
 
313
  and_starts = sum(1 for s in sentences_starters if s == "and") / n_sentences
 
322
  / n_sentences
323
  )
324
 
 
325
  short_word_ratio = sum(1 for w in words_lower if len(w) <= 2) / n_words
326
  medium_word_ratio = sum(1 for w in words_lower if 3 <= len(w) <= 6) / n_words
327
  long_word_ratio = sum(1 for w in words_lower if len(w) >= 7) / n_words
328
  very_long_word_ratio = sum(1 for w in words_lower if len(w) >= 10) / n_words
329
 
 
330
  para_lens = (
331
  [len(p.split()) for p in non_empty_paras] if non_empty_paras else [0]
332
  )
333
  avg_para_len = np.mean(para_lens)
334
+ para_len_std = np.std(para_lens) if len(para_lens) > 1 else 0.0
335
 
336
+ conjunction_ratio = (
337
+ sum(1 for w in words_lower if w in _ds["conjunctions"]) / n_words
338
+ )
339
+ discourse_ratio = sum(1 for w in words_lower if w in _ds["discourse"]) / n_words
340
+ hedging_ratio = sum(1 for w in words_lower if w in _ds["hedging"]) / n_words
341
+ certainty_ratio = sum(1 for w in words_lower if w in _ds["certainty"]) / n_words
342
+ transition_ratio = (
343
+ sum(1 for w in words_lower if w in _ds["transition"]) / n_words
344
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
 
 
346
  question_starts = sum(
347
+ 1 for s in sentences if s and _rc["question_start"].search(s.lower())
 
 
 
 
 
348
  )
349
 
 
350
  has_list = 1.0 if n_bullet > 0 or n_numbered > 0 else 0.0
351
  list_items = n_bullet + n_numbered
352
 
353
+ emoji_count = len(_rc["emoji"].findall(text))
 
354
  has_emoji = 1.0 if emoji_count > 0 else 0.0
355
 
 
 
356
  all_caps_words = sum(
357
  1 for w in words if len(w) > 1 and w.isupper() and w.isalpha()
358
  )
359
  all_caps_ratio = all_caps_words / n_words
360
 
361
+ paren_count = len(_rc["paren"].findall(text))
 
362
  paren_ratio = paren_count / n_sentences
363
 
 
364
  rhetorical_q = sum(1 for s in text.split("\n") if s.strip().endswith("?"))
365
  rhetorical_ratio = rhetorical_q / n_sentences
366
 
367
+ casual_ratio = sum(1 for w in words_lower if w in _ds["casual"]) / n_words
368
+ formal_ratio = sum(1 for w in words_lower if w in _ds["formal"]) / n_words
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
 
370
+ chinese_chars = len(_rc["chinese"].findall(text))
 
371
  has_chinese = 1.0 if chinese_chars > 0 else 0.0
372
  chinese_ratio = chinese_chars / n_chars
373
 
374
+ has_self_id_ai = 1.0 if _rc["self_id_ai"].search(text) else 0.0
375
+ has_provider_mention = 1.0 if _rc["provider_mention"].search(text) else 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
 
 
377
  ends_with_question = 1.0 if text.rstrip().endswith("?") else 0.0
378
+ has_closing_offer = 1.0 if _rc["closing_offer"].search(text) else 0.0
 
 
 
 
 
 
 
 
379
 
 
380
  commas_per_sentence = text.count(",") / n_sentences
381
 
 
382
  avg_line_len = (
383
+ np.mean([len(ln) for ln in non_empty_lines]) if non_empty_lines else 0.0
384
  )
385
  short_lines_ratio = (
386
+ sum(1 for ln in non_empty_lines if len(ln.split()) <= 5) / n_lines
387
  )
388
 
389
+ cap_words = len(_rc["all_caps"].findall(text))
 
390
  cap_word_ratio = cap_words / n_words
391
 
392
+ four_word_phrases = len(_rc["four_word"].findall(text))
 
393
  phrase_ratio = four_word_phrases / n_sentences
394
 
395
+ sent_boundaries = len(_rc["sent_boundary"].findall(text))
 
396
  sent_boundary_ratio = sent_boundaries / n_sentences
397
 
398
+ has_checkmark = 1.0 if any(c in text for c in "✓✗✔✘") else 0.0
399
+ has_arrow = 1.0 if any(c in text for c in "→←➡") else 0.0
400
+ has_star = 1.0 if any(c in text for c in "⭐★☆") else 0.0
401
+ special_unicode = len(_rc["special_unicode"].findall(text)) / n_chars
 
 
 
402
 
403
+ colon_definitions = len(_rc["colon_def"].findall(text)) / n_sentences
 
404
 
405
+ double_quote_pairs = len(_rc["double_quote"].findall(text)) / n_sentences
406
+ single_quote_pairs = len(_rc["single_quote"].findall(text)) / n_sentences
 
407
 
408
+ greeting_patterns = len(_rc["greeting"].findall(text))
 
 
 
 
 
409
  greeting_ratio = greeting_patterns / n_sentences
410
 
 
411
  is_short = 1.0 if n_words < 100 else 0.0
412
  is_medium = 1.0 if 100 <= n_words < 500 else 0.0
413
  is_long = 1.0 if n_words >= 500 else 0.0
414
 
 
415
  excl_sentences = sum(1 for s in sentences if s.strip().endswith("!"))
416
  excl_sentence_ratio = excl_sentences / n_sentences
417
 
418
+ question_lines = [ln for ln in non_empty_lines if ln.strip().endswith("?")]
 
419
  question_line_ratio = len(question_lines) / n_lines if n_lines > 0 else 0.0
420
 
421
+ conversational_phrases = len(_rc["conv_phrase"].findall(text))
 
 
 
 
 
 
 
 
422
  conv_phrase_ratio = conversational_phrases / n_words
423
 
424
+ helpful_phrases = len(_rc["helpful"].findall(text))
 
 
 
 
 
 
 
 
425
  helpful_ratio = helpful_phrases / n_sentences
426
 
427
  return [
 
428
  avg_word_len,
429
  word_len_std,
430
  median_word_len,
431
  avg_sent_len,
 
432
  sent_len_std,
433
  sent_len_max,
434
  sent_len_min,
435
  sent_len_median,
436
  sent_len_range,
437
  commas_per_sentence,
 
438
  n_commas,
439
  n_semicolons,
440
  n_colons,
 
448
  comma_colon_ratio,
449
  comma_period_ratio,
450
  excl_question_ratio,
 
451
  n_headers,
452
  n_bold,
453
  n_code_blocks,
 
456
  n_numbered,
457
  n_tables,
458
  has_list,
 
459
  newline_density,
460
  double_newline_ratio,
461
  uppercase_ratio,
 
466
  list_items,
467
  n_paragraphs,
468
  n_lines / n_sentences,
 
469
  has_think,
470
  has_xml,
471
  has_hr,
472
  has_url,
 
473
  first_person_ratio,
474
  second_person_ratio,
475
  third_person_ratio,
 
476
  ttr,
477
  hapax_ratio,
478
  contraction_ratio,
479
  short_word_ratio,
480
  medium_word_ratio,
 
481
  long_word_ratio,
482
  very_long_word_ratio,
 
483
  starter_vocab,
484
  and_starts,
485
  but_starts,
486
  so_starts,
487
  the_starts,
488
  it_starts,
 
489
  avg_para_len,
490
  para_len_std,
 
491
  conjunction_ratio,
492
  discourse_ratio,
493
  hedging_ratio,
494
  certainty_ratio,
495
  transition_ratio,
 
496
  question_starts / n_sentences if n_sentences > 0 else 0,
 
497
  emoji_count,
498
  has_emoji,
499
  special_unicode,
 
500
  all_caps_ratio,
501
  paren_ratio,
502
  rhetorical_ratio,
 
505
  has_chinese,
506
  chinese_ratio,
507
  has_self_id_ai,
 
508
  has_provider_mention,
509
  ends_with_question,
510
  has_closing_offer,
511
  has_checkmark,
 
512
  has_arrow,
513
  has_star,
514
  avg_line_len,
515
  short_lines_ratio,
516
  cap_word_ratio,
517
  phrase_ratio,
 
518
  sent_boundary_ratio,
519
  colon_definitions,
520
  double_quote_pairs,
521
  single_quote_pairs,
522
  i_starts,
 
523
  greeting_ratio,
524
  is_short,
525
  is_medium,
 
531
  ]
532
 
533
 
534
+ class StyleOnlyPipeline:
535
+ """Feature pipeline using ONLY stylometric features — no TF-IDF."""
536
+
537
+ def __init__(self):
538
+ self.stylo = StylometricFeatures()
539
+ self.scaler = MaxAbsScaler()
540
+
541
+ def fit_transform(self, texts):
542
+ import time
543
+
544
+ texts_clean = [strip_markdown(strip_cot(t)) for t in texts]
545
+ t0 = time.time()
546
+ stylo_features = self.stylo.transform(texts_clean)
547
+ print(
548
+ f" Stylometric: {stylo_features.shape[1]} features ({time.time() - t0:.1f}s)"
549
+ )
550
+ result = self.scaler.fit_transform(stylo_features)
551
+ print(f" Final feature matrix: {result.shape}")
552
+ return result
553
+
554
+ def transform(self, texts):
555
+ texts_clean = [strip_markdown(strip_cot(t)) for t in texts]
556
+ stylo_features = self.stylo.transform(texts_clean)
557
+ return self.scaler.transform(stylo_features)
558
+
559
+
560
  class FeaturePipeline:
561
  def __init__(self, use_tfidf=True):
562
  word_params = dict(TFIDF_WORD_PARAMS)
 
577
  )
578
 
579
  def _clean_for_tfidf(self, text):
 
580
  return strip_markdown(strip_cot(text))
581
 
582
  def fit_transform(self, texts):
 
584
 
585
  print(f" Input: {len(texts)} texts", flush=True)
586
 
587
+ texts_clean = [strip_markdown(strip_cot(t)) for t in texts]
588
+ texts_tfidf = texts_clean
589
 
590
  use_word_tfidf = (
591
  self.word_tfidf.max_features is not None
 
612
  char_features = csr_matrix((len(texts), 0), dtype=np.float32)
613
 
614
  t0 = time.time()
615
+ stylo_features = self.stylo.transform(texts_clean)
616
  print(
617
  f" stylometric: {stylo_features.shape[1]} features ({time.time() - t0:.1f}s)",
618
  flush=True,
 
624
  return combined
625
 
626
  def transform(self, texts):
627
+ texts_clean = [strip_markdown(strip_cot(t)) for t in texts]
628
+ texts_tfidf = texts_clean
629
 
630
  use_word_tfidf = (
631
  self.word_tfidf.max_features is not None
 
641
  else:
642
  char_features = csr_matrix((len(texts), 0), dtype=np.float32)
643
 
644
+ stylo_features = self.stylo.transform(texts_clean)
645
  combined = hstack([word_features, char_features, stylo_features])
646
  return self.scaler.transform(combined)
models/community/enc_4provider.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9eaeb0dde6cecb8561c2f4c47e1aeafb9dab1a6262390c9735716408f2231761
3
+ size 767
models/community/pipeline_4provider.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e916f828358de5f129731d1784b14d4e1f82ef9e315d9b972a88490b980d3ef
3
+ size 1365
models/community/rf_4provider.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da22e16dbf57465c4a178fc7e69a248486c3d32281fd76c395e6e74da8a51a2e
3
+ size 106139474
models/jobs.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24c91be6078f7fd4303d7f28e8a7212cea5f2113e05ab1335dacb6382c62c21e
3
+ size 7254
models/style/enc_4provider.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9eaeb0dde6cecb8561c2f4c47e1aeafb9dab1a6262390c9735716408f2231761
3
+ size 767
models/style/pipeline_4provider.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e916f828358de5f129731d1784b14d4e1f82ef9e315d9b972a88490b980d3ef
3
+ size 1365
models/style/rf_4provider.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da22e16dbf57465c4a178fc7e69a248486c3d32281fd76c395e6e74da8a51a2e
3
+ size 106139474
templates/index.html CHANGED
@@ -653,6 +653,16 @@
653
  animation: fadeIn 0.3s ease;
654
  }
655
 
 
 
 
 
 
 
 
 
 
 
656
  @media (max-width: 600px) {
657
  .container {
658
  padding: 1rem;
@@ -687,6 +697,7 @@
687
 
688
  <div class="tabs">
689
  <button class="tab active" data-tab="classify">Classify</button>
 
690
  <button class="tab" data-tab="docs">API Docs</button>
691
  </div>
692
 
@@ -695,6 +706,7 @@
695
  <div class="status-indicator">
696
  <span class="status-dot" id="statusDot"></span>
697
  <span id="statusText">Connecting to API...</span>
 
698
  </div>
699
 
700
  <div class="card">
@@ -751,6 +763,159 @@
751
  </div>
752
  </div>
753
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
754
  <!-- ═══ API Docs Tab ═══ -->
755
  <div class="tab-content" id="tab-docs">
756
 
@@ -1008,11 +1173,6 @@ async function classify(text, topN = 5) {
1008
 
1009
  <div class="footer">
1010
  <p>AIFinder &mdash; Train on corrections to improve accuracy</p>
1011
- <p style="margin-top: 0.5rem;">
1012
- Want to contribute? Test this and post to the
1013
- <a href="https://huggingface.co/spaces" target="_blank">HuggingFace Spaces Community</a>
1014
- if you want it merged!
1015
- </p>
1016
  </div>
1017
  </div>
1018
 
@@ -1050,6 +1210,7 @@ async function classify(text, topN = 5) {
1050
  const toast = document.getElementById('toast');
1051
  const statusDot = document.getElementById('statusDot');
1052
  const statusText = document.getElementById('statusText');
 
1053
  let usingCommunity = false;
1054
 
1055
  function showToast(message, type = 'info') {
@@ -1067,6 +1228,9 @@ async function classify(text, topN = 5) {
1067
  if (data.loaded) {
1068
  statusDot.classList.remove('loading');
1069
  statusText.textContent = data.using_community ? 'Ready — Community Model (cpu)' : `Ready (${data.device})`;
 
 
 
1070
  classifyBtn.disabled = false;
1071
  usingCommunity = data.using_community;
1072
  updateCommunityUI(data.community_available);
@@ -1367,7 +1531,483 @@ async function classify(text, topN = 5) {
1367
  populateDocsProviders();
1368
  };
1369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1370
  checkStatus();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1371
  </script>
1372
  </body>
1373
  </html>
 
653
  animation: fadeIn 0.3s ease;
654
  }
655
 
656
+ .format-option:hover {
657
+ border-color: var(--border-light) !important;
658
+ background: var(--bg-elevated) !important;
659
+ }
660
+
661
+ .format-option:has(input:checked) {
662
+ border-color: var(--accent-muted) !important;
663
+ background: rgba(232, 93, 4, 0.08) !important;
664
+ }
665
+
666
  @media (max-width: 600px) {
667
  .container {
668
  padding: 1rem;
 
697
 
698
  <div class="tabs">
699
  <button class="tab active" data-tab="classify">Classify</button>
700
+ <button class="tab" data-tab="dataset">Evaluate Dataset</button>
701
  <button class="tab" data-tab="docs">API Docs</button>
702
  </div>
703
 
 
706
  <div class="status-indicator">
707
  <span class="status-dot" id="statusDot"></span>
708
  <span id="statusText">Connecting to API...</span>
709
+ <span id="providerCount" style="margin-left:auto;font-size:0.75rem;color:var(--text-muted);"></span>
710
  </div>
711
 
712
  <div class="card">
 
763
  </div>
764
  </div>
765
 
766
+ <!-- ═══ Dataset Evaluation Tab ═══ -->
767
+ <div class="tab-content" id="tab-dataset">
768
+ <div class="card">
769
+ <div class="card-label">HuggingFace Dataset ID</div>
770
+ <input type="text" id="datasetId" placeholder="e.g., ianncity/Hunter-Alpha-SFT-300000x"
771
+ style="width:100%; padding:0.75rem 1rem; background:var(--bg-tertiary); border:1px solid var(--border); border-radius:8px; color:var(--text-primary); font-family:'Outfit',sans-serif;font-size:0.9rem;margin-bottom:0.75rem;">
772
+ <div style="display:flex;gap:0.75rem;flex-wrap:wrap;align-items:center;">
773
+ <button class="btn btn-secondary" id="checkDatasetBtn">Check Format</button>
774
+ <button class="btn btn-primary" id="evaluateDatasetBtn" disabled>Evaluate</button>
775
+ <input type="number" id="maxSamples" value="1000" min="1" max="10000"
776
+ style="width:100px;padding:0.5rem;background:var(--bg-tertiary);border:1px solid var(--border);border-radius:8px;color:var(--text-primary);font-size:0.85rem;">
777
+ <span style="color:var(--text-muted);font-size:0.8rem;">max samples</span>
778
+ </div>
779
+
780
+ <div style="margin-top:1rem;padding-top:1rem;border-top:1px solid var(--border);">
781
+ <div style="display:flex;justify-content:space-between;align-items:center;margin-bottom:0.5rem;">
782
+ <div class="card-label" style="margin-bottom:0;">Dataset Format</div>
783
+ <label style="display:flex;align-items:center;gap:0.5rem;cursor:pointer;">
784
+ <input type="checkbox" id="useCustomFormat" style="width:16px;height:16px;accent-color:var(--accent);">
785
+ <span style="font-size:0.8rem;color:var(--text-secondary);">Use custom format</span>
786
+ </label>
787
+ </div>
788
+
789
+ <div id="customFormatSection" style="display:none;background:var(--bg-tertiary);border:1px solid var(--border);border-radius:8px;padding:1rem;">
790
+ <div style="font-size:0.75rem;color:var(--text-muted);margin-bottom:0.75rem;">
791
+ How is your dataset structured? Choose a format below:
792
+ </div>
793
+
794
+ <div style="display:grid;gap:0.5rem;margin-bottom:1rem;">
795
+ <label style="display:flex;align-items:center;gap:0.5rem;cursor:pointer;padding:0.5rem;background:var(--bg-secondary);border-radius:6px;border:1px solid transparent;" class="format-option" data-format="auto">
796
+ <input type="radio" name="customFormatType" value="auto" checked style="accent-color:var(--accent);">
797
+ <div>
798
+ <div style="font-weight:500;font-size:0.85rem;">Auto-detect</div>
799
+ <div style="font-size:0.75rem;color:var(--text-muted);">Try to detect format automatically</div>
800
+ </div>
801
+ </label>
802
+
803
+ <label style="display:flex;align-items:center;gap:0.5rem;cursor:pointer;padding:0.5rem;background:var(--bg-secondary);border-radius:6px;border:1px solid transparent;" class="format-option" data-format="column">
804
+ <input type="radio" name="customFormatType" value="column" style="accent-color:var(--accent);">
805
+ <div>
806
+ <div style="font-weight:500;font-size:0.85rem;">Single column</div>
807
+ <div style="font-size:0.75rem;color:var(--text-muted);">Extract from one field (e.g., "response")</div>
808
+ </div>
809
+ </label>
810
+
811
+ <label style="display:flex;align-items:center;gap:0.5rem;cursor:pointer;padding:0.5rem;background:var(--bg-secondary);border-radius:6px;border:1px solid transparent;" class="format-option" data-format="two_column">
812
+ <input type="radio" name="customFormatType" value="two_column" style="accent-color:var(--accent);">
813
+ <div>
814
+ <div style="font-weight:500;font-size:0.85rem;">Two columns</div>
815
+ <div style="font-size:0.75rem;color:var(--text-muted);">User column + Assistant column</div>
816
+ </div>
817
+ </label>
818
+
819
+ <label style="display:flex;align-items:center;gap:0.5rem;cursor:pointer;padding:0.5rem;background:var(--bg-secondary);border-radius:6px;border:1px solid transparent;" class="format-option" data-format="pattern">
820
+ <input type="radio" name="customFormatType" value="pattern" style="accent-color:var(--accent);">
821
+ <div>
822
+ <div style="font-weight:500;font-size:0.85rem;">Text markers</div>
823
+ <div style="font-size:0.75rem;color:var(--text-muted);">Extract between text markers</div>
824
+ </div>
825
+ </label>
826
+ </div>
827
+
828
+ <div id="columnInput" style="display:none;">
829
+ <input type="text" id="customColumnName" placeholder="e.g., response, output, completion"
830
+ style="width:100%; padding:0.6rem 0.75rem; background:var(--bg-primary); border:1px solid var(--border); border-radius:6px; color:var(--text-primary); font-family:'JetBrains Mono',monospace;font-size:0.85rem;">
831
+ </div>
832
+
833
+ <div id="twoColumnInput" style="display:none;">
834
+ <div style="display:flex;gap:0.5rem;flex-wrap:wrap;">
835
+ <input type="text" id="customUserColumn" placeholder="User column (e.g., prompt, input)"
836
+ style="flex:1;min-width:150px;padding:0.6rem 0.75rem; background:var(--bg-primary); border:1px solid var(--border); border-radius:6px; color:var(--text-primary); font-family:'JetBrains Mono',monospace;font-size:0.85rem;">
837
+ <input type="text" id="customAssistantColumn" placeholder="Assistant column (e.g., response, output)"
838
+ style="flex:1;min-width:150px;padding:0.6rem 0.75rem; background:var(--bg-primary); border:1px solid var(--border); border-radius:6px; color:var(--text-primary); font-family:'JetBrains Mono',monospace;font-size:0.85rem;">
839
+ </div>
840
+ </div>
841
+
842
+ <div id="patternInput" style="display:none;">
843
+ <input type="text" id="customPattern" placeholder="e.g., user:[INST] assistant:[/INST] or [startuser] [startassistant]"
844
+ style="width:100%; padding:0.6rem 0.75rem; background:var(--bg-primary); border:1px solid var(--border); border-radius:6px; color:var(--text-primary); font-family:'JetBrains Mono',monospace;font-size:0.85rem;">
845
+ <div style="font-size:0.7rem;color:var(--text-muted);margin-top:0.5rem;">
846
+ Use <code style="background:var(--bg-primary);padding:0.1rem 0.3rem;border-radius:3px;">[startuser]</code> and <code style="background:var(--bg-primary);padding:0.1rem 0.3rem;border-radius:3px;">[startassistant]</code> as placeholders, or raw text like <code style="background:var(--bg-primary);padding:0.1rem 0.3rem;border-radius:3px;">user: assistant:</code>
847
+ </div>
848
+ </div>
849
+
850
+ <div style="margin-top:0.75rem;padding:0.5rem;background:var(--bg-primary);border-radius:6px;">
851
+ <div style="font-size:0.7rem;color:var(--text-muted);margin-bottom:0.25rem;">Format string preview:</div>
852
+ <code id="formatPreview" style="font-family:'JetBrains Mono',monospace;font-size:0.8rem;color:var(--accent);">column: response</code>
853
+ </div>
854
+ </div>
855
+ </div>
856
+ </div>
857
+
858
+ <div id="datasetFormatInfo" class="card" style="display:none;">
859
+ <div class="card-label">Dataset Format</div>
860
+ <div id="formatName" style="font-weight:600;margin-bottom:0.5rem;"></div>
861
+ <div id="formatDescription" style="color:var(--text-secondary);font-size:0.9rem;"></div>
862
+ <div style="margin-top:0.75rem;display:flex;gap:1rem;">
863
+ <div class="stat" style="padding:0.5rem 1rem;min-width:auto;">
864
+ <div class="stat-value" id="totalRows" style="font-size:1rem;">-</div>
865
+ <div class="stat-label" style="font-size:0.65rem;">Total Rows</div>
866
+ </div>
867
+ <div class="stat" style="padding:0.5rem 1rem;min-width:auto;">
868
+ <div class="stat-value" id="extractedCount" style="font-size:1rem;">-</div>
869
+ <div class="stat-label" style="font-size:0.65rem;">Responses</div>
870
+ </div>
871
+ </div>
872
+ <div id="formatError" style="display:none;margin-top:1rem;padding:0.75rem;background:rgba(232,93,4,0.12);border:1px solid var(--accent-muted);border-radius:8px;color:var(--text-secondary);font-size:0.85rem;"></div>
873
+ </div>
874
+
875
+ <div id="datasetResults" class="card" style="display:none;">
876
+ <div class="card-label">Evaluation Results</div>
877
+
878
+ <div style="display:flex;gap:1rem;margin-bottom:1.5rem;flex-wrap:wrap;">
879
+ <div class="stat">
880
+ <div class="stat-value" id="evalTotal">-</div>
881
+ <div class="stat-label">Samples</div>
882
+ </div>
883
+ <div class="stat">
884
+ <div class="stat-value" id="evalLikelyProvider">-</div>
885
+ <div class="stat-label">Likely Provider</div>
886
+ </div>
887
+ <div class="stat">
888
+ <div class="stat-value" id="evalAvgConfidence">-</div>
889
+ <div class="stat-label">Avg Confidence</div>
890
+ </div>
891
+ </div>
892
+
893
+ <div class="card-label" style="margin-top:1rem;">Provider Distribution</div>
894
+ <div id="providerDistribution"></div>
895
+
896
+ <div class="card-label" style="margin-top:1.5rem;">Top Providers (by cumulative score)</div>
897
+ <div id="topProvidersList"></div>
898
+ </div>
899
+
900
+ <div id="datasetLoading" style="display:none;text-align:center;padding:2rem;">
901
+ <span class="loading" style="width:24px;height:24px;border-width:3px;"></span>
902
+ <div style="margin-top:1rem;color:var(--text-secondary);" id="datasetLoadingText">Evaluating...</div>
903
+ </div>
904
+
905
+ <div class="docs-section" style="margin-top:2rem;">
906
+ <h2 style="font-size:1rem;font-weight:500;color:var(--text-secondary);margin-bottom:0.75rem;">Supported Dataset Formats</h2>
907
+ <div id="supportedFormatsList" style="display:grid;grid-template-columns:repeat(auto-fill,minmax(250px,1fr));gap:0.75rem;"></div>
908
+ </div>
909
+
910
+ <div class="card" style="margin-top:2rem;">
911
+ <div class="card-label" style="display:flex;justify-content:space-between;align-items:center;">
912
+ <span>Your Evaluated Datasets</span>
913
+ <button class="btn btn-secondary" id="clearHistoryBtn" style="padding:0.4rem 0.75rem;font-size:0.75rem;">Clear History</button>
914
+ </div>
915
+ <div id="datasetHistory" style="color:var(--text-muted);font-size:0.85rem;">Loading...</div>
916
+ </div>
917
+ </div>
918
+
919
  <!-- ═══ API Docs Tab ═══ -->
920
  <div class="tab-content" id="tab-docs">
921
 
 
1173
 
1174
  <div class="footer">
1175
  <p>AIFinder &mdash; Train on corrections to improve accuracy</p>
 
 
 
 
 
1176
  </div>
1177
  </div>
1178
 
 
1210
  const toast = document.getElementById('toast');
1211
  const statusDot = document.getElementById('statusDot');
1212
  const statusText = document.getElementById('statusText');
1213
+ const providerCountEl = document.getElementById('providerCount');
1214
  let usingCommunity = false;
1215
 
1216
  function showToast(message, type = 'info') {
 
1228
  if (data.loaded) {
1229
  statusDot.classList.remove('loading');
1230
  statusText.textContent = data.using_community ? 'Ready — Community Model (cpu)' : `Ready (${data.device})`;
1231
+ if (data.num_providers) {
1232
+ providerCountEl.textContent = `${data.num_providers} providers`;
1233
+ }
1234
  classifyBtn.disabled = false;
1235
  usingCommunity = data.using_community;
1236
  updateCommunityUI(data.community_available);
 
1531
  populateDocsProviders();
1532
  };
1533
 
1534
+ // ── Dataset Evaluation ──
1535
+ const datasetIdInput = document.getElementById('datasetId');
1536
+ const maxSamplesInput = document.getElementById('maxSamples');
1537
+ const checkDatasetBtn = document.getElementById('checkDatasetBtn');
1538
+ const evaluateDatasetBtn = document.getElementById('evaluateDatasetBtn');
1539
+ const datasetFormatInfo = document.getElementById('datasetFormatInfo');
1540
+ const formatName = document.getElementById('formatName');
1541
+ const formatDescription = document.getElementById('formatDescription');
1542
+ const totalRowsEl = document.getElementById('totalRows');
1543
+ const extractedCountEl = document.getElementById('extractedCount');
1544
+ const formatError = document.getElementById('formatError');
1545
+ const datasetResults = document.getElementById('datasetResults');
1546
+ const datasetLoading = document.getElementById('datasetLoading');
1547
+ const datasetLoadingText = document.getElementById('datasetLoadingText');
1548
+ const datasetHistory = document.getElementById('datasetHistory');
1549
+
1550
+ let currentDatasetInfo = null;
1551
+ let currentJobId = null;
1552
+ let jobPollingInterval = null;
1553
+
1554
+ function saveJobId(jobId) {
1555
+ localStorage.setItem('aifinder_current_job', jobId);
1556
+ }
1557
+
1558
+ function getSavedJobId() {
1559
+ return localStorage.getItem('aifinder_current_job');
1560
+ }
1561
+
1562
+ function clearSavedJobId() {
1563
+ localStorage.removeItem('aifinder_current_job');
1564
+ }
1565
+
1566
+ function generateApiKey() {
1567
+ const existing = localStorage.getItem('aifinder_api_key');
1568
+ if (existing) return existing;
1569
+ const key = 'usr_' + Math.random().toString(36).substring(2, 15) + Math.random().toString(36).substring(2, 15);
1570
+ localStorage.setItem('aifinder_api_key', key);
1571
+ return key;
1572
+ }
1573
+
1574
+ function getApiKey() {
1575
+ return localStorage.getItem('aifinder_api_key') || generateApiKey();
1576
+ }
1577
+
1578
+ getApiKey();
1579
+
1580
+ async function loadDatasetHistory() {
1581
+ const apiKey = getApiKey();
1582
+ if (!apiKey) {
1583
+ datasetHistory.innerHTML = '<span style="color:var(--text-muted);">No evaluated datasets yet.</span>';
1584
+ return;
1585
+ }
1586
+
1587
+ try {
1588
+ const res = await fetch(`${API_BASE}/api/datasets?api_key=${encodeURIComponent(apiKey)}`);
1589
+ const data = await res.json();
1590
+
1591
+ if (!data.datasets || data.datasets.length === 0) {
1592
+ datasetHistory.innerHTML = '<span style="color:var(--text-muted);">Your evaluated datasets will appear here. Start by checking a dataset format above.</span>';
1593
+ return;
1594
+ }
1595
+
1596
+ datasetHistory.innerHTML = data.datasets.map(ds => `
1597
+ <div style="display:flex;justify-content:space-between;align-items:center;padding:0.75rem;background:var(--bg-tertiary);border:1px solid var(--border);border-radius:8px;margin-bottom:0.5rem;cursor:pointer;"
1598
+ onclick="loadDatasetResult('${ds.job_id}')">
1599
+ <div>
1600
+ <div style="font-weight:500;">${ds.dataset_id}</div>
1601
+ <div style="font-size:0.75rem;color:var(--text-muted);">${ds.completed_at ? new Date(ds.completed_at).toLocaleString() : ''}</div>
1602
+ </div>
1603
+ <span style="padding:0.25rem 0.5rem;border-radius:4px;font-size:0.75rem;${ds.status === 'completed' ? 'background:var(--success-muted);color:var(--success);' : 'background:var(--accent-muted);color:var(--accent-hover);'}">${ds.status}</span>
1604
+ </div>
1605
+ `).join('');
1606
+ } catch (e) {
1607
+ datasetHistory.innerHTML = '<span style="color:var(--text-muted);">Failed to load history.</span>';
1608
+ }
1609
+ }
1610
+
1611
+ async function loadDatasetResult(jobId) {
1612
+ try {
1613
+ const res = await fetch(`${API_BASE}/api/dataset/job/${jobId}`);
1614
+ const data = await res.json();
1615
+
1616
+ if (data.status === 'completed' && data.results) {
1617
+ showEvaluationResults(data.results);
1618
+ } else if (data.status === 'failed') {
1619
+ showToast('Evaluation failed: ' + data.error);
1620
+ } else if (data.status === 'running' || data.status === 'pending') {
1621
+ datasetIdInput.value = data.dataset_id || '';
1622
+ currentJobId = jobId;
1623
+ saveJobId(currentJobId);
1624
+ datasetLoading.style.display = 'block';
1625
+ evaluateDatasetBtn.disabled = true;
1626
+ if (data.progress) {
1627
+ datasetLoadingText.textContent = `${data.progress.stage === 'downloading' ? 'Downloading' : data.progress.stage === 'evaluating' ? 'Evaluating' : 'Processing'}: ${data.progress.percent}%`;
1628
+ } else {
1629
+ datasetLoadingText.textContent = 'Evaluation running, please wait...';
1630
+ }
1631
+ startJobPolling();
1632
+ }
1633
+ } catch (e) {
1634
+ showToast('Error: ' + e.message);
1635
+ }
1636
+ }
1637
+
1638
+ function showEvaluationResults(data) {
1639
+ document.getElementById('evalTotal').textContent = data.extracted_count?.toLocaleString() || '-';
1640
+ document.getElementById('evalLikelyProvider').textContent = data.likely_provider || '-';
1641
+ document.getElementById('evalAvgConfidence').textContent = (data.average_confidence || 0) + '%';
1642
+
1643
+ const distContainer = document.getElementById('providerDistribution');
1644
+ distContainer.innerHTML = '';
1645
+
1646
+ const sortedProviders = Object.entries(data.provider_counts || {})
1647
+ .sort((a, b) => b[1].count - a[1].count);
1648
+
1649
+ for (const [provider, info] of sortedProviders) {
1650
+ const conf = data.provider_confidences?.[provider]?.average || 0;
1651
+ const html = `
1652
+ <div style="margin-bottom:1rem;">
1653
+ <div style="display:flex;justify-content:space-between;margin-bottom:0.25rem;">
1654
+ <span style="font-weight:500;">${provider}</span>
1655
+ <span style="color:var(--text-secondary);font-size:0.85rem;">${info.count} (${info.percentage}%) · ${conf}% avg</span>
1656
+ </div>
1657
+ <div class="result-bar">
1658
+ <div class="result-bar-fill" style="width:${info.percentage}%"></div>
1659
+ </div>
1660
+ </div>
1661
+ `;
1662
+ distContainer.innerHTML += html;
1663
+ }
1664
+
1665
+ const topContainer = document.getElementById('topProvidersList');
1666
+ topContainer.innerHTML = '';
1667
+
1668
+ const sortedTop = Object.entries(data.top_providers || {})
1669
+ .sort((a, b) => b[1] - a[1])
1670
+ .slice(0, 5);
1671
+
1672
+ for (const [provider, count] of sortedTop) {
1673
+ const conf = data.provider_confidences?.[provider]?.cumulative || 0;
1674
+ topContainer.innerHTML += `
1675
+ <div class="result-item">
1676
+ <span class="result-name">${provider}</span>
1677
+ <span class="result-percent">${conf.toFixed(2)} pts</span>
1678
+ </div>
1679
+ `;
1680
+ }
1681
+
1682
+ datasetResults.style.display = 'block';
1683
+ datasetLoading.style.display = 'none';
1684
+ }
1685
+
1686
+ function startJobPolling() {
1687
+ if (jobPollingInterval) clearInterval(jobPollingInterval);
1688
+ jobPollingInterval = setInterval(async () => {
1689
+ if (!currentJobId) return;
1690
+ try {
1691
+ const res = await fetch(`${API_BASE}/api/dataset/job/${currentJobId}`);
1692
+ const data = await res.json();
1693
+ console.log('Polling response:', data);
1694
+
1695
+ if (data.status === 'completed') {
1696
+ clearInterval(jobPollingInterval);
1697
+ jobPollingInterval = null;
1698
+ currentJobId = null;
1699
+ clearSavedJobId();
1700
+ showEvaluationResults(data.results);
1701
+ loadDatasetHistory();
1702
+ showToast('Evaluation complete!', 'success');
1703
+ } else if (data.status === 'failed') {
1704
+ clearInterval(jobPollingInterval);
1705
+ jobPollingInterval = null;
1706
+ currentJobId = null;
1707
+ clearSavedJobId();
1708
+ datasetLoading.style.display = 'none';
1709
+ evaluateDatasetBtn.disabled = false;
1710
+ showToast('Evaluation failed: ' + data.error);
1711
+ } else {
1712
+ const prog = data.progress;
1713
+ if (prog) {
1714
+ datasetLoadingText.textContent = `${prog.stage === 'downloading' ? 'Downloading' : prog.stage === 'evaluating' ? 'Evaluating' : 'Processing'}: ${prog.percent}%`;
1715
+ } else {
1716
+ datasetLoadingText.textContent = 'Evaluating... ' + (data.started_at ? new Date(data.started_at).toLocaleTimeString() : '');
1717
+ }
1718
+ }
1719
+ } catch (e) {
1720
+ console.error('Polling error:', e);
1721
+ }
1722
+ }, 2000);
1723
+ }
1724
+
1725
+ async function checkDatasetFormat() {
1726
+ const datasetId = datasetIdInput.value.trim();
1727
+ if (!datasetId) {
1728
+ showToast('Please enter a dataset ID');
1729
+ return;
1730
+ }
1731
+
1732
+ checkDatasetBtn.disabled = true;
1733
+ checkDatasetBtn.innerHTML = '<span class="loading"></span>';
1734
+
1735
+ const customFormat = buildFormatString();
1736
+
1737
+ try {
1738
+ const res = await fetch(`${API_BASE}/api/dataset/info`, {
1739
+ method: 'POST',
1740
+ headers: { 'Content-Type': 'application/json' },
1741
+ body: JSON.stringify({
1742
+ dataset_id: datasetId,
1743
+ max_samples: parseInt(maxSamplesInput.value) || 1000,
1744
+ custom_format: customFormat
1745
+ })
1746
+ });
1747
+ const data = await res.json();
1748
+
1749
+ currentDatasetInfo = data;
1750
+
1751
+ const formatDetectedButNoTexts = data.supported && (data.extracted_count === 0);
1752
+
1753
+ if (data.supported && !formatDetectedButNoTexts) {
1754
+ formatName.textContent = data.format_name || data.format || 'Unknown';
1755
+ formatDescription.textContent = data.format_description || '';
1756
+ totalRowsEl.textContent = data.total_rows?.toLocaleString() || '-';
1757
+ extractedCountEl.textContent = data.extracted_count?.toLocaleString() || '-';
1758
+ formatError.style.display = 'none';
1759
+ evaluateDatasetBtn.disabled = false;
1760
+ } else {
1761
+ if (formatDetectedButNoTexts) {
1762
+ formatName.textContent = data.format_name || data.format || 'Unknown';
1763
+ formatDescription.textContent = 'Format detected but no valid assistant responses found. Try a custom format below.';
1764
+ totalRowsEl.textContent = data.total_rows?.toLocaleString() || '-';
1765
+ extractedCountEl.textContent = '0';
1766
+ formatError.style.display = 'block';
1767
+ formatError.textContent = 'No valid assistant responses extracted (minimum 50 chars required). The detected format may not match the actual data structure.';
1768
+ } else {
1769
+ formatName.textContent = 'Unsupported Format';
1770
+ formatDescription.textContent = '';
1771
+ totalRowsEl.textContent = '-';
1772
+ extractedCountEl.textContent = '-';
1773
+ formatError.style.display = 'block';
1774
+ formatError.textContent = data.error || 'Unknown error';
1775
+ }
1776
+ evaluateDatasetBtn.disabled = true;
1777
+
1778
+ useCustomFormatCheckbox.checked = true;
1779
+ customFormatSection.style.display = 'block';
1780
+ showToast('Could not extract responses. Please specify a custom format below.');
1781
+ }
1782
+
1783
+ datasetFormatInfo.style.display = 'block';
1784
+ datasetResults.style.display = 'none';
1785
+
1786
+ } catch (e) {
1787
+ showToast('Error: ' + e.message);
1788
+ } finally {
1789
+ checkDatasetBtn.disabled = false;
1790
+ checkDatasetBtn.textContent = 'Check Format';
1791
+ }
1792
+ }
1793
+
1794
+ async function evaluateDataset() {
1795
+ const datasetId = datasetIdInput.value.trim();
1796
+ if (!datasetId || !currentDatasetInfo?.supported) return;
1797
+
1798
+ evaluateDatasetBtn.disabled = true;
1799
+ datasetLoading.style.display = 'block';
1800
+ datasetResults.style.display = 'none';
1801
+ datasetLoadingText.textContent = 'Starting evaluation...';
1802
+
1803
+ const apiKey = getApiKey();
1804
+ const customFormat = buildFormatString();
1805
+
1806
+ try {
1807
+ const res = await fetch(`${API_BASE}/api/dataset/evaluate`, {
1808
+ method: 'POST',
1809
+ headers: { 'Content-Type': 'application/json' },
1810
+ body: JSON.stringify({
1811
+ dataset_id: datasetId,
1812
+ max_samples: parseInt(maxSamplesInput.value) || 1000,
1813
+ api_key: apiKey || null,
1814
+ custom_format: customFormat
1815
+ })
1816
+ });
1817
+ const data = await res.json();
1818
+ console.log('Evaluate response:', data);
1819
+
1820
+ if (data.error) {
1821
+ showToast(data.error);
1822
+ datasetLoading.style.display = 'none';
1823
+ evaluateDatasetBtn.disabled = false;
1824
+ return;
1825
+ }
1826
+
1827
+ currentJobId = data.job_id;
1828
+ saveJobId(currentJobId);
1829
+ console.log('Job ID saved:', currentJobId);
1830
+ datasetLoadingText.textContent = 'Evaluation started. Processing in background...';
1831
+
1832
+ // Show info that user can close the page
1833
+ const closePageMsg = document.createElement('div');
1834
+ closePageMsg.style.cssText = 'margin-top:1rem;color:var(--text-muted);font-size:0.85rem;';
1835
+ closePageMsg.innerHTML = '✓ You can close this page — evaluation will continue in the background.';
1836
+ const loadingEl = document.getElementById('datasetLoading');
1837
+ loadingEl.querySelectorAll('.close-page-msg').forEach(el => el.remove());
1838
+ closePageMsg.className = 'close-page-msg';
1839
+ loadingEl.appendChild(closePageMsg);
1840
+
1841
+ startJobPolling();
1842
+ loadDatasetHistory();
1843
+
1844
+ } catch (e) {
1845
+ showToast('Error: ' + e.message);
1846
+ datasetLoading.style.display = 'none';
1847
+ evaluateDatasetBtn.disabled = false;
1848
+ }
1849
+ }
1850
+
1851
+ checkDatasetBtn.addEventListener('click', checkDatasetFormat);
1852
+ evaluateDatasetBtn.addEventListener('click', evaluateDataset);
1853
+
1854
+ document.getElementById('clearHistoryBtn').addEventListener('click', async () => {
1855
+ if (!confirm('Clear all dataset evaluation history?')) return;
1856
+ try {
1857
+ const res = await fetch(`${API_BASE}/api/datasets/clear`, {
1858
+ method: 'POST',
1859
+ headers: { 'Content-Type': 'application/json' },
1860
+ body: JSON.stringify({ api_key: getApiKey() })
1861
+ });
1862
+ const data = await res.json();
1863
+ if (data.error) {
1864
+ showToast(data.error);
1865
+ } else {
1866
+ clearSavedJobId();
1867
+ showToast(`Cleared ${data.cleared} datasets`, 'success');
1868
+ loadDatasetHistory();
1869
+ }
1870
+ } catch (e) {
1871
+ showToast('Error: ' + e.message);
1872
+ }
1873
+ });
1874
+
1875
+ datasetIdInput.addEventListener('keydown', (e) => {
1876
+ if (e.key === 'Enter') checkDatasetFormat();
1877
+ });
1878
+
1879
+ loadDatasetHistory();
1880
+
1881
+ // Load supported formats
1882
+ async function loadSupportedFormats() {
1883
+ try {
1884
+ const res = await fetch(`${API_BASE}/api/dataset/formats`);
1885
+ const data = await res.json();
1886
+ const container = document.getElementById('supportedFormatsList');
1887
+ container.innerHTML = data.formats.map(f => `
1888
+ <div style="background:var(--bg-tertiary);border:1px solid var(--border);border-radius:8px;padding:0.75rem;">
1889
+ <div style="font-weight:500;font-size:0.85rem;">${f.name}</div>
1890
+ <div style="font-size:0.75rem;color:var(--text-muted);margin-top:0.25rem;">${f.description}</div>
1891
+ </div>
1892
+ `).join('');
1893
+ } catch (e) {
1894
+ console.error('Failed to load formats:', e);
1895
+ }
1896
+ }
1897
+
1898
+ // ── Custom Format UI Handling ──
1899
+ const useCustomFormatCheckbox = document.getElementById('useCustomFormat');
1900
+ const customFormatSection = document.getElementById('customFormatSection');
1901
+ const formatPreview = document.getElementById('formatPreview');
1902
+ const columnInput = document.getElementById('columnInput');
1903
+ const twoColumnInput = document.getElementById('twoColumnInput');
1904
+ const patternInput = document.getElementById('patternInput');
1905
+ const customColumnName = document.getElementById('customColumnName');
1906
+ const customUserColumn = document.getElementById('customUserColumn');
1907
+ const customAssistantColumn = document.getElementById('customAssistantColumn');
1908
+ const customPattern = document.getElementById('customPattern');
1909
+
1910
+ function buildFormatString() {
1911
+ if (!useCustomFormatCheckbox.checked) return null;
1912
+
1913
+ const formatType = document.querySelector('input[name="customFormatType"]:checked')?.value || 'auto';
1914
+
1915
+ if (formatType === 'auto') return null;
1916
+
1917
+ if (formatType === 'column') {
1918
+ const col = customColumnName.value.trim();
1919
+ return col ? `column: ${col}` : null;
1920
+ }
1921
+
1922
+ if (formatType === 'two_column') {
1923
+ const userCol = customUserColumn.value.trim();
1924
+ const assistantCol = customAssistantColumn.value.trim();
1925
+ if (assistantCol) {
1926
+ return userCol ? `column: ${userCol}, column: ${assistantCol}` : `column: ${assistantCol}`;
1927
+ }
1928
+ return null;
1929
+ }
1930
+
1931
+ if (formatType === 'pattern') {
1932
+ const pat = customPattern.value.trim();
1933
+ if (!pat) return null;
1934
+
1935
+ if (pat.includes('[startuser]') && pat.includes('[startassistant]')) {
1936
+ return pat;
1937
+ }
1938
+
1939
+ const parts = pat.split(/\s+/);
1940
+ if (parts.length >= 2) {
1941
+ return `pattern: ${parts[0]}, pattern: ${parts[1]}`;
1942
+ }
1943
+ return `column: ${pat}`;
1944
+ }
1945
+
1946
+ return null;
1947
+ }
1948
+
1949
+ function updateFormatPreview() {
1950
+ const fmt = buildFormatString();
1951
+ formatPreview.textContent = fmt || '(auto-detect)';
1952
+ formatPreview.style.color = fmt ? 'var(--accent)' : 'var(--text-muted)';
1953
+ }
1954
+
1955
+ useCustomFormatCheckbox?.addEventListener('change', () => {
1956
+ customFormatSection.style.display = useCustomFormatCheckbox.checked ? 'block' : 'none';
1957
+ updateFormatPreview();
1958
+ });
1959
+
1960
+ document.querySelectorAll('input[name="customFormatType"]').forEach(radio => {
1961
+ radio.addEventListener('change', (e) => {
1962
+ columnInput.style.display = e.target.value === 'column' ? 'block' : 'none';
1963
+ twoColumnInput.style.display = e.target.value === 'two_column' ? 'block' : 'none';
1964
+ patternInput.style.display = e.target.value === 'pattern' ? 'block' : 'none';
1965
+ updateFormatPreview();
1966
+ });
1967
+ });
1968
+
1969
+ [customColumnName, customUserColumn, customAssistantColumn, customPattern].forEach(input => {
1970
+ input?.addEventListener('input', updateFormatPreview);
1971
+ });
1972
+
1973
  checkStatus();
1974
+
1975
+ async function restoreJobState() {
1976
+ const savedJobId = getSavedJobId();
1977
+ if (!savedJobId) return;
1978
+ console.log('Restoring job state, savedJobId:', savedJobId);
1979
+ try {
1980
+ const res = await fetch(`${API_BASE}/api/dataset/job/${savedJobId}`);
1981
+ const data = await res.json();
1982
+ console.log('Job data:', data);
1983
+
1984
+ if (data.status === 'running' || data.status === 'pending') {
1985
+ currentJobId = savedJobId;
1986
+ datasetIdInput.value = data.dataset_id || '';
1987
+ datasetLoading.style.display = 'block';
1988
+ evaluateDatasetBtn.disabled = true;
1989
+
1990
+ const prog = data.progress;
1991
+ console.log('Progress:', prog);
1992
+ if (prog) {
1993
+ datasetLoadingText.textContent = `${prog.stage === 'downloading' ? 'Downloading' : prog.stage === 'evaluating' ? 'Evaluating' : 'Processing'}: ${prog.percent}%`;
1994
+ } else {
1995
+ datasetLoadingText.textContent = 'Starting evaluation...';
1996
+ }
1997
+
1998
+ startJobPolling();
1999
+ } else if (data.status === 'completed') {
2000
+ clearSavedJobId();
2001
+ showEvaluationResults(data.results);
2002
+ } else if (data.status === 'failed') {
2003
+ clearSavedJobId();
2004
+ }
2005
+ } catch (e) {
2006
+ console.error('Restore error:', e);
2007
+ clearSavedJobId();
2008
+ }
2009
+ }
2010
+ restoreJobState();
2011
  </script>
2012
  </body>
2013
  </html>