m-ric HF Staff commited on
Commit
5da485d
·
1 Parent(s): b065765

Working Nari labs code

Browse files
Files changed (1) hide show
  1. app.py +20 -28
app.py CHANGED
@@ -7,13 +7,15 @@ from dia.model import Dia
7
  from huggingface_hub import InferenceClient
8
  import numpy as np
9
  from transformers import set_seed
 
 
10
 
11
  # Hardcoded podcast subject
12
  PODCAST_SUBJECT = "The future of AI and its impact on society"
13
 
14
  # Initialize the inference client
15
  client = InferenceClient("meta-llama/Llama-3.3-70B-Instruct", provider="cerebras", token=os.getenv("HF_TOKEN"))
16
- model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float32")
17
 
18
  # Queue for audio streaming
19
  audio_queue = queue.Queue()
@@ -36,15 +38,9 @@ Now go on, make 5 minutes of podcast.
36
 
37
  def split_podcast_into_chunks(podcast_text, chunk_size=3):
38
  lines = podcast_text.strip().split("\n")
39
- chunks = []
40
-
41
- for i in range(0, len(lines), chunk_size):
42
- chunk = "\n".join(lines[i : i + chunk_size])
43
- chunks.append(chunk)
44
 
45
- return chunks
46
-
47
- def postprocess_audio(output_audio_np, speed_factor: float=0.94):
48
  """Taken from https://huggingface.co/spaces/nari-labs/Dia-1.6B/blob/main/app.py"""
49
  # Get sample rate from the loaded DAC model
50
  output_sr = 44100
@@ -98,6 +94,7 @@ def process_audio_chunks(podcast_text):
98
  chunks = split_podcast_into_chunks(podcast_text)
99
  sample_rate = 44100 # Modified from https://huggingface.co/spaces/nari-labs/Dia-1.6B/blob/main/app.py has 44100
100
  for chunk in chunks:
 
101
  if stop_signal.is_set():
102
  break
103
  set_seed(42)
@@ -117,26 +114,21 @@ def process_audio_chunks(podcast_text):
117
  def stream_audio_generator(podcast_text):
118
  """Creates a generator that yields audio chunks for streaming"""
119
  stop_signal.clear()
 
120
 
121
- # Start audio generation in a separate thread
122
- gen_thread = threading.Thread(target=process_audio_chunks, args=(podcast_text,))
123
- gen_thread.start()
124
-
125
- try:
126
- while True:
127
- # Get next chunk from queue
128
- chunk = audio_queue.get()
129
-
130
- # None signals end of generation
131
- if chunk is None:
132
- break
133
-
134
- # Yield the audio chunk with sample rate
135
- print(chunk)
136
- yield chunk
137
-
138
- except Exception as e:
139
- print(f"Error in streaming: {e}")
140
 
141
 
142
  def stop_generation():
 
7
  from huggingface_hub import InferenceClient
8
  import numpy as np
9
  from transformers import set_seed
10
+ import io, soundfile as sf
11
+
12
 
13
  # Hardcoded podcast subject
14
  PODCAST_SUBJECT = "The future of AI and its impact on society"
15
 
16
  # Initialize the inference client
17
  client = InferenceClient("meta-llama/Llama-3.3-70B-Instruct", provider="cerebras", token=os.getenv("HF_TOKEN"))
18
+ model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="bfloat16")
19
 
20
  # Queue for audio streaming
21
  audio_queue = queue.Queue()
 
38
 
39
  def split_podcast_into_chunks(podcast_text, chunk_size=3):
40
  lines = podcast_text.strip().split("\n")
41
+ return ["\n".join(lines[i : i + chunk_size]) for i in range(0, len(lines), chunk_size)]
 
 
 
 
42
 
43
+ def postprocess_audio(output_audio_np, speed_factor: float=0.8):
 
 
44
  """Taken from https://huggingface.co/spaces/nari-labs/Dia-1.6B/blob/main/app.py"""
45
  # Get sample rate from the loaded DAC model
46
  output_sr = 44100
 
94
  chunks = split_podcast_into_chunks(podcast_text)
95
  sample_rate = 44100 # Modified from https://huggingface.co/spaces/nari-labs/Dia-1.6B/blob/main/app.py has 44100
96
  for chunk in chunks:
97
+ print(f"Processing chunk: {chunk}")
98
  if stop_signal.is_set():
99
  break
100
  set_seed(42)
 
114
  def stream_audio_generator(podcast_text):
115
  """Creates a generator that yields audio chunks for streaming"""
116
  stop_signal.clear()
117
+ threading.Thread(target=process_audio_chunks, args=(podcast_text,)).start()
118
 
119
+ while True:
120
+ chunk = audio_queue.get()
121
+ if chunk is None:
122
+ break
123
+ sr, data = chunk # the tuple you produced earlier
124
+
125
+ # Encode the numpy array into a WAV blob
126
+ buf = io.BytesIO()
127
+ sf.write(buf, data.astype(np.float32) / 32768.0, sr, format="wav")
128
+ buf.seek(0)
129
+ buffer = buf.getvalue()
130
+ print("PRINTING BUFFER:", buffer)
131
+ yield buffer# <-- bytes, so the browser can play it
 
 
 
 
 
 
132
 
133
 
134
  def stop_generation():