| import os |
| os.environ["KERAS_BACKEND"] = "jax" |
|
|
| import gradio as gr |
| import matplotlib.pyplot as plt |
| import matplotlib.cm as cm |
| import keras |
| import keras_hub |
| import numpy as np |
| import jax |
| from keras import ops |
| from PIL import Image |
|
|
| |
| model = None |
| last_conv_layer_model = None |
| classifier_model = None |
|
|
| def initialize_models(): |
| """Initialize the models once when the app starts.""" |
| global model, last_conv_layer_model, classifier_model |
| |
| |
| model = keras_hub.models.ImageClassifier.from_preset( |
| "xception_41_imagenet", |
| activation="softmax", |
| ) |
| |
| |
| last_conv_layer_name = "block14_sepconv2_act" |
| last_conv_layer = model.backbone.get_layer(last_conv_layer_name) |
| last_conv_layer_model = keras.Model(model.inputs, last_conv_layer.output) |
| |
| |
| classifier_input = last_conv_layer.output |
| x = classifier_input |
| for layer_name in ["pooler", "predictions"]: |
| x = model.get_layer(layer_name)(x) |
| classifier_model = keras.Model(classifier_input, x) |
|
|
| def loss_fn(last_conv_layer_output): |
| """Defines a separate loss function for gradient computation.""" |
| preds = classifier_model(last_conv_layer_output) |
| top_pred_index = ops.argmax(preds[0]) |
| top_class_channel = preds[:, top_pred_index] |
| return top_class_channel[0] |
|
|
| |
| grad_fn = jax.grad(loss_fn) |
|
|
| def get_top_class_gradients(img_array): |
| """Get gradients of the top predicted class with respect to last conv layer.""" |
| last_conv_layer_output = last_conv_layer_model(img_array) |
| grads = grad_fn(last_conv_layer_output) |
| return grads, last_conv_layer_output |
|
|
| def generate_heatmap(image): |
| """ |
| Generate class activation heatmap for an uploaded image. |
| |
| Args: |
| image: PIL Image or numpy array |
| |
| Returns: |
| tuple: (superimposed_img, prediction_text) |
| """ |
| if image is None: |
| return None, "Please upload an image." |
| |
| |
| if isinstance(image, Image.Image): |
| img = np.array(image) |
| else: |
| img = image |
| |
| |
| img_array = np.expand_dims(img, axis=0) |
| |
| |
| preds = model.predict(img_array, verbose=0) |
| |
| |
| decoded_preds = keras_hub.utils.decode_imagenet_predictions(preds) |
| |
| |
| prediction_text = "Top 5 Predictions:\n\n" |
| for i, (description, score) in enumerate(decoded_preds[0][:5], 1): |
| prediction_text += f"{i}. {description}: {score:.2%}\n" |
| |
| |
| img_array = model.preprocessor(img_array) |
| |
| |
| grads, last_conv_layer_output = get_top_class_gradients(img_array) |
| grads = ops.convert_to_numpy(grads) |
| last_conv_layer_output = ops.convert_to_numpy(last_conv_layer_output) |
| |
| |
| pooled_grads = np.mean(grads, axis=(0, 1, 2)) |
| last_conv_layer_output = last_conv_layer_output[0].copy() |
| |
| |
| for i in range(pooled_grads.shape[-1]): |
| last_conv_layer_output[:, :, i] *= pooled_grads[i] |
| |
| |
| heatmap = np.mean(last_conv_layer_output, axis=-1) |
| |
| |
| heatmap = np.maximum(heatmap, 0) |
| heatmap /= np.max(heatmap) |
| |
| |
| heatmap = np.uint8(255 * heatmap) |
| |
| |
| jet = cm.get_cmap("jet") |
| jet_colors = jet(np.arange(256))[:, :3] |
| jet_heatmap = jet_colors[heatmap] |
| |
| |
| jet_heatmap = keras.utils.array_to_img(jet_heatmap) |
| jet_heatmap = jet_heatmap.resize((img.shape[1], img.shape[0])) |
| jet_heatmap = keras.utils.img_to_array(jet_heatmap) |
| |
| |
| superimposed_img = jet_heatmap * 0.4 + img |
| superimposed_img = keras.utils.array_to_img(superimposed_img) |
| |
| return superimposed_img, prediction_text |
|
|
| |
| print("Initializing models... this may take a moment.") |
| initialize_models() |
| print("Models initialized!") |
|
|
| |
| with gr.Blocks(title="Class Activation Heatmap Visualizer") as demo: |
| gr.Markdown( |
| """ |
| # Class Activation Heatmap Visualizer |
| |
| Upload an image or choose one of the examples to see what parts of the image the neural network focuses on when making predictions. |
| The heatmap shows which regions of the image are most important for the top predicted class. |
| |
| Code adapted from: https://deeplearningwithpython.io/chapters/chapter10_interpreting-what-convnets-learn/#visualizing-heatmaps-of-class-activation |
| |
| **Model:** Xception trained on ImageNet (1,000 classes) |
| """ |
| ) |
| |
| with gr.Row(): |
| with gr.Column(): |
| input_image = gr.Image( |
| label="Upload Image", |
| type="pil", |
| height=400 |
| ) |
| submit_btn = gr.Button("Generate Heatmap", variant="primary", size="lg") |
|
|
| |
| gr.Examples( |
| examples=[ |
| ["images/elephant.jpg"], |
| ["images/dog.jpg"], |
| ["images/F1_car.jpg"], |
| ["images/multiple_animals.jpg"], |
| ["images/osprey.jpeg"] |
| ], |
| inputs=input_image, |
| label="Try an example:" |
| ) |
|
|
| gr.Markdown( |
| """ |
| ### How to interpret the heatmap: |
| - **Red/Yellow regions**: Areas the model focuses on most for its prediction |
| - **Blue/Purple regions**: Areas the model considers less important |
| """ |
| ) |
| |
| with gr.Column(): |
| output_image = gr.Image( |
| label="Heatmap Visualization", |
| type="pil", |
| height=400 |
| ) |
| prediction_text = gr.Textbox( |
| label="Predictions", |
| lines=7, |
| interactive=False |
| ) |
| |
| |
| submit_btn.click( |
| fn=generate_heatmap, |
| inputs=input_image, |
| outputs=[output_image, prediction_text] |
| ) |
|
|
| |
| if __name__ == "__main__": |
| demo.launch(share=False) |
|
|