|
|
|
|
| import sys, os, random, numpy as np, torch |
| sys.path.append("../") |
|
|
| from PIL import Image |
| import spaces |
| import gradio as gr |
| from gradio.themes import Soft |
| from huggingface_hub import hf_hub_download |
| from transformers import AutoModelForImageSegmentation |
| from torchvision import transforms |
|
|
| from pipeline import InstantCharacterFluxPipeline |
|
|
| |
| |
| |
| MAX_SEED = np.iinfo(np.int32).max |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| dtype = torch.float16 if device == "cuda" else torch.float32 |
|
|
| |
| |
| |
| ip_adapter_path = hf_hub_download("tencent/InstantCharacter", |
| "instantcharacter_ip-adapter.bin") |
| base_model = "black-forest-labs/FLUX.1-dev" |
| image_encoder_path = "google/siglip-so400m-patch14-384" |
| image_encoder2_path = "facebook/dinov2-giant" |
| birefnet_path = "ZhengPeng7/BiRefNet" |
| makoto_style_path = hf_hub_download("InstantX/FLUX.1-dev-LoRA-Makoto-Shinkai", |
| "Makoto_Shinkai_style.safetensors") |
| ghibli_style_path = hf_hub_download("InstantX/FLUX.1-dev-LoRA-Ghibli", |
| "ghibli_style.safetensors") |
|
|
| |
| |
| |
| pipe = InstantCharacterFluxPipeline.from_pretrained(base_model, |
| torch_dtype=torch.bfloat16) |
| pipe.to(device) |
| pipe.init_adapter( |
| image_encoder_path=image_encoder_path, |
| image_encoder_2_path=image_encoder2_path, |
| subject_ipadapter_cfg=dict(subject_ip_adapter_path=ip_adapter_path, |
| nb_token=1024), |
| ) |
|
|
| |
| |
| |
| birefnet = AutoModelForImageSegmentation.from_pretrained(birefnet_path, |
| trust_remote_code=True) |
| birefnet.to(device).eval() |
| birefnet_tf = transforms.Compose([ |
| transforms.Resize((1024, 1024)), |
| transforms.ToTensor(), |
| transforms.Normalize([0.485, 0.456, 0.406], |
| [0.229, 0.224, 0.225]), |
| ]) |
|
|
| |
| |
| |
| def randomize_seed_fn(seed: int, randomize: bool) -> int: |
| return random.randint(0, MAX_SEED) if randomize else seed |
|
|
| def _infer_matting(img_pil): |
| with torch.no_grad(): |
| inp = birefnet_tf(img_pil).unsqueeze(0).to(device) |
| mask = birefnet(inp)[-1].sigmoid().cpu()[0, 0].numpy() |
| return (mask * 255).astype(np.uint8) |
|
|
| def _bbox_from_mask(mask, th=128): |
| ys, xs = np.where(mask >= th) |
| if not len(xs): |
| return [0, 0, mask.shape[1]-1, mask.shape[0]-1] |
| return [xs.min(), ys.min(), xs.max(), ys.max()] |
|
|
| def _pad_square(arr, pad_val=255): |
| h, w = arr.shape[:2] |
| if h == w: |
| return arr |
| diff = abs(h - w) |
| pad_1 = diff // 2 |
| pad_2 = diff - pad_1 |
| if h > w: |
| pad = ((0, 0), (pad_1, pad_2), (0, 0)) |
| else: |
| pad = ((pad_1, pad_2), (0, 0), (0, 0)) |
| return np.pad(arr, pad, constant_values=pad_val) |
|
|
| def remove_bkg(img_pil: Image.Image) -> Image.Image: |
| mask = _infer_matting(img_pil) |
| x1, y1, x2, y2 = _bbox_from_mask(mask) |
| mask_bin = (mask >= 128).astype(np.uint8)[..., None] |
| img_np = np.array(img_pil) |
| obj = mask_bin * img_np + (1 - mask_bin) * 255 |
| crop = obj[y1:y2+1, x1:x2+1] |
| return Image.fromarray(_pad_square(crop).astype(np.uint8)) |
|
|
| def get_example(): |
| return [ |
| ["./assets/girl.jpg", |
| "A girl is playing a guitar in street", 0.9, "Makoto Shinkai style"], |
| ["./assets/boy.jpg", |
| "A boy is riding a bike in snow", 0.9, "Makoto Shinkai style"], |
| ] |
|
|
| @spaces.GPU |
| def create_image(input_image, prompt, scale, |
| guidance_scale, num_inference_steps, |
| seed, style_mode): |
| input_image = remove_bkg(input_image) |
| gen = torch.manual_seed(seed) |
|
|
| if style_mode is None: |
| imgs = pipe(prompt=prompt, |
| num_inference_steps=num_inference_steps, |
| guidance_scale=guidance_scale, |
| width=1024, height=1024, |
| subject_image=input_image, subject_scale=scale, |
| generator=gen).images |
| else: |
| lora_path, trigger = ( |
| (makoto_style_path, "Makoto Shinkai style") |
| if style_mode == "Makoto Shinkai style" |
| else (ghibli_style_path, "ghibli style") |
| ) |
| imgs = pipe.with_style_lora( |
| lora_file_path=lora_path, trigger=trigger, |
| prompt=prompt, num_inference_steps=num_inference_steps, |
| guidance_scale=guidance_scale, |
| width=1024, height=1024, |
| subject_image=input_image, subject_scale=scale, |
| generator=gen).images |
| return imgs |
|
|
| def run_for_examples(src, p, s, st): |
| return create_image(src, p, s, 3.5, 28, 123456, st) |
|
|
| |
| |
| |
| theme = Soft(primary_hue="pink", |
| font=[gr.themes.GoogleFont("Inter")]) |
|
|
| css = """ |
| body{ |
| background:#141e30; |
| background:linear-gradient(135deg,#141e30,#243b55); |
| } |
| #title{ |
| text-align:center; |
| font-size:2.2rem; |
| font-weight:700; |
| color:#ffffff; |
| padding:20px 0 6px; |
| } |
| .card{ |
| border-radius:18px; |
| background:#ffffff0d; |
| padding:18px 22px; |
| backdrop-filter:blur(6px); |
| } |
| .gr-image,.gr-video{border-radius:14px} |
| .gr-image:hover{box-shadow:0 0 0 4px #ec4899} |
| footer{visibility:hidden} |
| """ |
|
|
| |
| |
| |
| with gr.Blocks(css=css, theme=theme) as demo: |
| |
| gr.Markdown("<div id='title'>InstantCharacter PLUS</div>") |
| gr.Markdown( |
| "<b>Official 🤗 Gradio demo of " |
| "<a href='https://instantcharacter.github.io/' target='_blank'>InstantCharacter</a></b>" |
| ) |
|
|
| with gr.Tabs(): |
| with gr.TabItem("Generate"): |
| with gr.Row(equal_height=True): |
| |
| with gr.Column(elem_classes="card"): |
| image_pil = gr.Image(label="Source Image", |
| type="pil", height=380) |
| prompt = gr.Textbox( |
| label="Prompt", |
| value="A character is riding a bike in snow", |
| lines=2, |
| ) |
| scale = gr.Slider(0, 1.5, 1.0, step=0.01, label="Scale") |
| style_mode = gr.Dropdown( |
| ["None", "Makoto Shinkai style", "Ghibli style"], |
| label="Style", |
| value="Makoto Shinkai style", |
| ) |
|
|
| with gr.Accordion("⚙️ Advanced Options", open=False): |
| guidance_scale = gr.Slider( |
| 1, 7, 3.5, step=0.01, label="Guidance scale" |
| ) |
| num_inference_steps = gr.Slider( |
| 5, 50, 28, step=1, label="# Inference steps" |
| ) |
| seed = gr.Number(123456, label="Seed", precision=0) |
| randomize_seed = gr.Checkbox( |
| label="Randomize seed", value=True |
| ) |
|
|
| generate_btn = gr.Button( |
| "🚀 Generate", |
| variant="primary", |
| size="lg", |
| elem_classes="contrast", |
| ) |
|
|
| |
| with gr.Column(elem_classes="card"): |
| generated_image = gr.Gallery( |
| label="Generated Image", |
| show_label=True, |
| height="auto", |
| columns=[1], |
| ) |
|
|
| |
| generate_btn.click( |
| randomize_seed_fn, |
| [seed, randomize_seed], |
| seed, |
| queue=False, |
| ).then( |
| create_image, |
| [ |
| image_pil, |
| prompt, |
| scale, |
| guidance_scale, |
| num_inference_steps, |
| seed, |
| style_mode, |
| ], |
| generated_image, |
| ) |
|
|
| |
| gr.Markdown("### 🔥 Quick Examples") |
| gr.Examples( |
| examples=get_example(), |
| inputs=[image_pil, prompt, scale, style_mode], |
| outputs=generated_image, |
| fn=run_for_examples, |
| cache_examples=True, |
| ) |
|
|
| |
| |
| |
| if __name__ == "__main__": |
| demo.queue(max_size=10, api_open=False).launch() |
|
|