Spaces:
Running
on
Zero
Running
on
Zero
Working Nari labs code
Browse files
app.py
CHANGED
@@ -7,13 +7,15 @@ from dia.model import Dia
|
|
7 |
from huggingface_hub import InferenceClient
|
8 |
import numpy as np
|
9 |
from transformers import set_seed
|
|
|
|
|
10 |
|
11 |
# Hardcoded podcast subject
|
12 |
PODCAST_SUBJECT = "The future of AI and its impact on society"
|
13 |
|
14 |
# Initialize the inference client
|
15 |
client = InferenceClient("meta-llama/Llama-3.3-70B-Instruct", provider="cerebras", token=os.getenv("HF_TOKEN"))
|
16 |
-
model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="
|
17 |
|
18 |
# Queue for audio streaming
|
19 |
audio_queue = queue.Queue()
|
@@ -36,15 +38,9 @@ Now go on, make 5 minutes of podcast.
|
|
36 |
|
37 |
def split_podcast_into_chunks(podcast_text, chunk_size=3):
|
38 |
lines = podcast_text.strip().split("\n")
|
39 |
-
|
40 |
-
|
41 |
-
for i in range(0, len(lines), chunk_size):
|
42 |
-
chunk = "\n".join(lines[i : i + chunk_size])
|
43 |
-
chunks.append(chunk)
|
44 |
|
45 |
-
|
46 |
-
|
47 |
-
def postprocess_audio(output_audio_np, speed_factor: float=0.94):
|
48 |
"""Taken from https://huggingface.co/spaces/nari-labs/Dia-1.6B/blob/main/app.py"""
|
49 |
# Get sample rate from the loaded DAC model
|
50 |
output_sr = 44100
|
@@ -98,6 +94,7 @@ def process_audio_chunks(podcast_text):
|
|
98 |
chunks = split_podcast_into_chunks(podcast_text)
|
99 |
sample_rate = 44100 # Modified from https://huggingface.co/spaces/nari-labs/Dia-1.6B/blob/main/app.py has 44100
|
100 |
for chunk in chunks:
|
|
|
101 |
if stop_signal.is_set():
|
102 |
break
|
103 |
set_seed(42)
|
@@ -117,26 +114,21 @@ def process_audio_chunks(podcast_text):
|
|
117 |
def stream_audio_generator(podcast_text):
|
118 |
"""Creates a generator that yields audio chunks for streaming"""
|
119 |
stop_signal.clear()
|
|
|
120 |
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
# Yield the audio chunk with sample rate
|
135 |
-
print(chunk)
|
136 |
-
yield chunk
|
137 |
-
|
138 |
-
except Exception as e:
|
139 |
-
print(f"Error in streaming: {e}")
|
140 |
|
141 |
|
142 |
def stop_generation():
|
|
|
7 |
from huggingface_hub import InferenceClient
|
8 |
import numpy as np
|
9 |
from transformers import set_seed
|
10 |
+
import io, soundfile as sf
|
11 |
+
|
12 |
|
13 |
# Hardcoded podcast subject
|
14 |
PODCAST_SUBJECT = "The future of AI and its impact on society"
|
15 |
|
16 |
# Initialize the inference client
|
17 |
client = InferenceClient("meta-llama/Llama-3.3-70B-Instruct", provider="cerebras", token=os.getenv("HF_TOKEN"))
|
18 |
+
model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="bfloat16")
|
19 |
|
20 |
# Queue for audio streaming
|
21 |
audio_queue = queue.Queue()
|
|
|
38 |
|
39 |
def split_podcast_into_chunks(podcast_text, chunk_size=3):
|
40 |
lines = podcast_text.strip().split("\n")
|
41 |
+
return ["\n".join(lines[i : i + chunk_size]) for i in range(0, len(lines), chunk_size)]
|
|
|
|
|
|
|
|
|
42 |
|
43 |
+
def postprocess_audio(output_audio_np, speed_factor: float=0.8):
|
|
|
|
|
44 |
"""Taken from https://huggingface.co/spaces/nari-labs/Dia-1.6B/blob/main/app.py"""
|
45 |
# Get sample rate from the loaded DAC model
|
46 |
output_sr = 44100
|
|
|
94 |
chunks = split_podcast_into_chunks(podcast_text)
|
95 |
sample_rate = 44100 # Modified from https://huggingface.co/spaces/nari-labs/Dia-1.6B/blob/main/app.py has 44100
|
96 |
for chunk in chunks:
|
97 |
+
print(f"Processing chunk: {chunk}")
|
98 |
if stop_signal.is_set():
|
99 |
break
|
100 |
set_seed(42)
|
|
|
114 |
def stream_audio_generator(podcast_text):
|
115 |
"""Creates a generator that yields audio chunks for streaming"""
|
116 |
stop_signal.clear()
|
117 |
+
threading.Thread(target=process_audio_chunks, args=(podcast_text,)).start()
|
118 |
|
119 |
+
while True:
|
120 |
+
chunk = audio_queue.get()
|
121 |
+
if chunk is None:
|
122 |
+
break
|
123 |
+
sr, data = chunk # the tuple you produced earlier
|
124 |
+
|
125 |
+
# Encode the numpy array into a WAV blob
|
126 |
+
buf = io.BytesIO()
|
127 |
+
sf.write(buf, data.astype(np.float32) / 32768.0, sr, format="wav")
|
128 |
+
buf.seek(0)
|
129 |
+
buffer = buf.getvalue()
|
130 |
+
print("PRINTING BUFFER:", buffer)
|
131 |
+
yield buffer# <-- bytes, so the browser can play it
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
|
133 |
|
134 |
def stop_generation():
|