File size: 4,756 Bytes
cfd75dc
0845b6c
cfd75dc
0845b6c
cfd75dc
 
 
 
 
 
 
 
0845b6c
cfd75dc
 
40a9697
0845b6c
40a9697
0845b6c
 
 
 
 
 
cfd75dc
0845b6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cfd75dc
 
 
40a9697
cfd75dc
 
 
0845b6c
 
cfd75dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40a9697
cfd75dc
40a9697
cfd75dc
40a9697
0845b6c
 
cfd75dc
 
 
 
 
0845b6c
 
 
 
cfd75dc
40a9697
cfd75dc
 
40a9697
cfd75dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0845b6c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import asyncio
from google import genai
from google.genai import types
from config import settings
import wave
import queue
import logging
import io
import time

logger = logging.getLogger(__name__)

client = genai.Client(api_key=settings.gemini_api_key.get_secret_value(), http_options={'api_version': 'v1alpha'})

async def generate_music(user_hash: str, music_tone: str, receive_audio):
    if user_hash in sessions:
        logger.info(f"Music generation already started for user hash {user_hash}, skipping new generation")
        return
    async with (
        client.aio.live.music.connect(model='models/lyria-realtime-exp') as session,
        asyncio.TaskGroup() as tg,
    ):
        # Set up task to receive server messages.
        tg.create_task(receive_audio(session, user_hash))

        # Send initial prompts and config
        await session.set_weighted_prompts(
          prompts=[
            types.WeightedPrompt(text=music_tone, weight=1.0),
          ]
        )
        await session.set_music_generation_config(
          config=types.LiveMusicGenerationConfig(bpm=90, temperature=1.0)
        )
        await session.play()
        logger.info(f"Started music generation for user hash {user_hash}, music tone: {music_tone}")
        sessions[user_hash] = {
            'session': session,
            'queue': queue.Queue()
        }
        
async def change_music_tone(user_hash: str, new_tone):
    logger.info(f"Changing music tone to {new_tone}")
    session = sessions.get(user_hash, {}).get('session')
    if not session:
        logger.error(f"No session found for user hash {user_hash}")
        return
    await session.set_weighted_prompts(
        prompts=[types.WeightedPrompt(text=new_tone, weight=1.0)]
    )
        

SAMPLE_RATE = 48000
NUM_CHANNELS = 2  # Stereo
SAMPLE_WIDTH = 2  # 16-bit audio -> 2 bytes per sample

async def receive_audio(session, user_hash):
    """Process incoming audio from the music generation."""
    while True:
        try:
            async for message in session.receive():
                if message.server_content and message.server_content.audio_chunks:
                    audio_data = message.server_content.audio_chunks[0].data
                    queue = sessions[user_hash]['queue']
                    # audio_data is already bytes (raw PCM)
                    await asyncio.to_thread(queue.put, audio_data)
                await asyncio.sleep(10**-12)
        except Exception as e:
            logger.error(f"Error in receive_audio: {e}")
            break

sessions = {}

async def start_music_generation(user_hash: str, music_tone: str):
    """Start the music generation in a separate thread."""
    await generate_music(user_hash, music_tone, receive_audio)
    
async def cleanup_music_session(user_hash: str):
    if user_hash in sessions:
        logger.info(f"Cleaning up music session for user hash {user_hash}")
        session = sessions[user_hash]['session']
        await session.stop()
        await session.close()
        del sessions[user_hash]
    

def update_audio(user_hash):
    """Continuously stream audio from the queue as WAV bytes."""
    if user_hash == "":
        return
    
    logger.info(f"Starting audio update loop for user hash: {user_hash}")
    while True:
        if user_hash not in sessions:
            time.sleep(0.5)
            continue
        queue = sessions[user_hash]['queue']
        pcm_data = queue.get() # This is raw PCM audio bytes
        
        if not isinstance(pcm_data, bytes):
            logger.warning(f"Expected bytes from audio_queue, got {type(pcm_data)}. Skipping.")
            continue

        # Lyria provides stereo, 16-bit PCM at 48kHz.
        # Ensure the number of bytes is consistent with stereo 16-bit audio.
        # Each frame = NUM_CHANNELS * SAMPLE_WIDTH bytes.
        # If len(pcm_data) is not a multiple of (NUM_CHANNELS * SAMPLE_WIDTH), 
        # it might indicate an incomplete chunk or an issue.
        bytes_per_frame = NUM_CHANNELS * SAMPLE_WIDTH
        if len(pcm_data) % bytes_per_frame != 0:
            logger.warning(
                f"Received PCM data with length {len(pcm_data)}, which is not a multiple of "
                f"bytes_per_frame ({bytes_per_frame}). This might cause issues with WAV formatting."
            )
            # Depending on strictness, you might want to skip this chunk:
            # continue 

        wav_buffer = io.BytesIO()
        with wave.open(wav_buffer, 'wb') as wf:
            wf.setnchannels(NUM_CHANNELS)
            wf.setsampwidth(SAMPLE_WIDTH) # Corresponds to 16-bit audio
            wf.setframerate(SAMPLE_RATE)
            wf.writeframes(pcm_data)
        wav_bytes = wav_buffer.getvalue()
        yield wav_bytes