Spaces:
Sleeping
Sleeping
| """Unified Claude API scoring for both AI/ML and security domains.""" | |
| import json | |
| import logging | |
| import re | |
| import time | |
| import anthropic | |
| log = logging.getLogger(__name__) | |
| # Import the module so we always read live config values (not stale copies). | |
| import src.config as config | |
| from src.config import SECURITY_LLM_RE | |
| from src.db import get_unscored_papers, update_paper_scores | |
| def score_run(run_id: int, domain: str) -> int: | |
| """Score all unscored papers in a run. Returns count of scored papers.""" | |
| if not config.ANTHROPIC_API_KEY: | |
| log.warning("ANTHROPIC_API_KEY not set — skipping scoring") | |
| return 0 | |
| scoring_model = config.SCORING_MODEL | |
| batch_size = config.BATCH_SIZE | |
| scoring_config = config.SCORING_CONFIGS[domain] | |
| papers = get_unscored_papers(run_id) | |
| if not papers: | |
| log.info("No unscored papers for run %d", run_id) | |
| return 0 | |
| log.info("Scoring %d %s papers with %s ...", len(papers), domain, scoring_model) | |
| client = anthropic.Anthropic(timeout=120.0) | |
| max_chars = config.MAX_ABSTRACT_CHARS_AIML if domain == "aiml" else config.MAX_ABSTRACT_CHARS_SECURITY | |
| scored_count = 0 | |
| t0 = time.monotonic() | |
| for i in range(0, len(papers), batch_size): | |
| batch = papers[i : i + batch_size] | |
| batch_num = i // batch_size + 1 | |
| total_batches = (len(papers) + batch_size - 1) // batch_size | |
| log.info("Batch %d/%d (%d papers) ...", batch_num, total_batches, len(batch)) | |
| # Build user content | |
| user_content = _build_batch_content(batch, domain, max_chars) | |
| # Call Claude | |
| scores = _call_claude(client, scoring_config["prompt"], user_content, model=scoring_model) | |
| if not scores: | |
| continue | |
| # Map scores back to papers and update DB | |
| scored_count += _apply_scores(batch, scores, domain, scoring_config) | |
| elapsed = time.monotonic() - t0 | |
| log.info("Scored %d/%d papers with %s in %.0fs", scored_count, len(papers), scoring_model, elapsed) | |
| return scored_count | |
| def _build_batch_content(papers: list[dict], domain: str, max_chars: int) -> str: | |
| """Build the user content string for a batch of papers.""" | |
| lines = [] | |
| for p in papers: | |
| abstract = (p.get("abstract") or "")[:max_chars] | |
| id_field = p.get("entry_id") or p.get("arxiv_url") or p.get("arxiv_id", "") | |
| lines.append("---") | |
| if domain == "security": | |
| lines.append(f"entry_id: {id_field}") | |
| else: | |
| lines.append(f"arxiv_id: {p.get('arxiv_id', '')}") | |
| authors_list = p.get("authors", []) | |
| if isinstance(authors_list, str): | |
| authors_str = authors_list | |
| else: | |
| authors_str = ", ".join(authors_list[:5]) | |
| cats = p.get("categories", []) | |
| if isinstance(cats, str): | |
| cats_str = cats | |
| else: | |
| cats_str = ", ".join(cats) | |
| lines.append(f"title: {p.get('title', '')}") | |
| lines.append(f"authors: {authors_str}") | |
| lines.append(f"categories: {cats_str}") | |
| code_url = p.get("github_repo") or p.get("code_url") or "none found" | |
| lines.append(f"code_url_found: {code_url}") | |
| if domain == "security": | |
| if "llm_adjacent" not in p: | |
| text = f"{p.get('title', '')} {p.get('abstract', '')}" | |
| p["llm_adjacent"] = bool(SECURITY_LLM_RE.search(text)) | |
| lines.append(f"llm_adjacent: {str(p['llm_adjacent']).lower()}") | |
| if domain == "aiml": | |
| lines.append(f"hf_upvotes: {p.get('hf_upvotes', 0)}") | |
| hf_models = p.get("hf_models", []) | |
| if hf_models: | |
| model_ids = [m["id"] if isinstance(m, dict) else str(m) for m in hf_models[:3]] | |
| lines.append(f"hf_models: {', '.join(model_ids)}") | |
| hf_spaces = p.get("hf_spaces", []) | |
| if hf_spaces: | |
| space_ids = [s["id"] if isinstance(s, dict) else str(s) for s in hf_spaces[:3]] | |
| lines.append(f"hf_spaces: {', '.join(space_ids)}") | |
| lines.append(f"source: {p.get('source', 'unknown')}") | |
| lines.append(f"abstract: {abstract}") | |
| lines.append(f"comment: {p.get('comment', 'N/A')}") | |
| lines.append("") | |
| return "\n".join(lines) | |
| def _call_claude(client: anthropic.Anthropic, system_prompt: str, user_content: str, *, model: str) -> list[dict]: | |
| """Call Claude API and extract JSON response.""" | |
| for attempt in range(3): | |
| try: | |
| response = client.messages.create( | |
| model=model, | |
| max_tokens=4096, | |
| system=system_prompt, | |
| messages=[{"role": "user", "content": user_content}], | |
| ) | |
| text = response.content[0].text | |
| json_match = re.search(r"\[.*\]", text, re.DOTALL) | |
| if json_match: | |
| return json.loads(json_match.group()) | |
| log.warning("No JSON array in response (attempt %d)", attempt + 1) | |
| except (anthropic.APIError, json.JSONDecodeError) as e: | |
| log.error("Scoring API error (attempt %d): %s", attempt + 1, e) | |
| if attempt < 2: | |
| time.sleep(2 ** (attempt + 1)) | |
| else: | |
| log.error("Skipping batch after 3 failures") | |
| return [] | |
| def _apply_scores(papers: list[dict], scores: list[dict], domain: str, config: dict) -> int: | |
| """Apply scores from Claude response to papers in DB. Returns count applied.""" | |
| axes = config["axes"] | |
| weights = config["weights"] | |
| weight_values = list(weights.values()) | |
| # Build lookup by ID | |
| if domain == "security": | |
| score_map = {s.get("entry_id", ""): s for s in scores} | |
| else: | |
| score_map = {s.get("arxiv_id", ""): s for s in scores} | |
| applied = 0 | |
| for paper in papers: | |
| if domain == "security": | |
| key = paper.get("entry_id") or paper.get("arxiv_url") or "" | |
| else: | |
| key = paper.get("arxiv_id", "") | |
| score = score_map.get(key) | |
| if not score: | |
| continue | |
| # Extract axis scores | |
| axis_scores = [score.get(ax, 0) for ax in axes] | |
| # Compute composite | |
| composite = sum(s * w for s, w in zip(axis_scores, weight_values)) | |
| update_paper_scores(paper["id"], { | |
| "score_axis_1": axis_scores[0] if len(axis_scores) > 0 else None, | |
| "score_axis_2": axis_scores[1] if len(axis_scores) > 1 else None, | |
| "score_axis_3": axis_scores[2] if len(axis_scores) > 2 else None, | |
| "composite": round(composite, 2), | |
| "summary": score.get("summary", ""), | |
| "reasoning": score.get("reasoning", ""), | |
| "code_url": score.get("code_url"), | |
| }) | |
| applied += 1 | |
| return applied | |
| def rescore_top(run_id: int, domain: str, n: int = 0) -> int: | |
| """Re-score the top N papers from a run using the stronger rescore model. | |
| Returns count of re-scored papers. Pass n=0 to use RESCORE_TOP_N from config. | |
| """ | |
| rescore_model = config.RESCORE_MODEL | |
| scoring_model = config.SCORING_MODEL | |
| n = n or config.RESCORE_TOP_N | |
| if n <= 0: | |
| return 0 | |
| if not config.ANTHROPIC_API_KEY: | |
| log.warning("ANTHROPIC_API_KEY not set — skipping re-scoring") | |
| return 0 | |
| if rescore_model == scoring_model: | |
| log.info("Rescore model same as scoring model — skipping re-score") | |
| return 0 | |
| from src.db import get_top_papers | |
| scoring_config = config.SCORING_CONFIGS[domain] | |
| papers = get_top_papers(domain, run_id=run_id, limit=n) | |
| if not papers: | |
| log.info("No papers to re-score for run %d", run_id) | |
| return 0 | |
| log.info("Re-scoring top %d %s papers with %s ...", len(papers), domain, rescore_model) | |
| client = anthropic.Anthropic(timeout=120.0) | |
| max_chars = config.MAX_ABSTRACT_CHARS_AIML if domain == "aiml" else config.MAX_ABSTRACT_CHARS_SECURITY | |
| t0 = time.monotonic() | |
| user_content = _build_batch_content(papers, domain, max_chars) | |
| scores = _call_claude(client, scoring_config["prompt"], user_content, model=rescore_model) | |
| if not scores: | |
| log.warning("Re-scoring returned no results") | |
| return 0 | |
| rescored = _apply_scores(papers, scores, domain, scoring_config) | |
| elapsed = time.monotonic() - t0 | |
| log.info("Re-scored %d/%d papers with %s in %.0fs", rescored, len(papers), rescore_model, elapsed) | |
| return rescored | |