gemma-3n / app.py
matthartman's picture
Upload app.py with huggingface_hub
294ffc2 verified
import gradio as gr
import spaces
import torch
from fastrtc import AdditionalOutputs, ReplyOnPause, WebRTC, WebRTCData, get_cloudflare_turn_credentials_async
from threading import Thread
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation.streamers import TextIteratorStreamer
MODEL_ID = "google/gemma-3-27b-it"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
torch_dtype=torch.float16,
)
@spaces.GPU(time_limit=120)
def generate(data: WebRTCData, history, system_prompt="", max_new_tokens=512):
text = data.textbox
history.append({"role": "user", "content": text})
yield AdditionalOutputs(history)
messages = [{"role": "system", "content": system_prompt}] if system_prompt else []
messages.extend(history)
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
tokenize=True,
).to(model.device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = dict(
input_ids=inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=False,
)
Thread(target=model.generate, kwargs=gen_kwargs).start()
new_message = {"role": "assistant", "content": ""}
for token in streamer:
new_message["content"] += token
yield AdditionalOutputs(history + [new_message])
with gr.Blocks() as demo:
chatbot = gr.Chatbot(type="messages")
webrtc = WebRTC(
modality="audio",
mode="send",
variant="textbox",
rtc_configuration=get_cloudflare_turn_credentials_async,
)
with gr.Accordion("Settings", open=False):
system_prompt = gr.Textbox(
"You are a helpful assistant.", label="System prompt"
)
max_new_tokens = gr.Slider(50, 1500, 700, label="Max new tokens")
webrtc.stream(
ReplyOnPause(generate),
inputs=[webrtc, chatbot, system_prompt, max_new_tokens],
outputs=[chatbot],
concurrency_limit=100,
)
webrtc.on_additional_outputs(
lambda old, new: new, inputs=[chatbot], outputs=[chatbot]
)
if __name__ == "__main__":
demo.launch(ssr_mode=False)