| """ |
| AIFinder Flask API |
| Serves the trained sklearn ensemble via the AIFinder inference class. |
| """ |
|
|
| import os |
| import re |
| import shutil |
| import uuid |
| import threading |
| from collections import defaultdict |
| from datetime import datetime |
|
|
| import joblib |
| import numpy as np |
| from sklearn.ensemble import RandomForestClassifier |
| from flask import Flask, jsonify, request, send_from_directory, render_template |
| from flask_cors import CORS |
| from flask_limiter import Limiter |
| from flask_limiter.util import get_remote_address |
| from tqdm import tqdm |
|
|
| from config import MODEL_DIR |
| from inference import AIFinder |
|
|
| STYLE_MODEL_DIR = os.path.join(MODEL_DIR, "style") |
| from dataset_evaluator import load_dataset_texts, get_supported_formats |
|
|
| app = Flask(__name__) |
| CORS(app) |
| limiter = Limiter(get_remote_address, app=app) |
|
|
| finder: AIFinder | None = None |
| community_finder: AIFinder | None = None |
| using_community = False |
|
|
| DEFAULT_TOP_N = 4 |
| COMMUNITY_DIR = os.path.join(MODEL_DIR, "community") |
| CORRECTIONS_FILE = os.path.join(COMMUNITY_DIR, "corrections.joblib") |
| CORRECTION_MODEL_FILE = os.path.join(COMMUNITY_DIR, "correction_rf_4provider.joblib") |
| JOBS_FILE = os.path.join(MODEL_DIR, "jobs.joblib") |
| corrections: list[dict] = [] |
|
|
| jobs: dict[str, dict] = {} |
|
|
|
|
| def _copy_base_models_to_community(): |
| """Copy base models from style model to community directory if not already present.""" |
| base_files = [ |
| "rf_4provider.joblib", |
| "pipeline_4provider.joblib", |
| "enc_4provider.joblib", |
| ] |
| for fname in base_files: |
| src = os.path.join(STYLE_MODEL_DIR, fname) |
| dst = os.path.join(COMMUNITY_DIR, fname) |
| if os.path.exists(src) and not os.path.exists(dst): |
| shutil.copy(src, dst) |
|
|
|
|
| def _update_job_progress(job_id, current, total, stage): |
| """Update progress for a job.""" |
| if job_id in jobs: |
| jobs[job_id]["progress"] = { |
| "current": current, |
| "total": total, |
| "stage": stage, |
| "percent": round((current / total * 100), 1) if total > 0 else 0, |
| } |
| _save_jobs() |
|
|
|
|
| def _save_jobs(): |
| """Persist jobs to disk.""" |
| joblib.dump(jobs, JOBS_FILE) |
|
|
|
|
| def _active_finder(): |
| return community_finder if using_community and community_finder else finder |
|
|
|
|
| def _strip_think_tags(text): |
| text = re.sub(r"<think(?:ing)?>.*?</think(?:ing)?>", "", text, flags=re.DOTALL) |
| return text.strip() |
|
|
|
|
| @app.route("/") |
| def index(): |
| return render_template("index.html") |
|
|
|
|
| @app.route("/api/classify", methods=["POST"]) |
| @app.route("/v1/classify", methods=["POST"]) |
| @limiter.limit("60/minute") |
| def v1_classify(): |
| data = request.get_json(silent=True) |
| if not data or "text" not in data: |
| return jsonify({"error": "Request body must be JSON with a 'text' field."}), 400 |
|
|
| raw_text = data["text"] |
| text = _strip_think_tags(raw_text) |
| af = _active_finder() |
| top_n = min(data.get("top_n", DEFAULT_TOP_N), len(af.le.classes_)) |
|
|
| if not isinstance(top_n, int) or top_n < 1: |
| top_n = DEFAULT_TOP_N |
|
|
| if len(text) < 20: |
| return jsonify( |
| { |
| "error": "Text too short (minimum 20 characters after stripping think tags)." |
| } |
| ), 400 |
|
|
| proba = af.predict_proba(text) |
| sorted_providers = sorted(proba.items(), key=lambda x: x[1], reverse=True)[:top_n] |
|
|
| top_providers = [ |
| {"name": name, "confidence": round(float(conf * 100), 2)} |
| for name, conf in sorted_providers |
| ] |
|
|
| return jsonify( |
| { |
| "provider": top_providers[0]["name"], |
| "confidence": top_providers[0]["confidence"], |
| "top_providers": top_providers, |
| } |
| ) |
|
|
|
|
| @app.route("/api/correct", methods=["POST"]) |
| def correct(): |
| global community_finder |
| data = request.get_json(silent=True) |
| if not data or "text" not in data or "correct_provider" not in data: |
| return jsonify({"error": "Need 'text' and 'correct_provider'."}), 400 |
|
|
| provider = data["correct_provider"] |
| if provider not in list(finder.le.classes_): |
| return jsonify({"error": f"Unknown provider: {provider}"}), 400 |
|
|
| text = _strip_think_tags(data["text"]) |
| corrections.append({"text": text, "provider": provider}) |
|
|
| _copy_base_models_to_community() |
|
|
| if len(corrections) > 0: |
| texts = [c["text"] for c in corrections] |
| providers = [c["provider"] for c in corrections] |
| X = finder.pipeline.transform(texts) |
| y = finder.le.transform(providers) |
|
|
| correction_rf = RandomForestClassifier( |
| n_estimators=100, random_state=42, n_jobs=-1 |
| ) |
| correction_rf.fit(X, y) |
| joblib.dump([correction_rf], CORRECTION_MODEL_FILE) |
|
|
| joblib.dump(corrections, CORRECTIONS_FILE) |
|
|
| community_finder = AIFinder(model_dir=COMMUNITY_DIR) |
|
|
| return jsonify({"status": "ok", "loss": 0.0, "corrections": len(corrections)}) |
|
|
|
|
| @app.route("/api/save", methods=["POST"]) |
| def save_model(): |
| if community_finder is None: |
| return jsonify({"error": "No community model trained yet."}), 400 |
| filename = "community_rf_4provider.joblib" |
| return jsonify({"status": "ok", "filename": filename}) |
|
|
|
|
| @app.route("/api/toggle_community", methods=["POST"]) |
| def toggle_community(): |
| global using_community |
| data = request.get_json(silent=True) or {} |
| using_community = bool(data.get("enabled", not using_community)) |
| return jsonify( |
| {"using_community": using_community, "available": community_finder is not None} |
| ) |
|
|
|
|
| @app.route("/models/<filename>") |
| def download_model(filename): |
| if filename.startswith("community_"): |
| return send_from_directory(COMMUNITY_DIR, filename.replace("community_", "", 1)) |
| return send_from_directory(MODEL_DIR, filename) |
|
|
|
|
| @app.route("/api/status", methods=["GET"]) |
| def status(): |
| af = _active_finder() |
| return jsonify( |
| { |
| "loaded": af is not None, |
| "device": "cpu", |
| "providers": list(af.le.classes_) if af else [], |
| "num_providers": len(af.le.classes_) if af else 0, |
| "using_community": using_community, |
| "community_available": community_finder is not None, |
| "corrections_count": len(corrections), |
| } |
| ) |
|
|
|
|
| @app.route("/api/providers", methods=["GET"]) |
| def providers(): |
| return jsonify( |
| { |
| "providers": list(finder.le.classes_) if finder else [], |
| } |
| ) |
|
|
|
|
| @app.route("/api/dataset/info", methods=["POST"]) |
| def dataset_info(): |
| """Get info about a dataset without evaluating.""" |
| data = request.get_json(silent=True) |
| if not data or "dataset_id" not in data: |
| return jsonify({"error": "Request must include 'dataset_id'"}), 400 |
|
|
| dataset_id = data["dataset_id"] |
| max_samples = data.get("max_samples", 1000) |
| evaluate = data.get("evaluate", False) |
| api_key = data.get("api_key") |
| custom_format = data.get("custom_format") |
|
|
| result = load_dataset_texts( |
| dataset_id, max_samples=max_samples, sample_size=1, custom_format=custom_format |
| ) |
|
|
| response = { |
| "dataset_id": dataset_id, |
| "total_rows": result["total_rows"], |
| "extracted_count": len(result["texts"]), |
| "format": result["format"], |
| "format_name": result["format_info"]["name"] if result["format_info"] else None, |
| "format_description": result["format_info"]["description"] |
| if result["format_info"] |
| else None, |
| "supported": result["supported"], |
| "error": result["error"], |
| "custom_format": custom_format, |
| } |
|
|
| if evaluate and result["supported"]: |
| job_id = str(uuid.uuid4()) |
| jobs[job_id] = { |
| "job_id": job_id, |
| "dataset_id": dataset_id, |
| "max_samples": max_samples, |
| "status": "pending", |
| "created_at": datetime.utcnow().isoformat(), |
| "api_key": api_key, |
| } |
| _save_jobs() |
|
|
| thread = threading.Thread( |
| target=_run_evaluation_job, |
| args=(job_id, dataset_id, max_samples, api_key, custom_format), |
| ) |
| thread.daemon = True |
| thread.start() |
|
|
| response["job_id"] = job_id |
| response["status"] = "pending" |
| response["message"] = "Evaluation started in background." |
| response["custom_format"] = custom_format |
|
|
| return jsonify(response) |
|
|
|
|
| def _run_evaluation_job( |
| job_id: str, |
| dataset_id: str, |
| max_samples: int, |
| api_key: str | None, |
| custom_format: str | None = None, |
| ): |
| """Background task to run dataset evaluation.""" |
| jobs[job_id]["status"] = "running" |
| jobs[job_id]["started_at"] = datetime.utcnow().isoformat() |
| jobs[job_id]["custom_format"] = custom_format |
| _save_jobs() |
|
|
| progress_cb = lambda c, t, s: _update_job_progress(job_id, c, t, s) |
|
|
| try: |
| load_result = load_dataset_texts( |
| dataset_id, |
| max_samples=max_samples, |
| progress_callback=progress_cb, |
| custom_format=custom_format, |
| ) |
|
|
| if not load_result["supported"]: |
| jobs[job_id].update( |
| { |
| "status": "failed", |
| "error": load_result["error"], |
| "dataset_id": dataset_id, |
| "supported": False, |
| "completed_at": datetime.utcnow().isoformat(), |
| } |
| ) |
| _save_jobs() |
| return |
|
|
| texts = load_result["texts"] |
| if not texts: |
| jobs[job_id].update( |
| { |
| "status": "failed", |
| "error": "No valid assistant responses found in dataset", |
| "dataset_id": dataset_id, |
| "supported": True, |
| "extracted_count": 0, |
| "completed_at": datetime.utcnow().isoformat(), |
| } |
| ) |
| _save_jobs() |
| return |
|
|
| results = { |
| "dataset_id": dataset_id, |
| "format": load_result["format"], |
| "format_name": load_result["format_info"]["name"] |
| if load_result["format_info"] |
| else None, |
| "total_rows": load_result["total_rows"], |
| "extracted_count": len(texts), |
| "provider_counts": {}, |
| "provider_confidences": {}, |
| "top_providers": {}, |
| } |
|
|
| provider_counts = defaultdict(int) |
| provider_confidences = defaultdict(list) |
| top_providers = defaultdict(int) |
|
|
| af = _active_finder() |
|
|
| total = len(texts) |
| for i, text in enumerate(tqdm(texts, desc="Evaluating")): |
| if progress_cb and (i % 10 == 0 or i == total - 1): |
| progress_cb(i + 1, total, "evaluating") |
| try: |
| proba = af.predict_proba(text) |
| sorted_providers = sorted( |
| proba.items(), key=lambda x: x[1], reverse=True |
| ) |
|
|
| pred_provider = sorted_providers[0][0] |
| confidence = sorted_providers[0][1] |
|
|
| provider_counts[pred_provider] += 1 |
| provider_confidences[pred_provider].append(confidence) |
|
|
| for name, conf in sorted_providers[:5]: |
| top_providers[name] += 1 |
| except Exception: |
| continue |
|
|
| total = len(texts) |
| for provider, count in provider_counts.items(): |
| results["provider_counts"][provider] = { |
| "count": count, |
| "percentage": round((count / total) * 100, 2), |
| } |
| confs = provider_confidences[provider] |
| avg_conf = sum(confs) / len(confs) if confs else 0 |
| results["provider_confidences"][provider] = { |
| "average": round(avg_conf * 100, 2), |
| "cumulative": round(avg_conf * count, 2), |
| } |
|
|
| results["top_providers"] = dict( |
| sorted(top_providers.items(), key=lambda x: -x[1])[:5] |
| ) |
|
|
| sorted_by_cumulative = sorted( |
| results["provider_confidences"].items(), key=lambda x: -x[1]["cumulative"] |
| ) |
| results["likely_provider"] = ( |
| sorted_by_cumulative[0][0] if sorted_by_cumulative else None |
| ) |
| results["average_confidence"] = ( |
| round(sum(sum(c) for c in provider_confidences.values()) / total * 100, 2) |
| if total > 0 |
| else 0 |
| ) |
|
|
| jobs[job_id].update( |
| { |
| "status": "completed", |
| "results": results, |
| "api_key": api_key, |
| "completed_at": datetime.utcnow().isoformat(), |
| } |
| ) |
| _save_jobs() |
| except Exception as e: |
| jobs[job_id].update( |
| { |
| "status": "failed", |
| "error": str(e), |
| "completed_at": datetime.utcnow().isoformat(), |
| } |
| ) |
| _save_jobs() |
|
|
|
|
| @app.route("/api/dataset/evaluate", methods=["POST"]) |
| @limiter.limit("10/minute") |
| def dataset_evaluate(): |
| """Start a background job to evaluate a HuggingFace dataset.""" |
| data = request.get_json(silent=True) |
| if not data or "dataset_id" not in data: |
| return jsonify({"error": "Request must include 'dataset_id'"}), 400 |
|
|
| dataset_id = data["dataset_id"] |
| max_samples = data.get("max_samples", 1000) |
| api_key = data.get("api_key") |
| custom_format = data.get("custom_format") |
|
|
| load_result = load_dataset_texts( |
| dataset_id, max_samples=max_samples, custom_format=custom_format |
| ) |
|
|
| if not load_result["supported"]: |
| return jsonify( |
| { |
| "error": load_result["error"], |
| "dataset_id": dataset_id, |
| "supported": False, |
| } |
| ), 400 |
|
|
| if not load_result["texts"]: |
| return jsonify( |
| { |
| "error": "No valid assistant responses found in dataset", |
| "dataset_id": dataset_id, |
| "supported": True, |
| "extracted_count": 0, |
| } |
| ), 400 |
|
|
| job_id = str(uuid.uuid4()) |
| jobs[job_id] = { |
| "job_id": job_id, |
| "dataset_id": dataset_id, |
| "max_samples": max_samples, |
| "status": "pending", |
| "created_at": datetime.utcnow().isoformat(), |
| "api_key": api_key, |
| "custom_format": custom_format, |
| } |
| _save_jobs() |
|
|
| thread = threading.Thread( |
| target=_run_evaluation_job, |
| args=(job_id, dataset_id, max_samples, api_key, custom_format), |
| ) |
| thread.daemon = True |
| thread.start() |
|
|
| return jsonify( |
| { |
| "job_id": job_id, |
| "status": "pending", |
| "message": "Evaluation started. Use the job_id to check status later.", |
| "custom_format": custom_format, |
| } |
| ) |
|
|
|
|
| @app.route("/api/dataset/job/<job_id>", methods=["GET"]) |
| def dataset_job_status(job_id): |
| """Get the status and results of a dataset evaluation job.""" |
| if job_id not in jobs: |
| return jsonify({"error": "Job not found"}), 404 |
|
|
| job = jobs[job_id] |
| response = { |
| "job_id": job_id, |
| "dataset_id": job.get("dataset_id"), |
| "status": job["status"], |
| "created_at": job.get("created_at"), |
| "started_at": job.get("started_at"), |
| "completed_at": job.get("completed_at"), |
| } |
|
|
| if job.get("progress"): |
| response["progress"] = job["progress"] |
|
|
| if job["status"] == "completed": |
| response["results"] = job.get("results") |
| elif job["status"] == "failed": |
| response["error"] = job.get("error") |
|
|
| return jsonify(response) |
|
|
|
|
| @app.route("/api/datasets", methods=["GET"]) |
| def list_datasets(): |
| """List all evaluated datasets, optionally filtered by API key.""" |
| api_key = request.args.get("api_key") |
|
|
| filtered_jobs = [] |
| for job_id, job in jobs.items(): |
| if api_key and job.get("api_key") != api_key: |
| continue |
| if job["status"] in ("completed", "failed"): |
| filtered_jobs.append( |
| { |
| "job_id": job_id, |
| "dataset_id": job.get("dataset_id"), |
| "status": job["status"], |
| "created_at": job.get("created_at"), |
| "completed_at": job.get("completed_at"), |
| "error": job.get("error"), |
| "custom_format": job.get("custom_format"), |
| } |
| ) |
|
|
| filtered_jobs.sort(key=lambda x: x.get("created_at", ""), reverse=True) |
| return jsonify({"datasets": filtered_jobs}) |
|
|
|
|
| @app.route("/api/datasets/clear", methods=["POST"]) |
| def clear_datasets(): |
| """Clear all evaluated dataset history for the current API key.""" |
| data = request.get_json(silent=True) or {} |
| api_key = data.get("api_key") |
|
|
| if not api_key: |
| return jsonify({"error": "API key required"}), 400 |
|
|
| keys_to_remove = [] |
| for job_id, job in jobs.items(): |
| if job.get("api_key") == api_key and job["status"] in ("completed", "failed"): |
| keys_to_remove.append(job_id) |
|
|
| for key in keys_to_remove: |
| del jobs[key] |
|
|
| if keys_to_remove: |
| _save_jobs() |
|
|
| return jsonify({"status": "ok", "cleared": len(keys_to_remove)}) |
|
|
|
|
| @app.route("/api/dataset/formats", methods=["GET"]) |
| def dataset_formats(): |
| """Get list of supported dataset formats.""" |
| formats = get_supported_formats() |
| formats_list = [ |
| { |
| "name": info["name"], |
| "description": info["description"], |
| "examples": info["examples"], |
| } |
| for info in formats.values() |
| ] |
| formats_list.append( |
| { |
| "name": "Custom Format", |
| "description": "Define your own format specification", |
| "examples": [ |
| "column: response", |
| "column: prompt, column: response", |
| "pattern: user:, pattern: assistant:", |
| "user:[startuser]assistant:[startassistant]", |
| ], |
| } |
| ) |
| return jsonify( |
| { |
| "formats": formats_list, |
| "custom_format_help": { |
| "description": "Specify custom format using these patterns:", |
| "patterns": [ |
| "column: <field_name> - extract single field", |
| "column: <user_field>, column: <assistant_field> - extract from two columns", |
| "pattern: <regex> - use regex to extract", |
| "user:[startuser]assistant:[startassistant] - pattern-based extraction", |
| ], |
| "examples": [ |
| { |
| "input": "column: completion", |
| "description": "Extract from 'completion' field", |
| }, |
| { |
| "input": "column: input, column: output", |
| "description": "Extract from 'input' and 'output' columns", |
| }, |
| { |
| "input": "user:[INST]assistant:[/INST]", |
| "description": "Extract text between markers", |
| }, |
| ], |
| }, |
| } |
| ) |
|
|
|
|
| class CommunityAIFinder(AIFinder): |
| """Extended AIFinder that blends base model with correction model.""" |
|
|
| def __init__(self, model_dir, correction_model_path=None): |
| super().__init__(model_dir) |
| self.correction_models = None |
| if correction_model_path and os.path.exists(correction_model_path): |
| self.correction_models = joblib.load(correction_model_path) |
|
|
| def predict_proba(self, text): |
| """Blend base model predictions with correction model if available.""" |
| X = self.pipeline.transform([text]) |
|
|
| base_proba = np.mean([m.predict_proba(X) for m in self.models], axis=0) |
|
|
| if self.correction_models is not None and len(self.correction_models) > 0: |
| correction_proba = np.mean( |
| [m.predict_proba(X) for m in self.correction_models], axis=0 |
| ) |
|
|
| blend_weight = 0.7 |
| final_proba = ( |
| 1 - blend_weight |
| ) * base_proba + blend_weight * correction_proba |
| final_proba = final_proba / final_proba.sum(axis=1, keepdims=True) |
| else: |
| final_proba = base_proba |
|
|
| return dict(zip(self.le.classes_, final_proba[0])) |
|
|
|
|
| def load_models(): |
| global finder, community_finder, corrections, jobs |
| finder = AIFinder(model_dir=STYLE_MODEL_DIR) |
| os.makedirs(COMMUNITY_DIR, exist_ok=True) |
| _copy_base_models_to_community() |
| if os.path.exists(CORRECTIONS_FILE): |
| corrections = joblib.load(CORRECTIONS_FILE) |
| if os.path.exists(JOBS_FILE): |
| jobs = joblib.load(JOBS_FILE) |
| if os.path.exists(os.path.join(COMMUNITY_DIR, "rf_4provider.joblib")): |
| try: |
| community_finder = CommunityAIFinder( |
| model_dir=COMMUNITY_DIR, correction_model_path=CORRECTION_MODEL_FILE |
| ) |
| except Exception: |
| community_finder = None |
|
|
|
|
| if __name__ == "__main__": |
| print("Loading models...") |
| load_models() |
| print( |
| f"Ready on cpu — {len(finder.le.classes_)} providers: " |
| f"{', '.join(finder.le.classes_)}" |
| ) |
| app.run(host="0.0.0.0", port=7860) |
|
|