Spaces:
Build error
Build error
| import gradio as gr | |
| import os | |
| import sys | |
| import numpy as np | |
| import numpy as np | |
| import torch.backends.cudnn as cudnn | |
| import torch.utils.data | |
| import torch.nn.functional as F | |
| import torchvision.transforms as transforms | |
| from mmcv.utils import Config | |
| sys.path.append('.') | |
| from image_forgery_detection import build_detector | |
| from image_forgery_detection import Compose | |
| transform_pil = transforms.Compose([ | |
| transforms.ToPILImage(), | |
| ]) | |
| def inference_api(f_path): | |
| print(f_path) | |
| results = dict(img_info=dict(filename=f_path, ann=dict(seg_map='None'))) | |
| results['seg_fields'] = [] | |
| results['img_prefix'] = None | |
| results['seg_prefix'] = None | |
| inputs = pipelines(results) | |
| img = inputs['img'].data | |
| img_meta = inputs['img_metas'].data | |
| if 'dct_vol' in inputs: | |
| dct_vol = inputs['dct_vol'].data | |
| qtables = inputs['qtables'].data | |
| with torch.no_grad(): | |
| img = img.unsqueeze(dim=0) | |
| if 'dct_vol' in inputs: | |
| dct_vol = dct_vol.unsqueeze(dim=0) | |
| qtables = qtables.unsqueeze(dim=0) | |
| cls_pred, seg_pred = model(img, dct_vol, qtables, [img_meta, ], return_loss=False, rescale=True) | |
| else: | |
| cls_pred, seg_pred = model(img, [img_meta, ], return_loss=False, rescale=True) | |
| cls_pred = cls_pred[0] | |
| seg = seg_pred[0, 0] | |
| seg = np.array(transform_pil(torch.from_numpy(seg))) | |
| thresh_int = 255 * thresh | |
| seg[seg>=thresh_int] = 255 | |
| seg[seg<thresh_int] = 0 | |
| return '{:.3f}'.format(cls_pred), seg | |
| if __name__ == '__main__': | |
| model_path = './models/latest.pth' | |
| cfg = Config.fromfile('./models/config.py') | |
| global model | |
| global pipelines | |
| global thresh | |
| thresh = 0.5 | |
| if hasattr(cfg.model.base_model, 'backbone'): | |
| cfg.model.base_model.backbone.pretrained = None | |
| else: | |
| cfg.model.base_model.pretrained = None | |
| model = build_detector(cfg.model) | |
| if os.path.exists(model_path): | |
| checkpoint = torch.load(model_path, map_location='cpu')['state_dict'] | |
| model.load_state_dict(checkpoint, strict=True) | |
| print("load %s finish" % (os.path.basename(model_path))) | |
| else: | |
| print("%s not exist" % model_path) | |
| exit(1) | |
| model.eval() | |
| pipelines = Compose(cfg.data.val[0].pipeline) | |
| iface = gr.Interface( | |
| inference_api, | |
| inputs=gr.components.Image(label="Upload image to detect", type="filepath"), | |
| # outputs=['text', 'image'], | |
| outputs=[gr.components.Textbox(type="text", label="image forgery score"), | |
| gr.components.Image(type="numpy", label="predict mask")], | |
| title="Forged? Or Not?", | |
| ) | |
| # iface.launch(server_name='0.0.0.0', share=True) | |
| iface.launch() | |