File size: 3,925 Bytes
f1b7ce9
 
 
 
 
 
 
 
 
 
 
 
ff1e824
f1b7ce9
ff1e824
f1b7ce9
 
 
 
 
 
 
 
 
 
ff1e824
f1b7ce9
 
 
 
ff1e824
f1b7ce9
 
 
ff1e824
f1b7ce9
 
 
 
 
 
 
 
ff1e824
f1b7ce9
 
ff1e824
f1b7ce9
 
ff1e824
f1b7ce9
 
 
 
 
 
 
 
 
 
 
 
ff1e824
f1b7ce9
 
ff1e824
f1b7ce9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff1e824
f1b7ce9
 
ff1e824
f1b7ce9
 
 
 
 
 
 
 
 
 
 
ff1e824
 
f1b7ce9
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os, re, logging, gradio as gr
from openai import OpenAI
from gateway import request_generation
from utils import LATEX_DELIMS
    
openai_api_key = os.getenv("API_KEY")
openai_api_base = os.getenv("API_ENDPOINT")
MODEL = os.getenv("MODEL_NAME", "")
client = OpenAI(api_key=openai_api_key, base_url=openai_api_base)
MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", 1024))
CONCURRENCY_LIMIT = int(os.getenv("CONCURRENCY_LIMIT", 20))
QUEUE_SIZE = int(os.getenv("QUEUE_SIZE", CONCURRENCY_LIMIT * 4))

logging.basicConfig(level=logging.INFO)

def format_analysis_response(text):
    m = re.search(r"analysis(.*?)assistantfinal", text, re.DOTALL)
    if m:
        reasoning = m.group(1).strip()
        response = text.split("assistantfinal", 1)[-1].strip()
        return (
            f"**🤔 Analysis:**\n\n*{reasoning}*\n\n---\n\n"
            f"**💬 Response:**\n\n{response}"
        )
    return text.strip()

def generate(message, history,
             system_prompt, temperature,
             frequency_penalty, presence_penalty,
             max_new_tokens):

    if not message.strip():
        yield "Please enter a prompt."
        return

    msgs = []
    for h in history:
        if isinstance(h, dict):
            msgs.append(h)
        elif isinstance(h, (list, tuple)) and len(h) == 2:
            u, a = h
            if u: msgs.append({"role": "user", "content": u})
            if a: msgs.append({"role": "assistant", "content": a})

    logging.info(f"[User] {message}")
    logging.info(f"[System] {system_prompt} | Temp={temperature}")

    collected, buffer = "", ""
    yielded_once = False

    try:
        for delta in request_generation(
            api_key=openai_api_key, api_base=openai_api_base,
            message=message, system_prompt=system_prompt,
            model_name=MODEL, chat_history=msgs,
            temperature=temperature,
            frequency_penalty=frequency_penalty,
            presence_penalty=presence_penalty,
            max_new_tokens=max_new_tokens,
        ):
            if not delta:
                continue

            collected += delta
            buffer += delta

            if not yielded_once:
                yield delta
                buffer = ""
                yielded_once = True
                continue

            if "\n" in buffer or len(buffer) > 150:
                yield collected
                buffer = ""

        final = format_analysis_response(collected)
        if final.count("$") % 2:
            final += "$"
        yield final

    except Exception as e:
        logging.exception("Stream failed")
        yield f"❌ Error: {e}"

chatbot_ui = gr.ChatInterface(
    fn=generate,
    type="messages",
    chatbot=gr.Chatbot(
        label="OSS vLLM Chatbot",
        type="messages",
        scale=2,
        height=600,
        latex_delimiters=LATEX_DELIMS,
    ),
    stop_btn=True,
    additional_inputs=[
        gr.Textbox(label="System prompt", value="You are a helpful assistant.", lines=2),
        gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, step=0.1, value=0.7),
    ],
    examples=[
        ["Explain the difference between supervised and unsupervised learning."],
        ["Summarize the plot of Inception in two sentences."],
        ["Show me the LaTeX for the quadratic formula."],
        ["What are advantages of AMD Instinct MI300X GPU?"],
        ["Derive the gradient of softmax cross-entropy loss."],
        ["Explain why ∂/∂x xⁿ = n·xⁿ⁻¹ holds."],
    ],
    # title="Open-source GPT-OSS-120B on AMD MI300X",
    title=" GPT-OSS-120B on AMD MI300X",
    description="This Space is an Alpha release that demonstrates gpt-oss-120b model running on AMD MI300 infrastructure. The space is built with Apache 2.0 License.",
)
if __name__ == "__main__":
    chatbot_ui.queue(max_size=QUEUE_SIZE,
                     default_concurrency_limit=CONCURRENCY_LIMIT).launch()