project_02_DS / task /task_02 /step1_load_model.py
griddev's picture
Deploy Streamlit Space app
0710b5c verified
"""
step1_load_model.py
====================
STEP 1 β€” Load the BLIP model and processor.
Responsibilities:
- Detect the best available device (MPS / CUDA / CPU).
- Load base BLIP weights via project's get_blip_model().
- Optionally patch in fine-tuned weights from outputs/blip/best/.
- Disable gradient checkpointing (required for backward hooks).
- Return a ready-to-use (model, processor, device) triplet.
This module is intentionally tiny and self-contained so it can be
called independently from a notebook, a Streamlit app, or a HuggingFace Space.
"""
import os
import sys
import torch
# ── project path resolution ─────────────────────────────────────────────────
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
_TASK_DIR = os.path.dirname(_THIS_DIR)
_PROJECT_ROOT = os.path.dirname(_TASK_DIR)
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
from config import CFG
from models.blip_tuner import get_blip_model
FINETUNED_PATH = os.path.join(_PROJECT_ROOT, "outputs", "blip", "best")
# ────────────────────────────────────────────────────────────────────────────
def get_device() -> torch.device:
"""Return the best available compute device."""
if torch.backends.mps.is_available():
return torch.device("mps")
if torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")
def load_model(use_finetuned: bool = True, verbose: bool = True):
"""
Load the BLIP model and processor.
Args:
use_finetuned: If True, patches base weights with our fine-tuned
COCO checkpoint from outputs/blip/best/.
verbose: Print loading progress.
Returns:
model – BlipForConditionalGeneration, eval mode.
processor – BlipProcessor.
device – torch.device.
"""
from transformers import BlipForConditionalGeneration
device = get_device()
cfg = CFG.load_for_model("blip")
model, processor = get_blip_model(cfg, device)
# Optionally load fine-tuned weights
if use_finetuned and os.path.isdir(FINETUNED_PATH):
if verbose:
print(f"πŸ”„ Loading fine-tuned weights from {FINETUNED_PATH} …")
ft = BlipForConditionalGeneration.from_pretrained(FINETUNED_PATH)
model.load_state_dict(ft.state_dict(), strict=False)
model.to(device)
if verbose:
print("βœ… Fine-tuned weights loaded")
else:
if verbose:
print("⚠️ Fine-tuned weights not found (or disabled), using base model")
# Disable gradient checkpointing β€” incompatible with backward hooks
try:
model.gradient_checkpointing_disable()
except Exception:
pass
model.config.use_cache = False
model.eval()
if verbose:
print(f"βœ… Model ready on device: {device}")
return model, processor, device