Spaces:
Sleeping
Sleeping
import json | |
import numpy as np | |
import faiss | |
import re | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
from sentence_transformers import SentenceTransformer | |
from prompt import PROMPTS | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def normalized_embedding(emb: np.ndarray) -> np.ndarray: | |
return emb / np.linalg.norm(emb) | |
def load_embeddings(emb_path: str) -> dict[str, np.ndarray]: | |
raw = json.load(open(emb_path, encoding="utf-8")) | |
return {k: np.array(v, dtype="float32") for k, v in raw.items()} | |
def build_faiss_index_from_embeddings(emb_map: dict[str, np.ndarray]): | |
keys = list(emb_map.keys()) | |
matrix = np.stack([normalized_embedding(emb_map[k]) for k in keys]).astype("float32") | |
index = faiss.IndexFlatIP(matrix.shape[1]) | |
index.add(matrix) | |
return index, keys | |
def load_value_segments(value_path: str) -> dict[str, dict[str,str]]: | |
return json.load(open(value_path, encoding="utf-8")) | |
def generate_answer(tokenizer, model, system_prompt:str, query: str, context: str = "", conversation_history=None) -> str: | |
B = "<|begin_of_text|>" | |
SS = "<|start_header_id|>system<|end_header_id|>" | |
SU = "<|start_header_id|>user<|end_header_id|>" | |
SA = "<|start_header_id|>assistant<|end_header_id|>" | |
EOT = "<|eot_id|>" | |
system_block = f"{B}\n{SS}\n{system_prompt}\n{EOT}\n" | |
conv_text = "" | |
if conversation_history: | |
for msg in conversation_history: | |
role = msg["role"] | |
content = msg["content"].strip() | |
tag = SU if role=="user" else SA | |
conv_text += f"{tag}\n{content}\n{EOT}\n" | |
if context: | |
user_block = f"{query}\n\n### μΈλΆ μ§μ ###\n{context}" | |
else: | |
user_block = query | |
conv_text += f"{SU}\n{user_block}\n{EOT}\n" | |
conv_text += f"{SA}\n" | |
prompt = system_block + conv_text | |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(model.device) | |
out = model.generate( | |
**inputs, | |
max_new_tokens=512, | |
do_sample=True, | |
temperature=0.6, | |
top_p=0.8, | |
) | |
decoded = tokenizer.decode(out[0], skip_special_tokens=False) | |
answer = decoded.split(prompt, 1)[-1] | |
for tok in [B, SS, SU, SA, EOT]: | |
answer = answer.replace(tok, "") | |
return answer.strip() | |
def post_process_answer(raw: str, prev_answer: str = "") -> str: | |
if not raw: | |
return "μ 곡λ λ΅λ³μ΄ μμ΅λλ€." | |
m = re.search( | |
r"\<\|start_header_id\>assistant\<\|end_header_id\>(.*?)\<\|eot_id\>", | |
raw, re.DOTALL | |
) | |
if m: | |
ans = m.group(1).strip() | |
else: | |
ans = raw.strip() | |
ans = re.sub(r"\<\|.*?\|\>", "", ans).strip() | |
if ans.lower().count("assistant") >= 4: | |
return "μ 곡λ λ΅λ³μ΄ μμ΅λλ€." | |
if not ans or ans == prev_answer.strip(): | |
return "μ 곡λ λ΅λ³μ΄ μμ΅λλ€." | |
return ans | |
def answer_query( | |
query: str, | |
emb_key_path: str, | |
value_text_path: str, | |
tokenizer, | |
model, | |
system_prompt: str, | |
rag_model, | |
conversation_history=None, | |
threshold: float = 0.65 | |
) -> str: | |
emb_map, _ = load_embeddings(emb_key_path), None | |
index, keys= build_faiss_index_from_embeddings(emb_map) | |
value_map = load_value_segments(value_text_path) | |
q_emb = rag_model.encode(query, convert_to_tensor=True).cpu().numpy().squeeze() | |
q_norm= normalized_embedding(q_emb).astype("float32").reshape(1,-1) | |
D, I = index.search(q_norm, 1) | |
score, idx = float(D[0,0]), int(I[0,0]) | |
if score >= threshold: | |
full_key = keys[idx] | |
file_key, seg_id = full_key.rsplit("_",1) | |
context = value_map[file_key]["segments"].get(seg_id, "") | |
print(f"β μ μ¬λ: {score:.4f}, context μ€λΉλ¨ β '{context[:30]}β¦'") | |
else: | |
context = "" | |
print(f"β μ μ¬λ {score:.4f} < {threshold} β μΈλΆ μ§μ λ―Έμ¬μ©") | |
raw = generate_answer( | |
tokenizer=tokenizer, | |
model=model, | |
system_prompt=system_prompt, | |
query=query, | |
context=context, | |
conversation_history=conversation_history | |
) | |
answer_text = post_process_answer(raw) | |
print(f"\nβ Answer: {answer_text}\n") | |
return answer_text | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
EMB_KEY_PATH = "staria_keys_embed.json" | |
VALUE_TEXT_PATH = "staria_values.json" | |
MODEL_ID = "JLee0/staria-pdf-chatbot-lora" | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="right") | |
tokenizer.pad_token = tokenizer.eos_token | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
load_in_8bit=True, # λλ 4bit | |
device_map="auto" | |
) | |
rag_embedder = SentenceTransformer("JLee0/rag-embedder-staria-10epochs") | |
SYSTEM_PROMPT = PROMPTS["staria_after"] | |
def chat(query, history): | |
conv = [] | |
for u, a in history or []: | |
conv.append({"role":"user", "content":u}) | |
conv.append({"role":"assistant", "content":a}) | |
return answer_query( | |
query=query, | |
emb_key_path=EMB_KEY_PATH, | |
value_text_path=VALUE_TEXT_PATH, | |
tokenizer=tokenizer, | |
model=model, | |
system_prompt=SYSTEM_PROMPT, | |
rag_model=rag_embedder, | |
conversation_history=conv, | |
threshold=0.65 | |
) | |
demo = gr.ChatInterface( | |
fn=chat, | |
examples=[ | |
["μμ§ μ€μΌ κ΅μ²΄ μ μ£Όμν΄μΌ ν μ¬νμ 무μμΈκ°μ?"], | |
["λΉνΈμΈ μΊ λ°μ΄ν°λ μ΄λ»κ² μ²λ¦¬νλ©΄ λΌ?"], | |
["μμ μ‘ λΆμΆ κΈ°λ₯μ μ¬μ©ν ν μ€μμΉλ₯Ό λ€μ μμμΉλ‘ λλ €μΌ νλμ?"], | |
["μ°¨λ μλμ κΊΌλ μμ΄μ»¨ μ€μ μ΄ μ μ§λλμ?"] | |
], | |
title="νλ μ€ν리μ Q&A μ±λ΄", | |
description="μ±λ΄μ μ€μ κ²μ νμν©λλ€! μ§λ¬Έμ μ λ ₯ν΄ μ£ΌμΈμ." | |
) | |
if __name__ == "__main__": | |
demo.launch() | |