| from __future__ import annotations
|
|
|
| import logging
|
|
|
| import torch
|
|
|
| from modules import (
|
| devices,
|
| errors,
|
| face_restoration,
|
| face_restoration_utils,
|
| modelloader,
|
| shared,
|
| )
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
| model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
| model_download_name = 'codeformer-v0.1.0.pth'
|
|
|
|
|
| codeformer: face_restoration.FaceRestoration | None = None
|
|
|
|
|
| class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration):
|
| def name(self):
|
| return "CodeFormer"
|
|
|
| def load_net(self) -> torch.Module:
|
| for model_path in modelloader.load_models(
|
| model_path=self.model_path,
|
| model_url=model_url,
|
| command_path=self.model_path,
|
| download_name=model_download_name,
|
| ext_filter=['.pth'],
|
| ):
|
| return modelloader.load_spandrel_model(
|
| model_path,
|
| device=devices.device_codeformer,
|
| expected_architecture='CodeFormer',
|
| ).model
|
| raise ValueError("No codeformer model found")
|
|
|
| def get_device(self):
|
| return devices.device_codeformer
|
|
|
| def restore(self, np_image, w: float | None = None):
|
| if w is None:
|
| w = getattr(shared.opts, "code_former_weight", 0.5)
|
|
|
| def restore_face(cropped_face_t):
|
| assert self.net is not None
|
| return self.net(cropped_face_t, weight=w, adain=True)[0]
|
|
|
| return self.restore_with_helper(np_image, restore_face)
|
|
|
|
|
| def setup_model(dirname: str) -> None:
|
| global codeformer
|
| try:
|
| codeformer = FaceRestorerCodeFormer(dirname)
|
| shared.face_restorers.append(codeformer)
|
| except Exception:
|
| errors.report("Error setting up CodeFormer", exc_info=True)
|
|
|