my-staria-space / app.py
JLee0's picture
Update chat interface logic
da9a0af
import json
import numpy as np
import faiss
import re
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from sentence_transformers import SentenceTransformer
from prompt import PROMPTS
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
def normalized_embedding(emb: np.ndarray) -> np.ndarray:
return emb / np.linalg.norm(emb)
def load_embeddings(emb_path: str) -> dict[str, np.ndarray]:
raw = json.load(open(emb_path, encoding="utf-8"))
return {k: np.array(v, dtype="float32") for k, v in raw.items()}
def build_faiss_index_from_embeddings(emb_map: dict[str, np.ndarray]):
keys = list(emb_map.keys())
matrix = np.stack([normalized_embedding(emb_map[k]) for k in keys]).astype("float32")
index = faiss.IndexFlatIP(matrix.shape[1])
index.add(matrix)
return index, keys
def load_value_segments(value_path: str) -> dict[str, dict[str,str]]:
return json.load(open(value_path, encoding="utf-8"))
def generate_answer(tokenizer, model, system_prompt:str, query: str, context: str = "", conversation_history=None) -> str:
B = "<|begin_of_text|>"
SS = "<|start_header_id|>system<|end_header_id|>"
SU = "<|start_header_id|>user<|end_header_id|>"
SA = "<|start_header_id|>assistant<|end_header_id|>"
EOT = "<|eot_id|>"
system_block = f"{B}\n{SS}\n{system_prompt}\n{EOT}\n"
conv_text = ""
if conversation_history:
for msg in conversation_history:
role = msg["role"]
content = msg["content"].strip()
tag = SU if role=="user" else SA
conv_text += f"{tag}\n{content}\n{EOT}\n"
if context:
user_block = f"{query}\n\n### μ™ΈλΆ€ 지식 ###\n{context}"
else:
user_block = query
conv_text += f"{SU}\n{user_block}\n{EOT}\n"
conv_text += f"{SA}\n"
prompt = system_block + conv_text
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(model.device)
out = model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.6,
top_p=0.8,
)
decoded = tokenizer.decode(out[0], skip_special_tokens=False)
answer = decoded.split(prompt, 1)[-1]
for tok in [B, SS, SU, SA, EOT]:
answer = answer.replace(tok, "")
return answer.strip()
def post_process_answer(raw: str, prev_answer: str = "") -> str:
if not raw:
return "제곡된 닡변이 μ—†μŠ΅λ‹ˆλ‹€."
m = re.search(
r"\<\|start_header_id\>assistant\<\|end_header_id\>(.*?)\<\|eot_id\>",
raw, re.DOTALL
)
if m:
ans = m.group(1).strip()
else:
ans = raw.strip()
ans = re.sub(r"\<\|.*?\|\>", "", ans).strip()
if ans.lower().count("assistant") >= 4:
return "제곡된 닡변이 μ—†μŠ΅λ‹ˆλ‹€."
if not ans or ans == prev_answer.strip():
return "제곡된 닡변이 μ—†μŠ΅λ‹ˆλ‹€."
return ans
def answer_query(
query: str,
emb_key_path: str,
value_text_path: str,
tokenizer,
model,
system_prompt: str,
rag_model,
conversation_history=None,
threshold: float = 0.65
) -> str:
emb_map, _ = load_embeddings(emb_key_path), None
index, keys= build_faiss_index_from_embeddings(emb_map)
value_map = load_value_segments(value_text_path)
q_emb = rag_model.encode(query, convert_to_tensor=True).cpu().numpy().squeeze()
q_norm= normalized_embedding(q_emb).astype("float32").reshape(1,-1)
D, I = index.search(q_norm, 1)
score, idx = float(D[0,0]), int(I[0,0])
if score >= threshold:
full_key = keys[idx]
file_key, seg_id = full_key.rsplit("_",1)
context = value_map[file_key]["segments"].get(seg_id, "")
print(f"βœ… μœ μ‚¬λ„: {score:.4f}, context 쀀비됨 β†’ '{context[:30]}…'")
else:
context = ""
print(f"❌ μœ μ‚¬λ„ {score:.4f} < {threshold} β†’ μ™ΈλΆ€ 지식 λ―Έμ‚¬μš©")
raw = generate_answer(
tokenizer=tokenizer,
model=model,
system_prompt=system_prompt,
query=query,
context=context,
conversation_history=conversation_history
)
answer_text = post_process_answer(raw)
print(f"\nβœ… Answer: {answer_text}\n")
return answer_text
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
EMB_KEY_PATH = "staria_keys_embed.json"
VALUE_TEXT_PATH = "staria_values.json"
MODEL_ID = "JLee0/staria-pdf-chatbot-lora"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="right")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
load_in_8bit=True, # λ˜λŠ” 4bit
device_map="auto"
)
rag_embedder = SentenceTransformer("JLee0/rag-embedder-staria-10epochs")
SYSTEM_PROMPT = PROMPTS["staria_after"]
def chat(query, history):
conv = []
for u, a in history or []:
conv.append({"role":"user", "content":u})
conv.append({"role":"assistant", "content":a})
return answer_query(
query=query,
emb_key_path=EMB_KEY_PATH,
value_text_path=VALUE_TEXT_PATH,
tokenizer=tokenizer,
model=model,
system_prompt=SYSTEM_PROMPT,
rag_model=rag_embedder,
conversation_history=conv,
threshold=0.65
)
demo = gr.ChatInterface(
fn=chat,
examples=[
["μ—”μ§„ 였일 ꡐ체 μ‹œ μ£Όμ˜ν•΄μ•Ό ν•  사항은 λ¬΄μ—‡μΈκ°€μš”?"],
["빌트인 μΊ  λ°μ΄ν„°λŠ” μ–΄λ–»κ²Œ μ²˜λ¦¬ν•˜λ©΄ 돼?"],
["와셔앑 λΆ„μΆœ κΈ°λŠ₯을 μ‚¬μš©ν•œ ν›„ μŠ€μœ„μΉ˜λ₯Ό λ‹€μ‹œ μ›μœ„μΉ˜λ‘œ λŒλ €μ•Ό ν•˜λ‚˜μš”?"],
["μ°¨λŸ‰ μ‹œλ™μ„ 꺼도 에어컨 섀정이 μœ μ§€λ˜λ‚˜μš”?"]
],
title="ν˜„λŒ€ μŠ€νƒ€λ¦¬μ•„ Q&A 챗봇",
description="챗봇에 μ˜€μ‹  것을 ν™˜μ˜ν•©λ‹ˆλ‹€! μ§ˆλ¬Έμ„ μž…λ ₯ν•΄ μ£Όμ„Έμš”."
)
if __name__ == "__main__":
demo.launch()