| | |
| | |
| | |
| |
|
| | import os |
| | import torch |
| | from torchvision.utils import save_image |
| | import tempfile |
| | from templates import * |
| | from templates_cls import * |
| | from experiment_classifier import ClsModel |
| | from align import LandmarksDetector, image_align |
| | from cog import BasePredictor, Path, Input, BaseModel |
| |
|
| |
|
| | class ModelOutput(BaseModel): |
| | image: Path |
| |
|
| |
|
| | class Predictor(BasePredictor): |
| | def setup(self): |
| | self.aligned_dir = "aligned" |
| | os.makedirs(self.aligned_dir, exist_ok=True) |
| | self.device = "cuda:0" |
| |
|
| | |
| | model_config = ffhq256_autoenc() |
| | self.model = LitModel(model_config) |
| | state = torch.load("checkpoints/ffhq256_autoenc/last.ckpt", map_location="cpu") |
| | self.model.load_state_dict(state["state_dict"], strict=False) |
| | self.model.ema_model.eval() |
| | self.model.ema_model.to(self.device) |
| |
|
| | |
| | classifier_config = ffhq256_autoenc_cls() |
| | classifier_config.pretrain = None |
| | self.classifier = ClsModel(classifier_config) |
| | state_class = torch.load( |
| | "checkpoints/ffhq256_autoenc_cls/last.ckpt", map_location="cpu" |
| | ) |
| | print("latent step:", state_class["global_step"]) |
| | self.classifier.load_state_dict(state_class["state_dict"], strict=False) |
| | self.classifier.to(self.device) |
| |
|
| | self.landmarks_detector = LandmarksDetector( |
| | "shape_predictor_68_face_landmarks.dat" |
| | ) |
| |
|
| | def predict( |
| | self, |
| | image: Path = Input( |
| | description="Input image for face manipulation. Image will be aligned and cropped, " |
| | "output aligned and manipulated images.", |
| | ), |
| | target_class: str = Input( |
| | default="Bangs", |
| | choices=[ |
| | "5_o_Clock_Shadow", |
| | "Arched_Eyebrows", |
| | "Attractive", |
| | "Bags_Under_Eyes", |
| | "Bald", |
| | "Bangs", |
| | "Big_Lips", |
| | "Big_Nose", |
| | "Black_Hair", |
| | "Blond_Hair", |
| | "Blurry", |
| | "Brown_Hair", |
| | "Bushy_Eyebrows", |
| | "Chubby", |
| | "Double_Chin", |
| | "Eyeglasses", |
| | "Goatee", |
| | "Gray_Hair", |
| | "Heavy_Makeup", |
| | "High_Cheekbones", |
| | "Male", |
| | "Mouth_Slightly_Open", |
| | "Mustache", |
| | "Narrow_Eyes", |
| | "Beard", |
| | "Oval_Face", |
| | "Pale_Skin", |
| | "Pointy_Nose", |
| | "Receding_Hairline", |
| | "Rosy_Cheeks", |
| | "Sideburns", |
| | "Smiling", |
| | "Straight_Hair", |
| | "Wavy_Hair", |
| | "Wearing_Earrings", |
| | "Wearing_Hat", |
| | "Wearing_Lipstick", |
| | "Wearing_Necklace", |
| | "Wearing_Necktie", |
| | "Young", |
| | ], |
| | description="Choose manipulation direction.", |
| | ), |
| | manipulation_amplitude: float = Input( |
| | default=0.3, |
| | ge=-0.5, |
| | le=0.5, |
| | description="When set too strong it would result in artifact as it could dominate the original image information.", |
| | ), |
| | T_step: int = Input( |
| | default=100, |
| | choices=[50, 100, 125, 200, 250, 500], |
| | description="Number of step for generation.", |
| | ), |
| | T_inv: int = Input(default=200, choices=[50, 100, 125, 200, 250, 500]), |
| | ) -> List[ModelOutput]: |
| |
|
| | img_size = 256 |
| | print("Aligning image...") |
| | for i, face_landmarks in enumerate( |
| | self.landmarks_detector.get_landmarks(str(image)), start=1 |
| | ): |
| | image_align(str(image), f"{self.aligned_dir}/aligned.png", face_landmarks) |
| |
|
| | data = ImageDataset( |
| | self.aligned_dir, |
| | image_size=img_size, |
| | exts=["jpg", "jpeg", "JPG", "png"], |
| | do_augment=False, |
| | ) |
| |
|
| | print("Encoding and Manipulating the aligned image...") |
| | cls_manipulation_amplitude = manipulation_amplitude |
| | interpreted_target_class = target_class |
| | if ( |
| | target_class not in CelebAttrDataset.id_to_cls |
| | and f"No_{target_class}" in CelebAttrDataset.id_to_cls |
| | ): |
| | cls_manipulation_amplitude = -manipulation_amplitude |
| | interpreted_target_class = f"No_{target_class}" |
| |
|
| | batch = data[0]["img"][None] |
| |
|
| | semantic_latent = self.model.encode(batch.to(self.device)) |
| | stochastic_latent = self.model.encode_stochastic( |
| | batch.to(self.device), semantic_latent, T=T_inv |
| | ) |
| |
|
| | cls_id = CelebAttrDataset.cls_to_id[interpreted_target_class] |
| | class_direction = self.classifier.classifier.weight[cls_id] |
| | normalized_class_direction = F.normalize(class_direction[None, :], dim=1) |
| |
|
| | normalized_semantic_latent = self.classifier.normalize(semantic_latent) |
| | normalized_manipulation_amp = cls_manipulation_amplitude * math.sqrt(512) |
| | normalized_manipulated_semantic_latent = ( |
| | normalized_semantic_latent |
| | + normalized_manipulation_amp * normalized_class_direction |
| | ) |
| |
|
| | manipulated_semantic_latent = self.classifier.denormalize( |
| | normalized_manipulated_semantic_latent |
| | ) |
| |
|
| | |
| | manipulated_img = self.model.render( |
| | stochastic_latent, manipulated_semantic_latent, T=T_step |
| | )[0] |
| | original_img = data[0]["img"] |
| |
|
| | model_output = [] |
| | out_path = Path(tempfile.mkdtemp()) / "original_aligned.png" |
| | save_image(convert2rgb(original_img), str(out_path)) |
| | model_output.append(ModelOutput(image=out_path)) |
| |
|
| | out_path = Path(tempfile.mkdtemp()) / "manipulated_img.png" |
| | save_image(convert2rgb(manipulated_img, adjust_scale=False), str(out_path)) |
| | model_output.append(ModelOutput(image=out_path)) |
| | return model_output |
| |
|
| |
|
| | def convert2rgb(img, adjust_scale=True): |
| | convert_img = torch.tensor(img) |
| | if adjust_scale: |
| | convert_img = (convert_img + 1) / 2 |
| | return convert_img.cpu() |
| |
|