| | import streamlit as st |
| | from PIL import Image |
| | from transformers import ( |
| | BlipProcessor, |
| | BlipForConditionalGeneration, |
| | AutoTokenizer, |
| | AutoModelForCausalLM |
| | ) |
| | from gtts import gTTS |
| | import io |
| | import torch |
| |
|
| | |
| | |
| | |
| | @st.cache_resource |
| | def load_image_model(): |
| | """Load image captioning model""" |
| | return ( |
| | BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base"), |
| | BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") |
| | ) |
| |
|
| | def stage1_process(uploaded_file): |
| | """Generate image caption""" |
| | processor, model = load_image_model() |
| | img = Image.open(uploaded_file).convert("RGB") |
| | inputs = processor(images=img, return_tensors="pt") |
| | outputs = model.generate(**inputs) |
| | return processor.decode(outputs[0], skip_special_tokens=True) |
| |
|
| | |
| | |
| | |
| | @st.cache_resource |
| | def load_story_model(): |
| | """Load optimized story model""" |
| | return ( |
| | AutoTokenizer.from_pretrained("gpt2-medium"), |
| | AutoModelForCausalLM.from_pretrained("gpt2-medium") |
| | ) |
| |
|
| | def stage2_process(keyword): |
| | """Generate structured story""" |
| | tokenizer, model = load_story_model() |
| | |
| | |
| | prompt = f"""Write a children's story in 100-150 words with these elements: |
| | - Theme: {keyword} |
| | - Characters: Friendly animals |
| | - Moral: Sharing is caring |
| | |
| | Story begins: One sunny morning, a little rabbit named Cotton discovered""" |
| | |
| | inputs = tokenizer(prompt, return_tensors="pt", max_length=150, truncation=True) |
| | outputs = model.generate( |
| | inputs.input_ids, |
| | max_new_tokens=300, |
| | temperature=0.9, |
| | top_k=50, |
| | no_repeat_ngram_size=3, |
| | repetition_penalty=1.2, |
| | do_sample=True |
| | ) |
| | full_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| | return full_text.split("Story begins:")[-1].strip() |
| |
|
| | |
| | |
| | |
| | def stage3_process(text): |
| | """Convert text to audio""" |
| | try: |
| | clean_text = text.strip().replace('\n', ' ')[:300] |
| | if len(clean_text) < 20: |
| | return None |
| | tts = gTTS(text=clean_text, lang='en') |
| | audio = io.BytesIO() |
| | tts.write_to_fp(audio) |
| | audio.seek(0) |
| | return audio |
| | except: |
| | return None |
| |
|
| | |
| | |
| | |
| | def main(): |
| | st.title("📖 Children's Story Generator") |
| | |
| | |
| | if 'processing' not in st.session_state: |
| | st.session_state.update({ |
| | 'caption': None, |
| | 'story': None, |
| | 'audio': None |
| | }) |
| | |
| | |
| | uploaded_file = st.file_uploader("Upload Image", type=["jpg", "png"]) |
| | |
| | if uploaded_file: |
| | |
| | st.image(uploaded_file, width=300) |
| | |
| | |
| | if not st.session_state.caption: |
| | with st.spinner("Analyzing image..."): |
| | st.session_state.caption = stage1_process(uploaded_file) |
| | st.success(f"Detected Theme: {st.session_state.caption}") |
| | |
| | |
| | if not st.session_state.story: |
| | with st.spinner("Writing magical story..."): |
| | st.session_state.story = stage2_process(st.session_state.caption) |
| | |
| | |
| | if st.session_state.story: |
| | st.subheader("Generated Story") |
| | st.write(st.session_state.story) |
| | |
| | |
| | if not st.session_state.audio: |
| | with st.spinner("Generating audio..."): |
| | st.session_state.audio = stage3_process(st.session_state.story) |
| | if st.session_state.audio: |
| | st.audio(st.session_state.audio, format="audio/mp3") |
| | st.download_button("Download Audio", |
| | st.session_state.audio.getvalue(), |
| | "story.mp3", |
| | mime="audio/mp3") |
| |
|
| | if __name__ == "__main__": |
| | main() |