File size: 2,327 Bytes
294ffc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import gradio as gr
import spaces
import torch
from fastrtc import AdditionalOutputs, ReplyOnPause, WebRTC, WebRTCData, get_cloudflare_turn_credentials_async
from threading import Thread
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation.streamers import TextIteratorStreamer

MODEL_ID = "google/gemma-3-27b-it"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    torch_dtype=torch.float16,
)

@spaces.GPU(time_limit=120)
def generate(data: WebRTCData, history, system_prompt="", max_new_tokens=512):
    text = data.textbox
    history.append({"role": "user", "content": text})
    yield AdditionalOutputs(history)

    messages = [{"role": "system", "content": system_prompt}] if system_prompt else []
    messages.extend(history)

    inputs = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt",
        tokenize=True,
    ).to(model.device)

    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    gen_kwargs = dict(
        input_ids=inputs,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=False,
    )
    Thread(target=model.generate, kwargs=gen_kwargs).start()

    new_message = {"role": "assistant", "content": ""}
    for token in streamer:
        new_message["content"] += token
        yield AdditionalOutputs(history + [new_message])


with gr.Blocks() as demo:
    chatbot = gr.Chatbot(type="messages")
    webrtc = WebRTC(
        modality="audio",
        mode="send",
        variant="textbox",
        rtc_configuration=get_cloudflare_turn_credentials_async,
    )
    with gr.Accordion("Settings", open=False):
        system_prompt = gr.Textbox(
            "You are a helpful assistant.", label="System prompt"
        )
        max_new_tokens = gr.Slider(50, 1500, 700, label="Max new tokens")

    webrtc.stream(
        ReplyOnPause(generate),
        inputs=[webrtc, chatbot, system_prompt, max_new_tokens],
        outputs=[chatbot],
        concurrency_limit=100,
    )
    webrtc.on_additional_outputs(
        lambda old, new: new, inputs=[chatbot], outputs=[chatbot]
    )

if __name__ == "__main__":
    demo.launch(ssr_mode=False)