| from __future__ import annotations |
| from fal_serverless import isolated, cached |
|
|
| from pathlib import Path |
| import base64 |
| import io |
|
|
| requirements = [ |
| "controlnet-aux", |
| "diffusers", |
| "torch", |
| "mediapipe", |
| "transformers", |
| "accelerate", |
| "xformers" |
| ] |
|
|
|
|
| def get_image_from_url_as_bytes(url: str) -> bytes: |
| import requests |
|
|
| response = requests.get(url) |
| |
| response.raise_for_status() |
| return response.content |
|
|
| def read_image_bytes(file_path): |
| with open(file_path, "rb") as file: |
| image_bytes = file.read() |
| return image_bytes |
|
|
| @cached |
| def load_model(): |
| import torch |
| from diffusers import StableDiffusionControlNetPipeline, ControlNetModel |
|
|
| controlnet = ControlNetModel.from_pretrained( |
| "lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16 |
| ) |
| pipe = StableDiffusionControlNetPipeline.from_pretrained( |
| "peterwilli/deliberate-2", controlnet=controlnet, torch_dtype=torch.float16 |
| ) |
|
|
| pipe = pipe.to("cuda:0") |
| pipe.unet.to(memory_format=torch.channels_last) |
| pipe.controlnet.to(memory_format=torch.channels_last) |
| return pipe |
|
|
|
|
| def resize_image(input_image, resolution): |
| import cv2 |
| import numpy as np |
|
|
| H, W, C = input_image.shape |
| H = float(H) |
| W = float(W) |
| k = float(resolution) / min(H, W) |
| H *= k |
| W *= k |
| H = int(np.round(H / 64.0)) * 64 |
| W = int(np.round(W / 64.0)) * 64 |
| img = cv2.resize( |
| input_image, |
| (W, H), |
| interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA, |
| ) |
| return img |
|
|
| @isolated( |
| requirements=requirements, |
| machine_type="GPU", |
| keep_alive=30, |
| serve=True |
| ) |
| def generate( |
| image_url: str, prompt: str, num_samples: int, num_steps: int, gcs=False |
| ) -> list[bytes] | None: |
|
|
| from controlnet_aux import CannyDetector |
| from PIL import Image |
| import numpy as np |
| import uuid |
| import os |
| from base64 import b64encode |
|
|
| image_bytes = get_image_from_url_as_bytes(image_url) |
|
|
| pipe = load_model() |
| image = Image.open(io.BytesIO(image_bytes)) |
|
|
| canny = CannyDetector() |
| init_image = image.convert("RGB") |
|
|
| init_image = resize_image(np.asarray(init_image), 512) |
| detected_map = canny(init_image, 100, 200) |
| image = Image.fromarray(detected_map) |
|
|
| negative_prompt = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality" |
| results = pipe( |
| prompt=prompt, |
| image=image, |
| negative_prompt=negative_prompt, |
| num_inference_steps=num_steps, |
| num_images_per_prompt=num_samples |
| ).images |
|
|
| result_id = uuid.uuid4() |
| out_dir = Path(f"/data/cn-results/{result_id}") |
| out_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
| for i, res in enumerate(results): |
| res.save(out_dir / f"res_{i}.png") |
|
|
| file_names = [ |
| f for f in os.listdir(out_dir) if os.path.isfile(os.path.join(out_dir, f)) |
| ] |
|
|
| list_of_bytes = [read_image_bytes(out_dir / f) for f in file_names] |
| raw_image = list_of_bytes[0] |
|
|
| return b64encode(raw_image).decode("utf-8") |