Spaces:
Running
Running
| from pydantic import BaseModel | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_openai import AzureChatOpenAI | |
| from tenacity import ( | |
| retry, | |
| stop_after_attempt, | |
| wait_exponential, | |
| retry_if_exception_type | |
| ) | |
| from typing import Dict | |
| from externals.observability.langfuse import langfuse_handler, langfuse | |
| from services.llms.LLM import model_5mini, model_4omini | |
| from utils.decorator import trace_runtime | |
| from utils.logger import get_logger | |
| logger = get_logger("base generator") | |
| class MetadataObservability(BaseModel): | |
| fullname: str | |
| task_id: str | |
| agent: str | |
| class BaseAIGenerator: | |
| """ | |
| Args: | |
| name:str, | |
| prompt: ChatPromptTemplate, | |
| input_llm: Dict, | |
| metadata_observability: MetadataObservability, | |
| output_model: BaseModel, | |
| llm:AzureChatOpenAI = model_5mini | model_4omini, | |
| """ | |
| def __init__(self, | |
| task_name:str, | |
| prompt: ChatPromptTemplate, | |
| input_llm: Dict, | |
| metadata_observability: MetadataObservability, | |
| llm:AzureChatOpenAI = model_5mini | model_4omini, | |
| ): | |
| self.name = task_name | |
| self.llm = llm | |
| self.prompt = prompt | |
| self.input_llm = input_llm | |
| self.metadata_observability = metadata_observability | |
| async def _asafe_invoke(self, chain, input_llm, config): | |
| """private helper for retries""" | |
| return await chain.ainvoke(input_llm, config=config) | |
| async def _safe_invoke(self, chain, input_llm, config): | |
| """private helper for retries""" | |
| return chain.invoke(input_llm, config=config) | |
| async def agenerate(self): | |
| trace = None | |
| try: | |
| # ✅ Create trace (no context manager, no end()) | |
| trace = langfuse.trace( | |
| name=self.name, | |
| input=self.input_llm, | |
| user_id=self.metadata_observability.fullname, | |
| session_id=self.metadata_observability.task_id, | |
| metadata=self.metadata_observability.model_dump(), | |
| ) | |
| handler = trace.get_langchain_handler() | |
| config = {"callbacks": [handler]} | |
| chain = self.prompt | self.llm | |
| output = await self._asafe_invoke( | |
| chain=chain, | |
| input_llm=self.input_llm, | |
| config=config, | |
| ) | |
| trace.update(output=output) | |
| return output | |
| except Exception as e: | |
| logger.exception("❌ BaseGenerator agenerate error") | |
| if trace: | |
| trace.update( | |
| status="error", | |
| error=str(e), | |
| ) | |
| return None | |
| async def generate(self): | |
| trace = None | |
| try: | |
| trace = langfuse.trace( | |
| name=self.name, | |
| input=self.input_llm, | |
| user_id=self.metadata_observability.fullname, | |
| session_id=self.metadata_observability.task_id, | |
| metadata=self.metadata_observability.model_dump(), | |
| ) | |
| handler = trace.get_langchain_handler() | |
| config = {"callbacks": [handler]} | |
| chain = self.prompt | self.llm | |
| output = self._safe_invoke( | |
| chain=chain, | |
| input_llm=self.input_llm, | |
| config=config, | |
| ) | |
| trace.update(output=output) | |
| return output | |
| except Exception as e: | |
| logger.exception("❌ BaseGenerator generate error") | |
| if trace: | |
| trace.update( | |
| status="error", | |
| error=str(e), | |
| ) | |
| return None | |