| """"This module provides a function to get all""" |
| from dataclasses import dataclass |
| from typing import Union, List, Tuple |
| import re |
| import os |
| from scripts.Logger import Logger |
| from scripts.Utils import Utils |
|
|
| import modules |
|
|
| import modules.shared as shared |
|
|
| @dataclass() |
| class BatchParams: |
| """Dataclass to store the parameters for a batch |
| |
| Args: |
| checkpoint (str): the checkpoint name |
| prompt (str): the prompt |
| hr_prompt (str): the hires. fix prompt |
| neg_prompt (str): the negative prompt |
| style (List[str]): the style (A1111 styles) |
| batch_count (int, optional): the batch count. Defaults to -1. (don't overwrite the UI value) |
| clip_skip (int, optional): the clip skip. Defaults to 1. |
| width (int, optional): the width. Defaults to -1. (don't overwrite the UI value) |
| height (int, optional): the height. Defaults to -1. (don't overwrite the UI value) |
| """ |
| checkpoint: str |
| prompt: str |
| hr_prompt: str |
| neg_prompt: str |
| style : List[str] |
| batch_count: int = -1 |
| clip_skip: int = 1 |
| width: int = -1 |
| height: int = -1 |
|
|
| def __repr__(self) -> str: |
| checkpointName: str = os.path.basename(self.checkpoint) |
| return( f"BatchParams: {checkpointName},\n " |
| f"prompt: {self.prompt},\n" |
| f"style: {self.style},\n" |
| f"neg_prompt: {self.neg_prompt},\n " |
| f"batch_count: {self.batch_count},\n " |
| f"clip_skip: {self.clip_skip}\n" |
| f"size: {self.width}x{self.height}") |
|
|
| logger = Logger() |
|
|
| def get_all_batch_params(p: Union[modules.processing.StableDiffusionProcessingTxt2Img, modules.processing.StableDiffusionProcessingImg2Img], checkpoints_as_string: str, prompts_as_string: str) -> List[BatchParams]: |
| """Get all batch parameters from the input |
| |
| Args: |
| p (Union[modules.processing.StableDiffusionProcessingTxt2Img, modules.processing.StableDiffusionProcessingImg2Img]): the processing object |
| checkpoints_as_string (str): the checkpoints as string |
| prompts_as_string (str): the prompts as string |
| |
| Returns: |
| List[BatchParams]: the batch parameters |
| """ |
|
|
| def getRegexFromOpts(key: str, search_for_number: bool = True) -> Tuple[str, str]: |
| """Get the regex from the options. As the user can change the regex, |
| it is checked if the regex is valid. |
| |
| Args: |
| key (str): the key |
| search_for_number (bool, optional): If true checks if the regex is valid. Defaults to True. |
| |
| Returns: |
| Tuple[str, str]: the search pattern and the sub pattern |
| """ |
| sub_pattern = getattr(shared.opts, key) |
| search_pattern = sub_pattern.replace("[", "([").replace("]", "])") |
|
|
| if not re.search(r"\[0-9\]\+|\\d\+", sub_pattern) and search_for_number: |
| raise RuntimeError(f'Can\'t find a number with the regex for {key}: "{sub_pattern}"') |
| |
| return search_pattern, sub_pattern |
|
|
| utils = Utils() |
|
|
| def get_batch_count_from_prompt(prompt: str) -> Tuple[int, str]: |
| """Extracts the batch count from the prompt if specified, else uses the default value |
| |
| Args: |
| prompt (str): the prompt |
| |
| Returns: |
| Tuple[int, str]: the batch count and the prompt |
| """ |
| search_pattern, sub_pattern = getRegexFromOpts("batchCountRegex") |
| number_match = re.search(search_pattern, prompt) |
| if number_match and number_match.group(1): |
| |
| number = int(number_match.group(1)) |
| number = p.n_iter if number < 1 else number |
| prompt = re.sub(sub_pattern, '', prompt) |
| else: |
| number = p.n_iter |
|
|
|
|
| return number, prompt |
| |
| def get_clip_skip_from_prompt(prompt: str) -> Tuple[int, str]: |
| """Extracts the clip skip from the prompt if specified, else uses the default value |
| |
| Args: |
| prompt (str): the prompt |
| |
| Returns: |
| Tuple[int, str]: the clip skip and the prompt |
| """ |
| search_pattern, sub_pattern = getRegexFromOpts("clipSkipRegex") |
| number_match = re.search(search_pattern, prompt) |
| if number_match and number_match.group(1): |
| |
| number = int(number_match.group(1)) |
| number = shared.opts.data["CLIP_stop_at_last_layers"] if number < 1 else number |
| prompt = ( |
| re.sub(sub_pattern, '', prompt)) |
| else: |
| number = shared.opts.data["CLIP_stop_at_last_layers"] |
|
|
|
|
| return number, prompt |
| |
| def get_style_from_prompt(prompt: str) -> Tuple[List[str], str]: |
| """Extracts the style from the prompt if specified. |
| |
| Args: |
| prompt (str): the prompt |
| |
| Returns: |
| Tuple[List[str], str]: the styles and the prompt |
| """ |
| styles = [] |
| search_pattern, sub_pattern = getRegexFromOpts("styleRegex", False) |
| style_matches = re.findall(search_pattern, prompt) |
| if style_matches: |
| for i, stl in enumerate(style_matches): |
| styles.append(stl) |
| _, prompt_regex = getRegexFromOpts("promptRegex", False) |
| replacement = prompt_regex if i == len(style_matches) - 1 else "" |
| prompt = re.sub(sub_pattern, replacement, prompt, count=1) |
|
|
| logger.debug_log(f"nr.: {i}, prompt: {prompt}", False) |
|
|
| return styles, prompt |
| |
| def get_image_size_from_prompt(prompt: str) -> Tuple[int, int, str]: |
| """Extracts the image size from the prompt if specified, else uses the default value |
| |
| Args: |
| prompt (str): the prompt |
| |
| Returns: |
| Tuple[int, int, str]: the width, height and the prompt. |
| If the width and height are not specified, -1 is returned. |
| """ |
| search_pattern, sub_pattern = getRegexFromOpts("widthHeightRegex", False) |
| number_matches = re.search(search_pattern, prompt) |
| if number_matches: |
| try: |
| width, height = map(int, number_matches.groups()) |
| except ValueError: |
| raise RuntimeError(f"Can't convert the image size to an integer: {number_matches[0]}") |
| prompt = re.sub(sub_pattern, '', prompt) |
| else: |
| width, height = -1, -1 |
|
|
| return width, height, prompt |
|
|
| def split_postive_and_negative_postive_prompt(prompt: str) -> Tuple[str, str]: |
| """Splits the prompt into a positive and negative prompt. |
| If a negative prompt is specified. |
| |
| Args: |
| prompt (str): the prompt |
| |
| Returns: |
| Tuple[str, str]: the positive and negative prompt |
| """ |
| pattern = getattr(shared.opts, "negPromptRegex") |
| parts = re.split(pattern, prompt) |
| if len(parts) > 1: |
| neg_prompt = parts[1] |
| else: |
| neg_prompt = "" |
|
|
| prompt = parts[0] |
|
|
| return prompt, neg_prompt |
|
|
|
|
| all_batch_params: List[BatchParams] = [] |
|
|
| checkpoints: List[str] = utils.getCheckpointListFromInput(checkpoints_as_string) |
|
|
|
|
| prompts: List[str] = utils.remove_index_from_string(prompts_as_string).split(";") |
| prompts = [prompt.replace('\n', '').strip() for prompt in prompts if not prompt.isspace() and prompt != ''] |
|
|
| if len(prompts) != len(checkpoints): |
| logger.debug_log(f"len prompt: {len(prompts)}, len checkpoints{len(checkpoints)}") |
| raise RuntimeError("amount of prompts don't match with amount of checkpoints") |
|
|
| if len(prompts) == 0: |
| raise RuntimeError("can't run without a checkpoint and prompt") |
| |
| |
| for i in range(len(checkpoints)): |
|
|
| info = modules.sd_models.get_closet_checkpoint_match(checkpoints[i]) |
| if info is None: |
| raise RuntimeError(f"Unknown checkpoint: {checkpoints[i]}") |
|
|
|
|
| batch_count, prompts[i] = get_batch_count_from_prompt(prompts[i]) |
| clip_skip, prompts[i] = get_clip_skip_from_prompt(prompts[i]) |
| style, prompts[i] = get_style_from_prompt(prompts[i]) |
| width, height, prompts[i] = get_image_size_from_prompt(prompts[i]) |
| prompt_template, neg_prompt = split_postive_and_negative_postive_prompt(prompts[i]) |
|
|
|
|
| _, prompt_regex = getRegexFromOpts("promptRegex", False) |
|
|
| prompt = prompt_template.replace(prompt_regex, p.prompt) |
| hr_prompt = prompt_template.replace(prompt_regex, p.hr_prompt) |
| neg_prompt = p.negative_prompt + neg_prompt |
|
|
|
|
| all_batch_params.append(BatchParams(checkpoints[i], prompt,hr_prompt, neg_prompt, style, batch_count, clip_skip, width, height)) |
|
|
| logger.debug_log(f"batch_params: {all_batch_params[i]}", False) |
|
|
| return all_batch_params |