| import streamlit as st |
| import pandas as pd |
| import os |
| from crewai import Crew |
| from langchain_groq import ChatGroq |
| import streamlit_ace as st_ace |
| import traceback |
| import contextlib |
| import io |
| from crewai_tools import FileReadTool |
| import matplotlib.pyplot as plt |
| import glob |
| from dotenv import load_dotenv |
| from autotabml_agents import initialize_agents |
| from autotabml_tasks import create_tasks |
|
|
|
|
| TEMP_DIR = "temp_dir" |
| OUTPUT_DIR = "Output_dir" |
| |
| if not os.path.exists(TEMP_DIR): |
| os.makedirs(TEMP_DIR) |
|
|
| |
| if not os.path.exists(OUTPUT_DIR): |
| os.makedirs(OUTPUT_DIR) |
|
|
| |
| def save_uploaded_file(uploaded_file): |
| file_path = os.path.join(TEMP_DIR, uploaded_file.name) |
| with open(file_path, 'wb') as f: |
| f.write(uploaded_file.getbuffer()) |
| return file_path |
|
|
| |
| load_dotenv() |
| |
| groq_api_key = os.environ.get("GROQ_API_KEY") |
|
|
|
|
| def main(): |
| |
| set_custom_css() |
|
|
| |
| if 'edited_code' not in st.session_state: |
| st.session_state['edited_code'] = "" |
| |
| |
| if 'code_generated' not in st.session_state: |
| st.session_state['code_generated'] = False |
|
|
| |
| st.markdown(""" |
| <div class="header"> |
| <h1>AutoTabML</h1> |
| <p>Automated Machine Learning Code Generation for Tabluar Data</p> |
| </div> |
| """, unsafe_allow_html=True) |
|
|
| |
| st.sidebar.title('LLM Model') |
| model = st.sidebar.selectbox( |
| 'Model', |
| ["llama3-70b-8192"] |
| ) |
|
|
| |
| llm = initialize_llm(model) |
|
|
| |
|
|
| |
| user_question = st.text_area("Describe your ML problem:", key="user_question") |
| uploaded_file = st.file_uploader("Upload a sample .csv of your data", key="uploaded_file") |
| try: |
| file_name = uploaded_file.name |
| except: |
| file_name = "dataset.csv" |
|
|
| |
| agents = initialize_agents(llm,file_name,TEMP_DIR) |
| |
| if uploaded_file: |
| try: |
| file_path = save_uploaded_file(uploaded_file) |
| df = pd.read_csv(uploaded_file) |
| st.write("Data successfully uploaded:") |
| st.dataframe(df.head()) |
| data_upload = True |
| except Exception as e: |
| st.error(f"Error reading the file: {e}") |
| data_upload = False |
| else: |
| df = None |
| data_upload = False |
|
|
| |
| if st.button('Process'): |
| tasks = create_tasks("Process",user_question,file_name, data_upload, df, None, st.session_state['edited_code'], None, agents) |
| with st.spinner('Processing...'): |
| crew = Crew( |
| agents=list(agents.values()), |
| tasks=tasks, |
| verbose=2 |
| ) |
|
|
| result = crew.kickoff() |
|
|
| if result: |
| code = result.strip("```") |
| try: |
| filt_idx = code.index("```") |
| code = code[:filt_idx] |
| except: |
| pass |
| st.session_state['edited_code'] = code |
| st.session_state['code_generated'] = True |
|
|
| st.session_state['edited_code'] = st_ace.st_ace( |
| value=st.session_state['edited_code'], |
| language='python', |
| theme='monokai', |
| keybinding='vscode', |
| min_lines=20, |
| max_lines=50 |
| ) |
|
|
| if st.session_state['code_generated']: |
| |
| suggestion = st.text_area("Suggest modifications to the generated code (optional):", key="suggestion") |
| if st.button('Modify'): |
| if st.session_state['edited_code'] and suggestion: |
| tasks = create_tasks("Modify",user_question,file_name, data_upload, df, suggestion, st.session_state['edited_code'], None, agents) |
| with st.spinner('Modifying code...'): |
| crew = Crew( |
| agents=list(agents.values()), |
| tasks=tasks, |
| verbose=2 |
| ) |
|
|
| result = crew.kickoff() |
|
|
| if result: |
| code = result.strip("```") |
| try: |
| filter_idx = code.index("```") |
| code = code[:filter_idx] |
| except: |
| pass |
| st.session_state['edited_code'] = code |
|
|
| st.write("Modified code:") |
| st.session_state['edited_code']= st_ace.st_ace( |
| value=st.session_state['edited_code'], |
| language='python', |
| theme='monokai', |
| keybinding='vscode', |
| min_lines=20, |
| max_lines=50 |
| ) |
|
|
| debugger = st.text_area("Paste error message here for debugging (optional):", key="debugger") |
| if st.button('Debug'): |
| if st.session_state['edited_code'] and debugger: |
| tasks = create_tasks("Debug",user_question,file_name, data_upload, df, None, st.session_state['edited_code'], debugger, agents) |
| with st.spinner('Debugging code...'): |
| crew = Crew( |
| agents=list(agents.values()), |
| tasks=tasks, |
| verbose=2 |
| ) |
|
|
| result = crew.kickoff() |
|
|
| if result: |
| code = result.strip("```") |
| try: |
| filter_idx = code.index("```") |
| code = code[:filter_idx] |
| except: |
| pass |
| st.session_state['edited_code'] = code |
|
|
| st.write("Debugged code:") |
| st.session_state['edited_code'] = st_ace.st_ace( |
| value=st.session_state['edited_code'], |
| language='python', |
| theme='monokai', |
| keybinding='vscode', |
| min_lines=20, |
| max_lines=50 |
| ) |
|
|
| if st.button('Run'): |
| output = io.StringIO() |
| with contextlib.redirect_stdout(output): |
| try: |
| globals().update({'dataset': df}) |
| final_code = st.session_state["edited_code"] |
| |
| with st.expander("Final Code"): |
| st.code(final_code, language='python') |
|
|
| exec(final_code, globals()) |
| result = output.getvalue() |
| success = True |
| except Exception as e: |
| result = str(e) |
| success = False |
|
|
| st.subheader('Output:') |
| st.text(result) |
|
|
| figs = [manager.canvas.figure for manager in plt._pylab_helpers.Gcf.get_all_fig_managers()] |
| if figs: |
| st.subheader('Generated Plots:') |
| for fig in figs: |
| st.pyplot(fig) |
|
|
| if success: |
| st.success("Code executed successfully!") |
| else: |
| st.error("Code execution failed! Waiting for debugging input...") |
|
|
| |
| with st.sidebar: |
| st.header('Output_dir :') |
| files = glob.glob(os.path.join(OUTPUT_DIR, '*')) |
| for file in files: |
| if os.path.isfile(file): |
| with open(file, 'rb') as f: |
| st.download_button(label=f'Download {os.path.basename(file)}', data=f, file_name=os.path.basename(file)) |
|
|
|
|
|
|
| |
| def set_custom_css(): |
| st.markdown(""" |
| <style> |
| body { |
| background: #0e0e0e; |
| color: #e0e0e0; |
| font-family: 'Roboto', sans-serif; |
| } |
| .header { |
| background: linear-gradient(135deg, #6e3aff, #b839ff); |
| padding: 10px; |
| border-radius: 10px; |
| } |
| .header h1, .header p { |
| color: white; |
| text-align: center; |
| } |
| .stButton button { |
| background-color: #b839ff; |
| color: white; |
| border-radius: 10px; |
| font-size: 16px; |
| padding: 10px 20px; |
| } |
| .stButton button:hover { |
| background-color: #6e3aff; |
| color: #e0e0e0; |
| } |
| .spinner { |
| display: flex; |
| justify-content: center; |
| align-items: center; |
| } |
| </style> |
| """, unsafe_allow_html=True) |
|
|
| |
| def initialize_llm(model): |
| return ChatGroq( |
| temperature=0, |
| groq_api_key=groq_api_key, |
| model_name=model |
| ) |
|
|
| |
| if __name__ == "__main__": |
| main() |