| | import streamlit as st |
| | from PIL import Image |
| | import torch |
| | from transformers import ( |
| | ViTFeatureExtractor, |
| | ViTForImageClassification, |
| | pipeline, |
| | AutoTokenizer, |
| | AutoModelForSeq2SeqLM |
| | ) |
| | from diffusers import StableDiffusionPipeline |
| |
|
| | |
| | @st.cache_resource |
| | def load_models(): |
| | age_model = ViTForImageClassification.from_pretrained('nateraw/vit-age-classifier') |
| | age_transforms = ViTFeatureExtractor.from_pretrained('nateraw/vit-age-classifier') |
| | |
| | gender_model = ViTForImageClassification.from_pretrained('rizvandwiki/gender-classification-2') |
| | gender_transforms = ViTFeatureExtractor.from_pretrained('rizvandwiki/gender-classification-2') |
| | |
| | emotion_model = ViTForImageClassification.from_pretrained('dima806/facial_emotions_image_detection') |
| | emotion_transforms = ViTFeatureExtractor.from_pretrained('dima806/facial_emotions_image_detection') |
| | |
| | object_detector = pipeline("object-detection", model="facebook/detr-resnet-50") |
| | |
| | action_model = ViTForImageClassification.from_pretrained('rvv-karma/Human-Action-Recognition-VIT-Base-patch16-224') |
| | action_transforms = ViTFeatureExtractor.from_pretrained('rvv-karma/Human-Action-Recognition-VIT-Base-patch16-224') |
| | |
| | prompt_enhancer_tokenizer = AutoTokenizer.from_pretrained("gokaygokay/Flux-Prompt-Enhance") |
| | prompt_enhancer_model = AutoModelForSeq2SeqLM.from_pretrained("gokaygokay/Flux-Prompt-Enhance") |
| | prompt_enhancer = pipeline('text2text-generation', |
| | model=prompt_enhancer_model, |
| | tokenizer=prompt_enhancer_tokenizer, |
| | repetition_penalty=1.2, |
| | device="cpu") |
| | |
| | |
| | pipe = StableDiffusionPipeline.from_pretrained("nota-ai/bk-sdm-tiny", torch_dtype=torch.float16) |
| | return (age_model, age_transforms, gender_model, gender_transforms, |
| | emotion_model, emotion_transforms, object_detector, |
| | action_model, action_transforms, prompt_enhancer, pipe) |
| |
|
| | models = load_models() |
| | (age_model, age_transforms, gender_model, gender_transforms, |
| | emotion_model, emotion_transforms, object_detector, |
| | action_model, action_transforms, prompt_enhancer, pipe) = models |
| |
|
| | def predict(image, model, transforms): |
| | |
| | if image.mode != 'RGB': |
| | image = image.convert('RGB') |
| | |
| | |
| | inputs = transforms(images=[image], return_tensors='pt') |
| | output = model(**inputs) |
| | proba = output.logits.softmax(1) |
| | return proba.argmax(1).item() |
| |
|
| | def detect_attributes(image): |
| | age = predict(image, age_model, age_transforms) |
| | gender = predict(image, gender_model, gender_transforms) |
| | emotion = predict(image, emotion_model, emotion_transforms) |
| | action = predict(image, action_model, action_transforms) |
| | |
| | objects = object_detector(image) |
| | |
| | return { |
| | 'age': age_model.config.id2label[age], |
| | 'gender': gender_model.config.id2label[gender], |
| | 'emotion': emotion_model.config.id2label[emotion], |
| | 'action': action_model.config.id2label[action], |
| | 'objects': [obj['label'] for obj in objects] |
| | } |
| |
|
| | def generate_prompt(attributes): |
| | prompt = f"A {attributes['age']} year old {attributes['gender']} person feeling {attributes['emotion']} " |
| | prompt += f"while {attributes['action']}. " |
| | if attributes['objects']: |
| | prompt += f"Image has {', '.join(attributes['objects'])}. " |
| | return prompt |
| |
|
| | def enhance_prompt(prompt): |
| | prefix = "enhance prompt: " |
| | enhanced = prompt_enhancer(prefix + prompt, max_length=256) |
| | return enhanced[0]['generated_text'] |
| |
|
| | @st.cache_data |
| | def generate_image(prompt): |
| | |
| | with torch.no_grad(): |
| | image = pipe(prompt, num_inference_steps=50).images[0] |
| | return image |
| |
|
| | st.title("Image Attribute Detection and Image Generation") |
| |
|
| | uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) |
| |
|
| | if uploaded_file is not None: |
| | image = Image.open(uploaded_file) |
| | st.image(image, caption='Uploaded Image', use_column_width=True) |
| |
|
| | if st.button('Analyze Image'): |
| | with st.spinner('Detecting attributes...'): |
| | attributes = detect_attributes(image) |
| |
|
| | st.write("Detected Attributes:") |
| | for key, value in attributes.items(): |
| | st.write(f"{key.capitalize()}: {value}") |
| |
|
| | with st.spinner('Generating prompt...'): |
| | initial_prompt = generate_prompt(attributes) |
| | enhanced_prompt = enhance_prompt(initial_prompt) |
| | |
| | st.write("Initial Prompt:") |
| | st.write(initial_prompt) |
| | st.write("Enhanced Prompt:") |
| | st.write(enhanced_prompt) |
| |
|
| | with st.spinner('Generating image...'): |
| | generated_image = generate_image(enhanced_prompt) |
| | st.image(generated_image, caption='Generated Image', use_column_width=True) |