"""
app.py
======
VLM Caption Lab — Premium Streamlit Demo
Features:
• Sidebar — Weight Source: Base / Fine-tuned (Best) / Fine-tuned (Latest)
• Sidebar — Architecture selector, Generation Mode, Advanced Controls
• Tab 1 — Caption: Single model captioning with weight selection
• Tab 2 — Compare: Side-by-side 4-model comparison (same image, same config)
• Tab 3 — Results: Pre-computed benchmark comparison tables
"""
import os
import json
import warnings
import torch
import numpy as np
import streamlit as st
from PIL import Image
from models.blip_tuner import generate_with_mask
warnings.filterwarnings("ignore", message="urllib3 v2 only supports OpenSSL")
warnings.filterwarnings("ignore", category=UserWarning, message=".*use_fast.*")
# ─────────────────────────────────────────────────────────────────────────────
# Page Config & CSS
# ─────────────────────────────────────────────────────────────────────────────
st.set_page_config(
page_title="VLM Caption Lab",
page_icon="🔬",
layout="wide",
initial_sidebar_state="expanded",
)
st.markdown("""
""", unsafe_allow_html=True)
# ─────────────────────────────────────────────────────────────────────────────
# Architecture Info & Constants
# ─────────────────────────────────────────────────────────────────────────────
ARCH_INFO = {
"BLIP (Multimodal Mixture Attention)": (
"🔵 BLIP uses a Mixture-of-Encoder-Decoder (MED) architecture. "
"Gated cross-attention is injected between self-attention and FFN layers."
),
"ViT-GPT2 (Standard Cross-Attention)": (
"🟣 ViT-GPT2: every GPT-2 text token attends to all "
"197 ViT patch embeddings via full cross-attention at every decoder layer."
),
"GIT (Zero Cross-Attention)": (
"🟠 GIT abandons cross-attention entirely. Image patches are "
"concatenated to the front of the token sequence; no cross-attention block."
),
"Custom VLM (Shakespeare Prefix)": (
"🟢 Custom VLM fuses a frozen ViT with a Shakespeare char-level "
"decoder via a single trainable Linear(768→384) projection."
),
}
MODEL_KEYS = [
"BLIP (Multimodal Mixture Attention)",
"ViT-GPT2 (Standard Cross-Attention)",
"GIT (Zero Cross-Attention)",
"Custom VLM (Shakespeare Prefix)",
]
MODEL_SHORT = {
"BLIP (Multimodal Mixture Attention)": "BLIP",
"ViT-GPT2 (Standard Cross-Attention)": "ViT-GPT2",
"GIT (Zero Cross-Attention)": "GIT",
"Custom VLM (Shakespeare Prefix)": "Custom VLM",
}
MODEL_BADGE = {
"BLIP (Multimodal Mixture Attention)": "badge-blue",
"ViT-GPT2 (Standard Cross-Attention)": "badge-purple",
"GIT (Zero Cross-Attention)": "badge-orange",
"Custom VLM (Shakespeare Prefix)": "badge-green",
}
MODEL_CA_TYPE = {
"BLIP (Multimodal Mixture Attention)": "Gated MED Cross-Attention",
"ViT-GPT2 (Standard Cross-Attention)": "Full Cross-Attention",
"GIT (Zero Cross-Attention)": "Self-Attention Prefix",
"Custom VLM (Shakespeare Prefix)": "Linear Bridge Prefix",
}
WEIGHT_TAG_CLASS = {"base": "wt-base", "best": "wt-best", "latest": "wt-latest"}
WEIGHT_LABEL = {"base": "Base", "best": "Best", "latest": "Latest"}
DEFAULT_OUTPUT_ROOT = "./outputs"
DEFAULT_SHAKESPEARE_FILE = "./input.txt"
DEFAULT_SHAKESPEARE_WEIGHTS = "./shakespeare_transformer.pt"
WEIGHTS_REPO_ID = os.getenv("WEIGHTS_REPO_ID", "griddev/vlm-caption-weights")
WEIGHTS_CACHE_DIR = os.getenv("WEIGHTS_CACHE_DIR", "./weights_bundle")
TASK1_DIR = os.path.join("task", "task_01")
TASK1_RESULTS_DIR = os.path.join(TASK1_DIR, "results")
TASK2_DIR = os.path.join("task", "task_02")
TASK2_RESULTS_DIR = os.path.join(TASK2_DIR, "results")
TASK3_DIR = os.path.join("task", "task_03")
TASK3_RESULTS_DIR = os.path.join(TASK3_DIR, "results")
TASK4_DIR = os.path.join("task", "task_04")
TASK4_RESULTS_DIR = os.path.join(TASK4_DIR, "results")
TASK5_DIR = os.path.join("task", "task_05")
TASK5_RESULTS_DIR = os.path.join(TASK5_DIR, "results")
MODEL_DIR = {
"BLIP (Multimodal Mixture Attention)": "blip",
"ViT-GPT2 (Standard Cross-Attention)": "vit_gpt2",
"GIT (Zero Cross-Attention)": "git",
"Custom VLM (Shakespeare Prefix)": "custom_vlm",
}
DISABLE_FINETUNE_FOR = {"vit_gpt2", "git"}
OUTPUT_ROOT = DEFAULT_OUTPUT_ROOT
@st.cache_resource(show_spinner=False)
def _download_weights(
need_outputs: bool,
need_shakespeare: bool,
output_model_dir: str | None = None,
) -> str:
from huggingface_hub import snapshot_download
allow_patterns = []
if need_outputs:
if output_model_dir:
allow_patterns += [
f"outputs/{output_model_dir}/*",
f"outputs/{output_model_dir}/**/*",
]
else:
allow_patterns += ["outputs/*", "outputs/**/*"]
if need_shakespeare:
allow_patterns += ["input.txt", "shakespeare_transformer.pt"]
if not allow_patterns:
return WEIGHTS_CACHE_DIR
return snapshot_download(
repo_id=WEIGHTS_REPO_ID,
repo_type="model",
local_dir=WEIGHTS_CACHE_DIR,
local_dir_use_symlinks=False,
allow_patterns=allow_patterns,
)
@st.cache_resource(show_spinner=False)
def _download_model_outputs(model_dir: str) -> str:
return _download_weights(
need_outputs=True,
need_shakespeare=False,
output_model_dir=model_dir,
)
def _ensure_model_outputs_available(model_dir: str) -> None:
# Intentionally no eager snapshot download here.
# We only fetch checkpoints when a user explicitly selects a fine-tuned weight.
_ = model_dir
@st.cache_data(show_spinner=False, ttl=900)
def _weights_repo_file_set() -> set[str]:
from huggingface_hub import HfApi
api = HfApi()
try:
files = api.list_repo_files(repo_id=WEIGHTS_REPO_ID, repo_type="model")
return set(files)
except Exception:
return set()
def _remote_has_finetuned(model_dir: str, subdir: str) -> bool:
files = _weights_repo_file_set()
if not files:
return False
prefix = f"outputs/{model_dir}/{subdir}/"
return any(path.startswith(prefix) for path in files)
def _resolve_weight_paths(
need_outputs: bool,
need_shakespeare: bool,
output_model_dir: str | None = None,
):
output_root = DEFAULT_OUTPUT_ROOT
shakespeare_file = DEFAULT_SHAKESPEARE_FILE
shakespeare_weights = DEFAULT_SHAKESPEARE_WEIGHTS
if need_outputs:
if output_model_dir:
local_model = os.path.join(output_root, output_model_dir)
cache_model = os.path.join(WEIGHTS_CACHE_DIR, "outputs", output_model_dir)
have_outputs = (
(os.path.isdir(local_model) and len(os.listdir(local_model)) > 0)
or (os.path.isdir(cache_model) and len(os.listdir(cache_model)) > 0)
)
else:
have_outputs = os.path.isdir(output_root) and len(os.listdir(output_root)) > 0
else:
have_outputs = True
have_shakespeare = (
os.path.exists(shakespeare_file) and os.path.exists(shakespeare_weights)
)
if (not need_outputs or have_outputs) and (not need_shakespeare or have_shakespeare):
return output_root, shakespeare_file, shakespeare_weights
try:
cache_dir = _download_weights(
need_outputs,
need_shakespeare,
output_model_dir=output_model_dir,
)
candidate_output_root = os.path.join(cache_dir, "outputs")
candidate_shakespeare_file = os.path.join(cache_dir, "input.txt")
candidate_shakespeare_weights = os.path.join(
cache_dir, "shakespeare_transformer.pt"
)
if os.path.isdir(candidate_output_root):
output_root = candidate_output_root
if os.path.exists(candidate_shakespeare_file):
shakespeare_file = candidate_shakespeare_file
if os.path.exists(candidate_shakespeare_weights):
shakespeare_weights = candidate_shakespeare_weights
except Exception as e:
print(f"⚠️ Could not download fine-tuned weights from {WEIGHTS_REPO_ID}: {e}")
return output_root, shakespeare_file, shakespeare_weights
# ─────────────────────────────────────────────────────────────────────────────
# Device
# ─────────────────────────────────────────────────────────────────────────────
def get_device():
if torch.backends.mps.is_available(): return torch.device("mps")
if torch.cuda.is_available(): return torch.device("cuda")
return torch.device("cpu")
# ─────────────────────────────────────────────────────────────────────────────
# Weight Loading Helpers
# ─────────────────────────────────────────────────────────────────────────────
def _has_finetuned(model_dir, subdir):
"""Check if a fine-tuned checkpoint exists for a given model + subdir."""
candidates = [
os.path.join(DEFAULT_OUTPUT_ROOT, model_dir, subdir),
os.path.join(WEIGHTS_CACHE_DIR, "outputs", model_dir, subdir),
]
for path in candidates:
if os.path.isdir(path) and len(os.listdir(path)) > 0:
return True
return _remote_has_finetuned(model_dir, subdir)
def _ckpt_path(output_root, model_dir, subdir):
return os.path.join(output_root, model_dir, subdir)
def _resolve_weight_source_for_model(model_name, requested_source):
if requested_source == "base":
return requested_source, None
model_dir = MODEL_DIR.get(model_name)
if not model_dir:
return requested_source, None
if model_dir in DISABLE_FINETUNE_FOR:
short_name = MODEL_SHORT.get(model_name, model_name)
return "base", f"{short_name} uses base weights only."
if _has_finetuned(model_dir, requested_source):
return requested_source, None
_resolve_weight_paths(
need_outputs=True,
need_shakespeare=(model_dir == "custom_vlm"),
output_model_dir=model_dir,
)
if _has_finetuned(model_dir, requested_source):
return requested_source, None
short_name = MODEL_SHORT.get(model_name, model_name)
return "base", f"{short_name} has no '{requested_source}' weights. Using base."
def _finetuned_available_for_model(model_name, requested_source):
if requested_source == "base":
return True
model_dir = MODEL_DIR.get(model_name)
if not model_dir or model_dir in DISABLE_FINETUNE_FOR:
return False
return _has_finetuned(model_dir, requested_source)
# ─────────────────────────────────────────────────────────────────────────────
# Cached Model Loaders (with weight_source support)
# ─────────────────────────────────────────────────────────────────────────────
@st.cache_resource(show_spinner=False)
def load_blip(weight_source="base"):
from transformers import BlipProcessor, BlipForConditionalGeneration
device = get_device()
processor = BlipProcessor.from_pretrained(
"Salesforce/blip-image-captioning-base", use_fast=True)
model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-base")
if weight_source != "base":
output_root, _, _ = _resolve_weight_paths(
need_outputs=True,
need_shakespeare=False,
output_model_dir="blip",
)
ckpt = _ckpt_path(output_root, "blip", weight_source)
if os.path.isdir(ckpt) and os.listdir(ckpt):
try:
loaded = BlipForConditionalGeneration.from_pretrained(ckpt)
model.load_state_dict(loaded.state_dict())
del loaded
except Exception as e:
print(f"⚠️ Could not load BLIP {weight_source} weights: {e}")
model.to(device).eval()
return processor, model, device
@st.cache_resource(show_spinner=False)
def load_vit_gpt2(weight_source="base"):
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
device = get_device()
model_id = "nlpconnect/vit-gpt2-image-captioning"
processor = ViTImageProcessor.from_pretrained(model_id, use_fast=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
model = VisionEncoderDecoderModel.from_pretrained(model_id)
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
if weight_source != "base":
output_root, _, _ = _resolve_weight_paths(
need_outputs=True,
need_shakespeare=False,
output_model_dir="vit_gpt2",
)
ckpt = _ckpt_path(output_root, "vit_gpt2", weight_source)
if os.path.isdir(ckpt) and os.listdir(ckpt):
try:
loaded = VisionEncoderDecoderModel.from_pretrained(ckpt)
model.load_state_dict(loaded.state_dict())
del loaded
except Exception as e:
print(f"⚠️ Could not load ViT-GPT2 {weight_source} weights: {e}")
model.to(device).eval()
return processor, tokenizer, model, device
@st.cache_resource(show_spinner=False)
def load_git(weight_source="base"):
from transformers import AutoProcessor, AutoModelForCausalLM
device = get_device()
model_id = "microsoft/git-base-coco"
processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(model_id)
if weight_source != "base":
output_root, _, _ = _resolve_weight_paths(
need_outputs=True,
need_shakespeare=False,
output_model_dir="git",
)
ckpt = _ckpt_path(output_root, "git", weight_source)
if os.path.isdir(ckpt) and os.listdir(ckpt):
try:
loaded = AutoModelForCausalLM.from_pretrained(ckpt)
model.load_state_dict(loaded.state_dict())
del loaded
except Exception as e:
print(f"⚠️ Could not load GIT {weight_source} weights: {e}")
model.to(device).eval()
return processor, model, device
@st.cache_resource(show_spinner=False)
def load_custom_vlm(weight_source="base"):
from models.custom_vlm import CustomVLM, build_char_vocab
from config import CFG
device = get_device()
cfg = CFG()
output_root, shakespeare_file, shakespeare_weights = _resolve_weight_paths(
need_outputs=(weight_source != "base"),
need_shakespeare=True,
output_model_dir="custom_vlm" if weight_source != "base" else None,
)
cfg.output_root = output_root
cfg.shakespeare_file = shakespeare_file
cfg.shakespeare_weights_path = shakespeare_weights
if not os.path.exists(cfg.shakespeare_file):
return None, None, None, None, device
@st.cache_data(show_spinner=False)
def _load_char_vocab(text_path: str):
with open(text_path, "r", encoding="utf-8") as f:
text = f.read()
return build_char_vocab(text)
_, char_to_idx, idx_to_char, vocab_size = _load_char_vocab(cfg.shakespeare_file)
model = CustomVLM(
vocab_size=vocab_size,
text_embed_dim=cfg.text_embed_dim,
n_heads=cfg.n_heads,
n_layers=cfg.n_layers,
block_size=cfg.block_size,
dropout=cfg.dropout,
)
# Always load Shakespeare weights first
shakes_path = getattr(cfg, "shakespeare_weights_path", "./shakespeare_transformer.pt")
if os.path.exists(shakes_path):
model.load_shakespeare_weights(shakes_path)
# Then load fine-tuned checkpoint if requested
if weight_source != "base":
ckpt_path = os.path.join(cfg.output_root, "custom_vlm", weight_source, "custom_vlm.pt")
if os.path.exists(ckpt_path):
state = torch.load(ckpt_path, map_location="cpu")
own_state = model.state_dict()
filtered = {k: v for k, v in state["model_state"].items()
if k in own_state and own_state[k].shape == v.shape}
model.load_state_dict(filtered, strict=False)
else:
# Even for base, try loading best weights as fallback
for subdir in ["best", "latest"]:
candidate = os.path.join(cfg.output_root, "custom_vlm", subdir, "custom_vlm.pt")
if os.path.exists(candidate):
state = torch.load(candidate, map_location="cpu")
own_state = model.state_dict()
filtered = {k: v for k, v in state["model_state"].items()
if k in own_state and own_state[k].shape == v.shape}
model.load_state_dict(filtered, strict=False)
break
model.to(device).eval()
return model, char_to_idx, idx_to_char, vocab_size, device
@st.cache_resource(show_spinner=False)
def load_toxicity_filter():
from transformers import AutoModelForSequenceClassification, AutoTokenizer
tox_id = "unitary/toxic-bert"
tok = AutoTokenizer.from_pretrained(tox_id)
mdl = AutoModelForSequenceClassification.from_pretrained(tox_id)
mdl.eval()
return tok, mdl
@st.cache_resource(show_spinner=False)
def load_blip_attention_model(weight_source="base"):
from transformers import BlipForConditionalGeneration, BlipProcessor
device = get_device()
processor = BlipProcessor.from_pretrained(
"Salesforce/blip-image-captioning-base", use_fast=True
)
model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-base"
)
if weight_source != "base":
output_root, _, _ = _resolve_weight_paths(
need_outputs=True,
need_shakespeare=False,
output_model_dir="blip",
)
ckpt = _ckpt_path(output_root, "blip", weight_source)
if os.path.isdir(ckpt) and os.listdir(ckpt):
loaded = BlipForConditionalGeneration.from_pretrained(ckpt)
model.load_state_dict(loaded.state_dict(), strict=False)
del loaded
try:
model.gradient_checkpointing_disable()
except Exception:
pass
model.config.use_cache = False
model.to(device).eval()
return processor, model, device
@st.cache_resource(show_spinner=False)
def load_alignment_detector():
from models.attention_flow import load_owlvit_detector
return load_owlvit_detector(get_device())
@st.cache_data(show_spinner=False)
def load_task3_precomputed_results():
results_path = os.path.join(TASK3_RESULTS_DIR, "ablation_results.json")
if os.path.exists(results_path):
with open(results_path, "r", encoding="utf-8") as handle:
return json.load(handle)
from task.task_03.step3_run_ablation import PRECOMPUTED_RESULTS
return PRECOMPUTED_RESULTS
@st.cache_data(show_spinner=False)
def load_task3_demo_bundle():
results = load_task3_precomputed_results()
figure_paths = {
"heatmap": os.path.join(TASK3_RESULTS_DIR, "cider_heatmap.png"),
"latency": os.path.join(TASK3_RESULTS_DIR, "latency_barchart.png"),
"scatter": os.path.join(TASK3_RESULTS_DIR, "quality_speed_scatter.png"),
}
findings = {}
findings_path = os.path.join(TASK3_RESULTS_DIR, "findings.md")
if os.path.exists(findings_path):
findings = {"findings_path": findings_path}
return results, figure_paths, findings
@st.cache_data(show_spinner=False)
def load_task1_demo_bundle():
def _read_json(path, default):
if os.path.exists(path):
with open(path, "r", encoding="utf-8") as handle:
return json.load(handle)
return default
training_log = _read_json(
os.path.join(TASK1_RESULTS_DIR, "training_log.json"),
{},
)
onnx_meta = _read_json(
os.path.join(TASK1_RESULTS_DIR, "onnx_export_meta.json"),
{},
)
coreml_meta = _read_json(
os.path.join(TASK1_RESULTS_DIR, "coreml_conversion_meta.json"),
{},
)
benchmark_results = _read_json(
os.path.join(TASK1_RESULTS_DIR, "benchmark_results.json"),
{},
)
figure_paths = {
"model_size": os.path.join(TASK1_RESULTS_DIR, "model_size_comparison.png"),
"latency": os.path.join(TASK1_RESULTS_DIR, "latency_comparison.png"),
"training_curve": os.path.join(TASK1_RESULTS_DIR, "training_curve.png"),
"bleu4": os.path.join(TASK1_RESULTS_DIR, "bleu4_comparison.png"),
}
findings_path = os.path.join(TASK1_RESULTS_DIR, "findings.md")
findings_md = ""
if os.path.exists(findings_path):
with open(findings_path, "r", encoding="utf-8") as handle:
findings_md = handle.read()
return {
"training_log": training_log,
"onnx_meta": onnx_meta,
"coreml_meta": coreml_meta,
"benchmark_results": benchmark_results,
"figure_paths": figure_paths,
"findings_path": findings_path,
"findings_md": findings_md,
"run_dir": TASK1_RESULTS_DIR,
"source": "precomputed",
}
@st.cache_data(show_spinner=False)
def load_task4_demo_bundle():
def _read_json(path, default):
if os.path.exists(path):
with open(path, "r", encoding="utf-8") as handle:
return json.load(handle)
return default
diversity_records = _read_json(
os.path.join(TASK4_RESULTS_DIR, "diversity_results.json"),
[],
)
steering_results = _read_json(
os.path.join(TASK4_RESULTS_DIR, "steering_results.json"),
[],
)
vectors_meta = _read_json(
os.path.join(TASK4_RESULTS_DIR, "steering_vectors_meta.json"),
{},
)
figure_paths = {
"diversity_histogram": os.path.join(TASK4_RESULTS_DIR, "diversity_histogram.png"),
"diverse_vs_repetitive": os.path.join(TASK4_RESULTS_DIR, "diverse_vs_repetitive.png"),
"steering_lambda_sweep": os.path.join(TASK4_RESULTS_DIR, "steering_lambda_sweep.png"),
}
findings_path = os.path.join(TASK4_RESULTS_DIR, "findings.md")
findings_md = ""
if os.path.exists(findings_path):
with open(findings_path, "r", encoding="utf-8") as handle:
findings_md = handle.read()
return {
"diversity_records": diversity_records,
"steering_results": steering_results,
"vectors_meta": vectors_meta,
"figure_paths": figure_paths,
"findings_path": findings_path,
"findings_md": findings_md,
"run_dir": TASK4_RESULTS_DIR,
"source": "precomputed",
}
@st.cache_data(show_spinner=False)
def load_task5_demo_bundle():
def _read_json(path, default):
if os.path.exists(path):
with open(path, "r", encoding="utf-8") as handle:
return json.load(handle)
return default
toxicity_scores = _read_json(
os.path.join(TASK5_RESULTS_DIR, "toxicity_scores.json"),
[],
)
bias_audit = _read_json(
os.path.join(TASK5_RESULTS_DIR, "bias_audit.json"),
{"records": [], "freq_table": {}},
)
mitigation_results = _read_json(
os.path.join(TASK5_RESULTS_DIR, "mitigation_results.json"),
[],
)
figure_paths = {
"toxicity_distribution": os.path.join(TASK5_RESULTS_DIR, "toxicity_distribution.png"),
"bias_heatmap": os.path.join(TASK5_RESULTS_DIR, "bias_heatmap.png"),
"before_after": os.path.join(TASK5_RESULTS_DIR, "before_after_comparison.png"),
}
report_path = os.path.join(TASK5_RESULTS_DIR, "fairness_report.md")
report_md = ""
if os.path.exists(report_path):
with open(report_path, "r", encoding="utf-8") as handle:
report_md = handle.read()
return {
"toxicity_scores": toxicity_scores,
"bias_audit": bias_audit,
"mitigation_results": mitigation_results,
"figure_paths": figure_paths,
"report_path": report_path,
"report_md": report_md,
"run_dir": TASK5_RESULTS_DIR,
"source": "precomputed",
}
# ─────────────────────────────────────────────────────────────────────────────
# Toxicity Check
# ─────────────────────────────────────────────────────────────────────────────
def is_toxic(text, tox_tok, tox_mdl):
inputs = tox_tok(text, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = tox_mdl(**inputs)
scores = torch.sigmoid(outputs.logits).squeeze()
if isinstance(scores, torch.Tensor) and scores.dim() > 0:
return (scores > 0.5).any().item()
return scores.item() > 0.5
# ─────────────────────────────────────────────────────────────────────────────
# Ablation Mask Builder
# ─────────────────────────────────────────────────────────────────────────────
def build_mask_for_mode(ui_mode, device):
N = 197
if ui_mode == "Baseline (Full Attention)":
return torch.ones(1, N, dtype=torch.long, device=device), False
elif ui_mode == "Random Patch Dropout (50%)":
mask = torch.ones(1, N, dtype=torch.long, device=device)
spatial_indices = torch.randperm(196)[:98] + 1
mask[0, spatial_indices] = 0
return mask, False
elif ui_mode == "Center-Focus (Inner 8×8)":
GRID, INNER, offset = 14, 8, 3
keep = set()
for row in range(offset, offset + INNER):
for col in range(offset, offset + INNER):
keep.add(row * GRID + col + 1)
mask = torch.zeros(1, N, dtype=torch.long, device=device)
mask[0, 0] = 1
for idx in keep:
if idx < N: mask[0, idx] = 1
return mask, False
elif ui_mode == "Squint (Global Pool)":
return None, True
return torch.ones(1, N, dtype=torch.long, device=device), False
# ─────────────────────────────────────────────────────────────────────────────
# Caption Generation (single model)
# ─────────────────────────────────────────────────────────────────────────────
def generate_caption(model_name, gen_mode, image_pil,
num_beams=4, max_new_tokens=50, length_penalty=1.0,
weight_source="base"):
device = get_device()
with torch.no_grad():
if model_name == "BLIP (Multimodal Mixture Attention)":
processor, model, device = load_blip(weight_source)
inputs = processor(images=image_pil, return_tensors="pt").to(device)
mask, is_squint = build_mask_for_mode(gen_mode, device)
if is_squint:
vision_out = model.vision_model(pixel_values=inputs["pixel_values"])
hs = vision_out.last_hidden_state
pooled = torch.cat([hs[:, :1, :], hs[:, 1:, :].mean(dim=1, keepdim=True)], dim=1)
captions = generate_with_mask(
model, processor, device=device,
encoder_hidden_states=pooled,
encoder_attention_mask=torch.ones(1, 2, dtype=torch.long, device=device),
max_new_tokens=max_new_tokens, num_beams=num_beams,
)
else:
captions = generate_with_mask(
model, processor, device=device,
pixel_values=inputs["pixel_values"],
encoder_attention_mask=mask,
max_new_tokens=max_new_tokens, num_beams=num_beams,
)
caption = captions[0]
elif model_name == "ViT-GPT2 (Standard Cross-Attention)":
from transformers.modeling_outputs import BaseModelOutput
processor, tokenizer, model, device = load_vit_gpt2(weight_source)
inputs = processor(images=image_pil, return_tensors="pt").to(device)
mask, is_squint = build_mask_for_mode(gen_mode, device)
if is_squint:
enc_out = model.encoder(pixel_values=inputs["pixel_values"])
hs = enc_out.last_hidden_state
pooled = torch.cat([hs[:, :1, :], hs[:, 1:, :].mean(dim=1, keepdim=True)], dim=1)
out = model.generate(
encoder_outputs=BaseModelOutput(last_hidden_state=pooled),
decoder_start_token_id=tokenizer.bos_token_id,
max_new_tokens=max_new_tokens, num_beams=num_beams,
length_penalty=length_penalty,
)
else:
out = model.generate(
**inputs,
attention_mask=mask,
max_new_tokens=max_new_tokens, num_beams=num_beams,
length_penalty=length_penalty,
)
caption = tokenizer.decode(out[0], skip_special_tokens=True)
elif model_name == "GIT (Zero Cross-Attention)":
processor, model, device = load_git(weight_source)
inputs = processor(images=image_pil, return_tensors="pt").to(device)
out = model.generate(
**inputs, max_new_tokens=max_new_tokens,
num_beams=num_beams, length_penalty=length_penalty,
)
caption = processor.batch_decode(out, skip_special_tokens=True)[0]
elif model_name == "Custom VLM (Shakespeare Prefix)":
vlm, char_to_idx, idx_to_char, vocab_size, device = load_custom_vlm(weight_source)
if vlm is None:
return "[Custom VLM not available — train first with: python train.py --model custom]"
from transformers import ViTImageProcessor
image_processor = ViTImageProcessor.from_pretrained(
"google/vit-base-patch16-224-in21k", use_fast=True)
pv = image_processor(images=image_pil, return_tensors="pt")["pixel_values"].to(device)
if num_beams > 1:
caption = vlm.generate_beam(pv, char_to_idx, idx_to_char,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
length_penalty=length_penalty)
else:
caption = vlm.generate(pv, char_to_idx, idx_to_char,
max_new_tokens=max_new_tokens)
else:
caption = "Unknown model."
return caption.strip()
# ─────────────────────────────────────────────────────────────────────────────
# Sidebar
# ─────────────────────────────────────────────────────────────────────────────
with st.sidebar:
st.markdown("### 🔬 VLM Caption Lab")
st.markdown("---")
# ── Architecture Selector ─────────────────────────────────────────────────
selected_model = st.selectbox("**Architecture**", MODEL_KEYS, index=0)
# ── Weight Source ─────────────────────────────────────────────────────────
model_dir = MODEL_DIR.get(selected_model)
if model_dir and model_dir not in DISABLE_FINETUNE_FOR:
_ensure_model_outputs_available(model_dir)
weight_options = {"🔵 Base (Pretrained)": "base"}
if model_dir and model_dir not in DISABLE_FINETUNE_FOR and _has_finetuned(model_dir, "best"):
weight_options["🟢 Fine-tuned (Best)"] = "best"
if model_dir and model_dir not in DISABLE_FINETUNE_FOR and _has_finetuned(model_dir, "latest"):
weight_options["🟡 Fine-tuned (Latest)"] = "latest"
weight_choice = st.radio(
"**Weight Source**", list(weight_options.keys()), index=0,
help="Base = HuggingFace pretrained. Best/Latest = your fine-tuned checkpoints."
)
weight_source = weight_options[weight_choice]
if model_dir in DISABLE_FINETUNE_FOR:
st.caption("Fine-tuned weights are disabled for this model.")
elif len(weight_options) == 1:
st.caption("Fine-tuned weights not available for this model.")
st.markdown("---")
if selected_model in ("BLIP (Multimodal Mixture Attention)",
"ViT-GPT2 (Standard Cross-Attention)"):
mode_options = [
"Baseline (Full Attention)",
"Random Patch Dropout (50%)",
"Center-Focus (Inner 8×8)",
"Squint (Global Pool)",
]
elif selected_model == "Custom VLM (Shakespeare Prefix)":
mode_options = ["Shakespeare Prefix"]
else:
mode_options = ["Baseline (Full Attention)"]
selected_mode = st.selectbox("**Generation Mode**", mode_options, index=0)
st.markdown(
f"
{ARCH_INFO[selected_model]}
",
unsafe_allow_html=True,
)
st.markdown("---")
# ── Advanced Controls ─────────────────────────────────────────────────────
with st.expander("⚙️ Advanced Controls", expanded=False):
num_beams = st.select_slider(
"Beam Size", options=[1, 2, 3, 4, 5, 8, 10], value=10,
help="Number of beams in beam search. Higher = better but slower."
)
length_penalty = st.select_slider(
"Length Penalty", options=[0.8, 0.9, 1.0, 1.1, 1.2], value=1.2,
help=">1 favors longer captions, <1 favors shorter."
)
max_new_tokens = st.select_slider(
"Max Tokens", options=[20, 30, 50, 80, 100], value=50,
help="Maximum number of tokens to generate."
)
st.caption(
f"Config: `beams={num_beams}, len_pen={length_penalty}, max_tok={max_new_tokens}`"
)
st.markdown("---")
st.markdown("Toxicity filter: unitary/toxic-bert",
unsafe_allow_html=True)
# ─────────────────────────────────────────────────────────────────────────────
# Main Header
# ─────────────────────────────────────────────────────────────────────────────
st.markdown("VLM Caption Lab 🔬
", unsafe_allow_html=True)
st.markdown(
"Compare cross-attention strategies: BLIP · ViT-GPT2 · GIT · "
"Visual Prefix-Tuning. Upload, pick a mode, and explore different architectures.
",
unsafe_allow_html=True,
)
# ─────────────────────────────────────────────────────────────────────────────
# Helper — render a single caption card
# ─────────────────────────────────────────────────────────────────────────────
def render_caption_card(model_name, caption, weight_src, num_beams, length_penalty,
max_new_tokens, container, card_class="result-card",
caption_class="caption-text", show_params=True):
badge_cls = MODEL_BADGE.get(model_name, "badge-blue")
wt_cls = WEIGHT_TAG_CLASS.get(weight_src, "wt-base")
wt_label = WEIGHT_LABEL.get(weight_src, weight_src)
short = MODEL_SHORT.get(model_name, model_name)
ca = MODEL_CA_TYPE.get(model_name, "")
params_html = ""
if show_params:
params_html = (f"
beams={num_beams} · "
f"len_pen={length_penalty} · max_tok={max_new_tokens}")
container.markdown(
f""
f"
{short}"
f"
{wt_label}"
f"
{ca}"
f"
\"{caption}\"
"
f"{params_html}"
f"
",
unsafe_allow_html=True,
)
# Toxicity check
try:
tox_tok, tox_mdl = load_toxicity_filter()
toxic = is_toxic(caption, tox_tok, tox_mdl)
except Exception:
toxic = False
if toxic:
container.error("⚠️ Flagged by Toxic-BERT")
else:
container.caption("✅ Passed toxicity check")
# ─────────────────────────────────────────────────────────────────────────────
# Tabs
# ─────────────────────────────────────────────────────────────────────────────
tab_caption, tab_compare, tab_attention, tab_task1, tab_task3, tab_task4, tab_task5, tab_results = st.tabs([
"🖼️ Caption",
"🔀 Compare All Models",
"🧭 Task 02 · Word Focus Map",
"📦 Task 01 · Optimization",
"⚖️ Task 03 · Decoding Trade-offs",
"🧪 Task 04 · Diversity & Steering",
"🛡️ Task 05 · Safety & Bias",
"📊 Experiment Results",
])
# ═══════════════════════════════════════════════════════════════════════════
# Tab 1 — Single Model Caption
# ═══════════════════════════════════════════════════════════════════════════
with tab_caption:
col_upload, col_result = st.columns([1, 1.3], gap="large")
with col_upload:
uploaded_file = st.file_uploader(
"Upload an image", type=["jpg", "jpeg", "png", "webp"],
label_visibility="visible",
key="caption_uploader",
)
if uploaded_file:
image = Image.open(uploaded_file).convert("RGB")
st.image(image, caption="Uploaded Image", use_column_width=True)
generate_btn = st.button("✨ Generate Caption",
disabled=(uploaded_file is None),
key="caption_btn")
with col_result:
if uploaded_file and generate_btn:
if not _finetuned_available_for_model(selected_model, weight_source):
st.error(
f"{MODEL_SHORT[selected_model]} does not have '{weight_source}' weights."
)
caption = None
else:
with st.spinner(
f"Loading {MODEL_SHORT[selected_model]} ({weight_source}) + generating…"
):
try:
caption = generate_caption(
selected_model, selected_mode, image,
num_beams=num_beams,
max_new_tokens=max_new_tokens,
length_penalty=length_penalty,
weight_source=weight_source,
)
except Exception as e:
st.error(f"Generation error: {e}")
caption = None
if caption:
render_caption_card(
selected_model, caption, weight_source,
num_beams, length_penalty, max_new_tokens,
container=st,
)
elif not uploaded_file:
st.markdown(
""
"⬅️ Upload an image to get started
",
unsafe_allow_html=True,
)
# ═══════════════════════════════════════════════════════════════════════════
# Tab 2 — Compare All Models
# ═══════════════════════════════════════════════════════════════════════════
with tab_compare:
st.markdown("### 🔀 Multi-Model Comparison")
st.caption(
"Upload one image and generate captions from **all 4 architectures** simultaneously, "
"using the same decoding parameters. Perfect for report screenshots."
)
# Config banner
wt_label = WEIGHT_LABEL.get(weight_source, weight_source)
st.markdown(
f""
f"⚙️ Config: beams={num_beams} · len_pen={length_penalty} · "
f"max_tok={max_new_tokens} · weights={wt_label}"
f"
",
unsafe_allow_html=True,
)
is_common_mode = selected_mode in ["Baseline (Full Attention)", "Shakespeare Prefix"]
if not is_common_mode:
st.warning(
f"⚠️ **Warning:** You have selected **{selected_mode}**.\n\n"
"This generation mode is an ablation experiment and is not supported uniformly by all models. "
"GIT and Custom VLM lack standard cross-attention and cannot process these masks.\n\n"
"👉 **To compare all 4 architectures fairly, please change the Generation Mode in the sidebar to `Baseline (Full Attention)`.**"
)
col_img, col_ctrl = st.columns([1, 1])
with col_img:
compare_file = st.file_uploader(
"Upload an image for comparison", type=["jpg", "jpeg", "png", "webp"],
key="compare_uploader",
)
with col_ctrl:
if compare_file:
compare_image = Image.open(compare_file).convert("RGB")
st.image(compare_image, caption="Comparison Image", use_column_width=True)
compare_btn = st.button("🚀 Compare All 4 Models",
disabled=(compare_file is None or not is_common_mode),
key="compare_btn")
if compare_file and compare_btn:
compare_image = Image.open(compare_file).convert("RGB")
resolved_sources = {}
for model_key in MODEL_KEYS:
resolved_sources[model_key] = weight_source
if weight_source != "base":
missing = [
MODEL_SHORT[m]
for m in MODEL_KEYS
if not _finetuned_available_for_model(m, weight_source)
]
if missing:
st.warning(
"Missing fine-tuned weights for: "
+ ", ".join(missing)
+ ". Marking those results as unavailable."
)
# Generate captions from all 4 models
results = {}
progress = st.progress(0, text="Starting comparison...")
for i, model_key in enumerate(MODEL_KEYS):
short = MODEL_SHORT[model_key]
progress.progress((i) / 4, text=f"Generating with {short}...")
# Apply selected mode to supported models, otherwise use appropriate fallback
if model_key == "Custom VLM (Shakespeare Prefix)":
mode = "Shakespeare Prefix"
elif model_key in ("BLIP (Multimodal Mixture Attention)", "ViT-GPT2 (Standard Cross-Attention)"):
if selected_mode in [
"Baseline (Full Attention)",
"Random Patch Dropout (50%)",
"Center-Focus (Inner 8×8)",
"Squint (Global Pool)"
]:
mode = selected_mode
else:
mode = "Baseline (Full Attention)"
else:
mode = "Baseline (Full Attention)"
if not _finetuned_available_for_model(model_key, weight_source):
results[model_key] = (
f"[Fine-tuned '{weight_source}' weights not available]"
if weight_source != "base"
else "[Not available]"
)
else:
try:
cap = generate_caption(
model_key, mode, compare_image,
num_beams=num_beams,
max_new_tokens=max_new_tokens,
length_penalty=length_penalty,
weight_source=weight_source,
)
results[model_key] = cap
except Exception as e:
results[model_key] = f"[Error: {e}]"
progress.progress(1.0, text="✅ All models complete!")
# Render 2x2 grid
st.markdown("---")
row1_col1, row1_col2 = st.columns(2)
row2_col1, row2_col2 = st.columns(2)
grid = [(MODEL_KEYS[0], row1_col1), (MODEL_KEYS[1], row1_col2),
(MODEL_KEYS[2], row2_col1), (MODEL_KEYS[3], row2_col2)]
for model_key, col in grid:
cap = results.get(model_key, "[Not available]")
with col:
render_caption_card(
model_key, cap, resolved_sources.get(model_key, weight_source),
num_beams, length_penalty, max_new_tokens,
container=st,
card_class="compare-card",
caption_class="compare-caption",
show_params=False,
)
# Summary table
st.markdown("---")
st.markdown("#### 📋 Summary Table")
table_rows = []
for model_key in MODEL_KEYS:
short = MODEL_SHORT[model_key]
ca = MODEL_CA_TYPE[model_key]
cap = results.get(model_key, "–")
word_count = len(cap.split()) if cap and not cap.startswith("[") else 0
table_rows.append(f"| **{short}** | {ca} | {cap[:80]}{'…' if len(cap) > 80 else ''} | {word_count} |")
table_md = (
"| Architecture | Cross-Attention | Caption | Words |\n"
"|---|---|---|---|\n"
+ "\n".join(table_rows)
)
st.markdown(table_md)
st.caption(
f"Generated with: beams={num_beams}, len_pen={length_penalty}, "
f"max_tok={max_new_tokens}, weights={wt_label}"
)
# ═══════════════════════════════════════════════════════════════════════════
# Tab 3 — Word Focus Map (Task 2)
# ═══════════════════════════════════════════════════════════════════════════
with tab_attention:
st.markdown("### 🧭 Task 02 — Word Focus Map")
st.markdown("`Task 02: Attention Weight Visualization & Cross-Attention Rollout for Caption Generation`")
st.caption(
"Step-by-step cross-attention analysis with rollout across decoder layers, "
"2x5 heatmap grid, IoU grounding score, and caption-length summary."
)
attn_col_left, attn_col_right = st.columns([1, 1], gap="large")
with attn_col_left:
attn_file = st.file_uploader(
"Upload an image for attention analysis",
type=["jpg", "jpeg", "png", "webp"],
key="attention_uploader",
)
if attn_file:
attn_image = Image.open(attn_file).convert("RGB")
st.image(attn_image, caption="Attention Input Image", use_column_width=True)
with attn_col_right:
_ensure_model_outputs_available("blip")
attn_weight_options = {"Base (Pretrained)": "base"}
if _has_finetuned("blip", "best"):
attn_weight_options["Fine-tuned (Best)"] = "best"
if _has_finetuned("blip", "latest"):
attn_weight_options["Fine-tuned (Latest)"] = "latest"
attn_weight_choice = st.selectbox(
"BLIP Weight Source",
list(attn_weight_options.keys()),
index=0,
key="attn_weight_choice",
)
attn_weight_source = attn_weight_options[attn_weight_choice]
token_mode = st.radio(
"Token Source",
["Generated Caption", "Custom Text Prompt"],
horizontal=True,
key="attn_token_mode",
)
custom_text = ""
if token_mode == "Custom Text Prompt":
custom_text = st.text_input(
"Enter custom text/words for heatmap tracing",
value="a dog playing with a ball",
key="attn_custom_text",
)
max_attn_steps = st.slider(
"How many words to trace",
min_value=3,
max_value=12,
value=9,
help="One step = one word position in the generated/custom text (word 1, word 2, ...).",
key="attn_steps",
)
run_iou = st.toggle(
"Compute IoU Alignment with OWL-ViT (slower)",
value=True,
key="attn_iou_toggle",
)
run_attention_btn = st.button(
"Run Step-by-Step Attention Analysis",
disabled=(attn_file is None or (token_mode == "Custom Text Prompt" and not custom_text.strip())),
key="attn_run_btn",
)
if run_attention_btn and attn_file:
from models.attention_flow import (
build_attention_grid_figure,
decode_custom_text_with_flow,
decode_generated_caption_with_flow,
encode_image_for_flow,
grade_alignment_with_detector,
summarize_caption_alignment,
)
attn_image = Image.open(attn_file).convert("RGB")
iou_results = []
with st.status("Running attention pipeline...", expanded=True) as status:
st.write("Step 1/5: Loading BLIP model and selected weights")
attn_processor, attn_model, attn_device = load_blip_attention_model(attn_weight_source)
st.write("Step 2/5: Encoding image through ViT")
image_224, enc_hidden, enc_mask = encode_image_for_flow(
attn_model, attn_processor, attn_device, attn_image
)
st.write("Step 3/5: Extracting rollout heatmaps token-by-token")
if token_mode == "Custom Text Prompt":
tokens, heatmaps = decode_custom_text_with_flow(
attn_model,
attn_processor,
attn_device,
enc_hidden,
enc_mask,
custom_text,
max_tokens=max_attn_steps,
)
else:
tokens, heatmaps = decode_generated_caption_with_flow(
attn_model,
attn_processor,
attn_device,
enc_hidden,
enc_mask,
max_tokens=max_attn_steps,
)
st.write("Step 4/5: Building 2x5 attention grid")
fig_grid = build_attention_grid_figure(image_224, tokens, heatmaps, n_rows=2, n_cols=5)
if run_iou:
st.write("Step 5/5: Computing IoU alignment using OWL-ViT detections")
detector = load_alignment_detector()
iou_results = grade_alignment_with_detector(attn_image, tokens, heatmaps, detector)
else:
st.write("Step 5/5: IoU grading skipped by user")
status.update(label="Attention pipeline complete", state="complete", expanded=False)
st.pyplot(fig_grid, use_container_width=True)
caption_tokens = " ".join(tokens) if tokens else "[No tokens generated]"
st.markdown(f"**Decoded tokens:** `{caption_tokens}`")
summary = summarize_caption_alignment(iou_results, len(tokens))
st.markdown(
f"**Caption length:** `{summary['caption_length']}` | "
f"**Mean alignment IoU:** `{summary['mean_alignment_iou']:.4f}`"
)
if run_iou:
st.markdown("#### Word-level Alignment (IoU)")
if iou_results:
table_rows = [
{
"word": item["word"],
"position": item["position"],
"iou": round(item["iou"], 4),
"det_score": round(item["det_score"], 4),
"box": [int(x) for x in item["box"]],
}
for item in iou_results
]
st.dataframe(table_rows, use_container_width=True)
strong = [item["word"] for item in iou_results if item["iou"] >= 0.30]
weak = [item["word"] for item in iou_results if item["iou"] < 0.10]
if strong:
st.success("Strongly grounded words: " + ", ".join(strong))
if weak:
st.warning("Weakly grounded words: " + ", ".join(weak))
else:
st.info("No detectable object-word matches found for IoU grading on this run.")
if "alignment_history" not in st.session_state:
st.session_state["alignment_history"] = []
st.session_state["alignment_history"].append(
{
"caption_length": int(summary["caption_length"]),
"mean_alignment_iou": float(summary["mean_alignment_iou"]),
"mode": token_mode,
"weights": attn_weight_source,
}
)
st.markdown("#### Caption Length -> Mean Alignment IoU")
history = st.session_state["alignment_history"]
if history:
try:
import matplotlib.pyplot as plt
x_vals = [item["caption_length"] for item in history]
y_vals = [item["mean_alignment_iou"] for item in history]
fig_summary, ax_summary = plt.subplots(figsize=(6, 3.2))
ax_summary.scatter(x_vals, y_vals, color="#58a6ff", alpha=0.85)
if len(x_vals) > 1:
z = np.polyfit(x_vals, y_vals, 1)
trend = np.poly1d(z)
xs = sorted(x_vals)
ax_summary.plot(xs, [trend(v) for v in xs], linestyle="--", color="#ff7b72")
ax_summary.set_xlabel("Caption length")
ax_summary.set_ylabel("Mean IoU")
ax_summary.set_title("Alignment Trend")
ax_summary.grid(alpha=0.35, linestyle="--")
st.pyplot(fig_summary, use_container_width=True)
except Exception:
pass
st.dataframe(history[-20:], use_container_width=True)
# ═══════════════════════════════════════════════════════════════════════════
# Tab 4 — Task 01 On-Device Optimization (Static)
# ═══════════════════════════════════════════════════════════════════════════
with tab_task1:
st.markdown("### 📦 Task 01 — On-Device Optimization Lab")
st.markdown("`Task 01: End-to-End Optimization of BLIP for On-Device Inference`")
st.info("Static mode: showing precomputed artifacts for evaluator stability.")
task1_payload = load_task1_demo_bundle()
st.caption(
f"Result source: `{task1_payload.get('source', 'unknown')}` | "
f"Output folder: `{task1_payload.get('run_dir', TASK1_RESULTS_DIR)}`"
)
bench = task1_payload.get("benchmark_results", {})
fp32 = bench.get("pytorch_fp32", {})
coreml = bench.get("coreml_4bit", {})
if fp32 and coreml:
speedup = fp32.get("latency_per_100", 1.0) / max(coreml.get("latency_per_100", 0.01), 0.01)
size_reduction = (1 - coreml.get("model_size_mb", 1.0) / max(fp32.get("model_size_mb", 1.0), 1.0)) * 100
k1, k2, k3 = st.columns(3)
k1.metric("CoreML Speedup vs fp32", f"{speedup:.2f}x")
k2.metric("Model Size Reduction", f"{size_reduction:.1f}%")
k3.metric("BLEU-4 Drop", f"{(fp32.get('bleu4', 0.0) - coreml.get('bleu4', 0.0)):.4f}")
st.markdown("#### Benchmark Table")
rows = []
for key in ["pytorch_fp32", "pytorch_fp16_amp", "onnx_fp32", "coreml_4bit"]:
if key in bench and bench[key]:
row = dict(bench[key])
row["backend_key"] = key
rows.append(row)
if rows:
st.dataframe(rows, use_container_width=True)
else:
st.warning("No benchmark rows found in precomputed Task 01 results.")
st.markdown("#### Figures")
fig_paths = task1_payload.get("figure_paths", {})
f1, f2 = st.columns(2)
ms_path = fig_paths.get("model_size", os.path.join(task1_payload["run_dir"], "model_size_comparison.png"))
lat_path = fig_paths.get("latency", os.path.join(task1_payload["run_dir"], "latency_comparison.png"))
trn_path = fig_paths.get("training_curve", os.path.join(task1_payload["run_dir"], "training_curve.png"))
bleu_path = fig_paths.get("bleu4", os.path.join(task1_payload["run_dir"], "bleu4_comparison.png"))
if os.path.exists(ms_path):
f1.image(ms_path, caption="Model Size Comparison", use_column_width=True)
if os.path.exists(lat_path):
f2.image(lat_path, caption="Latency Comparison", use_column_width=True)
f3, f4 = st.columns(2)
if os.path.exists(trn_path):
f3.image(trn_path, caption="Training Curve", use_column_width=True)
if os.path.exists(bleu_path):
f4.image(bleu_path, caption="BLEU-4 + Memory", use_column_width=True)
if task1_payload.get("findings_md"):
with st.expander("Show Findings Report"):
st.markdown(task1_payload["findings_md"])
# ═══════════════════════════════════════════════════════════════════════════
# Tab 5 — Task 03 Decoding Trade-offs (Static)
# ═══════════════════════════════════════════════════════════════════════════
with tab_task3:
st.markdown("### ⚖️ Task 03 — Decoding Trade-offs Lab")
st.markdown("`Task 03: Beam Search & Length Penalty Ablation for Caption Quality Trade-offs`")
st.info("Static mode: showing precomputed ablation results.")
demo_results, demo_figures, _ = load_task3_demo_bundle()
task3_payload = {
"results": demo_results,
"figure_paths": demo_figures,
"run_dir": TASK3_RESULTS_DIR,
"source": "precomputed",
}
st.caption(
f"Result source: `{task3_payload.get('source', 'unknown')}` | "
f"Output folder: `{task3_payload['run_dir']}`"
)
all_results = task3_payload["results"]
sorted_results = sorted(all_results, key=lambda row: -row["cider"]) if all_results else []
if sorted_results:
beam_filter = st.multiselect(
"Filter Beam Sizes",
options=sorted({int(row["beam_size"]) for row in sorted_results}),
default=sorted({int(row["beam_size"]) for row in sorted_results}),
key="task3_beam_filter_static",
)
lp_filter = st.multiselect(
"Filter Length Penalties",
options=sorted({float(row["length_penalty"]) for row in sorted_results}),
default=sorted({float(row["length_penalty"]) for row in sorted_results}),
key="task3_lp_filter_static",
)
filtered = [
row for row in sorted_results
if int(row["beam_size"]) in beam_filter and float(row["length_penalty"]) in lp_filter
]
st.dataframe(filtered, use_container_width=True)
if filtered:
best = max(filtered, key=lambda row: row["cider"])
m1, m2, m3 = st.columns(3)
m1.metric("Best CIDEr", f"{best['cider']:.4f}")
m2.metric("Best Config", f"beam={best['beam_size']}, lp={best['length_penalty']}")
m3.metric("Latency/100", f"{best['latency_per_100']:.1f}s")
else:
st.warning("No precomputed Task 03 rows found.")
fig_paths = task3_payload.get("figure_paths", {})
c1, c2, c3 = st.columns(3)
heatmap_path = fig_paths.get("heatmap", os.path.join(task3_payload["run_dir"], "cider_heatmap.png"))
latency_path = fig_paths.get("latency", os.path.join(task3_payload["run_dir"], "latency_barchart.png"))
scatter_path = fig_paths.get("scatter", os.path.join(task3_payload["run_dir"], "quality_speed_scatter.png"))
if os.path.exists(heatmap_path):
c1.image(heatmap_path, caption="CIDEr Heatmap", use_column_width=True)
if os.path.exists(latency_path):
c2.image(latency_path, caption="Latency Bar Chart", use_column_width=True)
if os.path.exists(scatter_path):
c3.image(scatter_path, caption="Quality vs Speed", use_column_width=True)
report_path = os.path.join(task3_payload["run_dir"], "findings.md")
if os.path.exists(report_path):
with st.expander("Show Detailed Findings Report"):
with open(report_path, "r", encoding="utf-8") as handle:
st.markdown(handle.read())
# ═══════════════════════════════════════════════════════════════════════════
# Tab 6 — Task 04 Diversity + Steering (Static)
# ═══════════════════════════════════════════════════════════════════════════
with tab_task4:
st.markdown("### 🧪 Task 04 — Diversity & Steering Lab")
st.markdown("`Task 04: Caption Diversity Analysis & Concept Activation Vectors for Style Steering`")
st.info("Static mode: showing precomputed diversity and steering outputs.")
task4_payload = load_task4_demo_bundle()
records = task4_payload.get("diversity_records", [])
steering = task4_payload.get("steering_results", [])
st.caption(
f"Result source: `{task4_payload.get('source', 'unknown')}` | "
f"Output folder: `{task4_payload.get('run_dir', TASK4_RESULTS_DIR)}`"
)
if records:
n_total = len(records)
scores = [float(r.get("diversity_score", 0.0)) for r in records]
n_diverse = sum(1 for r in records if str(r.get("category", "")).lower() == "diverse")
n_repetitive = sum(1 for r in records if str(r.get("category", "")).lower() == "repetitive")
k1, k2, k3, k4 = st.columns(4)
k1.metric("Images Analysed", f"{n_total}")
k2.metric("Mean Diversity", f"{(sum(scores)/max(n_total, 1)):.4f}")
k3.metric("Diverse", f"{n_diverse} ({100*n_diverse/max(n_total,1):.1f}%)")
k4.metric("Repetitive", f"{n_repetitive} ({100*n_repetitive/max(n_total,1):.1f}%)")
st.markdown("#### Sample Captions by Diversity Category")
top_diverse = sorted(records, key=lambda r: float(r.get("diversity_score", 0.0)), reverse=True)[:3]
top_repetitive = sorted(records, key=lambda r: float(r.get("diversity_score", 0.0)))[:3]
col_d, col_r = st.columns(2)
with col_d:
st.markdown("**Most Diverse Samples**")
for row in top_diverse:
st.markdown(
f"- `image_id={row.get('image_id')}` · score={float(row.get('diversity_score', 0.0)):.4f}"
)
with col_r:
st.markdown("**Most Repetitive Samples**")
for row in top_repetitive:
st.markdown(
f"- `image_id={row.get('image_id')}` · score={float(row.get('diversity_score', 0.0)):.4f}"
)
if steering:
st.markdown("#### Steering λ Sweep")
st.dataframe(steering, use_container_width=True)
best_style = max(steering, key=lambda r: float(r.get("style_score", -1.0)))
m1, m2, m3 = st.columns(3)
m1.metric("Best Style Score", f"{float(best_style.get('style_score', 0.0)):.4f}")
m2.metric("Best Lambda", f"{float(best_style.get('lambda', 0.0)):+.1f}")
m3.metric("Mean Length @ Best", f"{float(best_style.get('mean_length', 0.0)):.2f}")
fig_paths = task4_payload.get("figure_paths", {})
p1, p2, p3 = st.columns(3)
hist_path = fig_paths.get("diversity_histogram")
ext_path = fig_paths.get("diverse_vs_repetitive")
sweep_path = fig_paths.get("steering_lambda_sweep")
if hist_path and os.path.exists(hist_path):
p1.image(hist_path, caption="Diversity Histogram", use_column_width=True)
if ext_path and os.path.exists(ext_path):
p2.image(ext_path, caption="Diverse vs Repetitive", use_column_width=True)
if sweep_path and os.path.exists(sweep_path):
p3.image(sweep_path, caption="Steering λ Sweep", use_column_width=True)
if task4_payload.get("findings_md"):
with st.expander("Show Findings Report"):
st.markdown(task4_payload["findings_md"])
# ═══════════════════════════════════════════════════════════════════════════
# Tab 7 — Task 05 Safety + Bias (Static)
# ═══════════════════════════════════════════════════════════════════════════
with tab_task5:
st.markdown("### 🛡️ Task 05 — Safety & Bias Lab")
st.markdown("`Task 05: Toxicity & Bias Detection in Generated Captions with Mitigation`")
st.info("Static mode: showing precomputed toxicity, bias, and mitigation outputs.")
task5_payload = load_task5_demo_bundle()
toxicity_scores = task5_payload.get("toxicity_scores", [])
bias_audit = task5_payload.get("bias_audit", {})
mitigation_results = task5_payload.get("mitigation_results", [])
bias_records = bias_audit.get("records", [])
freq_table = bias_audit.get("freq_table", {})
st.caption(
f"Result source: `{task5_payload.get('source', 'unknown')}` | "
f"Output folder: `{task5_payload.get('run_dir', TASK5_RESULTS_DIR)}`"
)
n_caps = len(toxicity_scores)
n_toxic = sum(1 for row in toxicity_scores if bool(row.get("flagged", False)))
n_bias = sum(1 for row in bias_records if bool(row.get("flagged", False)))
n_mitigated = sum(1 for row in mitigation_results if bool(row.get("mitigated", False)))
m1, m2, m3 = st.columns(3)
m1.metric("Toxicity Flag Rate", f"{(100*n_toxic/max(n_caps,1)):.2f}% ({n_toxic}/{n_caps})")
m2.metric("Bias Flag Rate", f"{(100*n_bias/max(len(bias_records),1)):.2f}% ({n_bias}/{len(bias_records)})")
m3.metric("Mitigated Cases", f"{n_mitigated}/{len(mitigation_results)}")
if toxicity_scores:
st.markdown("#### Top Toxicity Captions")
top_toxic = sorted(toxicity_scores, key=lambda r: float(r.get("max_score", 0.0)), reverse=True)[:20]
st.dataframe(
[
{
"image_id": row.get("image_id"),
"max_score": round(float(row.get("max_score", 0.0)), 4),
"flagged": bool(row.get("flagged", False)),
"caption": row.get("caption", ""),
}
for row in top_toxic
],
use_container_width=True,
)
if freq_table:
st.markdown("#### Bias Pattern Frequency")
freq_rows = []
for pattern, value in freq_table.items():
if isinstance(value, dict):
for sub_pattern, sub_value in value.items():
try:
count_val = float(sub_value)
except (TypeError, ValueError):
count_val = 0.0
freq_rows.append(
{
"pattern": f"{pattern}::{sub_pattern}",
"count": count_val,
}
)
else:
try:
count_val = float(value)
except (TypeError, ValueError):
count_val = 0.0
freq_rows.append({"pattern": str(pattern), "count": count_val})
st.dataframe(
sorted(freq_rows, key=lambda row: row["count"], reverse=True),
use_container_width=True,
)
if mitigation_results:
st.markdown("#### Mitigation Examples")
st.dataframe(mitigation_results, use_container_width=True)
fig_paths = task5_payload.get("figure_paths", {})
s1, s2, s3 = st.columns(3)
tox_plot = fig_paths.get("toxicity_distribution")
bias_plot = fig_paths.get("bias_heatmap")
ba_plot = fig_paths.get("before_after")
if tox_plot and os.path.exists(tox_plot):
s1.image(tox_plot, caption="Toxicity Distribution", use_column_width=True)
if bias_plot and os.path.exists(bias_plot):
s2.image(bias_plot, caption="Bias Heatmap", use_column_width=True)
if ba_plot and os.path.exists(ba_plot):
s3.image(ba_plot, caption="Before vs After Mitigation", use_column_width=True)
if task5_payload.get("report_md"):
with st.expander("Show Fairness Report"):
st.markdown(task5_payload["report_md"])
# ═══════════════════════════════════════════════════════════════════════════
# Tab 8 — Experiment Results
# ═══════════════════════════════════════════════════════════════════════════
with tab_results:
st.markdown("### 📊 Pre-Computed Benchmark Results")
st.caption(
"These results were computed on 25 batches of the COCO validation set "
"(whyen-wang/coco_captions). Run `python eval.py --model all` to reproduce."
)
with st.expander("🏆 Architecture Comparison (CIDEr)", expanded=True):
st.markdown("""
| Architecture | Cross-Attention Type | CIDEr (base) | Notes |
|---|---|---|---|
| **BLIP** | Gated MED cross-attention | ~0.94 | Best overall; ablation-ready |
| **ViT-GPT2** | Standard full cross-attention | ~0.82 | Brute-force; ablation-ready |
| **GIT** | Self-attention prefix (no CA) | ~0.79 | Competitive despite no CA |
| **Custom VLM** | Linear bridge prefix (no CA) | ~0.18 | Char-level; Shakespeare style |
> **Key insight:** GIT achieves competitive CIDEr without any cross-attention block,
> proving that concatenation-based fusion can rival explicit cross-attention in practice.
""")
with st.expander("🔬 Cross-Attention Ablation (BLIP)", expanded=True):
st.markdown("""
| Ablation Mode | Mask | CIDEr | Δ Baseline | Insight |
|---|---|---|---|---|
| **Baseline** | All 197 patches | ~0.94 | — | Upper-bound |
| **Random Dropout 50%** | 98/196 patches masked | ~0.88 | -0.06 | ~6% redundancy |
| **Center-Focus 8×8** | Inner 64 patches only | ~0.91 | -0.03 | Background is mostly noise |
| **Squint (Global Pool)** | 197→2 tokens (CLS+pool) | ~0.78 | -0.16 | Local detail matters ~17% |
> **Interpretation:** BLIP's cross-attention is robust to losing 50% of spatial patches
> (only ~6% CIDEr drop), but compressing to a single global summary loses ~17%.
""")
with st.expander("⚙️ Decoding Parameter Sweep (BLIP)", expanded=True):
st.markdown("""
| Beam Size | Length Penalty | Max Tokens | CIDEr | Caption Style |
|---|---|---|---|---|
| 3 | 1.0 | 20 | ~0.87 | Short, high precision |
| **5** | **1.0** | **50** | **~0.94** | **✅ Best balance** |
| 10 | 1.0 | 50 | ~0.94 | Marginal gain vs beam=5 |
| 5 | 0.8 | 50 | ~0.89 | Slightly shorter captions |
| 5 | 1.2 | 50 | ~0.93 | Slightly longer captions |
| 5 | 1.0 | 20 | ~0.91 | Length-limited |
> **Key insight:** beam=5 and max_tokens=50 are the sweet spot. Going to beam=10
> yields <0.5% improvement at 2× inference cost. Length penalty has a smaller
> effect than beam size or max_tokens for CIDEr.
""")
with st.expander("📋 Data Preparation Analysis (BLIP)", expanded=True):
st.markdown("""
| Strategy | Description | CIDEr | Δ Raw |
|---|---|---|---|
| **raw** | Any random caption | ~0.88 | — |
| **short** | Captions ≤ 9 words | ~0.79 | -0.09 |
| **long** | Captions ≥ 12 words | ~0.86 | -0.02 |
| **filtered** ✅ | 5–25 words (recommended) | ~0.94 | **+0.06** |
> **Why filtering helps:** COCO contains ~8% captions with < 5 words (often just
> object names) and ~4% with > 25 words (complex sentences the model can't learn well).
> Filtering to 5–25 words removes noise at both ends and improves CIDEr by ~6%.
""")
st.markdown("---")
st.markdown(
""
"Run experiments: "
"python eval.py --model all | "
"python eval.py --ablation | "
"python -m experiments.parameter_sweep | "
"python -m experiments.data_prep_analysis"
"
",
unsafe_allow_html=True,
)
# ─────────────────────────────────────────────────────────────────────────────
# Footer
# ─────────────────────────────────────────────────────────────────────────────
st.markdown("---")
st.markdown(
""
"VLM Caption Lab · Image Captioning · Cross-Attention Ablation Study · "
"BLIP · ViT-GPT2 · GIT · Visual Prefix-Tuning"
"
",
unsafe_allow_html=True,
)