Assignment1_2 / app.py
namuisam's picture
Update app.py
bfa949b verified
import streamlit as st
from transformers import pipeline
import hashlib
# Function definitions
def img2text(url):
image_to_text_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
text = image_to_text_model(url)[0]["generated_text"]
return text
def text2story(text):
text_generation_model = pipeline("text-generation", model="aspis/gpt2-genre-story-generation")
story_text = text_generation_model(
text,
min_length=50,
max_length=100,
do_sample=True,
early_stopping=True,
top_p=0.6
)[0]["generated_text"]
return story_text
def text2audio(story_text):
text2audio_model = pipeline("text-to-speech", model="Matthijs/mms-tts-eng")
gen_audio = text2audio_model(story_text)
return gen_audio
def main():
st.set_page_config(page_title="Your Image to Audio Story", page_icon="🦜")
st.header("Turn Your Image to Audio Story")
uploaded_file = st.file_uploader("Select an Image...")
if uploaded_file is not None:
# Get file bytes and compute a hash
bytes_data = uploaded_file.getvalue()
file_hash = hashlib.sha256(bytes_data).hexdigest()
# Reset session state only if the file content has changed, it prevents the regeneration after clicking "play audio"
if ("last_uploaded_hash" not in st.session_state) or (st.session_state.last_uploaded_hash != file_hash):
st.session_state.scenario = None
st.session_state.story = None
st.session_state.audio_data = None
st.session_state.last_uploaded_hash = file_hash
# Save the uploaded file locally.
with open(uploaded_file.name, "wb") as file:
file.write(bytes_data)
st.image(uploaded_file, caption="Uploaded Image", use_container_width=True)
# Stage 1: Image to Text
if st.session_state.scenario is None:
st.text("Processing img2text...")
st.session_state.scenario = img2text(uploaded_file.name)
st.write(st.session_state.scenario)
# Stage 2: Text to Story
if st.session_state.story is None:
st.text("Generating a story...")
st.session_state.story = text2story(st.session_state.scenario)
st.write(st.session_state.story)
# Stage 3: Story to Audio data
if st.session_state.audio_data is None:
st.text("Generating audio data...")
st.session_state.audio_data = text2audio(st.session_state.story)
# Play Audio button – uses stored audio_data.
if st.button("Play Audio"):
st.audio(
st.session_state.audio_data["audio"],
format="audio/wav",
start_time=0,
sample_rate=st.session_state.audio_data["sampling_rate"]
)
if __name__ == "__main__":
main()