""" AIFinder Dataset Evaluator with Server Runs the Flask server, then allows interactive dataset input. """ import os import sys import time import argparse import random import threading import requests from collections import defaultdict from datasets import load_dataset from tqdm import tqdm from config import MODEL_DIR from inference import AIFinder HF_TOKEN = os.environ.get("HF_TOKEN") SERVER_URL = "http://localhost:7860" def start_server(): """Start Flask server in background thread.""" os.chdir(os.path.dirname(os.path.abspath(__file__))) from app import app, load_models load_models() print("Server started on http://localhost:7860") app.run(host="0.0.0.0", port=7860, debug=False, use_reloader=False) def wait_for_server(timeout=30): """Wait for server to be ready.""" start = time.time() while time.time() - start < timeout: try: resp = requests.get(f"{SERVER_URL}/api/status", timeout=2) if resp.status_code == 200: return True except requests.exceptions.RequestException: pass time.sleep(1) return False def _parse_msg(msg): """Parse a message that may be a dict or a JSON string.""" if isinstance(msg, dict): return msg if isinstance(msg, str): try: import json parsed = json.loads(msg) if isinstance(parsed, dict): return parsed except (ValueError, Exception): pass return {} def _extract_response_only(content): """Extract only the final response, stripping CoT blocks.""" import re if not content: return "" think_match = re.search(r"(.*)$", content, re.DOTALL) if think_match: response = think_match.group(1).strip() if response: return response return content def extract_texts_from_dataset(dataset_id, max_samples=None): """Extract assistant response texts from a HuggingFace dataset.""" print(f"\nLoading dataset: {dataset_id}") load_kwargs = {"token": HF_TOKEN} if HF_TOKEN else {} rows = [] try: ds = load_dataset(dataset_id, split="train", **load_kwargs) rows = list(ds) except Exception as e: print(f"Failed to load dataset: {e}") try: import pandas as pd url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/default/train/0.parquet" df = pd.read_parquet(url) rows = df.to_dict(orient="records") except Exception as e2: print(f"Parquet fallback also failed: {e2}") return [] texts = [] for row in rows: convos = row.get("conversations") or row.get("messages") or [] if not convos: continue for msg in convos: msg = _parse_msg(msg) role = msg.get("role", "") content = msg.get("content", "") if role in ("assistant", "gpt", "model") and content: response_only = _extract_response_only(content) if response_only and len(response_only) > 50: texts.append(response_only) if max_samples and len(texts) > max_samples: random.seed(42) texts = random.sample(texts, max_samples) return texts def evaluate_dataset(texts): """Evaluate all texts via API and aggregate results.""" results = { "total": len(texts), "provider_counts": defaultdict(int), "confidences": defaultdict(list), } for text in tqdm(texts, desc="Evaluating"): try: resp = requests.post( f"{SERVER_URL}/api/classify", json={"text": text, "top_n": 5}, timeout=30, ) if resp.status_code == 200: result = resp.json() pred_provider = result.get("provider") confidence = result.get("confidence", 0) / 100.0 if pred_provider: results["provider_counts"][pred_provider] += 1 results["confidences"][pred_provider].append(confidence) except Exception as e: print(f"Error: {e}") continue return results def print_results(results): """Print aggregated evaluation results.""" total = results["total"] print("\n" + "=" * 60) print(f"EVALUATION RESULTS ({total} samples)") print("=" * 60) print("\n--- Predicted Provider Distribution ---") for provider, count in sorted( results["provider_counts"].items(), key=lambda x: -x[1] ): pct = (count / total) * 100 avg_conf = sum(results["confidences"][provider]) / len( results["confidences"][provider] ) print( f" {provider}: {count} ({pct:.1f}%) - Avg confidence: {avg_conf * 100:.1f}%" ) if results["confidences"]: print("\n--- Top Providers (by cumulative confidence) ---") provider_scores = {} for provider, confs in results["confidences"].items(): if confs: avg_conf = sum(confs) / len(confs) count = results["provider_counts"][provider] provider_scores[provider] = avg_conf * count for provider, score in sorted(provider_scores.items(), key=lambda x: -x[1])[:3]: print(f" {provider}: {score:.2f}") print("\n" + "=" * 60) def main(): parser = argparse.ArgumentParser( description="AIFinder Dataset Evaluator with Server" ) parser.add_argument( "--max-samples", type=int, default=None, help="Max samples to test" ) args = parser.parse_args() print("Starting AIFinder server...") server_thread = threading.Thread(target=start_server, daemon=True) server_thread.start() print("Waiting for server...") if not wait_for_server(): print("Server failed to start!") sys.exit(1) print("\n" + "=" * 60) print("AIFinder Server Ready!") print("=" * 60) print(f"Server running at: {SERVER_URL}") print("Enter a HuggingFace dataset ID to evaluate.") print("Examples: ianncity/Hunter-Alpha-SFT-300000x") print("Type 'quit' or 'exit' to stop.") print("=" * 60 + "\n") while True: try: dataset_id = input("Dataset ID: ").strip() if dataset_id.lower() in ("quit", "exit", "q"): print("Goodbye!") break if not dataset_id: continue texts = extract_texts_from_dataset(dataset_id, args.max_samples) if not texts: print("No valid texts found in dataset.") continue print(f"Testing {len(texts)} responses...") results = evaluate_dataset(texts) print_results(results) except KeyboardInterrupt: print("\nGoodbye!") break except Exception as e: print(f"Error: {e}") if __name__ == "__main__": main()