Spaces:
Runtime error
Runtime error
import gradio as gr | |
import spaces | |
import torch | |
from fastrtc import AdditionalOutputs, ReplyOnPause, WebRTC, WebRTCData, get_cloudflare_turn_credentials_async | |
from threading import Thread | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from transformers.generation.streamers import TextIteratorStreamer | |
MODEL_ID = "google/gemma-3-27b-it" | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
device_map="auto", | |
torch_dtype=torch.float16, | |
) | |
def generate(data: WebRTCData, history, system_prompt="", max_new_tokens=512): | |
text = data.textbox | |
history.append({"role": "user", "content": text}) | |
yield AdditionalOutputs(history) | |
messages = [{"role": "system", "content": system_prompt}] if system_prompt else [] | |
messages.extend(history) | |
inputs = tokenizer.apply_chat_template( | |
messages, | |
add_generation_prompt=True, | |
return_tensors="pt", | |
tokenize=True, | |
).to(model.device) | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
gen_kwargs = dict( | |
input_ids=inputs, | |
streamer=streamer, | |
max_new_tokens=max_new_tokens, | |
do_sample=False, | |
) | |
Thread(target=model.generate, kwargs=gen_kwargs).start() | |
new_message = {"role": "assistant", "content": ""} | |
for token in streamer: | |
new_message["content"] += token | |
yield AdditionalOutputs(history + [new_message]) | |
with gr.Blocks() as demo: | |
chatbot = gr.Chatbot(type="messages") | |
webrtc = WebRTC( | |
modality="audio", | |
mode="send", | |
variant="textbox", | |
rtc_configuration=get_cloudflare_turn_credentials_async, | |
) | |
with gr.Accordion("Settings", open=False): | |
system_prompt = gr.Textbox( | |
"You are a helpful assistant.", label="System prompt" | |
) | |
max_new_tokens = gr.Slider(50, 1500, 700, label="Max new tokens") | |
webrtc.stream( | |
ReplyOnPause(generate), | |
inputs=[webrtc, chatbot, system_prompt, max_new_tokens], | |
outputs=[chatbot], | |
concurrency_limit=100, | |
) | |
webrtc.on_additional_outputs( | |
lambda old, new: new, inputs=[chatbot], outputs=[chatbot] | |
) | |
if __name__ == "__main__": | |
demo.launch(ssr_mode=False) |