| """ |
| LLM Inference Module |
| |
| This module handles all interactions with the Groq API via LangChain, |
| allowing the application to generate EDA insights and feature engineering |
| recommendations from dataset analysis. |
| """ |
|
|
| import os |
| from dotenv import load_dotenv |
| import logging |
| import time |
| from typing import Dict, Any, List, Optional |
| from langchain_community.callbacks.manager import get_openai_callback |
|
|
| |
| from langchain_groq import ChatGroq |
| from langchain_core.messages import HumanMessage |
| from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate |
| from langchain_community.callbacks.manager import get_openai_callback |
| from langchain_core.runnables import RunnableSequence |
|
|
| |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") |
| logger = logging.getLogger(__name__) |
|
|
| |
| load_dotenv() |
| GROQ_API_KEY = os.getenv("GROQ_API_KEY") |
|
|
| if not GROQ_API_KEY: |
| raise ValueError("GROQ_API_KEY not found in environment variables. Please add it to your .env file.") |
|
|
| |
| try: |
| llm = ChatGroq(model_name="llama3-8b-8192", groq_api_key=GROQ_API_KEY) |
| logger.info("Successfully initialized Groq client") |
| except Exception as e: |
| logger.error(f"Failed to initialize Groq client: {str(e)}") |
| raise |
|
|
| class LLMInference: |
| """Class for interacting with LLM via Groq API using LangChain""" |
| |
| def __init__(self, model_id: str = "llama3-8b-8192"): |
| """Initialize the LLM inference class with Groq model""" |
| self.model_id = model_id |
| self.llm = llm |
| |
| |
| self._init_prompt_templates() |
| self._init_chains() |
| |
| logger.info(f"LLMInference initialized with model: {model_id}") |
| |
| def _init_prompt_templates(self): |
| """Initialize all prompt templates""" |
| |
| |
| self.eda_prompt_template = ChatPromptTemplate.from_messages([ |
| HumanMessagePromptTemplate.from_template( |
| """You are a data scientist tasked with performing Exploratory Data Analysis (EDA) on a dataset. |
| Based on the following dataset information, provide comprehensive EDA insights: |
| |
| Dataset Information: |
| - Shape: {shape} |
| - Columns and their types: |
| {columns_info} |
| |
| - Missing values: |
| {missing_info} |
| |
| - Basic statistics: |
| {basic_stats} |
| |
| - Top correlations: |
| {correlations} |
| |
| - Sample data: |
| {sample_data} |
| |
| Please provide a detailed EDA analysis that includes: |
| |
| 1. Summary of the dataset (what it appears to be about, key features, etc.) |
| 2. Distribution analysis of key variables |
| 3. Relationship analysis between variables |
| 4. Identification of patterns, outliers, or anomalies |
| 5. Recommended visualizations that would be insightful |
| 6. Initial hypotheses based on the data |
| |
| Your analysis should be structured, thorough, and provide actionable insights for further investigation. |
| """ |
| ) |
| ]) |
| |
| |
| self.feature_engineering_prompt_template = ChatPromptTemplate.from_messages([ |
| HumanMessagePromptTemplate.from_template( |
| """You are a machine learning engineer specializing in feature engineering. |
| Based on the following dataset information, provide recommendations for feature engineering: |
| |
| Dataset Information: |
| - Shape: {shape} |
| - Columns and their types: |
| {columns_info} |
| |
| - Basic statistics: |
| {basic_stats} |
| |
| - Top correlations: |
| {correlations} |
| |
| Please provide comprehensive feature engineering recommendations that include: |
| |
| 1. Numerical feature transformations (scaling, normalization, log transforms, etc.) |
| 2. Categorical feature encoding strategies |
| 3. Feature interaction suggestions |
| 4. Dimensionality reduction approaches if applicable |
| 5. Time-based feature creation if applicable |
| 6. Text processing techniques if there are text fields |
| 7. Feature selection recommendations |
| |
| For each recommendation, explain why it would be beneficial and how it could improve model performance. |
| Be specific to this dataset's characteristics rather than providing generic advice. |
| """ |
| ) |
| ]) |
| |
| |
| self.data_quality_prompt_template = ChatPromptTemplate.from_messages([ |
| HumanMessagePromptTemplate.from_template( |
| """You are a data quality expert. |
| Based on the following dataset information, provide data quality insights and recommendations: |
| |
| Dataset Information: |
| - Shape: {shape} |
| - Columns and their types: |
| {columns_info} |
| |
| - Missing values: |
| {missing_info} |
| |
| - Basic statistics: |
| {basic_stats} |
| |
| Please provide a comprehensive data quality assessment that includes: |
| |
| 1. Assessment of data completeness (missing values) |
| 2. Identification of potential data inconsistencies or errors |
| 3. Recommendations for data cleaning and preprocessing |
| 4. Advice on handling outliers |
| 5. Suggestions for data validation checks |
| 6. Recommendations to improve data quality |
| |
| Your assessment should be specific to this dataset and provide actionable recommendations. |
| """ |
| ) |
| ]) |
| |
| |
| self.qa_prompt_template = ChatPromptTemplate.from_messages([ |
| HumanMessagePromptTemplate.from_template( |
| """You are a data scientist answering questions about a dataset. |
| Based on the following dataset information, please answer the user's question: |
| |
| Dataset Information: |
| - Shape: {shape} |
| - Columns and their types: |
| {columns_info} |
| |
| - Basic statistics: |
| {basic_stats} |
| |
| User's question: {question} |
| |
| Please provide a clear, informative answer to the user's question based on the dataset information provided. |
| """ |
| ) |
| ]) |
| |
| def _init_chains(self): |
| """Initialize all chains using modern RunnableSequence pattern""" |
| |
| |
| self.eda_chain = self.eda_prompt_template | self.llm |
| |
| |
| self.feature_engineering_chain = self.feature_engineering_prompt_template | self.llm |
| |
| |
| self.data_quality_chain = self.data_quality_prompt_template | self.llm |
| |
| |
| self.qa_chain = self.qa_prompt_template | self.llm |
| |
| def _format_columns_info(self, columns: List[str], dtypes: Dict[str, str]) -> str: |
| """Format columns info for prompt""" |
| return "\n".join([f"- {col} ({dtypes.get(col, 'unknown')})" for col in columns]) |
| |
| def _format_missing_info(self, missing_values: Dict[str, tuple]) -> str: |
| """Format missing values info for prompt""" |
| missing_info = "\n".join([f"- {col}: {count} missing values ({percent}%)" |
| for col, (count, percent) in missing_values.items() if count > 0]) |
| |
| if not missing_info: |
| missing_info = "No missing values detected." |
| |
| return missing_info |
| |
| def _execute_chain( |
| self, |
| chain: RunnableSequence, |
| input_data: Dict[str, Any], |
| operation_name: str |
| ) -> str: |
| """ |
| Execute a chain with tracking and error handling |
| |
| Args: |
| chain: The LangChain chain to execute |
| input_data: The input data for the chain |
| operation_name: Name of the operation for logging |
| |
| Returns: |
| str: The generated text |
| """ |
| try: |
| start_time = time.time() |
| with get_openai_callback() as cb: |
| result = chain.invoke(input_data).content |
| elapsed_time = time.time() - start_time |
| |
| logger.info(f"{operation_name} generated in {elapsed_time:.2f} seconds") |
| logger.info(f"Tokens used: {cb.total_tokens}, " |
| f"Prompt tokens: {cb.prompt_tokens}, " |
| f"Completion tokens: {cb.completion_tokens}") |
| |
| return result |
| except Exception as e: |
| error_msg = f"Error executing {operation_name.lower()}: {str(e)}" |
| logger.error(error_msg) |
| return error_msg |
| |
| def generate_eda_insights(self, dataset_info: Dict[str, Any]) -> str: |
| """ |
| Generate EDA insights based on dataset information using LangChain |
| |
| Args: |
| dataset_info: Dictionary containing dataset analysis |
| |
| Returns: |
| str: Detailed EDA insights and recommendations |
| """ |
| logger.info("Generating EDA insights") |
| |
| |
| columns_info = self._format_columns_info( |
| dataset_info.get("columns", []), |
| dataset_info.get("dtypes", {}) |
| ) |
| |
| missing_info = self._format_missing_info( |
| dataset_info.get("missing_values", {}) |
| ) |
| |
| |
| input_data = { |
| "shape": dataset_info.get("shape", "N/A"), |
| "columns_info": columns_info, |
| "missing_info": missing_info, |
| "basic_stats": dataset_info.get("basic_stats", ""), |
| "correlations": dataset_info.get("correlations", ""), |
| "sample_data": dataset_info.get("sample_data", "N/A") |
| } |
| |
| return self._execute_chain(self.eda_chain, input_data, "EDA insights") |
| |
| def generate_feature_engineering_recommendations(self, dataset_info: Dict[str, Any]) -> str: |
| """ |
| Generate feature engineering recommendations based on dataset information using LangChain |
| |
| Args: |
| dataset_info: Dictionary containing dataset analysis |
| |
| Returns: |
| str: Feature engineering recommendations |
| """ |
| logger.info("Generating feature engineering recommendations") |
| |
| |
| columns_info = self._format_columns_info( |
| dataset_info.get("columns", []), |
| dataset_info.get("dtypes", {}) |
| ) |
| |
| |
| input_data = { |
| "shape": dataset_info.get("shape", "N/A"), |
| "columns_info": columns_info, |
| "basic_stats": dataset_info.get("basic_stats", ""), |
| "correlations": dataset_info.get("correlations", "") |
| } |
| |
| return self._execute_chain( |
| self.feature_engineering_chain, |
| input_data, |
| "Feature engineering recommendations" |
| ) |
| |
| def generate_data_quality_insights(self, dataset_info: Dict[str, Any]) -> str: |
| """ |
| Generate data quality insights based on dataset information using LangChain |
| |
| Args: |
| dataset_info: Dictionary containing dataset analysis |
| |
| Returns: |
| str: Data quality insights and improvement recommendations |
| """ |
| logger.info("Generating data quality insights") |
| |
| |
| columns_info = self._format_columns_info( |
| dataset_info.get("columns", []), |
| dataset_info.get("dtypes", {}) |
| ) |
| |
| missing_info = self._format_missing_info( |
| dataset_info.get("missing_values", {}) |
| ) |
| |
| |
| input_data = { |
| "shape": dataset_info.get("shape", "N/A"), |
| "columns_info": columns_info, |
| "missing_info": missing_info, |
| "basic_stats": dataset_info.get("basic_stats", "") |
| } |
| |
| return self._execute_chain( |
| self.data_quality_chain, |
| input_data, |
| "Data quality insights" |
| ) |
| |
| def answer_dataset_question(self, question: str, dataset_info: Dict[str, Any]) -> str: |
| """ |
| Answer a specific question about the dataset using LangChain |
| |
| Args: |
| question: User's question about the dataset |
| dataset_info: Dictionary containing dataset analysis |
| |
| Returns: |
| str: Answer to the user's question |
| """ |
| logger.info(f"Answering dataset question: {question[:50]}...") |
| |
| |
| columns_info = self._format_columns_info( |
| dataset_info.get("columns", []), |
| dataset_info.get("dtypes", {}) |
| ) |
| |
| |
| input_data = { |
| "shape": dataset_info.get("shape", "N/A"), |
| "columns_info": columns_info, |
| "basic_stats": dataset_info.get("basic_stats", ""), |
| "question": question |
| } |
| |
| return self._execute_chain( |
| self.qa_chain, |
| input_data, |
| "Answer" |
| ) |
| |
| def answer_with_memory(self, question: str, dataset_info: Dict[str, Any], memory) -> str: |
| """ |
| Answer a question with conversation memory to maintain context |
| |
| Args: |
| question: User's question about the dataset |
| dataset_info: Dictionary containing dataset analysis |
| memory: ConversationBufferMemory instance to store conversation history |
| |
| Returns: |
| str: Answer to the user's question with conversation context |
| """ |
| logger.info(f"Answering with memory: {question[:50]}...") |
| |
| |
| columns_info = self._format_columns_info( |
| dataset_info.get("columns", []), |
| dataset_info.get("dtypes", {}) |
| ) |
| |
| |
| memory_prompt = ChatPromptTemplate.from_messages([ |
| HumanMessagePromptTemplate.from_template( |
| """You are a data scientist answering questions about a dataset. |
| The following is information about the dataset: |
| |
| Dataset Information: |
| - Shape: {shape} |
| - Columns and their types: |
| {columns_info} |
| |
| - Basic statistics: |
| {basic_stats} |
| |
| Previous conversation: |
| {chat_history} |
| |
| User's new question: {question} |
| |
| Please provide a clear, informative answer to the user's question. Take into account the previous conversation for context. Make your answer specific to the dataset information provided.""" |
| ) |
| ]) |
| |
| |
| memory_chain = memory_prompt | self.llm |
| |
| |
| try: |
| chat_history = memory.load_memory_variables({})["chat_history"] |
| |
| chat_history_str = "\n".join([f"{msg.type}: {msg.content}" for msg in chat_history]) |
| except Exception as e: |
| logger.warning(f"Error loading memory: {str(e)}. Using empty chat history.") |
| chat_history_str = "No previous conversation." |
| |
| input_data = { |
| "shape": dataset_info.get("shape", "N/A"), |
| "columns_info": columns_info, |
| "basic_stats": dataset_info.get("basic_stats", ""), |
| "question": question, |
| "chat_history": chat_history_str |
| } |
| |
| |
| response = self._execute_chain( |
| memory_chain, |
| input_data, |
| "Answer with memory" |
| ) |
| |
| |
| memory.save_context( |
| {"input": question}, |
| {"output": response} |
| ) |
| |
| return response |
|
|
|
|