|
|
| from langchain_experimental.graph_transformers import LLMGraphTransformer |
| from langchain_core.documents import Document |
| from langchain_openai import ChatOpenAI, OpenAIEmbeddings |
| from langchain_community.vectorstores import FAISS |
| from langchain_core.prompts import PromptTemplate |
| from pyvis.network import Network |
| from dotenv import load_dotenv |
| import os |
| import asyncio |
|
|
| load_dotenv() |
| api_key = os.getenv("OPENAI_API_KEY") |
|
|
| llm = ChatOpenAI(temperature=0, model_name="gpt-4o") |
| graph_transformer = LLMGraphTransformer(llm=llm) |
|
|
| async def extract_graph_data(text): |
| documents = [Document(page_content=text)] |
| graph_documents = await graph_transformer.aconvert_to_graph_documents(documents) |
| return graph_documents |
|
|
| def visualize_graph(graph_documents): |
| net = Network(height="600px", width="100%", directed=True, notebook=False, bgcolor="#222222", font_color="white", filter_menu=True, cdn_resources='remote') |
| nodes = graph_documents[0].nodes |
| relationships = graph_documents[0].relationships |
|
|
| node_dict = {node.id: node for node in nodes} |
| valid_edges = [] |
| valid_node_ids = set() |
|
|
| for rel in relationships: |
| if rel.source.id in node_dict and rel.target.id in node_dict: |
| valid_edges.append(rel) |
| valid_node_ids.update([rel.source.id, rel.target.id]) |
|
|
| for node_id in valid_node_ids: |
| node = node_dict[node_id] |
| try: |
| net.add_node(node.id, label=node.id, title=node.type, group=node.type) |
| except: |
| continue |
|
|
| for rel in valid_edges: |
| try: |
| net.add_edge(rel.source.id, rel.target.id, label=rel.type.lower()) |
| except: |
| continue |
|
|
| net.set_options('{"physics": {"forceAtlas2Based": {"gravitationalConstant": -100, "centralGravity": 0.01, "springLength": 200, "springConstant": 0.08}, "minVelocity": 0.75, "solver": "forceAtlas2Based"}}') |
| return net |
|
|
| def generate_knowledge_graph(text): |
| graph_documents = asyncio.run(extract_graph_data(text)) |
| net = visualize_graph(graph_documents) |
| return net, graph_documents |
|
|
| def answer_question_with_graph(question, graph_documents, k_relations=5): |
| all_relationships = [] |
| for doc in graph_documents: |
| all_relationships.extend(doc.relationships) |
|
|
| if not all_relationships: |
| return "Aucune relation trouvée dans le graphe.", visualize_graph(graph_documents) |
|
|
| rel_docs = [] |
| for i, rel in enumerate(all_relationships): |
| text_rep = f"L'entité '{rel.source.id}' a pour relation '{rel.type}' avec l'entité '{rel.target.id}'." |
| rel_docs.append(Document(page_content=text_rep, metadata={"rel_index": i})) |
|
|
| embeddings = OpenAIEmbeddings(model="text-embedding-3-small") |
| vectorstore = FAISS.from_documents(rel_docs, embeddings) |
| retrieved_docs = vectorstore.similarity_search(question, k=k_relations) |
|
|
| used_relationships = [all_relationships[doc.metadata["rel_index"]] for doc in retrieved_docs] |
| context = "\n".join([doc.page_content for doc in retrieved_docs]) |
|
|
| prompt = PromptTemplate( |
| template='''Tu es un assistant expert qui répond aux questions en se basant UNIQUEMENT sur ce sous-ensemble de relations extraites d'un graphe de connaissances.\n\nContexte (Relations pertinentes trouvées) :\n{context}\n\nQuestion : {question}\n\nRéponds de manière claire et concise en français. Si la réponse n'est pas dans le contexte fourni, dis-le explicitement.''', |
| input_variables=["context", "question"] |
| ) |
|
|
| chain = prompt | llm |
| answer = chain.invoke({"context": context, "question": question}).content |
|
|
| net = Network(height="450px", width="100%", directed=True, bgcolor="#222222", font_color="white") |
|
|
| nodes_added = set() |
| for rel in used_relationships: |
| if rel.source.id not in nodes_added: |
| net.add_node(rel.source.id, label=rel.source.id, title=rel.source.type, group=rel.source.type) |
| nodes_added.add(rel.source.id) |
| if rel.target.id not in nodes_added: |
| net.add_node(rel.target.id, label=rel.target.id, title=rel.target.type, group=rel.target.type) |
| nodes_added.add(rel.target.id) |
| try: |
| net.add_edge(rel.source.id, rel.target.id, label=rel.type) |
| except: |
| pass |
|
|
| net.set_options('{"physics": {"forceAtlas2Based": {"gravitationalConstant": -50}}}') |
| net.save_graph("filtered_graph.html") |
|
|
| return answer, net |
|
|