import os import streamlit as st from dotenv import load_dotenv from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import PromptTemplate from langchain_huggingface import HuggingFaceEndpoint # --- 상수 정의 --- # 사용할 Hugging Face 모델 ID MODEL_ID = "google/gemma-3n-e4b" # 프롬프트 템플릿 PROMPT_TEMPLATE = """ [INST] {system_message} 현재 대화: {chat_history} 사용자: {user_text} [/INST] AI: """ # --- LLM 및 체인 설정 함수 --- def get_llm(max_new_tokens=128, temperature=0.1): """ Hugging Face 추론을 위한 언어 모델(LLM)을 생성하고 반환합니다. Args: max_new_tokens (int): 생성할 최대 토큰 수입니다. temperature (float): 샘플링 온도로, 낮을수록 결정적인 답변을 생성합니다. Returns: HuggingFaceEndpoint: 설정된 언어 모델 객체입니다. """ return HuggingFaceEndpoint( repo_id=MODEL_ID, max_new_tokens=max_new_tokens, temperature=temperature, token=os.getenv("HF_TOKEN"), ) def get_chain(llm): """ 주어진 언어 모델(LLM)을 사용하여 대화 체인을 생성합니다. Args: llm (HuggingFaceEndpoint): 사용할 언어 모델입니다. Returns: RunnableSequence: LangChain 표현 언어(LCEL)로 구성된 실행 가능한 체인입니다. """ prompt = PromptTemplate.from_template(PROMPT_TEMPLATE) return prompt | llm | StrOutputParser() def generate_response(chain, system_message, chat_history, user_text): """ LLM 체인을 호출하여 사용자의 입력에 대한 응답을 생성합니다. Args: chain (RunnableSequence): 응답 생성을 위한 LLM 체인입니다. system_message (str): AI의 역할을 정의하는 시스템 메시지입니다. chat_history (list[dict]): 이전 대화 기록입니다. user_text (str): 사용자의 현재 입력 메시지입니다. Returns: str: 생성된 AI의 응답 메시지입니다. """ history_str = "\n".join( [f"{msg['role']}: {msg['content']}" for msg in chat_history] ) response = chain.invoke({ "system_message": system_message, "chat_history": history_str, "user_text": user_text, }) return response.split("AI:")[-1].strip() # --- UI 렌더링 함수 --- def initialize_session_state(): """ Streamlit 세션 상태를 초기화합니다. 세션이 처음 시작될 때 기본값을 설정합니다. """ defaults = { "avatars": {"user": "👤", "assistant": "🤗"}, "chat_history": [], "max_response_length": 256, "system_message": "당신은 인간 사용자와 대화하는 친절한 AI입니다.", "starter_message": "안녕하세요! 오늘 무엇을 도와드릴까요?", } for key, value in defaults.items(): if key not in st.session_state: st.session_state[key] = value if not st.session_state.chat_history: st.session_state.chat_history = [ {"role": "assistant", "content": st.session_state.starter_message} ] def setup_sidebar(): """ 사이드바 UI 구성 요소를 설정하고 렌더링합니다. 사용자는 이 사이드바에서 시스템 설정, AI 메시지, 모델 응답 길이 등을 조정할 수 있습니다. """ with st.sidebar: st.header("시스템 설정") st.session_state.system_message = st.text_area( "시스템 메시지", value=st.session_state.system_message ) st.session_state.starter_message = st.text_area( "첫 번째 AI 메시지", value=st.session_state.starter_message ) st.session_state.max_response_length = st.number_input( "최대 응답 길이", value=st.session_state.max_response_length ) st.markdown("*아바타 선택:*") col1, col2 = st.columns(2) with col1: st.session_state.avatars["assistant"] = st.selectbox( "AI 아바타", options=["🤗", "💬", "🤖"], index=0 ) with col2: st.session_state.avatars["user"] = st.selectbox( "사용자 아바타", options=["👤", "👱‍♂️", "👨🏾", "👩", "👧🏾"], index=0 ) if st.button("채팅 기록 초기화"): st.session_state.chat_history = [ {"role": "assistant", "content": st.session_state.starter_message} ] st.rerun() def display_chat_history(): """ 세션에 저장된 채팅 기록을 순회하며 화면에 메시지를 표시합니다. """ for message in st.session_state.chat_history: if message["role"] == "system": continue avatar = st.session_state.avatars.get(message["role"]) with st.chat_message(message["role"], avatar=avatar): st.markdown(message["content"]) # --- 메인 애플리케이션 실행 --- def main(): """ 메인 Streamlit 애플리케이션을 실행합니다. """ load_dotenv() st.set_page_config(page_title="HuggingFace ChatBot", page_icon="🤗") st.title("개인 HuggingFace 챗봇") st.markdown( f"*이것은 HuggingFace transformers 라이브러리를 사용하여 텍스트 입력에 대한 응답을 생성하는 간단한 챗봇입니다. {MODEL_ID} 모델을 사용합니다.*" ) initialize_session_state() setup_sidebar() # 채팅 기록 표시 display_chat_history() # 사용자 입력 처리 if user_input := st.chat_input("여기에 텍스트를 입력하세요."): # 사용자 메시지를 기록에 추가하고 화면에 표시 st.session_state.chat_history.append({"role": "user", "content": user_input}) with st.chat_message("user", avatar=st.session_state.avatars["user"]): st.markdown(user_input) # AI 응답 생성 및 표시 with st.chat_message("assistant", avatar=st.session_state.avatars["assistant"]): with st.spinner("생각 중..."): llm = get_llm(max_new_tokens=st.session_state.max_response_length) chain = get_chain(llm) response = generate_response( chain, st.session_state.system_message, st.session_state.chat_history, user_input, ) st.session_state.chat_history.append({"role": "assistant", "content": response}) st.markdown(response) if __name__ == "__main__": main()