brignt commited on
Commit
5d1455b
·
verified ·
1 Parent(s): 3aac496

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -52
app.py CHANGED
@@ -6,23 +6,33 @@ import numpy as np
6
  from sentence_transformers import SentenceTransformer
7
  from sklearn.metrics.pairwise import cosine_similarity
8
 
9
- # ===== 0) OpenAI API Key (Secrets) =====
10
- # Hugging Face Spaces에선 Settings -> Repository secrets -> OPENAI_API_KEY 등록
11
- openai.api_key = os.getenv("OPENAI_API_KEY")
12
 
13
- # ===== 1) 모델 & 데이터프레임 로드 =====
14
- # Example: jhgan/ko-sroberta-multitask
15
- model = SentenceTransformer("jhgan/ko-sroberta-multitask")
16
 
17
- # 임의 예시: 세브란스 정신의학챗봇 데이터 (URL)
18
- df = pd.read_csv("https://raw.githubusercontent.com/kairess/mental-health-chatbot/master/wellness_dataset_original.csv")
19
  df = df.dropna()
20
- df["embedding"] = df["유저"].map(lambda x: model.encode(str(x)))
 
 
21
 
22
- # ===== 하이퍼파라미터 =====
23
- MAX_TURN = 5 # 소크라테스식 질문 최대 횟수
 
 
 
 
 
 
 
 
 
 
24
 
25
- # ===== 프롬프트 =====
26
  EMPATHY_PROMPT = """\
27
  당신은 친절한 정신의학과 전문의이며 심리상담 전문가입니다.
28
  사용자의 문장을 거의 그대로 요약하되, 끝에 '는군요.' 같은 공감 어미를 붙여 자연스럽게 응답하세요.
@@ -59,24 +69,10 @@ ADVICE_PROMPT = """\
59
  조언:
