Knowledge_Graph_Generator / generate_knowledge_graph_v1.py
Demosthene-OR's picture
.....
a2110a1
from langchain_experimental.graph_transformers import LLMGraphTransformer
from langchain_core.documents import Document
from langchain_openai import ChatOpenAI
from pyvis.network import Network
from dotenv import load_dotenv
import os
import asyncio
# Load the .env file
load_dotenv()
# Get API key from environment variable
api_key = os.getenv("OPENAI_API_KEY")
llm = ChatOpenAI(temperature=0, model_name="gpt-4o")
graph_transformer = LLMGraphTransformer(llm=llm)
# Extract graph data from input text
async def extract_graph_data(text):
"""
Asynchronously extracts graph data from input text using a graph transformer.
Args:
text (str): Input text to be processed into graph format.
Returns:
list: A list of GraphDocument objects containing nodes and relationships.
"""
documents = [Document(page_content=text)]
graph_documents = await graph_transformer.aconvert_to_graph_documents(documents)
return graph_documents
def visualize_graph(graph_documents):
"""
Visualizes a knowledge graph using PyVis based on the extracted graph documents.
Args:
graph_documents (list): A list of GraphDocument objects with nodes and relationships.
Returns:
pyvis.network.Network: The visualized network graph object.
"""
# Create network
net = Network(height="1200px", width="100%", directed=True,
notebook=False, bgcolor="#222222", font_color="white", filter_menu=True, cdn_resources='remote')
nodes = graph_documents[0].nodes
relationships = graph_documents[0].relationships
# Build lookup for valid nodes
node_dict = {node.id: node for node in nodes}
# Filter out invalid edges and collect valid node IDs
valid_edges = []
valid_node_ids = set()
for rel in relationships:
if rel.source.id in node_dict and rel.target.id in node_dict:
valid_edges.append(rel)
valid_node_ids.update([rel.source.id, rel.target.id])
# Track which nodes are part of any relationship
connected_node_ids = set()
for rel in relationships:
connected_node_ids.add(rel.source.id)
connected_node_ids.add(rel.target.id)
# Add valid nodes to the graph
for node_id in valid_node_ids:
node = node_dict[node_id]
try:
net.add_node(node.id, label=node.id, title=node.type, group=node.type)
except:
continue # Skip node if error occurs
# Add valid edges to the graph
for rel in valid_edges:
try:
net.add_edge(rel.source.id, rel.target.id, label=rel.type.lower())
except:
continue # Skip edge if error occurs
# Configure graph layout and physics
net.set_options("""
{
"physics": {
"forceAtlas2Based": {
"gravitationalConstant": -100,
"centralGravity": 0.01,
"springLength": 200,
"springConstant": 0.08
},
"minVelocity": 0.75,
"solver": "forceAtlas2Based"
}
}
""")
output_file = "knowledge_graph.html"
try:
net.save_graph(output_file)
print(f"Graph saved to {os.path.abspath(output_file)}")
return net
except Exception as e:
print(f"Error saving graph: {e}")
return None
def generate_knowledge_graph(text):
"""
Generates and visualizes a knowledge graph from input text.
This function runs the graph extraction asynchronously and then visualizes
the resulting graph using PyVis.
Args:
text (str): Input text to convert into a knowledge graph.
Returns:
pyvis.network.Network: The visualized network graph object.
"""
graph_documents = asyncio.run(extract_graph_data(text))
net = visualize_graph(graph_documents)
return net