Update app.py
Browse files
app.py
CHANGED
@@ -6,23 +6,33 @@ import numpy as np
|
|
6 |
from sentence_transformers import SentenceTransformer
|
7 |
from sklearn.metrics.pairwise import cosine_similarity
|
8 |
|
9 |
-
#
|
10 |
-
|
11 |
-
openai.api_key = os.getenv("OPENAI_API_KEY")
|
12 |
|
13 |
-
#
|
14 |
-
#
|
15 |
-
model = SentenceTransformer(
|
16 |
|
17 |
-
#
|
18 |
-
df = pd.read_csv(
|
19 |
df = df.dropna()
|
20 |
-
|
|
|
|
|
21 |
|
22 |
-
#
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
-
# ===== 프롬프트 =====
|
26 |
EMPATHY_PROMPT = """\
|
27 |
당신은 친절한 정신의학과 전문의이며 심리상담 전문가입니다.
|
28 |
사용자의 문장을 거의 그대로 요약하되, 끝에 '는군요.' 같은 공감 어미를 붙여 자연스럽게 응답하세요.
|
@@ -59,24 +69,10 @@ ADVICE_PROMPT = """\
|
|
59 |
조언:
|
60 |
"""
|
61 |
|
62 |
-
|
63 |
-
"""
|
64 |
-
유저 요청대로 'gpt-4o' 모델명 반환
|
65 |
-
(실제로는 존재하지 않을 가능성 큼)
|
66 |
-
"""
|
67 |
-
return "gpt-4o"
|
68 |
-
|
69 |
-
# ===== 함수들 =====
|
70 |
-
|
71 |
-
def kb_search(user_input: str) -> str:
|
72 |
-
"""SentenceTransformer로 임베딩 후, df에서 가장 유사한 챗봇 답변 획득."""
|
73 |
-
emb = model.encode(user_input)
|
74 |
-
df["sim"] = df["embedding"].map(lambda e: cosine_similarity([emb],[e]).squeeze())
|
75 |
-
idx = df["sim"].idxmax()
|
76 |
-
return df.loc[idx, "챗봇"]
|
77 |
|
78 |
def call_empathy(user_input: str) -> str:
|
79 |
-
"""
|
80 |
prompt = EMPATHY_PROMPT.format(sentence=user_input)
|
81 |
resp = openai.ChatCompletion.create(
|
82 |
model=set_openai_model(),
|
@@ -90,7 +86,7 @@ def call_empathy(user_input: str) -> str:
|
|
90 |
return resp.choices[0].message.content.strip()
|
91 |
|
92 |
def call_socratic_question(context: str) -> str:
|
93 |
-
"""
|
94 |
prompt = f"{SOCRATIC_PROMPT}\n\n대화 힌트:\n{context}"
|
95 |
resp = openai.ChatCompletion.create(
|
96 |
model=set_openai_model(),
|
@@ -104,7 +100,7 @@ def call_socratic_question(context: str) -> str:
|
|
104 |
return resp.choices[0].message.content.strip()
|
105 |
|
106 |
def call_advice(hints: str) -> str:
|
107 |
-
"""
|
108 |
final_prompt = ADVICE_PROMPT.format(hints=hints)
|
109 |
resp = openai.ChatCompletion.create(
|
110 |
model=set_openai_model(),
|
@@ -117,8 +113,8 @@ def call_advice(hints: str) -> str:
|
|
117 |
)
|
118 |
return resp.choices[0].message.content.strip()
|
119 |
|
|
|
120 |
def predict(user_input: str, state: dict):
|
121 |
-
"""Gradio Callback: 소크라테스 CBT 챗봇 흐름 (EMPATHY→SQ→ADVICE)."""
|
122 |
history = state.get("history", [])
|
123 |
stage = state.get("stage", "EMPATHY")
|
124 |
turn = state.get("turn", 0)
|
@@ -127,11 +123,20 @@ def predict(user_input: str, state: dict):
|
|
127 |
# 1) 사용자 발화 기록
|
128 |
history.append(("User", user_input))
|
129 |
|
130 |
-
# 2)
|
131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
hints.append(f"[KB] {kb_answer}")
|
133 |
|
134 |
-
# 3)
|
135 |
if stage == "EMPATHY":
|
136 |
empathic = call_empathy(user_input)
|
137 |
history.append(("Chatbot", empathic))
|
@@ -141,7 +146,7 @@ def predict(user_input: str, state: dict):
|
|
141 |
return history, {"history": history, "stage": stage, "turn": turn, "hints": hints}
|
142 |
|
143 |
if stage == "SQ" and turn < MAX_TURN:
|
144 |
-
# 전체 대화 + hints
|
145 |
context_text = "\n".join([f"{r}: {c}" for (r,c) in history]) + "\n" + "\n".join(hints)
|
146 |
sq = call_socratic_question(context_text)
|
147 |
history.append(("Chatbot", sq))
|
@@ -150,50 +155,50 @@ def predict(user_input: str, state: dict):
|
|
150 |
return history, {"history": history, "stage": stage, "turn": turn, "hints": hints}
|
151 |
|
152 |
# ADVICE 단계
|
153 |
-
stage = "
|
154 |
combined_hints = "\n".join(hints)
|
155 |
advice = call_advice(combined_hints)
|
156 |
history.append(("Chatbot", advice))
|
157 |
-
|
158 |
return history, {"history":history, "stage":stage, "turn":turn, "hints":hints}
|
159 |
|
|
|
160 |
def gradio_predict(user_input, chat_state):
|
161 |
-
"""Gradio
|
162 |
new_history, new_state = predict(user_input, chat_state)
|
163 |
-
|
164 |
-
#
|
165 |
display_history = []
|
166 |
for (role, txt) in new_history:
|
167 |
if role == "User":
|
168 |
-
display_history.append([txt, ""])
|
169 |
-
else:
|
170 |
-
if
|
171 |
display_history.append(["", txt])
|
172 |
-
elif display_history[-1][1] == "":
|
173 |
-
display_history[-1][1] = txt
|
174 |
else:
|
175 |
-
display_history
|
|
|
176 |
return display_history, new_state
|
177 |
|
178 |
def create_app():
|
179 |
-
"""Gradio Blocks UI 구성."""
|
180 |
with gr.Blocks() as demo:
|
181 |
-
gr.Markdown("##
|
182 |
|
183 |
-
chatbot = gr.Chatbot(label="
|
184 |
chat_state = gr.State({
|
185 |
"history": [],
|
186 |
"stage":"EMPATHY",
|
187 |
"turn":0,
|
188 |
"hints":[]
|
189 |
})
|
190 |
-
txt = gr.Textbox(show_label=False, placeholder="
|
191 |
|
192 |
-
|
|
|
193 |
return demo
|
194 |
|
195 |
app = create_app()
|
196 |
|
197 |
if __name__ == "__main__":
|
198 |
-
#
|
199 |
app.launch(debug=True, share=True)
|
|
|
6 |
from sentence_transformers import SentenceTransformer
|
7 |
from sklearn.metrics.pairwise import cosine_similarity
|
8 |
|
9 |
+
# OpenAI API Key (Hugging Face Secrets)
|
10 |
+
openai.api_key = os.getenv("OPENAI_API_KEY", "")
|
|
|
11 |
|
12 |
+
# =============== 0) 모델 / df 준비 ===============
|
13 |
+
# SentenceTransformer
|
14 |
+
model = SentenceTransformer('jhgan/ko-sroberta-multitask')
|
15 |
|
16 |
+
# 정신의학챗봇 CSV 로드
|
17 |
+
df = pd.read_csv('https://raw.githubusercontent.com/kairess/mental-health-chatbot/master/wellness_dataset_original.csv')
|
18 |
df = df.dropna()
|
19 |
+
# Unnamed 컬럼 제거
|
20 |
+
if 'Unnamed: 3' in df.columns:
|
21 |
+
df = df.drop(columns=['Unnamed: 3'])
|
22 |
|
23 |
+
# 임베딩 필드
|
24 |
+
df['embedding'] = df['유저'].map(lambda x: model.encode(str(x)))
|
25 |
+
|
26 |
+
# ============== 1) 파라미터/프롬프트 ==============
|
27 |
+
MAX_TURN = 5 # 최대 소크라테스 질문 회수
|
28 |
+
|
29 |
+
def set_openai_model():
|
30 |
+
"""
|
31 |
+
GPT-4 대신 'gpt-4o' (실제론 비존재 모델)
|
32 |
+
=> 실제로는 'gpt-3.5-turbo' 등으로 교체 권장
|
33 |
+
"""
|
34 |
+
return "gpt-4o"
|
35 |
|
|
|
36 |
EMPATHY_PROMPT = """\
|
37 |
당신은 친절한 정신의학과 전문의이며 심리상담 전문가입니다.
|
38 |
사용자의 문장을 거의 그대로 요약하되, 끝에 '는군요.' 같은 공감 어미를 붙여 자연스럽게 응답하세요.
|
|
|
69 |
조언:
|
70 |
"""
|
71 |
|
72 |
+
# ============== 2) OpenAI 호출 함수들 ==============
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
def call_empathy(user_input: str) -> str:
|
75 |
+
""" 공감 요약 생성 """
|
76 |
prompt = EMPATHY_PROMPT.format(sentence=user_input)
|
77 |
resp = openai.ChatCompletion.create(
|
78 |
model=set_openai_model(),
|
|
|
86 |
return resp.choices[0].message.content.strip()
|
87 |
|
88 |
def call_socratic_question(context: str) -> str:
|
89 |
+
""" 소크라테스 후속질문 1문장 생성 """
|
90 |
prompt = f"{SOCRATIC_PROMPT}\n\n대화 힌트:\n{context}"
|
91 |
resp = openai.ChatCompletion.create(
|
92 |
model=set_openai_model(),
|
|
|
100 |
return resp.choices[0].message.content.strip()
|
101 |
|
102 |
def call_advice(hints: str) -> str:
|
103 |
+
""" 최종 CBT 조언 """
|
104 |
final_prompt = ADVICE_PROMPT.format(hints=hints)
|
105 |
resp = openai.ChatCompletion.create(
|
106 |
model=set_openai_model(),
|
|
|
113 |
)
|
114 |
return resp.choices[0].message.content.strip()
|
115 |
|
116 |
+
# ============== 3) predict 함수: EMPATHY→SQ→ADVICE ==============
|
117 |
def predict(user_input: str, state: dict):
|
|
|
118 |
history = state.get("history", [])
|
119 |
stage = state.get("stage", "EMPATHY")
|
120 |
turn = state.get("turn", 0)
|
|
|
123 |
# 1) 사용자 발화 기록
|
124 |
history.append(("User", user_input))
|
125 |
|
126 |
+
# 2) 유사도 계산 → df['챗봇']
|
127 |
+
query_emb = model.encode(user_input)
|
128 |
+
df["sim"] = df["embedding"].map(lambda emb: cosine_similarity([query_emb],[emb]).squeeze())
|
129 |
+
|
130 |
+
# idxmax() 에러 방지: df가 비었거나 sim이 NaN인 경우 처리
|
131 |
+
if df["sim"].count() == 0:
|
132 |
+
# fallback: 그냥 "지식베이스가 비어 있습니다" 등
|
133 |
+
kb_answer = "적합한 지식베이스 응답을 찾지 못했어요."
|
134 |
+
else:
|
135 |
+
kb_answer = df.loc[df["sim"].idxmax(), "챗봇"]
|
136 |
+
|
137 |
hints.append(f"[KB] {kb_answer}")
|
138 |
|
139 |
+
# 3) 단계별 분기
|
140 |
if stage == "EMPATHY":
|
141 |
empathic = call_empathy(user_input)
|
142 |
history.append(("Chatbot", empathic))
|
|
|
146 |
return history, {"history": history, "stage": stage, "turn": turn, "hints": hints}
|
147 |
|
148 |
if stage == "SQ" and turn < MAX_TURN:
|
149 |
+
# 전체 대화 + hints → 소크라테스 질문
|
150 |
context_text = "\n".join([f"{r}: {c}" for (r,c) in history]) + "\n" + "\n".join(hints)
|
151 |
sq = call_socratic_question(context_text)
|
152 |
history.append(("Chatbot", sq))
|
|
|
155 |
return history, {"history": history, "stage": stage, "turn": turn, "hints": hints}
|
156 |
|
157 |
# ADVICE 단계
|
158 |
+
stage = "END"
|
159 |
combined_hints = "\n".join(hints)
|
160 |
advice = call_advice(combined_hints)
|
161 |
history.append(("Chatbot", advice))
|
162 |
+
|
163 |
return history, {"history":history, "stage":stage, "turn":turn, "hints":hints}
|
164 |
|
165 |
+
# ============== 4) Gradio UI ==============
|
166 |
def gradio_predict(user_input, chat_state):
|
167 |
+
"""Gradio callback"""
|
168 |
new_history, new_state = predict(user_input, chat_state)
|
169 |
+
|
170 |
+
# Gradio Chatbot expects list of [user, bot] pairs
|
171 |
display_history = []
|
172 |
for (role, txt) in new_history:
|
173 |
if role == "User":
|
174 |
+
display_history.append([txt, ""]) # user in left
|
175 |
+
else:
|
176 |
+
if len(display_history)==0:
|
177 |
display_history.append(["", txt])
|
|
|
|
|
178 |
else:
|
179 |
+
display_history[-1][1] = txt # bot in right
|
180 |
+
|
181 |
return display_history, new_state
|
182 |
|
183 |
def create_app():
|
|
|
184 |
with gr.Blocks() as demo:
|
185 |
+
gr.Markdown("## 다중턴 소크라테스 CBT 챗봇")
|
186 |
|
187 |
+
chatbot = gr.Chatbot(label="CBT Chatbot")
|
188 |
chat_state = gr.State({
|
189 |
"history": [],
|
190 |
"stage":"EMPATHY",
|
191 |
"turn":0,
|
192 |
"hints":[]
|
193 |
})
|
194 |
+
txt = gr.Textbox(show_label=False, placeholder="뭐든 물어보세요")
|
195 |
|
196 |
+
# submit
|
197 |
+
txt.submit(fn=gradio_predict, inputs=[txt, chat_state], outputs=[chatbot, chat_state])
|
198 |
return demo
|
199 |
|
200 |
app = create_app()
|
201 |
|
202 |
if __name__ == "__main__":
|
203 |
+
# 실제 배포/실행
|
204 |
app.launch(debug=True, share=True)
|