| import json |
|
|
|
|
| def get_prompt_list(file_path: str) -> list[list[str]]: |
|
|
| |
| return json.load(open(file_path, "r")) |
|
|
| class GetPromptList(object): |
| _SUPPORTED_SOURCE = {'Sachit-descriptors',} |
| def __init__(self, file_path: str, name2idx: dict[str: int] = None, class_names: list[str] = None) -> None: |
| self.class_names = class_names |
| self.file_path = file_path |
| self.desc = get_prompt_list(file_path) |
| if isinstance(self.desc, dict): |
| self.__get_parts() |
| if name2idx is not None: |
| self.name2idx = name2idx |
| elif class_names is not None: |
| self.name2idx = {cls_name: idx for idx, cls_name in enumerate(class_names)} |
| else: |
| self.name2idx = {cls_name: idx for idx, cls_name in enumerate(self.desc.keys())} if isinstance(self.desc, dict) else None |
| |
| |
| |
| |
| |
| |
| def __get_parts(self, ): |
| |
| self.part_names = [d.split(":")[0].strip() for d in self.desc[list(self.desc.keys())[0]]] |
| |
| |
| @staticmethod |
| def replace_class_names(self, descs: dict, target_class: list[str], new_classes: list[str]): |
| new_descs = [] |
| for desc, cls_name, new_name in zip(descs, target_class, new_classes): |
| temp = [d.replace(cls_name, new_name) for d in desc] |
| new_descs.extend(temp) |
| return new_descs |
| |
| def __call__(self, source: str, pad: bool = False, max_len: int = 15, pad_text: str = "", target_classes: list[int] = None, pad_neg_index: bool = True): |
| """ |
| This function will return a list of prompts based on the source (format) and file_path provided. |
| If name2idx is provided, the prompts will be mapped based on the provied class indexes. Otherwise, |
| the prompts will be mapped based on the order of class name in the file. |
| Note: this function is will apply trucation when padding is True to make sure to have fixed length prompts. |
| |
| Args: |
| source (str): The sorce (format) of the prompts. Supported sources are: {self._SUPPORTED_SOURCE} |
| file_path (str): The file that contains the original prompts. |
| pad (bool, optional): Whether to pad the prompts to the same length. Defaults to False. |
| max_len (int, optional): The maximum length of the prompts. Defaults to 15. |
| pad_text (str, optional): The text to pad the prompts. Defaults to "Padding". |
| target_classes (list[int], optional): A list of class indexes to include in the prompts. Defaults to None (include all classes). |
| Returns: |
| prompts (list[str]): A list of engineered prompts. |
| class_idxs (list[int]): A list of class indexes for each prompt. |
| class_mapping (dict[int: str]): A mapping of class indexes to class names. |
| """ |
| org_desc_mapper = None |
| match source: |
| case 'Sachit-descriptors': |
| desc, org_dict = self.__get_sachit_desc(self.file_path) |
|
|
| case 'Sachit-no-template': |
| desc = self.desc |
| |
| case 'Sachit-CLIP-template-5': |
| desc, org_dict = self.__get_sachit_desc(self.file_path) |
| desc = {k: [f'a photo of a {d}.' for d in v] for k, v in desc.items()} |
| case 'cub-12-parts': |
| return self.desc, None, None, None |
| case 'chatgpt-no-template': |
| desc = self.desc |
| case 'chatgpt-template-0': |
| |
| template = 'a {} {}.' |
| desc = {k: [template.format(d.split(":")[1].strip(), d.split(":")[0].strip()) for d in v] for k, v in self.desc.items()} |
| case 'chatgpt-template-8': |
| |
| template = 'a {} {} of {}.' |
| desc = {k: [template.format(d.split(":")[1].strip(), d.split(":")[0].strip(), k) for d in v] for k, v in self.desc.items()} |
| case 'chatgpt-template-5': |
| |
| desc, org_dict = self.__get_sachit_desc(self.file_path) |
| template = 'a photo of a {}' |
| desc = {k: [template.format(d) for d in v] for k, v in desc.items()} |
| case 'chatgpt-template-x': |
| |
| desc = {k: [f'{d.split(":")[1].strip()}. {d.split(":")[0].strip()}. {k}' for d in v] for k, v in self.desc.items()} |
| case 'chatgpt-template-x-2': |
| |
| desc = {k: [f'a {d.split(":")[1].strip()} {d.split(":")[0].strip()}' for d in v] for k, v in self.desc.items()} |
| case 'chatgpt-template-x-3': |
| |
| desc = {k: [f'{d.split(":")[1].strip()}. {d.split(":")[0].strip()}.' for d in v] for k, v in self.desc.items()} |
| case 'chatgpt-template-x-4': |
| |
| desc = {k: [f'a {d.split(":")[0].strip()} of {k}: {d.split(":")[1].strip()}' for d in v] for k, v in self.desc.items()} |
| |
| case _: |
| raise ValueError(f"Source {source} is not supported. Check {self._SUPPORTED_SOURCE}") |
| |
| |
| if len(self.name2idx) < len(desc): |
| desc = {k: desc[k] for k in self.name2idx} |
| |
| prompts, class_idxs, class_list = [], [], [] |
| class_mapping = {v: k for k, v in self.name2idx.items()} |
| for class_name, class_idx in self.name2idx.items(): |
| descriptions = desc[class_name] |
| if target_classes is not None and class_idx not in target_classes: |
| continue |
| if pad: |
| pad_id = -1 if pad_neg_index else class_idx |
| ids = [class_idx] * len(descriptions) + [pad_id] * (max_len - len(descriptions)) if len(descriptions) < max_len else [class_idx] * max_len |
| if len(descriptions) < max_len: |
| descriptions.extend([pad_text] * (max_len - len(descriptions))) |
| else: |
| descriptions = descriptions[:max_len] |
| else: |
| ids = [class_idx] * len(descriptions) |
| prompts.extend(descriptions) |
| class_idxs.extend(ids) |
| class_list.append(class_name) |
| |
| if org_desc_mapper is not None: |
| org_desc_mapper = {des: org_dict[class_name][des] for des in descriptions} |
|
|
| |
| return prompts, class_idxs, class_mapping, org_desc_mapper, class_list |
| |
| |
| imagenet_templates = [ |
| 'a bad photo of a {}.', |
| 'a photo of many {}.', |
| 'a sculpture of a {}.', |
| 'a photo of the hard to see {}.', |
| 'a low resolution photo of the {}.', |
| 'a rendering of a {}.', |
| 'graffiti of a {}.', |
| 'a bad photo of the {}.', |
| 'a cropped photo of the {}.', |
| 'a tattoo of a {}.', |
| 'the embroidered {}.', |
| 'a photo of a hard to see {}.', |
| 'a bright photo of a {}.', |
| 'a photo of a clean {}.', |
| 'a photo of a dirty {}.', |
| 'a dark photo of the {}.', |
| 'a drawing of a {}.', |
| 'a photo of my {}.', |
| 'the plastic {}.', |
| 'a photo of the cool {}.', |
| 'a close-up photo of a {}.', |
| 'a black and white photo of the {}.', |
| 'a painting of the {}.', |
| 'a painting of a {}.', |
| 'a pixelated photo of the {}.', |
| 'a sculpture of the {}.', |
| 'a bright photo of the {}.', |
| 'a cropped photo of a {}.', |
| 'a plastic {}.', |
| 'a photo of the dirty {}.', |
| 'a jpeg corrupted photo of a {}.', |
| 'a blurry photo of the {}.', |
| 'a photo of the {}.', |
| 'a good photo of the {}.', |
| 'a rendering of the {}.', |
| 'a {} in a video game.', |
| 'a photo of one {}.', |
| 'a doodle of a {}.', |
| 'a close-up photo of the {}.', |
| 'a photo of a {}.', |
| 'the origami {}.', |
| 'the {} in a video game.', |
| 'a sketch of a {}.', |
| 'a doodle of the {}.', |
| 'a origami {}.', |
| 'a low resolution photo of a {}.', |
| 'the toy {}.', |
| 'a rendition of the {}.', |
| 'a photo of the clean {}.', |
| 'a photo of a large {}.', |
| 'a rendition of a {}.', |
| 'a photo of a nice {}.', |
| 'a photo of a weird {}.', |
| 'a blurry photo of a {}.', |
| 'a cartoon {}.', |
| 'art of a {}.', |
| 'a sketch of the {}.', |
| 'a embroidered {}.', |
| 'a pixelated photo of a {}.', |
| 'itap of the {}.', |
| 'a jpeg corrupted photo of the {}.', |
| 'a good photo of a {}.', |
| 'a plushie {}.', |
| 'a photo of the nice {}.', |
| 'a photo of the small {}.', |
| 'a photo of the weird {}.', |
| 'the cartoon {}.', |
| 'art of the {}.', |
| 'a drawing of the {}.', |
| 'a photo of the large {}.', |
| 'a black and white photo of a {}.', |
| 'the plushie {}.', |
| 'a dark photo of a {}.', |
| 'itap of a {}.', |
| 'graffiti of the {}.', |
| 'a toy {}.', |
| 'itap of my {}.', |
| 'a photo of a cool {}.', |
| 'a photo of a small {}.', |
| 'a tattoo of the {}.', |
| ] |
|
|
|
|
|
|
|
|
|
|
|
|