File size: 4,417 Bytes
a2110a1
3f5d323
 
a2110a1
 
 
3f5d323
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2110a1
3f5d323
 
 
 
 
 
a2110a1
3f5d323
 
 
 
 
 
 
 
 
 
a2110a1
3f5d323
 
 
 
 
a2110a1
3f5d323
a2110a1
 
3f5d323
 
 
 
a2110a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d1a8c7
a2110a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105

from langchain_experimental.graph_transformers import LLMGraphTransformer
from langchain_core.documents import Document
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_core.prompts import PromptTemplate
from pyvis.network import Network
from dotenv import load_dotenv
import os
import asyncio

load_dotenv()
api_key = os.getenv("OPENAI_API_KEY")

llm = ChatOpenAI(temperature=0, model_name="gpt-4o")
graph_transformer = LLMGraphTransformer(llm=llm)

async def extract_graph_data(text):
    documents = [Document(page_content=text)]
    graph_documents = await graph_transformer.aconvert_to_graph_documents(documents)
    return graph_documents

def visualize_graph(graph_documents):
    net = Network(height="600px", width="100%", directed=True, notebook=False, bgcolor="#222222", font_color="white", filter_menu=True, cdn_resources='remote')
    nodes = graph_documents[0].nodes
    relationships = graph_documents[0].relationships

    node_dict = {node.id: node for node in nodes}
    valid_edges = []
    valid_node_ids = set()

    for rel in relationships:
        if rel.source.id in node_dict and rel.target.id in node_dict:
            valid_edges.append(rel)
            valid_node_ids.update([rel.source.id, rel.target.id])

    for node_id in valid_node_ids:
        node = node_dict[node_id]
        try:
            net.add_node(node.id, label=node.id, title=node.type, group=node.type)
        except:
            continue

    for rel in valid_edges:
        try:
            net.add_edge(rel.source.id, rel.target.id, label=rel.type.lower())
        except:
            continue

    net.set_options('{"physics": {"forceAtlas2Based": {"gravitationalConstant": -100, "centralGravity": 0.01, "springLength": 200, "springConstant": 0.08}, "minVelocity": 0.75, "solver": "forceAtlas2Based"}}')
    return net

def generate_knowledge_graph(text):
    graph_documents = asyncio.run(extract_graph_data(text))
    net = visualize_graph(graph_documents)
    return net, graph_documents

def answer_question_with_graph(question, graph_documents, k_relations=5):
    all_relationships = []
    for doc in graph_documents:
        all_relationships.extend(doc.relationships)

    if not all_relationships:
        return "Aucune relation trouvée dans le graphe.", visualize_graph(graph_documents)

    rel_docs = []
    for i, rel in enumerate(all_relationships):
        text_rep = f"L'entité '{rel.source.id}' a pour relation '{rel.type}' avec l'entité '{rel.target.id}'."
        rel_docs.append(Document(page_content=text_rep, metadata={"rel_index": i}))

    embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
    vectorstore = FAISS.from_documents(rel_docs, embeddings)
    retrieved_docs = vectorstore.similarity_search(question, k=k_relations)

    used_relationships = [all_relationships[doc.metadata["rel_index"]] for doc in retrieved_docs]
    context = "\n".join([doc.page_content for doc in retrieved_docs])

    prompt = PromptTemplate(
        template='''Tu es un assistant expert qui répond aux questions en se basant UNIQUEMENT sur ce sous-ensemble de relations extraites d'un graphe de connaissances.\n\nContexte (Relations pertinentes trouvées) :\n{context}\n\nQuestion : {question}\n\nRéponds de manière claire et concise en français. Si la réponse n'est pas dans le contexte fourni, dis-le explicitement.''',
        input_variables=["context", "question"]
    )

    chain = prompt | llm
    answer = chain.invoke({"context": context, "question": question}).content

    net = Network(height="450px", width="100%", directed=True, bgcolor="#222222", font_color="white")

    nodes_added = set()
    for rel in used_relationships:
        if rel.source.id not in nodes_added:
            net.add_node(rel.source.id, label=rel.source.id, title=rel.source.type, group=rel.source.type)
            nodes_added.add(rel.source.id)
        if rel.target.id not in nodes_added:
            net.add_node(rel.target.id, label=rel.target.id, title=rel.target.type, group=rel.target.type)
            nodes_added.add(rel.target.id)
        try:
            net.add_edge(rel.source.id, rel.target.id, label=rel.type)
        except:
            pass

    net.set_options('{"physics": {"forceAtlas2Based": {"gravitationalConstant": -50}}}')
    net.save_graph("filtered_graph.html")

    return answer, net