| import gradio as gr |
| import numpy as np |
| import matplotlib.pyplot as plt |
|
|
| from sklearn.datasets import load_digits |
| from sklearn.neighbors import KernelDensity |
| from sklearn.decomposition import PCA |
| from sklearn.model_selection import GridSearchCV |
|
|
| def generate_digits(bandwidth, num_samples): |
|
|
| |
| bandwidth = int(bandwidth) |
|
|
| |
| num_samples = int(num_samples) |
|
|
| |
| digits = load_digits() |
|
|
| |
| pca = PCA(n_components=15, whiten=False) |
| data = pca.fit_transform(digits.data) |
|
|
| |
| params = {"bandwidth": np.logspace(-1, 1, 20)} |
| grid = GridSearchCV(KernelDensity(), params) |
| grid.fit(data) |
|
|
| |
| kde = KernelDensity(bandwidth=bandwidth) |
| kde.fit(data) |
|
|
| |
| new_data = kde.sample(num_samples, random_state=0) |
| new_data = pca.inverse_transform(new_data) |
|
|
| |
| new_data = new_data.reshape((num_samples, 64)) |
| real_data = digits.data[:num_samples].reshape((num_samples, 64)) |
|
|
| |
| fig, ax = plt.subplots(9, 11, subplot_kw=dict(xticks=[], yticks=[])) |
| for j in range(11): |
| ax[4, j].set_visible(False) |
| for i in range(4): |
| index = i * 11 + j |
| if index < num_samples: |
| im = ax[i, j].imshow( |
| real_data[index].reshape((8, 8)), cmap=plt.cm.binary, interpolation="nearest" |
| ) |
| im.set_clim(0, 16) |
| im = ax[i + 5, j].imshow( |
| new_data[index].reshape((8, 8)), cmap=plt.cm.binary, interpolation="nearest" |
| ) |
| im.set_clim(0, 16) |
| else: |
| ax[i, j].axis("off") |
| ax[i + 5, j].axis("off") |
|
|
| ax[0, 5].set_title("Selection from the input data") |
| ax[5, 5].set_title('"New" digits drawn from the kernel density model') |
|
|
|
|
| |
| plt.savefig("digits_plot.png") |
|
|
| |
| return "digits_plot.png" |
|
|
| |
| inputs = [ |
| gr.inputs.Slider(minimum=1, maximum=10, step=1, label="Bandwidth"), |
| |
| |
| gr.inputs.Slider(minimum=1, maximum=100, step=1, label="Number of Samples") |
| ] |
| output = gr.outputs.Image(type="pil") |
|
|
| title = "Kernel Density Estimation" |
| description = "This example shows how kernel density estimation (KDE), a powerful non-parametric density estimation technique, can be used to learn a generative model for a dataset. With this generative model in place, new samples can be drawn. These new samples reflect the underlying model of the data. See the original scikit-learn example here: https://scikit-learn.org/stable/auto_examples/neighbors/plot_digits_kde_sampling.html" |
| examples = [ |
| [1, 44], |
| [8, 22], |
| [7, 51] |
| ] |
|
|
| gr.Interface(generate_digits, inputs, output, title=title, description=description, examples=examples, live=True).launch() |
|
|