""" MONAI WholeBody CT Segmentation - Hugging Face Space Segments 104 anatomical structures from CT scans using MONAI's SegResNet model """ import os import tempfile import numpy as np import gradio as gr import torch import nibabel as nib import matplotlib.pyplot as plt from matplotlib.colors import ListedColormap from huggingface_hub import hf_hub_download from monai.networks.nets import SegResNet from monai.transforms import ( Compose, LoadImage, EnsureChannelFirst, Orientation, Spacing, ScaleIntensityRange, CropForeground, Activations, AsDiscrete, ) from monai.inferers import sliding_window_inference from labels import LABEL_NAMES, get_color_map, get_label_name, get_organ_categories import trimesh from skimage import measure # Constants DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") MODEL_REPO = "MONAI/wholeBody_ct_segmentation" SPATIAL_SIZE = (96, 96, 96) PIXDIM = (3.0, 3.0, 3.0) # Low-res model spacing # Global model variable model = None def load_model(): """Download and load the MONAI SegResNet model""" global model if model is not None: return model print("Downloading model weights...") try: model_path = hf_hub_download( repo_id=MODEL_REPO, filename="models/model_lowres.pt", ) except Exception as e: print(f"Failed to download from HF, trying alternative: {e}") # Fallback: try to download from MONAI model zoo model_path = hf_hub_download( repo_id=MODEL_REPO, filename="models/model.pt", ) print(f"Loading model from {model_path}...") # Initialize SegResNet with 105 output channels (background + 104 classes) model = SegResNet( blocks_down=[1, 2, 2, 4], blocks_up=[1, 1, 1], init_filters=32, in_channels=1, out_channels=105, dropout_prob=0.2, ) # Load weights checkpoint = torch.load(model_path, map_location=DEVICE) if isinstance(checkpoint, dict) and "state_dict" in checkpoint: model.load_state_dict(checkpoint["state_dict"]) else: model.load_state_dict(checkpoint) model.to(DEVICE) model.eval() print(f"Model loaded successfully on {DEVICE}") return model def get_preprocessing_transforms(): """Get MONAI preprocessing transforms""" return Compose([ LoadImage(image_only=True), EnsureChannelFirst(), Orientation(axcodes="RAS"), Spacing(pixdim=PIXDIM, mode="bilinear"), ScaleIntensityRange( a_min=-1024, a_max=1024, b_min=0.0, b_max=1.0, clip=True ), ]) def get_postprocessing_transforms(): """Get MONAI postprocessing transforms""" return Compose([ Activations(softmax=True), AsDiscrete(argmax=True), ]) def run_inference(image_path: str, progress=gr.Progress()): """Run segmentation inference on a CT image""" progress(0.1, desc="Loading model...") model = load_model() progress(0.2, desc="Preprocessing image...") preprocess = get_preprocessing_transforms() postprocess = get_postprocessing_transforms() # Load and preprocess image_nib = nib.load(image_path) original_data = image_nib.get_fdata() # Keep original data for visualization image = preprocess(image_path) image = image.unsqueeze(0).to(DEVICE) # Add batch dimension progress(0.4, desc="Running segmentation (this may take a few minutes)...") with torch.no_grad(): # Use sliding window inference for large volumes outputs = sliding_window_inference( image, roi_size=SPATIAL_SIZE, sw_batch_size=4, predictor=model, overlap=0.5, ) progress(0.8, desc="Post-processing...") # Post-processing seg_data = postprocess(outputs).squeeze().cpu().numpy().astype(np.uint8) progress(1.0, desc="Complete!") return original_data, seg_data def generate_3d_mesh(seg_data, step_size=2): """Generate a 3D mesh from segmentation data using Marching Cubes""" if seg_data is None or np.max(seg_data) == 0: return None try: # Create a boolean mask of all structures (excluding background 0) # Using a step_size > 1 reduces resolution but speeds up generation significantly # This is crucial for CPU performance on Hugging Face Spaces mask = seg_data > 0 # Marching cubes to get vertices and faces # level=0.5 because boolean mask is 0 or 1 verts, faces, normals, values = measure.marching_cubes(mask, level=0.5, step_size=step_size) # Create trimesh object mesh = trimesh.Trimesh(vertices=verts, faces=faces, vertex_normals=normals) # Export to a temporary GLB file (efficient binary format) temp_file = tempfile.NamedTemporaryFile(suffix=".glb", delete=False) mesh.export(temp_file.name) temp_file.close() return temp_file.name except Exception as e: print(f"Error generating 3D mesh: {e}") return None def create_slice_visualization(ct_data, seg_data, axis, slice_idx, alpha=0.5, show_overlay=True): """Create a visualization of a CT slice with segmentation overlay""" # Get the slice based on axis if axis == "Axial": slice_idx = max(0, min(slice_idx, ct_data.shape[2] - 1)) ct_slice = ct_data[:, :, slice_idx] seg_slice = seg_data[:, :, slice_idx] if seg_data is not None else None elif axis == "Coronal": slice_idx = max(0, min(slice_idx, ct_data.shape[1] - 1)) ct_slice = ct_data[:, slice_idx, :] seg_slice = seg_data[:, slice_idx, :] if seg_data is not None else None else: # Sagittal slice_idx = max(0, min(slice_idx, ct_data.shape[0] - 1)) ct_slice = ct_data[slice_idx, :, :] seg_slice = seg_data[slice_idx, :, :] if seg_data is not None else None # Create figure fig, ax = plt.subplots(1, 1, figsize=(8, 8)) # Normalize CT for display ct_normalized = np.clip(ct_slice, -1024, 1024) ct_normalized = (ct_normalized - ct_normalized.min()) / (ct_normalized.max() - ct_normalized.min() + 1e-8) # Display CT ax.imshow(ct_normalized.T, cmap='gray', origin='lower') # Overlay segmentation if show_overlay and seg_slice is not None and np.any(seg_slice > 0): colors = get_color_map() / 255.0 colors[0] = [0, 0, 0, 0] # Make background transparent # Create RGBA overlay seg_rgba = colors[seg_slice.astype(int)] seg_rgba = np.concatenate([seg_rgba, np.ones((*seg_slice.shape, 1)) * alpha], axis=-1) seg_rgba[seg_slice == 0, 3] = 0 # Transparent background ax.imshow(seg_rgba.transpose(1, 0, 2), origin='lower') ax.axis('off') ax.set_title(f"{axis} View - Slice {slice_idx}") plt.tight_layout() return fig def get_detected_structures(seg_data): """Get list of detected anatomical structures""" unique_labels = np.unique(seg_data) unique_labels = unique_labels[unique_labels > 0] # Exclude background structures = [] for label in unique_labels: name = get_label_name(label) count = np.sum(seg_data == label) structures.append(f"â€ĸ {name} (Label {label})") return "\n".join(structures) if structures else "No structures detected" # Global state for current visualization current_ct_data = None current_seg_data = None def process_upload(file_path, progress=gr.Progress()): """Process uploaded CT file and run segmentation""" global current_ct_data, current_seg_data if file_path is None: return None, "Please upload a NIfTI file", gr.update(maximum=1), gr.update(maximum=1), gr.update(maximum=1) try: ct_data, seg_data = run_inference(file_path, progress) current_ct_data = ct_data current_seg_data = seg_data # Get initial visualization mid_axial = ct_data.shape[2] // 2 mid_coronal = ct_data.shape[1] // 2 mid_sagittal = ct_data.shape[0] // 2 fig = create_slice_visualization(ct_data, seg_data, "Axial", mid_axial) structures = get_detected_structures(seg_data) # Generate 3D mesh (this might take a few seconds) mesh_path = generate_3d_mesh(seg_data) return ( fig, structures, mesh_path, gr.update(maximum=ct_data.shape[2] - 1, value=mid_axial), gr.update(maximum=ct_data.shape[1] - 1, value=mid_coronal), gr.update(maximum=ct_data.shape[0] - 1, value=mid_sagittal), ) except Exception as e: return None, f"Error processing file: {str(e)}", None, gr.update(), gr.update(), gr.update() def update_visualization(axis, slice_idx, alpha, show_overlay): """Update the visualization based on slider changes""" global current_ct_data, current_seg_data if current_ct_data is None: return None fig = create_slice_visualization( current_ct_data, current_seg_data, axis, int(slice_idx), alpha, show_overlay ) return fig def load_example(example_name): """Load a bundled example CT scan""" example_dir = os.path.join(os.path.dirname(__file__), "examples") example_path = os.path.join(example_dir, example_name) if os.path.exists(example_path): return example_path return None # Create Gradio interface with gr.Blocks( title="MONAI WholeBody CT Segmentation", theme=gr.themes.Soft(), css=""" .gradio-container {max-width: 1200px !important} .output-image {min-height: 500px} """ ) as demo: gr.Markdown(""" # đŸĨ MONAI WholeBody CT Segmentation **Automatic segmentation of 104 anatomical structures from CT scans** This application uses MONAI's pre-trained SegResNet model trained on the TotalSegmentator dataset. Upload a CT scan in NIfTI format (.nii or .nii.gz) to get started. > ⚡ **Note**: Processing may take 1-5 minutes depending on the CT volume size. """) with gr.Row(): with gr.Column(scale=1): # Input section gr.Markdown("### 📤 Upload CT Scan") file_input = gr.File( label="Upload NIfTI file (.nii, .nii.gz)", file_types=[".nii", ".nii.gz", ".gz"], type="filepath" ) # Example files gr.Markdown("### 📁 Example Files") # Dynamically list all .nii.gz files in examples folder example_files = [[os.path.join("examples", f)] for f in os.listdir("examples") if f.endswith(".nii.gz")] example_gallery = gr.Examples( examples=example_files, inputs=[file_input], label="Click to load example" ) process_btn = gr.Button("đŸ”Ŧ Run Segmentation", variant="primary", size="lg") # Visualization controls gr.Markdown("### đŸŽ›ī¸ Visualization Controls") view_axis = gr.Radio( choices=["Axial", "Coronal", "Sagittal"], value="Axial", label="View Axis" ) with gr.Row(): axial_slider = gr.Slider(0, 100, value=50, step=1, label="Axial Slice") coronal_slider = gr.Slider(0, 100, value=50, step=1, label="Coronal Slice") sagittal_slider = gr.Slider(0, 100, value=50, step=1, label="Sagittal Slice") alpha_slider = gr.Slider(0, 1, value=0.5, step=0.1, label="Overlay Opacity") show_overlay = gr.Checkbox(value=True, label="Show Segmentation Overlay") with gr.Column(scale=2): # Output section gr.Markdown("### đŸ–ŧī¸ Segmentation Result") output_image = gr.Plot(label="CT with Segmentation Overlay") gr.Markdown("### 📋 Detected Structures") structures_output = gr.Textbox( label="Anatomical Structures Found", lines=10, max_lines=20 ) # 3D Model Output gr.Markdown("### 🧊 3D View") model_3d_output = gr.Model3D( label="3D Segmentation Mesh", clear_color=[0.0, 0.0, 0.0, 0.0], camera_position=(90, 90, 3) ) # Model info section with gr.Accordion("â„šī¸ Model Information", open=False): gr.Markdown(""" ### About the Model This model is based on **SegResNet** architecture from MONAI, trained on the **TotalSegmentator** dataset. **Capabilities:** - Segments 104 distinct anatomical structures - Works on whole-body CT scans - Uses 3.0mm isotropic spacing (low-resolution model for faster inference) **Segmented Structures include:** - **Major Organs**: Liver, Spleen, Kidneys, Pancreas, Gallbladder, Stomach, Bladder - **Cardiovascular**: Heart chambers, Aorta, Vena Cava, Portal Vein - **Respiratory**: Lung lobes, Trachea - **Skeletal**: Vertebrae (C1-L5), Ribs, Hip bones, Femur, Humerus, Scapula - **Muscles**: Gluteal muscles, Iliopsoas - And many more... **References:** - [MONAI Model Zoo](https://monai.io/model-zoo.html) - [TotalSegmentator Paper](https://pubs.rsna.org/doi/10.1148/ryai.230024) """) # Event handlers process_btn.click( fn=process_upload, inputs=[file_input], outputs=[output_image, structures_output, model_3d_output, axial_slider, coronal_slider, sagittal_slider] ) # Update visualization when controls change for control in [view_axis, alpha_slider, show_overlay]: control.change( fn=lambda axis, alpha, overlay, ax_s, cor_s, sag_s: update_visualization( axis, ax_s if axis == "Axial" else (cor_s if axis == "Coronal" else sag_s), alpha, overlay ), inputs=[view_axis, alpha_slider, show_overlay, axial_slider, coronal_slider, sagittal_slider], outputs=[output_image] ) # Update when sliders change axial_slider.change( fn=lambda s, alpha, overlay: update_visualization("Axial", s, alpha, overlay), inputs=[axial_slider, alpha_slider, show_overlay], outputs=[output_image] ) coronal_slider.change( fn=lambda s, alpha, overlay: update_visualization("Coronal", s, alpha, overlay), inputs=[coronal_slider, alpha_slider, show_overlay], outputs=[output_image] ) sagittal_slider.change( fn=lambda s, alpha, overlay: update_visualization("Sagittal", s, alpha, overlay), inputs=[sagittal_slider, alpha_slider, show_overlay], outputs=[output_image] ) if __name__ == "__main__": # Ensure examples directory exists os.makedirs("examples", exist_ok=True) # Launch the app demo.launch()