frankai98's picture
Update app.py
0ea388d verified
import streamlit as st
from transformers import pipeline
import torch
from gtts import gTTS
import io
st.set_page_config(page_title="Your Image to Audio Story",
page_icon="🦜")
st.header("Turn Your Image📷 to a Short Audio Story🔊 for Children👶")
uploaded_file = st.file_uploader("Select an Image After the Models are Loaded...")
# function part
# Preload models once
@st.cache_resource
def load_models():
return {
"img_model": pipeline("image-to-text", "cnmoro/tiny-image-captioning"),
"story_model": pipeline("text-generation", "Qwen/Qwen2.5-0.5B-Instruct")
}
models = load_models()
# img2text
def img2text(url):
text = models["img_model"](url)[0]["generated_text"]
return text
# text2story
def text2story(text):
# Define your messages
prompt = f"Generate a brief 100-word story about: {text}"
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
]
response = models["story_model"](
messages,
max_new_tokens=100,
do_sample=True,
temperature=0.7)[0]["generated_text"]
story_text = response[2]["content"]
return story_text
# text2audio
def text2audio(story_text):
# Create temporary in-memory file
audio_io = io.BytesIO()
# Generate speech using gTTS
tts = gTTS(text=story_text, lang='en', slow=False)
tts.write_to_fp(audio_io)
audio_io.seek(0)
# Return as dictionary with compatible structure
return {
'audio': audio_io,
'sampling_rate': 16000 # gTTS uses 16kHz by default
}
# Initialize session state variables
if 'processed_data' not in st.session_state:
st.session_state.processed_data = {
'scenario': None,
'story': None,
'audio': None
}
if uploaded_file is not None:
print(uploaded_file)
bytes_data = uploaded_file.getvalue()
with open(uploaded_file.name, "wb") as file:
file.write(bytes_data)
st.image(uploaded_file, caption="Uploaded Image",
use_container_width=True)
# Only process if file is new
if st.session_state.get('current_file') != uploaded_file.name:
st.session_state.current_file = uploaded_file.name
# Stage 1: Image to Text
with st.spinner('Processing image...'):
st.session_state.processed_data['scenario'] = img2text(uploaded_file.name)
# Stage 2: Text to Story
with st.spinner('Generating story...'):
st.session_state.processed_data['story'] = text2story(
st.session_state.processed_data['scenario']
)
# Stage 3: Story to Audio
with st.spinner('Creating audio...'):
st.session_state.processed_data['audio'] = text2audio(
st.session_state.processed_data['story']
)
# Display results
# st.image(uploaded_file)
st.write("Caption:", st.session_state.processed_data['scenario'])
st.write("Story:", st.session_state.processed_data['story'])
# Keep audio button OUTSIDE file processing block
if st.button("Play Audio of the Story Generated"):
if st.session_state.processed_data.get('audio'):
audio_data = st.session_state.processed_data['audio']
# Convert BytesIO to bytes and specify format
st.audio(
audio_data['audio'].getvalue(),
format="audio/mp3" # gTTS outputs MP3 by default
)
else:
st.warning("Please generate a story first!")