60
  """
61
 
62
- def set_openai_model():
63
- """
64
- 유저 요청대로 'gpt-4o' 모델명 반환
65
- (실제로는 존재하지 않을 가능성 큼)
66
- """
67
- return "gpt-4o"
68
-
69
- # ===== 함수들 =====
70
-
71
- def kb_search(user_input: str) -> str:
72
- """SentenceTransformer로 임베딩 후, df에서 가장 유사한 챗봇 답변 획득."""
73
- emb = model.encode(user_input)
74
- df["sim"] = df["embedding"].map(lambda e: cosine_similarity([emb],[e]).squeeze())
75
- idx = df["sim"].idxmax()
76
- return df.loc[idx, "챗봇"]
77
 
78
  def call_empathy(user_input: str) -> str:
79
- """EMPATHY 단계: 공감 요약."""
80
  prompt = EMPATHY_PROMPT.format(sentence=user_input)
81
  resp = openai.ChatCompletion.create(
82
  model=set_openai_model(),
@@ -90,7 +86,7 @@ def call_empathy(user_input: str) -> str:
90
  return resp.choices[0].message.content.strip()
91
 
92
  def call_socratic_question(context: str) -> str:
93
- """SQ 단계: 후속 질문 문장 생성."""
94
  prompt = f"{SOCRATIC_PROMPT}\n\n대화 힌트:\n{context}"
95
  resp = openai.ChatCompletion.create(
96
  model=set_openai_model(),
@@ -104,7 +100,7 @@ def call_socratic_question(context: str) -> str:
104
  return resp.choices[0].message.content.strip()
105
 
106
  def call_advice(hints: str) -> str:
107
- """ADVICE 단계: CBT 조언 생성."""
108
  final_prompt = ADVICE_PROMPT.format(hints=hints)
109
  resp = openai.ChatCompletion.create(
110
  model=set_openai_model(),
@@ -117,8 +113,8 @@ def call_advice(hints: str) -> str:
117
  )
118
  return resp.choices[0].message.content.strip()
119
 
 
120
  def predict(user_input: str, state: dict):
121
- """Gradio Callback: 소크라테스 CBT 챗봇 흐름 (EMPATHY→SQ→ADVICE)."""
122
  history = state.get("history", [])
123
  stage = state.get("stage", "EMPATHY")
124
  turn = state.get("turn", 0)
@@ -127,11 +123,20 @@ def predict(user_input: str, state: dict):
127
  # 1) 사용자 발화 기록
128
  history.append(("User", user_input))
129
 
130
- # 2) KB 검색hints
131
- kb_answer = kb_search(user_input)
 
 
 
 
 
 
 
 
 
132
  hints.append(f"[KB] {kb_answer}")
133
 
134
- # 3) 단계 분기
135
  if stage == "EMPATHY":
136
  empathic = call_empathy(user_input)
137
  history.append(("Chatbot", empathic))
@@ -141,7 +146,7 @@ def predict(user_input: str, state: dict):
141
  return history, {"history": history, "stage": stage, "turn": turn, "hints": hints}
142
 
143
  if stage == "SQ" and turn < MAX_TURN:
144
- # 전체 대화 + hints 합쳐 context
145
  context_text = "\n".join([f"{r}: {c}" for (r,c) in history]) + "\n" + "\n".join(hints)
146
  sq = call_socratic_question(context_text)
147
  history.append(("Chatbot", sq))
@@ -150,50 +155,50 @@ def predict(user_input: str, state: dict):
150
  return history, {"history": history, "stage": stage, "turn": turn, "hints": hints}
151
 
152
  # ADVICE 단계
153
- stage = "ADVICE"
154
  combined_hints = "\n".join(hints)
155
  advice = call_advice(combined_hints)
156
  history.append(("Chatbot", advice))
157
- stage = "END"
158
  return history, {"history":history, "stage":stage, "turn":turn, "hints":hints}
159
 
 
160
  def gradio_predict(user_input, chat_state):
161
- """Gradio에서 user_input, state를 받아 predict → (chatbot 출력, state 갱신)."""
162
  new_history, new_state = predict(user_input, chat_state)
163
-
164
- # display_history: list of (user, assistant)
165
  display_history = []
166
  for (role, txt) in new_history:
167
  if role == "User":
168
- display_history.append([txt, ""])
169
- else: # Chatbot
170
- if not display_history:
171
  display_history.append(["", txt])
172
- elif display_history[-1][1] == "":
173
- display_history[-1][1] = txt
174
  else:
175
- display_history.append(["", txt])
 
176
  return display_history, new_state
177
 
178
  def create_app():
179
- """Gradio Blocks UI 구성."""
180
  with gr.Blocks() as demo:
181
- gr.Markdown("## 🏥 소크라테스 CBT 챗봇 (GPT-4o)")
182
 
183
- chatbot = gr.Chatbot(label="Socratic CBT Chatbot")
184
  chat_state = gr.State({
185
  "history": [],
186
  "stage":"EMPATHY",
187
  "turn":0,
188
  "hints":[]
189
  })
190
- txt = gr.Textbox(show_label=False, placeholder="고민이나 궁금한 점을 입력하세요.")
191
 
192
- txt.submit(fn=gradio_predict, inputs=[txt, chat_state], outputs=[chatbot, chat_state], scroll_to_output=True)
 
193
  return demo
194
 
195
  app = create_app()
196
 
197
  if __name__ == "__main__":
198
- # Launch Gradio app
199
  app.launch(debug=True, share=True)
 
6
  from sentence_transformers import SentenceTransformer
7
  from sklearn.metrics.pairwise import cosine_similarity
8
 
9
+ # OpenAI API Key (Hugging Face Secrets)
10
+ openai.api_key = os.getenv("OPENAI_API_KEY", "")
 
11
 
12
+ # =============== 0) 모델 / df 준비 ===============
13
+ # SentenceTransformer
14
+ model = SentenceTransformer('jhgan/ko-sroberta-multitask')
15
 
16
+ # 정신의학챗봇 CSV 로드
17
+ df = pd.read_csv('https://raw.githubusercontent.com/kairess/mental-health-chatbot/master/wellness_dataset_original.csv')
18
  df = df.dropna()
19
+ # Unnamed 컬럼 제거
20
+ if 'Unnamed: 3' in df.columns:
21
+ df = df.drop(columns=['Unnamed: 3'])
22
 
23
+ # 임베딩 필드
24
+ df['embedding'] = df['유저'].map(lambda x: model.encode(str(x)))
25
+
26
+ # ============== 1) 파라미터/프롬프트 ==============
27
+ MAX_TURN = 5 # 최대 소크라테스 질문 회수
28
+
29
+ def set_openai_model():
30
+ """
31
+ GPT-4 대신 'gpt-4o' (실제론 비존재 모델)
32
+ => 실제로는 'gpt-3.5-turbo' 등으로 교체 권장
33
+ """
34
+ return "gpt-4o"
35
 
 
36
  EMPATHY_PROMPT = """\
37
  당신은 친절한 정신의학과 전문의이며 심리상담 전문가입니다.
38
  사용자의 문장을 거의 그대로 요약하되, 끝에 '는군요.' 같은 공감 어미를 붙여 자연스럽게 응답하세요.
 
69
  조언:
