Upload 18 files
Browse files- app.py +499 -21
- config.py +26 -30
- data_loader.py +270 -0
- dataset_evaluator.py +769 -0
- evaluate_dataset.py +246 -0
- features.py +293 -294
- models/community/enc_4provider.joblib +3 -0
- models/community/pipeline_4provider.joblib +3 -0
- models/community/rf_4provider.joblib +3 -0
- models/jobs.joblib +3 -0
- models/style/enc_4provider.joblib +3 -0
- models/style/pipeline_4provider.joblib +3 -0
- models/style/rf_4provider.joblib +3 -0
- templates/index.html +645 -5
app.py
CHANGED
|
@@ -5,6 +5,11 @@ Serves the trained sklearn ensemble via the AIFinder inference class.
|
|
| 5 |
|
| 6 |
import os
|
| 7 |
import re
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
import joblib
|
| 10 |
import numpy as np
|
|
@@ -13,10 +18,14 @@ from flask import Flask, jsonify, request, send_from_directory, render_template
|
|
| 13 |
from flask_cors import CORS
|
| 14 |
from flask_limiter import Limiter
|
| 15 |
from flask_limiter.util import get_remote_address
|
|
|
|
| 16 |
|
| 17 |
from config import MODEL_DIR
|
| 18 |
from inference import AIFinder
|
| 19 |
|
|
|
|
|
|
|
|
|
|
| 20 |
app = Flask(__name__)
|
| 21 |
CORS(app)
|
| 22 |
limiter = Limiter(get_remote_address, app=app)
|
|
@@ -28,20 +37,42 @@ using_community = False
|
|
| 28 |
DEFAULT_TOP_N = 4
|
| 29 |
COMMUNITY_DIR = os.path.join(MODEL_DIR, "community")
|
| 30 |
CORRECTIONS_FILE = os.path.join(COMMUNITY_DIR, "corrections.joblib")
|
|
|
|
|
|
|
| 31 |
corrections: list[dict] = []
|
| 32 |
|
|
|
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
|
| 47 |
def _active_finder():
|
|
@@ -112,17 +143,20 @@ def correct():
|
|
| 112 |
text = _strip_think_tags(data["text"])
|
| 113 |
corrections.append({"text": text, "provider": provider})
|
| 114 |
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
-
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
-
joblib.dump([rf], os.path.join(COMMUNITY_DIR, "rf_4provider.joblib"))
|
| 124 |
-
joblib.dump(finder.pipeline, os.path.join(COMMUNITY_DIR, "pipeline_4provider.joblib"))
|
| 125 |
-
joblib.dump(finder.le, os.path.join(COMMUNITY_DIR, "enc_4provider.joblib"))
|
| 126 |
joblib.dump(corrections, CORRECTIONS_FILE)
|
| 127 |
|
| 128 |
community_finder = AIFinder(model_dir=COMMUNITY_DIR)
|
|
@@ -143,7 +177,9 @@ def toggle_community():
|
|
| 143 |
global using_community
|
| 144 |
data = request.get_json(silent=True) or {}
|
| 145 |
using_community = bool(data.get("enabled", not using_community))
|
| 146 |
-
return jsonify(
|
|
|
|
|
|
|
| 147 |
|
| 148 |
|
| 149 |
@app.route("/models/<filename>")
|
|
@@ -178,6 +214,448 @@ def providers():
|
|
| 178 |
)
|
| 179 |
|
| 180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
if __name__ == "__main__":
|
| 182 |
print("Loading models...")
|
| 183 |
load_models()
|
|
|
|
| 5 |
|
| 6 |
import os
|
| 7 |
import re
|
| 8 |
+
import shutil
|
| 9 |
+
import uuid
|
| 10 |
+
import threading
|
| 11 |
+
from collections import defaultdict
|
| 12 |
+
from datetime import datetime
|
| 13 |
|
| 14 |
import joblib
|
| 15 |
import numpy as np
|
|
|
|
| 18 |
from flask_cors import CORS
|
| 19 |
from flask_limiter import Limiter
|
| 20 |
from flask_limiter.util import get_remote_address
|
| 21 |
+
from tqdm import tqdm
|
| 22 |
|
| 23 |
from config import MODEL_DIR
|
| 24 |
from inference import AIFinder
|
| 25 |
|
| 26 |
+
STYLE_MODEL_DIR = os.path.join(MODEL_DIR, "style")
|
| 27 |
+
from dataset_evaluator import load_dataset_texts, get_supported_formats
|
| 28 |
+
|
| 29 |
app = Flask(__name__)
|
| 30 |
CORS(app)
|
| 31 |
limiter = Limiter(get_remote_address, app=app)
|
|
|
|
| 37 |
DEFAULT_TOP_N = 4
|
| 38 |
COMMUNITY_DIR = os.path.join(MODEL_DIR, "community")
|
| 39 |
CORRECTIONS_FILE = os.path.join(COMMUNITY_DIR, "corrections.joblib")
|
| 40 |
+
CORRECTION_MODEL_FILE = os.path.join(COMMUNITY_DIR, "correction_rf_4provider.joblib")
|
| 41 |
+
JOBS_FILE = os.path.join(MODEL_DIR, "jobs.joblib")
|
| 42 |
corrections: list[dict] = []
|
| 43 |
|
| 44 |
+
jobs: dict[str, dict] = {}
|
| 45 |
|
| 46 |
+
|
| 47 |
+
def _copy_base_models_to_community():
|
| 48 |
+
"""Copy base models from style model to community directory if not already present."""
|
| 49 |
+
base_files = [
|
| 50 |
+
"rf_4provider.joblib",
|
| 51 |
+
"pipeline_4provider.joblib",
|
| 52 |
+
"enc_4provider.joblib",
|
| 53 |
+
]
|
| 54 |
+
for fname in base_files:
|
| 55 |
+
src = os.path.join(STYLE_MODEL_DIR, fname)
|
| 56 |
+
dst = os.path.join(COMMUNITY_DIR, fname)
|
| 57 |
+
if os.path.exists(src) and not os.path.exists(dst):
|
| 58 |
+
shutil.copy(src, dst)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _update_job_progress(job_id, current, total, stage):
|
| 62 |
+
"""Update progress for a job."""
|
| 63 |
+
if job_id in jobs:
|
| 64 |
+
jobs[job_id]["progress"] = {
|
| 65 |
+
"current": current,
|
| 66 |
+
"total": total,
|
| 67 |
+
"stage": stage,
|
| 68 |
+
"percent": round((current / total * 100), 1) if total > 0 else 0,
|
| 69 |
+
}
|
| 70 |
+
_save_jobs()
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _save_jobs():
|
| 74 |
+
"""Persist jobs to disk."""
|
| 75 |
+
joblib.dump(jobs, JOBS_FILE)
|
| 76 |
|
| 77 |
|
| 78 |
def _active_finder():
|
|
|
|
| 143 |
text = _strip_think_tags(data["text"])
|
| 144 |
corrections.append({"text": text, "provider": provider})
|
| 145 |
|
| 146 |
+
_copy_base_models_to_community()
|
| 147 |
+
|
| 148 |
+
if len(corrections) > 0:
|
| 149 |
+
texts = [c["text"] for c in corrections]
|
| 150 |
+
providers = [c["provider"] for c in corrections]
|
| 151 |
+
X = finder.pipeline.transform(texts)
|
| 152 |
+
y = finder.le.transform(providers)
|
| 153 |
|
| 154 |
+
correction_rf = RandomForestClassifier(
|
| 155 |
+
n_estimators=100, random_state=42, n_jobs=-1
|
| 156 |
+
)
|
| 157 |
+
correction_rf.fit(X, y)
|
| 158 |
+
joblib.dump([correction_rf], CORRECTION_MODEL_FILE)
|
| 159 |
|
|
|
|
|
|
|
|
|
|
| 160 |
joblib.dump(corrections, CORRECTIONS_FILE)
|
| 161 |
|
| 162 |
community_finder = AIFinder(model_dir=COMMUNITY_DIR)
|
|
|
|
| 177 |
global using_community
|
| 178 |
data = request.get_json(silent=True) or {}
|
| 179 |
using_community = bool(data.get("enabled", not using_community))
|
| 180 |
+
return jsonify(
|
| 181 |
+
{"using_community": using_community, "available": community_finder is not None}
|
| 182 |
+
)
|
| 183 |
|
| 184 |
|
| 185 |
@app.route("/models/<filename>")
|
|
|
|
| 214 |
)
|
| 215 |
|
| 216 |
|
| 217 |
+
@app.route("/api/dataset/info", methods=["POST"])
|
| 218 |
+
def dataset_info():
|
| 219 |
+
"""Get info about a dataset without evaluating."""
|
| 220 |
+
data = request.get_json(silent=True)
|
| 221 |
+
if not data or "dataset_id" not in data:
|
| 222 |
+
return jsonify({"error": "Request must include 'dataset_id'"}), 400
|
| 223 |
+
|
| 224 |
+
dataset_id = data["dataset_id"]
|
| 225 |
+
max_samples = data.get("max_samples", 1000)
|
| 226 |
+
evaluate = data.get("evaluate", False)
|
| 227 |
+
api_key = data.get("api_key")
|
| 228 |
+
custom_format = data.get("custom_format")
|
| 229 |
+
|
| 230 |
+
result = load_dataset_texts(
|
| 231 |
+
dataset_id, max_samples=max_samples, sample_size=1, custom_format=custom_format
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
response = {
|
| 235 |
+
"dataset_id": dataset_id,
|
| 236 |
+
"total_rows": result["total_rows"],
|
| 237 |
+
"extracted_count": len(result["texts"]),
|
| 238 |
+
"format": result["format"],
|
| 239 |
+
"format_name": result["format_info"]["name"] if result["format_info"] else None,
|
| 240 |
+
"format_description": result["format_info"]["description"]
|
| 241 |
+
if result["format_info"]
|
| 242 |
+
else None,
|
| 243 |
+
"supported": result["supported"],
|
| 244 |
+
"error": result["error"],
|
| 245 |
+
"custom_format": custom_format,
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
if evaluate and result["supported"]:
|
| 249 |
+
job_id = str(uuid.uuid4())
|
| 250 |
+
jobs[job_id] = {
|
| 251 |
+
"job_id": job_id,
|
| 252 |
+
"dataset_id": dataset_id,
|
| 253 |
+
"max_samples": max_samples,
|
| 254 |
+
"status": "pending",
|
| 255 |
+
"created_at": datetime.utcnow().isoformat(),
|
| 256 |
+
"api_key": api_key,
|
| 257 |
+
}
|
| 258 |
+
_save_jobs()
|
| 259 |
+
|
| 260 |
+
thread = threading.Thread(
|
| 261 |
+
target=_run_evaluation_job,
|
| 262 |
+
args=(job_id, dataset_id, max_samples, api_key, custom_format),
|
| 263 |
+
)
|
| 264 |
+
thread.daemon = True
|
| 265 |
+
thread.start()
|
| 266 |
+
|
| 267 |
+
response["job_id"] = job_id
|
| 268 |
+
response["status"] = "pending"
|
| 269 |
+
response["message"] = "Evaluation started in background."
|
| 270 |
+
response["custom_format"] = custom_format
|
| 271 |
+
|
| 272 |
+
return jsonify(response)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def _run_evaluation_job(
|
| 276 |
+
job_id: str,
|
| 277 |
+
dataset_id: str,
|
| 278 |
+
max_samples: int,
|
| 279 |
+
api_key: str | None,
|
| 280 |
+
custom_format: str | None = None,
|
| 281 |
+
):
|
| 282 |
+
"""Background task to run dataset evaluation."""
|
| 283 |
+
jobs[job_id]["status"] = "running"
|
| 284 |
+
jobs[job_id]["started_at"] = datetime.utcnow().isoformat()
|
| 285 |
+
jobs[job_id]["custom_format"] = custom_format
|
| 286 |
+
_save_jobs()
|
| 287 |
+
|
| 288 |
+
progress_cb = lambda c, t, s: _update_job_progress(job_id, c, t, s)
|
| 289 |
+
|
| 290 |
+
try:
|
| 291 |
+
load_result = load_dataset_texts(
|
| 292 |
+
dataset_id,
|
| 293 |
+
max_samples=max_samples,
|
| 294 |
+
progress_callback=progress_cb,
|
| 295 |
+
custom_format=custom_format,
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
if not load_result["supported"]:
|
| 299 |
+
jobs[job_id].update(
|
| 300 |
+
{
|
| 301 |
+
"status": "failed",
|
| 302 |
+
"error": load_result["error"],
|
| 303 |
+
"dataset_id": dataset_id,
|
| 304 |
+
"supported": False,
|
| 305 |
+
"completed_at": datetime.utcnow().isoformat(),
|
| 306 |
+
}
|
| 307 |
+
)
|
| 308 |
+
_save_jobs()
|
| 309 |
+
return
|
| 310 |
+
|
| 311 |
+
texts = load_result["texts"]
|
| 312 |
+
if not texts:
|
| 313 |
+
jobs[job_id].update(
|
| 314 |
+
{
|
| 315 |
+
"status": "failed",
|
| 316 |
+
"error": "No valid assistant responses found in dataset",
|
| 317 |
+
"dataset_id": dataset_id,
|
| 318 |
+
"supported": True,
|
| 319 |
+
"extracted_count": 0,
|
| 320 |
+
"completed_at": datetime.utcnow().isoformat(),
|
| 321 |
+
}
|
| 322 |
+
)
|
| 323 |
+
_save_jobs()
|
| 324 |
+
return
|
| 325 |
+
|
| 326 |
+
results = {
|
| 327 |
+
"dataset_id": dataset_id,
|
| 328 |
+
"format": load_result["format"],
|
| 329 |
+
"format_name": load_result["format_info"]["name"]
|
| 330 |
+
if load_result["format_info"]
|
| 331 |
+
else None,
|
| 332 |
+
"total_rows": load_result["total_rows"],
|
| 333 |
+
"extracted_count": len(texts),
|
| 334 |
+
"provider_counts": {},
|
| 335 |
+
"provider_confidences": {},
|
| 336 |
+
"top_providers": {},
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
provider_counts = defaultdict(int)
|
| 340 |
+
provider_confidences = defaultdict(list)
|
| 341 |
+
top_providers = defaultdict(int)
|
| 342 |
+
|
| 343 |
+
af = _active_finder()
|
| 344 |
+
|
| 345 |
+
total = len(texts)
|
| 346 |
+
for i, text in enumerate(tqdm(texts, desc="Evaluating")):
|
| 347 |
+
if progress_cb and (i % 10 == 0 or i == total - 1):
|
| 348 |
+
progress_cb(i + 1, total, "evaluating")
|
| 349 |
+
try:
|
| 350 |
+
proba = af.predict_proba(text)
|
| 351 |
+
sorted_providers = sorted(
|
| 352 |
+
proba.items(), key=lambda x: x[1], reverse=True
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
pred_provider = sorted_providers[0][0]
|
| 356 |
+
confidence = sorted_providers[0][1]
|
| 357 |
+
|
| 358 |
+
provider_counts[pred_provider] += 1
|
| 359 |
+
provider_confidences[pred_provider].append(confidence)
|
| 360 |
+
|
| 361 |
+
for name, conf in sorted_providers[:5]:
|
| 362 |
+
top_providers[name] += 1
|
| 363 |
+
except Exception:
|
| 364 |
+
continue
|
| 365 |
+
|
| 366 |
+
total = len(texts)
|
| 367 |
+
for provider, count in provider_counts.items():
|
| 368 |
+
results["provider_counts"][provider] = {
|
| 369 |
+
"count": count,
|
| 370 |
+
"percentage": round((count / total) * 100, 2),
|
| 371 |
+
}
|
| 372 |
+
confs = provider_confidences[provider]
|
| 373 |
+
avg_conf = sum(confs) / len(confs) if confs else 0
|
| 374 |
+
results["provider_confidences"][provider] = {
|
| 375 |
+
"average": round(avg_conf * 100, 2),
|
| 376 |
+
"cumulative": round(avg_conf * count, 2),
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
results["top_providers"] = dict(
|
| 380 |
+
sorted(top_providers.items(), key=lambda x: -x[1])[:5]
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
sorted_by_cumulative = sorted(
|
| 384 |
+
results["provider_confidences"].items(), key=lambda x: -x[1]["cumulative"]
|
| 385 |
+
)
|
| 386 |
+
results["likely_provider"] = (
|
| 387 |
+
sorted_by_cumulative[0][0] if sorted_by_cumulative else None
|
| 388 |
+
)
|
| 389 |
+
results["average_confidence"] = (
|
| 390 |
+
round(sum(sum(c) for c in provider_confidences.values()) / total * 100, 2)
|
| 391 |
+
if total > 0
|
| 392 |
+
else 0
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
jobs[job_id].update(
|
| 396 |
+
{
|
| 397 |
+
"status": "completed",
|
| 398 |
+
"results": results,
|
| 399 |
+
"api_key": api_key,
|
| 400 |
+
"completed_at": datetime.utcnow().isoformat(),
|
| 401 |
+
}
|
| 402 |
+
)
|
| 403 |
+
_save_jobs()
|
| 404 |
+
except Exception as e:
|
| 405 |
+
jobs[job_id].update(
|
| 406 |
+
{
|
| 407 |
+
"status": "failed",
|
| 408 |
+
"error": str(e),
|
| 409 |
+
"completed_at": datetime.utcnow().isoformat(),
|
| 410 |
+
}
|
| 411 |
+
)
|
| 412 |
+
_save_jobs()
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
@app.route("/api/dataset/evaluate", methods=["POST"])
|
| 416 |
+
@limiter.limit("10/minute")
|
| 417 |
+
def dataset_evaluate():
|
| 418 |
+
"""Start a background job to evaluate a HuggingFace dataset."""
|
| 419 |
+
data = request.get_json(silent=True)
|
| 420 |
+
if not data or "dataset_id" not in data:
|
| 421 |
+
return jsonify({"error": "Request must include 'dataset_id'"}), 400
|
| 422 |
+
|
| 423 |
+
dataset_id = data["dataset_id"]
|
| 424 |
+
max_samples = data.get("max_samples", 1000)
|
| 425 |
+
api_key = data.get("api_key")
|
| 426 |
+
custom_format = data.get("custom_format")
|
| 427 |
+
|
| 428 |
+
load_result = load_dataset_texts(
|
| 429 |
+
dataset_id, max_samples=max_samples, custom_format=custom_format
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
if not load_result["supported"]:
|
| 433 |
+
return jsonify(
|
| 434 |
+
{
|
| 435 |
+
"error": load_result["error"],
|
| 436 |
+
"dataset_id": dataset_id,
|
| 437 |
+
"supported": False,
|
| 438 |
+
}
|
| 439 |
+
), 400
|
| 440 |
+
|
| 441 |
+
if not load_result["texts"]:
|
| 442 |
+
return jsonify(
|
| 443 |
+
{
|
| 444 |
+
"error": "No valid assistant responses found in dataset",
|
| 445 |
+
"dataset_id": dataset_id,
|
| 446 |
+
"supported": True,
|
| 447 |
+
"extracted_count": 0,
|
| 448 |
+
}
|
| 449 |
+
), 400
|
| 450 |
+
|
| 451 |
+
job_id = str(uuid.uuid4())
|
| 452 |
+
jobs[job_id] = {
|
| 453 |
+
"job_id": job_id,
|
| 454 |
+
"dataset_id": dataset_id,
|
| 455 |
+
"max_samples": max_samples,
|
| 456 |
+
"status": "pending",
|
| 457 |
+
"created_at": datetime.utcnow().isoformat(),
|
| 458 |
+
"api_key": api_key,
|
| 459 |
+
"custom_format": custom_format,
|
| 460 |
+
}
|
| 461 |
+
_save_jobs()
|
| 462 |
+
|
| 463 |
+
thread = threading.Thread(
|
| 464 |
+
target=_run_evaluation_job,
|
| 465 |
+
args=(job_id, dataset_id, max_samples, api_key, custom_format),
|
| 466 |
+
)
|
| 467 |
+
thread.daemon = True
|
| 468 |
+
thread.start()
|
| 469 |
+
|
| 470 |
+
return jsonify(
|
| 471 |
+
{
|
| 472 |
+
"job_id": job_id,
|
| 473 |
+
"status": "pending",
|
| 474 |
+
"message": "Evaluation started. Use the job_id to check status later.",
|
| 475 |
+
"custom_format": custom_format,
|
| 476 |
+
}
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
@app.route("/api/dataset/job/<job_id>", methods=["GET"])
|
| 481 |
+
def dataset_job_status(job_id):
|
| 482 |
+
"""Get the status and results of a dataset evaluation job."""
|
| 483 |
+
if job_id not in jobs:
|
| 484 |
+
return jsonify({"error": "Job not found"}), 404
|
| 485 |
+
|
| 486 |
+
job = jobs[job_id]
|
| 487 |
+
response = {
|
| 488 |
+
"job_id": job_id,
|
| 489 |
+
"dataset_id": job.get("dataset_id"),
|
| 490 |
+
"status": job["status"],
|
| 491 |
+
"created_at": job.get("created_at"),
|
| 492 |
+
"started_at": job.get("started_at"),
|
| 493 |
+
"completed_at": job.get("completed_at"),
|
| 494 |
+
}
|
| 495 |
+
|
| 496 |
+
if job.get("progress"):
|
| 497 |
+
response["progress"] = job["progress"]
|
| 498 |
+
|
| 499 |
+
if job["status"] == "completed":
|
| 500 |
+
response["results"] = job.get("results")
|
| 501 |
+
elif job["status"] == "failed":
|
| 502 |
+
response["error"] = job.get("error")
|
| 503 |
+
|
| 504 |
+
return jsonify(response)
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
@app.route("/api/datasets", methods=["GET"])
|
| 508 |
+
def list_datasets():
|
| 509 |
+
"""List all evaluated datasets, optionally filtered by API key."""
|
| 510 |
+
api_key = request.args.get("api_key")
|
| 511 |
+
|
| 512 |
+
filtered_jobs = []
|
| 513 |
+
for job_id, job in jobs.items():
|
| 514 |
+
if api_key and job.get("api_key") != api_key:
|
| 515 |
+
continue
|
| 516 |
+
if job["status"] in ("completed", "failed"):
|
| 517 |
+
filtered_jobs.append(
|
| 518 |
+
{
|
| 519 |
+
"job_id": job_id,
|
| 520 |
+
"dataset_id": job.get("dataset_id"),
|
| 521 |
+
"status": job["status"],
|
| 522 |
+
"created_at": job.get("created_at"),
|
| 523 |
+
"completed_at": job.get("completed_at"),
|
| 524 |
+
"error": job.get("error"),
|
| 525 |
+
"custom_format": job.get("custom_format"),
|
| 526 |
+
}
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
filtered_jobs.sort(key=lambda x: x.get("created_at", ""), reverse=True)
|
| 530 |
+
return jsonify({"datasets": filtered_jobs})
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
@app.route("/api/datasets/clear", methods=["POST"])
|
| 534 |
+
def clear_datasets():
|
| 535 |
+
"""Clear all evaluated dataset history for the current API key."""
|
| 536 |
+
data = request.get_json(silent=True) or {}
|
| 537 |
+
api_key = data.get("api_key")
|
| 538 |
+
|
| 539 |
+
if not api_key:
|
| 540 |
+
return jsonify({"error": "API key required"}), 400
|
| 541 |
+
|
| 542 |
+
keys_to_remove = []
|
| 543 |
+
for job_id, job in jobs.items():
|
| 544 |
+
if job.get("api_key") == api_key and job["status"] in ("completed", "failed"):
|
| 545 |
+
keys_to_remove.append(job_id)
|
| 546 |
+
|
| 547 |
+
for key in keys_to_remove:
|
| 548 |
+
del jobs[key]
|
| 549 |
+
|
| 550 |
+
if keys_to_remove:
|
| 551 |
+
_save_jobs()
|
| 552 |
+
|
| 553 |
+
return jsonify({"status": "ok", "cleared": len(keys_to_remove)})
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
@app.route("/api/dataset/formats", methods=["GET"])
|
| 557 |
+
def dataset_formats():
|
| 558 |
+
"""Get list of supported dataset formats."""
|
| 559 |
+
formats = get_supported_formats()
|
| 560 |
+
formats_list = [
|
| 561 |
+
{
|
| 562 |
+
"name": info["name"],
|
| 563 |
+
"description": info["description"],
|
| 564 |
+
"examples": info["examples"],
|
| 565 |
+
}
|
| 566 |
+
for info in formats.values()
|
| 567 |
+
]
|
| 568 |
+
formats_list.append(
|
| 569 |
+
{
|
| 570 |
+
"name": "Custom Format",
|
| 571 |
+
"description": "Define your own format specification",
|
| 572 |
+
"examples": [
|
| 573 |
+
"column: response",
|
| 574 |
+
"column: prompt, column: response",
|
| 575 |
+
"pattern: user:, pattern: assistant:",
|
| 576 |
+
"user:[startuser]assistant:[startassistant]",
|
| 577 |
+
],
|
| 578 |
+
}
|
| 579 |
+
)
|
| 580 |
+
return jsonify(
|
| 581 |
+
{
|
| 582 |
+
"formats": formats_list,
|
| 583 |
+
"custom_format_help": {
|
| 584 |
+
"description": "Specify custom format using these patterns:",
|
| 585 |
+
"patterns": [
|
| 586 |
+
"column: <field_name> - extract single field",
|
| 587 |
+
"column: <user_field>, column: <assistant_field> - extract from two columns",
|
| 588 |
+
"pattern: <regex> - use regex to extract",
|
| 589 |
+
"user:[startuser]assistant:[startassistant] - pattern-based extraction",
|
| 590 |
+
],
|
| 591 |
+
"examples": [
|
| 592 |
+
{
|
| 593 |
+
"input": "column: completion",
|
| 594 |
+
"description": "Extract from 'completion' field",
|
| 595 |
+
},
|
| 596 |
+
{
|
| 597 |
+
"input": "column: input, column: output",
|
| 598 |
+
"description": "Extract from 'input' and 'output' columns",
|
| 599 |
+
},
|
| 600 |
+
{
|
| 601 |
+
"input": "user:[INST]assistant:[/INST]",
|
| 602 |
+
"description": "Extract text between markers",
|
| 603 |
+
},
|
| 604 |
+
],
|
| 605 |
+
},
|
| 606 |
+
}
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
class CommunityAIFinder(AIFinder):
|
| 611 |
+
"""Extended AIFinder that blends base model with correction model."""
|
| 612 |
+
|
| 613 |
+
def __init__(self, model_dir, correction_model_path=None):
|
| 614 |
+
super().__init__(model_dir)
|
| 615 |
+
self.correction_models = None
|
| 616 |
+
if correction_model_path and os.path.exists(correction_model_path):
|
| 617 |
+
self.correction_models = joblib.load(correction_model_path)
|
| 618 |
+
|
| 619 |
+
def predict_proba(self, text):
|
| 620 |
+
"""Blend base model predictions with correction model if available."""
|
| 621 |
+
X = self.pipeline.transform([text])
|
| 622 |
+
|
| 623 |
+
base_proba = np.mean([m.predict_proba(X) for m in self.models], axis=0)
|
| 624 |
+
|
| 625 |
+
if self.correction_models is not None and len(self.correction_models) > 0:
|
| 626 |
+
correction_proba = np.mean(
|
| 627 |
+
[m.predict_proba(X) for m in self.correction_models], axis=0
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
blend_weight = 0.7
|
| 631 |
+
final_proba = (
|
| 632 |
+
1 - blend_weight
|
| 633 |
+
) * base_proba + blend_weight * correction_proba
|
| 634 |
+
final_proba = final_proba / final_proba.sum(axis=1, keepdims=True)
|
| 635 |
+
else:
|
| 636 |
+
final_proba = base_proba
|
| 637 |
+
|
| 638 |
+
return dict(zip(self.le.classes_, final_proba[0]))
|
| 639 |
+
|
| 640 |
+
|
| 641 |
+
def load_models():
|
| 642 |
+
global finder, community_finder, corrections, jobs
|
| 643 |
+
finder = AIFinder(model_dir=STYLE_MODEL_DIR)
|
| 644 |
+
os.makedirs(COMMUNITY_DIR, exist_ok=True)
|
| 645 |
+
_copy_base_models_to_community()
|
| 646 |
+
if os.path.exists(CORRECTIONS_FILE):
|
| 647 |
+
corrections = joblib.load(CORRECTIONS_FILE)
|
| 648 |
+
if os.path.exists(JOBS_FILE):
|
| 649 |
+
jobs = joblib.load(JOBS_FILE)
|
| 650 |
+
if os.path.exists(os.path.join(COMMUNITY_DIR, "rf_4provider.joblib")):
|
| 651 |
+
try:
|
| 652 |
+
community_finder = CommunityAIFinder(
|
| 653 |
+
model_dir=COMMUNITY_DIR, correction_model_path=CORRECTION_MODEL_FILE
|
| 654 |
+
)
|
| 655 |
+
except Exception:
|
| 656 |
+
community_finder = None
|
| 657 |
+
|
| 658 |
+
|
| 659 |
if __name__ == "__main__":
|
| 660 |
print("Loading models...")
|
| 661 |
load_models()
|
config.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
"""
|
| 2 |
AIFinder Configuration
|
| 3 |
-
|
| 4 |
"""
|
| 5 |
|
| 6 |
import os
|
|
@@ -9,10 +9,11 @@ import os
|
|
| 9 |
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 10 |
MODEL_DIR = os.path.join(BASE_DIR, "models")
|
| 11 |
|
| 12 |
-
#
|
| 13 |
-
#
|
| 14 |
-
#
|
| 15 |
-
|
|
|
|
| 16 |
# Anthropic
|
| 17 |
("TeichAI/claude-4.5-opus-high-reasoning-250x", "Anthropic", "Claude 4.5 Opus", {}),
|
| 18 |
(
|
|
@@ -73,7 +74,7 @@ DATASET_REGISTRY = [
|
|
| 73 |
),
|
| 74 |
# Zhipu
|
| 75 |
("TeichAI/Pony-Alpha-15k", "Zhipu", "GLM-5", {"max_samples": 1500}),
|
| 76 |
-
# DeepSeek
|
| 77 |
("TeichAI/deepseek-v3.2-speciale-1000x", "DeepSeek", "DeepSeek V3.2 Speciale", {}),
|
| 78 |
(
|
| 79 |
"TeichAI/deepseek-v3.2-speciale-openr1-math-3k",
|
|
@@ -81,10 +82,7 @@ DATASET_REGISTRY = [
|
|
| 81 |
"DeepSeek V3.2 Speciale",
|
| 82 |
{"max_samples": 1500},
|
| 83 |
),
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
# DeepSeek (a-m-team) — different format, handled separately
|
| 87 |
-
DEEPSEEK_AM_DATASETS = [
|
| 88 |
(
|
| 89 |
"a-m-team/AM-DeepSeek-R1-Distilled-1.4M",
|
| 90 |
"DeepSeek",
|
|
@@ -93,24 +91,22 @@ DEEPSEEK_AM_DATASETS = [
|
|
| 93 |
),
|
| 94 |
]
|
| 95 |
|
| 96 |
-
#
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
"
|
| 106 |
-
"Mistral",
|
| 107 |
-
"MiniMax",
|
| 108 |
-
"StepFun",
|
| 109 |
-
"Zhipu",
|
| 110 |
-
"DeepSeek",
|
| 111 |
]
|
|
|
|
| 112 |
|
| 113 |
-
#
|
|
|
|
|
|
|
| 114 |
TFIDF_WORD_PARAMS = {
|
| 115 |
"analyzer": "word",
|
| 116 |
"ngram_range": (1, 2),
|
|
@@ -130,15 +126,15 @@ TFIDF_CHAR_PARAMS = {
|
|
| 130 |
"smooth_idf": True,
|
| 131 |
}
|
| 132 |
|
| 133 |
-
#
|
|
|
|
|
|
|
| 134 |
MAX_SAMPLES_PER_PROVIDER = 1000
|
| 135 |
-
|
| 136 |
-
# --- Train/val/test split ---
|
| 137 |
TEST_SIZE = 0.15
|
| 138 |
VAL_SIZE = 0.10
|
| 139 |
RANDOM_STATE = 42
|
| 140 |
|
| 141 |
-
#
|
| 142 |
HIDDEN_DIM = 256
|
| 143 |
EMBED_DIM = 128
|
| 144 |
DROPOUT = 0.7
|
|
|
|
| 1 |
"""
|
| 2 |
AIFinder Configuration
|
| 3 |
+
Easy configuration for providers and datasets.
|
| 4 |
"""
|
| 5 |
|
| 6 |
import os
|
|
|
|
| 9 |
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 10 |
MODEL_DIR = os.path.join(BASE_DIR, "models")
|
| 11 |
|
| 12 |
+
# ============================================================================
|
| 13 |
+
# EASY PROVIDER CONFIGURATION
|
| 14 |
+
# Add new providers here! Each entry: (huggingface_dataset, provider_name, model_name, kwargs)
|
| 15 |
+
# ============================================================================
|
| 16 |
+
PROVIDER_DATASETS = [
|
| 17 |
# Anthropic
|
| 18 |
("TeichAI/claude-4.5-opus-high-reasoning-250x", "Anthropic", "Claude 4.5 Opus", {}),
|
| 19 |
(
|
|
|
|
| 74 |
),
|
| 75 |
# Zhipu
|
| 76 |
("TeichAI/Pony-Alpha-15k", "Zhipu", "GLM-5", {"max_samples": 1500}),
|
| 77 |
+
# DeepSeek
|
| 78 |
("TeichAI/deepseek-v3.2-speciale-1000x", "DeepSeek", "DeepSeek V3.2 Speciale", {}),
|
| 79 |
(
|
| 80 |
"TeichAI/deepseek-v3.2-speciale-openr1-math-3k",
|
|
|
|
| 82 |
"DeepSeek V3.2 Speciale",
|
| 83 |
{"max_samples": 1500},
|
| 84 |
),
|
| 85 |
+
# DeepSeek (a-m-team) - different format
|
|
|
|
|
|
|
|
|
|
| 86 |
(
|
| 87 |
"a-m-team/AM-DeepSeek-R1-Distilled-1.4M",
|
| 88 |
"DeepSeek",
|
|
|
|
| 91 |
),
|
| 92 |
]
|
| 93 |
|
| 94 |
+
# Auto-generate DATASET_REGISTRY and PROVIDERS from PROVIDER_DATASETS
|
| 95 |
+
DEEPSEEK_AM_DATASETS = [
|
| 96 |
+
(ds_id, prov, model, kwargs)
|
| 97 |
+
for ds_id, prov, model, kwargs in PROVIDER_DATASETS
|
| 98 |
+
if "a-m-team" in ds_id
|
| 99 |
+
]
|
| 100 |
+
DATASET_REGISTRY = [
|
| 101 |
+
(ds_id, prov, model, kwargs)
|
| 102 |
+
for ds_id, prov, model, kwargs in PROVIDER_DATASETS
|
| 103 |
+
if "a-m-team" not in ds_id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
]
|
| 105 |
+
PROVIDERS = sorted(set(prov for _, prov, _, _ in PROVIDER_DATASETS))
|
| 106 |
|
| 107 |
+
# ============================================================================
|
| 108 |
+
# FEATURE PARAMETERS
|
| 109 |
+
# ============================================================================
|
| 110 |
TFIDF_WORD_PARAMS = {
|
| 111 |
"analyzer": "word",
|
| 112 |
"ngram_range": (1, 2),
|
|
|
|
| 126 |
"smooth_idf": True,
|
| 127 |
}
|
| 128 |
|
| 129 |
+
# ============================================================================
|
| 130 |
+
# TRAINING PARAMETERS
|
| 131 |
+
# ============================================================================
|
| 132 |
MAX_SAMPLES_PER_PROVIDER = 1000
|
|
|
|
|
|
|
| 133 |
TEST_SIZE = 0.15
|
| 134 |
VAL_SIZE = 0.10
|
| 135 |
RANDOM_STATE = 42
|
| 136 |
|
| 137 |
+
# Neural Network (unused currently, but kept for reference)
|
| 138 |
HIDDEN_DIM = 256
|
| 139 |
EMBED_DIM = 128
|
| 140 |
DROPOUT = 0.7
|
data_loader.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AIFinder Data Loader
|
| 3 |
+
Downloads and parses HuggingFace datasets, extracts assistant responses,
|
| 4 |
+
and labels them with is_ai, provider, and model.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import re
|
| 9 |
+
import time
|
| 10 |
+
|
| 11 |
+
from datasets import load_dataset
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
from config import (
|
| 15 |
+
DATASET_REGISTRY,
|
| 16 |
+
DEEPSEEK_AM_DATASETS,
|
| 17 |
+
MAX_SAMPLES_PER_PROVIDER,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _parse_msg(msg):
|
| 24 |
+
"""Parse a message that may be a dict or a JSON string."""
|
| 25 |
+
if isinstance(msg, dict):
|
| 26 |
+
return msg
|
| 27 |
+
if isinstance(msg, str):
|
| 28 |
+
try:
|
| 29 |
+
import json as _json
|
| 30 |
+
|
| 31 |
+
parsed = _json.loads(msg)
|
| 32 |
+
if isinstance(parsed, dict):
|
| 33 |
+
return parsed
|
| 34 |
+
except (ValueError, Exception):
|
| 35 |
+
pass
|
| 36 |
+
return {}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _extract_response_only(content):
|
| 40 |
+
"""Extract only the final response, stripping CoT blocks.
|
| 41 |
+
Returns only the text after </think> or </thinking> if present,
|
| 42 |
+
otherwise returns the full content.
|
| 43 |
+
"""
|
| 44 |
+
if not content:
|
| 45 |
+
return ""
|
| 46 |
+
think_match = re.search(r"</?think(?:ing)?>(.*)$", content, re.DOTALL)
|
| 47 |
+
if think_match:
|
| 48 |
+
response = think_match.group(1).strip()
|
| 49 |
+
if response:
|
| 50 |
+
return response
|
| 51 |
+
return content
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _extract_assistant_texts_from_conversations(rows):
|
| 55 |
+
"""Extract individual assistant messages from conversation datasets.
|
| 56 |
+
Returns one text per assistant turn (not concatenated) for cleaner samples.
|
| 57 |
+
Only extracts the response portion (after </think> if present).
|
| 58 |
+
"""
|
| 59 |
+
texts = []
|
| 60 |
+
for row in rows:
|
| 61 |
+
convos = row.get("conversations")
|
| 62 |
+
if convos is None or (hasattr(convos, "__len__") and len(convos) == 0):
|
| 63 |
+
convos = row.get("messages")
|
| 64 |
+
if convos is None or (hasattr(convos, "__len__") and len(convos) == 0):
|
| 65 |
+
convos = []
|
| 66 |
+
for msg in convos:
|
| 67 |
+
msg = _parse_msg(msg)
|
| 68 |
+
role = msg.get("role", "")
|
| 69 |
+
content = msg.get("content", "")
|
| 70 |
+
if role in ("assistant", "gpt", "model") and content:
|
| 71 |
+
response_only = _extract_response_only(content)
|
| 72 |
+
if response_only:
|
| 73 |
+
texts.append(response_only)
|
| 74 |
+
return texts
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def _extract_from_am_dataset(row):
|
| 78 |
+
"""Extract individual assistant texts from a-m-team format.
|
| 79 |
+
Only extracts the response portion (after </think> if present).
|
| 80 |
+
"""
|
| 81 |
+
messages = row.get("messages") or row.get("conversations") or []
|
| 82 |
+
texts = []
|
| 83 |
+
for msg in messages:
|
| 84 |
+
role = msg.get("role", "") if isinstance(msg, dict) else ""
|
| 85 |
+
content = msg.get("content", "") if isinstance(msg, dict) else ""
|
| 86 |
+
if role == "assistant" and content:
|
| 87 |
+
response_only = _extract_response_only(content)
|
| 88 |
+
if response_only:
|
| 89 |
+
texts.append(response_only)
|
| 90 |
+
return texts
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def load_teichai_dataset(dataset_id, provider, model_name, kwargs):
|
| 94 |
+
"""Load a single conversation-format dataset and return (texts, providers, models)."""
|
| 95 |
+
max_samples = kwargs.get("max_samples")
|
| 96 |
+
load_kwargs = {}
|
| 97 |
+
if "name" in kwargs:
|
| 98 |
+
load_kwargs["name"] = kwargs["name"]
|
| 99 |
+
|
| 100 |
+
try:
|
| 101 |
+
ds = load_dataset(dataset_id, split="train", token=HF_TOKEN, **load_kwargs)
|
| 102 |
+
rows = list(ds)
|
| 103 |
+
except Exception as e:
|
| 104 |
+
# Fallback: load from auto-converted parquet via HF API
|
| 105 |
+
try:
|
| 106 |
+
import pandas as pd
|
| 107 |
+
|
| 108 |
+
url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/default/train/0.parquet"
|
| 109 |
+
df = pd.read_parquet(url)
|
| 110 |
+
rows = df.to_dict(orient="records")
|
| 111 |
+
except Exception as e2:
|
| 112 |
+
print(f" [SKIP] {dataset_id}: {e} / parquet fallback: {e2}")
|
| 113 |
+
return [], [], []
|
| 114 |
+
|
| 115 |
+
if max_samples and len(rows) > max_samples:
|
| 116 |
+
import random
|
| 117 |
+
|
| 118 |
+
random.seed(42)
|
| 119 |
+
rows = random.sample(rows, max_samples)
|
| 120 |
+
|
| 121 |
+
texts = _extract_assistant_texts_from_conversations(rows)
|
| 122 |
+
|
| 123 |
+
# Filter out empty/too-short texts
|
| 124 |
+
filtered = [(t, provider, model_name) for t in texts if len(t) > 50]
|
| 125 |
+
if not filtered:
|
| 126 |
+
print(f" [SKIP] {dataset_id}: no valid texts extracted")
|
| 127 |
+
return [], [], []
|
| 128 |
+
|
| 129 |
+
t, p, m = zip(*filtered)
|
| 130 |
+
return list(t), list(p), list(m)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def load_am_deepseek_dataset(dataset_id, provider, model_name, kwargs):
|
| 134 |
+
"""Load a-m-team DeepSeek dataset."""
|
| 135 |
+
max_samples = kwargs.get("max_samples")
|
| 136 |
+
load_kwargs = {}
|
| 137 |
+
if "name" in kwargs:
|
| 138 |
+
load_kwargs["name"] = kwargs["name"]
|
| 139 |
+
|
| 140 |
+
try:
|
| 141 |
+
ds = load_dataset(dataset_id, split="train", token=HF_TOKEN, **load_kwargs)
|
| 142 |
+
except Exception:
|
| 143 |
+
try:
|
| 144 |
+
ds = load_dataset(
|
| 145 |
+
dataset_id, split="train", streaming=True, token=HF_TOKEN, **load_kwargs
|
| 146 |
+
)
|
| 147 |
+
rows = []
|
| 148 |
+
for row in ds:
|
| 149 |
+
rows.append(row)
|
| 150 |
+
if max_samples and len(rows) >= max_samples:
|
| 151 |
+
break
|
| 152 |
+
except Exception as e2:
|
| 153 |
+
print(f" [SKIP] {dataset_id}: {e2}")
|
| 154 |
+
return [], [], []
|
| 155 |
+
else:
|
| 156 |
+
rows = list(ds)
|
| 157 |
+
if max_samples and len(rows) > max_samples:
|
| 158 |
+
rows = rows[:max_samples]
|
| 159 |
+
|
| 160 |
+
texts = []
|
| 161 |
+
for row in rows:
|
| 162 |
+
for text in _extract_from_am_dataset(row):
|
| 163 |
+
if len(text) > 50:
|
| 164 |
+
texts.append(text)
|
| 165 |
+
|
| 166 |
+
providers = [provider] * len(texts)
|
| 167 |
+
models = [model_name] * len(texts)
|
| 168 |
+
return texts, providers, models
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def load_all_data():
|
| 172 |
+
"""Load all datasets and return combined lists.
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
texts: list of str
|
| 176 |
+
providers: list of str
|
| 177 |
+
models: list of str
|
| 178 |
+
is_ai: list of int (1=AI, 0=Human)
|
| 179 |
+
"""
|
| 180 |
+
all_texts = []
|
| 181 |
+
all_providers = []
|
| 182 |
+
all_models = []
|
| 183 |
+
|
| 184 |
+
# TeichAI datasets
|
| 185 |
+
print("Loading TeichAI datasets...")
|
| 186 |
+
for dataset_id, provider, model_name, kwargs in tqdm(
|
| 187 |
+
DATASET_REGISTRY, desc="TeichAI"
|
| 188 |
+
):
|
| 189 |
+
t0 = time.time()
|
| 190 |
+
texts, providers, models = load_teichai_dataset(
|
| 191 |
+
dataset_id, provider, model_name, kwargs
|
| 192 |
+
)
|
| 193 |
+
elapsed = time.time() - t0
|
| 194 |
+
all_texts.extend(texts)
|
| 195 |
+
all_providers.extend(providers)
|
| 196 |
+
all_models.extend(models)
|
| 197 |
+
print(f" {dataset_id}: {len(texts)} samples ({elapsed:.1f}s)")
|
| 198 |
+
|
| 199 |
+
# DeepSeek a-m-team datasets
|
| 200 |
+
print("\nLoading DeepSeek (a-m-team) datasets...")
|
| 201 |
+
for dataset_id, provider, model_name, kwargs in tqdm(
|
| 202 |
+
DEEPSEEK_AM_DATASETS, desc="DeepSeek-AM"
|
| 203 |
+
):
|
| 204 |
+
t0 = time.time()
|
| 205 |
+
texts, providers, models = load_am_deepseek_dataset(
|
| 206 |
+
dataset_id, provider, model_name, kwargs
|
| 207 |
+
)
|
| 208 |
+
elapsed = time.time() - t0
|
| 209 |
+
all_texts.extend(texts)
|
| 210 |
+
all_providers.extend(providers)
|
| 211 |
+
all_models.extend(models)
|
| 212 |
+
print(f" {dataset_id}: {len(texts)} samples ({elapsed:.1f}s)")
|
| 213 |
+
|
| 214 |
+
# Deduplicate by text hash
|
| 215 |
+
import hashlib
|
| 216 |
+
import random as _rng
|
| 217 |
+
|
| 218 |
+
_rng.seed(42)
|
| 219 |
+
|
| 220 |
+
seen = set()
|
| 221 |
+
dedup_texts, dedup_providers, dedup_models = [], [], []
|
| 222 |
+
for t, p, m in zip(all_texts, all_providers, all_models):
|
| 223 |
+
h = hashlib.md5(t.strip().lower().encode()).hexdigest()
|
| 224 |
+
if h not in seen:
|
| 225 |
+
seen.add(h)
|
| 226 |
+
dedup_texts.append(t)
|
| 227 |
+
dedup_providers.append(p)
|
| 228 |
+
dedup_models.append(m)
|
| 229 |
+
|
| 230 |
+
n_dupes = len(all_texts) - len(dedup_texts)
|
| 231 |
+
if n_dupes > 0:
|
| 232 |
+
print(f"\n Removed {n_dupes} duplicate samples")
|
| 233 |
+
|
| 234 |
+
# Equal samples per provider
|
| 235 |
+
from collections import defaultdict
|
| 236 |
+
|
| 237 |
+
provider_indices = defaultdict(list)
|
| 238 |
+
for i, p in enumerate(dedup_providers):
|
| 239 |
+
provider_indices[p].append(i)
|
| 240 |
+
|
| 241 |
+
# Use min of available or max allowed
|
| 242 |
+
keep_indices = []
|
| 243 |
+
for p, idxs in provider_indices.items():
|
| 244 |
+
_rng.shuffle(idxs)
|
| 245 |
+
n_sample = min(len(idxs), MAX_SAMPLES_PER_PROVIDER)
|
| 246 |
+
idxs = idxs[:n_sample]
|
| 247 |
+
print(f" Sampled {p}: {len(idxs)} samples")
|
| 248 |
+
keep_indices.extend(idxs)
|
| 249 |
+
keep_indices.sort()
|
| 250 |
+
|
| 251 |
+
all_texts = [dedup_texts[i] for i in keep_indices]
|
| 252 |
+
all_providers = [dedup_providers[i] for i in keep_indices]
|
| 253 |
+
all_models = [dedup_models[i] for i in keep_indices]
|
| 254 |
+
|
| 255 |
+
# Build is_ai labels (all AI)
|
| 256 |
+
is_ai = [1] * len(all_texts)
|
| 257 |
+
|
| 258 |
+
print(f"\n=== Total: {len(all_texts)} samples ===")
|
| 259 |
+
# Print per-provider counts
|
| 260 |
+
from collections import Counter
|
| 261 |
+
|
| 262 |
+
prov_counts = Counter(all_providers)
|
| 263 |
+
for p, c in sorted(prov_counts.items(), key=lambda x: -x[1]):
|
| 264 |
+
print(f" {p}: {c}")
|
| 265 |
+
|
| 266 |
+
return all_texts, all_providers, all_models, is_ai
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
if __name__ == "__main__":
|
| 270 |
+
texts, providers, models, is_ai = load_all_data()
|
dataset_evaluator.py
ADDED
|
@@ -0,0 +1,769 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AIFinder Dataset Evaluator
|
| 3 |
+
Supports various HuggingFace dataset formats for evaluation.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import re
|
| 8 |
+
import json
|
| 9 |
+
import random
|
| 10 |
+
from collections import defaultdict
|
| 11 |
+
from typing import Any
|
| 12 |
+
|
| 13 |
+
from datasets import load_dataset
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
|
| 16 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
SUPPORTED_FORMATS = {
|
| 20 |
+
"teichai_healer": {
|
| 21 |
+
"name": "TeichAI Healer Format",
|
| 22 |
+
"description": "TeichAI Healer-Alpha format with 'prompt' and 'response' fields",
|
| 23 |
+
"examples": ["TeichAI/Healer-Alpha-16k"],
|
| 24 |
+
"check": lambda row: (
|
| 25 |
+
"prompt" in row
|
| 26 |
+
and "response" in row
|
| 27 |
+
and isinstance(row.get("prompt"), (str, dict))
|
| 28 |
+
and isinstance(row.get("response"), (str, dict))
|
| 29 |
+
),
|
| 30 |
+
},
|
| 31 |
+
"teichai": {
|
| 32 |
+
"name": "TeichAI Format",
|
| 33 |
+
"description": "TeichAI dataset format with 'conversations' or 'messages' containing role/content",
|
| 34 |
+
"examples": [
|
| 35 |
+
"TeichAI/claude-4.5-opus-high-reasoning-250x",
|
| 36 |
+
"TeichAI/Claude-3.5-Sonnet-128k",
|
| 37 |
+
],
|
| 38 |
+
"check": lambda row: _check_conversations_format(row),
|
| 39 |
+
},
|
| 40 |
+
"combined": {
|
| 41 |
+
"name": "Combined Outputs",
|
| 42 |
+
"description": "Dataset with 'output', 'outputs', 'generated' or 'completion' field",
|
| 43 |
+
"examples": ["jacobmorrison/gpt-oss-20b-combined-outputs"],
|
| 44 |
+
"check": lambda row: (
|
| 45 |
+
"prompt" not in row
|
| 46 |
+
and "response" not in row
|
| 47 |
+
and not _check_conversations_format(row)
|
| 48 |
+
and (
|
| 49 |
+
any(k in row for k in ["output", "outputs", "generated", "completion"])
|
| 50 |
+
or (
|
| 51 |
+
isinstance(row.get("data"), str)
|
| 52 |
+
or isinstance(row.get("example"), str)
|
| 53 |
+
)
|
| 54 |
+
)
|
| 55 |
+
),
|
| 56 |
+
},
|
| 57 |
+
"conversations": {
|
| 58 |
+
"name": "Conversations Format",
|
| 59 |
+
"description": "Dataset with 'conversations' or 'messages' field containing role/content pairs",
|
| 60 |
+
"examples": [
|
| 61 |
+
"TeichAI/claude-4.5-opus-high-reasoning-250x",
|
| 62 |
+
"ianncity/Hunter-Alpha-SFT-300000x",
|
| 63 |
+
],
|
| 64 |
+
"check": lambda row: _check_conversations_format(row),
|
| 65 |
+
},
|
| 66 |
+
"chat": {
|
| 67 |
+
"name": "Chat Format",
|
| 68 |
+
"description": "Dataset with 'chat' or 'dialogue' field",
|
| 69 |
+
"examples": ["some/chat-dataset"],
|
| 70 |
+
"check": lambda row: ("chat" in row.keys() or "dialogue" in row.keys()),
|
| 71 |
+
},
|
| 72 |
+
"text": {
|
| 73 |
+
"name": "Text Field",
|
| 74 |
+
"description": "Dataset with a 'text' field containing the response",
|
| 75 |
+
"examples": ["some/text-dataset"],
|
| 76 |
+
"check": lambda row: "text" in row and isinstance(row.get("text"), str),
|
| 77 |
+
},
|
| 78 |
+
"response": {
|
| 79 |
+
"name": "Response Field",
|
| 80 |
+
"description": "Dataset with 'response' or 'output' field",
|
| 81 |
+
"examples": ["some/response-dataset"],
|
| 82 |
+
"check": lambda row: "response" in row or "output" in row,
|
| 83 |
+
},
|
| 84 |
+
"content": {
|
| 85 |
+
"name": "Content Field",
|
| 86 |
+
"description": "Dataset with 'content' field (single message)",
|
| 87 |
+
"examples": ["some/content-dataset"],
|
| 88 |
+
"check": lambda row: "content" in row and isinstance(row.get("content"), str),
|
| 89 |
+
},
|
| 90 |
+
"messages": {
|
| 91 |
+
"name": "Messages Array",
|
| 92 |
+
"description": "Dataset where each row is an array of message objects",
|
| 93 |
+
"examples": ["some/messages-dataset"],
|
| 94 |
+
"check": lambda row: isinstance(row, list)
|
| 95 |
+
and len(row) > 0
|
| 96 |
+
and isinstance(row[0], dict),
|
| 97 |
+
},
|
| 98 |
+
"sft": {
|
| 99 |
+
"name": "SFT Format",
|
| 100 |
+
"description": "Supervised Fine-Tuning format with 'prompt' and 'response' or 'completion'",
|
| 101 |
+
"examples": ["some/sft-dataset"],
|
| 102 |
+
"check": lambda row: "prompt" in row
|
| 103 |
+
and ("response" in row or "completion" in row),
|
| 104 |
+
},
|
| 105 |
+
"qa": {
|
| 106 |
+
"name": "Q&A Format",
|
| 107 |
+
"description": "Question-Answer format with 'question' and 'answer' fields",
|
| 108 |
+
"examples": ["some/qa-dataset"],
|
| 109 |
+
"check": lambda row: "question" in row and "answer" in row,
|
| 110 |
+
},
|
| 111 |
+
"combined": {
|
| 112 |
+
"name": "Combined Outputs",
|
| 113 |
+
"description": "Dataset with 'input', 'output', 'outputs' or combined text field",
|
| 114 |
+
"examples": ["jacobmorrison/gpt-oss-20b-combined-outputs"],
|
| 115 |
+
"check": lambda row: any(
|
| 116 |
+
k in row
|
| 117 |
+
for k in ["output", "outputs", "combined", "generated", "completion"]
|
| 118 |
+
)
|
| 119 |
+
or (isinstance(row.get("data"), str) or isinstance(row.get("example"), str)),
|
| 120 |
+
},
|
| 121 |
+
"completion": {
|
| 122 |
+
"name": "Completion Format",
|
| 123 |
+
"description": "Dataset with 'completion' field (like OpenAI fine-tuning)",
|
| 124 |
+
"examples": ["some/completion-dataset"],
|
| 125 |
+
"check": lambda row: "completion" in row
|
| 126 |
+
and isinstance(row.get("completion"), str),
|
| 127 |
+
},
|
| 128 |
+
"generations": {
|
| 129 |
+
"name": "Generations Format",
|
| 130 |
+
"description": "Dataset with 'generations' or 'generation' field (LLM outputs)",
|
| 131 |
+
"examples": ["some/generations-dataset"],
|
| 132 |
+
"check": lambda row: "generations" in row or "generation" in row,
|
| 133 |
+
},
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def _check_conversations_format(row):
|
| 138 |
+
"""Check if row has conversations/messages with proper role/content structure."""
|
| 139 |
+
conv_key = (
|
| 140 |
+
"conversations"
|
| 141 |
+
if "conversations" in row
|
| 142 |
+
else "messages"
|
| 143 |
+
if "messages" in row
|
| 144 |
+
else None
|
| 145 |
+
)
|
| 146 |
+
if not conv_key:
|
| 147 |
+
return False
|
| 148 |
+
convos = row.get(conv_key)
|
| 149 |
+
if not isinstance(convos, list) or not convos:
|
| 150 |
+
return False
|
| 151 |
+
first_msg = convos[0]
|
| 152 |
+
if isinstance(first_msg, dict):
|
| 153 |
+
return "role" in first_msg and "content" in first_msg
|
| 154 |
+
return False
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def detect_format(rows, sample_size=10):
|
| 158 |
+
"""Detect the dataset format from sample rows."""
|
| 159 |
+
if not rows:
|
| 160 |
+
return None, []
|
| 161 |
+
|
| 162 |
+
sample = rows[:sample_size]
|
| 163 |
+
|
| 164 |
+
for fmt_name, fmt_info in SUPPORTED_FORMATS.items():
|
| 165 |
+
check_func = fmt_info["check"]
|
| 166 |
+
matches = 0
|
| 167 |
+
for row in sample:
|
| 168 |
+
try:
|
| 169 |
+
if check_func(row):
|
| 170 |
+
matches += 1
|
| 171 |
+
except:
|
| 172 |
+
pass
|
| 173 |
+
|
| 174 |
+
if matches >= len(sample) * 0.6:
|
| 175 |
+
return fmt_name, SUPPORTED_FORMATS[fmt_name]
|
| 176 |
+
|
| 177 |
+
return None, []
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def _parse_msg(msg):
|
| 181 |
+
"""Parse a message that may be a dict or a JSON string."""
|
| 182 |
+
if isinstance(msg, dict):
|
| 183 |
+
return msg
|
| 184 |
+
if isinstance(msg, str):
|
| 185 |
+
try:
|
| 186 |
+
parsed = json.loads(msg)
|
| 187 |
+
if isinstance(parsed, dict):
|
| 188 |
+
return parsed
|
| 189 |
+
except (ValueError, Exception):
|
| 190 |
+
pass
|
| 191 |
+
return {}
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def _extract_response_only(content):
|
| 195 |
+
"""Extract only the final response, stripping CoT blocks."""
|
| 196 |
+
if not content:
|
| 197 |
+
return ""
|
| 198 |
+
think_match = re.search(r"</?think(?:ing)?>(.*)$", content, re.DOTALL)
|
| 199 |
+
if think_match:
|
| 200 |
+
response = think_match.group(1).strip()
|
| 201 |
+
if response:
|
| 202 |
+
return response
|
| 203 |
+
return content
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def extract_texts_conversations(rows):
|
| 207 |
+
"""Extract from conversations/messages format."""
|
| 208 |
+
texts = []
|
| 209 |
+
for row in rows:
|
| 210 |
+
convos = row.get("conversations") or row.get("messages") or []
|
| 211 |
+
|
| 212 |
+
if not convos:
|
| 213 |
+
continue
|
| 214 |
+
|
| 215 |
+
for msg in convos:
|
| 216 |
+
msg = _parse_msg(msg)
|
| 217 |
+
role = msg.get("role", "")
|
| 218 |
+
content = msg.get("content", "")
|
| 219 |
+
|
| 220 |
+
if role in ("assistant", "gpt", "model", "ai") and content:
|
| 221 |
+
response_only = _extract_response_only(content)
|
| 222 |
+
if response_only and len(response_only) > 50:
|
| 223 |
+
texts.append(response_only)
|
| 224 |
+
return texts
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def extract_texts_chat(rows):
|
| 228 |
+
"""Extract from chat/dialogue format."""
|
| 229 |
+
texts = []
|
| 230 |
+
for row in rows:
|
| 231 |
+
chat = row.get("chat") or row.get("dialogue") or []
|
| 232 |
+
|
| 233 |
+
if isinstance(chat, list):
|
| 234 |
+
for msg in chat:
|
| 235 |
+
msg = _parse_msg(msg)
|
| 236 |
+
role = msg.get("role", "")
|
| 237 |
+
content = msg.get("content", "")
|
| 238 |
+
|
| 239 |
+
if role in ("assistant", "ai") and content:
|
| 240 |
+
response_only = _extract_response_only(content)
|
| 241 |
+
if response_only and len(response_only) > 50:
|
| 242 |
+
texts.append(response_only)
|
| 243 |
+
return texts
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def extract_texts_text_field(rows, field="text"):
|
| 247 |
+
"""Extract from a text field."""
|
| 248 |
+
texts = []
|
| 249 |
+
for row in rows:
|
| 250 |
+
content = row.get(field, "")
|
| 251 |
+
if content and len(str(content)) > 50:
|
| 252 |
+
response_only = _extract_response_only(str(content))
|
| 253 |
+
if response_only and len(response_only) > 50:
|
| 254 |
+
texts.append(response_only)
|
| 255 |
+
return texts
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def extract_texts_sft(rows):
|
| 259 |
+
"""Extract from SFT format (prompt + response/completion)."""
|
| 260 |
+
texts = []
|
| 261 |
+
for row in rows:
|
| 262 |
+
response = row.get("response") or row.get("completion") or ""
|
| 263 |
+
if response and len(str(response)) > 50:
|
| 264 |
+
response_only = _extract_response_only(str(response))
|
| 265 |
+
if response_only and len(response_only) > 50:
|
| 266 |
+
texts.append(response_only)
|
| 267 |
+
return texts
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def extract_texts_qa(rows):
|
| 271 |
+
"""Extract from Q&A format (use answer as response)."""
|
| 272 |
+
texts = []
|
| 273 |
+
for row in rows:
|
| 274 |
+
answer = row.get("answer", "")
|
| 275 |
+
if answer and len(str(answer)) > 50:
|
| 276 |
+
response_only = _extract_response_only(str(answer))
|
| 277 |
+
if response_only and len(response_only) > 50:
|
| 278 |
+
texts.append(response_only)
|
| 279 |
+
return texts
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def extract_texts_messages_array(rows):
|
| 283 |
+
"""Extract from messages array format."""
|
| 284 |
+
texts = []
|
| 285 |
+
for row in rows:
|
| 286 |
+
if isinstance(row, list):
|
| 287 |
+
for msg in row:
|
| 288 |
+
msg = _parse_msg(msg)
|
| 289 |
+
role = msg.get("role", "")
|
| 290 |
+
content = msg.get("content", "")
|
| 291 |
+
|
| 292 |
+
if role in ("assistant", "ai", "model") and content:
|
| 293 |
+
response_only = _extract_response_only(content)
|
| 294 |
+
if response_only and len(response_only) > 50:
|
| 295 |
+
texts.append(response_only)
|
| 296 |
+
return texts
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def extract_texts_teichai_healer(rows):
|
| 300 |
+
"""Extract from TeichAI Healer-Alpha format (prompt + response fields)."""
|
| 301 |
+
texts = []
|
| 302 |
+
for row in rows:
|
| 303 |
+
response = row.get("response")
|
| 304 |
+
if response:
|
| 305 |
+
if isinstance(response, dict):
|
| 306 |
+
response = response.get("content") or response.get("text") or ""
|
| 307 |
+
if response and len(str(response)) > 50:
|
| 308 |
+
response_only = _extract_response_only(str(response))
|
| 309 |
+
if response_only and len(response_only) > 50:
|
| 310 |
+
texts.append(response_only)
|
| 311 |
+
return texts
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def _get_dataset_size(dataset_id, load_kwargs):
|
| 315 |
+
"""Get dataset size without loading all data."""
|
| 316 |
+
try:
|
| 317 |
+
ds = load_dataset(dataset_id, split="train", streaming=True, **load_kwargs)
|
| 318 |
+
return ds.info.num_rows
|
| 319 |
+
except Exception:
|
| 320 |
+
pass
|
| 321 |
+
try:
|
| 322 |
+
import pandas as pd
|
| 323 |
+
|
| 324 |
+
url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/default/train/0.parquet"
|
| 325 |
+
df = pd.read_parquet(url)
|
| 326 |
+
return len(df)
|
| 327 |
+
except Exception:
|
| 328 |
+
return 0
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def _streaming_download_with_progress(dataset_id, load_kwargs, progress_callback=None):
|
| 332 |
+
"""Download dataset using streaming with progress tracking."""
|
| 333 |
+
import pandas as pd
|
| 334 |
+
|
| 335 |
+
total_rows = _get_dataset_size(dataset_id, load_kwargs)
|
| 336 |
+
print(f"[PROGRESS] Dataset size: {total_rows} rows", flush=True)
|
| 337 |
+
|
| 338 |
+
if total_rows > 0 and progress_callback:
|
| 339 |
+
progress_callback(0, total_rows, "fetching_info")
|
| 340 |
+
print(f"[PROGRESS] Initial callback: 0/{total_rows}", flush=True)
|
| 341 |
+
|
| 342 |
+
try:
|
| 343 |
+
ds = load_dataset(dataset_id, split="train", streaming=True, **load_kwargs)
|
| 344 |
+
rows = []
|
| 345 |
+
for i, row in enumerate(tqdm(ds, desc="Downloading", unit="rows")):
|
| 346 |
+
rows.append(row)
|
| 347 |
+
if progress_callback and total_rows > 0:
|
| 348 |
+
progress_callback(i + 1, total_rows, "downloading")
|
| 349 |
+
if i % 100 == 0:
|
| 350 |
+
print(
|
| 351 |
+
f"[PROGRESS] Downloaded {i + 1}/{total_rows} ({100 * (i + 1) / total_rows:.1f}%)",
|
| 352 |
+
flush=True,
|
| 353 |
+
)
|
| 354 |
+
return rows, total_rows
|
| 355 |
+
except Exception as e:
|
| 356 |
+
print(f"[PROGRESS] Streaming failed: {e}", flush=True)
|
| 357 |
+
pass
|
| 358 |
+
|
| 359 |
+
try:
|
| 360 |
+
url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/default/train/0.parquet"
|
| 361 |
+
df = pd.read_parquet(url)
|
| 362 |
+
total = len(df)
|
| 363 |
+
if progress_callback:
|
| 364 |
+
progress_callback(0, total, "downloading")
|
| 365 |
+
rows = []
|
| 366 |
+
for i, row in enumerate(df.to_dict(orient="records")):
|
| 367 |
+
rows.append(row)
|
| 368 |
+
if progress_callback:
|
| 369 |
+
progress_callback(i + 1, total, "downloading")
|
| 370 |
+
return rows, total
|
| 371 |
+
except Exception as e:
|
| 372 |
+
raise e
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def _load_sample_rows(dataset_id, sample_size, load_kwargs):
|
| 376 |
+
"""Load just a few rows for format detection."""
|
| 377 |
+
try:
|
| 378 |
+
ds = load_dataset(dataset_id, split="train", streaming=True, **load_kwargs)
|
| 379 |
+
return [next(iter(ds)) for _ in range(sample_size)]
|
| 380 |
+
except Exception:
|
| 381 |
+
pass
|
| 382 |
+
try:
|
| 383 |
+
import pandas as pd
|
| 384 |
+
|
| 385 |
+
url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/default/train/0.parquet"
|
| 386 |
+
df = pd.read_parquet(url)
|
| 387 |
+
return df.head(sample_size).to_dict(orient="records")
|
| 388 |
+
except Exception:
|
| 389 |
+
return []
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def load_dataset_texts(
|
| 393 |
+
dataset_id,
|
| 394 |
+
max_samples=None,
|
| 395 |
+
sample_size=None,
|
| 396 |
+
progress_callback=None,
|
| 397 |
+
custom_format=None,
|
| 398 |
+
):
|
| 399 |
+
"""
|
| 400 |
+
Load a HuggingFace dataset and extract assistant response texts.
|
| 401 |
+
Returns: {
|
| 402 |
+
"texts": list of extracted texts,
|
| 403 |
+
"format": detected format name,
|
| 404 |
+
"format_info": format info dict,
|
| 405 |
+
"total_rows": total rows in dataset,
|
| 406 |
+
"supported": bool,
|
| 407 |
+
"error": error message if failed,
|
| 408 |
+
}
|
| 409 |
+
progress_callback: optional function(current, total, stage) -> None
|
| 410 |
+
stage can be: "fetching_info", "downloading", "extracting"
|
| 411 |
+
custom_format: optional custom format specification string
|
| 412 |
+
Examples:
|
| 413 |
+
- "column: response"
|
| 414 |
+
- "column: prompt, column: response"
|
| 415 |
+
- "pattern: user:, pattern: assistant:"
|
| 416 |
+
- "user:[startuser]assistant:[startassistant]"
|
| 417 |
+
"""
|
| 418 |
+
load_kwargs = {"token": HF_TOKEN} if HF_TOKEN else {}
|
| 419 |
+
rows = []
|
| 420 |
+
total_rows = 0
|
| 421 |
+
|
| 422 |
+
if sample_size:
|
| 423 |
+
total_rows = _get_dataset_size(dataset_id, load_kwargs)
|
| 424 |
+
if total_rows == 0:
|
| 425 |
+
return {
|
| 426 |
+
"texts": [],
|
| 427 |
+
"format": None,
|
| 428 |
+
"format_info": None,
|
| 429 |
+
"total_rows": 0,
|
| 430 |
+
"supported": False,
|
| 431 |
+
"error": "Dataset is empty",
|
| 432 |
+
}
|
| 433 |
+
|
| 434 |
+
rows = _load_sample_rows(dataset_id, sample_size, load_kwargs)
|
| 435 |
+
else:
|
| 436 |
+
if progress_callback:
|
| 437 |
+
try:
|
| 438 |
+
rows, total_rows = _streaming_download_with_progress(
|
| 439 |
+
dataset_id, load_kwargs, progress_callback
|
| 440 |
+
)
|
| 441 |
+
except Exception as e:
|
| 442 |
+
fallback_error = None
|
| 443 |
+
try:
|
| 444 |
+
import pandas as pd
|
| 445 |
+
|
| 446 |
+
url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/default/train/0.parquet"
|
| 447 |
+
df = pd.read_parquet(url)
|
| 448 |
+
total_rows = len(df)
|
| 449 |
+
if progress_callback:
|
| 450 |
+
progress_callback(0, total_rows, "downloading")
|
| 451 |
+
rows = []
|
| 452 |
+
for i, row in enumerate(df.to_dict(orient="records")):
|
| 453 |
+
rows.append(row)
|
| 454 |
+
if progress_callback:
|
| 455 |
+
progress_callback(i + 1, total_rows, "downloading")
|
| 456 |
+
except Exception as e2:
|
| 457 |
+
fallback_error = str(e2)
|
| 458 |
+
return {
|
| 459 |
+
"texts": [],
|
| 460 |
+
"format": None,
|
| 461 |
+
"format_info": None,
|
| 462 |
+
"total_rows": 0,
|
| 463 |
+
"supported": False,
|
| 464 |
+
"error": f"Failed to load: {e}. Parquet fallback also failed: {fallback_error}",
|
| 465 |
+
}
|
| 466 |
+
else:
|
| 467 |
+
try:
|
| 468 |
+
ds = load_dataset(dataset_id, split="train", **load_kwargs)
|
| 469 |
+
total_rows = len(ds)
|
| 470 |
+
rows = list(ds)
|
| 471 |
+
except Exception as e:
|
| 472 |
+
fallback_error = None
|
| 473 |
+
try:
|
| 474 |
+
import pandas as pd
|
| 475 |
+
|
| 476 |
+
url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/default/train/0.parquet"
|
| 477 |
+
df = pd.read_parquet(url)
|
| 478 |
+
total_rows = len(df)
|
| 479 |
+
rows = df.to_dict(orient="records")
|
| 480 |
+
except Exception as e2:
|
| 481 |
+
fallback_error = str(e2)
|
| 482 |
+
return {
|
| 483 |
+
"texts": [],
|
| 484 |
+
"format": None,
|
| 485 |
+
"format_info": None,
|
| 486 |
+
"total_rows": 0,
|
| 487 |
+
"supported": False,
|
| 488 |
+
"error": f"Failed to load: {e}. Parquet fallback also failed: {fallback_error}",
|
| 489 |
+
}
|
| 490 |
+
|
| 491 |
+
if not rows:
|
| 492 |
+
return {
|
| 493 |
+
"texts": [],
|
| 494 |
+
"format": None,
|
| 495 |
+
"format_info": None,
|
| 496 |
+
"total_rows": 0,
|
| 497 |
+
"supported": False,
|
| 498 |
+
"error": "Dataset is empty",
|
| 499 |
+
}
|
| 500 |
+
|
| 501 |
+
detect_rows = rows[:sample_size] if sample_size else rows
|
| 502 |
+
|
| 503 |
+
custom_format_spec = custom_format
|
| 504 |
+
if custom_format_spec and check_custom_format(detect_rows, custom_format_spec):
|
| 505 |
+
fmt_name = "custom"
|
| 506 |
+
fmt_info = {
|
| 507 |
+
"name": "Custom Format",
|
| 508 |
+
"description": f"Custom format: {custom_format_spec}",
|
| 509 |
+
"examples": [],
|
| 510 |
+
}
|
| 511 |
+
else:
|
| 512 |
+
fmt_name, fmt_info = detect_format(detect_rows, sample_size=sample_size or 10)
|
| 513 |
+
|
| 514 |
+
if fmt_name is None:
|
| 515 |
+
return {
|
| 516 |
+
"texts": [],
|
| 517 |
+
"format": None,
|
| 518 |
+
"format_info": None,
|
| 519 |
+
"total_rows": total_rows,
|
| 520 |
+
"supported": False,
|
| 521 |
+
"error": "Unknown dataset format. Supported formats: "
|
| 522 |
+
+ ", ".join(f["name"] for f in SUPPORTED_FORMATS.values()),
|
| 523 |
+
}
|
| 524 |
+
|
| 525 |
+
extractors = {
|
| 526 |
+
"teichai_healer": extract_texts_teichai_healer,
|
| 527 |
+
"teichai": extract_texts_conversations,
|
| 528 |
+
"conversations": extract_texts_conversations,
|
| 529 |
+
"chat": extract_texts_chat,
|
| 530 |
+
"text": lambda r: extract_texts_text_field(r, "text"),
|
| 531 |
+
"response": lambda r: extract_texts_text_field(r, "response")
|
| 532 |
+
or extract_texts_text_field(r, "output"),
|
| 533 |
+
"content": lambda r: extract_texts_text_field(r, "content"),
|
| 534 |
+
"messages": extract_texts_messages_array,
|
| 535 |
+
"sft": extract_texts_sft,
|
| 536 |
+
"qa": extract_texts_qa,
|
| 537 |
+
"combined": lambda r: (
|
| 538 |
+
extract_texts_text_field(r, "output")
|
| 539 |
+
or extract_texts_text_field(r, "outputs")
|
| 540 |
+
or extract_texts_text_field(r, "generated")
|
| 541 |
+
or extract_texts_text_field(r, "completion")
|
| 542 |
+
or extract_texts_text_field(r, "combined")
|
| 543 |
+
or extract_texts_text_field(r, "data")
|
| 544 |
+
or extract_texts_text_field(r, "example")
|
| 545 |
+
),
|
| 546 |
+
"completion": lambda r: extract_texts_text_field(r, "completion"),
|
| 547 |
+
"generations": lambda r: (
|
| 548 |
+
extract_texts_text_field(r, "generations")
|
| 549 |
+
or extract_texts_text_field(r, "generation")
|
| 550 |
+
),
|
| 551 |
+
"custom": lambda r: extract_texts_custom(r, custom_format_spec),
|
| 552 |
+
}
|
| 553 |
+
|
| 554 |
+
extractor = extractors.get(fmt_name)
|
| 555 |
+
texts = extractor(rows) if extractor else []
|
| 556 |
+
|
| 557 |
+
if max_samples and len(texts) > max_samples:
|
| 558 |
+
random.seed(42)
|
| 559 |
+
texts = random.sample(texts, max_samples)
|
| 560 |
+
|
| 561 |
+
return {
|
| 562 |
+
"texts": texts,
|
| 563 |
+
"format": fmt_name,
|
| 564 |
+
"format_info": fmt_info,
|
| 565 |
+
"total_rows": total_rows,
|
| 566 |
+
"supported": True,
|
| 567 |
+
"error": None,
|
| 568 |
+
}
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
def parse_custom_format_spec(spec):
|
| 572 |
+
"""
|
| 573 |
+
Parse custom format specification.
|
| 574 |
+
|
| 575 |
+
Supported formats:
|
| 576 |
+
- "column: <field_name>" - extract single field as text
|
| 577 |
+
- "column: <user_col>, column: <assistant_col>" - extract from two columns (user/assistant)
|
| 578 |
+
- "pattern: <start_marker>user<end_marker>, pattern: <start_marker>assistant<end_marker>" - use regex patterns
|
| 579 |
+
- "delimiter: <delim>" - use delimiter to split columns
|
| 580 |
+
|
| 581 |
+
Examples:
|
| 582 |
+
- "column: response"
|
| 583 |
+
- "column: prompt, column: response"
|
| 584 |
+
- "pattern: user:, pattern: assistant:"
|
| 585 |
+
- "user:[startuser]assistant:[startassistant]"
|
| 586 |
+
"""
|
| 587 |
+
if not spec:
|
| 588 |
+
return None
|
| 589 |
+
|
| 590 |
+
spec = spec.strip()
|
| 591 |
+
result = {
|
| 592 |
+
"type": None,
|
| 593 |
+
"user_field": None,
|
| 594 |
+
"assistant_field": None,
|
| 595 |
+
"user_pattern": None,
|
| 596 |
+
"assistant_pattern": None,
|
| 597 |
+
}
|
| 598 |
+
|
| 599 |
+
if spec.startswith("column:") or spec.startswith("col:"):
|
| 600 |
+
cols_spec = spec.replace("column:", "").replace("col:", "").strip()
|
| 601 |
+
if "," in cols_spec:
|
| 602 |
+
parts = [p.strip() for p in cols_spec.split(",")]
|
| 603 |
+
if len(parts) >= 2:
|
| 604 |
+
result["type"] = "two_column"
|
| 605 |
+
result["user_field"] = parts[0]
|
| 606 |
+
result["assistant_field"] = parts[1]
|
| 607 |
+
else:
|
| 608 |
+
result["type"] = "single_column"
|
| 609 |
+
result["assistant_field"] = cols_spec
|
| 610 |
+
return result
|
| 611 |
+
|
| 612 |
+
if spec.startswith("pattern:") or spec.startswith("regex:"):
|
| 613 |
+
patterns_spec = spec.replace("pattern:", "").replace("regex:", "").strip()
|
| 614 |
+
if "," in patterns_spec:
|
| 615 |
+
parts = [p.strip() for p in patterns_spec.split(",")]
|
| 616 |
+
if len(parts) >= 2:
|
| 617 |
+
result["type"] = "two_pattern"
|
| 618 |
+
result["user_pattern"] = parts[0]
|
| 619 |
+
result["assistant_pattern"] = parts[1]
|
| 620 |
+
else:
|
| 621 |
+
result["type"] = "single_pattern"
|
| 622 |
+
result["assistant_pattern"] = patterns_spec
|
| 623 |
+
return result
|
| 624 |
+
|
| 625 |
+
if "user:" in spec.lower() and "assistant:" in spec.lower():
|
| 626 |
+
import re
|
| 627 |
+
|
| 628 |
+
user_match = re.search(
|
| 629 |
+
r"user:\s*(\[.*?\]|(?:(?!\s+assistant:).)+)",
|
| 630 |
+
spec,
|
| 631 |
+
re.IGNORECASE | re.DOTALL,
|
| 632 |
+
)
|
| 633 |
+
assistant_match = re.search(
|
| 634 |
+
r"assistant:\s*(\[.*?\]|(?:(?:\s+user:|$).)+)",
|
| 635 |
+
spec,
|
| 636 |
+
re.IGNORECASE | re.DOTALL,
|
| 637 |
+
)
|
| 638 |
+
|
| 639 |
+
if user_match and assistant_match:
|
| 640 |
+
result["type"] = "two_pattern"
|
| 641 |
+
result["user_pattern"] = user_match.group(1).strip()
|
| 642 |
+
result["assistant_pattern"] = assistant_match.group(1).strip()
|
| 643 |
+
return result
|
| 644 |
+
|
| 645 |
+
if "[startuser]" in spec and "[startassistant]" in spec:
|
| 646 |
+
result["type"] = "two_pattern"
|
| 647 |
+
result["user_pattern"] = re.escape("[startuser]")
|
| 648 |
+
result["assistant_pattern"] = re.escape("[startassistant]")
|
| 649 |
+
return result
|
| 650 |
+
|
| 651 |
+
if "," in spec:
|
| 652 |
+
parts = [p.strip() for p in spec.split(",")]
|
| 653 |
+
if len(parts) >= 2:
|
| 654 |
+
result["type"] = "two_column"
|
| 655 |
+
result["user_field"] = parts[0]
|
| 656 |
+
result["assistant_field"] = parts[1]
|
| 657 |
+
return result
|
| 658 |
+
|
| 659 |
+
result["type"] = "single_column"
|
| 660 |
+
result["assistant_field"] = spec
|
| 661 |
+
return result
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
def extract_texts_custom(rows, format_spec):
|
| 665 |
+
"""Extract texts using custom format specification."""
|
| 666 |
+
parsed = parse_custom_format_spec(format_spec)
|
| 667 |
+
if not parsed or not parsed.get("type"):
|
| 668 |
+
return []
|
| 669 |
+
|
| 670 |
+
texts = []
|
| 671 |
+
|
| 672 |
+
if parsed["type"] == "single_column":
|
| 673 |
+
field = parsed["assistant_field"]
|
| 674 |
+
for row in rows:
|
| 675 |
+
content = row.get(field, "")
|
| 676 |
+
if content and len(str(content)) > 50:
|
| 677 |
+
response_only = _extract_response_only(str(content))
|
| 678 |
+
if response_only and len(response_only) > 50:
|
| 679 |
+
texts.append(response_only)
|
| 680 |
+
|
| 681 |
+
elif parsed["type"] == "two_column":
|
| 682 |
+
user_field = parsed.get("user_field")
|
| 683 |
+
assistant_field = parsed["assistant_field"]
|
| 684 |
+
for row in rows:
|
| 685 |
+
user_content = row.get(user_field, "") if user_field else ""
|
| 686 |
+
assistant_content = row.get(assistant_field, "")
|
| 687 |
+
if assistant_content and len(str(assistant_content)) > 50:
|
| 688 |
+
response_only = _extract_response_only(str(assistant_content))
|
| 689 |
+
if response_only and len(response_only) > 50:
|
| 690 |
+
texts.append(response_only)
|
| 691 |
+
|
| 692 |
+
elif parsed["type"] == "single_pattern":
|
| 693 |
+
pattern = parsed.get("assistant_pattern")
|
| 694 |
+
if pattern:
|
| 695 |
+
try:
|
| 696 |
+
regex = re.compile(pattern, re.DOTALL | re.IGNORECASE)
|
| 697 |
+
for row in rows:
|
| 698 |
+
row_str = str(row)
|
| 699 |
+
match = regex.search(row_str)
|
| 700 |
+
if match:
|
| 701 |
+
content = match.group(1) if match.groups() else match.group(0)
|
| 702 |
+
if content and len(content) > 50:
|
| 703 |
+
response_only = _extract_response_only(content)
|
| 704 |
+
if response_only and len(response_only) > 50:
|
| 705 |
+
texts.append(response_only)
|
| 706 |
+
except re.error:
|
| 707 |
+
pass
|
| 708 |
+
|
| 709 |
+
elif parsed["type"] == "two_pattern":
|
| 710 |
+
user_pattern = parsed.get("user_pattern")
|
| 711 |
+
assistant_pattern = parsed.get("assistant_pattern")
|
| 712 |
+
if assistant_pattern:
|
| 713 |
+
try:
|
| 714 |
+
user_regex = (
|
| 715 |
+
re.compile(user_pattern, re.DOTALL | re.IGNORECASE)
|
| 716 |
+
if user_pattern
|
| 717 |
+
else None
|
| 718 |
+
)
|
| 719 |
+
assistant_regex = re.compile(
|
| 720 |
+
assistant_pattern, re.DOTALL | re.IGNORECASE
|
| 721 |
+
)
|
| 722 |
+
|
| 723 |
+
for row in rows:
|
| 724 |
+
row_str = str(row)
|
| 725 |
+
match = assistant_regex.search(row_str)
|
| 726 |
+
if match:
|
| 727 |
+
content = match.group(1) if match.groups() else match.group(0)
|
| 728 |
+
if content and len(content) > 50:
|
| 729 |
+
response_only = _extract_response_only(content)
|
| 730 |
+
if response_only and len(response_only) > 50:
|
| 731 |
+
texts.append(response_only)
|
| 732 |
+
except re.error:
|
| 733 |
+
pass
|
| 734 |
+
|
| 735 |
+
return texts
|
| 736 |
+
|
| 737 |
+
|
| 738 |
+
def check_custom_format(rows, format_spec):
|
| 739 |
+
"""Check if custom format applies to the dataset."""
|
| 740 |
+
parsed = parse_custom_format_spec(format_spec)
|
| 741 |
+
if not parsed or not parsed.get("type"):
|
| 742 |
+
return False
|
| 743 |
+
|
| 744 |
+
if not rows:
|
| 745 |
+
return False
|
| 746 |
+
|
| 747 |
+
sample = rows[0]
|
| 748 |
+
|
| 749 |
+
if parsed["type"] == "single_column":
|
| 750 |
+
return parsed.get("assistant_field") in sample
|
| 751 |
+
|
| 752 |
+
if parsed["type"] == "two_column":
|
| 753 |
+
return parsed.get("assistant_field") in sample
|
| 754 |
+
|
| 755 |
+
if parsed["type"] in ("single_pattern", "two_pattern"):
|
| 756 |
+
pattern = parsed.get("assistant_pattern")
|
| 757 |
+
if pattern:
|
| 758 |
+
try:
|
| 759 |
+
regex = re.compile(pattern, re.DOTALL | re.IGNORECASE)
|
| 760 |
+
return regex.search(str(sample)) is not None
|
| 761 |
+
except re.error:
|
| 762 |
+
pass
|
| 763 |
+
|
| 764 |
+
return False
|
| 765 |
+
|
| 766 |
+
|
| 767 |
+
def get_supported_formats():
|
| 768 |
+
"""Return list of supported format info."""
|
| 769 |
+
return SUPPORTED_FORMATS
|
evaluate_dataset.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AIFinder Dataset Evaluator with Server
|
| 3 |
+
Runs the Flask server, then allows interactive dataset input.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
import time
|
| 9 |
+
import argparse
|
| 10 |
+
import random
|
| 11 |
+
import threading
|
| 12 |
+
import requests
|
| 13 |
+
from collections import defaultdict
|
| 14 |
+
|
| 15 |
+
from datasets import load_dataset
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
|
| 18 |
+
from config import MODEL_DIR
|
| 19 |
+
from inference import AIFinder
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 23 |
+
SERVER_URL = "http://localhost:7860"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def start_server():
|
| 27 |
+
"""Start Flask server in background thread."""
|
| 28 |
+
os.chdir(os.path.dirname(os.path.abspath(__file__)))
|
| 29 |
+
from app import app, load_models
|
| 30 |
+
|
| 31 |
+
load_models()
|
| 32 |
+
print("Server started on http://localhost:7860")
|
| 33 |
+
app.run(host="0.0.0.0", port=7860, debug=False, use_reloader=False)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def wait_for_server(timeout=30):
|
| 37 |
+
"""Wait for server to be ready."""
|
| 38 |
+
start = time.time()
|
| 39 |
+
while time.time() - start < timeout:
|
| 40 |
+
try:
|
| 41 |
+
resp = requests.get(f"{SERVER_URL}/api/status", timeout=2)
|
| 42 |
+
if resp.status_code == 200:
|
| 43 |
+
return True
|
| 44 |
+
except requests.exceptions.RequestException:
|
| 45 |
+
pass
|
| 46 |
+
time.sleep(1)
|
| 47 |
+
return False
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _parse_msg(msg):
|
| 51 |
+
"""Parse a message that may be a dict or a JSON string."""
|
| 52 |
+
if isinstance(msg, dict):
|
| 53 |
+
return msg
|
| 54 |
+
if isinstance(msg, str):
|
| 55 |
+
try:
|
| 56 |
+
import json
|
| 57 |
+
|
| 58 |
+
parsed = json.loads(msg)
|
| 59 |
+
if isinstance(parsed, dict):
|
| 60 |
+
return parsed
|
| 61 |
+
except (ValueError, Exception):
|
| 62 |
+
pass
|
| 63 |
+
return {}
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _extract_response_only(content):
|
| 67 |
+
"""Extract only the final response, stripping CoT blocks."""
|
| 68 |
+
import re
|
| 69 |
+
|
| 70 |
+
if not content:
|
| 71 |
+
return ""
|
| 72 |
+
think_match = re.search(r"</?think(?:ing)?>(.*)$", content, re.DOTALL)
|
| 73 |
+
if think_match:
|
| 74 |
+
response = think_match.group(1).strip()
|
| 75 |
+
if response:
|
| 76 |
+
return response
|
| 77 |
+
return content
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def extract_texts_from_dataset(dataset_id, max_samples=None):
|
| 81 |
+
"""Extract assistant response texts from a HuggingFace dataset."""
|
| 82 |
+
print(f"\nLoading dataset: {dataset_id}")
|
| 83 |
+
|
| 84 |
+
load_kwargs = {"token": HF_TOKEN} if HF_TOKEN else {}
|
| 85 |
+
rows = []
|
| 86 |
+
|
| 87 |
+
try:
|
| 88 |
+
ds = load_dataset(dataset_id, split="train", **load_kwargs)
|
| 89 |
+
rows = list(ds)
|
| 90 |
+
except Exception as e:
|
| 91 |
+
print(f"Failed to load dataset: {e}")
|
| 92 |
+
try:
|
| 93 |
+
import pandas as pd
|
| 94 |
+
|
| 95 |
+
url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/default/train/0.parquet"
|
| 96 |
+
df = pd.read_parquet(url)
|
| 97 |
+
rows = df.to_dict(orient="records")
|
| 98 |
+
except Exception as e2:
|
| 99 |
+
print(f"Parquet fallback also failed: {e2}")
|
| 100 |
+
return []
|
| 101 |
+
|
| 102 |
+
texts = []
|
| 103 |
+
for row in rows:
|
| 104 |
+
convos = row.get("conversations") or row.get("messages") or []
|
| 105 |
+
|
| 106 |
+
if not convos:
|
| 107 |
+
continue
|
| 108 |
+
|
| 109 |
+
for msg in convos:
|
| 110 |
+
msg = _parse_msg(msg)
|
| 111 |
+
role = msg.get("role", "")
|
| 112 |
+
content = msg.get("content", "")
|
| 113 |
+
|
| 114 |
+
if role in ("assistant", "gpt", "model") and content:
|
| 115 |
+
response_only = _extract_response_only(content)
|
| 116 |
+
if response_only and len(response_only) > 50:
|
| 117 |
+
texts.append(response_only)
|
| 118 |
+
|
| 119 |
+
if max_samples and len(texts) > max_samples:
|
| 120 |
+
random.seed(42)
|
| 121 |
+
texts = random.sample(texts, max_samples)
|
| 122 |
+
|
| 123 |
+
return texts
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def evaluate_dataset(texts):
|
| 127 |
+
"""Evaluate all texts via API and aggregate results."""
|
| 128 |
+
results = {
|
| 129 |
+
"total": len(texts),
|
| 130 |
+
"provider_counts": defaultdict(int),
|
| 131 |
+
"confidences": defaultdict(list),
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
for text in tqdm(texts, desc="Evaluating"):
|
| 135 |
+
try:
|
| 136 |
+
resp = requests.post(
|
| 137 |
+
f"{SERVER_URL}/api/classify",
|
| 138 |
+
json={"text": text, "top_n": 5},
|
| 139 |
+
timeout=30,
|
| 140 |
+
)
|
| 141 |
+
if resp.status_code == 200:
|
| 142 |
+
result = resp.json()
|
| 143 |
+
pred_provider = result.get("provider")
|
| 144 |
+
confidence = result.get("confidence", 0) / 100.0
|
| 145 |
+
|
| 146 |
+
if pred_provider:
|
| 147 |
+
results["provider_counts"][pred_provider] += 1
|
| 148 |
+
results["confidences"][pred_provider].append(confidence)
|
| 149 |
+
except Exception as e:
|
| 150 |
+
print(f"Error: {e}")
|
| 151 |
+
continue
|
| 152 |
+
|
| 153 |
+
return results
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def print_results(results):
|
| 157 |
+
"""Print aggregated evaluation results."""
|
| 158 |
+
total = results["total"]
|
| 159 |
+
print("\n" + "=" * 60)
|
| 160 |
+
print(f"EVALUATION RESULTS ({total} samples)")
|
| 161 |
+
print("=" * 60)
|
| 162 |
+
|
| 163 |
+
print("\n--- Predicted Provider Distribution ---")
|
| 164 |
+
for provider, count in sorted(
|
| 165 |
+
results["provider_counts"].items(), key=lambda x: -x[1]
|
| 166 |
+
):
|
| 167 |
+
pct = (count / total) * 100
|
| 168 |
+
avg_conf = sum(results["confidences"][provider]) / len(
|
| 169 |
+
results["confidences"][provider]
|
| 170 |
+
)
|
| 171 |
+
print(
|
| 172 |
+
f" {provider}: {count} ({pct:.1f}%) - Avg confidence: {avg_conf * 100:.1f}%"
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
if results["confidences"]:
|
| 176 |
+
print("\n--- Top Providers (by cumulative confidence) ---")
|
| 177 |
+
provider_scores = {}
|
| 178 |
+
for provider, confs in results["confidences"].items():
|
| 179 |
+
if confs:
|
| 180 |
+
avg_conf = sum(confs) / len(confs)
|
| 181 |
+
count = results["provider_counts"][provider]
|
| 182 |
+
provider_scores[provider] = avg_conf * count
|
| 183 |
+
|
| 184 |
+
for provider, score in sorted(provider_scores.items(), key=lambda x: -x[1])[:3]:
|
| 185 |
+
print(f" {provider}: {score:.2f}")
|
| 186 |
+
|
| 187 |
+
print("\n" + "=" * 60)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def main():
|
| 191 |
+
parser = argparse.ArgumentParser(
|
| 192 |
+
description="AIFinder Dataset Evaluator with Server"
|
| 193 |
+
)
|
| 194 |
+
parser.add_argument(
|
| 195 |
+
"--max-samples", type=int, default=None, help="Max samples to test"
|
| 196 |
+
)
|
| 197 |
+
args = parser.parse_args()
|
| 198 |
+
|
| 199 |
+
print("Starting AIFinder server...")
|
| 200 |
+
server_thread = threading.Thread(target=start_server, daemon=True)
|
| 201 |
+
server_thread.start()
|
| 202 |
+
|
| 203 |
+
print("Waiting for server...")
|
| 204 |
+
if not wait_for_server():
|
| 205 |
+
print("Server failed to start!")
|
| 206 |
+
sys.exit(1)
|
| 207 |
+
|
| 208 |
+
print("\n" + "=" * 60)
|
| 209 |
+
print("AIFinder Server Ready!")
|
| 210 |
+
print("=" * 60)
|
| 211 |
+
print(f"Server running at: {SERVER_URL}")
|
| 212 |
+
print("Enter a HuggingFace dataset ID to evaluate.")
|
| 213 |
+
print("Examples: ianncity/Hunter-Alpha-SFT-300000x")
|
| 214 |
+
print("Type 'quit' or 'exit' to stop.")
|
| 215 |
+
print("=" * 60 + "\n")
|
| 216 |
+
|
| 217 |
+
while True:
|
| 218 |
+
try:
|
| 219 |
+
dataset_id = input("Dataset ID: ").strip()
|
| 220 |
+
|
| 221 |
+
if dataset_id.lower() in ("quit", "exit", "q"):
|
| 222 |
+
print("Goodbye!")
|
| 223 |
+
break
|
| 224 |
+
|
| 225 |
+
if not dataset_id:
|
| 226 |
+
continue
|
| 227 |
+
|
| 228 |
+
texts = extract_texts_from_dataset(dataset_id, args.max_samples)
|
| 229 |
+
|
| 230 |
+
if not texts:
|
| 231 |
+
print("No valid texts found in dataset.")
|
| 232 |
+
continue
|
| 233 |
+
|
| 234 |
+
print(f"Testing {len(texts)} responses...")
|
| 235 |
+
results = evaluate_dataset(texts)
|
| 236 |
+
print_results(results)
|
| 237 |
+
|
| 238 |
+
except KeyboardInterrupt:
|
| 239 |
+
print("\nGoodbye!")
|
| 240 |
+
break
|
| 241 |
+
except Exception as e:
|
| 242 |
+
print(f"Error: {e}")
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
if __name__ == "__main__":
|
| 246 |
+
main()
|
features.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
"""
|
| 2 |
AIFinder Feature Extraction
|
| 3 |
-
TF-IDF and stylometric features for AI model detection.
|
| 4 |
"""
|
| 5 |
|
| 6 |
import re
|
|
@@ -12,25 +12,198 @@ from sklearn.preprocessing import MaxAbsScaler
|
|
| 12 |
|
| 13 |
from config import TFIDF_WORD_PARAMS, TFIDF_CHAR_PARAMS
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
def strip_cot(text):
|
| 17 |
-
|
| 18 |
-
return text.strip()
|
| 19 |
|
| 20 |
|
| 21 |
def strip_markdown(text):
|
| 22 |
-
text =
|
| 23 |
-
text =
|
| 24 |
-
text =
|
| 25 |
-
text =
|
| 26 |
-
text =
|
| 27 |
-
text =
|
| 28 |
-
text =
|
| 29 |
-
text =
|
| 30 |
-
text =
|
| 31 |
-
text =
|
| 32 |
-
text =
|
| 33 |
-
text =
|
| 34 |
return text.strip()
|
| 35 |
|
| 36 |
|
|
@@ -39,18 +212,14 @@ class StylometricFeatures(BaseEstimator, TransformerMixin):
|
|
| 39 |
return self
|
| 40 |
|
| 41 |
def transform(self, X):
|
| 42 |
-
|
| 43 |
-
for text in X:
|
| 44 |
-
features.append(self._extract(text))
|
| 45 |
-
return csr_matrix(np.array(features, dtype=np.float32))
|
| 46 |
|
| 47 |
def _extract(self, text):
|
| 48 |
-
words = text.split()
|
| 49 |
n_chars = max(len(text), 1)
|
|
|
|
| 50 |
n_words = max(len(words), 1)
|
| 51 |
|
| 52 |
-
sentences = re.split(r"[.!?]+", text)
|
| 53 |
-
sentences = [s.strip() for s in sentences if s.strip()]
|
| 54 |
n_sentences = max(len(sentences), 1)
|
| 55 |
|
| 56 |
paragraphs = text.split("\n\n")
|
|
@@ -58,17 +227,21 @@ class StylometricFeatures(BaseEstimator, TransformerMixin):
|
|
| 58 |
n_paragraphs = len(non_empty_paras)
|
| 59 |
|
| 60 |
lines = text.split("\n")
|
| 61 |
-
non_empty_lines = [
|
| 62 |
n_lines = max(len(non_empty_lines), 1)
|
| 63 |
|
| 64 |
-
# === Word-level stats ===
|
| 65 |
word_lens = [len(w) for w in words]
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
avg_sent_len = n_words / n_sentences
|
| 70 |
|
| 71 |
-
# === Punctuation density ===
|
| 72 |
n_commas = text.count(",") / n_chars
|
| 73 |
n_semicolons = text.count(";") / n_chars
|
| 74 |
n_colons = text.count(":") / n_chars
|
|
@@ -84,16 +257,14 @@ class StylometricFeatures(BaseEstimator, TransformerMixin):
|
|
| 84 |
comma_period_ratio = n_commas / (n_period + 0.001)
|
| 85 |
excl_question_ratio = n_exclaim / (n_question + 0.001)
|
| 86 |
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
n_tables = len(re.findall(r"\|.*\|", text)) / n_sentences
|
| 95 |
|
| 96 |
-
# === Whitespace & structure ===
|
| 97 |
newline_density = text.count("\n") / n_chars
|
| 98 |
double_newline_ratio = text.count("\n\n") / (text.count("\n") + 1)
|
| 99 |
uppercase_ratio = sum(1 for c in text if c.isupper()) / n_chars
|
|
@@ -103,59 +274,40 @@ class StylometricFeatures(BaseEstimator, TransformerMixin):
|
|
| 103 |
unique_chars = len(set(text)) / n_chars
|
| 104 |
unique_chars_ratio = len(set(text.lower())) / n_chars
|
| 105 |
|
| 106 |
-
|
| 107 |
-
sent_lens = [len(s.split()) for s in sentences]
|
| 108 |
-
sent_len_std = np.std(sent_lens) if len(sent_lens) > 1 else 0
|
| 109 |
sent_len_max = max(sent_lens) if sent_lens else 0
|
| 110 |
sent_len_min = min(sent_lens) if sent_lens else 0
|
| 111 |
-
sent_len_median = np.median(sent_lens) if sent_lens else 0
|
| 112 |
sent_len_range = sent_len_max - sent_len_min
|
| 113 |
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
has_url = 1.0 if re.search(r"https?://", text) else 0.0
|
| 119 |
|
| 120 |
-
# === Pronoun and person features ===
|
| 121 |
words_lower = [w.lower().strip(".,!?;:'\"()[]{}") for w in words]
|
| 122 |
-
|
| 123 |
-
first_person = {
|
| 124 |
-
"i",
|
| 125 |
-
"me",
|
| 126 |
-
"my",
|
| 127 |
-
"mine",
|
| 128 |
-
"myself",
|
| 129 |
-
"we",
|
| 130 |
-
"us",
|
| 131 |
-
"our",
|
| 132 |
-
"ours",
|
| 133 |
-
"ourselves",
|
| 134 |
-
}
|
| 135 |
-
second_person = {"you", "your", "yours", "yourself", "yourselves"}
|
| 136 |
-
third_person = {"he", "she", "it", "they", "them", "his", "her", "its", "their"}
|
| 137 |
-
|
| 138 |
-
first_person_ratio = sum(1 for w in words_lower if w in first_person) / n_words
|
| 139 |
second_person_ratio = (
|
| 140 |
-
sum(1 for w in words_lower if w in
|
| 141 |
)
|
| 142 |
-
third_person_ratio = sum(1 for w in words_lower if w in
|
| 143 |
|
| 144 |
-
# === Vocabulary richness ===
|
| 145 |
unique_words = len(set(words_lower))
|
| 146 |
-
ttr = unique_words / n_words if n_words > 0 else 0
|
| 147 |
-
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
-
contraction_count = len(
|
| 151 |
-
contraction_ratio = contraction_count / n_words if n_words > 0 else 0
|
| 152 |
|
| 153 |
-
# === Sentence starters ===
|
| 154 |
sentences_starters = [
|
| 155 |
s.split()[0].lower() if s.split() else "" for s in sentences
|
| 156 |
]
|
| 157 |
starter_vocab = (
|
| 158 |
-
len(set(sentences_starters)) / n_sentences if n_sentences > 0 else 0
|
| 159 |
)
|
| 160 |
|
| 161 |
and_starts = sum(1 for s in sentences_starters if s == "and") / n_sentences
|
|
@@ -170,281 +322,119 @@ class StylometricFeatures(BaseEstimator, TransformerMixin):
|
|
| 170 |
/ n_sentences
|
| 171 |
)
|
| 172 |
|
| 173 |
-
# === Word length distributions ===
|
| 174 |
short_word_ratio = sum(1 for w in words_lower if len(w) <= 2) / n_words
|
| 175 |
medium_word_ratio = sum(1 for w in words_lower if 3 <= len(w) <= 6) / n_words
|
| 176 |
long_word_ratio = sum(1 for w in words_lower if len(w) >= 7) / n_words
|
| 177 |
very_long_word_ratio = sum(1 for w in words_lower if len(w) >= 10) / n_words
|
| 178 |
|
| 179 |
-
# === Paragraph stats ===
|
| 180 |
para_lens = (
|
| 181 |
[len(p.split()) for p in non_empty_paras] if non_empty_paras else [0]
|
| 182 |
)
|
| 183 |
avg_para_len = np.mean(para_lens)
|
| 184 |
-
para_len_std = np.std(para_lens) if len(para_lens) > 1 else 0
|
| 185 |
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
"
|
| 194 |
-
|
| 195 |
-
"because",
|
| 196 |
-
"although",
|
| 197 |
-
"while",
|
| 198 |
-
"if",
|
| 199 |
-
"when",
|
| 200 |
-
"where",
|
| 201 |
-
}
|
| 202 |
-
discourse = {
|
| 203 |
-
"however",
|
| 204 |
-
"therefore",
|
| 205 |
-
"moreover",
|
| 206 |
-
"furthermore",
|
| 207 |
-
"nevertheless",
|
| 208 |
-
"consequently",
|
| 209 |
-
"thus",
|
| 210 |
-
"hence",
|
| 211 |
-
}
|
| 212 |
-
hedging = {
|
| 213 |
-
"perhaps",
|
| 214 |
-
"maybe",
|
| 215 |
-
"might",
|
| 216 |
-
"could",
|
| 217 |
-
"possibly",
|
| 218 |
-
"seemingly",
|
| 219 |
-
"apparently",
|
| 220 |
-
"arguably",
|
| 221 |
-
"potentially",
|
| 222 |
-
}
|
| 223 |
-
certainty = {
|
| 224 |
-
"definitely",
|
| 225 |
-
"certainly",
|
| 226 |
-
"absolutely",
|
| 227 |
-
"clearly",
|
| 228 |
-
"obviously",
|
| 229 |
-
"undoubtedly",
|
| 230 |
-
"indeed",
|
| 231 |
-
"surely",
|
| 232 |
-
}
|
| 233 |
-
transition = {
|
| 234 |
-
"additionally",
|
| 235 |
-
"meanwhile",
|
| 236 |
-
"subsequently",
|
| 237 |
-
"alternatively",
|
| 238 |
-
"specifically",
|
| 239 |
-
"notably",
|
| 240 |
-
"importantly",
|
| 241 |
-
"essentially",
|
| 242 |
-
}
|
| 243 |
-
|
| 244 |
-
conjunction_ratio = sum(1 for w in words_lower if w in conjunctions) / n_words
|
| 245 |
-
discourse_ratio = sum(1 for w in words_lower if w in discourse) / n_words
|
| 246 |
-
hedging_ratio = sum(1 for w in words_lower if w in hedging) / n_words
|
| 247 |
-
certainty_ratio = sum(1 for w in words_lower if w in certainty) / n_words
|
| 248 |
-
transition_ratio = sum(1 for w in words_lower if w in transition) / n_words
|
| 249 |
|
| 250 |
-
# === Question patterns ===
|
| 251 |
question_starts = sum(
|
| 252 |
-
1
|
| 253 |
-
for s in sentences
|
| 254 |
-
if s
|
| 255 |
-
and s.strip()
|
| 256 |
-
.lower()
|
| 257 |
-
.startswith(("who", "what", "when", "where", "why", "how"))
|
| 258 |
)
|
| 259 |
|
| 260 |
-
# === List features ===
|
| 261 |
has_list = 1.0 if n_bullet > 0 or n_numbered > 0 else 0.0
|
| 262 |
list_items = n_bullet + n_numbered
|
| 263 |
|
| 264 |
-
|
| 265 |
-
emoji_count = len(re.findall(r"[\U00010000-\U0010ffff]", text))
|
| 266 |
has_emoji = 1.0 if emoji_count > 0 else 0.0
|
| 267 |
|
| 268 |
-
# === Specific style markers ===
|
| 269 |
-
# ALL CAPS words (emphasis style)
|
| 270 |
all_caps_words = sum(
|
| 271 |
1 for w in words if len(w) > 1 and w.isupper() and w.isalpha()
|
| 272 |
)
|
| 273 |
all_caps_ratio = all_caps_words / n_words
|
| 274 |
|
| 275 |
-
|
| 276 |
-
paren_count = len(re.findall(r"\([^)]+\)", text))
|
| 277 |
paren_ratio = paren_count / n_sentences
|
| 278 |
|
| 279 |
-
# Rhetorical questions (sentences ending with ?)
|
| 280 |
rhetorical_q = sum(1 for s in text.split("\n") if s.strip().endswith("?"))
|
| 281 |
rhetorical_ratio = rhetorical_q / n_sentences
|
| 282 |
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
"okay",
|
| 286 |
-
"ok",
|
| 287 |
-
"hey",
|
| 288 |
-
"hi",
|
| 289 |
-
"cool",
|
| 290 |
-
"awesome",
|
| 291 |
-
"wow",
|
| 292 |
-
"basically",
|
| 293 |
-
"actually",
|
| 294 |
-
"literally",
|
| 295 |
-
"right",
|
| 296 |
-
"yeah",
|
| 297 |
-
}
|
| 298 |
-
casual_ratio = sum(1 for w in words_lower if w in casual_markers) / n_words
|
| 299 |
-
|
| 300 |
-
# Formal markers
|
| 301 |
-
formal_markers = {
|
| 302 |
-
"regarding",
|
| 303 |
-
"concerning",
|
| 304 |
-
"pertaining",
|
| 305 |
-
"aforementioned",
|
| 306 |
-
"respectively",
|
| 307 |
-
"accordingly",
|
| 308 |
-
"henceforth",
|
| 309 |
-
"whereby",
|
| 310 |
-
"notwithstanding",
|
| 311 |
-
"pursuant",
|
| 312 |
-
}
|
| 313 |
-
formal_ratio = sum(1 for w in words_lower if w in formal_markers) / n_words
|
| 314 |
|
| 315 |
-
|
| 316 |
-
chinese_chars = len(re.findall(r"[\u4e00-\u9fff]", text))
|
| 317 |
has_chinese = 1.0 if chinese_chars > 0 else 0.0
|
| 318 |
chinese_ratio = chinese_chars / n_chars
|
| 319 |
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
1.0
|
| 323 |
-
if re.search(
|
| 324 |
-
r"\b(I'm|I am)\s+(an?\s+)?(AI|language model|assistant|chatbot)\b",
|
| 325 |
-
text,
|
| 326 |
-
re.IGNORECASE,
|
| 327 |
-
)
|
| 328 |
-
else 0.0
|
| 329 |
-
)
|
| 330 |
-
has_provider_mention = (
|
| 331 |
-
1.0
|
| 332 |
-
if re.search(
|
| 333 |
-
r"\b(Claude|Anthropic|GPT|OpenAI|ChatGPT|Gemini|Google|Bard|Grok|xAI"
|
| 334 |
-
r"|DeepSeek|Kimi|Moonshot|Mistral|MiniMax|Zhipu|GLM|深度求索)\b",
|
| 335 |
-
text,
|
| 336 |
-
re.IGNORECASE,
|
| 337 |
-
)
|
| 338 |
-
else 0.0
|
| 339 |
-
)
|
| 340 |
|
| 341 |
-
# Response ending patterns
|
| 342 |
ends_with_question = 1.0 if text.rstrip().endswith("?") else 0.0
|
| 343 |
-
has_closing_offer = (
|
| 344 |
-
1.0
|
| 345 |
-
if re.search(
|
| 346 |
-
r"(let me know|feel free|happy to help|don't hesitate|hope this helps)",
|
| 347 |
-
text,
|
| 348 |
-
re.IGNORECASE,
|
| 349 |
-
)
|
| 350 |
-
else 0.0
|
| 351 |
-
)
|
| 352 |
|
| 353 |
-
# Sentence complexity (approximation via commas per sentence)
|
| 354 |
commas_per_sentence = text.count(",") / n_sentences
|
| 355 |
|
| 356 |
-
# Line-level features
|
| 357 |
avg_line_len = (
|
| 358 |
-
np.mean([len(
|
| 359 |
)
|
| 360 |
short_lines_ratio = (
|
| 361 |
-
sum(1 for
|
| 362 |
)
|
| 363 |
|
| 364 |
-
|
| 365 |
-
cap_words = len(re.findall(r"\b[A-Z][a-z]+\b", text))
|
| 366 |
cap_word_ratio = cap_words / n_words
|
| 367 |
|
| 368 |
-
|
| 369 |
-
four_word_phrases = len(re.findall(r"\b\w+\s+\w+\s+\w+\s+\w+\b", text))
|
| 370 |
phrase_ratio = four_word_phrases / n_sentences
|
| 371 |
|
| 372 |
-
|
| 373 |
-
sent_boundaries = len(re.findall(r"[.!?]\s+[A-Z]", text))
|
| 374 |
sent_boundary_ratio = sent_boundaries / n_sentences
|
| 375 |
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
)
|
| 380 |
-
has_arrow = 1.0 if "→" in text or "←" in text or "➡" in text else 0.0
|
| 381 |
-
has_star = 1.0 if "⭐" in text or "★" in text or "☆" in text else 0.0
|
| 382 |
-
special_unicode = len(re.findall(r"[^\x00-\x7F]", text)) / n_chars
|
| 383 |
|
| 384 |
-
|
| 385 |
-
colon_definitions = len(re.findall(r"\b\w+:\s+\w+", text)) / n_sentences
|
| 386 |
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
single_quote_pairs = len(re.findall(r"'[^']*'", text)) / n_sentences
|
| 390 |
|
| 391 |
-
|
| 392 |
-
greeting_patterns = len(
|
| 393 |
-
re.findall(
|
| 394 |
-
r"\b(hi|hello|hey|hiya|greetings|howdy|yo)\b", text, re.IGNORECASE
|
| 395 |
-
)
|
| 396 |
-
)
|
| 397 |
greeting_ratio = greeting_patterns / n_sentences
|
| 398 |
|
| 399 |
-
# Response length categories
|
| 400 |
is_short = 1.0 if n_words < 100 else 0.0
|
| 401 |
is_medium = 1.0 if 100 <= n_words < 500 else 0.0
|
| 402 |
is_long = 1.0 if n_words >= 500 else 0.0
|
| 403 |
|
| 404 |
-
# Exclamation usage
|
| 405 |
excl_sentences = sum(1 for s in sentences if s.strip().endswith("!"))
|
| 406 |
excl_sentence_ratio = excl_sentences / n_sentences
|
| 407 |
|
| 408 |
-
|
| 409 |
-
question_lines = [l for l in non_empty_lines if l.strip().endswith("?")]
|
| 410 |
question_line_ratio = len(question_lines) / n_lines if n_lines > 0 else 0.0
|
| 411 |
|
| 412 |
-
|
| 413 |
-
conversational_phrases = len(
|
| 414 |
-
re.findall(
|
| 415 |
-
r"\b(great|perfect|sure|definitely|certainly|absolutely|of course"
|
| 416 |
-
r"|no problem|sounds good|got it|understood|okay|alright)\b",
|
| 417 |
-
text,
|
| 418 |
-
re.IGNORECASE,
|
| 419 |
-
)
|
| 420 |
-
)
|
| 421 |
conv_phrase_ratio = conversational_phrases / n_words
|
| 422 |
|
| 423 |
-
|
| 424 |
-
helpful_phrases = len(
|
| 425 |
-
re.findall(
|
| 426 |
-
r"\b(let me know|feel free|happy to|glad to|happy to help"
|
| 427 |
-
r"|don't hesitate|let me know if|please let me|reach out)\b",
|
| 428 |
-
text,
|
| 429 |
-
re.IGNORECASE,
|
| 430 |
-
)
|
| 431 |
-
)
|
| 432 |
helpful_ratio = helpful_phrases / n_sentences
|
| 433 |
|
| 434 |
return [
|
| 435 |
-
# Basic word stats (0-3)
|
| 436 |
avg_word_len,
|
| 437 |
word_len_std,
|
| 438 |
median_word_len,
|
| 439 |
avg_sent_len,
|
| 440 |
-
# Sentence stats (4-9)
|
| 441 |
sent_len_std,
|
| 442 |
sent_len_max,
|
| 443 |
sent_len_min,
|
| 444 |
sent_len_median,
|
| 445 |
sent_len_range,
|
| 446 |
commas_per_sentence,
|
| 447 |
-
# Punctuation density (10-22)
|
| 448 |
n_commas,
|
| 449 |
n_semicolons,
|
| 450 |
n_colons,
|
|
@@ -458,7 +448,6 @@ class StylometricFeatures(BaseEstimator, TransformerMixin):
|
|
| 458 |
comma_colon_ratio,
|
| 459 |
comma_period_ratio,
|
| 460 |
excl_question_ratio,
|
| 461 |
-
# Markdown features (23-30)
|
| 462 |
n_headers,
|
| 463 |
n_bold,
|
| 464 |
n_code_blocks,
|
|
@@ -467,7 +456,6 @@ class StylometricFeatures(BaseEstimator, TransformerMixin):
|
|
| 467 |
n_numbered,
|
| 468 |
n_tables,
|
| 469 |
has_list,
|
| 470 |
-
# Structure (31-40)
|
| 471 |
newline_density,
|
| 472 |
double_newline_ratio,
|
| 473 |
uppercase_ratio,
|
|
@@ -478,47 +466,37 @@ class StylometricFeatures(BaseEstimator, TransformerMixin):
|
|
| 478 |
list_items,
|
| 479 |
n_paragraphs,
|
| 480 |
n_lines / n_sentences,
|
| 481 |
-
# Sentence level (41-44)
|
| 482 |
has_think,
|
| 483 |
has_xml,
|
| 484 |
has_hr,
|
| 485 |
has_url,
|
| 486 |
-
# Pronoun features (45-47)
|
| 487 |
first_person_ratio,
|
| 488 |
second_person_ratio,
|
| 489 |
third_person_ratio,
|
| 490 |
-
# Vocabulary (48-52)
|
| 491 |
ttr,
|
| 492 |
hapax_ratio,
|
| 493 |
contraction_ratio,
|
| 494 |
short_word_ratio,
|
| 495 |
medium_word_ratio,
|
| 496 |
-
# Word length distributions (53-54)
|
| 497 |
long_word_ratio,
|
| 498 |
very_long_word_ratio,
|
| 499 |
-
# Sentence starters (55-60)
|
| 500 |
starter_vocab,
|
| 501 |
and_starts,
|
| 502 |
but_starts,
|
| 503 |
so_starts,
|
| 504 |
the_starts,
|
| 505 |
it_starts,
|
| 506 |
-
# Paragraph stats (61-62)
|
| 507 |
avg_para_len,
|
| 508 |
para_len_std,
|
| 509 |
-
# Discourse markers (63-67)
|
| 510 |
conjunction_ratio,
|
| 511 |
discourse_ratio,
|
| 512 |
hedging_ratio,
|
| 513 |
certainty_ratio,
|
| 514 |
transition_ratio,
|
| 515 |
-
# Questions (68)
|
| 516 |
question_starts / n_sentences if n_sentences > 0 else 0,
|
| 517 |
-
# Emoji/special (69-71)
|
| 518 |
emoji_count,
|
| 519 |
has_emoji,
|
| 520 |
special_unicode,
|
| 521 |
-
# Style markers (72-79)
|
| 522 |
all_caps_ratio,
|
| 523 |
paren_ratio,
|
| 524 |
rhetorical_ratio,
|
|
@@ -527,25 +505,21 @@ class StylometricFeatures(BaseEstimator, TransformerMixin):
|
|
| 527 |
has_chinese,
|
| 528 |
chinese_ratio,
|
| 529 |
has_self_id_ai,
|
| 530 |
-
# Provider mention & response patterns (80-83)
|
| 531 |
has_provider_mention,
|
| 532 |
ends_with_question,
|
| 533 |
has_closing_offer,
|
| 534 |
has_checkmark,
|
| 535 |
-
# More structure (84-89)
|
| 536 |
has_arrow,
|
| 537 |
has_star,
|
| 538 |
avg_line_len,
|
| 539 |
short_lines_ratio,
|
| 540 |
cap_word_ratio,
|
| 541 |
phrase_ratio,
|
| 542 |
-
# Final features (90-94)
|
| 543 |
sent_boundary_ratio,
|
| 544 |
colon_definitions,
|
| 545 |
double_quote_pairs,
|
| 546 |
single_quote_pairs,
|
| 547 |
i_starts,
|
| 548 |
-
# New features (95-102)
|
| 549 |
greeting_ratio,
|
| 550 |
is_short,
|
| 551 |
is_medium,
|
|
@@ -557,6 +531,32 @@ class StylometricFeatures(BaseEstimator, TransformerMixin):
|
|
| 557 |
]
|
| 558 |
|
| 559 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 560 |
class FeaturePipeline:
|
| 561 |
def __init__(self, use_tfidf=True):
|
| 562 |
word_params = dict(TFIDF_WORD_PARAMS)
|
|
@@ -577,7 +577,6 @@ class FeaturePipeline:
|
|
| 577 |
)
|
| 578 |
|
| 579 |
def _clean_for_tfidf(self, text):
|
| 580 |
-
"""Strip CoT and markdown for TF-IDF (remove formatting artifacts, keep content)."""
|
| 581 |
return strip_markdown(strip_cot(text))
|
| 582 |
|
| 583 |
def fit_transform(self, texts):
|
|
@@ -585,8 +584,8 @@ class FeaturePipeline:
|
|
| 585 |
|
| 586 |
print(f" Input: {len(texts)} texts", flush=True)
|
| 587 |
|
| 588 |
-
|
| 589 |
-
|
| 590 |
|
| 591 |
use_word_tfidf = (
|
| 592 |
self.word_tfidf.max_features is not None
|
|
@@ -613,7 +612,7 @@ class FeaturePipeline:
|
|
| 613 |
char_features = csr_matrix((len(texts), 0), dtype=np.float32)
|
| 614 |
|
| 615 |
t0 = time.time()
|
| 616 |
-
stylo_features = self.stylo.
|
| 617 |
print(
|
| 618 |
f" stylometric: {stylo_features.shape[1]} features ({time.time() - t0:.1f}s)",
|
| 619 |
flush=True,
|
|
@@ -625,8 +624,8 @@ class FeaturePipeline:
|
|
| 625 |
return combined
|
| 626 |
|
| 627 |
def transform(self, texts):
|
| 628 |
-
|
| 629 |
-
|
| 630 |
|
| 631 |
use_word_tfidf = (
|
| 632 |
self.word_tfidf.max_features is not None
|
|
@@ -642,6 +641,6 @@ class FeaturePipeline:
|
|
| 642 |
else:
|
| 643 |
char_features = csr_matrix((len(texts), 0), dtype=np.float32)
|
| 644 |
|
| 645 |
-
stylo_features = self.stylo.transform(
|
| 646 |
combined = hstack([word_features, char_features, stylo_features])
|
| 647 |
return self.scaler.transform(combined)
|
|
|
|
| 1 |
"""
|
| 2 |
AIFinder Feature Extraction
|
| 3 |
+
Optimized TF-IDF and stylometric features for AI model detection.
|
| 4 |
"""
|
| 5 |
|
| 6 |
import re
|
|
|
|
| 12 |
|
| 13 |
from config import TFIDF_WORD_PARAMS, TFIDF_CHAR_PARAMS
|
| 14 |
|
| 15 |
+
_RE_COMPILED = {
|
| 16 |
+
"cot": re.compile(r"<think(?:ing)?>.*?</think(?:ing)?>", re.DOTALL),
|
| 17 |
+
"code_block": re.compile(r"```[\s\S]*?```"),
|
| 18 |
+
"inline_code": re.compile(r"`[^`]+`"),
|
| 19 |
+
"bold": re.compile(r"\*\*([^*]+)\*\*"),
|
| 20 |
+
"italic_ast": re.compile(r"\*([^*]+)\*"),
|
| 21 |
+
"italic_under": re.compile(r"__([^_]+)__"),
|
| 22 |
+
"under": re.compile(r"_([^_]+)_"),
|
| 23 |
+
"header": re.compile(r"^#{1,6}\s+", re.MULTILINE),
|
| 24 |
+
"bullet": re.compile(r"^[\s]*[-*+]\s+", re.MULTILINE),
|
| 25 |
+
"numbered": re.compile(r"^\s*\d+[.)]\s+", re.MULTILINE),
|
| 26 |
+
"link": re.compile(r"\[([^\]]+)\]\([^)]+\)"),
|
| 27 |
+
"quote": re.compile(r"^>.*$", re.MULTILINE),
|
| 28 |
+
"hr": re.compile(r"^---+$", re.MULTILINE),
|
| 29 |
+
"think_tag": re.compile(r"<think>"),
|
| 30 |
+
"xml_tag": re.compile(r"<[^>]+>"),
|
| 31 |
+
"url": re.compile(r"https?://"),
|
| 32 |
+
"contraction": re.compile(r"\b\w+'\w+\b"),
|
| 33 |
+
"markdown_header": re.compile(r"^#{1,6}\s", re.MULTILINE),
|
| 34 |
+
"markdown_bold": re.compile(r"\*\*.*?\*\*"),
|
| 35 |
+
"markdown_code_block": re.compile(r"```"),
|
| 36 |
+
"markdown_inline_code": re.compile(r"`[^`]+`"),
|
| 37 |
+
"markdown_bullet": re.compile(r"^[\s]*[-*+]\s", re.MULTILINE),
|
| 38 |
+
"markdown_numbered": re.compile(r"^\s*\d+[.)]\s", re.MULTILINE),
|
| 39 |
+
"markdown_table": re.compile(r"\|.*\|"),
|
| 40 |
+
"question_start": re.compile(
|
| 41 |
+
r"^(who|what|when|where|why|how)\b", re.IGNORECASE | re.MULTILINE
|
| 42 |
+
),
|
| 43 |
+
"emoji": re.compile(r"[\U00010000-\U0010ffff]"),
|
| 44 |
+
"chinese": re.compile(r"[\u4e00-\u9fff]"),
|
| 45 |
+
"all_caps": re.compile(r"\b[A-Z][a-z]+\b"),
|
| 46 |
+
"four_word": re.compile(r"\b\w+\s+\w+\s+\w+\s+\w+\b"),
|
| 47 |
+
"sent_boundary": re.compile(r"[.!?]\s+[A-Z]"),
|
| 48 |
+
"paren": re.compile(r"\([^)]+\)"),
|
| 49 |
+
"colon_def": re.compile(r"\b\w+:\s+\w+"),
|
| 50 |
+
"double_quote": re.compile(r'"[^"]*"'),
|
| 51 |
+
"single_quote": re.compile(r"'[^']*'"),
|
| 52 |
+
"greeting": re.compile(
|
| 53 |
+
r"\b(hi|hello|hey|hiya|greetings|howdy|yo)\b", re.IGNORECASE
|
| 54 |
+
),
|
| 55 |
+
"conv_phrase": re.compile(
|
| 56 |
+
r"\b(great|perfect|sure|definitely|certainly|absolutely|of course|no problem|sounds good|got it|understood|okay|alright)\b",
|
| 57 |
+
re.IGNORECASE,
|
| 58 |
+
),
|
| 59 |
+
"helpful": re.compile(
|
| 60 |
+
r"\b(let me know|feel free|happy to|glad to|happy to help|don't hesitate|let me know if|please let me|reach out)\b",
|
| 61 |
+
re.IGNORECASE,
|
| 62 |
+
),
|
| 63 |
+
"closing_offer": re.compile(
|
| 64 |
+
r"(let me know|feel free|happy to help|don't hesitate|hope this helps)",
|
| 65 |
+
re.IGNORECASE,
|
| 66 |
+
),
|
| 67 |
+
"self_id_ai": re.compile(
|
| 68 |
+
r"\b(I'm|I am)\s+(an?\s+)?(AI|language model|assistant|chatbot)\b",
|
| 69 |
+
re.IGNORECASE,
|
| 70 |
+
),
|
| 71 |
+
"provider_mention": re.compile(
|
| 72 |
+
r"\b(Claude|Anthropic|GPT|OpenAI|ChatGPT|Gemini|Google|Bard|Grok|xAI|DeepSeek|Kimi|Moonshot|Mistral|MiniMax|Zhipu|GLM|深度求索)\b",
|
| 73 |
+
re.IGNORECASE,
|
| 74 |
+
),
|
| 75 |
+
"special_unicode": re.compile(r"[^\x00-\x7F]"),
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
_PRONOUN_SETS = {
|
| 79 |
+
"first": frozenset(
|
| 80 |
+
{"i", "me", "my", "mine", "myself", "we", "us", "our", "ours", "ourselves"}
|
| 81 |
+
),
|
| 82 |
+
"second": frozenset({"you", "your", "yours", "yourself", "yourselves"}),
|
| 83 |
+
"third": frozenset(
|
| 84 |
+
{"he", "she", "it", "they", "them", "his", "her", "its", "their"}
|
| 85 |
+
),
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
_DISCOURSE_SETS = {
|
| 89 |
+
"conjunctions": frozenset(
|
| 90 |
+
{
|
| 91 |
+
"and",
|
| 92 |
+
"but",
|
| 93 |
+
"or",
|
| 94 |
+
"nor",
|
| 95 |
+
"for",
|
| 96 |
+
"yet",
|
| 97 |
+
"so",
|
| 98 |
+
"because",
|
| 99 |
+
"although",
|
| 100 |
+
"while",
|
| 101 |
+
"if",
|
| 102 |
+
"when",
|
| 103 |
+
"where",
|
| 104 |
+
}
|
| 105 |
+
),
|
| 106 |
+
"discourse": frozenset(
|
| 107 |
+
{
|
| 108 |
+
"however",
|
| 109 |
+
"therefore",
|
| 110 |
+
"moreover",
|
| 111 |
+
"furthermore",
|
| 112 |
+
"nevertheless",
|
| 113 |
+
"consequently",
|
| 114 |
+
"thus",
|
| 115 |
+
"hence",
|
| 116 |
+
}
|
| 117 |
+
),
|
| 118 |
+
"hedging": frozenset(
|
| 119 |
+
{
|
| 120 |
+
"perhaps",
|
| 121 |
+
"maybe",
|
| 122 |
+
"might",
|
| 123 |
+
"could",
|
| 124 |
+
"possibly",
|
| 125 |
+
"seemingly",
|
| 126 |
+
"apparently",
|
| 127 |
+
"arguably",
|
| 128 |
+
"potentially",
|
| 129 |
+
}
|
| 130 |
+
),
|
| 131 |
+
"certainty": frozenset(
|
| 132 |
+
{
|
| 133 |
+
"definitely",
|
| 134 |
+
"certainly",
|
| 135 |
+
"absolutely",
|
| 136 |
+
"clearly",
|
| 137 |
+
"obviously",
|
| 138 |
+
"undoubtedly",
|
| 139 |
+
"indeed",
|
| 140 |
+
"surely",
|
| 141 |
+
}
|
| 142 |
+
),
|
| 143 |
+
"transition": frozenset(
|
| 144 |
+
{
|
| 145 |
+
"additionally",
|
| 146 |
+
"meanwhile",
|
| 147 |
+
"subsequently",
|
| 148 |
+
"alternatively",
|
| 149 |
+
"specifically",
|
| 150 |
+
"notably",
|
| 151 |
+
"importantly",
|
| 152 |
+
"essentially",
|
| 153 |
+
}
|
| 154 |
+
),
|
| 155 |
+
"casual": frozenset(
|
| 156 |
+
{
|
| 157 |
+
"okay",
|
| 158 |
+
"ok",
|
| 159 |
+
"hey",
|
| 160 |
+
"hi",
|
| 161 |
+
"cool",
|
| 162 |
+
"awesome",
|
| 163 |
+
"wow",
|
| 164 |
+
"basically",
|
| 165 |
+
"actually",
|
| 166 |
+
"literally",
|
| 167 |
+
"right",
|
| 168 |
+
"yeah",
|
| 169 |
+
}
|
| 170 |
+
),
|
| 171 |
+
"formal": frozenset(
|
| 172 |
+
{
|
| 173 |
+
"regarding",
|
| 174 |
+
"concerning",
|
| 175 |
+
"pertaining",
|
| 176 |
+
"aforementioned",
|
| 177 |
+
"respectively",
|
| 178 |
+
"accordingly",
|
| 179 |
+
"henceforth",
|
| 180 |
+
"whereby",
|
| 181 |
+
"notwithstanding",
|
| 182 |
+
"pursuant",
|
| 183 |
+
}
|
| 184 |
+
),
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
_PUNC_STRIP = frozenset(".,!?;:'\"()[]{}")
|
| 188 |
+
|
| 189 |
|
| 190 |
def strip_cot(text):
|
| 191 |
+
return _RE_COMPILED["cot"].sub("", text).strip()
|
|
|
|
| 192 |
|
| 193 |
|
| 194 |
def strip_markdown(text):
|
| 195 |
+
text = _RE_COMPILED["code_block"].sub("", text)
|
| 196 |
+
text = _RE_COMPILED["inline_code"].sub("", text)
|
| 197 |
+
text = _RE_COMPILED["bold"].sub(r"\1", text)
|
| 198 |
+
text = _RE_COMPILED["italic_ast"].sub(r"\1", text)
|
| 199 |
+
text = _RE_COMPILED["italic_under"].sub(r"\1", text)
|
| 200 |
+
text = _RE_COMPILED["under"].sub(r"\1", text)
|
| 201 |
+
text = _RE_COMPILED["header"].sub("", text)
|
| 202 |
+
text = _RE_COMPILED["bullet"].sub("", text)
|
| 203 |
+
text = _RE_COMPILED["numbered"].sub("", text)
|
| 204 |
+
text = _RE_COMPILED["link"].sub(r"\1", text)
|
| 205 |
+
text = _RE_COMPILED["quote"].sub("", text)
|
| 206 |
+
text = _RE_COMPILED["hr"].sub("", text)
|
| 207 |
return text.strip()
|
| 208 |
|
| 209 |
|
|
|
|
| 212 |
return self
|
| 213 |
|
| 214 |
def transform(self, X):
|
| 215 |
+
return csr_matrix(np.array([self._extract(t) for t in X], dtype=np.float32))
|
|
|
|
|
|
|
|
|
|
| 216 |
|
| 217 |
def _extract(self, text):
|
|
|
|
| 218 |
n_chars = max(len(text), 1)
|
| 219 |
+
words = text.split()
|
| 220 |
n_words = max(len(words), 1)
|
| 221 |
|
| 222 |
+
sentences = [s.strip() for s in re.split(r"[.!?]+", text) if s.strip()]
|
|
|
|
| 223 |
n_sentences = max(len(sentences), 1)
|
| 224 |
|
| 225 |
paragraphs = text.split("\n\n")
|
|
|
|
| 227 |
n_paragraphs = len(non_empty_paras)
|
| 228 |
|
| 229 |
lines = text.split("\n")
|
| 230 |
+
non_empty_lines = [ln for ln in lines if ln.strip()]
|
| 231 |
n_lines = max(len(non_empty_lines), 1)
|
| 232 |
|
|
|
|
| 233 |
word_lens = [len(w) for w in words]
|
| 234 |
+
sent_lens = [len(s.split()) for s in sentences]
|
| 235 |
+
|
| 236 |
+
_rc = _RE_COMPILED
|
| 237 |
+
_ps = _PRONOUN_SETS
|
| 238 |
+
_ds = _DISCOURSE_SETS
|
| 239 |
+
|
| 240 |
+
avg_word_len = np.mean(word_lens) if words else 0.0
|
| 241 |
+
word_len_std = np.std(word_lens) if len(words) > 1 else 0.0
|
| 242 |
+
median_word_len = np.median(word_lens) if words else 0.0
|
| 243 |
avg_sent_len = n_words / n_sentences
|
| 244 |
|
|
|
|
| 245 |
n_commas = text.count(",") / n_chars
|
| 246 |
n_semicolons = text.count(";") / n_chars
|
| 247 |
n_colons = text.count(":") / n_chars
|
|
|
|
| 257 |
comma_period_ratio = n_commas / (n_period + 0.001)
|
| 258 |
excl_question_ratio = n_exclaim / (n_question + 0.001)
|
| 259 |
|
| 260 |
+
n_headers = len(_rc["markdown_header"].findall(text)) / n_sentences
|
| 261 |
+
n_bold = len(_rc["markdown_bold"].findall(text)) / n_sentences
|
| 262 |
+
n_code_blocks = len(_rc["markdown_code_block"].findall(text)) / n_sentences
|
| 263 |
+
n_inline_code = len(_rc["markdown_inline_code"].findall(text)) / n_sentences
|
| 264 |
+
n_bullet = len(_rc["markdown_bullet"].findall(text)) / n_sentences
|
| 265 |
+
n_numbered = len(_rc["markdown_numbered"].findall(text)) / n_sentences
|
| 266 |
+
n_tables = len(_rc["markdown_table"].findall(text)) / n_sentences
|
|
|
|
| 267 |
|
|
|
|
| 268 |
newline_density = text.count("\n") / n_chars
|
| 269 |
double_newline_ratio = text.count("\n\n") / (text.count("\n") + 1)
|
| 270 |
uppercase_ratio = sum(1 for c in text if c.isupper()) / n_chars
|
|
|
|
| 274 |
unique_chars = len(set(text)) / n_chars
|
| 275 |
unique_chars_ratio = len(set(text.lower())) / n_chars
|
| 276 |
|
| 277 |
+
sent_len_std = np.std(sent_lens) if len(sent_lens) > 1 else 0.0
|
|
|
|
|
|
|
| 278 |
sent_len_max = max(sent_lens) if sent_lens else 0
|
| 279 |
sent_len_min = min(sent_lens) if sent_lens else 0
|
| 280 |
+
sent_len_median = np.median(sent_lens) if sent_lens else 0.0
|
| 281 |
sent_len_range = sent_len_max - sent_len_min
|
| 282 |
|
| 283 |
+
has_think = 1.0 if _rc["think_tag"].search(text) else 0.0
|
| 284 |
+
has_xml = 1.0 if _rc["xml_tag"].search(text) else 0.0
|
| 285 |
+
has_hr = 1.0 if _rc["hr"].search(text) else 0.0
|
| 286 |
+
has_url = 1.0 if _rc["url"].search(text) else 0.0
|
|
|
|
| 287 |
|
|
|
|
| 288 |
words_lower = [w.lower().strip(".,!?;:'\"()[]{}") for w in words]
|
| 289 |
+
first_person_ratio = sum(1 for w in words_lower if w in _ps["first"]) / n_words
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
second_person_ratio = (
|
| 291 |
+
sum(1 for w in words_lower if w in _ps["second"]) / n_words
|
| 292 |
)
|
| 293 |
+
third_person_ratio = sum(1 for w in words_lower if w in _ps["third"]) / n_words
|
| 294 |
|
|
|
|
| 295 |
unique_words = len(set(words_lower))
|
| 296 |
+
ttr = unique_words / n_words if n_words > 0 else 0.0
|
| 297 |
+
word_counts = {}
|
| 298 |
+
for w in words_lower:
|
| 299 |
+
word_counts[w] = word_counts.get(w, 0) + 1
|
| 300 |
+
hapax = sum(1 for c in word_counts.values() if c == 1)
|
| 301 |
+
hapax_ratio = hapax / n_words if n_words > 0 else 0.0
|
| 302 |
|
| 303 |
+
contraction_count = len(_rc["contraction"].findall(text))
|
| 304 |
+
contraction_ratio = contraction_count / n_words if n_words > 0 else 0.0
|
| 305 |
|
|
|
|
| 306 |
sentences_starters = [
|
| 307 |
s.split()[0].lower() if s.split() else "" for s in sentences
|
| 308 |
]
|
| 309 |
starter_vocab = (
|
| 310 |
+
len(set(sentences_starters)) / n_sentences if n_sentences > 0 else 0.0
|
| 311 |
)
|
| 312 |
|
| 313 |
and_starts = sum(1 for s in sentences_starters if s == "and") / n_sentences
|
|
|
|
| 322 |
/ n_sentences
|
| 323 |
)
|
| 324 |
|
|
|
|
| 325 |
short_word_ratio = sum(1 for w in words_lower if len(w) <= 2) / n_words
|
| 326 |
medium_word_ratio = sum(1 for w in words_lower if 3 <= len(w) <= 6) / n_words
|
| 327 |
long_word_ratio = sum(1 for w in words_lower if len(w) >= 7) / n_words
|
| 328 |
very_long_word_ratio = sum(1 for w in words_lower if len(w) >= 10) / n_words
|
| 329 |
|
|
|
|
| 330 |
para_lens = (
|
| 331 |
[len(p.split()) for p in non_empty_paras] if non_empty_paras else [0]
|
| 332 |
)
|
| 333 |
avg_para_len = np.mean(para_lens)
|
| 334 |
+
para_len_std = np.std(para_lens) if len(para_lens) > 1 else 0.0
|
| 335 |
|
| 336 |
+
conjunction_ratio = (
|
| 337 |
+
sum(1 for w in words_lower if w in _ds["conjunctions"]) / n_words
|
| 338 |
+
)
|
| 339 |
+
discourse_ratio = sum(1 for w in words_lower if w in _ds["discourse"]) / n_words
|
| 340 |
+
hedging_ratio = sum(1 for w in words_lower if w in _ds["hedging"]) / n_words
|
| 341 |
+
certainty_ratio = sum(1 for w in words_lower if w in _ds["certainty"]) / n_words
|
| 342 |
+
transition_ratio = (
|
| 343 |
+
sum(1 for w in words_lower if w in _ds["transition"]) / n_words
|
| 344 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
|
|
|
|
| 346 |
question_starts = sum(
|
| 347 |
+
1 for s in sentences if s and _rc["question_start"].search(s.lower())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 348 |
)
|
| 349 |
|
|
|
|
| 350 |
has_list = 1.0 if n_bullet > 0 or n_numbered > 0 else 0.0
|
| 351 |
list_items = n_bullet + n_numbered
|
| 352 |
|
| 353 |
+
emoji_count = len(_rc["emoji"].findall(text))
|
|
|
|
| 354 |
has_emoji = 1.0 if emoji_count > 0 else 0.0
|
| 355 |
|
|
|
|
|
|
|
| 356 |
all_caps_words = sum(
|
| 357 |
1 for w in words if len(w) > 1 and w.isupper() and w.isalpha()
|
| 358 |
)
|
| 359 |
all_caps_ratio = all_caps_words / n_words
|
| 360 |
|
| 361 |
+
paren_count = len(_rc["paren"].findall(text))
|
|
|
|
| 362 |
paren_ratio = paren_count / n_sentences
|
| 363 |
|
|
|
|
| 364 |
rhetorical_q = sum(1 for s in text.split("\n") if s.strip().endswith("?"))
|
| 365 |
rhetorical_ratio = rhetorical_q / n_sentences
|
| 366 |
|
| 367 |
+
casual_ratio = sum(1 for w in words_lower if w in _ds["casual"]) / n_words
|
| 368 |
+
formal_ratio = sum(1 for w in words_lower if w in _ds["formal"]) / n_words
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
|
| 370 |
+
chinese_chars = len(_rc["chinese"].findall(text))
|
|
|
|
| 371 |
has_chinese = 1.0 if chinese_chars > 0 else 0.0
|
| 372 |
chinese_ratio = chinese_chars / n_chars
|
| 373 |
|
| 374 |
+
has_self_id_ai = 1.0 if _rc["self_id_ai"].search(text) else 0.0
|
| 375 |
+
has_provider_mention = 1.0 if _rc["provider_mention"].search(text) else 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
|
|
|
|
| 377 |
ends_with_question = 1.0 if text.rstrip().endswith("?") else 0.0
|
| 378 |
+
has_closing_offer = 1.0 if _rc["closing_offer"].search(text) else 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
|
|
|
|
| 380 |
commas_per_sentence = text.count(",") / n_sentences
|
| 381 |
|
|
|
|
| 382 |
avg_line_len = (
|
| 383 |
+
np.mean([len(ln) for ln in non_empty_lines]) if non_empty_lines else 0.0
|
| 384 |
)
|
| 385 |
short_lines_ratio = (
|
| 386 |
+
sum(1 for ln in non_empty_lines if len(ln.split()) <= 5) / n_lines
|
| 387 |
)
|
| 388 |
|
| 389 |
+
cap_words = len(_rc["all_caps"].findall(text))
|
|
|
|
| 390 |
cap_word_ratio = cap_words / n_words
|
| 391 |
|
| 392 |
+
four_word_phrases = len(_rc["four_word"].findall(text))
|
|
|
|
| 393 |
phrase_ratio = four_word_phrases / n_sentences
|
| 394 |
|
| 395 |
+
sent_boundaries = len(_rc["sent_boundary"].findall(text))
|
|
|
|
| 396 |
sent_boundary_ratio = sent_boundaries / n_sentences
|
| 397 |
|
| 398 |
+
has_checkmark = 1.0 if any(c in text for c in "✓✗✔✘") else 0.0
|
| 399 |
+
has_arrow = 1.0 if any(c in text for c in "→←➡") else 0.0
|
| 400 |
+
has_star = 1.0 if any(c in text for c in "⭐★☆") else 0.0
|
| 401 |
+
special_unicode = len(_rc["special_unicode"].findall(text)) / n_chars
|
|
|
|
|
|
|
|
|
|
| 402 |
|
| 403 |
+
colon_definitions = len(_rc["colon_def"].findall(text)) / n_sentences
|
|
|
|
| 404 |
|
| 405 |
+
double_quote_pairs = len(_rc["double_quote"].findall(text)) / n_sentences
|
| 406 |
+
single_quote_pairs = len(_rc["single_quote"].findall(text)) / n_sentences
|
|
|
|
| 407 |
|
| 408 |
+
greeting_patterns = len(_rc["greeting"].findall(text))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 409 |
greeting_ratio = greeting_patterns / n_sentences
|
| 410 |
|
|
|
|
| 411 |
is_short = 1.0 if n_words < 100 else 0.0
|
| 412 |
is_medium = 1.0 if 100 <= n_words < 500 else 0.0
|
| 413 |
is_long = 1.0 if n_words >= 500 else 0.0
|
| 414 |
|
|
|
|
| 415 |
excl_sentences = sum(1 for s in sentences if s.strip().endswith("!"))
|
| 416 |
excl_sentence_ratio = excl_sentences / n_sentences
|
| 417 |
|
| 418 |
+
question_lines = [ln for ln in non_empty_lines if ln.strip().endswith("?")]
|
|
|
|
| 419 |
question_line_ratio = len(question_lines) / n_lines if n_lines > 0 else 0.0
|
| 420 |
|
| 421 |
+
conversational_phrases = len(_rc["conv_phrase"].findall(text))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 422 |
conv_phrase_ratio = conversational_phrases / n_words
|
| 423 |
|
| 424 |
+
helpful_phrases = len(_rc["helpful"].findall(text))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 425 |
helpful_ratio = helpful_phrases / n_sentences
|
| 426 |
|
| 427 |
return [
|
|
|
|
| 428 |
avg_word_len,
|
| 429 |
word_len_std,
|
| 430 |
median_word_len,
|
| 431 |
avg_sent_len,
|
|
|
|
| 432 |
sent_len_std,
|
| 433 |
sent_len_max,
|
| 434 |
sent_len_min,
|
| 435 |
sent_len_median,
|
| 436 |
sent_len_range,
|
| 437 |
commas_per_sentence,
|
|
|
|
| 438 |
n_commas,
|
| 439 |
n_semicolons,
|
| 440 |
n_colons,
|
|
|
|
| 448 |
comma_colon_ratio,
|
| 449 |
comma_period_ratio,
|
| 450 |
excl_question_ratio,
|
|
|
|
| 451 |
n_headers,
|
| 452 |
n_bold,
|
| 453 |
n_code_blocks,
|
|
|
|
| 456 |
n_numbered,
|
| 457 |
n_tables,
|
| 458 |
has_list,
|
|
|
|
| 459 |
newline_density,
|
| 460 |
double_newline_ratio,
|
| 461 |
uppercase_ratio,
|
|
|
|
| 466 |
list_items,
|
| 467 |
n_paragraphs,
|
| 468 |
n_lines / n_sentences,
|
|
|
|
| 469 |
has_think,
|
| 470 |
has_xml,
|
| 471 |
has_hr,
|
| 472 |
has_url,
|
|
|
|
| 473 |
first_person_ratio,
|
| 474 |
second_person_ratio,
|
| 475 |
third_person_ratio,
|
|
|
|
| 476 |
ttr,
|
| 477 |
hapax_ratio,
|
| 478 |
contraction_ratio,
|
| 479 |
short_word_ratio,
|
| 480 |
medium_word_ratio,
|
|
|
|
| 481 |
long_word_ratio,
|
| 482 |
very_long_word_ratio,
|
|
|
|
| 483 |
starter_vocab,
|
| 484 |
and_starts,
|
| 485 |
but_starts,
|
| 486 |
so_starts,
|
| 487 |
the_starts,
|
| 488 |
it_starts,
|
|
|
|
| 489 |
avg_para_len,
|
| 490 |
para_len_std,
|
|
|
|
| 491 |
conjunction_ratio,
|
| 492 |
discourse_ratio,
|
| 493 |
hedging_ratio,
|
| 494 |
certainty_ratio,
|
| 495 |
transition_ratio,
|
|
|
|
| 496 |
question_starts / n_sentences if n_sentences > 0 else 0,
|
|
|
|
| 497 |
emoji_count,
|
| 498 |
has_emoji,
|
| 499 |
special_unicode,
|
|
|
|
| 500 |
all_caps_ratio,
|
| 501 |
paren_ratio,
|
| 502 |
rhetorical_ratio,
|
|
|
|
| 505 |
has_chinese,
|
| 506 |
chinese_ratio,
|
| 507 |
has_self_id_ai,
|
|
|
|
| 508 |
has_provider_mention,
|
| 509 |
ends_with_question,
|
| 510 |
has_closing_offer,
|
| 511 |
has_checkmark,
|
|
|
|
| 512 |
has_arrow,
|
| 513 |
has_star,
|
| 514 |
avg_line_len,
|
| 515 |
short_lines_ratio,
|
| 516 |
cap_word_ratio,
|
| 517 |
phrase_ratio,
|
|
|
|
| 518 |
sent_boundary_ratio,
|
| 519 |
colon_definitions,
|
| 520 |
double_quote_pairs,
|
| 521 |
single_quote_pairs,
|
| 522 |
i_starts,
|
|
|
|
| 523 |
greeting_ratio,
|
| 524 |
is_short,
|
| 525 |
is_medium,
|
|
|
|
| 531 |
]
|
| 532 |
|
| 533 |
|
| 534 |
+
class StyleOnlyPipeline:
|
| 535 |
+
"""Feature pipeline using ONLY stylometric features — no TF-IDF."""
|
| 536 |
+
|
| 537 |
+
def __init__(self):
|
| 538 |
+
self.stylo = StylometricFeatures()
|
| 539 |
+
self.scaler = MaxAbsScaler()
|
| 540 |
+
|
| 541 |
+
def fit_transform(self, texts):
|
| 542 |
+
import time
|
| 543 |
+
|
| 544 |
+
texts_clean = [strip_markdown(strip_cot(t)) for t in texts]
|
| 545 |
+
t0 = time.time()
|
| 546 |
+
stylo_features = self.stylo.transform(texts_clean)
|
| 547 |
+
print(
|
| 548 |
+
f" Stylometric: {stylo_features.shape[1]} features ({time.time() - t0:.1f}s)"
|
| 549 |
+
)
|
| 550 |
+
result = self.scaler.fit_transform(stylo_features)
|
| 551 |
+
print(f" Final feature matrix: {result.shape}")
|
| 552 |
+
return result
|
| 553 |
+
|
| 554 |
+
def transform(self, texts):
|
| 555 |
+
texts_clean = [strip_markdown(strip_cot(t)) for t in texts]
|
| 556 |
+
stylo_features = self.stylo.transform(texts_clean)
|
| 557 |
+
return self.scaler.transform(stylo_features)
|
| 558 |
+
|
| 559 |
+
|
| 560 |
class FeaturePipeline:
|
| 561 |
def __init__(self, use_tfidf=True):
|
| 562 |
word_params = dict(TFIDF_WORD_PARAMS)
|
|
|
|
| 577 |
)
|
| 578 |
|
| 579 |
def _clean_for_tfidf(self, text):
|
|
|
|
| 580 |
return strip_markdown(strip_cot(text))
|
| 581 |
|
| 582 |
def fit_transform(self, texts):
|
|
|
|
| 584 |
|
| 585 |
print(f" Input: {len(texts)} texts", flush=True)
|
| 586 |
|
| 587 |
+
texts_clean = [strip_markdown(strip_cot(t)) for t in texts]
|
| 588 |
+
texts_tfidf = texts_clean
|
| 589 |
|
| 590 |
use_word_tfidf = (
|
| 591 |
self.word_tfidf.max_features is not None
|
|
|
|
| 612 |
char_features = csr_matrix((len(texts), 0), dtype=np.float32)
|
| 613 |
|
| 614 |
t0 = time.time()
|
| 615 |
+
stylo_features = self.stylo.transform(texts_clean)
|
| 616 |
print(
|
| 617 |
f" stylometric: {stylo_features.shape[1]} features ({time.time() - t0:.1f}s)",
|
| 618 |
flush=True,
|
|
|
|
| 624 |
return combined
|
| 625 |
|
| 626 |
def transform(self, texts):
|
| 627 |
+
texts_clean = [strip_markdown(strip_cot(t)) for t in texts]
|
| 628 |
+
texts_tfidf = texts_clean
|
| 629 |
|
| 630 |
use_word_tfidf = (
|
| 631 |
self.word_tfidf.max_features is not None
|
|
|
|
| 641 |
else:
|
| 642 |
char_features = csr_matrix((len(texts), 0), dtype=np.float32)
|
| 643 |
|
| 644 |
+
stylo_features = self.stylo.transform(texts_clean)
|
| 645 |
combined = hstack([word_features, char_features, stylo_features])
|
| 646 |
return self.scaler.transform(combined)
|
models/community/enc_4provider.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9eaeb0dde6cecb8561c2f4c47e1aeafb9dab1a6262390c9735716408f2231761
|
| 3 |
+
size 767
|
models/community/pipeline_4provider.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1e916f828358de5f129731d1784b14d4e1f82ef9e315d9b972a88490b980d3ef
|
| 3 |
+
size 1365
|
models/community/rf_4provider.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:da22e16dbf57465c4a178fc7e69a248486c3d32281fd76c395e6e74da8a51a2e
|
| 3 |
+
size 106139474
|
models/jobs.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:24c91be6078f7fd4303d7f28e8a7212cea5f2113e05ab1335dacb6382c62c21e
|
| 3 |
+
size 7254
|
models/style/enc_4provider.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9eaeb0dde6cecb8561c2f4c47e1aeafb9dab1a6262390c9735716408f2231761
|
| 3 |
+
size 767
|
models/style/pipeline_4provider.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1e916f828358de5f129731d1784b14d4e1f82ef9e315d9b972a88490b980d3ef
|
| 3 |
+
size 1365
|
models/style/rf_4provider.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:da22e16dbf57465c4a178fc7e69a248486c3d32281fd76c395e6e74da8a51a2e
|
| 3 |
+
size 106139474
|
templates/index.html
CHANGED
|
@@ -653,6 +653,16 @@
|
|
| 653 |
animation: fadeIn 0.3s ease;
|
| 654 |
}
|
| 655 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 656 |
@media (max-width: 600px) {
|
| 657 |
.container {
|
| 658 |
padding: 1rem;
|
|
@@ -687,6 +697,7 @@
|
|
| 687 |
|
| 688 |
<div class="tabs">
|
| 689 |
<button class="tab active" data-tab="classify">Classify</button>
|
|
|
|
| 690 |
<button class="tab" data-tab="docs">API Docs</button>
|
| 691 |
</div>
|
| 692 |
|
|
@@ -695,6 +706,7 @@
|
|
| 695 |
<div class="status-indicator">
|
| 696 |
<span class="status-dot" id="statusDot"></span>
|
| 697 |
<span id="statusText">Connecting to API...</span>
|
|
|
|
| 698 |
</div>
|
| 699 |
|
| 700 |
<div class="card">
|
|
@@ -751,6 +763,159 @@
|
|
| 751 |
</div>
|
| 752 |
</div>
|
| 753 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 754 |
<!-- ═══ API Docs Tab ═══ -->
|
| 755 |
<div class="tab-content" id="tab-docs">
|
| 756 |
|
|
@@ -1008,11 +1173,6 @@ async function classify(text, topN = 5) {
|
|
| 1008 |
|
| 1009 |
<div class="footer">
|
| 1010 |
<p>AIFinder — Train on corrections to improve accuracy</p>
|
| 1011 |
-
<p style="margin-top: 0.5rem;">
|
| 1012 |
-
Want to contribute? Test this and post to the
|
| 1013 |
-
<a href="https://huggingface.co/spaces" target="_blank">HuggingFace Spaces Community</a>
|
| 1014 |
-
if you want it merged!
|
| 1015 |
-
</p>
|
| 1016 |
</div>
|
| 1017 |
</div>
|
| 1018 |
|
|
@@ -1050,6 +1210,7 @@ async function classify(text, topN = 5) {
|
|
| 1050 |
const toast = document.getElementById('toast');
|
| 1051 |
const statusDot = document.getElementById('statusDot');
|
| 1052 |
const statusText = document.getElementById('statusText');
|
|
|
|
| 1053 |
let usingCommunity = false;
|
| 1054 |
|
| 1055 |
function showToast(message, type = 'info') {
|
|
@@ -1067,6 +1228,9 @@ async function classify(text, topN = 5) {
|
|
| 1067 |
if (data.loaded) {
|
| 1068 |
statusDot.classList.remove('loading');
|
| 1069 |
statusText.textContent = data.using_community ? 'Ready — Community Model (cpu)' : `Ready (${data.device})`;
|
|
|
|
|
|
|
|
|
|
| 1070 |
classifyBtn.disabled = false;
|
| 1071 |
usingCommunity = data.using_community;
|
| 1072 |
updateCommunityUI(data.community_available);
|
|
@@ -1367,7 +1531,483 @@ async function classify(text, topN = 5) {
|
|
| 1367 |
populateDocsProviders();
|
| 1368 |
};
|
| 1369 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1370 |
checkStatus();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1371 |
</script>
|
| 1372 |
</body>
|
| 1373 |
</html>
|
|
|
|
| 653 |
animation: fadeIn 0.3s ease;
|
| 654 |
}
|
| 655 |
|
| 656 |
+
.format-option:hover {
|
| 657 |
+
border-color: var(--border-light) !important;
|
| 658 |
+
background: var(--bg-elevated) !important;
|
| 659 |
+
}
|
| 660 |
+
|
| 661 |
+
.format-option:has(input:checked) {
|
| 662 |
+
border-color: var(--accent-muted) !important;
|
| 663 |
+
background: rgba(232, 93, 4, 0.08) !important;
|
| 664 |
+
}
|
| 665 |
+
|
| 666 |
@media (max-width: 600px) {
|
| 667 |
.container {
|
| 668 |
padding: 1rem;
|
|
|
|
| 697 |
|
| 698 |
<div class="tabs">
|
| 699 |
<button class="tab active" data-tab="classify">Classify</button>
|
| 700 |
+
<button class="tab" data-tab="dataset">Evaluate Dataset</button>
|
| 701 |
<button class="tab" data-tab="docs">API Docs</button>
|
| 702 |
</div>
|
| 703 |
|
|
|
|
| 706 |
<div class="status-indicator">
|
| 707 |
<span class="status-dot" id="statusDot"></span>
|
| 708 |
<span id="statusText">Connecting to API...</span>
|
| 709 |
+
<span id="providerCount" style="margin-left:auto;font-size:0.75rem;color:var(--text-muted);"></span>
|
| 710 |
</div>
|
| 711 |
|
| 712 |
<div class="card">
|
|
|
|
| 763 |
</div>
|
| 764 |
</div>
|
| 765 |
|
| 766 |
+
<!-- ═══ Dataset Evaluation Tab ═══ -->
|
| 767 |
+
<div class="tab-content" id="tab-dataset">
|
| 768 |
+
<div class="card">
|
| 769 |
+
<div class="card-label">HuggingFace Dataset ID</div>
|
| 770 |
+
<input type="text" id="datasetId" placeholder="e.g., ianncity/Hunter-Alpha-SFT-300000x"
|
| 771 |
+
style="width:100%; padding:0.75rem 1rem; background:var(--bg-tertiary); border:1px solid var(--border); border-radius:8px; color:var(--text-primary); font-family:'Outfit',sans-serif;font-size:0.9rem;margin-bottom:0.75rem;">
|
| 772 |
+
<div style="display:flex;gap:0.75rem;flex-wrap:wrap;align-items:center;">
|
| 773 |
+
<button class="btn btn-secondary" id="checkDatasetBtn">Check Format</button>
|
| 774 |
+
<button class="btn btn-primary" id="evaluateDatasetBtn" disabled>Evaluate</button>
|
| 775 |
+
<input type="number" id="maxSamples" value="1000" min="1" max="10000"
|
| 776 |
+
style="width:100px;padding:0.5rem;background:var(--bg-tertiary);border:1px solid var(--border);border-radius:8px;color:var(--text-primary);font-size:0.85rem;">
|
| 777 |
+
<span style="color:var(--text-muted);font-size:0.8rem;">max samples</span>
|
| 778 |
+
</div>
|
| 779 |
+
|
| 780 |
+
<div style="margin-top:1rem;padding-top:1rem;border-top:1px solid var(--border);">
|
| 781 |
+
<div style="display:flex;justify-content:space-between;align-items:center;margin-bottom:0.5rem;">
|
| 782 |
+
<div class="card-label" style="margin-bottom:0;">Dataset Format</div>
|
| 783 |
+
<label style="display:flex;align-items:center;gap:0.5rem;cursor:pointer;">
|
| 784 |
+
<input type="checkbox" id="useCustomFormat" style="width:16px;height:16px;accent-color:var(--accent);">
|
| 785 |
+
<span style="font-size:0.8rem;color:var(--text-secondary);">Use custom format</span>
|
| 786 |
+
</label>
|
| 787 |
+
</div>
|
| 788 |
+
|
| 789 |
+
<div id="customFormatSection" style="display:none;background:var(--bg-tertiary);border:1px solid var(--border);border-radius:8px;padding:1rem;">
|
| 790 |
+
<div style="font-size:0.75rem;color:var(--text-muted);margin-bottom:0.75rem;">
|
| 791 |
+
How is your dataset structured? Choose a format below:
|
| 792 |
+
</div>
|
| 793 |
+
|
| 794 |
+
<div style="display:grid;gap:0.5rem;margin-bottom:1rem;">
|
| 795 |
+
<label style="display:flex;align-items:center;gap:0.5rem;cursor:pointer;padding:0.5rem;background:var(--bg-secondary);border-radius:6px;border:1px solid transparent;" class="format-option" data-format="auto">
|
| 796 |
+
<input type="radio" name="customFormatType" value="auto" checked style="accent-color:var(--accent);">
|
| 797 |
+
<div>
|
| 798 |
+
<div style="font-weight:500;font-size:0.85rem;">Auto-detect</div>
|
| 799 |
+
<div style="font-size:0.75rem;color:var(--text-muted);">Try to detect format automatically</div>
|
| 800 |
+
</div>
|
| 801 |
+
</label>
|
| 802 |
+
|
| 803 |
+
<label style="display:flex;align-items:center;gap:0.5rem;cursor:pointer;padding:0.5rem;background:var(--bg-secondary);border-radius:6px;border:1px solid transparent;" class="format-option" data-format="column">
|
| 804 |
+
<input type="radio" name="customFormatType" value="column" style="accent-color:var(--accent);">
|
| 805 |
+
<div>
|
| 806 |
+
<div style="font-weight:500;font-size:0.85rem;">Single column</div>
|
| 807 |
+
<div style="font-size:0.75rem;color:var(--text-muted);">Extract from one field (e.g., "response")</div>
|
| 808 |
+
</div>
|
| 809 |
+
</label>
|
| 810 |
+
|
| 811 |
+
<label style="display:flex;align-items:center;gap:0.5rem;cursor:pointer;padding:0.5rem;background:var(--bg-secondary);border-radius:6px;border:1px solid transparent;" class="format-option" data-format="two_column">
|
| 812 |
+
<input type="radio" name="customFormatType" value="two_column" style="accent-color:var(--accent);">
|
| 813 |
+
<div>
|
| 814 |
+
<div style="font-weight:500;font-size:0.85rem;">Two columns</div>
|
| 815 |
+
<div style="font-size:0.75rem;color:var(--text-muted);">User column + Assistant column</div>
|
| 816 |
+
</div>
|
| 817 |
+
</label>
|
| 818 |
+
|
| 819 |
+
<label style="display:flex;align-items:center;gap:0.5rem;cursor:pointer;padding:0.5rem;background:var(--bg-secondary);border-radius:6px;border:1px solid transparent;" class="format-option" data-format="pattern">
|
| 820 |
+
<input type="radio" name="customFormatType" value="pattern" style="accent-color:var(--accent);">
|
| 821 |
+
<div>
|
| 822 |
+
<div style="font-weight:500;font-size:0.85rem;">Text markers</div>
|
| 823 |
+
<div style="font-size:0.75rem;color:var(--text-muted);">Extract between text markers</div>
|
| 824 |
+
</div>
|
| 825 |
+
</label>
|
| 826 |
+
</div>
|
| 827 |
+
|
| 828 |
+
<div id="columnInput" style="display:none;">
|
| 829 |
+
<input type="text" id="customColumnName" placeholder="e.g., response, output, completion"
|
| 830 |
+
style="width:100%; padding:0.6rem 0.75rem; background:var(--bg-primary); border:1px solid var(--border); border-radius:6px; color:var(--text-primary); font-family:'JetBrains Mono',monospace;font-size:0.85rem;">
|
| 831 |
+
</div>
|
| 832 |
+
|
| 833 |
+
<div id="twoColumnInput" style="display:none;">
|
| 834 |
+
<div style="display:flex;gap:0.5rem;flex-wrap:wrap;">
|
| 835 |
+
<input type="text" id="customUserColumn" placeholder="User column (e.g., prompt, input)"
|
| 836 |
+
style="flex:1;min-width:150px;padding:0.6rem 0.75rem; background:var(--bg-primary); border:1px solid var(--border); border-radius:6px; color:var(--text-primary); font-family:'JetBrains Mono',monospace;font-size:0.85rem;">
|
| 837 |
+
<input type="text" id="customAssistantColumn" placeholder="Assistant column (e.g., response, output)"
|
| 838 |
+
style="flex:1;min-width:150px;padding:0.6rem 0.75rem; background:var(--bg-primary); border:1px solid var(--border); border-radius:6px; color:var(--text-primary); font-family:'JetBrains Mono',monospace;font-size:0.85rem;">
|
| 839 |
+
</div>
|
| 840 |
+
</div>
|
| 841 |
+
|
| 842 |
+
<div id="patternInput" style="display:none;">
|
| 843 |
+
<input type="text" id="customPattern" placeholder="e.g., user:[INST] assistant:[/INST] or [startuser] [startassistant]"
|
| 844 |
+
style="width:100%; padding:0.6rem 0.75rem; background:var(--bg-primary); border:1px solid var(--border); border-radius:6px; color:var(--text-primary); font-family:'JetBrains Mono',monospace;font-size:0.85rem;">
|
| 845 |
+
<div style="font-size:0.7rem;color:var(--text-muted);margin-top:0.5rem;">
|
| 846 |
+
Use <code style="background:var(--bg-primary);padding:0.1rem 0.3rem;border-radius:3px;">[startuser]</code> and <code style="background:var(--bg-primary);padding:0.1rem 0.3rem;border-radius:3px;">[startassistant]</code> as placeholders, or raw text like <code style="background:var(--bg-primary);padding:0.1rem 0.3rem;border-radius:3px;">user: assistant:</code>
|
| 847 |
+
</div>
|
| 848 |
+
</div>
|
| 849 |
+
|
| 850 |
+
<div style="margin-top:0.75rem;padding:0.5rem;background:var(--bg-primary);border-radius:6px;">
|
| 851 |
+
<div style="font-size:0.7rem;color:var(--text-muted);margin-bottom:0.25rem;">Format string preview:</div>
|
| 852 |
+
<code id="formatPreview" style="font-family:'JetBrains Mono',monospace;font-size:0.8rem;color:var(--accent);">column: response</code>
|
| 853 |
+
</div>
|
| 854 |
+
</div>
|
| 855 |
+
</div>
|
| 856 |
+
</div>
|
| 857 |
+
|
| 858 |
+
<div id="datasetFormatInfo" class="card" style="display:none;">
|
| 859 |
+
<div class="card-label">Dataset Format</div>
|
| 860 |
+
<div id="formatName" style="font-weight:600;margin-bottom:0.5rem;"></div>
|
| 861 |
+
<div id="formatDescription" style="color:var(--text-secondary);font-size:0.9rem;"></div>
|
| 862 |
+
<div style="margin-top:0.75rem;display:flex;gap:1rem;">
|
| 863 |
+
<div class="stat" style="padding:0.5rem 1rem;min-width:auto;">
|
| 864 |
+
<div class="stat-value" id="totalRows" style="font-size:1rem;">-</div>
|
| 865 |
+
<div class="stat-label" style="font-size:0.65rem;">Total Rows</div>
|
| 866 |
+
</div>
|
| 867 |
+
<div class="stat" style="padding:0.5rem 1rem;min-width:auto;">
|
| 868 |
+
<div class="stat-value" id="extractedCount" style="font-size:1rem;">-</div>
|
| 869 |
+
<div class="stat-label" style="font-size:0.65rem;">Responses</div>
|
| 870 |
+
</div>
|
| 871 |
+
</div>
|
| 872 |
+
<div id="formatError" style="display:none;margin-top:1rem;padding:0.75rem;background:rgba(232,93,4,0.12);border:1px solid var(--accent-muted);border-radius:8px;color:var(--text-secondary);font-size:0.85rem;"></div>
|
| 873 |
+
</div>
|
| 874 |
+
|
| 875 |
+
<div id="datasetResults" class="card" style="display:none;">
|
| 876 |
+
<div class="card-label">Evaluation Results</div>
|
| 877 |
+
|
| 878 |
+
<div style="display:flex;gap:1rem;margin-bottom:1.5rem;flex-wrap:wrap;">
|
| 879 |
+
<div class="stat">
|
| 880 |
+
<div class="stat-value" id="evalTotal">-</div>
|
| 881 |
+
<div class="stat-label">Samples</div>
|
| 882 |
+
</div>
|
| 883 |
+
<div class="stat">
|
| 884 |
+
<div class="stat-value" id="evalLikelyProvider">-</div>
|
| 885 |
+
<div class="stat-label">Likely Provider</div>
|
| 886 |
+
</div>
|
| 887 |
+
<div class="stat">
|
| 888 |
+
<div class="stat-value" id="evalAvgConfidence">-</div>
|
| 889 |
+
<div class="stat-label">Avg Confidence</div>
|
| 890 |
+
</div>
|
| 891 |
+
</div>
|
| 892 |
+
|
| 893 |
+
<div class="card-label" style="margin-top:1rem;">Provider Distribution</div>
|
| 894 |
+
<div id="providerDistribution"></div>
|
| 895 |
+
|
| 896 |
+
<div class="card-label" style="margin-top:1.5rem;">Top Providers (by cumulative score)</div>
|
| 897 |
+
<div id="topProvidersList"></div>
|
| 898 |
+
</div>
|
| 899 |
+
|
| 900 |
+
<div id="datasetLoading" style="display:none;text-align:center;padding:2rem;">
|
| 901 |
+
<span class="loading" style="width:24px;height:24px;border-width:3px;"></span>
|
| 902 |
+
<div style="margin-top:1rem;color:var(--text-secondary);" id="datasetLoadingText">Evaluating...</div>
|
| 903 |
+
</div>
|
| 904 |
+
|
| 905 |
+
<div class="docs-section" style="margin-top:2rem;">
|
| 906 |
+
<h2 style="font-size:1rem;font-weight:500;color:var(--text-secondary);margin-bottom:0.75rem;">Supported Dataset Formats</h2>
|
| 907 |
+
<div id="supportedFormatsList" style="display:grid;grid-template-columns:repeat(auto-fill,minmax(250px,1fr));gap:0.75rem;"></div>
|
| 908 |
+
</div>
|
| 909 |
+
|
| 910 |
+
<div class="card" style="margin-top:2rem;">
|
| 911 |
+
<div class="card-label" style="display:flex;justify-content:space-between;align-items:center;">
|
| 912 |
+
<span>Your Evaluated Datasets</span>
|
| 913 |
+
<button class="btn btn-secondary" id="clearHistoryBtn" style="padding:0.4rem 0.75rem;font-size:0.75rem;">Clear History</button>
|
| 914 |
+
</div>
|
| 915 |
+
<div id="datasetHistory" style="color:var(--text-muted);font-size:0.85rem;">Loading...</div>
|
| 916 |
+
</div>
|
| 917 |
+
</div>
|
| 918 |
+
|
| 919 |
<!-- ═══ API Docs Tab ═══ -->
|
| 920 |
<div class="tab-content" id="tab-docs">
|
| 921 |
|
|
|
|
| 1173 |
|
| 1174 |
<div class="footer">
|
| 1175 |
<p>AIFinder — Train on corrections to improve accuracy</p>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1176 |
</div>
|
| 1177 |
</div>
|
| 1178 |
|
|
|
|
| 1210 |
const toast = document.getElementById('toast');
|
| 1211 |
const statusDot = document.getElementById('statusDot');
|
| 1212 |
const statusText = document.getElementById('statusText');
|
| 1213 |
+
const providerCountEl = document.getElementById('providerCount');
|
| 1214 |
let usingCommunity = false;
|
| 1215 |
|
| 1216 |
function showToast(message, type = 'info') {
|
|
|
|
| 1228 |
if (data.loaded) {
|
| 1229 |
statusDot.classList.remove('loading');
|
| 1230 |
statusText.textContent = data.using_community ? 'Ready — Community Model (cpu)' : `Ready (${data.device})`;
|
| 1231 |
+
if (data.num_providers) {
|
| 1232 |
+
providerCountEl.textContent = `${data.num_providers} providers`;
|
| 1233 |
+
}
|
| 1234 |
classifyBtn.disabled = false;
|
| 1235 |
usingCommunity = data.using_community;
|
| 1236 |
updateCommunityUI(data.community_available);
|
|
|
|
| 1531 |
populateDocsProviders();
|
| 1532 |
};
|
| 1533 |
|
| 1534 |
+
// ── Dataset Evaluation ──
|
| 1535 |
+
const datasetIdInput = document.getElementById('datasetId');
|
| 1536 |
+
const maxSamplesInput = document.getElementById('maxSamples');
|
| 1537 |
+
const checkDatasetBtn = document.getElementById('checkDatasetBtn');
|
| 1538 |
+
const evaluateDatasetBtn = document.getElementById('evaluateDatasetBtn');
|
| 1539 |
+
const datasetFormatInfo = document.getElementById('datasetFormatInfo');
|
| 1540 |
+
const formatName = document.getElementById('formatName');
|
| 1541 |
+
const formatDescription = document.getElementById('formatDescription');
|
| 1542 |
+
const totalRowsEl = document.getElementById('totalRows');
|
| 1543 |
+
const extractedCountEl = document.getElementById('extractedCount');
|
| 1544 |
+
const formatError = document.getElementById('formatError');
|
| 1545 |
+
const datasetResults = document.getElementById('datasetResults');
|
| 1546 |
+
const datasetLoading = document.getElementById('datasetLoading');
|
| 1547 |
+
const datasetLoadingText = document.getElementById('datasetLoadingText');
|
| 1548 |
+
const datasetHistory = document.getElementById('datasetHistory');
|
| 1549 |
+
|
| 1550 |
+
let currentDatasetInfo = null;
|
| 1551 |
+
let currentJobId = null;
|
| 1552 |
+
let jobPollingInterval = null;
|
| 1553 |
+
|
| 1554 |
+
function saveJobId(jobId) {
|
| 1555 |
+
localStorage.setItem('aifinder_current_job', jobId);
|
| 1556 |
+
}
|
| 1557 |
+
|
| 1558 |
+
function getSavedJobId() {
|
| 1559 |
+
return localStorage.getItem('aifinder_current_job');
|
| 1560 |
+
}
|
| 1561 |
+
|
| 1562 |
+
function clearSavedJobId() {
|
| 1563 |
+
localStorage.removeItem('aifinder_current_job');
|
| 1564 |
+
}
|
| 1565 |
+
|
| 1566 |
+
function generateApiKey() {
|
| 1567 |
+
const existing = localStorage.getItem('aifinder_api_key');
|
| 1568 |
+
if (existing) return existing;
|
| 1569 |
+
const key = 'usr_' + Math.random().toString(36).substring(2, 15) + Math.random().toString(36).substring(2, 15);
|
| 1570 |
+
localStorage.setItem('aifinder_api_key', key);
|
| 1571 |
+
return key;
|
| 1572 |
+
}
|
| 1573 |
+
|
| 1574 |
+
function getApiKey() {
|
| 1575 |
+
return localStorage.getItem('aifinder_api_key') || generateApiKey();
|
| 1576 |
+
}
|
| 1577 |
+
|
| 1578 |
+
getApiKey();
|
| 1579 |
+
|
| 1580 |
+
async function loadDatasetHistory() {
|
| 1581 |
+
const apiKey = getApiKey();
|
| 1582 |
+
if (!apiKey) {
|
| 1583 |
+
datasetHistory.innerHTML = '<span style="color:var(--text-muted);">No evaluated datasets yet.</span>';
|
| 1584 |
+
return;
|
| 1585 |
+
}
|
| 1586 |
+
|
| 1587 |
+
try {
|
| 1588 |
+
const res = await fetch(`${API_BASE}/api/datasets?api_key=${encodeURIComponent(apiKey)}`);
|
| 1589 |
+
const data = await res.json();
|
| 1590 |
+
|
| 1591 |
+
if (!data.datasets || data.datasets.length === 0) {
|
| 1592 |
+
datasetHistory.innerHTML = '<span style="color:var(--text-muted);">Your evaluated datasets will appear here. Start by checking a dataset format above.</span>';
|
| 1593 |
+
return;
|
| 1594 |
+
}
|
| 1595 |
+
|
| 1596 |
+
datasetHistory.innerHTML = data.datasets.map(ds => `
|
| 1597 |
+
<div style="display:flex;justify-content:space-between;align-items:center;padding:0.75rem;background:var(--bg-tertiary);border:1px solid var(--border);border-radius:8px;margin-bottom:0.5rem;cursor:pointer;"
|
| 1598 |
+
onclick="loadDatasetResult('${ds.job_id}')">
|
| 1599 |
+
<div>
|
| 1600 |
+
<div style="font-weight:500;">${ds.dataset_id}</div>
|
| 1601 |
+
<div style="font-size:0.75rem;color:var(--text-muted);">${ds.completed_at ? new Date(ds.completed_at).toLocaleString() : ''}</div>
|
| 1602 |
+
</div>
|
| 1603 |
+
<span style="padding:0.25rem 0.5rem;border-radius:4px;font-size:0.75rem;${ds.status === 'completed' ? 'background:var(--success-muted);color:var(--success);' : 'background:var(--accent-muted);color:var(--accent-hover);'}">${ds.status}</span>
|
| 1604 |
+
</div>
|
| 1605 |
+
`).join('');
|
| 1606 |
+
} catch (e) {
|
| 1607 |
+
datasetHistory.innerHTML = '<span style="color:var(--text-muted);">Failed to load history.</span>';
|
| 1608 |
+
}
|
| 1609 |
+
}
|
| 1610 |
+
|
| 1611 |
+
async function loadDatasetResult(jobId) {
|
| 1612 |
+
try {
|
| 1613 |
+
const res = await fetch(`${API_BASE}/api/dataset/job/${jobId}`);
|
| 1614 |
+
const data = await res.json();
|
| 1615 |
+
|
| 1616 |
+
if (data.status === 'completed' && data.results) {
|
| 1617 |
+
showEvaluationResults(data.results);
|
| 1618 |
+
} else if (data.status === 'failed') {
|
| 1619 |
+
showToast('Evaluation failed: ' + data.error);
|
| 1620 |
+
} else if (data.status === 'running' || data.status === 'pending') {
|
| 1621 |
+
datasetIdInput.value = data.dataset_id || '';
|
| 1622 |
+
currentJobId = jobId;
|
| 1623 |
+
saveJobId(currentJobId);
|
| 1624 |
+
datasetLoading.style.display = 'block';
|
| 1625 |
+
evaluateDatasetBtn.disabled = true;
|
| 1626 |
+
if (data.progress) {
|
| 1627 |
+
datasetLoadingText.textContent = `${data.progress.stage === 'downloading' ? 'Downloading' : data.progress.stage === 'evaluating' ? 'Evaluating' : 'Processing'}: ${data.progress.percent}%`;
|
| 1628 |
+
} else {
|
| 1629 |
+
datasetLoadingText.textContent = 'Evaluation running, please wait...';
|
| 1630 |
+
}
|
| 1631 |
+
startJobPolling();
|
| 1632 |
+
}
|
| 1633 |
+
} catch (e) {
|
| 1634 |
+
showToast('Error: ' + e.message);
|
| 1635 |
+
}
|
| 1636 |
+
}
|
| 1637 |
+
|
| 1638 |
+
function showEvaluationResults(data) {
|
| 1639 |
+
document.getElementById('evalTotal').textContent = data.extracted_count?.toLocaleString() || '-';
|
| 1640 |
+
document.getElementById('evalLikelyProvider').textContent = data.likely_provider || '-';
|
| 1641 |
+
document.getElementById('evalAvgConfidence').textContent = (data.average_confidence || 0) + '%';
|
| 1642 |
+
|
| 1643 |
+
const distContainer = document.getElementById('providerDistribution');
|
| 1644 |
+
distContainer.innerHTML = '';
|
| 1645 |
+
|
| 1646 |
+
const sortedProviders = Object.entries(data.provider_counts || {})
|
| 1647 |
+
.sort((a, b) => b[1].count - a[1].count);
|
| 1648 |
+
|
| 1649 |
+
for (const [provider, info] of sortedProviders) {
|
| 1650 |
+
const conf = data.provider_confidences?.[provider]?.average || 0;
|
| 1651 |
+
const html = `
|
| 1652 |
+
<div style="margin-bottom:1rem;">
|
| 1653 |
+
<div style="display:flex;justify-content:space-between;margin-bottom:0.25rem;">
|
| 1654 |
+
<span style="font-weight:500;">${provider}</span>
|
| 1655 |
+
<span style="color:var(--text-secondary);font-size:0.85rem;">${info.count} (${info.percentage}%) · ${conf}% avg</span>
|
| 1656 |
+
</div>
|
| 1657 |
+
<div class="result-bar">
|
| 1658 |
+
<div class="result-bar-fill" style="width:${info.percentage}%"></div>
|
| 1659 |
+
</div>
|
| 1660 |
+
</div>
|
| 1661 |
+
`;
|
| 1662 |
+
distContainer.innerHTML += html;
|
| 1663 |
+
}
|
| 1664 |
+
|
| 1665 |
+
const topContainer = document.getElementById('topProvidersList');
|
| 1666 |
+
topContainer.innerHTML = '';
|
| 1667 |
+
|
| 1668 |
+
const sortedTop = Object.entries(data.top_providers || {})
|
| 1669 |
+
.sort((a, b) => b[1] - a[1])
|
| 1670 |
+
.slice(0, 5);
|
| 1671 |
+
|
| 1672 |
+
for (const [provider, count] of sortedTop) {
|
| 1673 |
+
const conf = data.provider_confidences?.[provider]?.cumulative || 0;
|
| 1674 |
+
topContainer.innerHTML += `
|
| 1675 |
+
<div class="result-item">
|
| 1676 |
+
<span class="result-name">${provider}</span>
|
| 1677 |
+
<span class="result-percent">${conf.toFixed(2)} pts</span>
|
| 1678 |
+
</div>
|
| 1679 |
+
`;
|
| 1680 |
+
}
|
| 1681 |
+
|
| 1682 |
+
datasetResults.style.display = 'block';
|
| 1683 |
+
datasetLoading.style.display = 'none';
|
| 1684 |
+
}
|
| 1685 |
+
|
| 1686 |
+
function startJobPolling() {
|
| 1687 |
+
if (jobPollingInterval) clearInterval(jobPollingInterval);
|
| 1688 |
+
jobPollingInterval = setInterval(async () => {
|
| 1689 |
+
if (!currentJobId) return;
|
| 1690 |
+
try {
|
| 1691 |
+
const res = await fetch(`${API_BASE}/api/dataset/job/${currentJobId}`);
|
| 1692 |
+
const data = await res.json();
|
| 1693 |
+
console.log('Polling response:', data);
|
| 1694 |
+
|
| 1695 |
+
if (data.status === 'completed') {
|
| 1696 |
+
clearInterval(jobPollingInterval);
|
| 1697 |
+
jobPollingInterval = null;
|
| 1698 |
+
currentJobId = null;
|
| 1699 |
+
clearSavedJobId();
|
| 1700 |
+
showEvaluationResults(data.results);
|
| 1701 |
+
loadDatasetHistory();
|
| 1702 |
+
showToast('Evaluation complete!', 'success');
|
| 1703 |
+
} else if (data.status === 'failed') {
|
| 1704 |
+
clearInterval(jobPollingInterval);
|
| 1705 |
+
jobPollingInterval = null;
|
| 1706 |
+
currentJobId = null;
|
| 1707 |
+
clearSavedJobId();
|
| 1708 |
+
datasetLoading.style.display = 'none';
|
| 1709 |
+
evaluateDatasetBtn.disabled = false;
|
| 1710 |
+
showToast('Evaluation failed: ' + data.error);
|
| 1711 |
+
} else {
|
| 1712 |
+
const prog = data.progress;
|
| 1713 |
+
if (prog) {
|
| 1714 |
+
datasetLoadingText.textContent = `${prog.stage === 'downloading' ? 'Downloading' : prog.stage === 'evaluating' ? 'Evaluating' : 'Processing'}: ${prog.percent}%`;
|
| 1715 |
+
} else {
|
| 1716 |
+
datasetLoadingText.textContent = 'Evaluating... ' + (data.started_at ? new Date(data.started_at).toLocaleTimeString() : '');
|
| 1717 |
+
}
|
| 1718 |
+
}
|
| 1719 |
+
} catch (e) {
|
| 1720 |
+
console.error('Polling error:', e);
|
| 1721 |
+
}
|
| 1722 |
+
}, 2000);
|
| 1723 |
+
}
|
| 1724 |
+
|
| 1725 |
+
async function checkDatasetFormat() {
|
| 1726 |
+
const datasetId = datasetIdInput.value.trim();
|
| 1727 |
+
if (!datasetId) {
|
| 1728 |
+
showToast('Please enter a dataset ID');
|
| 1729 |
+
return;
|
| 1730 |
+
}
|
| 1731 |
+
|
| 1732 |
+
checkDatasetBtn.disabled = true;
|
| 1733 |
+
checkDatasetBtn.innerHTML = '<span class="loading"></span>';
|
| 1734 |
+
|
| 1735 |
+
const customFormat = buildFormatString();
|
| 1736 |
+
|
| 1737 |
+
try {
|
| 1738 |
+
const res = await fetch(`${API_BASE}/api/dataset/info`, {
|
| 1739 |
+
method: 'POST',
|
| 1740 |
+
headers: { 'Content-Type': 'application/json' },
|
| 1741 |
+
body: JSON.stringify({
|
| 1742 |
+
dataset_id: datasetId,
|
| 1743 |
+
max_samples: parseInt(maxSamplesInput.value) || 1000,
|
| 1744 |
+
custom_format: customFormat
|
| 1745 |
+
})
|
| 1746 |
+
});
|
| 1747 |
+
const data = await res.json();
|
| 1748 |
+
|
| 1749 |
+
currentDatasetInfo = data;
|
| 1750 |
+
|
| 1751 |
+
const formatDetectedButNoTexts = data.supported && (data.extracted_count === 0);
|
| 1752 |
+
|
| 1753 |
+
if (data.supported && !formatDetectedButNoTexts) {
|
| 1754 |
+
formatName.textContent = data.format_name || data.format || 'Unknown';
|
| 1755 |
+
formatDescription.textContent = data.format_description || '';
|
| 1756 |
+
totalRowsEl.textContent = data.total_rows?.toLocaleString() || '-';
|
| 1757 |
+
extractedCountEl.textContent = data.extracted_count?.toLocaleString() || '-';
|
| 1758 |
+
formatError.style.display = 'none';
|
| 1759 |
+
evaluateDatasetBtn.disabled = false;
|
| 1760 |
+
} else {
|
| 1761 |
+
if (formatDetectedButNoTexts) {
|
| 1762 |
+
formatName.textContent = data.format_name || data.format || 'Unknown';
|
| 1763 |
+
formatDescription.textContent = 'Format detected but no valid assistant responses found. Try a custom format below.';
|
| 1764 |
+
totalRowsEl.textContent = data.total_rows?.toLocaleString() || '-';
|
| 1765 |
+
extractedCountEl.textContent = '0';
|
| 1766 |
+
formatError.style.display = 'block';
|
| 1767 |
+
formatError.textContent = 'No valid assistant responses extracted (minimum 50 chars required). The detected format may not match the actual data structure.';
|
| 1768 |
+
} else {
|
| 1769 |
+
formatName.textContent = 'Unsupported Format';
|
| 1770 |
+
formatDescription.textContent = '';
|
| 1771 |
+
totalRowsEl.textContent = '-';
|
| 1772 |
+
extractedCountEl.textContent = '-';
|
| 1773 |
+
formatError.style.display = 'block';
|
| 1774 |
+
formatError.textContent = data.error || 'Unknown error';
|
| 1775 |
+
}
|
| 1776 |
+
evaluateDatasetBtn.disabled = true;
|
| 1777 |
+
|
| 1778 |
+
useCustomFormatCheckbox.checked = true;
|
| 1779 |
+
customFormatSection.style.display = 'block';
|
| 1780 |
+
showToast('Could not extract responses. Please specify a custom format below.');
|
| 1781 |
+
}
|
| 1782 |
+
|
| 1783 |
+
datasetFormatInfo.style.display = 'block';
|
| 1784 |
+
datasetResults.style.display = 'none';
|
| 1785 |
+
|
| 1786 |
+
} catch (e) {
|
| 1787 |
+
showToast('Error: ' + e.message);
|
| 1788 |
+
} finally {
|
| 1789 |
+
checkDatasetBtn.disabled = false;
|
| 1790 |
+
checkDatasetBtn.textContent = 'Check Format';
|
| 1791 |
+
}
|
| 1792 |
+
}
|
| 1793 |
+
|
| 1794 |
+
async function evaluateDataset() {
|
| 1795 |
+
const datasetId = datasetIdInput.value.trim();
|
| 1796 |
+
if (!datasetId || !currentDatasetInfo?.supported) return;
|
| 1797 |
+
|
| 1798 |
+
evaluateDatasetBtn.disabled = true;
|
| 1799 |
+
datasetLoading.style.display = 'block';
|
| 1800 |
+
datasetResults.style.display = 'none';
|
| 1801 |
+
datasetLoadingText.textContent = 'Starting evaluation...';
|
| 1802 |
+
|
| 1803 |
+
const apiKey = getApiKey();
|
| 1804 |
+
const customFormat = buildFormatString();
|
| 1805 |
+
|
| 1806 |
+
try {
|
| 1807 |
+
const res = await fetch(`${API_BASE}/api/dataset/evaluate`, {
|
| 1808 |
+
method: 'POST',
|
| 1809 |
+
headers: { 'Content-Type': 'application/json' },
|
| 1810 |
+
body: JSON.stringify({
|
| 1811 |
+
dataset_id: datasetId,
|
| 1812 |
+
max_samples: parseInt(maxSamplesInput.value) || 1000,
|
| 1813 |
+
api_key: apiKey || null,
|
| 1814 |
+
custom_format: customFormat
|
| 1815 |
+
})
|
| 1816 |
+
});
|
| 1817 |
+
const data = await res.json();
|
| 1818 |
+
console.log('Evaluate response:', data);
|
| 1819 |
+
|
| 1820 |
+
if (data.error) {
|
| 1821 |
+
showToast(data.error);
|
| 1822 |
+
datasetLoading.style.display = 'none';
|
| 1823 |
+
evaluateDatasetBtn.disabled = false;
|
| 1824 |
+
return;
|
| 1825 |
+
}
|
| 1826 |
+
|
| 1827 |
+
currentJobId = data.job_id;
|
| 1828 |
+
saveJobId(currentJobId);
|
| 1829 |
+
console.log('Job ID saved:', currentJobId);
|
| 1830 |
+
datasetLoadingText.textContent = 'Evaluation started. Processing in background...';
|
| 1831 |
+
|
| 1832 |
+
// Show info that user can close the page
|
| 1833 |
+
const closePageMsg = document.createElement('div');
|
| 1834 |
+
closePageMsg.style.cssText = 'margin-top:1rem;color:var(--text-muted);font-size:0.85rem;';
|
| 1835 |
+
closePageMsg.innerHTML = '✓ You can close this page — evaluation will continue in the background.';
|
| 1836 |
+
const loadingEl = document.getElementById('datasetLoading');
|
| 1837 |
+
loadingEl.querySelectorAll('.close-page-msg').forEach(el => el.remove());
|
| 1838 |
+
closePageMsg.className = 'close-page-msg';
|
| 1839 |
+
loadingEl.appendChild(closePageMsg);
|
| 1840 |
+
|
| 1841 |
+
startJobPolling();
|
| 1842 |
+
loadDatasetHistory();
|
| 1843 |
+
|
| 1844 |
+
} catch (e) {
|
| 1845 |
+
showToast('Error: ' + e.message);
|
| 1846 |
+
datasetLoading.style.display = 'none';
|
| 1847 |
+
evaluateDatasetBtn.disabled = false;
|
| 1848 |
+
}
|
| 1849 |
+
}
|
| 1850 |
+
|
| 1851 |
+
checkDatasetBtn.addEventListener('click', checkDatasetFormat);
|
| 1852 |
+
evaluateDatasetBtn.addEventListener('click', evaluateDataset);
|
| 1853 |
+
|
| 1854 |
+
document.getElementById('clearHistoryBtn').addEventListener('click', async () => {
|
| 1855 |
+
if (!confirm('Clear all dataset evaluation history?')) return;
|
| 1856 |
+
try {
|
| 1857 |
+
const res = await fetch(`${API_BASE}/api/datasets/clear`, {
|
| 1858 |
+
method: 'POST',
|
| 1859 |
+
headers: { 'Content-Type': 'application/json' },
|
| 1860 |
+
body: JSON.stringify({ api_key: getApiKey() })
|
| 1861 |
+
});
|
| 1862 |
+
const data = await res.json();
|
| 1863 |
+
if (data.error) {
|
| 1864 |
+
showToast(data.error);
|
| 1865 |
+
} else {
|
| 1866 |
+
clearSavedJobId();
|
| 1867 |
+
showToast(`Cleared ${data.cleared} datasets`, 'success');
|
| 1868 |
+
loadDatasetHistory();
|
| 1869 |
+
}
|
| 1870 |
+
} catch (e) {
|
| 1871 |
+
showToast('Error: ' + e.message);
|
| 1872 |
+
}
|
| 1873 |
+
});
|
| 1874 |
+
|
| 1875 |
+
datasetIdInput.addEventListener('keydown', (e) => {
|
| 1876 |
+
if (e.key === 'Enter') checkDatasetFormat();
|
| 1877 |
+
});
|
| 1878 |
+
|
| 1879 |
+
loadDatasetHistory();
|
| 1880 |
+
|
| 1881 |
+
// Load supported formats
|
| 1882 |
+
async function loadSupportedFormats() {
|
| 1883 |
+
try {
|
| 1884 |
+
const res = await fetch(`${API_BASE}/api/dataset/formats`);
|
| 1885 |
+
const data = await res.json();
|
| 1886 |
+
const container = document.getElementById('supportedFormatsList');
|
| 1887 |
+
container.innerHTML = data.formats.map(f => `
|
| 1888 |
+
<div style="background:var(--bg-tertiary);border:1px solid var(--border);border-radius:8px;padding:0.75rem;">
|
| 1889 |
+
<div style="font-weight:500;font-size:0.85rem;">${f.name}</div>
|
| 1890 |
+
<div style="font-size:0.75rem;color:var(--text-muted);margin-top:0.25rem;">${f.description}</div>
|
| 1891 |
+
</div>
|
| 1892 |
+
`).join('');
|
| 1893 |
+
} catch (e) {
|
| 1894 |
+
console.error('Failed to load formats:', e);
|
| 1895 |
+
}
|
| 1896 |
+
}
|
| 1897 |
+
|
| 1898 |
+
// ── Custom Format UI Handling ──
|
| 1899 |
+
const useCustomFormatCheckbox = document.getElementById('useCustomFormat');
|
| 1900 |
+
const customFormatSection = document.getElementById('customFormatSection');
|
| 1901 |
+
const formatPreview = document.getElementById('formatPreview');
|
| 1902 |
+
const columnInput = document.getElementById('columnInput');
|
| 1903 |
+
const twoColumnInput = document.getElementById('twoColumnInput');
|
| 1904 |
+
const patternInput = document.getElementById('patternInput');
|
| 1905 |
+
const customColumnName = document.getElementById('customColumnName');
|
| 1906 |
+
const customUserColumn = document.getElementById('customUserColumn');
|
| 1907 |
+
const customAssistantColumn = document.getElementById('customAssistantColumn');
|
| 1908 |
+
const customPattern = document.getElementById('customPattern');
|
| 1909 |
+
|
| 1910 |
+
function buildFormatString() {
|
| 1911 |
+
if (!useCustomFormatCheckbox.checked) return null;
|
| 1912 |
+
|
| 1913 |
+
const formatType = document.querySelector('input[name="customFormatType"]:checked')?.value || 'auto';
|
| 1914 |
+
|
| 1915 |
+
if (formatType === 'auto') return null;
|
| 1916 |
+
|
| 1917 |
+
if (formatType === 'column') {
|
| 1918 |
+
const col = customColumnName.value.trim();
|
| 1919 |
+
return col ? `column: ${col}` : null;
|
| 1920 |
+
}
|
| 1921 |
+
|
| 1922 |
+
if (formatType === 'two_column') {
|
| 1923 |
+
const userCol = customUserColumn.value.trim();
|
| 1924 |
+
const assistantCol = customAssistantColumn.value.trim();
|
| 1925 |
+
if (assistantCol) {
|
| 1926 |
+
return userCol ? `column: ${userCol}, column: ${assistantCol}` : `column: ${assistantCol}`;
|
| 1927 |
+
}
|
| 1928 |
+
return null;
|
| 1929 |
+
}
|
| 1930 |
+
|
| 1931 |
+
if (formatType === 'pattern') {
|
| 1932 |
+
const pat = customPattern.value.trim();
|
| 1933 |
+
if (!pat) return null;
|
| 1934 |
+
|
| 1935 |
+
if (pat.includes('[startuser]') && pat.includes('[startassistant]')) {
|
| 1936 |
+
return pat;
|
| 1937 |
+
}
|
| 1938 |
+
|
| 1939 |
+
const parts = pat.split(/\s+/);
|
| 1940 |
+
if (parts.length >= 2) {
|
| 1941 |
+
return `pattern: ${parts[0]}, pattern: ${parts[1]}`;
|
| 1942 |
+
}
|
| 1943 |
+
return `column: ${pat}`;
|
| 1944 |
+
}
|
| 1945 |
+
|
| 1946 |
+
return null;
|
| 1947 |
+
}
|
| 1948 |
+
|
| 1949 |
+
function updateFormatPreview() {
|
| 1950 |
+
const fmt = buildFormatString();
|
| 1951 |
+
formatPreview.textContent = fmt || '(auto-detect)';
|
| 1952 |
+
formatPreview.style.color = fmt ? 'var(--accent)' : 'var(--text-muted)';
|
| 1953 |
+
}
|
| 1954 |
+
|
| 1955 |
+
useCustomFormatCheckbox?.addEventListener('change', () => {
|
| 1956 |
+
customFormatSection.style.display = useCustomFormatCheckbox.checked ? 'block' : 'none';
|
| 1957 |
+
updateFormatPreview();
|
| 1958 |
+
});
|
| 1959 |
+
|
| 1960 |
+
document.querySelectorAll('input[name="customFormatType"]').forEach(radio => {
|
| 1961 |
+
radio.addEventListener('change', (e) => {
|
| 1962 |
+
columnInput.style.display = e.target.value === 'column' ? 'block' : 'none';
|
| 1963 |
+
twoColumnInput.style.display = e.target.value === 'two_column' ? 'block' : 'none';
|
| 1964 |
+
patternInput.style.display = e.target.value === 'pattern' ? 'block' : 'none';
|
| 1965 |
+
updateFormatPreview();
|
| 1966 |
+
});
|
| 1967 |
+
});
|
| 1968 |
+
|
| 1969 |
+
[customColumnName, customUserColumn, customAssistantColumn, customPattern].forEach(input => {
|
| 1970 |
+
input?.addEventListener('input', updateFormatPreview);
|
| 1971 |
+
});
|
| 1972 |
+
|
| 1973 |
checkStatus();
|
| 1974 |
+
|
| 1975 |
+
async function restoreJobState() {
|
| 1976 |
+
const savedJobId = getSavedJobId();
|
| 1977 |
+
if (!savedJobId) return;
|
| 1978 |
+
console.log('Restoring job state, savedJobId:', savedJobId);
|
| 1979 |
+
try {
|
| 1980 |
+
const res = await fetch(`${API_BASE}/api/dataset/job/${savedJobId}`);
|
| 1981 |
+
const data = await res.json();
|
| 1982 |
+
console.log('Job data:', data);
|
| 1983 |
+
|
| 1984 |
+
if (data.status === 'running' || data.status === 'pending') {
|
| 1985 |
+
currentJobId = savedJobId;
|
| 1986 |
+
datasetIdInput.value = data.dataset_id || '';
|
| 1987 |
+
datasetLoading.style.display = 'block';
|
| 1988 |
+
evaluateDatasetBtn.disabled = true;
|
| 1989 |
+
|
| 1990 |
+
const prog = data.progress;
|
| 1991 |
+
console.log('Progress:', prog);
|
| 1992 |
+
if (prog) {
|
| 1993 |
+
datasetLoadingText.textContent = `${prog.stage === 'downloading' ? 'Downloading' : prog.stage === 'evaluating' ? 'Evaluating' : 'Processing'}: ${prog.percent}%`;
|
| 1994 |
+
} else {
|
| 1995 |
+
datasetLoadingText.textContent = 'Starting evaluation...';
|
| 1996 |
+
}
|
| 1997 |
+
|
| 1998 |
+
startJobPolling();
|
| 1999 |
+
} else if (data.status === 'completed') {
|
| 2000 |
+
clearSavedJobId();
|
| 2001 |
+
showEvaluationResults(data.results);
|
| 2002 |
+
} else if (data.status === 'failed') {
|
| 2003 |
+
clearSavedJobId();
|
| 2004 |
+
}
|
| 2005 |
+
} catch (e) {
|
| 2006 |
+
console.error('Restore error:', e);
|
| 2007 |
+
clearSavedJobId();
|
| 2008 |
+
}
|
| 2009 |
+
}
|
| 2010 |
+
restoreJobState();
|
| 2011 |
</script>
|
| 2012 |
</body>
|
| 2013 |
</html>
|