import streamlit as st from transformers import pipeline import hashlib # Function definitions def img2text(url): image_to_text_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base") text = image_to_text_model(url)[0]["generated_text"] return text def text2story(text): text_generation_model = pipeline("text-generation", model="aspis/gpt2-genre-story-generation") story_text = text_generation_model( text, min_length=50, max_length=100, do_sample=True, early_stopping=True, top_p=0.6 )[0]["generated_text"] return story_text def text2audio(story_text): text2audio_model = pipeline("text-to-speech", model="Matthijs/mms-tts-eng") gen_audio = text2audio_model(story_text) return gen_audio def main(): st.set_page_config(page_title="Your Image to Audio Story", page_icon="🦜") st.header("Turn Your Image to Audio Story") uploaded_file = st.file_uploader("Select an Image...") if uploaded_file is not None: # Get file bytes and compute a hash bytes_data = uploaded_file.getvalue() file_hash = hashlib.sha256(bytes_data).hexdigest() # Reset session state only if the file content has changed, it prevents the regeneration after clicking "play audio" if ("last_uploaded_hash" not in st.session_state) or (st.session_state.last_uploaded_hash != file_hash): st.session_state.scenario = None st.session_state.story = None st.session_state.audio_data = None st.session_state.last_uploaded_hash = file_hash # Save the uploaded file locally. with open(uploaded_file.name, "wb") as file: file.write(bytes_data) st.image(uploaded_file, caption="Uploaded Image", use_container_width=True) # Stage 1: Image to Text if st.session_state.scenario is None: st.text("Processing img2text...") st.session_state.scenario = img2text(uploaded_file.name) st.write(st.session_state.scenario) # Stage 2: Text to Story if st.session_state.story is None: st.text("Generating a story...") st.session_state.story = text2story(st.session_state.scenario) st.write(st.session_state.story) # Stage 3: Story to Audio data if st.session_state.audio_data is None: st.text("Generating audio data...") st.session_state.audio_data = text2audio(st.session_state.story) # Play Audio button – uses stored audio_data. if st.button("Play Audio"): st.audio( st.session_state.audio_data["audio"], format="audio/wav", start_time=0, sample_rate=st.session_state.audio_data["sampling_rate"] ) if __name__ == "__main__": main()