| from io import BytesIO |
|
|
| import torch |
| from data.dataAccessor import update_db |
| from data.task import ModelType, Task, TaskType |
| from pipelines.inpainter import InPainter |
| from pipelines.prompt_modifier import PromptModifier |
| from pipelines.remove_background import RemoveBackground |
| from pipelines.upscaler import Upscaler |
| from util.cache import clear_cuda |
| from util.commons import ( |
| add_code_names, |
| construct_default_s3_url, |
| upload_image, |
| upload_images, |
| ) |
| from util.slack import Slack |
|
|
| torch.backends.cudnn.benchmark = True |
| torch.backends.cuda.matmul.allow_tf32 = True |
|
|
| num_return_sequences = 4 |
| auto_mode = False |
|
|
| slack = Slack() |
|
|
| prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences) |
| upscaler = Upscaler() |
| inpainter = InPainter() |
|
|
|
|
| @update_db |
| @slack.auto_send_alert |
| def remove_bg(task: Task): |
| remove_background = RemoveBackground() |
| output_image = remove_background.remove(task.get_imageUrl()) |
|
|
| output_key = "crecoAI/{}_rmbg.png".format(task.get_taskId()) |
| upload_image(output_image, output_key) |
|
|
| return {"generated_image_url": construct_default_s3_url(output_key)} |
|
|
|
|
| @update_db |
| @slack.auto_send_alert |
| def inpaint(task: Task): |
| prompt = add_code_names(task.get_prompt()) |
| if task.is_prompt_engineering(): |
| prompt = prompt_modifier.modify(prompt) |
| else: |
| prompt = [prompt] * num_return_sequences |
|
|
| print({"prompts": prompt}) |
|
|
| images = inpainter.process( |
| prompt=prompt, |
| image_url=task.get_imageUrl(), |
| mask_image_url=task.get_maskImageUrl(), |
| width=task.get_width(), |
| height=task.get_height(), |
| seed=task.get_seed(), |
| negative_prompt=[task.get_negative_prompt()] * num_return_sequences, |
| ) |
| generated_image_urls = upload_images(images, "_inpaint", task.get_taskId()) |
|
|
| clear_cuda() |
|
|
| return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls} |
|
|
|
|
| @update_db |
| @slack.auto_send_alert |
| def upscale_image(task: Task): |
| output_key = "crecoAI/{}_upscale.png".format(task.get_taskId()) |
| out_img = None |
| if task.get_modelType() == ModelType.ANIME: |
| print("Using Anime model") |
| out_img = upscaler.upscale_anime(task.get_imageUrl()) |
| else: |
| print("Using Real model") |
| out_img = upscaler.upscale(task.get_imageUrl()) |
|
|
| upload_image(BytesIO(out_img), output_key) |
| return {"generated_image_url": construct_default_s3_url(output_key)} |
|
|
|
|
| def model_fn(model_dir): |
| print("Logs: model loaded .... starts") |
|
|
| prompt_modifier.load() |
| upscaler.load() |
| inpainter.load() |
|
|
| print("Logs: model loaded ....") |
| return |
|
|
|
|
| def predict_fn(data, pipe): |
| task = Task(data) |
| print("task is ", data) |
|
|
| try: |
| task_type = task.get_type() |
|
|
| if task_type == TaskType.REMOVE_BG: |
| return remove_bg(task) |
| elif task_type == TaskType.INPAINT: |
| return inpaint(task) |
| elif task_type == TaskType.UPSCALE_IMAGE: |
| return upscale_image(task) |
| else: |
| raise Exception("Invalid task type") |
| except Exception as e: |
| print(f"Error: {e}") |
| slack.error_alert(task, e) |
| return None |
|
|