| | import streamlit as st |
| | import streamlit_analytics |
| |
|
| | import torch |
| | import torchvision.transforms as transforms |
| | from transformers import ViTModel, ViTConfig |
| | from PIL import Image |
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| | import io |
| |
|
| | streamlit_analytics.start_tracking() |
| |
|
| | |
| | st.set_page_config(page_title="ViewViz", layout="wide") |
| |
|
| | |
| | st.markdown(""" |
| | <style> |
| | .stApp { |
| | background-color: #2b3d4f; |
| | color: #ffffff; |
| | } |
| | .stButton>button { |
| | color: #2b3d4f; |
| | background-color: #4fd1c5; |
| | border-radius: 5px; |
| | } |
| | .stSlider>div>div>div>div { |
| | background-color: #4fd1c5; |
| | } |
| | </style> |
| | """, unsafe_allow_html=True) |
| |
|
| | |
| | USE_GPU = False |
| | device = torch.device('cuda' if USE_GPU and torch.cuda.is_available() else 'cpu') |
| |
|
| | |
| | COLOR_SCHEMES = { |
| | 'Plasma': plt.cm.plasma, |
| | 'Viridis': plt.cm.viridis, |
| | 'Magma': plt.cm.magma, |
| | 'Inferno': plt.cm.inferno, |
| | 'Cividis': plt.cm.cividis, |
| | 'Spectral': plt.cm.Spectral, |
| | 'Coolwarm': plt.cm.coolwarm |
| | } |
| |
|
| | |
| | @st.cache_resource |
| | def load_model(): |
| | model_name = 'google/vit-base-patch16-384' |
| | config = ViTConfig.from_pretrained(model_name, output_attentions=True, attn_implementation="eager") |
| | model = ViTModel.from_pretrained(model_name, config=config) |
| | model.eval() |
| | return model.to(device) |
| |
|
| | model = load_model() |
| |
|
| | |
| | preprocess = transforms.Compose([ |
| | transforms.Resize((384, 384)), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| | ]) |
| |
|
| | def get_attention_map(img): |
| | |
| | input_tensor = preprocess(img).unsqueeze(0).to(device) |
| | |
| | |
| | with torch.no_grad(): |
| | outputs = model(input_tensor, output_attentions=True) |
| | |
| | |
| | att_mat = torch.stack(outputs.attentions).squeeze(1) |
| | att_mat = torch.mean(att_mat, dim=1) |
| |
|
| | |
| | residual_att = torch.eye(att_mat.size(-1)).unsqueeze(0).to(device) |
| | aug_att_mat = att_mat + residual_att |
| | aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1) |
| |
|
| | |
| | joint_attentions = torch.zeros(aug_att_mat.size()).to(device) |
| | joint_attentions[0] = aug_att_mat[0] |
| | for n in range(1, aug_att_mat.size(0)): |
| | joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n-1]) |
| |
|
| | |
| | v = joint_attentions[-1] |
| | grid_size = int(np.sqrt(aug_att_mat.size(-1))) |
| | mask = v[0, 1:].reshape(grid_size, grid_size).detach().cpu().numpy() |
| | |
| | return mask |
| |
|
| | def overlay_attention_map(image, attention_map, overlay_strength, color_scheme): |
| | |
| | attention_map = Image.fromarray(attention_map).resize(image.size, Image.BICUBIC) |
| | attention_map = np.array(attention_map) |
| | |
| | |
| | attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min()) |
| | |
| | |
| | attention_map_color = color_scheme(attention_map) |
| | |
| | |
| | image_rgba = image.convert("RGBA") |
| | image_array = np.array(image_rgba) / 255.0 |
| | |
| | |
| | overlayed_image = image_array * (1 - overlay_strength) + attention_map_color * overlay_strength |
| | |
| | return Image.fromarray((overlayed_image * 255).astype(np.uint8)) |
| |
|
| | st.title("ViewViz") |
| |
|
| | uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) |
| |
|
| | if uploaded_file is not None: |
| | image = Image.open(uploaded_file).convert('RGB') |
| | |
| | st.success("Starting Prediction Process...") |
| | attention_map = get_attention_map(image) |
| | |
| | col1, col2 = st.columns(2) |
| | |
| | with col1: |
| | overlay_strength = st.slider("Heatmap Overlay Percentage", 0, 100, 50) / 100.0 |
| | |
| | with col2: |
| | color_scheme_name = st.selectbox("Choose Heatmap Color Scheme", list(COLOR_SCHEMES.keys())) |
| | |
| | color_scheme = COLOR_SCHEMES[color_scheme_name] |
| | |
| | overlayed_image = overlay_attention_map(image, attention_map, overlay_strength, color_scheme) |
| | |
| | st.image(overlayed_image, caption='Image with Heatmap Overlay', use_column_width=True) |
| | |
| | |
| | buf = io.BytesIO() |
| | overlayed_image.save(buf, format="PNG") |
| | btn = st.download_button( |
| | label="Download Image with Attention Map", |
| | data=buf.getvalue(), |
| | file_name="attention_map_overlay.png", |
| | mime="image/png" |
| | ) |
| |
|
| | streamlit_analytics.stop_tracking() |