Spaces:
Running
Running
import os | |
import streamlit as st | |
from dotenv import load_dotenv | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.prompts import PromptTemplate | |
from langchain_huggingface import HuggingFaceEndpoint | |
# --- ์์ ์ ์ --- | |
# ์ฌ์ฉํ Hugging Face ๋ชจ๋ธ ID | |
MODEL_ID = "google/gemma-3n-e4b" | |
# ํ๋กฌํํธ ํ ํ๋ฆฟ | |
PROMPT_TEMPLATE = """ | |
[INST] {system_message} | |
ํ์ฌ ๋ํ: | |
{chat_history} | |
์ฌ์ฉ์: {user_text} | |
[/INST] | |
AI: | |
""" | |
# --- LLM ๋ฐ ์ฒด์ธ ์ค์ ํจ์ --- | |
def get_llm(max_new_tokens=128, temperature=0.1): | |
""" | |
Hugging Face ์ถ๋ก ์ ์ํ ์ธ์ด ๋ชจ๋ธ(LLM)์ ์์ฑํ๊ณ ๋ฐํํฉ๋๋ค. | |
Args: | |
max_new_tokens (int): ์์ฑํ ์ต๋ ํ ํฐ ์์ ๋๋ค. | |
temperature (float): ์ํ๋ง ์จ๋๋ก, ๋ฎ์์๋ก ๊ฒฐ์ ์ ์ธ ๋ต๋ณ์ ์์ฑํฉ๋๋ค. | |
Returns: | |
HuggingFaceEndpoint: ์ค์ ๋ ์ธ์ด ๋ชจ๋ธ ๊ฐ์ฒด์ ๋๋ค. | |
""" | |
return HuggingFaceEndpoint( | |
repo_id=MODEL_ID, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
token=os.getenv("HF_TOKEN"), | |
) | |
def get_chain(llm): | |
""" | |
์ฃผ์ด์ง ์ธ์ด ๋ชจ๋ธ(LLM)์ ์ฌ์ฉํ์ฌ ๋ํ ์ฒด์ธ์ ์์ฑํฉ๋๋ค. | |
Args: | |
llm (HuggingFaceEndpoint): ์ฌ์ฉํ ์ธ์ด ๋ชจ๋ธ์ ๋๋ค. | |
Returns: | |
RunnableSequence: LangChain ํํ ์ธ์ด(LCEL)๋ก ๊ตฌ์ฑ๋ ์คํ ๊ฐ๋ฅํ ์ฒด์ธ์ ๋๋ค. | |
""" | |
prompt = PromptTemplate.from_template(PROMPT_TEMPLATE) | |
return prompt | llm | StrOutputParser() | |
def generate_response(chain, system_message, chat_history, user_text): | |
""" | |
LLM ์ฒด์ธ์ ํธ์ถํ์ฌ ์ฌ์ฉ์์ ์ ๋ ฅ์ ๋ํ ์๋ต์ ์์ฑํฉ๋๋ค. | |
Args: | |
chain (RunnableSequence): ์๋ต ์์ฑ์ ์ํ LLM ์ฒด์ธ์ ๋๋ค. | |
system_message (str): AI์ ์ญํ ์ ์ ์ํ๋ ์์คํ ๋ฉ์์ง์ ๋๋ค. | |
chat_history (list[dict]): ์ด์ ๋ํ ๊ธฐ๋ก์ ๋๋ค. | |
user_text (str): ์ฌ์ฉ์์ ํ์ฌ ์ ๋ ฅ ๋ฉ์์ง์ ๋๋ค. | |
Returns: | |
str: ์์ฑ๋ AI์ ์๋ต ๋ฉ์์ง์ ๋๋ค. | |
""" | |
history_str = "\n".join( | |
[f"{msg['role']}: {msg['content']}" for msg in chat_history] | |
) | |
response = chain.invoke({ | |
"system_message": system_message, | |
"chat_history": history_str, | |
"user_text": user_text, | |
}) | |
return response.split("AI:")[-1].strip() | |
# --- UI ๋ ๋๋ง ํจ์ --- | |
def initialize_session_state(): | |
""" | |
Streamlit ์ธ์ ์ํ๋ฅผ ์ด๊ธฐํํฉ๋๋ค. | |
์ธ์ ์ด ์ฒ์ ์์๋ ๋ ๊ธฐ๋ณธ๊ฐ์ ์ค์ ํฉ๋๋ค. | |
""" | |
defaults = { | |
"avatars": {"user": "๐ค", "assistant": "๐ค"}, | |
"chat_history": [], | |
"max_response_length": 256, | |
"system_message": "๋น์ ์ ์ธ๊ฐ ์ฌ์ฉ์์ ๋ํํ๋ ์น์ ํ AI์ ๋๋ค.", | |
"starter_message": "์๋ ํ์ธ์! ์ค๋ ๋ฌด์์ ๋์๋๋ฆด๊น์?", | |
} | |
for key, value in defaults.items(): | |
if key not in st.session_state: | |
st.session_state[key] = value | |
if not st.session_state.chat_history: | |
st.session_state.chat_history = [ | |
{"role": "assistant", "content": st.session_state.starter_message} | |
] | |
def setup_sidebar(): | |
""" | |
์ฌ์ด๋๋ฐ UI ๊ตฌ์ฑ ์์๋ฅผ ์ค์ ํ๊ณ ๋ ๋๋งํฉ๋๋ค. | |
์ฌ์ฉ์๋ ์ด ์ฌ์ด๋๋ฐ์์ ์์คํ ์ค์ , AI ๋ฉ์์ง, ๋ชจ๋ธ ์๋ต ๊ธธ์ด ๋ฑ์ ์กฐ์ ํ ์ ์์ต๋๋ค. | |
""" | |
with st.sidebar: | |
st.header("์์คํ ์ค์ ") | |
st.session_state.system_message = st.text_area( | |
"์์คํ ๋ฉ์์ง", value=st.session_state.system_message | |
) | |
st.session_state.starter_message = st.text_area( | |
"์ฒซ ๋ฒ์งธ AI ๋ฉ์์ง", value=st.session_state.starter_message | |
) | |
st.session_state.max_response_length = st.number_input( | |
"์ต๋ ์๋ต ๊ธธ์ด", value=st.session_state.max_response_length | |
) | |
st.markdown("*์๋ฐํ ์ ํ:*") | |
col1, col2 = st.columns(2) | |
with col1: | |
st.session_state.avatars["assistant"] = st.selectbox( | |
"AI ์๋ฐํ", options=["๐ค", "๐ฌ", "๐ค"], index=0 | |
) | |
with col2: | |
st.session_state.avatars["user"] = st.selectbox( | |
"์ฌ์ฉ์ ์๋ฐํ", options=["๐ค", "๐ฑโโ๏ธ", "๐จ๐พ", "๐ฉ", "๐ง๐พ"], index=0 | |
) | |
if st.button("์ฑํ ๊ธฐ๋ก ์ด๊ธฐํ"): | |
st.session_state.chat_history = [ | |
{"role": "assistant", "content": st.session_state.starter_message} | |
] | |
st.rerun() | |
def display_chat_history(): | |
""" | |
์ธ์ ์ ์ ์ฅ๋ ์ฑํ ๊ธฐ๋ก์ ์ํํ๋ฉฐ ํ๋ฉด์ ๋ฉ์์ง๋ฅผ ํ์ํฉ๋๋ค. | |
""" | |
for message in st.session_state.chat_history: | |
if message["role"] == "system": | |
continue | |
avatar = st.session_state.avatars.get(message["role"]) | |
with st.chat_message(message["role"], avatar=avatar): | |
st.markdown(message["content"]) | |
# --- ๋ฉ์ธ ์ ํ๋ฆฌ์ผ์ด์ ์คํ --- | |
def main(): | |
""" | |
๋ฉ์ธ Streamlit ์ ํ๋ฆฌ์ผ์ด์ ์ ์คํํฉ๋๋ค. | |
""" | |
load_dotenv() | |
st.set_page_config(page_title="HuggingFace ChatBot", page_icon="๐ค") | |
st.title("๊ฐ์ธ HuggingFace ์ฑ๋ด") | |
st.markdown( | |
f"*์ด๊ฒ์ HuggingFace transformers ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ฌ์ฉํ์ฌ ํ ์คํธ ์ ๋ ฅ์ ๋ํ ์๋ต์ ์์ฑํ๋ ๊ฐ๋จํ ์ฑ๋ด์ ๋๋ค. {MODEL_ID} ๋ชจ๋ธ์ ์ฌ์ฉํฉ๋๋ค.*" | |
) | |
initialize_session_state() | |
setup_sidebar() | |
# ์ฑํ ๊ธฐ๋ก ํ์ | |
display_chat_history() | |
# ์ฌ์ฉ์ ์ ๋ ฅ ์ฒ๋ฆฌ | |
if user_input := st.chat_input("์ฌ๊ธฐ์ ํ ์คํธ๋ฅผ ์ ๋ ฅํ์ธ์."): | |
# ์ฌ์ฉ์ ๋ฉ์์ง๋ฅผ ๊ธฐ๋ก์ ์ถ๊ฐํ๊ณ ํ๋ฉด์ ํ์ | |
st.session_state.chat_history.append({"role": "user", "content": user_input}) | |
with st.chat_message("user", avatar=st.session_state.avatars["user"]): | |
st.markdown(user_input) | |
# AI ์๋ต ์์ฑ ๋ฐ ํ์ | |
with st.chat_message("assistant", avatar=st.session_state.avatars["assistant"]): | |
with st.spinner("์๊ฐ ์ค..."): | |
llm = get_llm(max_new_tokens=st.session_state.max_response_length) | |
chain = get_chain(llm) | |
response = generate_response( | |
chain, | |
st.session_state.system_message, | |
st.session_state.chat_history, | |
user_input, | |
) | |
st.session_state.chat_history.append({"role": "assistant", "content": response}) | |
st.markdown(response) | |
if __name__ == "__main__": | |
main() |