| import gradio as gr |
| import json |
| import numpy as np |
| import re |
| import tarfile |
|
|
| import PIL.Image as PImage |
|
|
| from os import listdir, path, remove |
| from sklearn.cluster import KMeans |
| from urllib import request |
|
|
| def download_extract(): |
| url = "https://github.com/PSAM-5020-2025S-A/5020-utils/releases/latest/download/flowers.tar.gz" |
| target_path = "flowers.tar.gz" |
|
|
| with request.urlopen(request.Request(url), timeout=15.0) as response: |
| if response.status == 200: |
| with open(target_path, "wb") as f: |
| f.write(response.read()) |
| |
| tar = tarfile.open(target_path, "r:gz") |
| tar.extractall() |
| tar.close() |
| remove("flowers.tar.gz") |
|
|
| |
| def top_colors(fpath, n_clusters=8, n_colors=4): |
| pimg = PImage.open(fpath).convert("RGB") |
| pimg_pxs = list(pimg.getdata()) |
|
|
| posterizer = KMeans(n_clusters=n_clusters) |
| px_clusters = posterizer.fit_predict(pimg_pxs) |
| cluster_colors = posterizer.cluster_centers_ |
|
|
| _, ccounts = np.unique(px_clusters, return_counts=True) |
| ccounts_order = np.argsort(-ccounts) |
| ccolors_sorted = [[round(rgb) for rgb in cluster_colors[idx]] for idx in ccounts_order] |
|
|
| return ccolors_sorted[:n_colors] |
|
|
| |
| def get_top_colors(flower_image_dir): |
| flower_files = sorted([f for f in listdir(flower_image_dir) if f.endswith(".png")]) |
|
|
| file_colors = [] |
| for fname in flower_files: |
| file_colors.append({ |
| "filename": fname, |
| "colors": top_colors(f"{flower_image_dir}/{fname}", n_clusters=8, n_colors=4) |
| }) |
| return file_colors |
|
|
| |
| def color_distance(c0, c1): |
| return ((c0[0] - c1[0])**2 + (c0[1] - c1[1])**2 + (c0[2] - c1[2])**2) ** 0.5 |
|
|
| |
| def min_color_distance(ref_color, color_list): |
| c_dists = [color_distance(ref_color, c) for c in color_list] |
| return min(c_dists) |
|
|
| |
| |
| |
| |
| |
| |
| def css_to_rgb(css_str): |
| if css_str[0] == "#": |
| return [int(css_str[i:i+2], 16) for i in range(1,6,2)] |
|
|
| COLOR_PATTERN = r"([^(]+)\(([0-9.]+), ?([0-9.]+%?), ?([0-9.]+%?)(, ?([0-9.]+))?\)" |
| match = re.match(COLOR_PATTERN, css_str) |
| if not match: |
| return [0,0,0] |
|
|
| if "rgb" in match.group(1): |
| return [int(float(match.group(i))) for i in range(2,5)] |
|
|
| if "hsl" in match.group(1): |
| print("hsl not supported") |
| return [0,0,0] |
|
|
| def order_by_color(center_color_str): |
| center_color = css_to_rgb(center_color_str) |
|
|
| |
| def by_color_dist(A): |
| return min_color_distance(center_color, A["colors"]) |
|
|
| file_colors_sorted = sorted(FILE_COLORS, key=by_color_dist) |
| files_sorted = [A["filename"] for A in file_colors_sorted] |
|
|
| file_order = { |
| "color": center_color, |
| "files": files_sorted |
| } |
|
|
| return json.dumps(file_order) |
|
|
| my_inputs = [ |
| gr.ColorPicker(value="#ffdf00", label="center_color", interactive=True) |
| ] |
|
|
| my_outputs = [ |
| gr.JSON(show_label=False, show_indices=False, height=200, container=False) |
| ] |
|
|
| my_examples = [ |
| ["#FFFFFF"], |
| ["#FFD700"], |
| ["#7814BE"] |
| ] |
|
|
| def setup(): |
| global FILE_COLORS |
| FLOWER_IMG_DIR = "./data/image/flowers" |
| if not path.isdir(FLOWER_IMG_DIR): |
| download_extract() |
| FILE_COLORS = get_top_colors(FLOWER_IMG_DIR) |
|
|
| setup() |
|
|
| with gr.Blocks() as demo: |
| gr.Interface( |
| fn=order_by_color, |
| inputs=my_inputs, |
| outputs=my_outputs, |
| flagging_mode="never", |
| cache_examples=True, |
| examples=my_examples, |
| fill_width=True |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch(ssr_mode=False) |
|
|