{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "5e83734d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m26.0\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m26.0.1\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", "Note: you may need to restart the kernel to use updated packages.\n" ] } ], "source": [ "\n", "\n", "%pip install -q \"datasets<4.0.0\" transformers accelerate pillow tqdm numpy torch torchvision\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "1f26db57", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/makumar/Documents/.venv/lib/python3.14/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "✅ Config loaded\n" ] } ], "source": [ "import os, math, time, random\n", "from dataclasses import dataclass\n", "\n", "import numpy as np\n", "import torch\n", "from torch.utils.data import DataLoader\n", "from torch.optim import AdamW # use PyTorch AdamW, not transformers [web:34][web:36]\n", "from tqdm.auto import tqdm\n", "\n", "from datasets import load_dataset\n", "from transformers import (\n", " BlipProcessor,\n", " BlipForConditionalGeneration,\n", " get_cosine_schedule_with_warmup, # still valid in transformers optimization APIs [web:41][web:46]\n", ")\n", "\n", "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", "\n", "@dataclass\n", "class CFG:\n", " model_id: str = \"Salesforce/blip-image-captioning-base\"\n", " dataset_id: str = \"whyen-wang/coco_captions\" # COCO captions dataset: image + list of 5 captions [web:7]\n", "\n", " train_samples: int = 1000 # start small; increase to 10k–50k later\n", " val_samples: int = 200\n", " seed: int = 42\n", "\n", " image_size: int = 224\n", " max_target_len: int = 32\n", "\n", " batch_size: int = 4\n", " grad_accum: int = 8\n", " epochs: int = 1\n", "\n", " lr: float = 1e-5\n", " weight_decay: float = 0.01\n", " warmup_ratio: float = 0.03\n", " max_grad_norm: float = 1.0\n", "\n", " num_workers: int = 0 # safer on macOS\n", " log_every: int = 10\n", " save_every_steps: int = 100\n", "\n", " out_dir: str = \"./blip_coco_ft_mps\"\n", "\n", "cfg = CFG()\n", "os.makedirs(cfg.out_dir, exist_ok=True)\n", "print(\"✅ Config loaded\")\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "74fa92b3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "✅ Device: mps\n" ] } ], "source": [ "def seed_all(seed: int):\n", " random.seed(seed)\n", " np.random.seed(seed)\n", " torch.manual_seed(seed)\n", "\n", "seed_all(cfg.seed)\n", "\n", "if torch.backends.mps.is_available():\n", " device = torch.device(\"mps\")\n", "elif torch.cuda.is_available():\n", " device = torch.device(\"cuda\")\n", "else:\n", " device = torch.device(\"cpu\")\n", "\n", "print(f\"✅ Device: {device}\")\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "46dced20", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.\n", "Downloading data: 100%|██████████| 19.3G/19.3G [28:19<00:00, 11.4MB/s] \n", "Downloading data: 100%|██████████| 816M/816M [01:08<00:00, 12.0MB/s] \n", "Generating train split: 118287 examples [00:02, 54322.81 examples/s]\n", "Generating validation split: 5000 examples [00:00, 55846.76 examples/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['image', 'captions'],\n", " num_rows: 118287\n", " })\n", " validation: Dataset({\n", " features: ['image', 'captions'],\n", " num_rows: 5000\n", " })\n", "})\n", "Example keys: dict_keys(['image', 'captions'])\n", "Captions per image: 5\n", "✅ Train: 1000, Val: 200\n" ] } ], "source": [ "import aiohttp\n", "import datasets\n", "\n", "# Use storage_options to increase the timeout from 5 minutes (300s) to 1 hour (3600s)\n", "ds = load_dataset(\n", " cfg.dataset_id, \n", " trust_remote_code=True,\n", " storage_options={'client_kwargs': {'timeout': aiohttp.ClientTimeout(total=3600)}}\n", ")\n", "\n", "print(ds)\n", "print(\"Example keys:\", ds[\"train\"][0].keys())\n", "print(\"Captions per image:\", len(ds[\"train\"][0][\"captions\"]))\n", "\n", "train_split = \"train\"\n", "val_split = \"validation\" if \"validation\" in ds else (\"val\" if \"val\" in ds else \"train\")\n", "\n", "train_ds = ds[train_split].shuffle(seed=cfg.seed).select(\n", " range(min(cfg.train_samples, len(ds[train_split])))\n", ")\n", "val_ds = ds[val_split].shuffle(seed=cfg.seed + 1).select(\n", " range(min(cfg.val_samples, len(ds[val_split])))\n", ")\n", "\n", "print(f\"✅ Train: {len(train_ds)}, Val: {len(val_ds)}\")\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "681b5a5f", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The image processor of type `BlipImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. \n", "Loading weights: 100%|██████████| 473/473 [00:00<00:00, 1923.98it/s, Materializing param=vision_model.post_layernorm.weight] \n", "The tied weights mapping and config for this model specifies to tie text_decoder.cls.predictions.bias to text_decoder.cls.predictions.decoder.bias, but both are present in the checkpoints, so we will NOT tie them. You should update the config with `tie_word_embeddings=False` to silence this warning\n", "The tied weights mapping and config for this model specifies to tie text_decoder.bert.embeddings.word_embeddings.weight to text_decoder.cls.predictions.decoder.weight, but both are present in the checkpoints, so we will NOT tie them. You should update the config with `tie_word_embeddings=False` to silence this warning\n", "\u001b[1mBlipForConditionalGeneration LOAD REPORT\u001b[0m from: Salesforce/blip-image-captioning-base\n", "Key | Status | | \n", "------------------------------------------+------------+--+-\n", "text_decoder.bert.embeddings.position_ids | UNEXPECTED | | \n", "\n", "\u001b[3mNotes:\n", "- UNEXPECTED\u001b[3m\t:can be ignored when loading from different task/architecture; not ok if you expect identical arch.\u001b[0m\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "✅ Gradient checkpointing enabled\n", "✅ Model loaded: Salesforce/blip-image-captioning-base\n" ] } ], "source": [ "processor = BlipProcessor.from_pretrained(cfg.model_id)\n", "model = BlipForConditionalGeneration.from_pretrained(cfg.model_id)\n", "\n", "# Force 224px images (lighter for Mac)\n", "try:\n", " processor.image_processor.size = {\"height\": cfg.image_size, \"width\": cfg.image_size}\n", "except Exception as e:\n", " print(f\"⚠️ Could not set image size: {e}\")\n", "\n", "# Memory helpers\n", "try:\n", " model.gradient_checkpointing_enable()\n", " print(\"✅ Gradient checkpointing enabled\")\n", "except Exception as e:\n", " print(f\"⚠️ Gradient checkpointing failed: {e}\")\n", "\n", "model.config.use_cache = False # must be False when using gradient checkpointing\n", "model.to(device)\n", "\n", "print(f\"✅ Model loaded: {cfg.model_id}\")\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "ae518a72", "metadata": {}, "outputs": [], "source": [ "def collate_fn(examples):\n", " images = [ex[\"image\"].convert(\"RGB\") for ex in examples]\n", " # pick one random caption per image\n", " captions = [random.choice(ex[\"captions\"]) for ex in examples]\n", "\n", " encoding = processor(\n", " images=images,\n", " text=captions,\n", " padding=\"max_length\",\n", " truncation=True,\n", " max_length=cfg.max_target_len,\n", " return_tensors=\"pt\",\n", " )\n", "\n", " # BLIP needs `labels` = `input_ids` for captioning loss\n", " encoding[\"labels\"] = encoding[\"input_ids\"].clone()\n", "\n", " return encoding\n", "\n", "\n", "train_loader = DataLoader(\n", " train_ds,\n", " batch_size=cfg.batch_size,\n", " shuffle=True,\n", " num_workers=cfg.num_workers,\n", " collate_fn=collate_fn,\n", " pin_memory=True,\n", ")\n", "\n", "val_loader = DataLoader(\n", " val_ds,\n", " batch_size=cfg.batch_size,\n", " shuffle=False,\n", " num_workers=cfg.num_workers,\n", " collate_fn=collate_fn,\n", " pin_memory=True,\n", ")" ] }, { "cell_type": "code", "execution_count": 10, "id": "becf6f22", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "✅ Update steps: 32, Warmup: 0\n" ] } ], "source": [ "optimizer = AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)\n", "\n", "total_update_steps = math.ceil(len(train_loader) / cfg.grad_accum) * cfg.epochs\n", "warmup_steps = int(total_update_steps * cfg.warmup_ratio)\n", "\n", "scheduler = get_cosine_schedule_with_warmup(\n", " optimizer,\n", " num_warmup_steps=warmup_steps,\n", " num_training_steps=total_update_steps,\n", ")\n", "\n", "print(f\"✅ Update steps: {total_update_steps}, Warmup: {warmup_steps}\")\n" ] }, { "cell_type": "code", "execution_count": 11, "id": "4134441d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "✅ Checkpoint helpers ready\n" ] } ], "source": [ "def save_ckpt(step, epoch):\n", " \"\"\"\n", " Save model weights, processor, and training state to cfg.out_dir.\n", " Directory: out_dir/ckpt_step{step}_epoch{epoch}\n", " \"\"\"\n", " path = os.path.join(cfg.out_dir, f\"ckpt_step{step}_epoch{epoch}\")\n", " os.makedirs(path, exist_ok=True)\n", "\n", " # Save model weights + config in HF format\n", " model.save_pretrained(path)\n", " processor.save_pretrained(path)\n", "\n", " # Save optimizer/scheduler state, step, epoch\n", " torch.save(\n", " {\n", " \"step\": step,\n", " \"epoch\": epoch,\n", " \"optimizer\": optimizer.state_dict(),\n", " \"scheduler\": scheduler.state_dict(),\n", " \"cfg\": cfg.__dict__,\n", " },\n", " os.path.join(path, \"train_state.pt\"),\n", " )\n", "\n", " print(f\"✅ Checkpoint saved: {path}\")\n", "\n", "\n", "def load_ckpt(path):\n", " \"\"\"\n", " Load model + optimizer/scheduler from a checkpoint directory.\n", " \"\"\"\n", " # Load model weights\n", " loaded_model = BlipForConditionalGeneration.from_pretrained(path)\n", " model.load_state_dict(loaded_model.state_dict())\n", "\n", " # Load training state\n", " state = torch.load(os.path.join(path, \"train_state.pt\"), map_location=\"cpu\")\n", " optimizer.load_state_dict(state[\"optimizer\"])\n", " scheduler.load_state_dict(state[\"scheduler\"])\n", "\n", " print(f\"✅ Resumed from step {state['step']}, epoch {state['epoch']}\")\n", " return state[\"step\"], state[\"epoch\"]\n", "\n", "\n", "print(\"✅ Checkpoint helpers ready\")\n" ] }, { "cell_type": "code", "execution_count": 12, "id": "c323b9bb", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Epoch 1/1: 0%| | 0/250 [00:00