tuxedocat's picture
Unrestricted chatbox length
248afb0
raw
history blame
4.9 kB
from functools import partial
import gradio as gr
import httpx
import subprocess
import os
from openai import OpenAI
from const import (
LLM_BASE_URL,
AUTH_CMD,
SYSTEM_PROMPTS,
EXAMPLES,
CSS,
HEADER,
FOOTER,
PLACEHOLDER,
ModelInfo,
MODELS,
)
def get_token() -> str:
try:
t = (
subprocess.run(
AUTH_CMD,
stdout=subprocess.PIPE,
stderr=subprocess.DEVNULL,
env=os.environ.copy(),
)
.stdout.decode("utf-8")
.strip()
)
assert t, "Failed to get auth token"
return t
except Exception:
raise ValueError("Failed to get auth token")
def get_headers(host: str) -> dict:
return {
"Authorization": f"Bearer {get_token()}",
"Host": host,
"Accept": "application/json",
"Content-Type": "application/json",
}
def proxy(request: httpx.Request, model_info: ModelInfo) -> httpx.Request:
request.url = request.url.copy_with(path=model_info.endpoint)
request.headers.update(get_headers(host=model_info.host))
return request
def call_llm(
message: str,
history: list[dict],
model_name: str,
system_prompt: str,
max_tokens: int,
temperature: float,
top_p: float,
):
history_openai_format = []
system_prompt_text = SYSTEM_PROMPTS[system_prompt]
if len(history) == 0:
init = {
"role": "system",
"content": system_prompt_text,
}
history_openai_format.append(init)
history_openai_format.append({"role": "user", "content": message})
else:
for human, assistant in history:
history_openai_format.append({"role": "user", "content": human})
history_openai_format.append({"role": "assistant", "content": assistant})
history_openai_format.append({"role": "user", "content": message})
model_info = MODELS[model_name]
client = OpenAI(
api_key="",
base_url=LLM_BASE_URL,
http_client=httpx.Client(
event_hooks={
"request": [partial(proxy, model_info=model_info)],
},
verify=False,
),
)
stream = client.chat.completions.create(
model=f"/data/cyberagent/{model_info.name}",
messages=history_openai_format,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
n=1,
stream=True,
extra_body={"repetition_penalty": 1.1},
)
message = ""
for chunk in stream:
content = chunk.choices[0].delta.content or ""
message = message + content
yield message
def run():
chatbot = gr.Chatbot(
elem_id="chatbot",
scale=1,
show_copy_button=True,
height="70%",
layout="panel",
)
with gr.Blocks(fill_height=True) as demo:
gr.Markdown(HEADER)
gr.ChatInterface(
fn=call_llm,
stop_btn="Stop Generation",
examples=EXAMPLES,
cache_examples=False,
multimodal=False,
chatbot=chatbot,
additional_inputs_accordion=gr.Accordion(
label="Parameters", open=False, render=False
),
additional_inputs=[
gr.Dropdown(
choices=list(MODELS.keys()),
value=list(MODELS.keys())[0],
label="Model",
visible=False,
render=False,
),
gr.Dropdown(
choices=list(SYSTEM_PROMPTS.keys()),
value=list(SYSTEM_PROMPTS.keys())[0],
label="System Prompt",
visible=False,
render=False,
),
gr.Slider(
minimum=1,
maximum=4096,
step=1,
value=1024,
label="Max tokens",
visible=True,
render=False,
),
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.3,
label="Temperature",
visible=True,
render=False,
),
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=1.0,
label="Top-p",
visible=True,
render=False,
),
],
analytics_enabled=False,
)
gr.Markdown(FOOTER)
demo.queue(max_size=256, api_open=False)
demo.launch(share=False, quiet=True)
if __name__ == "__main__":
run()