""" llm_chain.py — LangChain LCEL pipeline for multi-LLM code debugging. Flow: 1. Build one LCEL chain per panel model (PANEL_PROMPT | llm | StrOutputParser) 2. Run both chains concurrently (asyncio.gather + ainvoke) 3. Feed results into judge chain (JUDGE_PROMPT | llm | StrOutputParser) 4. Parse judge output into {error_analysis, corrected_code} 5. Return structured dict consumed by FastAPI """ import asyncio import re import time from typing import Any from langchain_core.output_parsers import StrOutputParser from app.config import settings from app.llm_factory import make_panel_llm, make_judge_llm from app.prompts import PANEL_PROMPT, JUDGE_PROMPT # ── Helpers ─────────────────────────────────────────────────────────────────── def _parse_judge_output(raw: str) -> tuple[str, str]: """ Parse judge response into (error_analysis, corrected_code). Splits on the exact section headers defined in JUDGE_PROMPT. """ error_analysis = "" corrected_code = "" if "## Corrected Code" in raw: parts = raw.split("##Corrected Code", 1) corrected_code = parts[1].strip() if "## Error Analysis" in parts[0]: error_analysis = parts[0].split("## Error Analysis", 1)[1].strip() else: error_analysis = parts[0].strip() elif "## Error Analysis" in raw: error_analysis = raw.split("## Error Analysis", 1)[1].strip() corrected_code = raw # fallback: treat whole thing as corrected code # Strip leading/trailing whitespace from each error_analysis = error_analysis.strip() corrected_code = corrected_code.strip() return error_analysis, corrected_code def _error_panel(model: str, label: str, error: str, latency_ms: float) -> dict: return { "model": model, "label": label, "response": None, "latency_ms": latency_ms, "error": error, } # ── Panel chain ─────────────────────────────────────────────────────────────── def build_panel_chain(model: str, temperature: float): """ LCEL chain: PANEL_PROMPT | ChatOpenAI | StrOutputParser Input: {"question": str} Output: str """ llm = make_panel_llm(model, temperature=temperature) return PANEL_PROMPT | llm | StrOutputParser() async def _run_panel(model: str, label: str, question: str, temperature: float) -> dict: chain = build_panel_chain(model, temperature) start = time.perf_counter() try: response = await chain.ainvoke({"question": question}) latency = round((time.perf_counter() - start) * 1000, 1) return { "model": model, "label": label, "response": response.strip(), "latency_ms": latency, "error": None, } except Exception as exc: latency = round((time.perf_counter() - start) * 1000, 1) return _error_panel(model, label, str(exc), latency) # ── Judge chain ─────────────────────────────────────────────────────────────── def build_judge_chain(): """ LCEL chain: JUDGE_PROMPT | ChatOpenAI | StrOutputParser Input: {"question", "label_1", "response_1", "label_2", "response_2"} Output: str """ llm = make_judge_llm() return JUDGE_PROMPT | llm | StrOutputParser() async def _run_judge(question: str, panel_results: list[dict]) -> dict: chain = build_judge_chain() start = time.perf_counter() # Always provide exactly 2 slots padded = list(panel_results) while len(padded) < 2: padded.append({"label": "N/A", "response": "[No response — model failed]"}) def _content(r: dict) -> str: return r["response"] if r.get("response") else f"[Error: {r.get('error', 'unknown')}]" judge_input = { "question": question, "label_1": padded[0]["label"], "response_1": _content(padded[0]), "label_2": padded[1]["label"], "response_2": _content(padded[1]), } try: raw = await chain.ainvoke(judge_input) latency = round((time.perf_counter() - start) * 1000, 1) error_analysis, corrected_code = _parse_judge_output(raw) return { "model": settings.JUDGE_MODEL, "label": settings.JUDGE_LABEL, "error_analysis": error_analysis, "corrected_code": corrected_code, "full_response": raw.strip(), "latency_ms": latency, "error": None, } except Exception as exc: latency = round((time.perf_counter() - start) * 1000, 1) return { "model": settings.JUDGE_MODEL, "label": settings.JUDGE_LABEL, "error_analysis": None, "corrected_code": None, "full_response": None, "latency_ms": latency, "error": str(exc), } async def debug_query(question: str, temperature: float = 0.3) -> dict[str, Any]: """ Phase 1 — 2 panel models in parallel (asyncio.gather) Phase 2 — StepFun Flash judge → {error_analysis, corrected_code} """ total_start = time.perf_counter() panel_tasks = [ _run_panel( model = model, label = settings.PANEL_MODEL_LABELS.get(model, model.split("/")[-1]), question = question, temperature = temperature, ) for model in settings.PANEL_MODELS ] panel_results: list[dict] = list(await asyncio.gather(*panel_tasks)) judge_result = await _run_judge(question, panel_results) return { "question": question, "panel": panel_results, "judge": judge_result, "total_ms": round((time.perf_counter() - total_start) * 1000, 1), }