| """ |
| AIFinder Dataset Evaluator with Server |
| Runs the Flask server, then allows interactive dataset input. |
| """ |
|
|
| import os |
| import sys |
| import time |
| import argparse |
| import random |
| import threading |
| import requests |
| from collections import defaultdict |
|
|
| from datasets import load_dataset |
| from tqdm import tqdm |
|
|
| from config import MODEL_DIR |
| from inference import AIFinder |
|
|
|
|
| HF_TOKEN = os.environ.get("HF_TOKEN") |
| SERVER_URL = "http://localhost:7860" |
|
|
|
|
| def start_server(): |
| """Start Flask server in background thread.""" |
| os.chdir(os.path.dirname(os.path.abspath(__file__))) |
| from app import app, load_models |
|
|
| load_models() |
| print("Server started on http://localhost:7860") |
| app.run(host="0.0.0.0", port=7860, debug=False, use_reloader=False) |
|
|
|
|
| def wait_for_server(timeout=30): |
| """Wait for server to be ready.""" |
| start = time.time() |
| while time.time() - start < timeout: |
| try: |
| resp = requests.get(f"{SERVER_URL}/api/status", timeout=2) |
| if resp.status_code == 200: |
| return True |
| except requests.exceptions.RequestException: |
| pass |
| time.sleep(1) |
| return False |
|
|
|
|
| def _parse_msg(msg): |
| """Parse a message that may be a dict or a JSON string.""" |
| if isinstance(msg, dict): |
| return msg |
| if isinstance(msg, str): |
| try: |
| import json |
|
|
| parsed = json.loads(msg) |
| if isinstance(parsed, dict): |
| return parsed |
| except (ValueError, Exception): |
| pass |
| return {} |
|
|
|
|
| def _extract_response_only(content): |
| """Extract only the final response, stripping CoT blocks.""" |
| import re |
|
|
| if not content: |
| return "" |
| think_match = re.search(r"</?think(?:ing)?>(.*)$", content, re.DOTALL) |
| if think_match: |
| response = think_match.group(1).strip() |
| if response: |
| return response |
| return content |
|
|
|
|
| def extract_texts_from_dataset(dataset_id, max_samples=None): |
| """Extract assistant response texts from a HuggingFace dataset.""" |
| print(f"\nLoading dataset: {dataset_id}") |
|
|
| load_kwargs = {"token": HF_TOKEN} if HF_TOKEN else {} |
| rows = [] |
|
|
| try: |
| ds = load_dataset(dataset_id, split="train", **load_kwargs) |
| rows = list(ds) |
| except Exception as e: |
| print(f"Failed to load dataset: {e}") |
| try: |
| import pandas as pd |
|
|
| url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/default/train/0.parquet" |
| df = pd.read_parquet(url) |
| rows = df.to_dict(orient="records") |
| except Exception as e2: |
| print(f"Parquet fallback also failed: {e2}") |
| return [] |
|
|
| texts = [] |
| for row in rows: |
| convos = row.get("conversations") or row.get("messages") or [] |
|
|
| if not convos: |
| continue |
|
|
| for msg in convos: |
| msg = _parse_msg(msg) |
| role = msg.get("role", "") |
| content = msg.get("content", "") |
|
|
| if role in ("assistant", "gpt", "model") and content: |
| response_only = _extract_response_only(content) |
| if response_only and len(response_only) > 50: |
| texts.append(response_only) |
|
|
| if max_samples and len(texts) > max_samples: |
| random.seed(42) |
| texts = random.sample(texts, max_samples) |
|
|
| return texts |
|
|
|
|
| def evaluate_dataset(texts): |
| """Evaluate all texts via API and aggregate results.""" |
| results = { |
| "total": len(texts), |
| "provider_counts": defaultdict(int), |
| "confidences": defaultdict(list), |
| } |
|
|
| for text in tqdm(texts, desc="Evaluating"): |
| try: |
| resp = requests.post( |
| f"{SERVER_URL}/api/classify", |
| json={"text": text, "top_n": 5}, |
| timeout=30, |
| ) |
| if resp.status_code == 200: |
| result = resp.json() |
| pred_provider = result.get("provider") |
| confidence = result.get("confidence", 0) / 100.0 |
|
|
| if pred_provider: |
| results["provider_counts"][pred_provider] += 1 |
| results["confidences"][pred_provider].append(confidence) |
| except Exception as e: |
| print(f"Error: {e}") |
| continue |
|
|
| return results |
|
|
|
|
| def print_results(results): |
| """Print aggregated evaluation results.""" |
| total = results["total"] |
| print("\n" + "=" * 60) |
| print(f"EVALUATION RESULTS ({total} samples)") |
| print("=" * 60) |
|
|
| print("\n--- Predicted Provider Distribution ---") |
| for provider, count in sorted( |
| results["provider_counts"].items(), key=lambda x: -x[1] |
| ): |
| pct = (count / total) * 100 |
| avg_conf = sum(results["confidences"][provider]) / len( |
| results["confidences"][provider] |
| ) |
| print( |
| f" {provider}: {count} ({pct:.1f}%) - Avg confidence: {avg_conf * 100:.1f}%" |
| ) |
|
|
| if results["confidences"]: |
| print("\n--- Top Providers (by cumulative confidence) ---") |
| provider_scores = {} |
| for provider, confs in results["confidences"].items(): |
| if confs: |
| avg_conf = sum(confs) / len(confs) |
| count = results["provider_counts"][provider] |
| provider_scores[provider] = avg_conf * count |
|
|
| for provider, score in sorted(provider_scores.items(), key=lambda x: -x[1])[:3]: |
| print(f" {provider}: {score:.2f}") |
|
|
| print("\n" + "=" * 60) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="AIFinder Dataset Evaluator with Server" |
| ) |
| parser.add_argument( |
| "--max-samples", type=int, default=None, help="Max samples to test" |
| ) |
| args = parser.parse_args() |
|
|
| print("Starting AIFinder server...") |
| server_thread = threading.Thread(target=start_server, daemon=True) |
| server_thread.start() |
|
|
| print("Waiting for server...") |
| if not wait_for_server(): |
| print("Server failed to start!") |
| sys.exit(1) |
|
|
| print("\n" + "=" * 60) |
| print("AIFinder Server Ready!") |
| print("=" * 60) |
| print(f"Server running at: {SERVER_URL}") |
| print("Enter a HuggingFace dataset ID to evaluate.") |
| print("Examples: ianncity/Hunter-Alpha-SFT-300000x") |
| print("Type 'quit' or 'exit' to stop.") |
| print("=" * 60 + "\n") |
|
|
| while True: |
| try: |
| dataset_id = input("Dataset ID: ").strip() |
|
|
| if dataset_id.lower() in ("quit", "exit", "q"): |
| print("Goodbye!") |
| break |
|
|
| if not dataset_id: |
| continue |
|
|
| texts = extract_texts_from_dataset(dataset_id, args.max_samples) |
|
|
| if not texts: |
| print("No valid texts found in dataset.") |
| continue |
|
|
| print(f"Testing {len(texts)} responses...") |
| results = evaluate_dataset(texts) |
| print_results(results) |
|
|
| except KeyboardInterrupt: |
| print("\nGoodbye!") |
| break |
| except Exception as e: |
| print(f"Error: {e}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|