import sys import gradio as gr from model import load_model, predict_mask from utils import ( preprocess_image_pil, postprocess_mask, overlay_mask, generate_gradcam, overlay_gradcam ) from PIL import Image, ImageDraw import os import logging import torch # === Logging Setup === logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # === Model Setup === MODEL_PATH = "best_model_large_data.pth" def initialize_model(): try: if not os.path.exists(MODEL_PATH): raise FileNotFoundError(f"Model file not found: {MODEL_PATH}") logger.info("⏳ Loading model...") model = load_model(MODEL_PATH, device="cpu") logger.info("✅ Model loaded successfully") return model except Exception as e: logger.critical(f"❌ Model load failed: {str(e)}") raise # Initialize model with error fallback try: model = initialize_model() except Exception as e: logger.error(f"🚨 Failed to initialize model: {e}") error_img = Image.new("RGB", (256, 256), color="black") draw = ImageDraw.Draw(error_img) draw.text((10, 10), "MODEL LOAD FAILED", fill="red") draw.text((10, 40), str(e)[:100], fill="white") # Dummy model for graceful failure class DummyModel: def __call__(self, *args, **kwargs): raise RuntimeError("Model failed to load") model = DummyModel() def segment_brain(image): """Process MRI image with Grad-CAM and calculate tumor density with severity + recommendation.""" try: logger.info("Starting new segmentation") # Step 1: Preprocess image image_tensor, orig_image = preprocess_image_pil(image) if image_tensor is None: raise ValueError("Preprocessing returned None") image_tensor.requires_grad_(True) # Step 2: Predict tumor mask logger.info("Running segmentation...") mask_tensor = predict_mask(model, image_tensor) # Step 3: Calculate density metrics tumor_density = torch.sum(mask_tensor).item() tumor_area = torch.sum(mask_tensor > 0.5).item() tumor_mean = tumor_density / torch.numel(mask_tensor) # Step 4: Define severity using both area and mean if tumor_area > 200 and tumor_mean > 0.15: severity = "High" elif tumor_area > 200 and tumor_mean > 0.05: severity = "Medium" elif tumor_area > 100 and tumor_mean > 0.01: severity = "Low" elif tumor_area > 50: severity = "Very Low (but tumor-like region exists)" else: severity = "No detectable tumor" # Step 5: Emoji mapping emoji = { "High": "🔴", "Medium": "🟠", "Low": "🟡", "Very Low (but tumor-like region exists)": "🟢", "No detectable tumor": "✅" } # Step 6: Create stats summary stats_text = ( f"{emoji[severity]} Density: {tumor_density:.2f}, " f"Area: {int(tumor_area)}, Mean: {tumor_mean:.4f}, " f"Severity: {severity}" ) # Step 7: Create clinical recommendation if "High" in severity: recommendation = "🔴 Significant tumor-like region detected. Please consult a radiologist immediately." elif "Medium" in severity: recommendation = "🟠 Suspicious tumor region detected. Radiologist review recommended." elif "Low" in severity: recommendation = "🟡 Possible small abnormality detected. Review advised." elif "Very Low" in severity: recommendation = "🟢 Minor tumor-like pattern seen. May be early stage or low confidence — recommend radiologist review." else: recommendation = "✅ No visible tumor region detected." # Step 8: Generate Grad-CAM visualization gradcam_img = generate_gradcam(model, image_tensor, orig_image.size) gradcam_overlay = overlay_gradcam(orig_image, gradcam_img) # Step 9: Return all results return [image, gradcam_overlay, stats_text, recommendation] except Exception as e: logger.error(f"Segmentation Error: {str(e)}") error_img = Image.new("RGB", (256, 256), color="black") draw = ImageDraw.Draw(error_img) draw.text((10, 10), "PROCESSING FAILED", fill="red") return [image, error_img, "Error", "Error"] with gr.Blocks(analytics_enabled=False) as demo: gr.Markdown("## 🧠 Brain Tumor Grad-CAM with Severity Analysis") with gr.Row(): input_image = gr.Image(type="pil", label="Upload Brain MRI") with gr.Row(): original_display = gr.Image(type="pil", label="Input Image") gradcam_overlay_output = gr.Image(type="pil", label="Grad-CAM Overlay") with gr.Row(): density_label = gr.Textbox(label="Tumor Density Stats", interactive=False) recommendation_label = gr.Textbox(label="Clinical Recommendation", interactive=False) input_image.change( fn=segment_brain, inputs=input_image, outputs=[original_display, gradcam_overlay_output, density_label, recommendation_label] ) if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, share=False, show_api=False )