import matplotlib.pyplot as plt import networkx as nx def plot_network(nodes, path): # Create directed graph G = nx.DiGraph() # Add edges between consecutive nodes for i in range(len(nodes) - 1): G.add_edge(nodes[i], nodes[i + 1]) # Color nodes: red if in path, lightblue otherwise color_map = ['red' if node in path else 'lightblue' for node in G.nodes()] # Set layout pos = nx.spring_layout(G, seed=42) # Plot the graph plt.figure(figsize=(10, 6)) nx.draw(G, pos, with_labels=True, node_color=color_map, node_size=1200, font_size=10, arrows=True, edge_color='gray') # Highlight path edges path_edges = [(path[i], path[i+1]) for i in range(len(path)-1) if G.has_edge(path[i], path[i+1])] nx.draw_networkx_edges(G, pos, edgelist=path_edges, edge_color='red', width=2) return plt.gcf()