AITelecomAnalyzer / utils /visualization.py
FrederickSundeep's picture
updated commit 010
b5b8247
raw
history blame contribute delete
872 Bytes
import matplotlib.pyplot as plt
import networkx as nx
def plot_network(nodes, path):
# Create directed graph
G = nx.DiGraph()
# Add edges between consecutive nodes
for i in range(len(nodes) - 1):
G.add_edge(nodes[i], nodes[i + 1])
# Color nodes: red if in path, lightblue otherwise
color_map = ['red' if node in path else 'lightblue' for node in G.nodes()]
# Set layout
pos = nx.spring_layout(G, seed=42)
# Plot the graph
plt.figure(figsize=(10, 6))
nx.draw(G, pos, with_labels=True, node_color=color_map,
node_size=1200, font_size=10, arrows=True, edge_color='gray')
# Highlight path edges
path_edges = [(path[i], path[i+1]) for i in range(len(path)-1) if G.has_edge(path[i], path[i+1])]
nx.draw_networkx_edges(G, pos, edgelist=path_edges, edge_color='red', width=2)
return plt.gcf()