|
|
| import streamlit as st |
| import streamlit.components.v1 as components |
| from generate_knowledge_graph import generate_knowledge_graph, answer_question_with_graph |
|
|
| st.set_page_config( |
| page_icon="None", |
| layout="wide", |
| initial_sidebar_state="auto", |
| menu_items=None |
| ) |
|
|
| st.title("Knowledge Graph From Text") |
|
|
| st.sidebar.title("Input document") |
| input_method = st.sidebar.radio( |
| "Choose an input method:", |
| ("Upload .txt", "Input text") |
| ) |
|
|
| |
| text = "" |
| if input_method == "Upload .txt": |
| uploaded_file = st.sidebar.file_uploader(label="Upload file", type="txt") |
| if uploaded_file is not None: |
| text = uploaded_file.read().decode("utf-8") |
| else: |
| text = st.sidebar.text_area("Input text", height=300) |
|
|
| if st.sidebar.button("1. Generate Knowledge Graph"): |
| if text: |
| with st.spinner("Generating knowledge graph..."): |
| net, graph_docs = generate_knowledge_graph(text) |
| st.session_state['graph_docs'] = graph_docs |
| st.success("Knowledge graph generated successfully!") |
|
|
| output_file = "knowledge_graph.html" |
| net.save_graph(output_file) |
| HtmlFile = open(output_file, 'r', encoding='utf-8') |
| components.html(HtmlFile.read(), height=600) |
| else: |
| st.sidebar.error("Please provide some text to generate the graph.") |
|
|
| |
| if 'graph_docs' in st.session_state: |
| st.markdown("---") |
| st.subheader("Posez une question sur le document") |
|
|
| col1, col2 = st.columns([3, 1]) |
| with col1: |
| question = st.text_input("Votre question :") |
| with col2: |
| k_value = st.slider("Relations à analyser (Top K)", min_value=1, max_value=20, value=5) |
|
|
| if st.button("2. Analyser") and question: |
| with st.spinner("Recherche sémantique dans le graphe en cours..."): |
| answer, filtered_net = answer_question_with_graph( |
| question, |
| st.session_state['graph_docs'], |
| k_relations=k_value |
| ) |
|
|
| st.info(f"**Réponse :** {answer}") |
|
|
| st.markdown("**Sous-graphe des relations utilisées pour répondre :**") |
| HtmlFile = open("filtered_graph.html", 'r', encoding='utf-8') |
| components.html(HtmlFile.read(), height=450) |
|
|