| ''' |
| This script contains the CaptionGenerator class, |
| which is used to generate Instagram captions for images |
| using BLIP and Gemini models. |
| ''' |
| import google.generativeai as genai |
| import streamlit as st |
| import app_utils as utils |
| from transformers import AutoProcessor, Blip2ForConditionalGeneration |
|
|
| genai.configure(api_key=utils.get_gemini_api_key()) |
|
|
| class CaptionGenerator: |
| """ |
| Class for generating Instagram captions for images using BLIP and Gemini models. |
| The model from Hugging Face is used to generate the initial caption for the image, |
| which is then used as a prompt for the Gemini model to generate five distinct and |
| engaging captions for the image. |
| |
| Attributes: |
| - google_api_key (str): Google API key for accessing the Generative AI API. |
| - gemini_model (GenerativeModel): Gemini model for generating Instagram captions. |
| - blip_processor (AutoProcessor): BLIP model processor for image captioning. |
| - blip_model (BlipForConditionalGeneration): BLIP model for image captioning. |
| |
| Methods: |
| - process_image(image_data): Resize and prepare the image for caption generation. |
| - predict(image_data): Generate five Instagram captions for the provided image. |
| """ |
| def __init__(self): |
| self.gemini = genai.GenerativeModel('gemini-pro') |
| self.processor = None |
| self.model = None |
|
|
| def image_2_text(self, image_data, processor, model): |
| """ |
| Generate a caption for the provided image using the BLIP-2 model. |
| :param image_data: PIL.Image - The image for which the caption is to be generated. |
| :return: description - The description generated for the image. |
| """ |
| try: |
| inputs = processor(images=image_data, return_tensors="pt") |
| generated_ids = model.generate(**inputs, max_length=100) |
| description = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() |
| return description |
|
|
| except Exception as e: |
| st.error(f"Error occurred during image captioning: {e}") |
|
|
|
|
| def text_2_caption(self, image_description): |
| """ |
| Generate five Instagram captions for the provided image. |
| The image is first processed before generating captions. |
| |
| :param image_description: str - The description of the image for which captions are to be generated. |
| :return: str - Five Instagram captions formatted as specified. |
| """ |
| prompt = ( |
| f"This caption was generated with a deep learning model." |
| f"Your task is to enhance the caption to make it more engaging:" |
| f"Given this provided photo description, generate five distinct " |
| f"fun and engaging Instagram captions. Each caption must include " |
| f"at least one emoji and one hashtag. The captions should be " |
| f"formatted with a preceding 'Caption #', followed by the " |
| f"caption text. Ensure each caption is separated by a blank " |
| f"line for readability." |
| f"Original Caption: {image_description}" |
| f"Please format your response as follows: \n" |
| f"**Caption 1**: [caption text]\n" |
| f"**Caption 2**: [caption text]\n" |
| f"**Caption 3**: [caption text]\n" |
| f"**Caption 4**: [caption text]\n" |
| f"**Caption 5**: [caption text]\n" |
| ) |
|
|
| try: |
| response = self.gemini.generate_content(prompt) |
| caption = response.parts[0].text |
| caption_list = response.parts[0].text.split("\n") |
| return caption, caption_list, image_description |
|
|
| except Exception as e: |
| st.error(f"Unable to connect to Gemini API: {e}") |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def caption_2_hashtag(self, caption): |
| """ |
| Generate additional hashtags based on the content of the caption. |
| |
| :param caption: str - The caption for which hashtags are to be generated. |
| :return: str - Additional hashtags based on the content of the caption. |
| """ |
| |
| prompt = (f"Given the provided caption, generate relevant hashtags to increase engagement," |
| f"and are related to the caption content. Original Image Description: {caption}," |
| f"Please format your response as follows:\n" |
| f'[hashtags separated by commas]' |
| f" \n") |
|
|
| try: |
| response = self.gemini.generate_content(prompt) |
| hashtags = response.parts[0].text |
| return hashtags |
|
|
| except Exception as e: |
| st.error(f"Error occurred with Gemini API: {e}") |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|