| | import os |
| | from abc import ABC, abstractmethod |
| | from google import genai |
| | from google.genai import types |
| | from pydantic import BaseModel |
| | import numpy as np |
| | from typing import List, Any, Dict, Tuple |
| | import time |
| | from langchain_text_splitters import HTMLHeaderTextSplitter |
| | from sentence_transformers import SentenceTransformer |
| |
|
| | class LLMClient(ABC): |
| | """ |
| | Abstract base class for calling LLM APIs. |
| | """ |
| | def __init__(self, config: dict = None): |
| | """ |
| | Initializes the LLMClient with a configuration dictionary. |
| | |
| | Args: |
| | config (dict): Configuration settings for the LLM client. |
| | """ |
| | self.config = config or {} |
| |
|
| | @abstractmethod |
| | def call_api(self, prompt: str) -> str: |
| | """ |
| | Call the underlying LLM API with the given prompt. |
| | |
| | Args: |
| | prompt (str): The prompt or input text for the LLM. |
| | |
| | Returns: |
| | str: The response from the LLM. |
| | """ |
| | pass |
| |
|
| |
|
| | class GeminiLLMClient(LLMClient): |
| | """ |
| | Concrete implementation of LLMClient for the Gemini API. |
| | """ |
| |
|
| | def __init__(self, config: dict): |
| | """ |
| | Initializes the GeminiLLMClient with an API key, model name, and optional generation settings. |
| | |
| | Args: |
| | config (dict): Configuration containing: |
| | - 'api_key': (optional) API key for Gemini (falls back to GEMINI_API_KEY env var) |
| | - 'model_name': (optional) the model to use (default 'gemini-2.0-flash') |
| | - 'generation_config': (optional) dict of GenerateContentConfig parameters |
| | """ |
| | api_key = config.get("api_key") or os.environ.get("GEMINI_API_KEY") |
| | if not api_key: |
| | raise ValueError( |
| | "API key for Gemini must be provided in config['api_key'] or GEMINI_API_KEY env var." |
| | ) |
| | self.client = genai.Client(api_key=api_key) |
| | self.model_name = config.get("model_name", "gemini-2.0-flash") |
| | |
| | gen_conf = config.get("generation_config", {}) |
| | self.generate_config = types.GenerateContentConfig( |
| | response_mime_type=gen_conf.get("response_mime_type", "text/plain"), |
| | temperature=gen_conf.get("temperature"), |
| | max_output_tokens=gen_conf.get("max_output_tokens"), |
| | top_p=gen_conf.get("top_p"), |
| | top_k=gen_conf.get("top_k"), |
| | |
| | ) |
| |
|
| | def call_api(self, prompt: str) -> str: |
| | """ |
| | Call the Gemini API with the given prompt (non-streaming). |
| | |
| | Args: |
| | prompt (str): The input text for the API. |
| | |
| | Returns: |
| | str: The generated text from the Gemini API. |
| | """ |
| | contents = [ |
| | types.Content( |
| | role="user", |
| | parts=[types.Part.from_text(text=prompt)], |
| | ) |
| | ] |
| |
|
| | |
| | response = self.client.models.generate_content( |
| | model=self.model_name, |
| | contents=contents, |
| | config=self.generate_config, |
| | ) |
| |
|
| | |
| | return response.text |
| |
|
| | |
| |
|
| | class AIExtractor: |
| | def __init__(self, llm_client: LLMClient, prompt_template: str): |
| | """ |
| | Initializes the AIExtractor with a specific LLM client and configuration. |
| | |
| | Args: |
| | llm_client (LLMClient): An instance of a class that implements the LLMClient interface. |
| | prompt_template (str): The template to use for generating prompts for the LLM. |
| | should contain placeholders for dynamic content. |
| | e.g., "Extract the following information: {content} based on schema: {schema}" |
| | """ |
| | self.llm_client = llm_client |
| | self.prompt_template = prompt_template |
| |
|
| | def extract(self, content: str, schema: BaseModel) -> str: |
| | """ |
| | Extracts structured information from the given content based on the provided schema. |
| | |
| | Args: |
| | content (str): The raw content to extract information from. |
| | schema (BaseModel): A Pydantic model defining the structure of the expected output. |
| | |
| | Returns: |
| | str: The structured JSON object as a string. |
| | """ |
| | prompt = self.prompt_template.format(content=content, schema=schema.model_json_schema()) |
| | |
| | response = self.llm_client.call_api(prompt) |
| | return response |
| |
|
| | |
| | class RAGExtractor(AIExtractor): |
| | """ |
| | RAG-enhanced extractor that uses similarity search to find relevant chunks |
| | before performing extraction, utilizing HTML header-based chunking and SentenceTransformer embeddings. |
| | """ |
| |
|
| | def __init__(self, |
| | llm_client: LLMClient, |
| | prompt_template: str, |
| | embedding_model_path: str = "sentence-transformers/all-mpnet-base-v2", |
| | top_k: int = 3): |
| | """ |
| | Initialize RAG extractor with embedding and chunking capabilities. |
| | |
| | Args: |
| | llm_client: LLM client for generation. |
| | prompt_template: Template for prompts. |
| | embedding_model_path: Path/name for the SentenceTransformer embedding model. |
| | top_k: Number of top similar chunks to retrieve. |
| | """ |
| | super().__init__(llm_client, prompt_template) |
| | self.embedding_model_path = embedding_model_path |
| | |
| | self.embedding_model_instance = SentenceTransformer(self.embedding_model_path) |
| | self.top_k = top_k |
| |
|
| | @staticmethod |
| | def _langchain_HHTS(text: str) -> List[str]: |
| | """ |
| | Chunks HTML text using Langchain's HTMLHeaderTextSplitter based on h1 and h2 headers. |
| | |
| | Args: |
| | text (str): The HTML content to chunk. |
| | |
| | Returns: |
| | List[str]: A list of chunked text strings (extracted from Document objects' page_content). |
| | """ |
| | headers_to_split_on = [ |
| | ("h1", "Header 1"), |
| | ("h2", "Header 2"), |
| | |
| | ] |
| | html_splitter = HTMLHeaderTextSplitter(headers_to_split_on=headers_to_split_on) |
| | return [doc.page_content for doc in html_splitter.split_text(text)] |
| |
|
| | def embed_text(self, text: str) -> np.ndarray: |
| | """ |
| | Generate embeddings for text using the initialized SentenceTransformer model. |
| | |
| | Args: |
| | text: The text string to embed. |
| | |
| | Returns: |
| | np.ndarray: The embedding vector for the input text as a NumPy array. |
| | """ |
| | try: |
| | return self.embedding_model_instance.encode(text) |
| | except Exception as e: |
| | print(f"Warning: Embedding failed for text: '{text[:50]}...', using random embedding: {e}") |
| |
|
| | return None |
| |
|
| | def search_similar_chunks(self, |
| | query: str, |
| | chunks: List[str], |
| | embeddings: np.ndarray) -> List[str]: |
| | """ |
| | Find the most similar chunks to the query within the given list of chunks |
| | by calculating cosine similarity between their embeddings. |
| | |
| | Args: |
| | query (str): The query text whose embedding will be used for similarity comparison. |
| | chunks (List[str]): A list of text chunks to search within. |
| | embeddings (np.ndarray): Precomputed embeddings for the chunks, corresponding to the 'chunks' list. |
| | |
| | Returns: |
| | List[str]: A list of the 'top_k' most similar chunks to the query. |
| | """ |
| | query_embedding = self.embed_text(query) |
| |
|
| | similarities = [] |
| |
|
| | if query_embedding.ndim > 1: |
| | query_embedding = query_embedding.flatten() |
| |
|
| | for i, chunk_embedding in enumerate(embeddings): |
| | if chunk_embedding.ndim > 1: |
| | chunk_embedding = chunk_embedding.flatten() |
| |
|
| | norm_query = np.linalg.norm(query_embedding) |
| | norm_chunk = np.linalg.norm(chunk_embedding) |
| |
|
| | if norm_query == 0 or norm_chunk == 0: |
| | similarity = 0.0 |
| | else: |
| | similarity = np.dot(query_embedding, chunk_embedding) / (norm_query * norm_chunk) |
| | similarities.append((similarity, i)) |
| |
|
| | similarities.sort(key=lambda x: x[0], reverse=True) |
| | top_indices = [idx for _, idx in similarities[:self.top_k]] |
| |
|
| | return [chunks[i] for i in top_indices] |
| |
|
| | def extract(self, content: str, schema: BaseModel, query: str = None) -> str: |
| | """ |
| | Overrides the base AIExtractor's method to implement RAG-enhanced extraction. |
| | This function first chunks the input HTML content, then uses a query to find |
| | the most relevant chunks via embedding similarity, and finally sends these |
| | relevant chunks as context to the LLM for structured information extraction. |
| | |
| | Args: |
| | content (str): The raw HTML content from which to extract information. |
| | schema (BaseModel): A Pydantic model defining the desired output structure for the LLM. |
| | query (str, optional): An optional query string to guide the retrieval of relevant chunks. |
| | If not provided, a default query based on the schema will be used. |
| | |
| | Returns: |
| | str: The structured JSON object as a string, as generated by the LLM. |
| | """ |
| | start_time = time.time() |
| |
|
| | if not query: |
| | query = f"Extract information based on the following JSON schema: {schema.model_json_schema()}" |
| | print(f"No explicit query provided for retrieval. Using default: '{query[:100]}...'") |
| |
|
| | chunks = self._langchain_HHTS(content) |
| | print(f"Content successfully chunked into {len(chunks)} pieces.") |
| |
|
| | combined_content_for_llm = "" |
| | if not chunks: |
| | print("Warning: No chunks were generated from the provided content. The entire original content will be sent to the LLM.") |
| | combined_content_for_llm = content |
| | else: |
| | chunk_embeddings = np.array([self.embed_text(chunk) for chunk in chunks]) |
| | print(f"Generated embeddings for {len(chunks)} chunks.") |
| |
|
| | similar_chunks = self.search_similar_chunks(query, chunks, chunk_embeddings) |
| | print(f"Retrieved {len(similar_chunks)} similar chunks based on the query.") |
| |
|
| | combined_content_for_llm = "\n\n".join(similar_chunks) |
| | print(f"Combined content for LLM (truncated): '{combined_content_for_llm[:200]}...'") |
| |
|
| | prompt = self.prompt_template.format(content=combined_content_for_llm, schema=schema.model_json_schema()) |
| | print(f"Sending prompt to LLM (truncated): '{prompt[:500]}...'") |
| | llm_response = self.llm_client.call_api(prompt) |
| |
|
| | execution_time = (time.time() - start_time) * 1000 |
| | print(f"Extraction process completed in {execution_time:.2f} milliseconds.") |
| | print(f"LLM's final response: {llm_response}") |
| | print("=" * 78) |
| |
|
| | return llm_response |