| | import time |
| | import streamlit as st |
| | import numpy as np |
| | from pathlib import Path |
| | from experiments.gmm_dataset import GeneralizedGaussianMixture |
| | import plotly.graph_objects as go |
| | from plotly.subplots import make_subplots |
| | from typing import List, Tuple |
| | import torch |
| | import os |
| | import sys |
| | import matplotlib.pyplot as plt |
| |
|
| | |
| | torch.classes.__path__ = [os.path.join(torch.__path__[0], torch.classes.__file__ or "")] |
| |
|
| | |
| | pykan_path = Path(__file__).parent.parent / 'third_party' / 'pykan' |
| | sys.path.append(str(pykan_path)) |
| |
|
| | |
| | from kan import KAN |
| | from kan.utils import create_dataset, ex_round |
| |
|
| | |
| | torch.set_default_dtype(torch.float64) |
| |
|
| | def show_kan_prediction(model, device, samples, placeholder, phase_name): |
| | """显示KAN的预测结果""" |
| | |
| | x = np.linspace(-5, 5, 100) |
| | y = np.linspace(-5, 5, 100) |
| | X, Y = np.meshgrid(x, y) |
| | xy = np.column_stack((X.ravel(), Y.ravel())) |
| | |
| | |
| | grid_points = torch.from_numpy(xy).to(device) |
| | with torch.no_grad(): |
| | Z_kan = model(grid_points).cpu().numpy().reshape(X.shape) |
| | |
| | |
| | fig_kan = make_subplots( |
| | rows=1, cols=2, |
| | specs=[[{'type': 'surface'}, {'type': 'contour'}]], |
| | subplot_titles=('KAN预测的3D概率密度曲面', 'KAN预测的等高线图') |
| | ) |
| | |
| | |
| | surface_kan = go.Surface( |
| | x=X, y=Y, z=Z_kan, |
| | colorscale='viridis', |
| | showscale=True, |
| | colorbar=dict(x=0.45) |
| | ) |
| | fig_kan.add_trace(surface_kan, row=1, col=1) |
| | |
| | |
| | contour_kan = go.Contour( |
| | x=x, y=y, z=Z_kan, |
| | colorscale='viridis', |
| | showscale=True, |
| | colorbar=dict(x=1.0), |
| | contours=dict( |
| | showlabels=True, |
| | labelfont=dict(size=12) |
| | ) |
| | ) |
| | fig_kan.add_trace(contour_kan, row=1, col=2) |
| | |
| | |
| | if samples is not None: |
| | samples = samples.cpu().numpy() if torch.is_tensor(samples) else samples |
| | fig_kan.add_trace( |
| | go.Scatter( |
| | x=samples[:, 0], y=samples[:, 1], |
| | mode='markers', |
| | marker=dict( |
| | size=8, |
| | color='yellow', |
| | line=dict(color='black', width=1) |
| | ), |
| | name='训练点' |
| | ), |
| | row=1, col=2 |
| | ) |
| | |
| | |
| | fig_kan.update_layout( |
| | title='KAN预测分布', |
| | showlegend=True, |
| | width=1200, |
| | height=600, |
| | scene=dict( |
| | xaxis_title='X', |
| | yaxis_title='Y', |
| | zaxis_title='密度' |
| | ) |
| | ) |
| | |
| | |
| | fig_kan.update_xaxes(title_text='X', row=1, col=2) |
| | fig_kan.update_yaxes(title_text='Y', row=1, col=2) |
| | |
| | |
| | |
| | placeholder.plotly_chart(fig_kan, |
| | use_container_width=False, |
| | key=f"kan_plot_{phase_name}_{time.time()}") |
| |
|
| | def create_gmm_plot(dataset, centers, K, samples=None): |
| | """创建GMM分布的可视化图形""" |
| | |
| | x = np.linspace(-5, 5, 100) |
| | y = np.linspace(-5, 5, 100) |
| | X, Y = np.meshgrid(x, y) |
| | xy = np.column_stack((X.ravel(), Y.ravel())) |
| |
|
| | |
| | Z = dataset.pdf(xy).reshape(X.shape) |
| |
|
| | |
| | fig = make_subplots( |
| | rows=1, cols=2, |
| | specs=[[{'type': 'surface'}, {'type': 'contour'}]], |
| | subplot_titles=('3D概率密度曲面', '等高线图与分量中心') |
| | ) |
| |
|
| | |
| | surface = go.Surface( |
| | x=X, y=Y, z=Z, |
| | colorscale='viridis', |
| | showscale=True, |
| | colorbar=dict(x=0.45) |
| | ) |
| | fig.add_trace(surface, row=1, col=1) |
| |
|
| | |
| | contour = go.Contour( |
| | x=x, y=y, z=Z, |
| | colorscale='viridis', |
| | showscale=True, |
| | colorbar=dict(x=1.0), |
| | contours=dict( |
| | showlabels=True, |
| | labelfont=dict(size=12) |
| | ) |
| | ) |
| | fig.add_trace(contour, row=1, col=2) |
| |
|
| | |
| | fig.add_trace( |
| | go.Scatter( |
| | x=centers[:K, 0], y=centers[:K, 1], |
| | mode='markers+text', |
| | marker=dict(size=10, color='red'), |
| | text=[f'C{i+1}' for i in range(K)], |
| | textposition="top center", |
| | name='分量中心' |
| | ), |
| | row=1, col=2 |
| | ) |
| |
|
| | |
| | if samples is not None: |
| | fig.add_trace( |
| | go.Scatter( |
| | x=samples[:, 0], y=samples[:, 1], |
| | mode='markers+text', |
| | marker=dict( |
| | size=8, |
| | color='yellow', |
| | line=dict(color='black', width=1) |
| | ), |
| | text=[f'S{i+1}' for i in range(len(samples))], |
| | textposition="bottom center", |
| | name='采样点' |
| | ), |
| | row=1, col=2 |
| | ) |
| |
|
| | |
| | fig.update_layout( |
| | title='广义高斯混合分布', |
| | showlegend=True, |
| | width=1200, |
| | height=600, |
| | scene=dict( |
| | xaxis_title='X', |
| | yaxis_title='Y', |
| | zaxis_title='密度' |
| | ) |
| | ) |
| |
|
| | |
| | fig.update_xaxes(title_text='X', row=1, col=2) |
| | fig.update_yaxes(title_text='Y', row=1, col=2) |
| |
|
| | return fig |
| |
|
| | def train_kan(samples, gmm_dataset, device='cuda'): |
| | """训练KAN网络""" |
| | if torch.cuda.is_available() and device == 'cuda': |
| | device = torch.device('cuda') |
| | else: |
| | device = torch.device('cpu') |
| | st.info(f"使用设备: {device} 训练网络") |
| |
|
| | |
| | samples = torch.from_numpy(samples).to(device) |
| | |
| | labels = torch.from_numpy(gmm_dataset.pdf(samples.cpu().numpy())).reshape(-1, 1).to(device) |
| |
|
| | |
| | model = KAN(width=[2,5,1], grid=3, k=3, seed=42, device=device) |
| | |
| | train_size = int(0.8 * samples.shape[0]) |
| | train_dataset = { |
| | 'train_input': samples[:train_size], |
| | 'train_label': labels[:train_size], |
| | 'test_input': samples[train_size:], |
| | 'test_label': labels[train_size:] |
| | } |
| |
|
| | |
| | |
| | |
| |
|
| | st.write("网络图形结构:") |
| | kan_network_arch_placeholder = st.empty() |
| |
|
| |
|
| | progress_container = st.container() |
| |
|
| | |
| | total_steps = 50 |
| | steps_per_update = 10 |
| |
|
| | def calculate_error(model, x, y): |
| | """计算预测误差""" |
| | with torch.no_grad(): |
| | pred = model(x) |
| | return torch.mean((pred - y) ** 2).item() |
| | |
| | def train_phase(phase_name, steps, lamb=None, show_plot=True): |
| | with progress_container: |
| | progress_bar = st.progress(0) |
| | status_text = st.empty() |
| | |
| | for step in range(0, steps, steps_per_update): |
| | |
| | if lamb is not None: |
| | model.fit(train_dataset, opt="LBFGS", steps=steps_per_update, lamb=lamb) |
| | else: |
| | model.fit(train_dataset, opt="LBFGS", steps=steps_per_update) |
| | |
| | |
| | progress = (step + steps_per_update) / steps |
| | progress_bar.progress(progress) |
| | |
| | |
| | train_error = calculate_error(model, train_dataset['train_input'], train_dataset['train_label']) |
| | test_error = calculate_error(model, train_dataset['test_input'], train_dataset['test_label']) |
| | |
| | status_text.markdown(f""" |
| | ### {phase_name} |
| | | 项目 | 值 | |
| | |:---|:---| |
| | | 进度 | {progress:.0%} | |
| | | 训练误差 | {train_error:.8f} | |
| | | 测试误差 | {test_error:.8f} | |
| | """) |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | if show_plot: |
| | try: |
| | model.plot() |
| | kan_fig = plt.gcf() |
| | |
| | |
| | |
| | kan_network_arch_placeholder.pyplot(kan_fig, use_container_width=False) |
| | |
| | except Exception as e: |
| | if step == 0: |
| | st.warning(f"注意:网络结构图显示失败 ({str(e)})") |
| |
|
| |
|
| | |
| | show_kan_prediction(model, device, samples, kan_distribution_plot_placeholder, phase_name) |
| | |
| | with progress_container: |
| | st.markdown("#### 训练过程") |
| | error_text = st.empty() |
| |
|
| | |
| | |
| | with st.spinner("参数调整中..."): |
| | train_phase("第一阶段: 正则化训练", total_steps, lamb=0.001, show_plot=True) |
| | |
| | |
| | with st.spinner("正在进行网络剪枝优化..."): |
| | model = model.prune() |
| | progress_container.info("网络剪枝完成") |
| | |
| | with st.spinner("参数调整中..."): |
| | train_phase("第二阶段: 剪枝适应性训练", total_steps, show_plot=True) |
| | |
| | with st.spinner("正在进行网格精细化..."): |
| | model = model.refine(10) |
| | progress_container.info("网格精细化完成") |
| |
|
| | with st.spinner("参数调整中..."): |
| | train_phase("第三阶段: 网格适应性训练", total_steps, show_plot=True) |
| |
|
| | with st.spinner("符号简化中..."): |
| | |
| | |
| | lib = ['x','x^2','x^3','x^4','exp','log','sqrt','tanh','sin','abs'] |
| | model.auto_symbolic(lib=lib) |
| | |
| | progress_container.info("符号简化完成") |
| |
|
| | with st.spinner("参数调整中..."): |
| | train_phase("第四阶段:符号适应性训练", total_steps, show_plot=True) |
| | |
| | from kan.utils import ex_round |
| | from sympy import latex |
| | s= ex_round(model.symbolic_formula()[0][0],4) |
| |
|
| | st.write("网络公式:") |
| | st.latex(latex(s)) |
| | |
| | |
| | train_error = calculate_error(model, train_dataset['train_input'], train_dataset['train_label']) |
| | test_error = calculate_error(model, train_dataset['test_input'], train_dataset['test_label']) |
| | error_text.markdown(f""" |
| | #### 训练结果 |
| | - 训练集误差: {train_error:.6f} |
| | - 测试集误差: {test_error:.6f} |
| | """) |
| |
|
| | progress_container.success("🎉 训练完成!") |
| | return model |
| |
|
| | def init_session_state(): |
| | """初始化session state""" |
| | if 'prev_K' not in st.session_state: |
| | st.session_state.prev_K = 3 |
| | if 'p' not in st.session_state: |
| | st.session_state.p = 2.0 |
| | if 'centers' not in st.session_state: |
| | st.session_state.centers = np.array([[-2, -2], [0, 0], [2, 2]], dtype=np.float64) |
| | if 'scales' not in st.session_state: |
| | st.session_state.scales = np.array([[0.3, 0.3], [0.2, 0.2], [0.4, 0.4]], dtype=np.float64) |
| | if 'weights' not in st.session_state: |
| | st.session_state.weights = np.ones(3, dtype=np.float64) / 3 |
| | if 'sample_points' not in st.session_state: |
| | st.session_state.sample_points = None |
| | if 'kan_model' not in st.session_state: |
| | st.session_state.kan_model = None |
| |
|
| | def create_default_parameters(K: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: |
| | """创建默认参数""" |
| | |
| | x = np.linspace(-3, 3, K) |
| | y = np.linspace(-3, 3, K) |
| | centers = np.column_stack((x, y)) |
| | |
| | |
| | scales = np.ones((K, 2), dtype=np.float64) * 3 |
| | weights = np.random.random(size=K).astype(np.float64) |
| | weights /= weights.sum() |
| | return centers, scales, weights |
| |
|
| | def generate_latex_formula(p: float, K: int, centers: np.ndarray, |
| | scales: np.ndarray, weights: np.ndarray) -> str: |
| | """生成LaTeX公式""" |
| | formula = r"P(x) = \sum_{k=1}^{" + str(K) + r"} \pi_k P_{\theta_k}(x) \\" |
| | formula += r"P_{\theta_k}(x) = \eta_k \exp(-s_k d_k(x)) = \frac{p}{2\alpha_k \Gamma(1/p) }\exp(-\frac{|x-c_k|^p}{\alpha_k^p})= \frac{p}{2\alpha_k \Gamma(1/p) }\exp(-|\frac{x-c_k}{\alpha_k}|^p) \\" |
| | formula += r"\text{where: }" |
| | |
| | for k in range(K): |
| | c = centers[k] |
| | s = scales[k] |
| | w = weights[k] |
| | component = f"P_{{\\theta_{k+1}}}(x) = \\frac{{{p:.1f}}}{{2\\alpha_{k+1} \\Gamma(1/{p:.1f})}}\\exp(-|\\frac{{x-({c[0]:.1f}, {c[1]:.1f})}}{{{s[0]:.1f}, {s[1]:.1f}}}|^{{{p:.1f}}}) \\\\" |
| | formula += component |
| | formula += f"\\pi_{k+1} = {w:.2f} \\\\" |
| | |
| | return formula |
| |
|
| | st.set_page_config(page_title="GMM Distribution Visualization", layout="wide") |
| | st.title("广义高斯混合分布可视化") |
| |
|
| | |
| | init_session_state() |
| |
|
| | |
| | with st.sidebar: |
| | st.header("分布参数") |
| | |
| | |
| | st.session_state.p = st.slider("形状参数 (p)", 0.1, 5.0, st.session_state.p, 0.1, |
| | help="p=1: 拉普拉斯分布, p=2: 高斯分布, p→∞: 均匀分布") |
| | K = st.slider("分量数 (K)", 1, 5, st.session_state.prev_K) |
| | |
| | |
| | if K != st.session_state.prev_K: |
| | centers, scales, weights = create_default_parameters(K) |
| | st.session_state.centers = centers |
| | st.session_state.scales = scales |
| | st.session_state.weights = weights |
| | st.session_state.prev_K = K |
| | |
| | |
| | st.subheader("高级设置") |
| | show_advanced = st.checkbox("显示分量参数", value=False) |
| | |
| | if show_advanced: |
| | |
| | centers_list: List[List[float]] = [] |
| | scales_list: List[List[float]] = [] |
| | weights_list: List[float] = [] |
| | |
| | for k in range(K): |
| | st.write(f"分量 {k+1}") |
| | col1, col2 = st.columns(2) |
| | with col1: |
| | cx = st.number_input(f"中心X_{k+1}", -5.0, 5.0, float(st.session_state.centers[k][0]), 0.1) |
| | cy = st.number_input(f"中心Y_{k+1}", -5.0, 5.0, float(st.session_state.centers[k][1]), 0.1) |
| | with col2: |
| | sx = st.number_input(f"尺度X_{k+1}", 0.1, 3.0, float(st.session_state.scales[k][0]), 0.1) |
| | sy = st.number_input(f"尺度Y_{k+1}", 0.1, 3.0, float(st.session_state.scales[k][1]), 0.1) |
| | w = st.slider(f"权重_{k+1}", 0.0, 1.0, float(st.session_state.weights[k]), 0.1) |
| | |
| | centers_list.append([cx, cy]) |
| | scales_list.append([sx, sy]) |
| | weights_list.append(w) |
| | |
| | centers = np.array(centers_list, dtype=np.float64) |
| | scales = np.array(scales_list, dtype=np.float64) |
| | weights = np.array(weights_list, dtype=np.float64) |
| | weights = weights / weights.sum() |
| | |
| | st.session_state.centers = centers |
| | st.session_state.scales = scales |
| | st.session_state.weights = weights |
| | else: |
| | centers = st.session_state.centers |
| | scales = st.session_state.scales |
| | weights = st.session_state.weights |
| |
|
| | |
| | st.subheader("采样设置") |
| | n_samples = st.slider("采样点数", 5, 1000, 100) |
| | if st.button("重新采样"): |
| | |
| | gmm = GeneralizedGaussianMixture( |
| | D=2, |
| | K=K, |
| | p=st.session_state.p, |
| | centers=centers[:K], |
| | scales=scales[:K], |
| | weights=weights[:K] |
| | ) |
| | |
| | samples, _ = gmm.generate_samples(n_samples) |
| | st.session_state.sample_points = samples |
| | st.session_state.kan_model = None |
| |
|
| | |
| | dataset = GeneralizedGaussianMixture( |
| | D=2, |
| | K=K, |
| | p=st.session_state.p, |
| | centers=centers[:K], |
| | scales=scales[:K], |
| | weights=weights[:K] |
| | ) |
| |
|
| | |
| | x = np.linspace(-5, 5, 100) |
| | y = np.linspace(-5, 5, 100) |
| | X, Y = np.meshgrid(x, y) |
| | xy = np.column_stack((X.ravel(), Y.ravel())) |
| |
|
| | |
| | Z = dataset.pdf(xy).reshape(X.shape) |
| |
|
| | |
| | fig = make_subplots( |
| | rows=1, cols=2, |
| | specs=[[{'type': 'surface'}, {'type': 'contour'}]], |
| | subplot_titles=('3D概率密度曲面', '等高线图与分量中心') |
| | ) |
| |
|
| | |
| | surface = go.Surface( |
| | x=X, y=Y, z=Z, |
| | colorscale='viridis', |
| | showscale=True, |
| | colorbar=dict(x=0.45) |
| | ) |
| | fig.add_trace(surface, row=1, col=1) |
| |
|
| | |
| | contour = go.Contour( |
| | x=x, y=y, z=Z, |
| | colorscale='viridis', |
| | showscale=True, |
| | colorbar=dict(x=1.0), |
| | contours=dict( |
| | showlabels=True, |
| | labelfont=dict(size=12) |
| | ) |
| | ) |
| | fig.add_trace(contour, row=1, col=2) |
| |
|
| | |
| | fig.add_trace( |
| | go.Scatter( |
| | x=centers[:K, 0], y=centers[:K, 1], |
| | mode='markers+text', |
| | marker=dict(size=10, color='red'), |
| | text=[f'C{i+1}' for i in range(K)], |
| | textposition="top center", |
| | name='分量中心' |
| | ), |
| | row=1, col=2 |
| | ) |
| |
|
| | |
| | if st.session_state.sample_points is not None: |
| | samples = st.session_state.sample_points |
| | |
| | probs = dataset.pdf(samples) |
| | |
| | posteriors = [] |
| | for sample in samples: |
| | component_probs = [ |
| | weights[k] * np.exp(-np.sum(((sample - centers[k]) / scales[k])**st.session_state.p)) |
| | for k in range(K) |
| | ] |
| | total = sum(component_probs) |
| | posteriors.append([p/total for p in component_probs]) |
| | |
| | |
| | fig.add_trace( |
| | go.Scatter( |
| | x=samples[:, 0], y=samples[:, 1], |
| | mode='markers+text', |
| | marker=dict( |
| | size=8, |
| | color='yellow', |
| | line=dict(color='black', width=1) |
| | ), |
| | text=[f'S{i+1}' for i in range(len(samples))], |
| | textposition="bottom center", |
| | name='采样点' |
| | ), |
| | row=1, col=2 |
| | ) |
| |
|
| | |
| | fig.update_layout( |
| | title='广义高斯混合分布', |
| | showlegend=True, |
| | width=1200, |
| | height=600, |
| | scene=dict( |
| | xaxis_title='X', |
| | yaxis_title='Y', |
| | zaxis_title='密度' |
| | ) |
| | ) |
| |
|
| | |
| | fig.update_xaxes(title_text='X', row=1, col=2) |
| | fig.update_yaxes(title_text='Y', row=1, col=2) |
| |
|
| | |
| | st.plotly_chart(fig, use_container_width=False) |
| |
|
| |
|
| | |
| | if st.session_state.sample_points is not None: |
| | st.markdown("---") |
| | st.subheader("KAN网络训练与预测") |
| |
|
| | kan_distribution_plot_placeholder = st.empty() |
| | |
| | |
| | col1, col2, col3 = st.columns([1, 2, 1]) |
| | with col1: |
| | if st.button("拟合KAN", use_container_width=False): |
| | with st.spinner('训练KAN网络中...'): |
| | st.session_state.kan_model = train_kan(st.session_state.sample_points, dataset) |
| | st.balloons() |
| |
|
| | with col3: |
| | if st.session_state.kan_model is not None: |
| | if st.button("清除KAN结果", use_container_width=False): |
| | st.session_state.kan_model = None |
| | st.rerun() |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | st.markdown("---") |
| |
|
| | |
| | if st.session_state.sample_points is not None: |
| | |
| | samples = st.session_state.sample_points |
| | probs = dataset.pdf(samples) |
| | posteriors = [] |
| | for sample in samples: |
| | component_probs = [ |
| | weights[k] * np.exp(-np.sum(((sample - centers[k]) / scales[k])**st.session_state.p)) |
| | for k in range(K) |
| | ] |
| | total = sum(component_probs) |
| | posteriors.append([p/total for p in component_probs]) |
| | |
| | with st.expander("采样点信息"): |
| | |
| | point_data = [] |
| | for i, (sample, prob, post) in enumerate(zip(samples, probs, posteriors)): |
| | row = { |
| | '采样点': f'S{i+1}', |
| | 'X坐标': f'{sample[0]:.2f}', |
| | 'Y坐标': f'{sample[1]:.2f}', |
| | '概率密度': f'{prob:.4f}' |
| | } |
| | |
| | for k in range(K): |
| | row[f'分量{k+1}后验概率'] = f'{post[k]:.4f}' |
| | point_data.append(row) |
| | |
| | |
| | st.dataframe(point_data) |
| |
|
| | |
| | with st.expander("分布参数说明"): |
| | st.markdown(""" |
| | - **形状参数 (p)**:控制分布的形状 |
| | - p = 1: 拉普拉斯分布 |
| | - p = 2: 高斯分布 |
| | - p → ∞: 均匀分布 |
| | - **分量参数**:每个分量由以下参数确定 |
| | - 中心 (μ): 峰值位置,通过X和Y坐标确定 |
| | - 尺度 (α): 分布的展宽程度,X和Y方向可不同 |
| | - 权重 (π): 混合系数,所有分量权重和为1 |
| | """) |
| |
|
| | |
| | with st.expander("分布概率密度函数公式"): |
| | st.latex(generate_latex_formula(st.session_state.p, K, centers[:K], scales[:K], weights[:K])) |