Mokhtar commited on
Commit
e4721a6
·
1 Parent(s): b4aa453

Deploying backend code

Browse files
Dockerfile ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use Python 3.9 as base image
2
+ FROM python:3.9
3
+
4
+ # Set the working directory inside the container
5
+ WORKDIR /code
6
+
7
+ # Copy requirements and install dependencies
8
+ # We use --no-cache-dir to keep the image small
9
+ COPY ./requirements.txt /code/requirements.txt
10
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
11
+
12
+ # Install python-multipart (required for UploadFile in FastAPI) if not in requirements.txt
13
+ RUN pip install python-multipart
14
+
15
+ # Copy the entire project into the container
16
+ COPY . /code
17
+
18
+ # Create a directory for Hugging Face cache to avoid permission errors
19
+ RUN mkdir -p /code/cache
20
+ ENV TRANSFORMERS_CACHE=/code/cache
21
+ ENV TORCH_HOME=/code/cache
22
+ RUN chmod -R 777 /code/cache
23
+
24
+ # Expose the default Hugging Face port
25
+ EXPOSE 7860
26
+
27
+ # Command to run the application
28
+ # We point to src.api.app:app because app.py is inside src/api/
29
+ CMD ["uvicorn", "src.api.app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,11 +1,58 @@
1
- ---
2
- title: Captioning
3
- emoji: 🐨
4
- colorFrom: gray
5
- colorTo: red
6
- sdk: docker
7
- pinned: false
8
- license: mit
9
- ---
10
-
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Image Captioning with SOTA Models
2
+
3
+ This project provides a unified API for Image Captioning using various State-of-the-Art (SOTA) models as well as a custom ResNet+GPT2 implementation.
4
+
5
+ ## Supported Models
6
+
7
+ 1. **BLIP (Bootstrapping Language-Image Pre-training)**
8
+ * Model: `Salesforce/blip-image-captioning-large`
9
+ * Status: **Default** (Best Performance)
10
+ * Description: Produces highly accurate and detailed captions.
11
+
12
+ 2. **ViT-GPT2**
13
+ * Model: `nlpconnect/vit-gpt2-image-captioning`
14
+ * Status: Available
15
+ * Description: Uses Vision Transformer (ViT) encoder and GPT-2 decoder.
16
+
17
+ 3. **ResNet50 + GPT-2 (Custom)**
18
+ * Model: Custom implementation trained from scratch.
19
+ * Status: Legacy / Experimental
20
+ * Description: Good for learning purposes or custom datasets.
21
+
22
+ ## Installation
23
+
24
+ 1. Clone the repository.
25
+ 2. Install dependencies:
26
+ ```bash
27
+ pip install -r requirements.txt
28
+ ```
29
+
30
+ ## Configuration
31
+
32
+ Edit `config/config.py` to select the model:
33
+
34
+ ```python
35
+ class Config:
36
+ # ...
37
+ MODEL_TYPE = "blip" # Options: "blip", "vit_gpt2", "resnet_gpt2"
38
+ ```
39
+
40
+ ## Running the API
41
+
42
+ Start the FastAPI server:
43
+
44
+ ```bash
45
+ python main.py --mode api
46
+ ```
47
+
48
+ Open your browser at `http://localhost:8001` to use the drag-and-drop interface.
49
+
50
+ ## Training (ResNet+GPT2 only)
51
+
52
+ To train the custom model:
53
+
54
+ 1. Set `MODEL_TYPE = "resnet_gpt2"` in config.
55
+ 2. Run:
56
+ ```bash
57
+ python main.py --mode train
58
+ ```
config/config.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ class Config:
5
+ # Paths
6
+ BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
7
+
8
+ # We do not need training data paths for the inference API
9
+ DATA_DIR = None
10
+ CAPTIONS_FILE = None
11
+
12
+ # Model saving/loading directory
13
+ MODEL_SAVE_DIR = os.path.join(BASE_DIR, 'models')
14
+ LOG_DIR = os.path.join(BASE_DIR, 'logs')
15
+
16
+ os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
17
+ os.makedirs(LOG_DIR, exist_ok=True)
18
+
19
+ # Device: Force CPU if CUDA is not available (Hugging Face Free Tier is CPU)
20
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+
22
+ # Hyperparameters (kept for reference, mostly unused in inference)
23
+ BATCH_SIZE = 1
24
+ LEARNING_RATE = 2e-5
25
+ NUM_EPOCHS = 10
26
+ NUM_WORKERS = 2
27
+
28
+ # Model Config
29
+ # Change this to "blip" or "vit_gpt2" for your deployment to ensure no custom weights are needed
30
+ MODEL_TYPE = "blip"
31
+ ENCODER_MODEL = "resnet50"
32
+ DECODER_MODEL = "gpt2"
33
+ EMBED_DIM = 768
34
+ MAX_SEQ_LEN = 40
35
+
36
+ # Image Config
37
+ IMAGE_SIZE = (224, 224)
38
+
39
+ config = Config()
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ fastapi
4
+ uvicorn
5
+ pillow
6
+ pandas
7
+ spacy
8
+ tqdm
9
+ matplotlib
10
+ gTTS
11
+ transformers
12
+ python-multipart
13
+ requests
src/api/app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import torch
4
+ from fastapi import FastAPI, UploadFile, File, HTTPException
5
+ from fastapi.middleware.cors import CORSMiddleware
6
+ from PIL import Image
7
+ from transformers import GPT2Tokenizer
8
+
9
+ # Adjust imports based on your folder structure
10
+ from config.config import config
11
+ from src.models.model import get_model
12
+ from src.preprocessing.transforms import get_transforms
13
+
14
+ app = FastAPI(title="Object Captioning LLM API")
15
+
16
+ # --- CORS CONFIGURATION (Crucial for Vercel) ---
17
+ app.add_middleware(
18
+ CORSMiddleware,
19
+ allow_origins=["*"], # For production, replace "*" with your Vercel URL
20
+ allow_credentials=True,
21
+ allow_methods=["*"],
22
+ allow_headers=["*"],
23
+ )
24
+
25
+ # Load Model
26
+ print(f"Loading model: {config.MODEL_TYPE} on {config.DEVICE}...")
27
+ device = config.DEVICE
28
+ model = get_model(config).to(device)
29
+
30
+ # Legacy support for ResNetGPT2
31
+ if config.MODEL_TYPE == "resnet_gpt2":
32
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
33
+ # Use a relative path or ensure this file is uploaded to the Docker container
34
+ model_path = os.path.join(config.MODEL_SAVE_DIR, "best_model_llm.pth")
35
+ if os.path.exists(model_path):
36
+ model.load_state_dict(torch.load(model_path, map_location=device))
37
+ print("Loaded trained custom model.")
38
+ else:
39
+ print("Warning: No trained model found for ResNetGPT2. Using random weights.")
40
+ else:
41
+ tokenizer = None
42
+
43
+ model.eval()
44
+ transform = get_transforms(train=False)
45
+
46
+ @app.get("/")
47
+ def home():
48
+ return {"message": "Image Captioning API is running. Send POST requests to /predict"}
49
+
50
+ @app.post("/predict")
51
+ async def predict(file: UploadFile = File(...)):
52
+ try:
53
+ # Read Image
54
+ image_data = await file.read()
55
+ image = Image.open(io.BytesIO(image_data)).convert("RGB")
56
+
57
+ # Generate Caption
58
+ if config.MODEL_TYPE == "resnet_gpt2":
59
+ img_tensor = transform(image).to(device)
60
+ # Ensure generate_caption handles the tensor/tokenizer correctly
61
+ caption = model.generate_caption(img_tensor, tokenizer)
62
+ else:
63
+ # SOTA models (BLIP/ViT) take the PIL image directly
64
+ caption = model.generate_caption(image)
65
+
66
+ return {
67
+ "caption": caption
68
+ }
69
+ except Exception as e:
70
+ print(f"Error during prediction: {e}")
71
+ raise HTTPException(status_code=500, detail=str(e))
src/data/dataset.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ from PIL import Image
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+ from transformers import GPT2Tokenizer
7
+
8
+ class CaptionDataset(Dataset):
9
+ def __init__(self, root_dir, captions_file, transform=None, max_length=40):
10
+ self.root_dir = root_dir
11
+ self.transform = transform
12
+ self.max_length = max_length
13
+
14
+ # Load captions
15
+ # Format: image,caption (csv)
16
+ self.df = pd.read_csv(captions_file, delimiter=',')
17
+
18
+ # Rename columns to match expected internal names if necessary, or just use them directly
19
+ # The file has 'image' and 'caption' columns based on inspection
20
+ self.df.rename(columns={'image': 'image_name', 'caption': 'comment'}, inplace=True)
21
+
22
+ self.df['image_name'] = self.df['image_name'].str.strip()
23
+ self.df['comment'] = self.df['comment'].str.strip()
24
+ self.df = self.df.dropna()
25
+
26
+ self.captions = self.df['comment'].tolist()
27
+ self.images = self.df['image_name'].tolist()
28
+
29
+ # Initialize Tokenizer
30
+ self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
31
+ # GPT2 doesn't have a pad token, so we use eos_token as pad_token
32
+ self.tokenizer.pad_token = self.tokenizer.eos_token
33
+
34
+ def __len__(self):
35
+ return len(self.captions)
36
+
37
+ def __getitem__(self, idx):
38
+ caption = self.captions[idx]
39
+ img_name = self.images[idx]
40
+ img_path = os.path.join(self.root_dir, img_name)
41
+
42
+ try:
43
+ image = Image.open(img_path).convert("RGB")
44
+ except Exception:
45
+ # Fallback for missing images or errors, return next item
46
+ return self.__getitem__((idx + 1) % len(self))
47
+
48
+ if self.transform:
49
+ image = self.transform(image)
50
+
51
+ # Tokenize caption
52
+ # We add a special prefix to prompt the model if desired, but for direct captioning:
53
+ # Format: [Image Feature] -> Caption
54
+ encoding = self.tokenizer(
55
+ caption,
56
+ truncation=True,
57
+ padding='max_length',
58
+ max_length=self.max_length,
59
+ return_tensors='pt'
60
+ )
61
+
62
+ input_ids = encoding['input_ids'].squeeze()
63
+ attention_mask = encoding['attention_mask'].squeeze()
64
+
65
+ return image, input_ids, attention_mask
src/models/model.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.models as models
4
+ from transformers import (
5
+ GPT2LMHeadModel,
6
+ VisionEncoderDecoderModel,
7
+ ViTImageProcessor,
8
+ AutoTokenizer,
9
+ BlipProcessor,
10
+ BlipForConditionalGeneration
11
+ )
12
+
13
+ # -----------------------------------------------------------------------------
14
+ # 1. Custom ResNet + GPT-2 (Training from Scratch)
15
+ # -----------------------------------------------------------------------------
16
+ class ResNetEncoder(nn.Module):
17
+ def __init__(self, embed_dim=768):
18
+ super(ResNetEncoder, self).__init__()
19
+ resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
20
+ modules = list(resnet.children())[:-1]
21
+ self.resnet = nn.Sequential(*modules)
22
+ for param in self.resnet.parameters():
23
+ param.requires_grad = False
24
+ self.projection = nn.Linear(2048, embed_dim)
25
+ self.bn = nn.BatchNorm1d(embed_dim, momentum=0.01)
26
+
27
+ def forward(self, images):
28
+ features = self.resnet(images)
29
+ features = features.view(features.size(0), -1)
30
+ features = self.projection(features)
31
+ features = self.bn(features)
32
+ return features
33
+
34
+ class ResNetGPT2(nn.Module):
35
+ def __init__(self, max_seq_len=40):
36
+ super(ResNetGPT2, self).__init__()
37
+ self.encoder = ResNetEncoder(embed_dim=768)
38
+ self.gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
39
+ self.max_seq_len = max_seq_len
40
+
41
+ def forward(self, images, input_ids, attention_mask):
42
+ image_embeds = self.encoder(images)
43
+ token_embeds = self.gpt2.transformer.wte(input_ids)
44
+ inputs_embeds = torch.cat((image_embeds.unsqueeze(1), token_embeds), dim=1)
45
+ batch_size = images.shape[0]
46
+ ones = torch.ones(batch_size, 1).to(images.device)
47
+ attention_mask = torch.cat((ones, attention_mask), dim=1)
48
+ outputs = self.gpt2(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
49
+ return outputs.logits
50
+
51
+ def generate_caption(self, image, tokenizer, max_length=20, temperature=1.0):
52
+ self.eval()
53
+ with torch.no_grad():
54
+ image_embed = self.encoder(image.unsqueeze(0))
55
+ inputs_embeds = image_embed.unsqueeze(1)
56
+ generated_tokens = []
57
+ for _ in range(max_length):
58
+ outputs = self.gpt2(inputs_embeds=inputs_embeds)
59
+ logits = outputs.logits[:, -1, :] / temperature
60
+ next_token = torch.argmax(logits, dim=-1).unsqueeze(0)
61
+ if next_token.item() == tokenizer.eos_token_id:
62
+ break
63
+ generated_tokens.append(next_token.item())
64
+ next_token_embed = self.gpt2.transformer.wte(next_token)
65
+ inputs_embeds = torch.cat((inputs_embeds, next_token_embed), dim=1)
66
+ return tokenizer.decode(generated_tokens, skip_special_tokens=True)
67
+
68
+ # -----------------------------------------------------------------------------
69
+ # 2. ViT + GPT-2 (Pre-trained SOTA 1)
70
+ # -----------------------------------------------------------------------------
71
+ class ViTGPT2Captioner(nn.Module):
72
+ def __init__(self):
73
+ super().__init__()
74
+ print("Loading ViT-GPT2 model...")
75
+ self.model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
76
+ self.feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
77
+ self.tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
78
+
79
+ def generate_caption(self, image, **kwargs):
80
+ self.eval()
81
+ with torch.no_grad():
82
+ pixel_values = self.feature_extractor(images=image, return_tensors="pt").pixel_values
83
+ pixel_values = pixel_values.to(self.model.device)
84
+ output_ids = self.model.generate(pixel_values, max_length=20, num_beams=4)
85
+ preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
86
+ return preds[0].strip()
87
+
88
+ # -----------------------------------------------------------------------------
89
+ # 3. BLIP (Pre-trained SOTA 2 - Best)
90
+ # -----------------------------------------------------------------------------
91
+ class BLIPCaptioner(nn.Module):
92
+ def __init__(self):
93
+ super().__init__()
94
+ print("Loading BLIP model (Salesforce/blip-image-captioning-large)...")
95
+ self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
96
+ self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
97
+
98
+ def generate_caption(self, image, **kwargs):
99
+ self.eval()
100
+ with torch.no_grad():
101
+ inputs = self.processor(images=image, return_tensors="pt").to(self.model.device)
102
+ output_ids = self.model.generate(**inputs, max_length=50, num_beams=5, repetition_penalty=1.2, min_length=5)
103
+ caption = self.processor.decode(output_ids[0], skip_special_tokens=True)
104
+ return caption
105
+
106
+ # -----------------------------------------------------------------------------
107
+ # Factory
108
+ # -----------------------------------------------------------------------------
109
+ def get_model(config):
110
+ if config.MODEL_TYPE == "resnet_gpt2":
111
+ return ResNetGPT2()
112
+ elif config.MODEL_TYPE == "vit_gpt2":
113
+ return ViTGPT2Captioner()
114
+ elif config.MODEL_TYPE == "blip":
115
+ return BLIPCaptioner()
116
+ else:
117
+ raise ValueError(f"Unknown model type: {config.MODEL_TYPE}")
118
+
src/preprocessing/transforms.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision.transforms as transforms
2
+
3
+ def get_transforms(image_size=(224, 224), train=True):
4
+ if train:
5
+ return transforms.Compose([
6
+ transforms.Resize(image_size),
7
+ transforms.RandomHorizontalFlip(),
8
+ transforms.ColorJitter(brightness=0.1, contrast=0.1),
9
+ transforms.ToTensor(),
10
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet stats
11
+ ])
12
+ else:
13
+ return transforms.Compose([
14
+ transforms.Resize(image_size),
15
+ transforms.ToTensor(),
16
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
17
+ ])
src/training/trainer.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from tqdm import tqdm
5
+ import os
6
+
7
+ def train_model(model, train_loader, val_loader, config, tokenizer):
8
+ model = model.to(config.DEVICE)
9
+ optimizer = optim.AdamW(model.parameters(), lr=config.LEARNING_RATE)
10
+ criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
11
+
12
+ best_loss = float('inf')
13
+
14
+ for epoch in range(config.NUM_EPOCHS):
15
+ model.train()
16
+ train_loss = 0
17
+ loop = tqdm(train_loader, total=len(train_loader), leave=True)
18
+
19
+ for images, input_ids, attention_mask in loop:
20
+ images = images.to(config.DEVICE)
21
+ input_ids = input_ids.to(config.DEVICE)
22
+ attention_mask = attention_mask.to(config.DEVICE)
23
+
24
+ optimizer.zero_grad()
25
+
26
+ # Forward pass
27
+ # Logits: [batch, seq_len+1, vocab_size]
28
+ logits = model(images, input_ids, attention_mask)
29
+
30
+ # Shift logits and labels for next-token prediction
31
+ # We want to predict input_ids based on previous context
32
+ # The model output at index i corresponds to prediction for token at i+1
33
+ # Input sequence to model: [Image, T1, T2, T3, ...]
34
+ # Output logits: [P1, P2, P3, P4, ...]
35
+ # Targets: [T1, T2, T3, T4, ...]
36
+
37
+ # We discard the last logit because we don't have a target for it
38
+ shift_logits = logits[:, :-1, :].contiguous()
39
+ # We use input_ids as targets
40
+ shift_labels = input_ids.contiguous()
41
+
42
+ loss = criterion(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
43
+
44
+ loss.backward()
45
+ optimizer.step()
46
+
47
+ train_loss += loss.item()
48
+ loop.set_description(f"Epoch [{epoch+1}/{config.NUM_EPOCHS}]")
49
+ loop.set_postfix(loss=loss.item())
50
+
51
+ avg_train_loss = train_loss / len(train_loader)
52
+ print(f"Epoch {epoch+1} Loss: {avg_train_loss:.4f}")
53
+
54
+ # Save checkpoint
55
+ if avg_train_loss < best_loss:
56
+ best_loss = avg_train_loss
57
+ torch.save(model.state_dict(), os.path.join(config.MODEL_SAVE_DIR, "best_model_llm.pth"))
58
+ print("Saved Best Model!")
59
+
60
+ torch.save(model.state_dict(), os.path.join(config.MODEL_SAVE_DIR, "last_checkpoint_llm.pth"))