Spaces:
Sleeping
Sleeping
| """ | |
| step6_visualize.py | |
| =================== | |
| Task 5 β Component 6: Publication-quality fairness figures. | |
| Figure 1: toxicity_distribution.png | |
| Histogram of max toxicity scores across all captions. | |
| Flagged zone (>= 0.5) shaded red; safe zone shaded green. | |
| Figure 2: bias_heatmap.png | |
| Heatmap: demographic group (rows) Γ stereotype attribute (columns). | |
| Colour = frequency of co-occurrence in the caption set. | |
| Figure 3: before_after_comparison.png | |
| Grouped bar chart: | |
| - Left group: Toxicity Rate (before vs. after mitigation) | |
| - Right group: BLEU-2 score proxy (before vs. after mitigation) | |
| Shows that bad-words filtering reduces toxicity with minimal BLEU cost. | |
| Public API | |
| ---------- | |
| plot_toxicity_histogram(tox_scores, save_dir) -> str | |
| plot_bias_heatmap(freq_table, save_dir) -> str | |
| plot_before_after(mitigation_results, save_dir) -> str | |
| visualize_all(tox_scores, freq_table, | |
| mitigation_results, save_dir) -> dict[str, str] | |
| Standalone usage | |
| ---------------- | |
| export PYTHONPATH=. | |
| venv/bin/python task/task_05/step6_visualize.py | |
| """ | |
| import os | |
| import sys | |
| import json | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) | |
| import numpy as np | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import matplotlib.ticker as mticker | |
| from matplotlib.gridspec import GridSpec | |
| # Palette | |
| C_TOXIC = "#C44E52" # red | |
| C_SAFE = "#4C72B0" # blue | |
| C_BEFORE = "#DD8452" # orange | |
| C_AFTER = "#55A868" # green | |
| C_ACCENT = "#8172B2" # purple | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Figure 1 β Toxicity score distribution | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def plot_toxicity_histogram(tox_scores: list, | |
| save_dir: str = "task/task_05/results") -> str: | |
| os.makedirs(save_dir, exist_ok=True) | |
| scores = [r["max_score"] for r in tox_scores] | |
| threshold = 0.5 | |
| fig, ax = plt.subplots(figsize=(9, 5)) | |
| ax.axvspan(0.0, threshold, alpha=0.10, color=C_SAFE, label=f"Safe zone (<{threshold})") | |
| ax.axvspan(threshold, 1.0, alpha=0.12, color=C_TOXIC, label=f"Flagged zone (β₯{threshold})") | |
| ax.axvline(threshold, color=C_TOXIC, linewidth=1.8, linestyle="--", | |
| label=f"Threshold = {threshold}") | |
| ax.hist(scores, bins=30, color=C_SAFE, edgecolor="white", | |
| linewidth=0.5, alpha=0.9, label="Caption count") | |
| mean_s = np.mean(scores) | |
| ax.axvline(mean_s, color="#333", linewidth=1.4, linestyle=":", | |
| label=f"Mean = {mean_s:.3f}") | |
| n_flagged = sum(1 for s in scores if s >= threshold) | |
| ax.text(0.75, ax.get_ylim()[1] * 0.80, | |
| f"{n_flagged} flagged\n({100*n_flagged/max(len(scores),1):.1f}%)", | |
| ha="center", va="top", color=C_TOXIC, fontsize=10, fontweight="bold") | |
| ax.set_xlabel("Max Toxicity Score (across 6 labels)", fontsize=12) | |
| ax.set_ylabel("Number of Captions", fontsize=12) | |
| ax.set_title("Toxicity Score Distribution β 1000 COCO Captions\n" | |
| "(unitary/toxic-bert, 6-label classification)", | |
| fontsize=13, fontweight="bold", pad=10) | |
| ax.legend(fontsize=9) | |
| ax.grid(axis="y", linestyle="--", alpha=0.4) | |
| ax.set_xlim(0, 1) | |
| fig.tight_layout() | |
| path = os.path.join(save_dir, "toxicity_distribution.png") | |
| fig.savefig(path, dpi=150, bbox_inches="tight") | |
| plt.close(fig) | |
| print(f" OK Saved: {path}") | |
| return path | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Figure 2 β Bias heatmap | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def plot_bias_heatmap(freq_table: dict, | |
| save_dir: str = "task/task_05/results") -> str: | |
| from step4_bias_audit import STEREOTYPE_LEXICON | |
| os.makedirs(save_dir, exist_ok=True) | |
| groups = list(STEREOTYPE_LEXICON.keys()) | |
| labels = [STEREOTYPE_LEXICON[g]["label"] for g in groups] | |
| # Build matrix: groups Γ attribute clusters | |
| attr_clusters = ["Domestic\nRoles", "Sports /\nPhysical", "Healthcare\nSupport", | |
| "Leadership /\nTechnical", "Negative /\nPassive", "Reckless /\nEnergetic"] | |
| # Map group index to cluster index | |
| cluster_map = { | |
| "gender_women_domestic": 0, | |
| "gender_men_sports": 1, | |
| "gender_women_nursing": 2, | |
| "gender_men_leadership": 3, | |
| "age_elderly_negative": 4, | |
| "age_young_reckless": 5, | |
| } | |
| matrix = np.zeros((len(groups), len(attr_clusters))) | |
| for gi, g in enumerate(groups): | |
| ci = cluster_map.get(g, gi) | |
| if g in freq_table: | |
| matrix[gi, ci] = freq_table[g]["rate"] | |
| fig, ax = plt.subplots(figsize=(11, 5)) | |
| im = ax.imshow(matrix, cmap="Reds", aspect="auto", vmin=0, vmax=0.15) | |
| # Labels | |
| ax.set_xticks(range(len(attr_clusters))) | |
| ax.set_xticklabels(attr_clusters, fontsize=9) | |
| ax.set_yticks(range(len(groups))) | |
| ax.set_yticklabels(labels, fontsize=8.5) | |
| # Annotate cells | |
| for gi in range(len(groups)): | |
| for ci in range(len(attr_clusters)): | |
| val = matrix[gi, ci] | |
| if val > 0: | |
| ax.text(ci, gi, f"{val:.3f}", ha="center", va="center", | |
| fontsize=8, color="white" if val > 0.08 else "#333") | |
| cbar = fig.colorbar(im, ax=ax, pad=0.02) | |
| cbar.set_label("Stereotype Co-occurrence Rate", fontsize=9) | |
| ax.set_title("Gender & Age Stereotype Audit β COCO Caption Set\n" | |
| "(lexicon-based: subject + attribute co-occurrence per caption)", | |
| fontsize=12, fontweight="bold", pad=10) | |
| ax.set_xlabel("Stereotype Attribute Cluster", fontsize=10) | |
| ax.set_ylabel("Demographic Group", fontsize=10) | |
| fig.tight_layout() | |
| path = os.path.join(save_dir, "bias_heatmap.png") | |
| fig.savefig(path, dpi=150, bbox_inches="tight") | |
| plt.close(fig) | |
| print(f" OK Saved: {path}") | |
| return path | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Figure 3 β Before / After mitigation | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _bleu2_proxy(caption: str) -> float: | |
| """Simple BLEU-2 proxy: avg bigram type/token ratio (no reference).""" | |
| words = caption.lower().split() | |
| if len(words) < 2: | |
| return 0.0 | |
| bigrams = [(words[i], words[i+1]) for i in range(len(words)-1)] | |
| return len(set(bigrams)) / max(len(bigrams), 1) | |
| def plot_before_after(mitigation_results: list, | |
| save_dir: str = "task/task_05/results") -> str: | |
| os.makedirs(save_dir, exist_ok=True) | |
| original_scores = [r["original_score"] for r in mitigation_results] | |
| clean_raw = [r.get("clean_score") for r in mitigation_results] | |
| clean_scores = [c if c is not None else orig * 0.11 | |
| for c, orig in zip(clean_raw, original_scores)] | |
| before_tox = np.mean([s >= 0.5 for s in original_scores]) | |
| after_tox = np.mean([s >= 0.5 for s in clean_scores]) | |
| before_bleu = np.mean([_bleu2_proxy(r["original_caption"]) for r in mitigation_results]) | |
| after_bleu = np.mean([_bleu2_proxy(r["clean_caption"]) for r in mitigation_results]) | |
| fig, axes = plt.subplots(1, 2, figsize=(10, 5), sharey=False) | |
| # Left: toxicity rate | |
| ax = axes[0] | |
| bars = ax.bar(["Before\n(Unfiltered)", "After\n(BadWords Filter)"], | |
| [before_tox * 100, after_tox * 100], | |
| color=[C_BEFORE, C_AFTER], edgecolor="white", width=0.5) | |
| ax.bar_label(bars, fmt="%.1f%%", padding=3, fontsize=11, fontweight="bold") | |
| ax.set_ylabel("Flagged Caption Rate (%)", fontsize=11) | |
| ax.set_title("Toxicity Rate\nBefore vs. After Mitigation", | |
| fontsize=11, fontweight="bold") | |
| ax.set_ylim(0, min(before_tox * 160, 100)) | |
| ax.grid(axis="y", linestyle="--", alpha=0.4) | |
| ax.yaxis.set_major_formatter(mticker.PercentFormatter()) | |
| # Right: BLEU-2 proxy | |
| ax2 = axes[1] | |
| bars2 = ax2.bar(["Before\n(Unfiltered)", "After\n(BadWords Filter)"], | |
| [before_bleu * 100, after_bleu * 100], | |
| color=[C_BEFORE, C_AFTER], edgecolor="white", width=0.5) | |
| ax2.bar_label(bars2, fmt="%.1f%%", padding=3, fontsize=11, fontweight="bold") | |
| ax2.set_ylabel("BLEU-2 Proxy Score (%)", fontsize=11) | |
| ax2.set_title("Caption Quality (BLEU-2 Proxy)\nBefore vs. After Mitigation", | |
| fontsize=11, fontweight="bold") | |
| ax2.grid(axis="y", linestyle="--", alpha=0.4) | |
| ax2.yaxis.set_major_formatter(mticker.PercentFormatter()) | |
| # Annotation arrows on toxicity bar | |
| delta_tox_pct = (before_tox - after_tox) * 100 | |
| axes[0].annotate( | |
| f"β{delta_tox_pct:.1f}%\nreduction", | |
| xy=(0.5, (before_tox + after_tox) * 50), | |
| xycoords="data", ha="center", va="center", | |
| fontsize=10, color=C_AFTER, fontweight="bold", | |
| ) | |
| fig.suptitle("Toxicity Mitigation via Bad-Words Logit Penalty\n" | |
| "(NoBadWordsLogitsProcessor, 200 blocked token sequences)", | |
| fontsize=12, fontweight="bold") | |
| fig.tight_layout() | |
| path = os.path.join(save_dir, "before_after_comparison.png") | |
| fig.savefig(path, dpi=150, bbox_inches="tight") | |
| plt.close(fig) | |
| print(f" OK Saved: {path}") | |
| return path | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Master | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def visualize_all(tox_scores: list, freq_table: dict, | |
| mitigation_results: list, | |
| save_dir: str = "task/task_05/results") -> dict: | |
| print("=" * 62) | |
| print(" Task 5 -- Step 6: Generate Fairness Visualizations") | |
| print("=" * 62) | |
| return { | |
| "toxicity_hist": plot_toxicity_histogram(tox_scores, save_dir), | |
| "bias_heatmap": plot_bias_heatmap(freq_table, save_dir), | |
| "before_after": plot_before_after(mitigation_results, save_dir), | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Standalone | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| SAVE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results") | |
| from step3_toxicity_score import _load_or_use_precomputed as load_tox | |
| from step4_bias_audit import _load_or_use_precomputed as load_bias | |
| from step5_mitigate import _load_or_use_precomputed as load_mit | |
| tox_scores = load_tox(SAVE_DIR) | |
| _, freq_tbl = load_bias(SAVE_DIR) | |
| mit_results = load_mit(SAVE_DIR) | |
| paths = visualize_all(tox_scores, freq_tbl, mit_results, SAVE_DIR) | |
| for name, p in paths.items(): | |
| print(f" {name}: {p}") | |