# import gradio as gr # from transformers import BlipProcessor, BlipForConditionalGeneration # from PIL import Image # import torch # import requests # # Load model & processor # processor = BlipProcessor.from_pretrained( # "Salesforce/blip-image-captioning-base" # ) # model = BlipForConditionalGeneration.from_pretrained( # "Salesforce/blip-image-captioning-base" # ) # device = "cuda" if torch.cuda.is_available() else "cpu" # model.to(device) # def caption_image(image, prompt="", openai_api_key=""): # if not prompt or not prompt.strip(): # return "Please enter a prompt/question for the image." # image = image.convert("RGB") # # Use OpenAI API if key provided (unchanged) # if openai_api_key: # try: # import base64 # from io import BytesIO # buffered = BytesIO() # image.save(buffered, format="PNG") # img_b64 = base64.b64encode(buffered.getvalue()).decode() # headers = { # "Authorization": f"Bearer {openai_api_key}", # "Content-Type": "application/json" # } # data = { # "model": "gpt-4-vision-preview", # "messages": [ # { # "role": "user", # "content": [ # {"type": "text", "text": prompt.strip()}, # {"type": "image_url", "image_url": f"data:image/png;base64,{img_b64}"} # ] # } # ], # "max_tokens": 100 # } # resp = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=data) # if resp.status_code == 200: # result = resp.json() # return result["choices"][0]["message"]["content"].strip() # else: # return f"OpenAI API error: {resp.status_code} {resp.text}" # except Exception as e: # return f"OpenAI API error: {e}" # # BLIP: always use prompt as instruction, no retry, fast settings # p = prompt.strip() # prompt_text = f"Question: {p} Answer:" # inputs = processor(images=image, text=prompt_text, return_tensors="pt").to(device) # # Speed up: reduce beams and max_length # gen_kwargs = {"max_length": 25, "num_beams": 1, "early_stopping": True} # output = model.generate(**inputs, **gen_kwargs) # caption = processor.decode(output[0], skip_special_tokens=True) # # Extract answer after 'Answer:' if present # idx = caption.lower().find("answer:") # if idx != -1: # ans = caption[idx + len("answer:"):].strip() # if ans: # return ans # # Otherwise, return the raw caption # return caption.strip() # # Gradio UI: horizontal layout with image, prompt, button left; output right # with gr.Blocks() as demo: # gr.Markdown("## 🖼️ Image Captioning (Prompt-driven)\nUpload an image, enter a prompt, and click Submit. Output depends on both image and prompt.") # with gr.Row(): # with gr.Column(scale=2): # img = gr.Image(type="pil", label="Upload Image") # prompt = gr.Textbox(label="Prompt (ask a question)", placeholder="What is the color of the t-shirt?") # openai_api_key = gr.Textbox(label="OpenAI API Key (optional)", type="password", placeholder="sk-...", lines=1) # btn = gr.Button("Submit") # with gr.Column(scale=1): # out = gr.Textbox(label="Answer", lines=6) # btn.click(fn=caption_image, inputs=[img, prompt, openai_api_key], outputs=out) # demo.launch() import gradio as gr import torch from transformers import BlipProcessor, BlipForQuestionAnswering from PIL import Image # --------------------------- # Load BLIP VQA model # --------------------------- MODEL_NAME = "Salesforce/blip-vqa-base" processor = BlipProcessor.from_pretrained(MODEL_NAME) model = BlipForQuestionAnswering.from_pretrained(MODEL_NAME) device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) model.eval() # --------------------------- # Inference function # --------------------------- def answer_image_question(image, question): if image is None: return "Please upload an image." if not question.strip(): return "Please enter a question." image = image.convert("RGB") inputs = processor( images=image, text=question, return_tensors="pt" ).to(device) with torch.no_grad(): output = model.generate( **inputs, max_length=10, # fast num_beams=1 # faster ) answer = processor.decode(output[0], skip_special_tokens=True) return answer # --------------------------- # Gradio UI # --------------------------- with gr.Blocks() as demo: gr.Markdown("## 🖼️ Image Question Answering (Fast & Accurate)") gr.Markdown( "Upload an image and ask a question like:\n" "Anything" ) with gr.Row(): with gr.Column(): img = gr.Image(type="pil", label="Upload Image") question = gr.Textbox( label="Question", placeholder="What is the color of the shirt?" ) btn = gr.Button("Submit") with gr.Column(): answer = gr.Textbox(label="Answer", lines=3) btn.click( fn=answer_image_question, inputs=[img, question], outputs=answer ) demo.launch()