| """This script is used to generate images with different checkpoints and prompts""" |
| from copy import copy |
| import os |
| import re |
| import subprocess |
| import sys |
| from typing import Any, List, Tuple, Union |
|
|
| sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts")) |
| from scripts.Utils import Utils |
| from scripts.Logger import Logger |
| from scripts.CivitaihelperPrompts import CivitaihelperPrompts |
| from scripts.Save import Save |
| from scripts.BatchParams import BatchParams, get_all_batch_params |
|
|
|
|
| import gradio as gr |
| import modules |
| import modules.scripts as scripts |
| import modules.shared as shared |
| from modules.shared_state import State as shared_state |
| from modules import processing |
| from modules.processing import process_images |
| from modules.ui_components import (FormColumn, FormRow) |
|
|
| from PIL import Image, ImageDraw, ImageFont |
|
|
| import PIL |
|
|
|
|
|
|
| try: |
| import matplotlib.font_manager as fm |
| except: |
| subprocess.check_call(["pip", "install", "matplotlib"]) |
| import matplotlib.font_manager as fm |
|
|
| class ToolButton(gr.Button, gr.components.FormComponent): |
| """Small button with single emoji as text, fits inside gradio forms""" |
|
|
| def __init__(self, **kwargs: Any) -> None: |
| super().__init__(variant="tool", elem_classes=["batch-checkpoint-prompt"], **kwargs) |
|
|
| def get_block_name(self) -> str: |
| return "button" |
|
|
|
|
| class CheckpointLoopScript(scripts.Script): |
| """Script for generating images with different checkpoints and prompts |
| This calss is called by A1111 |
| """ |
|
|
| def __init__(self) -> None: |
| self.margin_size = 0 |
| self.logger = Logger() |
| self.logger.debug = False |
| self.font = None |
| self.text_margin_left_and_right = 16 |
| self.fill_values_symbol = "\U0001f4d2" |
| self.zero_width_space = '\u200B' |
| self.zero_width_joiner = '\u200D' |
| self.save_symbol = "\U0001F4BE" |
| self.reload_symbol = "\U0001F504" |
| self.index_symbol = "\U0001F522" |
| self.rm_index_symbol = "\U0001F5D1" |
| self.save = Save() |
| self.utils = Utils() |
| self.civitai_helper = CivitaihelperPrompts() |
| self.outdir_txt2img_grids = shared.opts.outdir_txt2img_grids |
| self.outdir_img2img_grids = shared.opts.outdir_img2img_grids |
|
|
|
|
| def title(self) -> str: |
| return "Batch Checkpoint and Prompt" |
|
|
| def save_inputs(self, save_name: str, checkpoints: str, prompt_templates: str, action : str) -> str: |
| """Save the inputs to a file |
| |
| Args: |
| save_name (str): the save name |
| checkpoints (str): the checkpoints |
| prompt_templates (str): the prompt templates |
| action (str): Possible values: "No", "Overwrite existing save", "append existing save" |
| |
| Returns: |
| str: the save status |
| """ |
| overwrite_existing_save = False |
| append_existing_save = False |
| if action == "Overwrite existing save": |
| overwrite_existing_save = True |
| elif action == "append existing save": |
| append_existing_save = True |
| return self.save.store_values( |
| save_name.strip(), checkpoints.strip(), prompt_templates.strip(), overwrite_existing_save, append_existing_save) |
| |
|
|
| """ def load_inputs(self, name: str) -> None: |
| values = self.save.read_value(name.strip()) """ |
|
|
| def get_checkpoints(self) -> str: |
| """Get the checkpoints from the sd_models module. |
| Add the index to the checkpoints |
| |
| Returns: |
| str: the checkpoints |
| """ |
| checkpoint_list_no_index = list(modules.sd_models.checkpoints_list) |
| checkpoint_list_with_index = [] |
| for i in range(len(checkpoint_list_no_index)): |
| checkpoint_list_with_index.append( |
| f"{checkpoint_list_no_index[i]} @index:{i}") |
| return ',\n'.join(checkpoint_list_with_index) |
|
|
| def getCheckpoints_and_prompt_with_index_and_version(self, checkpoint_list: str, prompts: str, add_model_version: bool) -> Tuple[str, str]: |
| """Add the index to the checkpoints and prompts |
| and add the model version to the checkpoints |
| |
| Args: |
| checkpoint_list (str): the checkpoint list |
| prompts (str): the prompts |
| add_model_version (bool): add the model version to the checkpoints. EXPERIMENTAL! |
| |
| Returns: |
| Tuple[str, str]: the checkpoints and prompts |
| """ |
| checkpoints = self.utils.add_index_to_string(checkpoint_list) |
| if add_model_version: |
| checkpoints = self.utils.add_model_version_to_string(checkpoints) |
| prompts = self.utils.add_index_to_string(prompts, is_checkpoint=False) |
| return checkpoints, prompts |
| |
| def refresh_saved(self) -> gr.Dropdown: |
| """Refresh the saved values dropdown |
| |
| Returns: |
| gr.Dropdown: the updated dropdown |
| """ |
| return gr.Dropdown.update(choices=self.save.get_keys()) |
| |
| def remove_checkpoints_prompt_at_index(self, checkpoints: str, prompts: str, index: str) -> List[str]: |
| """Remove the checkpoint and prompt at the specified index |
| |
| Args: |
| checkpoints (str): the checkpoints |
| prompts (str): the prompts |
| index (str): the index |
| |
| Returns: |
| List[str]: the checkpoints and prompts |
| """ |
| index_list = index.split(",") |
| index_list_num = [int(i) for i in index_list] |
| return self.utils.remove_element_at_index(checkpoints, prompts, index_list_num) |
| |
| |
| |
|
|
| def ui(self, is_img2img: bool) -> List[Union[gr.components.Textbox, gr.components.Slider]]: |
| """Create the UI |
| |
| Args: |
| is_img2img (bool): not used. |
| |
| Returns: |
| List[Union[gr.components.Textbox, gr.components.Slider]]: the UI components |
| """ |
| with gr.Tab("Parameters"): |
| with FormRow(): |
| checkpoints_input = gr.components.Textbox( |
| lines=5, label="Checkpoint Names", placeholder="Checkpoint names (separated with comma)") |
| fill_checkpoints_button = ToolButton( |
| value=self.fill_values_symbol, visible=True) |
| with FormRow(): |
| |
| checkpoints_prompt = gr.components.Textbox( |
| lines=5, label="Prompts/prompt templates for Checkpoints", placeholder="prompts/prompt templates (separated with semicolon)") |
| |
| civitai_prompt_fill_button = ToolButton( |
| value=self.fill_values_symbol+self.zero_width_joiner, visible=True) |
| add_index_button = ToolButton( |
| value=self.index_symbol, visible=True) |
| with FormColumn(): |
| with FormRow(): |
| rm_model_prompt_at_indexes_textbox = gr.components.Textbox(lines=1, label="Remove checkpoint and prompt at index", placeholder="Remove checkpoint and prompt at index (separated with comma)") |
| rm_model_prompt_at_indexes_button = ToolButton(value=self.rm_index_symbol, visible=True) |
| margin_size = gr.Slider( |
| label="Grid margins (px)", minimum=0, maximum=10, value=0, step=1) |
|
|
| |
|
|
| with FormRow(): |
| keys = self.save.get_keys() |
| saved_inputs_dropdown = gr.components.Dropdown( |
| choices=keys, label="Saved values") |
| |
| load_button = ToolButton( |
| value=self.fill_values_symbol+self.zero_width_space, visible=True) |
| refresh_button = ToolButton(value=self.reload_symbol, visible=True) |
|
|
|
|
| with FormRow(): |
| save_name = gr.components.Textbox( |
| lines=1, label="save name", placeholder="save name") |
| save_button = ToolButton(value=self.save_symbol, visible=True) |
| with FormRow(): |
| test = gr.components.Radio(["No", "Overwrite existing save", "append existing save"], label="Change saves?") |
|
|
| save_status = gr.Textbox(label="", interactive=False) |
|
|
| |
| |
| |
| with gr.Accordion(label='Advanced settings', open=False): |
| gr.Markdown(""" |
| This can take a long time depending on the number of checkpoints! <br> |
| See the help tab for more information |
| """) |
| add_model_version_checkbox = gr.components.Checkbox(label="Add model version to checkpoint names", interactive=False |
| , info="Not working in current webui versions") |
|
|
| |
|
|
| fill_checkpoints_button.click( |
| fn=self.get_checkpoints, outputs=[checkpoints_input]) |
| save_button.click(fn=self.save_inputs, inputs=[ |
| save_name, checkpoints_input, checkpoints_prompt, test], outputs=[save_status]) |
| load_button.click(fn=self.save.read_value, inputs=[saved_inputs_dropdown], outputs=[ |
| checkpoints_input, checkpoints_prompt]) |
| civitai_prompt_fill_button.click(fn=self.civitai_helper.createCivitaiPromptString, inputs=[ |
| checkpoints_input], outputs=[checkpoints_prompt]) |
| add_index_button.click(fn=self.getCheckpoints_and_prompt_with_index_and_version, inputs=[ |
| checkpoints_input, checkpoints_prompt, add_model_version_checkbox], outputs=[checkpoints_input, checkpoints_prompt]) |
| |
| refresh_button.click(fn=self.refresh_saved, outputs=[saved_inputs_dropdown]) |
|
|
| rm_model_prompt_at_indexes_button.click(fn=self.remove_checkpoints_prompt_at_index, inputs=[ |
| checkpoints_input, checkpoints_prompt, rm_model_prompt_at_indexes_textbox], outputs=[checkpoints_input, checkpoints_prompt]) |
|
|
| with gr.Tab("help"): |
| gr.Markdown(self.utils.get_help_md()) |
|
|
| return [checkpoints_input, checkpoints_prompt, margin_size] |
|
|
| def show(self, is_img2img: bool) -> bool: |
| """Show the UI in text2img and img2img mode |
| |
| Args: |
| is_img2img (bool): not used |
| |
| Returns: |
| bool: True |
| """ |
| return True |
| |
|
|
| def _generate_images_with_SD(self,p: Union[modules.processing.StableDiffusionProcessingTxt2Img, modules.processing.StableDiffusionProcessingImg2Img], |
| batch_params: BatchParams, orginal_size: Tuple[int, int]) -> modules.processing.Processed: |
| """ manipulates the StableDiffusionProcessing Obect |
| to generate images with the new checkpoint and prompt |
| and other parameters |
| |
| Args: |
| p (Union[modules.processing.StableDiffusionProcessingTxt2Img, modules.processing.StableDiffusionProcessingImg2Img]): the processing object |
| batch_params (BatchParams): the batch parameters |
| orginal_size (Tuple[int, int]): the original size specified in the UI |
| |
| Returns: |
| modules.processing.Processed: the processed object |
| """ |
| self.logger.debug_log(str(batch_params), False) |
| |
| info = None |
| info = modules.sd_models.get_closet_checkpoint_match(batch_params.checkpoint) |
| modules.sd_models.reload_model_weights(shared.sd_model, info) |
| p.override_settings['sd_model_checkpoint'] = info.name |
| p.prompt = batch_params.prompt |
| p.negative_prompt = batch_params.neg_prompt |
| if len(batch_params.style) > 0: |
| p.styles = batch_params.style |
| p.n_iter = batch_params.batch_count |
| shared.opts.data["CLIP_stop_at_last_layers"] = batch_params.clip_skip |
| if batch_params.width > 0 and batch_params.height > 0: |
| self.logger.debug_print_attributes(p, False) |
| p.height = batch_params.height |
| p.width = batch_params.width |
| else: |
| p.width, p.height = orginal_size |
| p.hr_prompt = batch_params.hr_prompt |
| p.hr_negative_prompt = p.negative_prompt |
| self.logger.debug_log(f"batch count {p.n_iter}") |
|
|
| processed = process_images(p) |
|
|
| return processed |
| |
|
|
| def _generate_infotexts(self, pc: Union[modules.processing.StableDiffusionProcessingTxt2Img, modules.processing.StableDiffusionProcessingImg2Img], |
| all_infotexts: List[str], n_iter: int) -> List[str]: |
| """Generate the infotexts for the images |
| |
| Args: |
| pc (Union[modules.processing.StableDiffusionProcessingTxt2Img, modules.processing.StableDiffusionProcessingImg2Img]): the processing object |
| all_infotexts (List[str]): the infotexts created by A1111 |
| n_iter (int): the number of iterations |
| |
| Returns: |
| List[str]: the infotexts |
| """ |
|
|
| def _a1111_infotext_caller(i: int = 0) -> str: |
| """Call A1111 to create a infotext. This is a helper function. |
| |
| Args: |
| i (int, optional): the index. Defaults to 0. Used to get the correct seed and subseed. |
| |
| Returns: |
| str: the infotext |
| """ |
| return processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds, position_in_batch=i) |
|
|
| self.logger.pretty_debug_log(all_infotexts) |
|
|
|
|
| self.logger.debug_print_attributes(pc) |
|
|
| if n_iter == 1: |
| all_infotexts.append(_a1111_infotext_caller()) |
| else: |
| all_infotexts.append(self.base_prompt) |
| for i in range(n_iter * pc.batch_size): |
| all_infotexts.append(_a1111_infotext_caller(i)) |
|
|
| return all_infotexts |
|
|
|
|
| def run(self, p: Union[modules.processing.StableDiffusionProcessingTxt2Img, modules.processing.StableDiffusionProcessingImg2Img], checkpoints_text: str, checkpoints_prompt: str, margin_size: int) -> modules.processing.Processed: |
| """The main function to generate the images |
| |
| Args: |
| p (Union[modules.processing.StableDiffusionProcessingTxt2Img, modules.processing.StableDiffusionProcessingImg2Img]): the processing object |
| checkpoints_text (str): the checkpoints |
| checkpoints_prompt (str): the prompts |
| margin_size (int): the margin size for the grid |
| |
| Returns: |
| modules.processing.Processed: the processed object |
| """ |
| image_processed = [] |
| self.margin_size = margin_size |
|
|
| def _get_total_batch_count(batchParams: List[BatchParams]) -> int: |
| """Get the total batch count to update the progress bar |
| |
| Args: |
| batchParams (List[BatchParams]): the batch parameters |
| |
| Returns: |
| int: the total batch count |
| """ |
| summe = 0 |
| for param in batchParams: |
| summe += param.batch_count |
| return summe |
| |
| self.base_prompt: str = p.prompt |
|
|
| all_batchParams = get_all_batch_params(p, checkpoints_text, checkpoints_prompt) |
|
|
| total_batch_count = _get_total_batch_count(all_batchParams) |
| total_steps = p.steps * total_batch_count |
| self.logger.debug_log(f"total steps: {total_steps}") |
|
|
| shared.state.job_count = total_batch_count |
| shared.total_tqdm.updateTotal(total_steps) |
|
|
| all_infotexts = [self.base_prompt] |
|
|
| p.extra_generation_params['Script'] = self.title() |
|
|
| self.logger.log_info(f'will generate {total_batch_count} images over {len(all_batchParams)} checkpoints)') |
|
|
| original_size = p.width, p.height |
| |
|
|
| for i, checkpoint in enumerate(all_batchParams): |
|
|
| |
| self.logger.log_info(f"checkpoint: {i+1}/{len(all_batchParams)} ({checkpoint.checkpoint})") |
|
|
|
|
| self.logger.debug_log( |
| f"Propmpt with replace: {all_batchParams[i].prompt}, neg prompt: {all_batchParams[i].neg_prompt}") |
| |
|
|
| processed_sd_object = self._generate_images_with_SD(p, all_batchParams[i], original_size) |
|
|
| image_processed.append(processed_sd_object) |
|
|
| |
| all_infotexts = self._generate_infotexts(copy(p), all_infotexts, all_batchParams[i].batch_count) |
|
|
|
|
| if shared.state.interrupted or shared.state.stopping_generation: |
| break |
|
|
| img_grid = self._create_grid(image_processed, all_batchParams) |
|
|
| image_processed[0].images.insert(0, img_grid) |
| image_processed[0].index_of_first_image = 1 |
| for i, image in enumerate(image_processed): |
| if i > 0: |
| for j in range(len(image_processed[i].images)): |
| image_processed[0].images.append( |
| image_processed[i].images[j]) |
| |
| image_processed[0].infotexts = all_infotexts |
|
|
|
|
| return image_processed[0] |
|
|
| |
|
|
| def _create_grid(self, image_processed: List[modules.processing.Processed], all_batch_params: List[BatchParams]) -> PIL.Image.Image: |
| """Create the grid with the images |
| |
| Args: |
| image_processed (List[modules.processing.Processed]): the images |
| all_batch_params (List[BatchParams]): the batch parameters |
| |
| Returns: |
| PIL.Image.Image: the grid |
| """ |
| self.logger.log_info( |
| "creating the grid. This can take a while, depending on the amount of images") |
|
|
| def _getFileName(save_path: str) -> str: |
| """Get the file name for the grid. |
| The files are acsending numbered. |
| |
| Args: |
| save_path (str): the save path |
| |
| Returns: |
| str: the file name |
| """ |
| save_path = os.path.join(save_path, "Checkpoint-Prompt-Loop") |
| self.logger.debug_log(f"save path: {save_path}") |
| if not os.path.exists(save_path): |
| os.makedirs(save_path) |
|
|
| files = os.listdir(save_path) |
| pattern = r"img_(\d{4})" |
|
|
| matching_files = [f for f in files if re.match(pattern, f)] |
|
|
| if matching_files: |
|
|
| matching_files.sort() |
| last_file = matching_files[-1] |
| match = re.search(r"\d{4}", last_file) |
| number = int(match.group()) if match else 0 |
| else: |
| number = 0 |
|
|
| new_number = number + 1 |
|
|
| return os.path.join(save_path, f"img_{new_number:04d}.png") |
|
|
| total_width = 0 |
| max_height = 0 |
| min_height = 0 |
|
|
| spacing = self.margin_size |
|
|
| |
| for img in image_processed: |
| total_width += img.images[0].size[0] + spacing |
|
|
| img_with_legend = [] |
| for i, img in enumerate(image_processed): |
| img_with_legend.append(self._add_legend( |
| img.images[0], all_batch_params[i].checkpoint)) |
|
|
| for img in img_with_legend: |
| max_height = max(max_height, img.size[1]) |
| min_height = min(min_height, img.size[1]) |
|
|
| result_img = Image.new('RGB', (total_width, max_height), "white") |
|
|
| x_offset = -spacing |
| for i, img in enumerate(img_with_legend): |
| y_offset = max_height - img.size[1] |
| result_img.paste(((0, 0, 0)), (x_offset, 0, x_offset + |
| img.size[0] + spacing, max_height + spacing)) |
| result_img.paste(((255, 255, 255)), (x_offset, 0, |
| x_offset + img.size[0], max_height - min_height)) |
| result_img.paste(img, (x_offset + spacing, y_offset)) |
|
|
| x_offset += img.size[0] + spacing |
|
|
| if self.is_img2img: |
| result_img.save(_getFileName(self.outdir_img2img_grids)) |
| else: |
| result_img.save(_getFileName(self.outdir_txt2img_grids)) |
|
|
| return result_img |
| |
| def _add_legend(self, img: Image, checkpoint_name: str) -> Image: |
| """Add the checkpoint name to the image |
| |
| Args: |
| img (Image): the image |
| checkpoint_name (str): the checkpoint name |
| |
| Returns: |
| Image: the image with the checkpoint name as legend |
| """ |
|
|
| def _find_available_font() -> str: |
| """Find an available font |
| |
| Returns: |
| str: the font |
| """ |
|
|
| if self.font is None: |
|
|
| self.font = fm.findfont( |
| fm.FontProperties(family='DejaVu Sans')) |
|
|
| if self.font is None: |
| font_list = fm.findSystemFonts( |
| fontpaths=None, fontext='ttf') |
|
|
| for font_file in font_list: |
| self.font = os.path.abspath(font_file) |
| if os.path.isfile(self.font): |
| self.logger.debug_log("font list font") |
| return self.font |
|
|
| self.logger.debug_log("default font") |
| return ImageFont.load_default() |
| self.logger.debug_log("DejaVu font") |
|
|
| return self.font |
|
|
| def _strip_checkpoint_name(checkpoint_name: str) -> str: |
| """Remove the path from the checkpoint name |
| |
| Args: |
| checkpoint_name (str): the checkpoint with path |
| |
| Returns: |
| str: the checkpoint name |
| """ |
| checkpoint_name = os.path.basename(checkpoint_name) |
| return self.utils.get_clean_checkpoint_path(checkpoint_name) |
|
|
| def _calculate_font(draw: ImageDraw, text: str, width: int) -> Tuple[int, int]: |
| """Calculate the font size for the text according to the image width |
| |
| Args: |
| draw (ImageDraw): the draw object |
| text (str): the text |
| width (int): the image width |
| |
| Returns: |
| Tuple[int, int]: the font and the text height |
| """ |
| width -= self.text_margin_left_and_right |
| default_font_path = _find_available_font() |
| font_size = 1 |
| font = ImageFont.truetype( |
| default_font_path, font_size) if default_font_path else ImageFont.load_default() |
| text_width, text_height = draw.textsize(text, font) |
|
|
| while text_width < width: |
| self.logger.debug_log( |
| f"text width: {text_width}, img width: {width}") |
| font_size += 1 |
| font = ImageFont.truetype( |
| default_font_path, font_size) if default_font_path else ImageFont.load_default() |
| text_width, text_height = draw.textsize(text, font) |
|
|
| return font, text_height |
|
|
| checkpoint_name = _strip_checkpoint_name(checkpoint_name) |
|
|
| width, height = img.size |
|
|
| draw = ImageDraw.Draw(img) |
|
|
| font, text_height = _calculate_font(draw, checkpoint_name, width) |
|
|
| new_image = Image.new("RGB", (width, height + text_height), "white") |
| new_image.paste(img, (0, text_height)) |
|
|
| new_draw = ImageDraw.Draw(new_image) |
|
|
| new_draw.text((self.text_margin_left_and_right/4, 0), |
| checkpoint_name, fill="black", font=font) |
|
|
| return new_image |
|
|