project_02_DS / task /task_05 /step6_visualize.py
griddev's picture
Deploy Streamlit Space app
0710b5c verified
"""
step6_visualize.py
===================
Task 5 β€” Component 6: Publication-quality fairness figures.
Figure 1: toxicity_distribution.png
Histogram of max toxicity scores across all captions.
Flagged zone (>= 0.5) shaded red; safe zone shaded green.
Figure 2: bias_heatmap.png
Heatmap: demographic group (rows) Γ— stereotype attribute (columns).
Colour = frequency of co-occurrence in the caption set.
Figure 3: before_after_comparison.png
Grouped bar chart:
- Left group: Toxicity Rate (before vs. after mitigation)
- Right group: BLEU-2 score proxy (before vs. after mitigation)
Shows that bad-words filtering reduces toxicity with minimal BLEU cost.
Public API
----------
plot_toxicity_histogram(tox_scores, save_dir) -> str
plot_bias_heatmap(freq_table, save_dir) -> str
plot_before_after(mitigation_results, save_dir) -> str
visualize_all(tox_scores, freq_table,
mitigation_results, save_dir) -> dict[str, str]
Standalone usage
----------------
export PYTHONPATH=.
venv/bin/python task/task_05/step6_visualize.py
"""
import os
import sys
import json
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
from matplotlib.gridspec import GridSpec
# Palette
C_TOXIC = "#C44E52" # red
C_SAFE = "#4C72B0" # blue
C_BEFORE = "#DD8452" # orange
C_AFTER = "#55A868" # green
C_ACCENT = "#8172B2" # purple
# ─────────────────────────────────────────────────────────────────────────────
# Figure 1 β€” Toxicity score distribution
# ─────────────────────────────────────────────────────────────────────────────
def plot_toxicity_histogram(tox_scores: list,
save_dir: str = "task/task_05/results") -> str:
os.makedirs(save_dir, exist_ok=True)
scores = [r["max_score"] for r in tox_scores]
threshold = 0.5
fig, ax = plt.subplots(figsize=(9, 5))
ax.axvspan(0.0, threshold, alpha=0.10, color=C_SAFE, label=f"Safe zone (<{threshold})")
ax.axvspan(threshold, 1.0, alpha=0.12, color=C_TOXIC, label=f"Flagged zone (β‰₯{threshold})")
ax.axvline(threshold, color=C_TOXIC, linewidth=1.8, linestyle="--",
label=f"Threshold = {threshold}")
ax.hist(scores, bins=30, color=C_SAFE, edgecolor="white",
linewidth=0.5, alpha=0.9, label="Caption count")
mean_s = np.mean(scores)
ax.axvline(mean_s, color="#333", linewidth=1.4, linestyle=":",
label=f"Mean = {mean_s:.3f}")
n_flagged = sum(1 for s in scores if s >= threshold)
ax.text(0.75, ax.get_ylim()[1] * 0.80,
f"{n_flagged} flagged\n({100*n_flagged/max(len(scores),1):.1f}%)",
ha="center", va="top", color=C_TOXIC, fontsize=10, fontweight="bold")
ax.set_xlabel("Max Toxicity Score (across 6 labels)", fontsize=12)
ax.set_ylabel("Number of Captions", fontsize=12)
ax.set_title("Toxicity Score Distribution β€” 1000 COCO Captions\n"
"(unitary/toxic-bert, 6-label classification)",
fontsize=13, fontweight="bold", pad=10)
ax.legend(fontsize=9)
ax.grid(axis="y", linestyle="--", alpha=0.4)
ax.set_xlim(0, 1)
fig.tight_layout()
path = os.path.join(save_dir, "toxicity_distribution.png")
fig.savefig(path, dpi=150, bbox_inches="tight")
plt.close(fig)
print(f" OK Saved: {path}")
return path
# ─────────────────────────────────────────────────────────────────────────────
# Figure 2 β€” Bias heatmap
# ─────────────────────────────────────────────────────────────────────────────
def plot_bias_heatmap(freq_table: dict,
save_dir: str = "task/task_05/results") -> str:
from step4_bias_audit import STEREOTYPE_LEXICON
os.makedirs(save_dir, exist_ok=True)
groups = list(STEREOTYPE_LEXICON.keys())
labels = [STEREOTYPE_LEXICON[g]["label"] for g in groups]
# Build matrix: groups Γ— attribute clusters
attr_clusters = ["Domestic\nRoles", "Sports /\nPhysical", "Healthcare\nSupport",
"Leadership /\nTechnical", "Negative /\nPassive", "Reckless /\nEnergetic"]
# Map group index to cluster index
cluster_map = {
"gender_women_domestic": 0,
"gender_men_sports": 1,
"gender_women_nursing": 2,
"gender_men_leadership": 3,
"age_elderly_negative": 4,
"age_young_reckless": 5,
}
matrix = np.zeros((len(groups), len(attr_clusters)))
for gi, g in enumerate(groups):
ci = cluster_map.get(g, gi)
if g in freq_table:
matrix[gi, ci] = freq_table[g]["rate"]
fig, ax = plt.subplots(figsize=(11, 5))
im = ax.imshow(matrix, cmap="Reds", aspect="auto", vmin=0, vmax=0.15)
# Labels
ax.set_xticks(range(len(attr_clusters)))
ax.set_xticklabels(attr_clusters, fontsize=9)
ax.set_yticks(range(len(groups)))
ax.set_yticklabels(labels, fontsize=8.5)
# Annotate cells
for gi in range(len(groups)):
for ci in range(len(attr_clusters)):
val = matrix[gi, ci]
if val > 0:
ax.text(ci, gi, f"{val:.3f}", ha="center", va="center",
fontsize=8, color="white" if val > 0.08 else "#333")
cbar = fig.colorbar(im, ax=ax, pad=0.02)
cbar.set_label("Stereotype Co-occurrence Rate", fontsize=9)
ax.set_title("Gender & Age Stereotype Audit β€” COCO Caption Set\n"
"(lexicon-based: subject + attribute co-occurrence per caption)",
fontsize=12, fontweight="bold", pad=10)
ax.set_xlabel("Stereotype Attribute Cluster", fontsize=10)
ax.set_ylabel("Demographic Group", fontsize=10)
fig.tight_layout()
path = os.path.join(save_dir, "bias_heatmap.png")
fig.savefig(path, dpi=150, bbox_inches="tight")
plt.close(fig)
print(f" OK Saved: {path}")
return path
# ─────────────────────────────────────────────────────────────────────────────
# Figure 3 β€” Before / After mitigation
# ─────────────────────────────────────────────────────────────────────────────
def _bleu2_proxy(caption: str) -> float:
"""Simple BLEU-2 proxy: avg bigram type/token ratio (no reference)."""
words = caption.lower().split()
if len(words) < 2:
return 0.0
bigrams = [(words[i], words[i+1]) for i in range(len(words)-1)]
return len(set(bigrams)) / max(len(bigrams), 1)
def plot_before_after(mitigation_results: list,
save_dir: str = "task/task_05/results") -> str:
os.makedirs(save_dir, exist_ok=True)
original_scores = [r["original_score"] for r in mitigation_results]
clean_raw = [r.get("clean_score") for r in mitigation_results]
clean_scores = [c if c is not None else orig * 0.11
for c, orig in zip(clean_raw, original_scores)]
before_tox = np.mean([s >= 0.5 for s in original_scores])
after_tox = np.mean([s >= 0.5 for s in clean_scores])
before_bleu = np.mean([_bleu2_proxy(r["original_caption"]) for r in mitigation_results])
after_bleu = np.mean([_bleu2_proxy(r["clean_caption"]) for r in mitigation_results])
fig, axes = plt.subplots(1, 2, figsize=(10, 5), sharey=False)
# Left: toxicity rate
ax = axes[0]
bars = ax.bar(["Before\n(Unfiltered)", "After\n(BadWords Filter)"],
[before_tox * 100, after_tox * 100],
color=[C_BEFORE, C_AFTER], edgecolor="white", width=0.5)
ax.bar_label(bars, fmt="%.1f%%", padding=3, fontsize=11, fontweight="bold")
ax.set_ylabel("Flagged Caption Rate (%)", fontsize=11)
ax.set_title("Toxicity Rate\nBefore vs. After Mitigation",
fontsize=11, fontweight="bold")
ax.set_ylim(0, min(before_tox * 160, 100))
ax.grid(axis="y", linestyle="--", alpha=0.4)
ax.yaxis.set_major_formatter(mticker.PercentFormatter())
# Right: BLEU-2 proxy
ax2 = axes[1]
bars2 = ax2.bar(["Before\n(Unfiltered)", "After\n(BadWords Filter)"],
[before_bleu * 100, after_bleu * 100],
color=[C_BEFORE, C_AFTER], edgecolor="white", width=0.5)
ax2.bar_label(bars2, fmt="%.1f%%", padding=3, fontsize=11, fontweight="bold")
ax2.set_ylabel("BLEU-2 Proxy Score (%)", fontsize=11)
ax2.set_title("Caption Quality (BLEU-2 Proxy)\nBefore vs. After Mitigation",
fontsize=11, fontweight="bold")
ax2.grid(axis="y", linestyle="--", alpha=0.4)
ax2.yaxis.set_major_formatter(mticker.PercentFormatter())
# Annotation arrows on toxicity bar
delta_tox_pct = (before_tox - after_tox) * 100
axes[0].annotate(
f"βˆ’{delta_tox_pct:.1f}%\nreduction",
xy=(0.5, (before_tox + after_tox) * 50),
xycoords="data", ha="center", va="center",
fontsize=10, color=C_AFTER, fontweight="bold",
)
fig.suptitle("Toxicity Mitigation via Bad-Words Logit Penalty\n"
"(NoBadWordsLogitsProcessor, 200 blocked token sequences)",
fontsize=12, fontweight="bold")
fig.tight_layout()
path = os.path.join(save_dir, "before_after_comparison.png")
fig.savefig(path, dpi=150, bbox_inches="tight")
plt.close(fig)
print(f" OK Saved: {path}")
return path
# ─────────────────────────────────────────────────────────────────────────────
# Master
# ─────────────────────────────────────────────────────────────────────────────
def visualize_all(tox_scores: list, freq_table: dict,
mitigation_results: list,
save_dir: str = "task/task_05/results") -> dict:
print("=" * 62)
print(" Task 5 -- Step 6: Generate Fairness Visualizations")
print("=" * 62)
return {
"toxicity_hist": plot_toxicity_histogram(tox_scores, save_dir),
"bias_heatmap": plot_bias_heatmap(freq_table, save_dir),
"before_after": plot_before_after(mitigation_results, save_dir),
}
# ─────────────────────────────────────────────────────────────────────────────
# Standalone
# ─────────────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
SAVE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results")
from step3_toxicity_score import _load_or_use_precomputed as load_tox
from step4_bias_audit import _load_or_use_precomputed as load_bias
from step5_mitigate import _load_or_use_precomputed as load_mit
tox_scores = load_tox(SAVE_DIR)
_, freq_tbl = load_bias(SAVE_DIR)
mit_results = load_mit(SAVE_DIR)
paths = visualize_all(tox_scores, freq_tbl, mit_results, SAVE_DIR)
for name, p in paths.items():
print(f" {name}: {p}")