| """ |
| Extraction Schemas for Document Intelligence |
| |
| Pydantic models for schema-based field extraction, tables, and charts. |
| """ |
|
|
| from enum import Enum |
| from typing import List, Dict, Any, Optional, Union |
| from pydantic import BaseModel, Field |
|
|
| from .core import BoundingBox, EvidenceRef |
|
|
|
|
| class FieldType(str, Enum): |
| """Supported field types for extraction.""" |
| STRING = "string" |
| INTEGER = "integer" |
| FLOAT = "float" |
| BOOLEAN = "boolean" |
| DATE = "date" |
| CURRENCY = "currency" |
| PERCENTAGE = "percentage" |
| EMAIL = "email" |
| PHONE = "phone" |
| ADDRESS = "address" |
| LIST = "list" |
| OBJECT = "object" |
|
|
|
|
| class FieldDefinition(BaseModel): |
| """ |
| Definition of a field to extract from a document. |
| Used to build extraction schemas. |
| """ |
| name: str = Field(..., description="Field name/key") |
| type: FieldType = Field(..., description="Expected data type") |
| description: str = Field(..., description="Human-readable description") |
| required: bool = Field(default=False, description="Whether field is required") |
|
|
| |
| pattern: Optional[str] = Field(default=None, description="Regex pattern for validation") |
| min_value: Optional[float] = Field(default=None, description="Minimum numeric value") |
| max_value: Optional[float] = Field(default=None, description="Maximum numeric value") |
| enum_values: Optional[List[str]] = Field(default=None, description="Allowed values") |
|
|
| |
| aliases: List[str] = Field( |
| default_factory=list, |
| description="Alternative names/labels for the field" |
| ) |
| search_context: Optional[str] = Field( |
| default=None, |
| description="Context hint for where to find this field" |
| ) |
|
|
| |
| nested_fields: Optional[List["FieldDefinition"]] = Field( |
| default=None, |
| description="Nested field definitions for complex types" |
| ) |
|
|
|
|
| class ExtractionSchema(BaseModel): |
| """ |
| Schema defining fields to extract from a document. |
| Supports document-type-specific extraction rules. |
| """ |
| schema_id: str = Field(..., description="Unique schema identifier") |
| name: str = Field(..., description="Human-readable schema name") |
| description: str = Field(..., description="Schema description") |
| version: str = Field(default="1.0", description="Schema version") |
|
|
| |
| fields: List[FieldDefinition] = Field( |
| default_factory=list, |
| description="Fields to extract" |
| ) |
|
|
| |
| document_types: List[str] = Field( |
| default_factory=list, |
| description="Applicable document types" |
| ) |
|
|
| |
| cross_field_validations: List[str] = Field( |
| default_factory=list, |
| description="Cross-field validation expressions" |
| ) |
|
|
| |
| require_evidence: bool = Field( |
| default=True, |
| description="Require evidence for all extracted fields" |
| ) |
| min_confidence: float = Field( |
| default=0.7, |
| ge=0.0, |
| le=1.0, |
| description="Minimum confidence threshold" |
| ) |
| abstain_on_low_confidence: bool = Field( |
| default=True, |
| description="Abstain rather than guess when confidence is low" |
| ) |
|
|
| def get_field(self, name: str) -> Optional[FieldDefinition]: |
| """Get field definition by name.""" |
| for field in self.fields: |
| if field.name == name or name in field.aliases: |
| return field |
| return None |
|
|
| def get_required_fields(self) -> List[FieldDefinition]: |
| """Get all required field definitions.""" |
| return [f for f in self.fields if f.required] |
|
|
|
|
| class TableCell(BaseModel): |
| """ |
| Single cell in a table structure. |
| """ |
| cell_id: str = Field(..., description="Unique cell identifier") |
| row: int = Field(..., ge=0, description="Row index (0-based)") |
| col: int = Field(..., ge=0, description="Column index (0-based)") |
| text: str = Field(..., description="Cell text content") |
| bbox: BoundingBox = Field(..., description="Cell bounding box") |
|
|
| |
| row_span: int = Field(default=1, ge=1, description="Number of rows spanned") |
| col_span: int = Field(default=1, ge=1, description="Number of columns spanned") |
|
|
| |
| is_header: bool = Field(default=False, description="Whether cell is a header") |
| is_empty: bool = Field(default=False, description="Whether cell is empty") |
|
|
| |
| confidence: float = Field(default=1.0, ge=0.0, le=1.0) |
|
|
|
|
| class TableData(BaseModel): |
| """ |
| Structured table data extracted from a document. |
| """ |
| table_id: str = Field(..., description="Unique table identifier") |
| page: int = Field(..., ge=0, description="Page number") |
| bbox: BoundingBox = Field(..., description="Table bounding box") |
|
|
| |
| num_rows: int = Field(..., ge=1, description="Number of rows") |
| num_cols: int = Field(..., ge=1, description="Number of columns") |
| cells: List[TableCell] = Field(default_factory=list, description="All cells") |
|
|
| |
| header_rows: List[int] = Field( |
| default_factory=list, |
| description="Row indices that are headers" |
| ) |
| header_cols: List[int] = Field( |
| default_factory=list, |
| description="Column indices that are headers" |
| ) |
|
|
| |
| caption: Optional[str] = Field(default=None, description="Table caption") |
| caption_bbox: Optional[BoundingBox] = Field(default=None) |
|
|
| |
| confidence: float = Field(default=1.0, ge=0.0, le=1.0) |
|
|
| |
| evidence: Optional[EvidenceRef] = Field(default=None) |
|
|
| def to_markdown(self) -> str: |
| """Convert table to markdown format.""" |
| if not self.cells: |
| return "" |
|
|
| |
| grid = [[None for _ in range(self.num_cols)] for _ in range(self.num_rows)] |
| for cell in self.cells: |
| if cell.row < self.num_rows and cell.col < self.num_cols: |
| grid[cell.row][cell.col] = cell.text |
|
|
| |
| lines = [] |
| for i, row in enumerate(grid): |
| line = "| " + " | ".join(str(c) if c else "" for c in row) + " |" |
| lines.append(line) |
| if i == 0 or i in self.header_rows: |
| lines.append("|" + "|".join(["---"] * self.num_cols) + "|") |
|
|
| return "\n".join(lines) |
|
|
| def to_dict_list(self) -> List[Dict[str, str]]: |
| """Convert table to list of dictionaries (using first row as keys).""" |
| if not self.cells or self.num_rows < 2: |
| return [] |
|
|
| |
| grid = [[None for _ in range(self.num_cols)] for _ in range(self.num_rows)] |
| for cell in self.cells: |
| if cell.row < self.num_rows and cell.col < self.num_cols: |
| grid[cell.row][cell.col] = cell.text |
|
|
| |
| headers = [str(h) if h else f"col_{i}" for i, h in enumerate(grid[0])] |
|
|
| |
| result = [] |
| for row in grid[1:]: |
| row_dict = {headers[i]: str(v) if v else "" for i, v in enumerate(row)} |
| result.append(row_dict) |
|
|
| return result |
|
|
|
|
| class ChartType(str, Enum): |
| """Types of charts/graphs.""" |
| BAR = "bar" |
| LINE = "line" |
| PIE = "pie" |
| SCATTER = "scatter" |
| AREA = "area" |
| HISTOGRAM = "histogram" |
| BOX = "box" |
| HEATMAP = "heatmap" |
| TREEMAP = "treemap" |
| FLOWCHART = "flowchart" |
| DIAGRAM = "diagram" |
| OTHER = "other" |
|
|
|
|
| class ChartData(BaseModel): |
| """ |
| Structured chart/graph data extracted from a document. |
| """ |
| chart_id: str = Field(..., description="Unique chart identifier") |
| page: int = Field(..., ge=0, description="Page number") |
| bbox: BoundingBox = Field(..., description="Chart bounding box") |
| chart_type: ChartType = Field(..., description="Type of chart") |
|
|
| |
| title: Optional[str] = Field(default=None, description="Chart title") |
| x_axis_label: Optional[str] = Field(default=None, description="X-axis label") |
| y_axis_label: Optional[str] = Field(default=None, description="Y-axis label") |
|
|
| |
| series: List[Dict[str, Any]] = Field( |
| default_factory=list, |
| description="Data series extracted from chart" |
| ) |
|
|
| |
| trends: List[str] = Field( |
| default_factory=list, |
| description="Identified trends or patterns" |
| ) |
|
|
| |
| caption: Optional[str] = Field(default=None, description="Chart caption") |
|
|
| |
| confidence: float = Field(default=1.0, ge=0.0, le=1.0) |
| evidence: Optional[EvidenceRef] = Field(default=None) |
|
|
| |
| description: Optional[str] = Field( |
| default=None, |
| description="Natural language description of the chart" |
| ) |
|
|
|
|
| class ExtractedField(BaseModel): |
| """ |
| A single extracted field value with evidence. |
| """ |
| field_name: str = Field(..., description="Field name from schema") |
| value: Any = Field(..., description="Extracted value") |
| confidence: float = Field(..., ge=0.0, le=1.0, description="Extraction confidence") |
| evidence: List[EvidenceRef] = Field( |
| default_factory=list, |
| description="Evidence supporting the extraction" |
| ) |
|
|
| |
| is_valid: bool = Field(default=True, description="Whether value passed validation") |
| validation_errors: List[str] = Field( |
| default_factory=list, |
| description="Validation error messages" |
| ) |
|
|
| |
| abstained: bool = Field( |
| default=False, |
| description="Whether extraction was abstained" |
| ) |
| abstain_reason: Optional[str] = Field( |
| default=None, |
| description="Reason for abstention" |
| ) |
|
|
| @property |
| def is_grounded(self) -> bool: |
| """Check if extraction has evidence.""" |
| return len(self.evidence) > 0 and not self.abstained |
|
|