| | from __future__ import annotations |
| |
|
| | from typing import List, Dict, Any, Optional, Tuple |
| | import math |
| |
|
| | from ...models.model_configs import LLMConfig |
| | from ...agents.customize_agent import CustomizeAgent |
| |
|
| |
|
| | def _tokenize(text: str) -> List[str]: |
| | return [t for t in text.lower().split() if t.strip()] |
| |
|
| |
|
| | def _tf_vector(tokens: List[str]) -> Dict[str, float]: |
| | vec: Dict[str, float] = {} |
| | for t in tokens: |
| | vec[t] = vec.get(t, 0.0) + 1.0 |
| | |
| | norm = math.sqrt(sum(v * v for v in vec.values())) or 1.0 |
| | for k in list(vec.keys()): |
| | vec[k] /= norm |
| | return vec |
| |
|
| |
|
| | def _cosine_sim(a: Dict[str, float], b: Dict[str, float]) -> float: |
| | if len(a) < len(b): |
| | a, b = b, a |
| | return sum(v * b.get(k, 0.0) for k, v in a.items()) |
| |
|
| |
|
| | def _js_divergence(p: Dict[str, float], q: Dict[str, float]) -> float: |
| | |
| | vocab = set(p.keys()) | set(q.keys()) |
| | eps = 1e-9 |
| | def _norm(d: Dict[str, float]) -> Dict[str, float]: |
| | s = sum(d.get(w, 0.0) for w in vocab) or 1.0 |
| | return {w: (d.get(w, 0.0) + eps) / (s + eps * len(vocab)) for w in vocab} |
| | P = _norm(p) |
| | Q = _norm(q) |
| | M = {w: 0.5 * (P[w] + Q[w]) for w in vocab} |
| | def _kl(X, Y): |
| | return sum(X[w] * math.log((X[w] + eps) / (Y[w] + eps)) for w in vocab) |
| | return 0.5 * _kl(P, M) + 0.5 * _kl(Q, M) |
| |
|
| |
|
| | class PruningPipeline: |
| | """可插拔剪枝流水线:质量剪枝(QP) → 多样性剪枝(DP) → 误解反驳(MR)。 |
| | |
| | 候选输入格式:List[{"agent_id": int, "text": str}] |
| | 输出保留相同结构,并在条目中填充可选指标:qp_score、dup_removed 等。 |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | enable_qp: bool = True, |
| | enable_dp: bool = True, |
| | enable_mr: bool = False, |
| | qp_threshold: float = 0.15, |
| | qp_top_k: Optional[int] = None, |
| | dp_similarity_threshold: float = 0.92, |
| | dp_max_candidates: Optional[int] = None, |
| | mr_llm_config: Optional[LLMConfig] = None, |
| | min_keep_count: Optional[int] = None, |
| | ) -> None: |
| | self.enable_qp = enable_qp |
| | self.enable_dp = enable_dp |
| | self.enable_mr = enable_mr |
| | self.qp_threshold = qp_threshold |
| | self.qp_top_k = qp_top_k |
| | self.dp_similarity_threshold = dp_similarity_threshold |
| | self.dp_max_candidates = dp_max_candidates |
| | self.mr_llm_config = mr_llm_config |
| | |
| | self.min_keep_count = min_keep_count |
| |
|
| | |
| | def _qp_score(self, problem: str, text: str) -> float: |
| | |
| | qv = _tf_vector(_tokenize(problem)) |
| | tv = _tf_vector(_tokenize(text)) |
| | return _cosine_sim(qv, tv) |
| |
|
| | def _quality_prune(self, problem: str, candidates: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
| | if not self.enable_qp or len(candidates) <= 1: |
| | return candidates |
| | scored: List[Tuple[float, Dict[str, Any]]] = [] |
| | for c in candidates: |
| | s = self._qp_score(problem, c.get("text", "")) |
| | c = dict(c) |
| | c["qp_score"] = s |
| | scored.append((s, c)) |
| | scored.sort(key=lambda x: x[0], reverse=True) |
| | if self.qp_top_k is not None and self.qp_top_k > 0: |
| | scored = scored[: self.qp_top_k] |
| | kept = [c for s, c in scored if s >= self.qp_threshold] |
| | |
| | if not kept: |
| | kept = [scored[0][1]] |
| | if self.min_keep_count and len(kept) < self.min_keep_count: |
| | |
| | existing_ids = set(id(obj) for obj in kept) |
| | for _, c in scored: |
| | if id(c) not in existing_ids: |
| | kept.append(c) |
| | if len(kept) >= self.min_keep_count: |
| | break |
| | return kept |
| |
|
| | |
| | def _diversity_prune(self, candidates: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
| | if not self.enable_dp or len(candidates) <= 1: |
| | return candidates |
| | vecs = [_tf_vector(_tokenize(c.get("text", ""))) for c in candidates] |
| | kept: List[int] = [] |
| | for i, v in enumerate(vecs): |
| | diverse = True |
| | for j in kept: |
| | sim = _cosine_sim(v, vecs[j]) |
| | if sim >= self.dp_similarity_threshold: |
| | diverse = False |
| | break |
| | if diverse: |
| | kept.append(i) |
| | if self.dp_max_candidates and len(kept) >= self.dp_max_candidates: |
| | break |
| | pruned = [candidates[i] for i in kept] |
| | |
| | if self.min_keep_count and len(pruned) < self.min_keep_count: |
| | |
| | ranked = sorted( |
| | range(len(candidates)), |
| | key=lambda idx: float(candidates[idx].get("qp_score") or 0.0), |
| | reverse=True, |
| | ) |
| | chosen = set(kept) |
| | for idx in ranked: |
| | if idx in chosen: |
| | continue |
| | pruned.append(candidates[idx]) |
| | chosen.add(idx) |
| | if len(pruned) >= self.min_keep_count: |
| | break |
| | return pruned |
| |
|
| | |
| | def _build_critic(self) -> Optional[CustomizeAgent]: |
| | if not self.mr_llm_config: |
| | return None |
| | prompt = ( |
| | """ |
| | You are a critical reviewer. Given a problem and a set of condensed candidate answers, identify common misunderstandings or mistakes, and propose a corrected consolidated answer. |
| | |
| | Problem: |
| | {problem} |
| | |
| | Candidates: |
| | {candidates_text} |
| | |
| | Return XML: |
| | <response> |
| | <issues>Common mistakes found</issues> |
| | <rebuttal>How to fix them</rebuttal> |
| | <corrected>Single corrected final answer</corrected> |
| | </response> |
| | """ |
| | ).strip() |
| | inputs = [ |
| | {"name": "problem", "type": "str", "description": "Problem statement"}, |
| | {"name": "candidates_text", "type": "str", "description": "Concatenated candidates"}, |
| | ] |
| | outputs = [ |
| | {"name": "issues", "type": "str", "description": "Common mistakes", "required": True}, |
| | {"name": "rebuttal", "type": "str", "description": "Corrections", "required": True}, |
| | {"name": "corrected", "type": "str", "description": "Corrected final answer", "required": True}, |
| | ] |
| | return CustomizeAgent( |
| | name="CriticAgent", |
| | description="Detects misunderstandings and proposes corrected answer", |
| | prompt=prompt, |
| | llm_config=self.mr_llm_config, |
| | inputs=inputs, |
| | outputs=outputs, |
| | parse_mode="xml", |
| | ) |
| |
|
| | def _misunderstanding_rebuttal(self, problem: str, candidates: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], Optional[Dict[str, str]]]: |
| | if not self.enable_mr: |
| | return candidates, None |
| | critic = self._build_critic() |
| | if critic is None: |
| | return candidates, None |
| | concat = "\n\n".join(f"#{c.get('agent_id')}: {c.get('text','').strip()}" for c in candidates) |
| | msg = critic(inputs={"problem": problem, "candidates_text": concat}) |
| | st = msg.content.get_structured_data() |
| | |
| | for c in candidates: |
| | c["mr_issues"] = st.get("issues", "") |
| | c["mr_rebuttal"] = st.get("rebuttal", "") |
| | suggested = { |
| | "issues": st.get("issues", ""), |
| | "rebuttal": st.get("rebuttal", ""), |
| | "corrected": st.get("corrected", ""), |
| | } |
| | return candidates, suggested |
| |
|
| | |
| | def apply(self, problem: str, candidates: List[Dict[str, Any]]) -> Dict[str, Any]: |
| | """返回 {"candidates": pruned, "mr_suggested": optional}。""" |
| | step1 = self._quality_prune(problem, candidates) |
| | step2 = self._diversity_prune(step1) |
| | step3, suggested = self._misunderstanding_rebuttal(problem, step2) |
| | return {"candidates": step3, "mr_suggested": suggested} |
| |
|
| |
|
| |
|