70
  """
71
 
72
+ # ============== 2) OpenAI 호출 함수들 ==============
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  def call_empathy(user_input: str) -> str:
75
+ """ 공감 요약 생성 """
76
  prompt = EMPATHY_PROMPT.format(sentence=user_input)
77
  resp = openai.ChatCompletion.create(
78
  model=set_openai_model(),
 
86
  return resp.choices[0].message.content.strip()
87
 
88
  def call_socratic_question(context: str) -> str:
89
+ """ 소크라테스 후속질문 1문장 생성 """
90
  prompt = f"{SOCRATIC_PROMPT}\n\n대화 힌트:\n{context}"
91
  resp = openai.ChatCompletion.create(
92
  model=set_openai_model(),
 
100
  return resp.choices[0].message.content.strip()
101
 
102
  def call_advice(hints: str) -> str:
103
+ """ 최종 CBT 조언 """
104
  final_prompt = ADVICE_PROMPT.format(hints=hints)
105
  resp = openai.ChatCompletion.create(
106
  model=set_openai_model(),
 
113
  )
114
  return resp.choices[0].message.content.strip()
115
 
116
+ # ============== 3) predict 함수: EMPATHY→SQ→ADVICE ==============
117
  def predict(user_input: str, state: dict):
 
118
  history = state.get("history", [])
119
  stage = state.get("stage", "EMPATHY")
120
  turn = state.get("turn", 0)
 
123
  # 1) 사용자 발화 기록
124
  history.append(("User", user_input))
125
 
126
+ # 2) 유사도 계산df['챗봇']
127
+ query_emb = model.encode(user_input)
128
+ df["sim"] = df["embedding"].map(lambda emb: cosine_similarity([query_emb],[emb]).squeeze())
129
+
130
+ # idxmax() 에러 방지: df가 비었거나 sim이 NaN인 경우 처리
131
+ if df["sim"].count() == 0:
132
+ # fallback: 그냥 "지식베이스가 비어 있습니다" 등
133
+ kb_answer = "적합한 지식베이스 응답을 찾지 못했어요."
134
+ else:
135
+ kb_answer = df.loc[df["sim"].idxmax(), "챗봇"]
136
+
137
  hints.append(f"[KB] {kb_answer}")
138
 
139
+ # 3) 단계별 분기
140
  if stage == "EMPATHY":
141
  empathic = call_empathy(user_input)
142
  history.append(("Chatbot", empathic))
 
146
  return history, {"history": history, "stage": stage, "turn": turn, "hints": hints}
147
 
148
  if stage == "SQ" and turn < MAX_TURN:
149
+ # 전체 대화 + hints 소크라테스 질문
150
  context_text = "\n".join([f"{r}: {c}" for (r,c) in history]) + "\n" + "\n".join(hints)
151
  sq = call_socratic_question(context_text)
152
  history.append(("Chatbot", sq))
 
155
  return history, {"history": history, "stage": stage, "turn": turn, "hints": hints}
156
 
157
  # ADVICE 단계
158
+ stage = "END"
159
  combined_hints = "\n".join(hints)
160
  advice = call_advice(combined_hints)
161
  history.append(("Chatbot", advice))
162
+
163
  return history, {"history":history, "stage":stage, "turn":turn, "hints":hints}
164
 
165
+ # ============== 4) Gradio UI ==============
166
  def gradio_predict(user_input, chat_state):
167
+ """Gradio callback"""
168
  new_history, new_state = predict(user_input, chat_state)
169
+
170
+ # Gradio Chatbot expects list of [user, bot] pairs
171
  display_history = []
172
  for (role, txt) in new_history:
173
  if role == "User":
174
+ display_history.append([txt, ""]) # user in left
175
+ else:
176
+ if len(display_history)==0:
177
  display_history.append(["", txt])
 
 
178
  else:
179
+ display_history[-1][1] = txt # bot in right
180
+
181
  return display_history, new_state
182
 
183
  def create_app():
 
184
  with gr.Blocks() as demo:
185
+ gr.Markdown("## 다중턴 소크라테스 CBT 챗봇")
186
 
187
+ chatbot = gr.Chatbot(label="CBT Chatbot")
188
  chat_state = gr.State({
189
  "history": [],
190
  "stage":"EMPATHY",
191
  "turn":0,
192
  "hints":[]
193
  })
194
+ txt = gr.Textbox(show_label=False, placeholder="뭐든 물어보세요")
195
 
196
+ # submit
197
+ txt.submit(fn=gradio_predict, inputs=[txt, chat_state], outputs=[chatbot, chat_state])
198
  return demo
199
 
200
  app = create_app()
201
 
202
  if __name__ == "__main__":
203
+ # 실제 배포/실행
204
  app.launch(debug=True, share=True)