| from pprint import pprint |
| from time import perf_counter |
| from traceback import print_exc |
| from typing import Any |
|
|
| from app_settings import Settings |
| from backend.image_saver import ImageSaver |
| from backend.lcm_text_to_image import LCMTextToImage |
| from backend.models.lcmdiffusion_setting import DiffusionTask |
| from backend.utils import get_blank_image |
| from models.interface_types import InterfaceType |
|
|
|
|
| class Context: |
| def __init__( |
| self, |
| interface_type: InterfaceType, |
| device="cpu", |
| ): |
| self.interface_type = interface_type.value |
| self.lcm_text_to_image = LCMTextToImage(device) |
| self._latency = 0 |
| self._error = "" |
|
|
| @property |
| def latency(self): |
| return self._latency |
|
|
| @property |
| def error(self): |
| return self._error |
|
|
| def generate_text_to_image( |
| self, |
| settings: Settings, |
| reshape: bool = False, |
| device: str = "cpu", |
| save_config=True, |
| ) -> Any: |
| try: |
| self._error = "" |
| tick = perf_counter() |
| from state import get_settings |
|
|
| if ( |
| settings.lcm_diffusion_setting.diffusion_task |
| == DiffusionTask.text_to_image.value |
| ): |
| settings.lcm_diffusion_setting.init_image = None |
|
|
| if save_config: |
| get_settings().save() |
|
|
| pprint(settings.lcm_diffusion_setting.model_dump()) |
| if not settings.lcm_diffusion_setting.lcm_lora: |
| return None |
| self.lcm_text_to_image.init( |
| device, |
| settings.lcm_diffusion_setting, |
| ) |
|
|
| images = self.lcm_text_to_image.generate( |
| settings.lcm_diffusion_setting, |
| reshape, |
| ) |
|
|
| elapsed = perf_counter() - tick |
| self._latency = elapsed |
| print(f"Latency : {elapsed:.2f} seconds") |
| if settings.lcm_diffusion_setting.controlnet: |
| if settings.lcm_diffusion_setting.controlnet.enabled: |
| images.append( |
| settings.lcm_diffusion_setting.controlnet._control_image |
| ) |
|
|
| if settings.lcm_diffusion_setting.use_safety_checker: |
| print("Safety Checker is enabled") |
| from state import get_safety_checker |
|
|
| safety_checker = get_safety_checker() |
| blank_image = get_blank_image( |
| settings.lcm_diffusion_setting.image_width, |
| settings.lcm_diffusion_setting.image_height, |
| ) |
| for idx, image in enumerate(images): |
| if not safety_checker.is_safe(image): |
| images[idx] = blank_image |
| except Exception as exception: |
| print(f"Error in generating images: {exception}") |
| self._error = str(exception) |
| print_exc() |
| return None |
| return images |
|
|
| def save_images( |
| self, |
| images: Any, |
| settings: Settings, |
| ) -> list[str]: |
| saved_images = [] |
| if images and settings.generated_images.save_image: |
| saved_images = ImageSaver.save_images( |
| settings.generated_images.path, |
| images=images, |
| lcm_diffusion_setting=settings.lcm_diffusion_setting, |
| format=settings.generated_images.format, |
| jpeg_quality=settings.generated_images.save_image_quality, |
| ) |
| return saved_images |
|
|