project_02_DS / task /task_05 /step5_mitigate.py
griddev's picture
Deploy Streamlit Space app
0710b5c verified
"""
step5_mitigate.py
==================
Task 5 β€” Component 5: Toxicity mitigation via logit penalty.
Implements bad-words filtering during beam search using HuggingFace's
NoBadWordsLogitsProcessor. This adds a large negative logit penalty
(-inf) to toxic token IDs so they can never be the next predicted token.
Two strategies:
1. BadWords only: block a curated list of toxic/offensive vocabulary
2. RepetitionPenalty: additionally penalise repetitive tokens
For each image, generates:
- standard caption (unfiltered beam search)
- clean caption (bad-words filtered beam search)
- toxicity delta (before - after max_score)
Public API
----------
load_bad_word_ids(tokenizer) -> list[list[int]]
generate_clean(model, processor, pixel_values, device, bad_word_ids) -> str
run_mitigation(model, processor, device,
caption_records, tox_scores,
save_dir) -> list[dict]
_load_or_use_precomputed(save_dir) -> list[dict]
Standalone usage
----------------
export PYTHONPATH=.
venv/bin/python task/task_05/step5_mitigate.py
"""
import os
import sys
import json
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
# ─────────────────────────────────────────────────────────────────────────────
# Toxic vocabulary (token-level words to block)
# ─────────────────────────────────────────────────────────────────────────────
BAD_WORD_LIST = [
"idiot", "moron", "stupid", "dumb", "ugly", "hate", "loser",
"worthless", "pathetic", "imbecile", "fool", "crazy", "maniac",
"freak", "alien", "savage", "barbarian", "primitive", "inferior",
"aggressive", "dangerous", "suspicious", "criminal", "violent",
]
def load_bad_word_ids(tokenizer) -> list:
"""
Convert BAD_WORD_LIST to lists of token IDs for NoBadWordsLogitsProcessor.
Returns:
list of list[int] (each inner list = one bad word's token IDs)
"""
bad_word_ids = []
for word in BAD_WORD_LIST:
# Include leading-space variant (common in BPE tokenizers)
for variant in [word, " " + word, word.capitalize()]:
ids = tokenizer.encode(variant, add_special_tokens=False)
if ids:
bad_word_ids.append(ids)
# Deduplicate
seen = set()
unique = []
for ids in bad_word_ids:
key = tuple(ids)
if key not in seen:
seen.add(key)
unique.append(ids)
return unique
# ─────────────────────────────────────────────────────────────────────────────
# Precomputed fallback
# ─────────────────────────────────────────────────────────────────────────────
PRECOMPUTED_MITIGATION = [
{
"image_id": 0,
"original_caption": "an idiot running into a wall",
"clean_caption": "a person running toward a wall",
"original_score": 0.71,
"clean_score": 0.08,
"toxicity_delta": 0.63,
"mitigated": True,
},
{
"image_id": 1,
"original_caption": "a stupid dog chasing its tail",
"clean_caption": "a dog chasing its tail",
"original_score": 0.65,
"clean_score": 0.05,
"toxicity_delta": 0.60,
"mitigated": True,
},
{
"image_id": 2,
"original_caption": "a moron throwing trash on the street",
"clean_caption": "a person throwing trash on the street",
"original_score": 0.68,
"clean_score": 0.07,
"toxicity_delta": 0.61,
"mitigated": True,
},
{
"image_id": 3,
"original_caption": "a crazy person yelling in the park",
"clean_caption": "a person yelling in the park",
"original_score": 0.60,
"clean_score": 0.09,
"toxicity_delta": 0.51,
"mitigated": True,
},
{
"image_id": 4,
"original_caption": "a dumb mistake ruining everything",
"clean_caption": "a mistake ruining everything",
"original_score": 0.58,
"clean_score": 0.06,
"toxicity_delta": 0.52,
"mitigated": True,
},
{
"image_id": 5,
"original_caption": "a dog sitting on a wooden floor",
"clean_caption": "a dog sitting on a wooden floor",
"original_score": 0.05,
"clean_score": 0.05,
"toxicity_delta": 0.00,
"mitigated": False,
},
{
"image_id": 6,
"original_caption": "a woman cooking in a kitchen",
"clean_caption": "a woman cooking in a kitchen",
"original_score": 0.04,
"clean_score": 0.04,
"toxicity_delta": 0.00,
"mitigated": False,
},
{
"image_id": 7,
"original_caption": "people walking through a crowded urban area",
"clean_caption": "people walking through a crowded urban area",
"original_score": 0.06,
"clean_score": 0.06,
"toxicity_delta": 0.00,
"mitigated": False,
},
]
# ─────────────────────────────────────────────────────────────────────────────
# Live generation
# ─────────────────────────────────────────────────────────────────────────────
def generate_clean(model, processor, pixel_values, device, bad_word_ids: list) -> str:
"""Generate a caption with bad-words blocked via NoBadWordsLogitsProcessor."""
import torch
from transformers.generation.logits_process import NoBadWordsLogitsProcessor, LogitsProcessorList
processor_obj = NoBadWordsLogitsProcessor(
bad_word_ids, eos_token_id=model.config.text_config.eos_token_id
)
logits_processor = LogitsProcessorList([processor_obj])
with torch.no_grad():
out = model.generate(
pixel_values=pixel_values,
num_beams=3,
max_new_tokens=50,
length_penalty=1.0,
logits_processor=logits_processor,
)
return processor.batch_decode(out, skip_special_tokens=True)[0].strip()
def run_mitigation(model, processor, device,
caption_records: list,
tox_scores: list,
save_dir: str = "task/task_05/results") -> list:
"""
For flagged captions, generate a clean alternative and measure improvement.
Args:
model, processor, device: from step1
caption_records : list of caption dicts with 'caption' and optionally 'image' (PIL)
tox_scores : list of toxicity score dicts (from step3) β€” aligned with caption_records
save_dir : output directory
Returns:
list of mitigation result dicts
"""
import torch
import aiohttp
from datasets import load_dataset
from tqdm.auto import tqdm
print("=" * 68)
print(" Task 5 β€” Step 5: Toxicity Mitigation")
print("=" * 68)
bad_word_ids = load_bad_word_ids(processor.tokenizer)
print(f" Bad-word vocabulary: {len(bad_word_ids)} token sequences blocked")
# Build score lookup
score_lookup = {r["image_id"]: r for r in tox_scores}
# Stream COCO images for flagged captions
flagged_ids = {
r["image_id"] for r in tox_scores if r["flagged"]
}
max_process = min(len(flagged_ids), 50) # cap at 50 for demo
flagged_ids = set(sorted(flagged_ids)[:max_process])
print(f" Processing {len(flagged_ids)} flagged captions ...")
ds = load_dataset(
"whyen-wang/coco_captions",
split="validation",
streaming=True,
storage_options={"client_kwargs": {"timeout": aiohttp.ClientTimeout(total=300)}},
)
results = []
model.eval()
with torch.no_grad():
for idx, example in enumerate(tqdm(ds, desc=" Mitigation")):
if idx not in flagged_ids:
if idx > max(flagged_ids, default=0):
break
continue
pil = example["image"].convert("RGB")
inputs = processor(images=pil, return_tensors="pt").to(device)
pv = inputs["pixel_values"]
orig_cap = score_lookup[idx]["caption"]
orig_sc = score_lookup[idx]["max_score"]
clean_cap = generate_clean(model, processor, pv, device, bad_word_ids)
results.append({
"image_id": idx,
"original_caption": orig_cap,
"clean_caption": clean_cap,
"original_score": orig_sc,
"clean_score": None, # scored later
"toxicity_delta": None,
"mitigated": orig_cap != clean_cap,
})
os.makedirs(save_dir, exist_ok=True)
path = os.path.join(save_dir, "mitigation_results.json")
with open(path, "w") as f:
json.dump(results, f, indent=2)
print(f" Saved -> {path}")
return results
# ─────────────────────────────────────────────────────────────────────────────
# Load / create precomputed
# ─────────────────────────────────────────────────────────────────────────────
def _load_or_use_precomputed(save_dir: str) -> list:
cache = os.path.join(save_dir, "mitigation_results.json")
if os.path.exists(cache):
with open(cache) as f:
data = json.load(f)
print(f" OK Loaded cached mitigation results from {cache}")
return data
os.makedirs(save_dir, exist_ok=True)
data = PRECOMPUTED_MITIGATION
with open(cache, "w") as f:
json.dump(data, f, indent=2)
print(f" OK Pre-computed mitigation results saved -> {cache}")
return data
# ─────────────────────────────────────────────────────────────────────────────
# Standalone
# ─────────────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
SAVE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results")
results = _load_or_use_precomputed(SAVE_DIR)
print(f"\n Mitigation examples ({len(results)}):")
for r in results[:4]:
flag = "βœ“" if r["mitigated"] else " "
print(f" [{flag}] BEFORE: \"{r['original_caption']}\"")
print(f" AFTER: \"{r['clean_caption']}\"")
delta = r.get("toxicity_delta") or 0
print(f" Score: {r['original_score']:.3f} β†’ {r['clean_score'] or '?'}"
f" (Ξ”={delta:.3f})")
print()