| import streamlit as st |
| from PIL import Image |
| from transformers import ( |
| BlipProcessor, |
| BlipForConditionalGeneration, |
| AutoTokenizer, |
| AutoModelForCausalLM |
| ) |
| from gtts import gTTS |
| import io |
| import torch |
|
|
| |
| |
| |
| @st.cache_resource |
| def load_image_model(): |
| """Load image captioning model""" |
| return ( |
| BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base"), |
| BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") |
| ) |
|
|
| def stage1_process(uploaded_file): |
| """Generate image caption""" |
| processor, model = load_image_model() |
| img = Image.open(uploaded_file).convert("RGB") |
| inputs = processor(images=img, return_tensors="pt") |
| outputs = model.generate(**inputs) |
| return processor.decode(outputs[0], skip_special_tokens=True) |
|
|
| |
| |
| |
| @st.cache_resource |
| def load_story_model(): |
| """Load optimized story model""" |
| return ( |
| AutoTokenizer.from_pretrained("gpt2-medium"), |
| AutoModelForCausalLM.from_pretrained("gpt2-medium") |
| ) |
|
|
| def stage2_process(keyword): |
| """Generate structured story""" |
| tokenizer, model = load_story_model() |
| |
| |
| prompt = f"""Write a children's story in 100-150 words with these elements: |
| - Theme: {keyword} |
| - Characters: Friendly animals |
| - Moral: Sharing is caring |
| |
| Story begins: One sunny morning, a little rabbit named Cotton discovered""" |
| |
| inputs = tokenizer(prompt, return_tensors="pt", max_length=150, truncation=True) |
| outputs = model.generate( |
| inputs.input_ids, |
| max_new_tokens=300, |
| temperature=0.9, |
| top_k=50, |
| no_repeat_ngram_size=3, |
| repetition_penalty=1.2, |
| do_sample=True |
| ) |
| full_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| return full_text.split("Story begins:")[-1].strip() |
|
|
| |
| |
| |
| def stage3_process(text): |
| """Convert text to audio""" |
| try: |
| clean_text = text.strip().replace('\n', ' ')[:300] |
| if len(clean_text) < 20: |
| return None |
| tts = gTTS(text=clean_text, lang='en') |
| audio = io.BytesIO() |
| tts.write_to_fp(audio) |
| audio.seek(0) |
| return audio |
| except: |
| return None |
|
|
| |
| |
| |
| def main(): |
| st.title("📖 Children's Story Generator") |
| |
| |
| if 'processing' not in st.session_state: |
| st.session_state.update({ |
| 'caption': None, |
| 'story': None, |
| 'audio': None |
| }) |
| |
| |
| uploaded_file = st.file_uploader("Upload Image", type=["jpg", "png"]) |
| |
| if uploaded_file: |
| |
| st.image(uploaded_file, width=300) |
| |
| |
| if not st.session_state.caption: |
| with st.spinner("Analyzing image..."): |
| st.session_state.caption = stage1_process(uploaded_file) |
| st.success(f"Detected Theme: {st.session_state.caption}") |
| |
| |
| if not st.session_state.story: |
| with st.spinner("Writing magical story..."): |
| st.session_state.story = stage2_process(st.session_state.caption) |
| |
| |
| if st.session_state.story: |
| st.subheader("Generated Story") |
| st.write(st.session_state.story) |
| |
| |
| if not st.session_state.audio: |
| with st.spinner("Generating audio..."): |
| st.session_state.audio = stage3_process(st.session_state.story) |
| if st.session_state.audio: |
| st.audio(st.session_state.audio, format="audio/mp3") |
| st.download_button("Download Audio", |
| st.session_state.audio.getvalue(), |
| "story.mp3", |
| mime="audio/mp3") |
|
|
| if __name__ == "__main__": |
| main() |