| import ast |
| import os |
| from copy import deepcopy |
|
|
| import dhg |
| import gradio as gr |
| import matplotlib.pyplot as plt |
| import pandas as pd |
| from dhg.visualization.structure.defaults import (default_hypergraph_strength, |
| default_hypergraph_style, |
| default_size) |
| from dhg.visualization.structure.layout import force_layout |
| from dhg.visualization.structure.utils import draw_circle_edge, draw_vertex |
| from huggingface_hub import hf_hub_download |
|
|
|
|
| def draw_hypergraph( |
| hg: "dhg.Hypergraph", |
| e_style="circle", |
| v_label=None, |
| v_size=1.0, |
| v_color="r", |
| v_line_width=1.0, |
| e_color="gray", |
| e_fill_color="whitesmoke", |
| e_line_width=1.0, |
| font_size=1.0, |
| font_family="sans-serif", |
| push_v_strength=1.0, |
| push_e_strength=1.0, |
| pull_e_strength=1.0, |
| pull_center_strength=1.0, |
| ): |
| fig, ax = plt.subplots(figsize=(6, 6)) |
|
|
| num_v, e_list = hg.num_v, deepcopy(hg.e[0]) |
| |
| v_color, e_color, e_fill_color = default_hypergraph_style( |
| hg.num_v, hg.num_e, v_color, e_color, e_fill_color |
| ) |
| v_size, v_line_width, e_line_width, font_size = default_size( |
| num_v, e_list, v_size, v_line_width, e_line_width |
| ) |
| ( |
| push_v_strength, |
| push_e_strength, |
| pull_e_strength, |
| pull_center_strength, |
| ) = default_hypergraph_strength( |
| num_v, |
| e_list, |
| push_v_strength, |
| push_e_strength, |
| pull_e_strength, |
| pull_center_strength, |
| ) |
| |
| v_coor = force_layout( |
| num_v, |
| e_list, |
| push_v_strength, |
| push_e_strength, |
| pull_e_strength, |
| pull_center_strength, |
| ) |
| draw_circle_edge( |
| ax, |
| v_coor, |
| v_size, |
| e_list, |
| e_color, |
| e_fill_color, |
| e_line_width, |
| ) |
|
|
| draw_vertex( |
| ax, |
| v_coor, |
| v_label, |
| font_size, |
| font_family, |
| v_size, |
| v_color, |
| v_line_width, |
| ) |
|
|
| plt.xlim((0, 1.0)) |
| plt.ylim((0, 1.0)) |
| plt.axis("off") |
| fig.tight_layout() |
|
|
| return fig |
|
|
|
|
| def plot_dataset(dataset_choice: str, sampling_choice: str, split_choice: str): |
| os.makedirs("artifacts", exist_ok=True) |
| hf_hub_download( |
| filename=f"processed/{sampling_choice}/{split_choice}_df.csv", |
| local_dir="./artifacts/", |
| repo_id=f"SauravMaheshkar/{dataset_choice}", |
| repo_type="dataset", |
| ) |
|
|
| df = pd.read_csv(f"artifacts/processed/{sampling_choice}/{split_choice}_df.csv") |
|
|
| num_vertices = len(df) |
| edge_list = df["nodes"].values.tolist() |
| edge_list = [ast.literal_eval(edges) for edges in edge_list] |
|
|
| hypergraph = dhg.Hypergraph(num_vertices, edge_list) |
|
|
| fig = draw_hypergraph(hypergraph) |
| return fig |
|
|
|
|
| with gr.Blocks() as demo: |
|
|
| with gr.Row(): |
| dataset_choices = gr.Dropdown( |
| choices=[ |
| "email-Eu", |
| "email-Enron", |
| "NDC-classes", |
| "tags-math-sx", |
| "email-Eu-25", |
| "NDC-substances", |
| "congress-bills", |
| "tags-ask-ubuntu", |
| "email-Enron-25", |
| "NDC-classes-25", |
| "threads-ask-ubuntu", |
| "contact-high-school", |
| "NDC-substances-25", |
| "congress-bills-25", |
| "contact-primary-school", |
| ], |
| value="email-Enron-25", |
| label="Please choose a dataset", |
| interactive=True, |
| ) |
|
|
| sampling_choice = gr.Dropdown( |
| choices=[ |
| "transductive", |
| "inductive", |
| ], |
| value="inductive", |
| label="Choose sampling type", |
| interactive=True, |
| ) |
|
|
| split_choice = gr.Dropdown( |
| choices=[ |
| "train", |
| "valid", |
| "test", |
| ], |
| value="test", |
| label="Choose split", |
| interactive=True, |
| ) |
|
|
| output_plot = gr.Plot(label="Hypergraph plot") |
|
|
| btn = gr.Button("Visualise") |
| btn.click( |
| fn=plot_dataset, |
| inputs=[dataset_choices, sampling_choice, split_choice], |
| outputs=output_plot, |
| ) |
|
|
| demo.launch() |
|
|