File size: 3,544 Bytes
388bd64
 
2ad549c
5294a30
 
 
a112332
 
0ea388d
f0ce8ff
a112332
388bd64
0ca4f2c
 
 
 
 
 
0b49799
0ca4f2c
 
 
 
 
388bd64
c4849b3
0ca4f2c
693e4a0
388bd64
 
0a2a76c
4ea7eed
a969550
4ea7eed
f9c806c
4ea7eed
 
943eb34
 
0ca4f2c
4ea7eed
4ee6b1d
 
388bd64
 
 
 
128e5fe
 
5294a30
 
 
128e5fe
 
 
 
 
 
 
 
c4849b3
7b9c657
 
 
 
 
 
 
 
388bd64
a969550
 
 
 
 
 
ff8f18e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4849b3
ff8f18e
128e5fe
ff8f18e
 
c4849b3
ff8f18e
f0ce8ff
cf8bcb4
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import streamlit as st
from transformers import pipeline
import torch
from gtts import gTTS
import io

st.set_page_config(page_title="Your Image to Audio Story",
                   page_icon="🦜")
st.header("Turn Your Image📷 to a Short Audio Story🔊 for Children👶")
uploaded_file = st.file_uploader("Select an Image After the Models are Loaded...")

# function part

# Preload models once
@st.cache_resource
def load_models():
    return {
        "img_model": pipeline("image-to-text", "cnmoro/tiny-image-captioning"),
        "story_model": pipeline("text-generation", "Qwen/Qwen2.5-0.5B-Instruct")
    }

models = load_models()


# img2text
def img2text(url):
    text = models["img_model"](url)[0]["generated_text"]
    return text

# text2story
def text2story(text):
    # Define your messages
    prompt = f"Generate a brief 100-word story about: {text}"
    messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": prompt}
    ]
    response = models["story_model"](
    messages,
    max_new_tokens=100,
    do_sample=True,
    temperature=0.7)[0]["generated_text"]
    story_text = response[2]["content"]
    return story_text

# text2audio
def text2audio(story_text):
    # Create temporary in-memory file
    audio_io = io.BytesIO()
    
    # Generate speech using gTTS
    tts = gTTS(text=story_text, lang='en', slow=False)
    tts.write_to_fp(audio_io)
    audio_io.seek(0)
    
    # Return as dictionary with compatible structure
    return {
        'audio': audio_io,
        'sampling_rate': 16000  # gTTS uses 16kHz by default
    }
    
# Initialize session state variables
if 'processed_data' not in st.session_state:
    st.session_state.processed_data = {
        'scenario': None,
        'story': None,
        'audio': None
    }

if uploaded_file is not None:
    print(uploaded_file)
    bytes_data = uploaded_file.getvalue()
    with open(uploaded_file.name, "wb") as file:
        file.write(bytes_data)
    st.image(uploaded_file, caption="Uploaded Image",
             use_container_width=True)
    # Only process if file is new
    if st.session_state.get('current_file') != uploaded_file.name:
        st.session_state.current_file = uploaded_file.name
        
        # Stage 1: Image to Text
        with st.spinner('Processing image...'):
            st.session_state.processed_data['scenario'] = img2text(uploaded_file.name)
        
        # Stage 2: Text to Story
        with st.spinner('Generating story...'):
            st.session_state.processed_data['story'] = text2story(
                st.session_state.processed_data['scenario']
            )
        
        # Stage 3: Story to Audio
        with st.spinner('Creating audio...'):
            st.session_state.processed_data['audio'] = text2audio(
                st.session_state.processed_data['story']
            )

    # Display results
    # st.image(uploaded_file)
    st.write("Caption:", st.session_state.processed_data['scenario'])
    st.write("Story:", st.session_state.processed_data['story'])

# Keep audio button OUTSIDE file processing block
if st.button("Play Audio of the Story Generated"):
    if st.session_state.processed_data.get('audio'):
        audio_data = st.session_state.processed_data['audio']
        # Convert BytesIO to bytes and specify format
        st.audio(
            audio_data['audio'].getvalue(),
            format="audio/mp3"   # gTTS outputs MP3 by default
        )
    else:
        st.warning("Please generate a story first!")