CandidateExplorer / services /base /BaseGenerator.py
ishaq101's picture
[NOTICKET] Fix: model cost display in langfuse
4324a46
from pydantic import BaseModel
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import AzureChatOpenAI
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type
)
from typing import Dict
from externals.observability.langfuse import langfuse_handler, langfuse
from services.llms.LLM import model_5mini, model_4omini
from utils.decorator import trace_runtime
from utils.logger import get_logger
logger = get_logger("base generator")
class MetadataObservability(BaseModel):
fullname: str
task_id: str
agent: str
class BaseAIGenerator:
"""
Args:
name:str,
prompt: ChatPromptTemplate,
input_llm: Dict,
metadata_observability: MetadataObservability,
output_model: BaseModel,
llm:AzureChatOpenAI = model_5mini | model_4omini,
"""
def __init__(self,
task_name:str,
prompt: ChatPromptTemplate,
input_llm: Dict,
metadata_observability: MetadataObservability,
llm:AzureChatOpenAI = model_5mini | model_4omini,
):
self.name = task_name
self.llm = llm
self.prompt = prompt
self.input_llm = input_llm
self.metadata_observability = metadata_observability
@retry(
reraise=True,
stop=stop_after_attempt(2), # retry max 3 times
wait=wait_exponential(multiplier=1, min=1, max=5),
retry=retry_if_exception_type(Exception) # retry on any exception from LLM
)
async def _asafe_invoke(self, chain, input_llm, config):
"""private helper for retries"""
return await chain.ainvoke(input_llm, config=config)
@retry(
reraise=True,
stop=stop_after_attempt(2), # retry max 3 times
wait=wait_exponential(multiplier=1, min=1, max=5),
retry=retry_if_exception_type(Exception) # retry on any exception from LLM
)
async def _safe_invoke(self, chain, input_llm, config):
"""private helper for retries"""
return chain.invoke(input_llm, config=config)
@trace_runtime
async def agenerate(self):
trace = None
try:
# ✅ Create trace (no context manager, no end())
trace = langfuse.trace(
name=self.name,
input=self.input_llm,
user_id=self.metadata_observability.fullname,
session_id=self.metadata_observability.task_id,
metadata=self.metadata_observability.model_dump(),
)
handler = trace.get_langchain_handler()
config = {"callbacks": [handler]}
chain = self.prompt | self.llm
output = await self._asafe_invoke(
chain=chain,
input_llm=self.input_llm,
config=config,
)
trace.update(output=output)
return output
except Exception as e:
logger.exception("❌ BaseGenerator agenerate error")
if trace:
trace.update(
status="error",
error=str(e),
)
return None
@trace_runtime
async def generate(self):
trace = None
try:
trace = langfuse.trace(
name=self.name,
input=self.input_llm,
user_id=self.metadata_observability.fullname,
session_id=self.metadata_observability.task_id,
metadata=self.metadata_observability.model_dump(),
)
handler = trace.get_langchain_handler()
config = {"callbacks": [handler]}
chain = self.prompt | self.llm
output = self._safe_invoke(
chain=chain,
input_llm=self.input_llm,
config=config,
)
trace.update(output=output)
return output
except Exception as e:
logger.exception("❌ BaseGenerator generate error")
if trace:
trace.update(
status="error",
error=str(e),
)
return None