import gradio as gr
from core.context_manager import ContextManager
from core.make_pipeline import MakePipeline
from core.make_reply import generate_reply
from core.utils import load_config as load_full_config, save_config as save_full_config, load_llm_config
import re
def create_interface(ctx: ContextManager, makePipeline: MakePipeline):
with gr.Blocks(css="""
.chat-box { max-height: 500px; overflow-y: auto; padding: 10px; border: 1px solid #ccc; border-radius: 10px; }
.bubble-left { background-color: #f1f0f0; border-radius: 10px; padding: 10px; margin: 5px; max-width: 70%; float: left; clear: both; }
.bubble-right { background-color: #d1e7ff; border-radius: 10px; padding: 10px; margin: 5px; max-width: 70%; float: right; clear: both; text-align: right; }
.reset-btn-container { text-align: right; margin-bottom: 10px; }
""") as demo:
with gr.Tabs():
### 1. ์ฑํ
ํญ ###
with gr.TabItem("๐ฌ ํ์ง๋ก์ ๋ํํ๊ธฐ"):
with gr.Column():
with gr.Row():
gr.Markdown("### ํ์ง๋ก์ ๋ํํ๊ธฐ")
reset_btn = gr.Button("๐ ๋ํ ์ด๊ธฐํ", elem_classes="reset-btn-container", scale=0.25)
chat_output = gr.HTML(elem_id="chat-box")
user_input = gr.Textbox(label="๋ฉ์์ง ์
๋ ฅ", placeholder="ํ์ง๋ก์๊ฒ ๋ง์ ๊ฑธ์ด๋ณด์ธ์")
state = gr.State(ctx)
# history ์ฝ์ด์ ํ๋ฉด์ ๋ฟ๋ฆฌ๋ ์ญํ
def render_chat(ctx: ContextManager):
def parse_emotion_text(text: str) -> str:
"""
*...* ๋ถ๋ถ์ ํ์ ํ
์คํธ๋ก ๋ฐ๊พธ๊ณ , ์ค๋ฐ๊ฟ์ ์ถ๊ฐํ์ฌ HTML๋ก ๋ฐํ
"""
segments = []
pattern = re.compile(r"\*(.+?)\*|([^\*]+)")
matches = pattern.findall(text)
for action, plain in matches:
if action:
segments.append(f"
*{action}*
")
elif plain:
for line in plain.strip().splitlines():
line = line.strip()
if line:
segments.append(f"{line}
")
return "\n".join(segments)
html = ""
for item in ctx.getHistory():
parsed = parse_emotion_text(item['text'])
if item["role"] == "user":
html += f"{parsed}
"
elif item["role"] == "bot":
html += f"{parsed}
"
return gr.update(value=html)
def on_submit(user_msg: str, ctx: ContextManager):
# ์ฌ์ฉ์ ์
๋ ฅ history์ ์ถ๊ฐ
ctx.addHistory("user", user_msg)
# ์ฌ์ฉ์ ์
๋ ฅ์ ํฌํจํ ์ฑํ
์ฐ์ ๋ ๋๋ง
html = render_chat(ctx)
yield html, "", ctx
# ๋ด ์๋ต ์์ฑ
generate_reply(ctx, makePipeline, user_msg)
# ์๋ต์ ํฌํจํ ์ ์ฒด history ๊ธฐ๋ฐ ๋ ๋๋ง
html = render_chat(ctx)
yield html, "", ctx
# history ์ด๊ธฐํ
def reset_chat(ctx: ContextManager):
ctx.clearHistory()
return gr.update(value=""), "", ctx
user_input.submit(on_submit, inputs=[user_input, state], outputs=[chat_output, user_input, state], queue=True)
reset_btn.click(reset_chat, inputs=[state], outputs=[chat_output, user_input, state])
### 2. ์ค์ ํญ ###
with gr.TabItem("โ๏ธ ๋ชจ๋ธ ์ค์ "):
gr.Markdown("### LLM ํ๋ผ๋ฏธํฐ ์ค์ ")
with gr.Row():
temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature")
top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.05, label="Top-p")
repetition_penalty = gr.Slider(0.8, 2.0, value=1.05, step=0.01, label="Repetition Penalty")
with gr.Row():
max_tokens = gr.Slider(16, 2048, value=96, step=8, label="Max New Tokens")
apply_btn = gr.Button("โ
์ค์ ์ ์ฉ")
def update_config(temp, topp, max_tok, repeat):
makePipeline.update_config({
"temperature": temp,
"top_p": topp,
"max_new_tokens": max_tok,
"repetition_penalty": repeat
})
return gr.update(value="โ
์ค์ ์ ์ฉ ์๋ฃ")
# ๐ป ์ค์ ๋ถ๋ฌ์ค๊ธฐ / ๋ด๋ณด๋ด๊ธฐ ๋ฒํผ๋ค
with gr.Row():
load_btn = gr.Button("๐ ์ค์ ๋ถ๋ฌ์ค๊ธฐ")
save_btn = gr.Button("๐พ ์ค์ ๋ด๋ณด๋ด๊ธฐ")
def load_config():
llm_cfg = load_llm_config("config.json")
return (
llm_cfg.get("temperature", 0.7),
llm_cfg.get("top_p", 0.9),
llm_cfg.get("repetition_penalty", 1.05),
llm_cfg.get("max_new_tokens", 96),
"๐ ์ค์ ๋ถ๋ฌ์ค๊ธฐ ์๋ฃ"
)
def save_config(temp, topp, repeat, max_tok):
# ๊ธฐ์กด ์ ์ฒด ์ค์ ๋ถ๋ฌ์ค๊ธฐ
config = load_full_config("config.json")
# LLM ๋ธ๋ก๋ง ์๋ก ๋์
config["llm"] = {
"temperature": temp,
"top_p": topp,
"repetition_penalty": repeat,
"max_new_tokens": max_tok
}
# ์ ์ฒด ์ ์ฅ
save_full_config(config, path="config.json")
return gr.update(value="๐พ ์ค์ ์ ์ฅ ์๋ฃ")
# โ
๋งจ ์๋์ ์ํ์ฐฝ ๋ฐฐ์น
status = gr.Textbox(label="", interactive=False)
# ๐ ๋ฒํผ ๋์ ์ฐ๊ฒฐ
apply_btn.click(
update_config,
inputs=[temperature, top_p, max_tokens, repetition_penalty],
outputs=[status] # ํน์ []
)
load_btn.click(
load_config,
inputs=None,
outputs=[temperature, top_p, repetition_penalty, max_tokens, status]
)
save_btn.click(
save_config,
inputs=[temperature, top_p, repetition_penalty, max_tokens],
outputs=[status]
)
### 3. ํ๋กฌํํธ ํธ์ง ํญ ###
with gr.TabItem("๐ ํ๋กฌํํธ ์ค์ "):
gr.Markdown("### ์ฌ์ฉ์ ๋ฐ ์บ๋ฆญํฐ ์ด๋ฆ ์ค์ ")
with gr.Row():
user_name = gr.Textbox(label="๐ค ์ฌ์ฉ์ ์ด๋ฆ")
bot_name = gr.Textbox(label="๐ค ์บ๋ฆญํฐ ์ด๋ฆ")
name_status = gr.Textbox(label="", interactive=False)
with gr.Row():
load_name_btn = gr.Button("๐ ์ด๋ฆ ๋ถ๋ฌ์ค๊ธฐ")
save_name_btn = gr.Button("๐พ ์ด๋ฆ ์ ์ฅํ๊ธฐ")
def load_names(ctx):
cha_cfg = load_full_config("config.json").get("cha", {})
user = cha_cfg.get("user_name", "user")
bot = cha_cfg.get("bot_name", "Tanjiro")
ctx.setUserName(user)
ctx.setBotName(bot)
return user, bot, "๐ ์ด๋ฆ ๋ถ๋ฌ์ค๊ธฐ ์๋ฃ"
def save_names(user, bot, ctx):
config = load_full_config("config.json")
config["cha"] = {
"user_name": user,
"bot_name": bot
}
save_full_config(config, path="config.json")
ctx.setUserName(user)
ctx.setBotName(bot)
return "๐พ ์ด๋ฆ ์ ์ฅ ์๋ฃ!"
load_name_btn.click(
fn=load_names,
inputs=[state],
outputs=[user_name, bot_name, name_status]
)
save_name_btn.click(
save_names,
inputs=[user_name, bot_name, state],
outputs=[name_status]
)
#์ด๊ธฐํ ์์ ์์ ์ด๋ฆ ํ๋ฒ ๋ถ๋ฌ์ค๊ธฐ
demo.load(
fn=load_names,
inputs=[state],
outputs=[user_name, bot_name, name_status]
)
gr.Markdown("### ์บ๋ฆญํฐ ๋ฐ ์ธ๊ณ๊ด ํ๋กฌํํธ ํธ์ง")
prompt_editor = gr.Textbox(
lines=20,
label="ํ
์คํธ (init.txt)",
placeholder="!! ๋ฐ๋์ ๋ถ๋ฌ์ค๊ธฐ๋ฅผ ๋จผ์ ํ์ธ์ !!",
interactive=True
)
with gr.Row():
gr.Markdown("#### !! ๋ฐ๋์ ๋ถ๋ฌ์ค๊ธฐ๋ฅผ ๋จผ์ ํ์ธ์ !!")
with gr.Row():
load_prompt_btn = gr.Button("๐ ํ์ฌ ํ๋กฌํํธ ๋ถ๋ฌ์ค๊ธฐ")
save_prompt_btn = gr.Button("๐พ ์์ฑํ ํ๋กฌํํธ๋ก ๊ต์ฒด")
def load_prompt():
try:
with open("assets/prompt/init.txt", "r", encoding="utf-8") as f:
return f.read()
except FileNotFoundError:
return ""
def save_prompt(text):
with open("assets/prompt/init.txt", "w", encoding="utf-8") as f:
f.write(text)
return "๐พ ์ ์ฅ ์๋ฃ!"
load_prompt_btn.click(
load_prompt,
inputs=None,
outputs=prompt_editor
)
save_prompt_btn.click(
save_prompt,
inputs=[prompt_editor],
outputs=[save_prompt_btn]
)
return demo