from pydantic import BaseModel from langchain_core.prompts import ChatPromptTemplate from langchain_openai import AzureChatOpenAI from tenacity import ( retry, stop_after_attempt, wait_exponential, retry_if_exception_type ) from typing import Dict from externals.observability.langfuse import langfuse_handler, langfuse from services.llms.LLM import model_5mini, model_4omini from utils.decorator import trace_runtime from utils.logger import get_logger logger = get_logger("base generator") class MetadataObservability(BaseModel): fullname: str task_id: str agent: str class BaseAIGenerator: """ Args: name:str, prompt: ChatPromptTemplate, input_llm: Dict, metadata_observability: MetadataObservability, output_model: BaseModel, llm:AzureChatOpenAI = model_5mini | model_4omini, """ def __init__(self, task_name:str, prompt: ChatPromptTemplate, input_llm: Dict, metadata_observability: MetadataObservability, llm:AzureChatOpenAI = model_5mini | model_4omini, ): self.name = task_name self.llm = llm self.prompt = prompt self.input_llm = input_llm self.metadata_observability = metadata_observability @retry( reraise=True, stop=stop_after_attempt(2), # retry max 3 times wait=wait_exponential(multiplier=1, min=1, max=5), retry=retry_if_exception_type(Exception) # retry on any exception from LLM ) async def _asafe_invoke(self, chain, input_llm, config): """private helper for retries""" return await chain.ainvoke(input_llm, config=config) @retry( reraise=True, stop=stop_after_attempt(2), # retry max 3 times wait=wait_exponential(multiplier=1, min=1, max=5), retry=retry_if_exception_type(Exception) # retry on any exception from LLM ) async def _safe_invoke(self, chain, input_llm, config): """private helper for retries""" return chain.invoke(input_llm, config=config) @trace_runtime async def agenerate(self): trace = None try: # ✅ Create trace (no context manager, no end()) trace = langfuse.trace( name=self.name, input=self.input_llm, user_id=self.metadata_observability.fullname, session_id=self.metadata_observability.task_id, metadata=self.metadata_observability.model_dump(), ) handler = trace.get_langchain_handler() config = {"callbacks": [handler]} chain = self.prompt | self.llm output = await self._asafe_invoke( chain=chain, input_llm=self.input_llm, config=config, ) trace.update(output=output) return output except Exception as e: logger.exception("❌ BaseGenerator agenerate error") if trace: trace.update( status="error", error=str(e), ) return None @trace_runtime async def generate(self): trace = None try: trace = langfuse.trace( name=self.name, input=self.input_llm, user_id=self.metadata_observability.fullname, session_id=self.metadata_observability.task_id, metadata=self.metadata_observability.model_dump(), ) handler = trace.get_langchain_handler() config = {"callbacks": [handler]} chain = self.prompt | self.llm output = self._safe_invoke( chain=chain, input_llm=self.input_llm, config=config, ) trace.update(output=output) return output except Exception as e: logger.exception("❌ BaseGenerator generate error") if trace: trace.update( status="error", error=str(e), ) return None