first_chatbot / app.py
fkt
refactor code
0e3780a
import os
import streamlit as st
from dotenv import load_dotenv
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_huggingface import HuggingFaceEndpoint
# --- ์ƒ์ˆ˜ ์ •์˜ ---
# ์‚ฌ์šฉํ•  Hugging Face ๋ชจ๋ธ ID
MODEL_ID = "google/gemma-3n-e4b"
# ํ”„๋กฌํ”„ํŠธ ํ…œํ”Œ๋ฆฟ
PROMPT_TEMPLATE = """
[INST] {system_message}
ํ˜„์žฌ ๋Œ€ํ™”:
{chat_history}
์‚ฌ์šฉ์ž: {user_text}
[/INST]
AI:
"""
# --- LLM ๋ฐ ์ฒด์ธ ์„ค์ • ํ•จ์ˆ˜ ---
def get_llm(max_new_tokens=128, temperature=0.1):
"""
Hugging Face ์ถ”๋ก ์„ ์œ„ํ•œ ์–ธ์–ด ๋ชจ๋ธ(LLM)์„ ์ƒ์„ฑํ•˜๊ณ  ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
Args:
max_new_tokens (int): ์ƒ์„ฑํ•  ์ตœ๋Œ€ ํ† ํฐ ์ˆ˜์ž…๋‹ˆ๋‹ค.
temperature (float): ์ƒ˜ํ”Œ๋ง ์˜จ๋„๋กœ, ๋‚ฎ์„์ˆ˜๋ก ๊ฒฐ์ •์ ์ธ ๋‹ต๋ณ€์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
Returns:
HuggingFaceEndpoint: ์„ค์ •๋œ ์–ธ์–ด ๋ชจ๋ธ ๊ฐ์ฒด์ž…๋‹ˆ๋‹ค.
"""
return HuggingFaceEndpoint(
repo_id=MODEL_ID,
max_new_tokens=max_new_tokens,
temperature=temperature,
token=os.getenv("HF_TOKEN"),
)
def get_chain(llm):
"""
์ฃผ์–ด์ง„ ์–ธ์–ด ๋ชจ๋ธ(LLM)์„ ์‚ฌ์šฉํ•˜์—ฌ ๋Œ€ํ™” ์ฒด์ธ์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
Args:
llm (HuggingFaceEndpoint): ์‚ฌ์šฉํ•  ์–ธ์–ด ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค.
Returns:
RunnableSequence: LangChain ํ‘œํ˜„ ์–ธ์–ด(LCEL)๋กœ ๊ตฌ์„ฑ๋œ ์‹คํ–‰ ๊ฐ€๋Šฅํ•œ ์ฒด์ธ์ž…๋‹ˆ๋‹ค.
"""
prompt = PromptTemplate.from_template(PROMPT_TEMPLATE)
return prompt | llm | StrOutputParser()
def generate_response(chain, system_message, chat_history, user_text):
"""
LLM ์ฒด์ธ์„ ํ˜ธ์ถœํ•˜์—ฌ ์‚ฌ์šฉ์ž์˜ ์ž…๋ ฅ์— ๋Œ€ํ•œ ์‘๋‹ต์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
Args:
chain (RunnableSequence): ์‘๋‹ต ์ƒ์„ฑ์„ ์œ„ํ•œ LLM ์ฒด์ธ์ž…๋‹ˆ๋‹ค.
system_message (str): AI์˜ ์—ญํ• ์„ ์ •์˜ํ•˜๋Š” ์‹œ์Šคํ…œ ๋ฉ”์‹œ์ง€์ž…๋‹ˆ๋‹ค.
chat_history (list[dict]): ์ด์ „ ๋Œ€ํ™” ๊ธฐ๋ก์ž…๋‹ˆ๋‹ค.
user_text (str): ์‚ฌ์šฉ์ž์˜ ํ˜„์žฌ ์ž…๋ ฅ ๋ฉ”์‹œ์ง€์ž…๋‹ˆ๋‹ค.
Returns:
str: ์ƒ์„ฑ๋œ AI์˜ ์‘๋‹ต ๋ฉ”์‹œ์ง€์ž…๋‹ˆ๋‹ค.
"""
history_str = "\n".join(
[f"{msg['role']}: {msg['content']}" for msg in chat_history]
)
response = chain.invoke({
"system_message": system_message,
"chat_history": history_str,
"user_text": user_text,
})
return response.split("AI:")[-1].strip()
# --- UI ๋ Œ๋”๋ง ํ•จ์ˆ˜ ---
def initialize_session_state():
"""
Streamlit ์„ธ์…˜ ์ƒํƒœ๋ฅผ ์ดˆ๊ธฐํ™”ํ•ฉ๋‹ˆ๋‹ค.
์„ธ์…˜์ด ์ฒ˜์Œ ์‹œ์ž‘๋  ๋•Œ ๊ธฐ๋ณธ๊ฐ’์„ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.
"""
defaults = {
"avatars": {"user": "๐Ÿ‘ค", "assistant": "๐Ÿค—"},
"chat_history": [],
"max_response_length": 256,
"system_message": "๋‹น์‹ ์€ ์ธ๊ฐ„ ์‚ฌ์šฉ์ž์™€ ๋Œ€ํ™”ํ•˜๋Š” ์นœ์ ˆํ•œ AI์ž…๋‹ˆ๋‹ค.",
"starter_message": "์•ˆ๋…•ํ•˜์„ธ์š”! ์˜ค๋Š˜ ๋ฌด์—‡์„ ๋„์™€๋“œ๋ฆด๊นŒ์š”?",
}
for key, value in defaults.items():
if key not in st.session_state:
st.session_state[key] = value
if not st.session_state.chat_history:
st.session_state.chat_history = [
{"role": "assistant", "content": st.session_state.starter_message}
]
def setup_sidebar():
"""
์‚ฌ์ด๋“œ๋ฐ” UI ๊ตฌ์„ฑ ์š”์†Œ๋ฅผ ์„ค์ •ํ•˜๊ณ  ๋ Œ๋”๋งํ•ฉ๋‹ˆ๋‹ค.
์‚ฌ์šฉ์ž๋Š” ์ด ์‚ฌ์ด๋“œ๋ฐ”์—์„œ ์‹œ์Šคํ…œ ์„ค์ •, AI ๋ฉ”์‹œ์ง€, ๋ชจ๋ธ ์‘๋‹ต ๊ธธ์ด ๋“ฑ์„ ์กฐ์ •ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
"""
with st.sidebar:
st.header("์‹œ์Šคํ…œ ์„ค์ •")
st.session_state.system_message = st.text_area(
"์‹œ์Šคํ…œ ๋ฉ”์‹œ์ง€", value=st.session_state.system_message
)
st.session_state.starter_message = st.text_area(
"์ฒซ ๋ฒˆ์งธ AI ๋ฉ”์‹œ์ง€", value=st.session_state.starter_message
)
st.session_state.max_response_length = st.number_input(
"์ตœ๋Œ€ ์‘๋‹ต ๊ธธ์ด", value=st.session_state.max_response_length
)
st.markdown("*์•„๋ฐ”ํƒ€ ์„ ํƒ:*")
col1, col2 = st.columns(2)
with col1:
st.session_state.avatars["assistant"] = st.selectbox(
"AI ์•„๋ฐ”ํƒ€", options=["๐Ÿค—", "๐Ÿ’ฌ", "๐Ÿค–"], index=0
)
with col2:
st.session_state.avatars["user"] = st.selectbox(
"์‚ฌ์šฉ์ž ์•„๋ฐ”ํƒ€", options=["๐Ÿ‘ค", "๐Ÿ‘ฑโ€โ™‚๏ธ", "๐Ÿ‘จ๐Ÿพ", "๐Ÿ‘ฉ", "๐Ÿ‘ง๐Ÿพ"], index=0
)
if st.button("์ฑ„ํŒ… ๊ธฐ๋ก ์ดˆ๊ธฐํ™”"):
st.session_state.chat_history = [
{"role": "assistant", "content": st.session_state.starter_message}
]
st.rerun()
def display_chat_history():
"""
์„ธ์…˜์— ์ €์žฅ๋œ ์ฑ„ํŒ… ๊ธฐ๋ก์„ ์ˆœํšŒํ•˜๋ฉฐ ํ™”๋ฉด์— ๋ฉ”์‹œ์ง€๋ฅผ ํ‘œ์‹œํ•ฉ๋‹ˆ๋‹ค.
"""
for message in st.session_state.chat_history:
if message["role"] == "system":
continue
avatar = st.session_state.avatars.get(message["role"])
with st.chat_message(message["role"], avatar=avatar):
st.markdown(message["content"])
# --- ๋ฉ”์ธ ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜ ์‹คํ–‰ ---
def main():
"""
๋ฉ”์ธ Streamlit ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜์„ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค.
"""
load_dotenv()
st.set_page_config(page_title="HuggingFace ChatBot", page_icon="๐Ÿค—")
st.title("๊ฐœ์ธ HuggingFace ์ฑ—๋ด‡")
st.markdown(
f"*์ด๊ฒƒ์€ HuggingFace transformers ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ…์ŠคํŠธ ์ž…๋ ฅ์— ๋Œ€ํ•œ ์‘๋‹ต์„ ์ƒ์„ฑํ•˜๋Š” ๊ฐ„๋‹จํ•œ ์ฑ—๋ด‡์ž…๋‹ˆ๋‹ค. {MODEL_ID} ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.*"
)
initialize_session_state()
setup_sidebar()
# ์ฑ„ํŒ… ๊ธฐ๋ก ํ‘œ์‹œ
display_chat_history()
# ์‚ฌ์šฉ์ž ์ž…๋ ฅ ์ฒ˜๋ฆฌ
if user_input := st.chat_input("์—ฌ๊ธฐ์— ํ…์ŠคํŠธ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”."):
# ์‚ฌ์šฉ์ž ๋ฉ”์‹œ์ง€๋ฅผ ๊ธฐ๋ก์— ์ถ”๊ฐ€ํ•˜๊ณ  ํ™”๋ฉด์— ํ‘œ์‹œ
st.session_state.chat_history.append({"role": "user", "content": user_input})
with st.chat_message("user", avatar=st.session_state.avatars["user"]):
st.markdown(user_input)
# AI ์‘๋‹ต ์ƒ์„ฑ ๋ฐ ํ‘œ์‹œ
with st.chat_message("assistant", avatar=st.session_state.avatars["assistant"]):
with st.spinner("์ƒ๊ฐ ์ค‘..."):
llm = get_llm(max_new_tokens=st.session_state.max_response_length)
chain = get_chain(llm)
response = generate_response(
chain,
st.session_state.system_message,
st.session_state.chat_history,
user_input,
)
st.session_state.chat_history.append({"role": "assistant", "content": response})
st.markdown(response)
if __name__ == "__main__":
main()