| | """ |
| | Extraction Critic for Validation |
| | |
| | Validates extracted information against source evidence. |
| | Provides confidence scoring and abstention recommendations. |
| | """ |
| |
|
| | from typing import List, Optional, Dict, Any, Tuple |
| | from enum import Enum |
| | from pydantic import BaseModel, Field |
| | from loguru import logger |
| |
|
| | try: |
| | import httpx |
| | HTTPX_AVAILABLE = True |
| | except ImportError: |
| | HTTPX_AVAILABLE = False |
| |
|
| |
|
| | class ValidationStatus(str, Enum): |
| | """Validation status codes.""" |
| | VALID = "valid" |
| | INVALID = "invalid" |
| | UNCERTAIN = "uncertain" |
| | ABSTAIN = "abstain" |
| | NO_EVIDENCE = "no_evidence" |
| |
|
| |
|
| | class CriticConfig(BaseModel): |
| | """Configuration for extraction critic.""" |
| | |
| | llm_provider: str = Field(default="ollama", description="LLM provider") |
| | ollama_base_url: str = Field(default="http://localhost:11434") |
| | ollama_model: str = Field(default="llama3.2:3b") |
| |
|
| | |
| | confidence_threshold: float = Field( |
| | default=0.7, |
| | ge=0.0, |
| | le=1.0, |
| | description="Minimum confidence for valid extraction" |
| | ) |
| | evidence_required: bool = Field( |
| | default=True, |
| | description="Require evidence for validation" |
| | ) |
| | strict_mode: bool = Field( |
| | default=False, |
| | description="Strict validation mode" |
| | ) |
| |
|
| | |
| | max_fields_per_request: int = Field(default=10, ge=1) |
| | timeout: float = Field(default=60.0, ge=1.0) |
| |
|
| |
|
| | class FieldValidation(BaseModel): |
| | """Validation result for a single field.""" |
| | field_name: str |
| | extracted_value: Any |
| | status: ValidationStatus |
| | confidence: float |
| | reasoning: str |
| |
|
| | |
| | evidence_found: bool = False |
| | evidence_snippet: Optional[str] = None |
| | evidence_page: Optional[int] = None |
| |
|
| | |
| | suggested_value: Optional[Any] = None |
| | correction_reason: Optional[str] = None |
| |
|
| |
|
| | class ValidationResult(BaseModel): |
| | """Complete validation result.""" |
| | overall_status: ValidationStatus |
| | overall_confidence: float |
| | field_validations: List[FieldValidation] |
| |
|
| | |
| | valid_count: int = 0 |
| | invalid_count: int = 0 |
| | uncertain_count: int = 0 |
| | abstain_count: int = 0 |
| |
|
| | |
| | should_accept: bool |
| | abstain_reason: Optional[str] = None |
| |
|
| |
|
| | class ExtractionCritic: |
| | """ |
| | Critic for validating extracted information. |
| | |
| | Features: |
| | - Validates extracted fields against source evidence |
| | - Provides confidence scores |
| | - Recommends abstention when uncertain |
| | - Suggests corrections when possible |
| | """ |
| |
|
| | VALIDATION_PROMPT = """You are a critical validator for document extraction. |
| | Your task is to validate extracted information against the source evidence. |
| | |
| | For each field, determine: |
| | 1. Is the extracted value supported by the evidence? (yes/no/partially) |
| | 2. Confidence score (0.0 to 1.0) |
| | 3. Brief reasoning |
| | 4. If incorrect, suggest the correct value |
| | |
| | Be strict and skeptical. Only mark as valid if clearly supported. |
| | |
| | Evidence: |
| | {evidence} |
| | |
| | Extracted Fields to Validate: |
| | {fields} |
| | |
| | Respond in JSON format: |
| | {{ |
| | "validations": [ |
| | {{ |
| | "field": "field_name", |
| | "status": "valid|invalid|uncertain|no_evidence", |
| | "confidence": 0.0-1.0, |
| | "reasoning": "explanation", |
| | "suggested_value": null or corrected value |
| | }} |
| | ] |
| | }}""" |
| |
|
| | def __init__(self, config: Optional[CriticConfig] = None): |
| | """Initialize extraction critic.""" |
| | self.config = config or CriticConfig() |
| |
|
| | def validate_extraction( |
| | self, |
| | extracted_fields: Dict[str, Any], |
| | evidence: List[Dict[str, Any]], |
| | ) -> ValidationResult: |
| | """ |
| | Validate extracted fields against evidence. |
| | |
| | Args: |
| | extracted_fields: Dictionary of field_name -> value |
| | evidence: List of evidence chunks with text, page, etc. |
| | |
| | Returns: |
| | ValidationResult |
| | """ |
| | if not extracted_fields: |
| | return ValidationResult( |
| | overall_status=ValidationStatus.ABSTAIN, |
| | overall_confidence=0.0, |
| | field_validations=[], |
| | should_accept=False, |
| | abstain_reason="No fields to validate", |
| | ) |
| |
|
| | |
| | if not evidence and self.config.evidence_required: |
| | return self._create_no_evidence_result(extracted_fields) |
| |
|
| | |
| | field_validations = self._validate_with_llm(extracted_fields, evidence) |
| |
|
| | |
| | valid_count = sum(1 for v in field_validations if v.status == ValidationStatus.VALID) |
| | invalid_count = sum(1 for v in field_validations if v.status == ValidationStatus.INVALID) |
| | uncertain_count = sum(1 for v in field_validations if v.status == ValidationStatus.UNCERTAIN) |
| | abstain_count = sum(1 for v in field_validations if v.status == ValidationStatus.ABSTAIN) |
| |
|
| | |
| | if field_validations: |
| | overall_confidence = sum(v.confidence for v in field_validations) / len(field_validations) |
| | else: |
| | overall_confidence = 0.0 |
| |
|
| | |
| | if invalid_count > 0: |
| | overall_status = ValidationStatus.INVALID |
| | elif abstain_count > valid_count: |
| | overall_status = ValidationStatus.ABSTAIN |
| | elif uncertain_count > valid_count: |
| | overall_status = ValidationStatus.UNCERTAIN |
| | else: |
| | overall_status = ValidationStatus.VALID |
| |
|
| | |
| | should_accept = ( |
| | overall_confidence >= self.config.confidence_threshold |
| | and invalid_count == 0 |
| | and overall_status in [ValidationStatus.VALID, ValidationStatus.UNCERTAIN] |
| | ) |
| |
|
| | |
| | abstain_reason = None |
| | if not should_accept: |
| | if overall_confidence < self.config.confidence_threshold: |
| | abstain_reason = f"Confidence ({overall_confidence:.2f}) below threshold ({self.config.confidence_threshold})" |
| | elif invalid_count > 0: |
| | abstain_reason = f"{invalid_count} field(s) validated as invalid" |
| | elif overall_status == ValidationStatus.ABSTAIN: |
| | abstain_reason = "Insufficient evidence to validate" |
| |
|
| | return ValidationResult( |
| | overall_status=overall_status, |
| | overall_confidence=overall_confidence, |
| | field_validations=field_validations, |
| | valid_count=valid_count, |
| | invalid_count=invalid_count, |
| | uncertain_count=uncertain_count, |
| | abstain_count=abstain_count, |
| | should_accept=should_accept, |
| | abstain_reason=abstain_reason, |
| | ) |
| |
|
| | def _validate_with_llm( |
| | self, |
| | fields: Dict[str, Any], |
| | evidence: List[Dict[str, Any]], |
| | ) -> List[FieldValidation]: |
| | """Validate fields using LLM.""" |
| | |
| | evidence_text = self._format_evidence(evidence) |
| |
|
| | |
| | fields_text = "\n".join( |
| | f"- {name}: {value}" |
| | for name, value in fields.items() |
| | ) |
| |
|
| | |
| | prompt = self.VALIDATION_PROMPT.format( |
| | evidence=evidence_text, |
| | fields=fields_text, |
| | ) |
| |
|
| | |
| | try: |
| | response = self._call_llm(prompt) |
| | validations = self._parse_validation_response(response, fields, evidence) |
| | except Exception as e: |
| | logger.error(f"LLM validation failed: {e}") |
| | |
| | validations = self._heuristic_validation(fields, evidence) |
| |
|
| | return validations |
| |
|
| | def _format_evidence(self, evidence: List[Dict[str, Any]]) -> str: |
| | """Format evidence for prompt.""" |
| | parts = [] |
| | for i, ev in enumerate(evidence[:10], 1): |
| | page = ev.get("page", "?") |
| | text = ev.get("text", ev.get("snippet", ""))[:500] |
| | parts.append(f"[{i}] Page {page}: {text}") |
| | return "\n\n".join(parts) |
| |
|
| | def _call_llm(self, prompt: str) -> str: |
| | """Call LLM for validation.""" |
| | if not HTTPX_AVAILABLE: |
| | raise ImportError("httpx required for LLM calls") |
| |
|
| | with httpx.Client(timeout=self.config.timeout) as client: |
| | response = client.post( |
| | f"{self.config.ollama_base_url}/api/generate", |
| | json={ |
| | "model": self.config.ollama_model, |
| | "prompt": prompt, |
| | "stream": False, |
| | "options": {"temperature": 0.1}, |
| | }, |
| | ) |
| | response.raise_for_status() |
| | return response.json().get("response", "") |
| |
|
| | def _parse_validation_response( |
| | self, |
| | response: str, |
| | fields: Dict[str, Any], |
| | evidence: List[Dict[str, Any]], |
| | ) -> List[FieldValidation]: |
| | """Parse LLM validation response.""" |
| | import json |
| | import re |
| |
|
| | validations = [] |
| |
|
| | |
| | json_match = re.search(r'\{[\s\S]*\}', response) |
| | if json_match: |
| | try: |
| | data = json.loads(json_match.group()) |
| | llm_validations = data.get("validations", []) |
| |
|
| | for v in llm_validations: |
| | field_name = v.get("field", "") |
| | if field_name not in fields: |
| | continue |
| |
|
| | status_str = v.get("status", "uncertain").lower() |
| | try: |
| | status = ValidationStatus(status_str) |
| | except ValueError: |
| | status = ValidationStatus.UNCERTAIN |
| |
|
| | validation = FieldValidation( |
| | field_name=field_name, |
| | extracted_value=fields[field_name], |
| | status=status, |
| | confidence=float(v.get("confidence", 0.5)), |
| | reasoning=v.get("reasoning", ""), |
| | evidence_found=status != ValidationStatus.NO_EVIDENCE, |
| | suggested_value=v.get("suggested_value"), |
| | ) |
| | validations.append(validation) |
| |
|
| | except json.JSONDecodeError: |
| | pass |
| |
|
| | |
| | validated_fields = {v.field_name for v in validations} |
| | for field_name, value in fields.items(): |
| | if field_name not in validated_fields: |
| | validations.append(FieldValidation( |
| | field_name=field_name, |
| | extracted_value=value, |
| | status=ValidationStatus.UNCERTAIN, |
| | confidence=0.5, |
| | reasoning="Could not validate", |
| | evidence_found=False, |
| | )) |
| |
|
| | return validations |
| |
|
| | def _heuristic_validation( |
| | self, |
| | fields: Dict[str, Any], |
| | evidence: List[Dict[str, Any]], |
| | ) -> List[FieldValidation]: |
| | """Heuristic validation when LLM fails.""" |
| | validations = [] |
| | evidence_text = " ".join( |
| | ev.get("text", ev.get("snippet", "")).lower() |
| | for ev in evidence |
| | ) |
| |
|
| | for field_name, value in fields.items(): |
| | |
| | value_str = str(value).lower() |
| | found = value_str in evidence_text if value_str else False |
| |
|
| | if found: |
| | status = ValidationStatus.VALID |
| | confidence = 0.7 |
| | reasoning = "Value found in evidence" |
| | elif evidence: |
| | status = ValidationStatus.UNCERTAIN |
| | confidence = 0.4 |
| | reasoning = "Value not directly found in evidence" |
| | else: |
| | status = ValidationStatus.NO_EVIDENCE |
| | confidence = 0.2 |
| | reasoning = "No evidence available" |
| |
|
| | validations.append(FieldValidation( |
| | field_name=field_name, |
| | extracted_value=value, |
| | status=status, |
| | confidence=confidence, |
| | reasoning=reasoning, |
| | evidence_found=found, |
| | )) |
| |
|
| | return validations |
| |
|
| | def _create_no_evidence_result( |
| | self, |
| | fields: Dict[str, Any], |
| | ) -> ValidationResult: |
| | """Create result when no evidence is available.""" |
| | validations = [ |
| | FieldValidation( |
| | field_name=name, |
| | extracted_value=value, |
| | status=ValidationStatus.NO_EVIDENCE, |
| | confidence=0.0, |
| | reasoning="No evidence provided for validation", |
| | evidence_found=False, |
| | ) |
| | for name, value in fields.items() |
| | ] |
| |
|
| | return ValidationResult( |
| | overall_status=ValidationStatus.ABSTAIN, |
| | overall_confidence=0.0, |
| | field_validations=validations, |
| | abstain_count=len(validations), |
| | should_accept=False, |
| | abstain_reason="No evidence available for validation", |
| | ) |
| |
|
| |
|
| | |
| | _extraction_critic: Optional[ExtractionCritic] = None |
| |
|
| |
|
| | def get_extraction_critic( |
| | config: Optional[CriticConfig] = None, |
| | ) -> ExtractionCritic: |
| | """Get or create singleton extraction critic.""" |
| | global _extraction_critic |
| | if _extraction_critic is None: |
| | _extraction_critic = ExtractionCritic(config) |
| | return _extraction_critic |
| |
|
| |
|
| | def reset_extraction_critic(): |
| | """Reset the global critic instance.""" |
| | global _extraction_critic |
| | _extraction_critic = None |
| |
|