from langfuse import get_client, Langfuse, propagate_attributes from langfuse.langchain import CallbackHandler import os from config.constant import LangfuseConstants from pydantic import BaseModel from langchain_core.prompts import ChatPromptTemplate from langchain_openai import AzureChatOpenAI from tenacity import ( retry, stop_after_attempt, wait_exponential, retry_if_exception_type ) from typing import Dict from services.llms.LLM import model_5mini, model_4omini from utils.decorator import trace_runtime from utils.logger import get_logger logger = get_logger("base generator") # Set environment variables at module level os.environ["LANGFUSE_PUBLIC_KEY"] = LangfuseConstants.PUBLIC_KEY os.environ["LANGFUSE_SECRET_KEY"] = LangfuseConstants.SECRET_KEY os.environ["LANGFUSE_HOST"] = LangfuseConstants.HOST or "https://us.cloud.langfuse.com" class MetadataObservability(BaseModel): fullname: str task_id: str agent: str user_id: str class BaseAIGenerator: def __init__(self, task_name: str, prompt: ChatPromptTemplate, input_llm: Dict, metadata_observability: MetadataObservability, llm: AzureChatOpenAI = model_5mini | model_4omini, ): self.metadata_observability = metadata_observability self.llm = llm self.prompt = prompt self.input_llm = input_llm self.name = task_name def _get_langfuse_client(self): try: # Environment variables already set at module level return get_client() except Exception as e: logger.warning(f"⚠️ Langfuse unavailable, skipping observability: {e}") return None def _get_langfuse_config(self): try: # Environment variables already set at module level handler = CallbackHandler() return { "callbacks": [handler], "metadata": { "langfuse_session_id": self.metadata_observability.task_id, "langfuse_user_id": self.metadata_observability.user_id, "langfuse_tags": [self.metadata_observability.agent], }, } except Exception as e: logger.warning(f"⚠️ Langfuse unavailable, skipping observability: {e}") return {} @retry( reraise=True, stop=stop_after_attempt(2), wait=wait_exponential(multiplier=1, min=1, max=5), retry=retry_if_exception_type(Exception) ) async def _asafe_invoke(self, chain, input_llm, config): return await chain.ainvoke(input_llm, config=config) @retry( reraise=True, stop=stop_after_attempt(2), wait=wait_exponential(multiplier=1, min=1, max=5), retry=retry_if_exception_type(Exception) ) def _safe_invoke(self, chain, input_llm, config): return chain.invoke(input_llm, config=config) @trace_runtime async def agenerate(self): try: config = self._get_langfuse_config() chain = self.prompt | self.llm langfuse_client = self._get_langfuse_client() if not langfuse_client: return await self._asafe_invoke(chain, self.input_llm, config) trace_id = Langfuse.create_trace_id(seed=self.metadata_observability.task_id) with langfuse_client.start_as_current_observation( as_type="generation", name=self.name, trace_context={"trace_id": trace_id}, metadata=self.metadata_observability, ) as span: with propagate_attributes( user_id=self.metadata_observability.user_id, session_id=self.metadata_observability.task_id, tags=[self.metadata_observability.agent], ): span.update_trace( input=self.input_llm, ) output = await self._asafe_invoke( chain=chain, input_llm=self.input_llm, config=config, ) span.update_trace(output=output) return output except Exception: logger.exception("❌ BaseGenerator agenerate error") return None @trace_runtime def generate(self): try: config = self._get_langfuse_config() chain = self.prompt | self.llm langfuse_client = self._get_langfuse_client() if not langfuse_client: return self._safe_invoke(chain, self.input_llm, config) trace_id = Langfuse.create_trace_id(seed=self.metadata_observability.task_id) with langfuse_client.start_as_current_observation( as_type="generation", name=self.name, trace_context={"trace_id": trace_id}, metadata=self.metadata_observability, ) as span: with propagate_attributes( user_id=self.metadata_observability.user_id, session_id=self.metadata_observability.task_id, tags=[self.metadata_observability.agent], ): span.update_trace( input=self.input_llm, ) output = self._safe_invoke( chain=chain, input_llm=self.input_llm, config=config, ) span.update_trace(output=output) return output except Exception: logger.exception("❌ BaseGenerator generate error") return None