| import torch |
| import torch.nn |
| import torchvision.models as models |
| from copy import deepcopy |
| import cv2 |
|
|
| import cv2 |
| import numpy as np |
| import sys |
| import itertools |
| import os |
| import IPython |
| import matplotlib |
| matplotlib.use("Agg") |
|
|
| import matplotlib.pyplot as plt |
| import pandas as pd |
|
|
| import openai |
| from sklearn.manifold import TSNE |
| from sklearn.decomposition import PCA, KernelPCA |
| import seaborn as sns |
|
|
| import time |
| from matplotlib.offsetbox import OffsetImage, AnnotationBbox |
| import colorsys |
| from torchvision import datasets |
| import argparse |
| import matplotlib.patheffects as PathEffects |
| from sklearn.cluster import KMeans |
|
|
|
|
| sns.set_style("white") |
| sns.set_palette("muted") |
|
|
| font = { |
| "size": 22, |
| } |
|
|
| matplotlib.rc("font", **font) |
| sns.set_context("paper", font_scale=3.0) |
|
|
|
|
| plt_param = {'legend.fontsize': 60, |
| 'axes.labelsize': 80, |
| 'axes.titlesize':80, |
| 'font.size' : 80 , |
| 'xtick.labelsize':80, |
| 'ytick.labelsize':80, |
| 'lines.linewidth': 10, |
| 'lines.color': (0,0,0)} |
|
|
| plt.rcParams.update(plt_param) |
|
|
| openai.api_key ="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" |
| GPT_MODEL = "gpt4" |
| EMBEDDING_MODEL = "text-embedding-ada-002" |
|
|
|
|
| def normalize_numpy_array(arr): |
| return arr / (arr.max(axis=-1, keepdims=True) - arr.min(axis=-1, keepdims=True)) |
|
|
| def fashion_scatter( |
| x, class_labels, fig_name, class_names, add_text=True |
| ): |
| |
| x = np.array(x) |
| class_labels = np.array(class_labels) |
| num_classes = np.max(class_labels) + 1 |
|
|
| |
| fig_size1, fig_size2 = 140 * 0.8, 80 * 0.6 |
| plt.clf() |
| plt.cla() |
| f = plt.figure(figsize=(fig_size1, fig_size2)) |
| ax = plt.subplot() |
|
|
| |
| |
| for x_i in range(num_classes): |
| mask = class_labels == x_i |
| if mask.sum() > 0: |
| sc = ax.scatter( |
| x[mask, 0], |
| x[mask, 1], |
| lw=0, |
| s=1500, |
| label=class_names[x_i] |
| |
| ) |
| if add_text: |
| txts = [] |
| for i in range(len(class_names)): |
| xtext, ytext = x[i, :] |
| txt = ax.text(xtext, ytext, str(class_names[i]), fontsize=40) |
| txt.set_path_effects( |
| [PathEffects.Stroke(linewidth=5, foreground="w"), PathEffects.Normal()] |
| ) |
| txts.append(txt) |
|
|
| |
| ax.axis("on") |
| |
| plt.savefig(fig_name +".pdf") |
| plt.clf() |
| print("save figure to ", fig_name) |
|
|
| def compute_embedding(response): |
| while True: |
| try: |
| print('ping openai api') |
| response_embedding = openai.Embedding.create( |
| model=EMBEDDING_MODEL, |
| input=response, |
| ) |
|
|
| response_embedding = np.array(response_embedding["data"][0]['embedding']) |
| return response_embedding |
| except Exception as e: |
| print(e) |
|
|
| def draw_latent_plot( |
| max_num=80, |
| method="pca+tsne", |
| fig_name="", |
| ): |
| |
| latents = [] |
| class_labels = [] |
| label_sets = [] |
| |
| |
| total_tasks = [os.path.join("cliport/tasks", x) for x in os.listdir("cliport/tasks")] + [os.path.join("cliport/generated_tasks", x) for x in os.listdir("cliport/generated_tasks")] |
| total_tasks = [t for t in total_tasks if 'pycache' not in t and 'init' not in t \ |
| and 'README' not in t and 'extended' not in t and 'gripper' not in t and 'primitive' not in t\ |
| and 'task.py' not in t and 'camera' not in t and 'seq' not in t] |
| cache_embedding_path = "output/output_embedding/task_cache_embedding.npz" |
| cache_embedding = {} |
|
|
| if os.path.exists(cache_embedding_path): |
| cache_embedding = dict(np.load(cache_embedding_path)) |
|
|
| print(total_tasks) |
|
|
| for idx, task_name in enumerate(total_tasks): |
| if task_name in cache_embedding: |
| code_embedding = cache_embedding[task_name] |
| else: |
| code = open(task_name).read() |
| code_embedding = compute_embedding(code) |
|
|
| latents.append(code_embedding) |
| label_sets.append(task_name.split("/")[-1][:-3]) |
| cache_embedding[task_name] = code_embedding |
| class_labels.append(idx) |
|
|
| latents = np.array(latents) |
| print("latents shape:", latents.shape) |
| np.savez(cache_embedding_path, **cache_embedding) |
|
|
| n_clusters = 6 |
| kmeans = KMeans(n_clusters=n_clusters, init="k-means++", random_state=42) |
| kmeans.fit(latents) |
| cluster_labels = kmeans.labels_ |
|
|
| if method == "pca+tsne": |
| |
| pca = PCA(random_state=123, n_components=min(50, max_num)) |
|
|
| X_embedded = pca.fit_transform(latents) |
| print( |
| "Variance explained per principal component: {}".format( |
| pca.explained_variance_ratio_[:5] |
| ) |
| ) |
| print("PCA data shape:", X_embedded.shape) |
| X_embedded = TSNE(random_state=123, perplexity=20).fit_transform(X_embedded) |
|
|
| if method == "pca": |
| pca = KernelPCA(random_state=123, n_components=2) |
| X_embedded = pca.fit_transform(latents[:, :5]) |
|
|
| if method == "tsne": |
| X_embedded = TSNE(random_state=123).fit_transform(latents) |
|
|
| fashion_scatter(X_embedded, class_labels, fig_name, label_sets) |
| fashion_scatter(X_embedded, cluster_labels, fig_name + "_cluster", label_sets) |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Generate chat-gpt embeddings") |
| """ |
| load task descriptions from the tasks folder and embed |
| """ |
| parser.add_argument("--file", type=str, default="task_embedding") |
| args = parser.parse_args() |
| draw_latent_plot(fig_name=f'output/output_embedding/{args.file}') |