Spaces:
Sleeping
Sleeping
Mokhtar commited on
Commit ·
e4721a6
1
Parent(s): b4aa453
Deploying backend code
Browse files- Dockerfile +29 -0
- README.md +58 -11
- config/config.py +39 -0
- requirements.txt +13 -0
- src/api/app.py +71 -0
- src/data/dataset.py +65 -0
- src/models/model.py +118 -0
- src/preprocessing/transforms.py +17 -0
- src/training/trainer.py +60 -0
Dockerfile
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use Python 3.9 as base image
|
| 2 |
+
FROM python:3.9
|
| 3 |
+
|
| 4 |
+
# Set the working directory inside the container
|
| 5 |
+
WORKDIR /code
|
| 6 |
+
|
| 7 |
+
# Copy requirements and install dependencies
|
| 8 |
+
# We use --no-cache-dir to keep the image small
|
| 9 |
+
COPY ./requirements.txt /code/requirements.txt
|
| 10 |
+
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
| 11 |
+
|
| 12 |
+
# Install python-multipart (required for UploadFile in FastAPI) if not in requirements.txt
|
| 13 |
+
RUN pip install python-multipart
|
| 14 |
+
|
| 15 |
+
# Copy the entire project into the container
|
| 16 |
+
COPY . /code
|
| 17 |
+
|
| 18 |
+
# Create a directory for Hugging Face cache to avoid permission errors
|
| 19 |
+
RUN mkdir -p /code/cache
|
| 20 |
+
ENV TRANSFORMERS_CACHE=/code/cache
|
| 21 |
+
ENV TORCH_HOME=/code/cache
|
| 22 |
+
RUN chmod -R 777 /code/cache
|
| 23 |
+
|
| 24 |
+
# Expose the default Hugging Face port
|
| 25 |
+
EXPOSE 7860
|
| 26 |
+
|
| 27 |
+
# Command to run the application
|
| 28 |
+
# We point to src.api.app:app because app.py is inside src/api/
|
| 29 |
+
CMD ["uvicorn", "src.api.app:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
|
@@ -1,11 +1,58 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Image Captioning with SOTA Models
|
| 2 |
+
|
| 3 |
+
This project provides a unified API for Image Captioning using various State-of-the-Art (SOTA) models as well as a custom ResNet+GPT2 implementation.
|
| 4 |
+
|
| 5 |
+
## Supported Models
|
| 6 |
+
|
| 7 |
+
1. **BLIP (Bootstrapping Language-Image Pre-training)**
|
| 8 |
+
* Model: `Salesforce/blip-image-captioning-large`
|
| 9 |
+
* Status: **Default** (Best Performance)
|
| 10 |
+
* Description: Produces highly accurate and detailed captions.
|
| 11 |
+
|
| 12 |
+
2. **ViT-GPT2**
|
| 13 |
+
* Model: `nlpconnect/vit-gpt2-image-captioning`
|
| 14 |
+
* Status: Available
|
| 15 |
+
* Description: Uses Vision Transformer (ViT) encoder and GPT-2 decoder.
|
| 16 |
+
|
| 17 |
+
3. **ResNet50 + GPT-2 (Custom)**
|
| 18 |
+
* Model: Custom implementation trained from scratch.
|
| 19 |
+
* Status: Legacy / Experimental
|
| 20 |
+
* Description: Good for learning purposes or custom datasets.
|
| 21 |
+
|
| 22 |
+
## Installation
|
| 23 |
+
|
| 24 |
+
1. Clone the repository.
|
| 25 |
+
2. Install dependencies:
|
| 26 |
+
```bash
|
| 27 |
+
pip install -r requirements.txt
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
## Configuration
|
| 31 |
+
|
| 32 |
+
Edit `config/config.py` to select the model:
|
| 33 |
+
|
| 34 |
+
```python
|
| 35 |
+
class Config:
|
| 36 |
+
# ...
|
| 37 |
+
MODEL_TYPE = "blip" # Options: "blip", "vit_gpt2", "resnet_gpt2"
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
## Running the API
|
| 41 |
+
|
| 42 |
+
Start the FastAPI server:
|
| 43 |
+
|
| 44 |
+
```bash
|
| 45 |
+
python main.py --mode api
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
Open your browser at `http://localhost:8001` to use the drag-and-drop interface.
|
| 49 |
+
|
| 50 |
+
## Training (ResNet+GPT2 only)
|
| 51 |
+
|
| 52 |
+
To train the custom model:
|
| 53 |
+
|
| 54 |
+
1. Set `MODEL_TYPE = "resnet_gpt2"` in config.
|
| 55 |
+
2. Run:
|
| 56 |
+
```bash
|
| 57 |
+
python main.py --mode train
|
| 58 |
+
```
|
config/config.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
class Config:
|
| 5 |
+
# Paths
|
| 6 |
+
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 7 |
+
|
| 8 |
+
# We do not need training data paths for the inference API
|
| 9 |
+
DATA_DIR = None
|
| 10 |
+
CAPTIONS_FILE = None
|
| 11 |
+
|
| 12 |
+
# Model saving/loading directory
|
| 13 |
+
MODEL_SAVE_DIR = os.path.join(BASE_DIR, 'models')
|
| 14 |
+
LOG_DIR = os.path.join(BASE_DIR, 'logs')
|
| 15 |
+
|
| 16 |
+
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
|
| 17 |
+
os.makedirs(LOG_DIR, exist_ok=True)
|
| 18 |
+
|
| 19 |
+
# Device: Force CPU if CUDA is not available (Hugging Face Free Tier is CPU)
|
| 20 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 21 |
+
|
| 22 |
+
# Hyperparameters (kept for reference, mostly unused in inference)
|
| 23 |
+
BATCH_SIZE = 1
|
| 24 |
+
LEARNING_RATE = 2e-5
|
| 25 |
+
NUM_EPOCHS = 10
|
| 26 |
+
NUM_WORKERS = 2
|
| 27 |
+
|
| 28 |
+
# Model Config
|
| 29 |
+
# Change this to "blip" or "vit_gpt2" for your deployment to ensure no custom weights are needed
|
| 30 |
+
MODEL_TYPE = "blip"
|
| 31 |
+
ENCODER_MODEL = "resnet50"
|
| 32 |
+
DECODER_MODEL = "gpt2"
|
| 33 |
+
EMBED_DIM = 768
|
| 34 |
+
MAX_SEQ_LEN = 40
|
| 35 |
+
|
| 36 |
+
# Image Config
|
| 37 |
+
IMAGE_SIZE = (224, 224)
|
| 38 |
+
|
| 39 |
+
config = Config()
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
fastapi
|
| 4 |
+
uvicorn
|
| 5 |
+
pillow
|
| 6 |
+
pandas
|
| 7 |
+
spacy
|
| 8 |
+
tqdm
|
| 9 |
+
matplotlib
|
| 10 |
+
gTTS
|
| 11 |
+
transformers
|
| 12 |
+
python-multipart
|
| 13 |
+
requests
|
src/api/app.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
from fastapi import FastAPI, UploadFile, File, HTTPException
|
| 5 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from transformers import GPT2Tokenizer
|
| 8 |
+
|
| 9 |
+
# Adjust imports based on your folder structure
|
| 10 |
+
from config.config import config
|
| 11 |
+
from src.models.model import get_model
|
| 12 |
+
from src.preprocessing.transforms import get_transforms
|
| 13 |
+
|
| 14 |
+
app = FastAPI(title="Object Captioning LLM API")
|
| 15 |
+
|
| 16 |
+
# --- CORS CONFIGURATION (Crucial for Vercel) ---
|
| 17 |
+
app.add_middleware(
|
| 18 |
+
CORSMiddleware,
|
| 19 |
+
allow_origins=["*"], # For production, replace "*" with your Vercel URL
|
| 20 |
+
allow_credentials=True,
|
| 21 |
+
allow_methods=["*"],
|
| 22 |
+
allow_headers=["*"],
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
# Load Model
|
| 26 |
+
print(f"Loading model: {config.MODEL_TYPE} on {config.DEVICE}...")
|
| 27 |
+
device = config.DEVICE
|
| 28 |
+
model = get_model(config).to(device)
|
| 29 |
+
|
| 30 |
+
# Legacy support for ResNetGPT2
|
| 31 |
+
if config.MODEL_TYPE == "resnet_gpt2":
|
| 32 |
+
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
| 33 |
+
# Use a relative path or ensure this file is uploaded to the Docker container
|
| 34 |
+
model_path = os.path.join(config.MODEL_SAVE_DIR, "best_model_llm.pth")
|
| 35 |
+
if os.path.exists(model_path):
|
| 36 |
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
| 37 |
+
print("Loaded trained custom model.")
|
| 38 |
+
else:
|
| 39 |
+
print("Warning: No trained model found for ResNetGPT2. Using random weights.")
|
| 40 |
+
else:
|
| 41 |
+
tokenizer = None
|
| 42 |
+
|
| 43 |
+
model.eval()
|
| 44 |
+
transform = get_transforms(train=False)
|
| 45 |
+
|
| 46 |
+
@app.get("/")
|
| 47 |
+
def home():
|
| 48 |
+
return {"message": "Image Captioning API is running. Send POST requests to /predict"}
|
| 49 |
+
|
| 50 |
+
@app.post("/predict")
|
| 51 |
+
async def predict(file: UploadFile = File(...)):
|
| 52 |
+
try:
|
| 53 |
+
# Read Image
|
| 54 |
+
image_data = await file.read()
|
| 55 |
+
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
| 56 |
+
|
| 57 |
+
# Generate Caption
|
| 58 |
+
if config.MODEL_TYPE == "resnet_gpt2":
|
| 59 |
+
img_tensor = transform(image).to(device)
|
| 60 |
+
# Ensure generate_caption handles the tensor/tokenizer correctly
|
| 61 |
+
caption = model.generate_caption(img_tensor, tokenizer)
|
| 62 |
+
else:
|
| 63 |
+
# SOTA models (BLIP/ViT) take the PIL image directly
|
| 64 |
+
caption = model.generate_caption(image)
|
| 65 |
+
|
| 66 |
+
return {
|
| 67 |
+
"caption": caption
|
| 68 |
+
}
|
| 69 |
+
except Exception as e:
|
| 70 |
+
print(f"Error during prediction: {e}")
|
| 71 |
+
raise HTTPException(status_code=500, detail=str(e))
|
src/data/dataset.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import torch
|
| 5 |
+
from torch.utils.data import Dataset
|
| 6 |
+
from transformers import GPT2Tokenizer
|
| 7 |
+
|
| 8 |
+
class CaptionDataset(Dataset):
|
| 9 |
+
def __init__(self, root_dir, captions_file, transform=None, max_length=40):
|
| 10 |
+
self.root_dir = root_dir
|
| 11 |
+
self.transform = transform
|
| 12 |
+
self.max_length = max_length
|
| 13 |
+
|
| 14 |
+
# Load captions
|
| 15 |
+
# Format: image,caption (csv)
|
| 16 |
+
self.df = pd.read_csv(captions_file, delimiter=',')
|
| 17 |
+
|
| 18 |
+
# Rename columns to match expected internal names if necessary, or just use them directly
|
| 19 |
+
# The file has 'image' and 'caption' columns based on inspection
|
| 20 |
+
self.df.rename(columns={'image': 'image_name', 'caption': 'comment'}, inplace=True)
|
| 21 |
+
|
| 22 |
+
self.df['image_name'] = self.df['image_name'].str.strip()
|
| 23 |
+
self.df['comment'] = self.df['comment'].str.strip()
|
| 24 |
+
self.df = self.df.dropna()
|
| 25 |
+
|
| 26 |
+
self.captions = self.df['comment'].tolist()
|
| 27 |
+
self.images = self.df['image_name'].tolist()
|
| 28 |
+
|
| 29 |
+
# Initialize Tokenizer
|
| 30 |
+
self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
| 31 |
+
# GPT2 doesn't have a pad token, so we use eos_token as pad_token
|
| 32 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 33 |
+
|
| 34 |
+
def __len__(self):
|
| 35 |
+
return len(self.captions)
|
| 36 |
+
|
| 37 |
+
def __getitem__(self, idx):
|
| 38 |
+
caption = self.captions[idx]
|
| 39 |
+
img_name = self.images[idx]
|
| 40 |
+
img_path = os.path.join(self.root_dir, img_name)
|
| 41 |
+
|
| 42 |
+
try:
|
| 43 |
+
image = Image.open(img_path).convert("RGB")
|
| 44 |
+
except Exception:
|
| 45 |
+
# Fallback for missing images or errors, return next item
|
| 46 |
+
return self.__getitem__((idx + 1) % len(self))
|
| 47 |
+
|
| 48 |
+
if self.transform:
|
| 49 |
+
image = self.transform(image)
|
| 50 |
+
|
| 51 |
+
# Tokenize caption
|
| 52 |
+
# We add a special prefix to prompt the model if desired, but for direct captioning:
|
| 53 |
+
# Format: [Image Feature] -> Caption
|
| 54 |
+
encoding = self.tokenizer(
|
| 55 |
+
caption,
|
| 56 |
+
truncation=True,
|
| 57 |
+
padding='max_length',
|
| 58 |
+
max_length=self.max_length,
|
| 59 |
+
return_tensors='pt'
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
input_ids = encoding['input_ids'].squeeze()
|
| 63 |
+
attention_mask = encoding['attention_mask'].squeeze()
|
| 64 |
+
|
| 65 |
+
return image, input_ids, attention_mask
|
src/models/model.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torchvision.models as models
|
| 4 |
+
from transformers import (
|
| 5 |
+
GPT2LMHeadModel,
|
| 6 |
+
VisionEncoderDecoderModel,
|
| 7 |
+
ViTImageProcessor,
|
| 8 |
+
AutoTokenizer,
|
| 9 |
+
BlipProcessor,
|
| 10 |
+
BlipForConditionalGeneration
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
# -----------------------------------------------------------------------------
|
| 14 |
+
# 1. Custom ResNet + GPT-2 (Training from Scratch)
|
| 15 |
+
# -----------------------------------------------------------------------------
|
| 16 |
+
class ResNetEncoder(nn.Module):
|
| 17 |
+
def __init__(self, embed_dim=768):
|
| 18 |
+
super(ResNetEncoder, self).__init__()
|
| 19 |
+
resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
|
| 20 |
+
modules = list(resnet.children())[:-1]
|
| 21 |
+
self.resnet = nn.Sequential(*modules)
|
| 22 |
+
for param in self.resnet.parameters():
|
| 23 |
+
param.requires_grad = False
|
| 24 |
+
self.projection = nn.Linear(2048, embed_dim)
|
| 25 |
+
self.bn = nn.BatchNorm1d(embed_dim, momentum=0.01)
|
| 26 |
+
|
| 27 |
+
def forward(self, images):
|
| 28 |
+
features = self.resnet(images)
|
| 29 |
+
features = features.view(features.size(0), -1)
|
| 30 |
+
features = self.projection(features)
|
| 31 |
+
features = self.bn(features)
|
| 32 |
+
return features
|
| 33 |
+
|
| 34 |
+
class ResNetGPT2(nn.Module):
|
| 35 |
+
def __init__(self, max_seq_len=40):
|
| 36 |
+
super(ResNetGPT2, self).__init__()
|
| 37 |
+
self.encoder = ResNetEncoder(embed_dim=768)
|
| 38 |
+
self.gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
|
| 39 |
+
self.max_seq_len = max_seq_len
|
| 40 |
+
|
| 41 |
+
def forward(self, images, input_ids, attention_mask):
|
| 42 |
+
image_embeds = self.encoder(images)
|
| 43 |
+
token_embeds = self.gpt2.transformer.wte(input_ids)
|
| 44 |
+
inputs_embeds = torch.cat((image_embeds.unsqueeze(1), token_embeds), dim=1)
|
| 45 |
+
batch_size = images.shape[0]
|
| 46 |
+
ones = torch.ones(batch_size, 1).to(images.device)
|
| 47 |
+
attention_mask = torch.cat((ones, attention_mask), dim=1)
|
| 48 |
+
outputs = self.gpt2(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
|
| 49 |
+
return outputs.logits
|
| 50 |
+
|
| 51 |
+
def generate_caption(self, image, tokenizer, max_length=20, temperature=1.0):
|
| 52 |
+
self.eval()
|
| 53 |
+
with torch.no_grad():
|
| 54 |
+
image_embed = self.encoder(image.unsqueeze(0))
|
| 55 |
+
inputs_embeds = image_embed.unsqueeze(1)
|
| 56 |
+
generated_tokens = []
|
| 57 |
+
for _ in range(max_length):
|
| 58 |
+
outputs = self.gpt2(inputs_embeds=inputs_embeds)
|
| 59 |
+
logits = outputs.logits[:, -1, :] / temperature
|
| 60 |
+
next_token = torch.argmax(logits, dim=-1).unsqueeze(0)
|
| 61 |
+
if next_token.item() == tokenizer.eos_token_id:
|
| 62 |
+
break
|
| 63 |
+
generated_tokens.append(next_token.item())
|
| 64 |
+
next_token_embed = self.gpt2.transformer.wte(next_token)
|
| 65 |
+
inputs_embeds = torch.cat((inputs_embeds, next_token_embed), dim=1)
|
| 66 |
+
return tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
| 67 |
+
|
| 68 |
+
# -----------------------------------------------------------------------------
|
| 69 |
+
# 2. ViT + GPT-2 (Pre-trained SOTA 1)
|
| 70 |
+
# -----------------------------------------------------------------------------
|
| 71 |
+
class ViTGPT2Captioner(nn.Module):
|
| 72 |
+
def __init__(self):
|
| 73 |
+
super().__init__()
|
| 74 |
+
print("Loading ViT-GPT2 model...")
|
| 75 |
+
self.model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
|
| 76 |
+
self.feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
|
| 77 |
+
self.tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
|
| 78 |
+
|
| 79 |
+
def generate_caption(self, image, **kwargs):
|
| 80 |
+
self.eval()
|
| 81 |
+
with torch.no_grad():
|
| 82 |
+
pixel_values = self.feature_extractor(images=image, return_tensors="pt").pixel_values
|
| 83 |
+
pixel_values = pixel_values.to(self.model.device)
|
| 84 |
+
output_ids = self.model.generate(pixel_values, max_length=20, num_beams=4)
|
| 85 |
+
preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
| 86 |
+
return preds[0].strip()
|
| 87 |
+
|
| 88 |
+
# -----------------------------------------------------------------------------
|
| 89 |
+
# 3. BLIP (Pre-trained SOTA 2 - Best)
|
| 90 |
+
# -----------------------------------------------------------------------------
|
| 91 |
+
class BLIPCaptioner(nn.Module):
|
| 92 |
+
def __init__(self):
|
| 93 |
+
super().__init__()
|
| 94 |
+
print("Loading BLIP model (Salesforce/blip-image-captioning-large)...")
|
| 95 |
+
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
|
| 96 |
+
self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
|
| 97 |
+
|
| 98 |
+
def generate_caption(self, image, **kwargs):
|
| 99 |
+
self.eval()
|
| 100 |
+
with torch.no_grad():
|
| 101 |
+
inputs = self.processor(images=image, return_tensors="pt").to(self.model.device)
|
| 102 |
+
output_ids = self.model.generate(**inputs, max_length=50, num_beams=5, repetition_penalty=1.2, min_length=5)
|
| 103 |
+
caption = self.processor.decode(output_ids[0], skip_special_tokens=True)
|
| 104 |
+
return caption
|
| 105 |
+
|
| 106 |
+
# -----------------------------------------------------------------------------
|
| 107 |
+
# Factory
|
| 108 |
+
# -----------------------------------------------------------------------------
|
| 109 |
+
def get_model(config):
|
| 110 |
+
if config.MODEL_TYPE == "resnet_gpt2":
|
| 111 |
+
return ResNetGPT2()
|
| 112 |
+
elif config.MODEL_TYPE == "vit_gpt2":
|
| 113 |
+
return ViTGPT2Captioner()
|
| 114 |
+
elif config.MODEL_TYPE == "blip":
|
| 115 |
+
return BLIPCaptioner()
|
| 116 |
+
else:
|
| 117 |
+
raise ValueError(f"Unknown model type: {config.MODEL_TYPE}")
|
| 118 |
+
|
src/preprocessing/transforms.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torchvision.transforms as transforms
|
| 2 |
+
|
| 3 |
+
def get_transforms(image_size=(224, 224), train=True):
|
| 4 |
+
if train:
|
| 5 |
+
return transforms.Compose([
|
| 6 |
+
transforms.Resize(image_size),
|
| 7 |
+
transforms.RandomHorizontalFlip(),
|
| 8 |
+
transforms.ColorJitter(brightness=0.1, contrast=0.1),
|
| 9 |
+
transforms.ToTensor(),
|
| 10 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet stats
|
| 11 |
+
])
|
| 12 |
+
else:
|
| 13 |
+
return transforms.Compose([
|
| 14 |
+
transforms.Resize(image_size),
|
| 15 |
+
transforms.ToTensor(),
|
| 16 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 17 |
+
])
|
src/training/trainer.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.optim as optim
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
def train_model(model, train_loader, val_loader, config, tokenizer):
|
| 8 |
+
model = model.to(config.DEVICE)
|
| 9 |
+
optimizer = optim.AdamW(model.parameters(), lr=config.LEARNING_RATE)
|
| 10 |
+
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
|
| 11 |
+
|
| 12 |
+
best_loss = float('inf')
|
| 13 |
+
|
| 14 |
+
for epoch in range(config.NUM_EPOCHS):
|
| 15 |
+
model.train()
|
| 16 |
+
train_loss = 0
|
| 17 |
+
loop = tqdm(train_loader, total=len(train_loader), leave=True)
|
| 18 |
+
|
| 19 |
+
for images, input_ids, attention_mask in loop:
|
| 20 |
+
images = images.to(config.DEVICE)
|
| 21 |
+
input_ids = input_ids.to(config.DEVICE)
|
| 22 |
+
attention_mask = attention_mask.to(config.DEVICE)
|
| 23 |
+
|
| 24 |
+
optimizer.zero_grad()
|
| 25 |
+
|
| 26 |
+
# Forward pass
|
| 27 |
+
# Logits: [batch, seq_len+1, vocab_size]
|
| 28 |
+
logits = model(images, input_ids, attention_mask)
|
| 29 |
+
|
| 30 |
+
# Shift logits and labels for next-token prediction
|
| 31 |
+
# We want to predict input_ids based on previous context
|
| 32 |
+
# The model output at index i corresponds to prediction for token at i+1
|
| 33 |
+
# Input sequence to model: [Image, T1, T2, T3, ...]
|
| 34 |
+
# Output logits: [P1, P2, P3, P4, ...]
|
| 35 |
+
# Targets: [T1, T2, T3, T4, ...]
|
| 36 |
+
|
| 37 |
+
# We discard the last logit because we don't have a target for it
|
| 38 |
+
shift_logits = logits[:, :-1, :].contiguous()
|
| 39 |
+
# We use input_ids as targets
|
| 40 |
+
shift_labels = input_ids.contiguous()
|
| 41 |
+
|
| 42 |
+
loss = criterion(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
| 43 |
+
|
| 44 |
+
loss.backward()
|
| 45 |
+
optimizer.step()
|
| 46 |
+
|
| 47 |
+
train_loss += loss.item()
|
| 48 |
+
loop.set_description(f"Epoch [{epoch+1}/{config.NUM_EPOCHS}]")
|
| 49 |
+
loop.set_postfix(loss=loss.item())
|
| 50 |
+
|
| 51 |
+
avg_train_loss = train_loss / len(train_loader)
|
| 52 |
+
print(f"Epoch {epoch+1} Loss: {avg_train_loss:.4f}")
|
| 53 |
+
|
| 54 |
+
# Save checkpoint
|
| 55 |
+
if avg_train_loss < best_loss:
|
| 56 |
+
best_loss = avg_train_loss
|
| 57 |
+
torch.save(model.state_dict(), os.path.join(config.MODEL_SAVE_DIR, "best_model_llm.pth"))
|
| 58 |
+
print("Saved Best Model!")
|
| 59 |
+
|
| 60 |
+
torch.save(model.state_dict(), os.path.join(config.MODEL_SAVE_DIR, "last_checkpoint_llm.pth"))
|