| """This module provides utility functions.""" |
| from scripts.Logger import Logger |
| import os |
| import re |
| import requests |
| from typing import List |
|
|
| import modules |
| from modules.sd_models import read_state_dict |
| from modules.sd_models_config import (find_checkpoint_config, config_default, config_sd2, config_sd2v, config_sd2_inpainting, |
| config_depth_model, config_unclip, config_unopenclip, config_inpainting, config_instruct_pix2pix, config_alt_diffusion) |
|
|
| import sys |
| sys.path.insert(0, os.path.join(os.path.dirname( |
| os.path.abspath(__file__)), "scripts")) |
|
|
|
|
| class Utils(): |
| """ |
| methods that are needed in different classes |
| """ |
|
|
| def __init__(self) -> None: |
| self.logger = Logger() |
| self.logger.debug = False |
| script_path = os.path.dirname( |
| os.path.dirname(os.path.abspath(__file__))) |
| self.held_md_file_name = os.path.join( |
| script_path, "HelpBatchCheckpointsPrompt.md") |
| self.held_md_url = f"https://raw.githubusercontent.com/h43lb1t0/BatchCheckpointPrompt/main/{self.held_md_file_name}.md" |
|
|
| def split_prompts(self, text: str) -> List[str]: |
| """Split the prompts by the ; and remove empty strings and newlines |
| |
| Args: |
| text (str): the input string |
| Returns: |
| List[str]: a list of prompts |
| """ |
| prompt_list = text.split(";") |
| return [prompt.replace('\n', '').strip( |
| ) for prompt in prompt_list if not prompt.isspace() and prompt != ''] |
|
|
|
|
| def remove_index_from_string(self, input: str) -> str: |
| """Remove the index from the string |
| |
| Args: |
| input (str): the input string |
| Returns: |
| str: the string without the index |
| """ |
| return re.sub(r"@index:\d+", "", input).strip() |
| |
| def remove_model_version_from_string(self, checkpoints_text: str) -> str: |
| """Remove the model version from the string |
| |
| Args: |
| input (str): the input string with all checkpoints |
| Returns: |
| str: the string without the model version |
| """ |
| patterns = [ |
| '@version:sd1', |
| '@version:sd2', |
| '@version:sd2v', |
| '@version:sd2-inpainting', |
| '@version:depth', |
| '@version:unclip', |
| '@version:unopenclip', |
| '@version:sd1-inpainting', |
| '@version:pix2pix', |
| '@version:alt' |
| ] |
| |
| |
| for pattern in patterns: |
| checkpoints_text = re.sub(pattern, '', checkpoints_text) |
|
|
| return checkpoints_text |
|
|
| def get_clean_checkpoint_path(self, checkpoint: str) -> str: |
| """Remove the checkpoint hash from the filename |
| |
| Args: |
| input (str): the input string with hash |
| Returns: |
| str: the string without the hash |
| """ |
| return re.sub(r' \[.*?\]', '', checkpoint).strip() |
|
|
| def getCheckpointListFromInput(self, checkpoints_text: str, clean: bool = True) -> List[str]: |
| """Get a list of checkpoints from the input string |
| |
| Args: |
| checkpoints_text (str): the input string with all checkpoints |
| clean (bool): remove the index and hash from the string |
| Returns: |
| List[str]: a list of checkpoints |
| """ |
| self.logger.debug_log(f"checkpoints: {checkpoints_text}") |
| checkpoints_text = self.remove_model_version_from_string(checkpoints_text) |
| if clean: |
| checkpoints_text = self.remove_index_from_string(checkpoints_text) |
| checkpoints_text = self.get_clean_checkpoint_path(checkpoints_text) |
| checkpoints = checkpoints_text.split(",") |
| checkpoints = [checkpoint.replace('\n', '').strip( |
| ) for checkpoint in checkpoints if checkpoints if not checkpoint.isspace() and checkpoint != ''] |
| return checkpoints |
|
|
| def get_help_md(self) -> str: |
| """Gets the help md file. |
| If the file is not localy found downloads it from the github repository |
| |
| Returns: |
| str: the help md file as a string |
| """ |
| md = "could not get help file. Check Github for more information" |
| if os.path.isfile(self.held_md_file_name): |
| with open(self.held_md_file_name) as f: |
| md = f.read() |
| else: |
| self.logger.debug_log("downloading help md") |
| result = requests.get(self.held_md_url) |
| if result.status_code == 200: |
| with open(self.held_md_file_name, "wb") as file: |
| file.write(result.content) |
| return self.get_help_md() |
| return md |
|
|
| def add_index_to_string(self, text: str, is_checkpoint: bool = True) -> str: |
| """Add the index to the string |
| |
| Args: |
| text (str): the input string |
| is_checkpoint (bool): if the string is a checkpoint lits or a prompt list |
| Returns: |
| str: the string with the index |
| """ |
| text_string = "" |
| if is_checkpoint: |
| checkpoint_List = self.getCheckpointListFromInput(text) |
| for i, checkpoint in enumerate(checkpoint_List): |
| text_string += f"{self.remove_index_from_string(checkpoint)} @index:{i},\n" |
| return text_string |
| else: |
| prompt_list = self.split_prompts(text) |
| for i, prompt in enumerate(prompt_list): |
| text_string += f"{self.remove_index_from_string(prompt)} @index:{i};\n\n" |
| return text_string |
|
|
| def add_model_version_to_string(self, checkpoints_text: str) -> str: |
| """Add the model version to the string. |
| EXPERIMENTAL! |
| |
| Args: |
| checkpoints_text (str): the input string with all checkpoints |
| Returns: |
| str: the string with the model version |
| """ |
| text_string = "" |
| checkpoints_not_cleaned = self.getCheckpointListFromInput( |
| checkpoints_text, clean=False) |
| checkpoints = self.getCheckpointListFromInput(checkpoints_text) |
| for i, checkpoint in enumerate(checkpoints): |
| info = modules.sd_models.get_closet_checkpoint_match(checkpoint) |
| state_dict = read_state_dict(info.filename) |
| version_string = find_checkpoint_config(state_dict, None) |
| if version_string == config_default: |
| version_string = "sd1" |
| elif version_string == config_sd2: |
| version_string = "sd2" |
| elif version_string == config_sd2v: |
| version_string = "sd2v" |
| elif version_string == config_sd2_inpainting: |
| version_string = "sd2-inpainting" |
| elif version_string == config_depth_model: |
| version_string = "depth" |
| elif version_string == config_unclip: |
| version_string = "unclip" |
| elif version_string == config_unopenclip: |
| version_string = "unopenclip" |
| elif version_string == config_inpainting: |
| version_string = "sd1-inpainting" |
| elif version_string == config_instruct_pix2pix: |
| version_string = "pix2pix" |
| elif version_string == config_alt_diffusion: |
| version_string = "alt" |
| checkpoint_partly_cleaned = checkpoints_not_cleaned[i].replace( |
| "\n", "").replace(",", "") |
| text_string += f"{checkpoint_partly_cleaned} @version:{version_string},\n\n" |
| return text_string |
|
|
| def remove_element_at_index(self, checkpoints: str, prompts: str, index: List[int]) -> List[str]: |
| """Remove the element at the given index from the string |
| |
| Args: |
| checkpoints (str): the input string with all checkpoints |
| prompts (str): the input string with all prompts |
| index (List[int]): the indices to remove |
| Returns: |
| List[str]: a list with the new checkpoints and prompts |
| """ |
|
|
| checkpoints_list = self.getCheckpointListFromInput(checkpoints) |
| prompts_list = self.split_prompts(prompts) |
| if (len(checkpoints_list) == len(prompts_list) or len(prompts_list) - len(index) <= 0 ): |
| if max(index) <= len(checkpoints_list) -1: |
| for i in index: |
| checkpoints_list.pop(i) |
| prompts_list.pop(i) |
| checkpoints = "" |
| for c in checkpoints_list: |
| checkpoints += f"{c}," |
| prompts = "" |
| for p in prompts_list: |
| prompts += f"{p};" |
| result = [self.add_index_to_string(checkpoints, True), self.add_index_to_string(prompts, False)] |
| self.logger.debug_log(f"result: {result}") |
| return result |
| else: |
| self.logger.debug_log("index is out of range") |
| return [checkpoints, prompts] |
| else: |
| self.logger.debug_log( |
| f"checkpoints and prompts are not the same length cp: {len(checkpoints_list)} p: {len(prompts_list)}") |
| return [checkpoints, prompts] |