| import streamlit as st |
| import sys |
| import os |
|
|
| |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) |
|
|
| import h5py |
| import torch |
| import numpy as np |
| import matplotlib.pyplot as plt |
| import yaml |
| import os |
| import io |
|
|
| |
| from src.mobilenetv2_model import LandslideModel as MobileNetV2Model |
| from src.vgg16_model import LandslideModel as VGG16Model |
| from src.resnet34_model import LandslideModel as ResNet34Model |
| from src.efficientnetb0_model import LandslideModel as EfficientNetB0Model |
| from src.mitb1_model import LandslideModel as MiTB1Model |
| from src.inceptionv4_model import LandslideModel as InceptionV4Model |
| from src.densenet121_model import LandslideModel as DenseNet121Model |
| from src.deeplabv3plus_model import LandslideModel as DeepLabV3PlusModel |
| from src.resnext50_32x4d_model import LandslideModel as ResNeXt50Model |
| from src.se_resnet50_model import LandslideModel as SEResNet50Model |
| from src.se_resnext50_32x4d_model import LandslideModel as SEResNeXt50Model |
| from src.segformer_model import LandslideModel as SegFormerB2Model |
| from src.inceptionresnetv2_model import LandslideModel as InceptionResNetV2Model |
| from src.model_downloader import ModelDownloader |
|
|
| |
| AVAILABLE_MODELS = { |
| "mobilenetv2": {"name": "MobileNetV2", "type": "mobilenet_v2"}, |
| "vgg16": {"name": "VGG16", "type": "vgg16"}, |
| "resnet34": {"name": "ResNet34", "type": "resnet34"}, |
| "efficientnetb0": {"name": "EfficientNetB0", "type": "efficientnet_b0"}, |
| "mitb1": {"name": "MiTB1", "type": "mitb1"}, |
| "inceptionv4": {"name": "InceptionV4", "type": "inception_v4"}, |
| "densenet121": {"name": "DenseNet121", "type": "densenet121"}, |
| "deeplabv3plus": {"name": "DeepLabV3Plus", "type": "deeplabv3plus"}, |
| "resnext50": {"name": "ResNeXt50", "type": "resnext50_32x4d", "downloader_key": "resnext50_32x4d"}, |
| "seresnet50": {"name": "SEResNet50", "type": "se_resnet50", "downloader_key": "se_resnet50"}, |
| "seresnext50": {"name": "SEResNeXt50", "type": "se_resnext50_32x4d", "downloader_key": "se_resnext50_32x4d"}, |
| "segformerb2": {"name": "SegFormerB2", "type": "segformer_b2", "downloader_key": "segformer"}, |
| "inceptionresnetv2": {"name": "InceptionResNetV2", "type": "inception_resnet_v2"} |
| } |
|
|
| |
| MODEL_DESCRIPTIONS = { |
| model_key: { |
| "type": model_info["type"], |
| "description": f"{model_info['name']} - A model for landslide detection and segmentation.", |
| "name": model_info["name"], |
| "downloader_key": model_info.get("downloader_key", model_key) |
| } |
| for model_key, model_info in AVAILABLE_MODELS.items() |
| } |
|
|
| |
| config_str = """ |
| model_config: |
| model_type: "mobilenet_v2" |
| in_channels: 14 |
| num_classes: 1 |
| encoder_weights: "imagenet" |
| wce_weight: 0.5 |
| |
| dataset_config: |
| num_classes: 1 |
| num_channels: 14 |
| channels: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14] |
| normalize: False |
| |
| train_config: |
| dataset_path: "" |
| checkpoint_path: "checkpoints" |
| seed: 42 |
| train_val_split: 0.8 |
| batch_size: 16 |
| num_epochs: 100 |
| lr: 0.001 |
| device: "cuda:0" |
| save_config: True |
| experiment_name: "mobilenet_v2" |
| |
| logging_config: |
| wandb_project: "l4s" |
| wandb_entity: "Silvamillion" |
| """ |
|
|
| config = yaml.safe_load(config_str) |
|
|
| def process_and_visualize(model_key, model_info, image_tensor, original_image, uploaded_file_name): |
| """ |
| Process the image with the selected model and visualize results. |
| """ |
| try: |
| st.write(f"Using model: {model_info['name']}") |
| |
| |
| current_config = config.copy() |
| current_config['model_config']['model_type'] = model_info['type'] |
| |
| |
| model_class_name = AVAILABLE_MODELS[model_key]['name'].replace('-', '') + 'Model' |
| if model_class_name not in globals(): |
| |
| |
| pass |
| model_class = globals()[model_class_name] |
|
|
| |
| downloader = ModelDownloader() |
| |
| |
| download_key = model_info.get('downloader_key', model_key) |
| model_path = downloader.download_model(download_key) |
| st.info(f"Using model from: {model_path}") |
| |
| |
| model = model_class(current_config) |
| model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False) |
| model.eval() |
|
|
| |
| with torch.no_grad(): |
| prediction = model(image_tensor) |
| prediction = torch.sigmoid(prediction).cpu().numpy() |
|
|
| |
| st.header(f"Prediction Results - {model_info['name']}") |
| fig, ax = plt.subplots(1, 3, figsize=(15, 5)) |
| |
| |
| img_display = original_image.transpose(1, 2, 0) |
| img_display = (img_display - img_display.min()) / (img_display.max() - img_display.min()) |
| |
| ax[0].imshow(img_display[:, :, :3]) |
| ax[0].set_title("Input Image") |
| ax[0].axis('off') |
| |
| ax[1].imshow(prediction.squeeze(), cmap='plasma') |
| ax[1].set_title("Prediction Probability") |
| ax[1].axis('off') |
| |
| ax[2].imshow(img_display[:, :, :3]) |
| ax[2].imshow(prediction.squeeze() > 0.5, cmap='plasma', alpha=0.4) |
| ax[2].set_title("Overlay (Threshold > 0.5)") |
| ax[2].axis('off') |
| |
| st.pyplot(fig) |
| plt.close(fig) |
|
|
| |
| st.write(f"Download the prediction as a .npy file for {model_info['name']}:") |
| npy_data = prediction.squeeze() |
| st.download_button( |
| label=f"Download Prediction - {model_info['name']}", |
| data=npy_data.tobytes(), |
| file_name=f"{uploaded_file_name.split('.')[0]}_{model_key}_prediction.npy", |
| mime="application/octet-stream" |
| ) |
| |
| except Exception as e: |
| st.error(f"Error with model {model_info['name']}: {str(e)}") |
| import traceback |
| st.error(traceback.format_exc()) |
|
|
| |
| st.set_page_config(page_title="DeepSlide: Landslide Detection", layout="wide") |
|
|
| st.title("DeepSlide: Landslide Detection") |
| st.markdown(""" |
| ## Instructions |
| 1. **Model Selection**: Choose a single model from the sidebar or select "Run all models". |
| 2. **Data Input**: |
| - Try an example image from the dropdown, or |
| - Upload your own .h5 files |
| 3. **Results**: View predictions and download results as .npy files. |
| """) |
|
|
| |
| st.sidebar.title("Model Selection") |
| model_option = st.sidebar.radio("Choose an option", ["Select a single model", "Run all models"]) |
|
|
| selected_model_key = None |
| if model_option == "Select a single model": |
| selected_model_key = st.sidebar.selectbox("Select Model", list(MODEL_DESCRIPTIONS.keys())) |
| selected_model_info = MODEL_DESCRIPTIONS[selected_model_key] |
| |
| |
| st.sidebar.markdown("### Model Details") |
| st.sidebar.markdown(f"**Model Name:** {selected_model_info['name']}") |
| st.sidebar.markdown(f"**Model Type:** {selected_model_info['type']}") |
| st.sidebar.markdown(f"**Description:** {selected_model_info['description']}") |
|
|
| |
| st.header("Upload Data") |
|
|
| |
| if 'upload_errors' not in st.session_state: |
| st.session_state.upload_errors = [] |
|
|
| |
| st.subheader("Try Example Images") |
| examples_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "examples") |
| example_files = [] |
|
|
| try: |
| if os.path.exists(examples_dir): |
| example_files = [f for f in os.listdir(examples_dir) if f.endswith('.h5')] |
| example_files.sort() |
| except: |
| pass |
|
|
| if example_files: |
| selected_example = st.selectbox( |
| "Select an example image to test:", |
| options=["None"] + example_files, |
| help="Choose an example .h5 file to quickly test the models" |
| ) |
| else: |
| st.info("No example files found") |
| selected_example = "None" |
|
|
| |
| st.subheader("Upload Your Own Files") |
| uploaded_files = st.file_uploader( |
| "Choose .h5 files...", |
| type="h5", |
| accept_multiple_files=True, |
| help="Upload your .h5 files here. Maximum file size is 200MB." |
| ) |
|
|
| def process_h5_file(file_path, file_name): |
| """Process a single h5 file""" |
| try: |
| with h5py.File(file_path, 'r') as hdf: |
| if 'img' not in hdf: |
| st.error(f"Error: 'img' dataset not found in {file_name}") |
| return |
| |
| data = np.array(hdf.get('img')) |
| data[np.isnan(data)] = 0.000001 |
| channels = config["dataset_config"]["channels"] |
| |
| image = np.zeros((128, 128, len(channels))) |
| |
| if data.ndim == 3: |
| if data.shape[0] == 14: |
| for i, band in enumerate(channels): |
| image[:, :, i] = data[band-1, :, :] |
| elif data.shape[2] == 14: |
| for i, band in enumerate(channels): |
| image[:, :, i] = data[:, :, band-1] |
| else: |
| st.warning(f"Unexpected data shape: {data.shape}. Assuming (C, H, W).") |
| for i, band in enumerate(channels): |
| if band-1 < data.shape[0]: |
| image[:, :, i] = data[band-1, :, :] |
| else: |
| st.error(f"Data has {data.ndim} dimensions, expected 3.") |
| return |
|
|
| |
| image_display = image.transpose(2, 0, 1) |
| image_tensor = torch.from_numpy(image_display).unsqueeze(0).float() |
|
|
| if model_option == "Select a single model": |
| process_and_visualize(selected_model_key, selected_model_info, image_tensor, image_display, file_name) |
| else: |
| for model_key, model_info in MODEL_DESCRIPTIONS.items(): |
| process_and_visualize(model_key, model_info, image_tensor, image_display, file_name) |
|
|
| except Exception as e: |
| st.error(f"Error processing file {file_name}: {str(e)}") |
|
|
| |
| if selected_example != "None": |
| st.write(f"Processing example: {selected_example}") |
| example_path = os.path.join(examples_dir, selected_example) |
| with st.spinner(f'Processing {selected_example}...'): |
| process_h5_file(example_path, selected_example) |
|
|
| |
| if uploaded_files: |
| for uploaded_file in uploaded_files: |
| st.write(f"Processing file: {uploaded_file.name}") |
| st.write(f"File size: {uploaded_file.size} bytes") |
| |
| with st.spinner('Processing...'): |
| try: |
| |
| bytes_data = uploaded_file.getvalue() |
| bytes_io = io.BytesIO(bytes_data) |
| |
| with h5py.File(bytes_io, 'r') as hdf: |
| if 'img' not in hdf: |
| st.error(f"Error: 'img' dataset not found in {uploaded_file.name}") |
| continue |
| |
| data = np.array(hdf.get('img')) |
| data[np.isnan(data)] = 0.000001 |
| channels = config["dataset_config"]["channels"] |
| |
| image = np.zeros((128, 128, len(channels))) |
| |
| if data.ndim == 3: |
| if data.shape[0] == 14: |
| for i, band in enumerate(channels): |
| image[:, :, i] = data[band-1, :, :] |
| elif data.shape[2] == 14: |
| for i, band in enumerate(channels): |
| image[:, :, i] = data[:, :, band-1] |
| else: |
| st.warning(f"Unexpected data shape: {data.shape}. Assuming (C, H, W).") |
| for i, band in enumerate(channels): |
| if band-1 < data.shape[0]: |
| image[:, :, i] = data[band-1, :, :] |
| else: |
| st.error(f"Data has {data.ndim} dimensions, expected 3.") |
| continue |
|
|
| |
| image_display = image.transpose(2, 0, 1) |
| image_tensor = torch.from_numpy(image_display).unsqueeze(0).float() |
|
|
| if model_option == "Select a single model": |
| process_and_visualize(selected_model_key, selected_model_info, image_tensor, image_display, uploaded_file.name) |
| else: |
| for model_key, model_info in MODEL_DESCRIPTIONS.items(): |
| process_and_visualize(model_key, model_info, image_tensor, image_display, uploaded_file.name) |
|
|
| except Exception as e: |
| st.error(f"Error processing file {uploaded_file.name}: {str(e)}") |
| import traceback |
| st.error(traceback.format_exc()) |
| continue |
|
|
| if selected_example != "None" or uploaded_files: |
| st.success('✅ Processing completed!') |