| | """ |
| | Embedding Service for generating image embeddings |
| | """ |
| |
|
| | import os |
| | from typing import List, Dict, Any |
| | from PIL import Image |
| | import io |
| | import numpy as np |
| | import torch |
| | from transformers import CLIPProcessor, CLIPModel |
| |
|
| |
|
| | class ImageEmbeddingModel: |
| | """Class for generating embeddings from images using CLIP""" |
| | |
| | def __init__(self, model_name: str = "openai/clip-vit-base-patch32"): |
| | """Initialize the CLIP model |
| | |
| | Args: |
| | model_name: Name of the CLIP model to use |
| | """ |
| | self.model_name = model_name |
| | self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| | self.model = CLIPModel.from_pretrained(model_name).to(self.device) |
| | self.processor = CLIPProcessor.from_pretrained(model_name) |
| | |
| | def generate_embedding(self, image_data: bytes) -> List[float]: |
| | """Generate embedding for an image from binary data |
| | |
| | Args: |
| | image_data: Binary image data |
| | |
| | Returns: |
| | Image embedding as a list of floats |
| | """ |
| | |
| | image = Image.open(io.BytesIO(image_data)).convert("RGB") |
| | return self.generate_embedding_from_pil(image) |
| | |
| | def generate_embedding_from_pil(self, image: Image.Image) -> List[float]: |
| | """Generate embedding for a PIL Image |
| | |
| | Args: |
| | image: PIL Image object |
| | |
| | Returns: |
| | Image embedding as a list of floats |
| | """ |
| | |
| | inputs = self.processor(images=image, return_tensors="pt").to(self.device) |
| | |
| | |
| | with torch.no_grad(): |
| | image_features = self.model.get_image_features(**inputs) |
| | |
| | |
| | image_embedding = image_features.cpu().numpy()[0] |
| | normalized_embedding = image_embedding / np.linalg.norm(image_embedding) |
| | return normalized_embedding.tolist() |
| | |
| | def get_embeddings_from_folder(self, folder_path: str) -> Dict[str, Any]: |
| | """Generate embeddings for all images in a folder |
| | |
| | Args: |
| | folder_path: Path to folder containing images |
| | |
| | Returns: |
| | Dictionary mapping filenames to embeddings |
| | """ |
| | results = {} |
| | image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp'} |
| | |
| | |
| | if not os.path.exists(folder_path): |
| | return {"error": f"Folder {folder_path} does not exist"} |
| | |
| | |
| | for filename in os.listdir(folder_path): |
| | if os.path.splitext(filename)[1].lower() in image_extensions: |
| | try: |
| | file_path = os.path.join(folder_path, filename) |
| | with open(file_path, 'rb') as f: |
| | image_data = f.read() |
| | |
| | embedding = self.generate_embedding(image_data) |
| | results[filename] = { |
| | "embedding": embedding, |
| | "status": "success" |
| | } |
| | except Exception as e: |
| | results[filename] = { |
| | "error": str(e), |
| | "status": "failed" |
| | } |
| | |
| | return results |