| | from IPython.display import display, JSON |
| | import matplotlib.pyplot as plt |
| | from speciesnet import DEFAULT_MODEL, SUPPORTED_MODELS, SpeciesNet |
| | import numpy as np |
| | import time |
| | import gradio as gr |
| | import json |
| | import cv2 |
| | import os |
| |
|
| |
|
| | |
| | |
| | |
| | print("Default SpeciesNet model:", DEFAULT_MODEL) |
| | print("Supported SpeciesNet models:", SUPPORTED_MODELS) |
| | model = SpeciesNet(DEFAULT_MODEL) |
| |
|
| |
|
| |
|
| | |
| | |
| | |
| | def validate_predictions_structure(pred): |
| | """ |
| | Validate internal structure for both detection and classification. |
| | This ensures correct keys exist and formats are valid. |
| | """ |
| |
|
| | required_keys = ["filepath", "detections", "classifications"] |
| |
|
| | for key in required_keys: |
| | if key not in pred: |
| | raise ValueError(f" Missing key '{key}' in prediction block") |
| |
|
| | |
| | if not isinstance(pred["detections"], list): |
| | raise ValueError(" detections must be a list") |
| |
|
| | for det in pred["detections"]: |
| | if not all(k in det for k in ["bbox", "conf", "label"]): |
| | raise ValueError(" Each detection must contain bbox, conf, label") |
| |
|
| | if len(det["bbox"]) != 4: |
| | raise ValueError(" bbox must be [x, y, w, h]") |
| |
|
| | |
| | cls = pred["classifications"] |
| | if not isinstance(cls, dict): |
| | raise ValueError(" classifications must be a dictionary") |
| |
|
| | for key in ["classes", "scores"]: |
| | if key not in cls: |
| | raise ValueError(f" classifications missing '{key}'") |
| |
|
| | if len(cls["classes"]) != len(cls["scores"]): |
| | raise ValueError(" classes and scores length mismatch") |
| |
|
| | return True |
| |
|
| |
|
| |
|
| | def validate_model_output(predictions_dict): |
| | """ |
| | Validates entire output returned by SpeciesNet before visualization. |
| | """ |
| |
|
| | if "predictions" not in predictions_dict: |
| | raise ValueError(" Output missing top-level 'predictions' key") |
| |
|
| | if not isinstance(predictions_dict["predictions"], list): |
| | raise ValueError(" 'predictions' must be a list") |
| |
|
| | print(f" Total prediction entries: {len(predictions_dict['predictions'])}") |
| |
|
| | |
| | for i, pred in enumerate(predictions_dict["predictions"]): |
| | print(f"\n--- Checking prediction #{i+1} ---") |
| | validate_predictions_structure(pred) |
| |
|
| | print("\n Output format validated successfully!\n") |
| |
|
| |
|
| |
|
| | |
| | |
| | |
| | def draw_predictions(image_path, predictions_dict): |
| |
|
| | img = cv2.imread(image_path) |
| | if img is None: |
| | raise ValueError(f"Could not load image: {image_path}") |
| |
|
| | img_h, img_w, _ = img.shape |
| |
|
| | for pred in predictions_dict.get("predictions", []): |
| | detections = pred.get("detections", []) |
| | classifications = pred.get("classifications", {}) |
| |
|
| | classes = classifications.get("classes", []) |
| | scores = classifications.get("scores", []) |
| |
|
| | top_class_name = None |
| | top_score = None |
| |
|
| | if len(classes) > 0: |
| | top_class_name = classes[0].split(";")[-1] |
| | top_score = scores[0] |
| |
|
| | |
| | if len(classes) == 0: |
| | continue |
| |
|
| | taxon = classes[0].lower() |
| |
|
| | if not ("mammalia" in taxon or "aves" in taxon): |
| | continue |
| |
|
| | for det in detections: |
| | bbox = det["bbox"] |
| | conf = det["conf"] |
| | label = det["label"] |
| |
|
| | x, y, w, h = bbox |
| | x1 = int(x * img_w) |
| | y1 = int(y * img_h) |
| | x2 = int((x + w) * img_w) |
| | y2 = int((y + h) * img_h) |
| |
|
| | cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 3) |
| |
|
| | detection_text = f"{label} ({conf:.2f})" |
| | classification_text = ( |
| | f"{top_class_name} ({top_score:.2f})" if top_class_name else "" |
| | ) |
| |
|
| | text_lines = [] |
| | if classification_text: |
| | text_lines.append(classification_text) |
| | text_lines.append(detection_text) |
| |
|
| | total_text_height = 0 |
| | text_widths = [] |
| |
|
| | for line in text_lines: |
| | (text_w, text_h), _ = cv2.getTextSize( |
| | line, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2 |
| | ) |
| | total_text_height += text_h + 5 |
| | text_widths.append(text_w) |
| |
|
| | max_text_width = max(text_widths) |
| |
|
| | cv2.rectangle( |
| | img, |
| | (x1, max(y1 - total_text_height - 10, 0)), |
| | (x1 + max_text_width + 10, y1), |
| | (0, 255, 0), |
| | -1, |
| | ) |
| |
|
| | y_text = y1 - 5 |
| | for line in text_lines[::-1]: |
| | cv2.putText( |
| | img, |
| | line, |
| | (x1 + 5, y_text), |
| | cv2.FONT_HERSHEY_SIMPLEX, |
| | 0.6, |
| | (0, 0, 0), |
| | 2, |
| | cv2.LINE_AA, |
| | ) |
| | (_, text_h), _ = cv2.getTextSize( |
| | line, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2 |
| | ) |
| | y_text -= text_h + 5 |
| |
|
| | return cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| |
|
| |
|
| |
|
| | |
| | |
| | |
| | def inference(image): |
| |
|
| | filepath = "temp_image.jpg" |
| | image.save(filepath) |
| |
|
| | start = time.time() |
| | predictions_dict = model.predict( |
| | instances_dict={ |
| | "instances": [ |
| | { |
| | "filepath": filepath, |
| | |
| | } |
| | ] |
| | } |
| | ) |
| | end = time.time() |
| |
|
| | print(f"\n⏱ Inference Time: {end - start:.2f} sec") |
| |
|
| | |
| | validate_model_output(predictions_dict) |
| |
|
| | |
| | with open("last_output.json", "w") as f: |
| | json.dump(predictions_dict, f, indent=4) |
| |
|
| | print(" Saved JSON to last_output.json\n") |
| |
|
| | |
| | annotated_image = draw_predictions(filepath, predictions_dict) |
| |
|
| | pretty_json = json.dumps(predictions_dict, indent=4) |
| |
|
| | return annotated_image, pretty_json |
| |
|
| |
|
| |
|
| | |
| | |
| | |
| | iface = gr.Interface( |
| | fn=inference, |
| | inputs=gr.Image(type="pil"), |
| | outputs=[ |
| | gr.Image(label="Detection + Classification Output"), |
| | gr.JSON(label="Raw Model Output"), |
| | ], |
| | title=" SpeciesNet Wildlife Detector + Classifier", |
| | description="Upload a wildlife camera image.", |
| | ) |
| |
|
| | iface.launch() |
| |
|