| | import subprocess |
| | import sys |
| | print("Reinstalling mmcv") |
| | subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "-y", "mmcv-full==1.3.17"]) |
| | subprocess.check_call([sys.executable, "-m", "pip", "install", "mmcv-full==1.3.17", "-f", "https://download.openmmlab.com/mmcv/dist/cpu/torch1.10.0/index.html"]) |
| | print("mmcv install complete") |
| |
|
| | |
| |
|
| | from gradio.outputs import Label |
| | from icevision.all import * |
| | from icevision.models.checkpoint import * |
| | import PIL |
| | import gradio as gr |
| | import os |
| |
|
| | |
| | checkpoint_path = "models/model_checkpoint.pth" |
| | checkpoint_and_model = model_from_checkpoint(checkpoint_path) |
| | model = checkpoint_and_model["model"] |
| | model_type = checkpoint_and_model["model_type"] |
| | class_map = checkpoint_and_model["class_map"] |
| |
|
| | |
| | img_size = checkpoint_and_model["img_size"] |
| | valid_tfms = tfms.A.Adapter([*tfms.A.resize_and_pad(img_size), tfms.A.Normalize()]) |
| |
|
| | for root, dirs, files in os.walk(r"sample_images/"): |
| | for filename in files: |
| | print("Loading sample image:", filename) |
| |
|
| |
|
| | |
| | example_images = [["sample_images/" + file] for file in files] |
| | |
| | examples = [ |
| | [example_images[0], False, True, 0.5], |
| | [example_images[1], True, True, 0.5], |
| | [example_images[2], False, True, 0.7], |
| | [example_images[3], True, True, 0.7], |
| | [example_images[4], False, True, 0.5], |
| | [example_images[5], False, True, 0.5], |
| | [example_images[6], False, True, 0.6], |
| | [example_images[7], False, True, 0.6], |
| | ] |
| |
|
| |
|
| | def show_preds(input_image, display_label, display_bbox, detection_threshold): |
| | if detection_threshold == 0: |
| | detection_threshold = 0.5 |
| | img = PIL.Image.fromarray(input_image, "RGB") |
| | pred_dict = model_type.end2end_detect( |
| | img, |
| | valid_tfms, |
| | model, |
| | class_map=class_map, |
| | detection_threshold=detection_threshold, |
| | display_label=display_label, |
| | display_bbox=display_bbox, |
| | return_img=True, |
| | font_size=16, |
| | label_color="#FF59D6", |
| | ) |
| | return pred_dict["img"], len(pred_dict["detection"]["bboxes"]) |
| |
|
| |
|
| | |
| | display_chkbox_label = gr.inputs.Checkbox(label="Label", default=False) |
| | display_chkbox_box = gr.inputs.Checkbox(label="Box", default=True) |
| | detection_threshold_slider = gr.inputs.Slider( |
| | minimum=0, maximum=1, step=0.1, default=0.5, label="Detection Threshold" |
| | ) |
| | outputs = [ |
| | gr.outputs.Image(type="pil", label="RetinaNet Inference"), |
| | gr.outputs.Textbox(type="number", label="Microalgae Count"), |
| | ] |
| |
|
| | article = "<p style='text-align: center'><a href='https://dicksonneoh.com/' target='_blank'>Blog post</a></p>" |
| |
|
| | |
| | gr_interface = gr.Interface( |
| | fn=show_preds, |
| | inputs=[ |
| | "image", |
| | display_chkbox_label, |
| | display_chkbox_box, |
| | detection_threshold_slider, |
| | ], |
| | outputs=outputs, |
| | title="Microalgae Detector with RetinaNet", |
| | description="This RetinaNet model counts microalgaes on a given image. Upload an image or click an example image below to use.", |
| | article=article, |
| | examples=examples, |
| | ) |
| | |
| | |
| | |
| | |
| | gr_interface.launch(inline=False, share=False, debug=True) |