Spaces:
Running
Running
| """ | |
| step1_load_model.py | |
| ==================== | |
| STEP 1 β Load the BLIP model and processor. | |
| Responsibilities: | |
| - Detect the best available device (MPS / CUDA / CPU). | |
| - Load base BLIP weights via project's get_blip_model(). | |
| - Optionally patch in fine-tuned weights from outputs/blip/best/. | |
| - Disable gradient checkpointing (required for backward hooks). | |
| - Return a ready-to-use (model, processor, device) triplet. | |
| This module is intentionally tiny and self-contained so it can be | |
| called independently from a notebook, a Streamlit app, or a HuggingFace Space. | |
| """ | |
| import os | |
| import sys | |
| import torch | |
| # ββ project path resolution βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _THIS_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| _TASK_DIR = os.path.dirname(_THIS_DIR) | |
| _PROJECT_ROOT = os.path.dirname(_TASK_DIR) | |
| if _PROJECT_ROOT not in sys.path: | |
| sys.path.insert(0, _PROJECT_ROOT) | |
| from config import CFG | |
| from models.blip_tuner import get_blip_model | |
| FINETUNED_PATH = os.path.join(_PROJECT_ROOT, "outputs", "blip", "best") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_device() -> torch.device: | |
| """Return the best available compute device.""" | |
| if torch.backends.mps.is_available(): | |
| return torch.device("mps") | |
| if torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| return torch.device("cpu") | |
| def load_model(use_finetuned: bool = True, verbose: bool = True): | |
| """ | |
| Load the BLIP model and processor. | |
| Args: | |
| use_finetuned: If True, patches base weights with our fine-tuned | |
| COCO checkpoint from outputs/blip/best/. | |
| verbose: Print loading progress. | |
| Returns: | |
| model β BlipForConditionalGeneration, eval mode. | |
| processor β BlipProcessor. | |
| device β torch.device. | |
| """ | |
| from transformers import BlipForConditionalGeneration | |
| device = get_device() | |
| cfg = CFG.load_for_model("blip") | |
| model, processor = get_blip_model(cfg, device) | |
| # Optionally load fine-tuned weights | |
| if use_finetuned and os.path.isdir(FINETUNED_PATH): | |
| if verbose: | |
| print(f"π Loading fine-tuned weights from {FINETUNED_PATH} β¦") | |
| ft = BlipForConditionalGeneration.from_pretrained(FINETUNED_PATH) | |
| model.load_state_dict(ft.state_dict(), strict=False) | |
| model.to(device) | |
| if verbose: | |
| print("β Fine-tuned weights loaded") | |
| else: | |
| if verbose: | |
| print("β οΈ Fine-tuned weights not found (or disabled), using base model") | |
| # Disable gradient checkpointing β incompatible with backward hooks | |
| try: | |
| model.gradient_checkpointing_disable() | |
| except Exception: | |
| pass | |
| model.config.use_cache = False | |
| model.eval() | |
| if verbose: | |
| print(f"β Model ready on device: {device}") | |
| return model, processor, device | |