palisade / src /rag_chain.py
Jina Camellia Yoo
upload full project structure
01661a1
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain_community.llms import Ollama
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document
from typing import List, Dict, Any
import os
import time
class VehicleManualRAG:
"""
์ฐจ๋Ÿ‰ ๋งค๋‰ด์–ผ Q&A๋ฅผ ์œ„ํ•œ RAG ์‹œ์Šคํ…œ
"""
def __init__(self, vector_store: FAISS, use_ollama: bool = True):
self.vector_store = vector_store
# LLM ์„ค์ •
if use_ollama:
print("๐Ÿค– Ollama ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ์ค‘...")
print(" (Ollama๊ฐ€ ์„ค์น˜๋˜์–ด ์žˆ์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค)")
print(" ์„ค์น˜: https://ollama.ai")
# Ollama ๋ชจ๋ธ (ํ•œ๊ตญ์–ด ์ž˜ํ•˜๋Š” ๋ชจ๋ธ)
self.llm = Ollama(
model="llama3.2:3b", # ๋˜๋Š” "gemma2:2b", "mistral" ๋“ฑ
temperature=0.3, # ๋‚ฎ์„์ˆ˜๋ก ์ผ๊ด€๋œ ๋‹ต๋ณ€
num_ctx=4096, # ์ปจํ…์ŠคํŠธ ์œˆ๋„์šฐ
)
else:
# OpenAI ์‚ฌ์šฉ์‹œ (API ํ‚ค ํ•„์š”)
from langchain_openai import ChatOpenAI
self.llm = ChatOpenAI(
model="gpt-3.5-turbo",
temperature=0.3,
api_key=os.getenv("OPENAI_API_KEY")
)
# ํ”„๋กฌํ”„ํŠธ ํ…œํ”Œ๋ฆฟ ์„ค์ •
self.prompt_template = self._create_prompt_template()
# RAG ์ฒด์ธ ์ƒ์„ฑ
self.qa_chain = self._create_qa_chain()
def _create_prompt_template(self) -> PromptTemplate:
template = """๋‹น์‹ ์€ ํ˜„๋Œ€ ํŒฐ๋ฆฌ์„ธ์ด๋“œ ์ฐจ๋Ÿ‰ ์ „๋ฌธ๊ฐ€์ž…๋‹ˆ๋‹ค.
์•„๋ž˜ ์ฐจ๋Ÿ‰ ๋งค๋‰ด์–ผ ๋‚ด์šฉ์„ ์ฐธ๊ณ ํ•˜์—ฌ ์งˆ๋ฌธ์— ๋‹ต๋ณ€ํ•ด์ฃผ์„ธ์š”.
๋งค๋‰ด์–ผ ๋‚ด์šฉ:
{context}
์งˆ๋ฌธ: {question}
๋‹ต๋ณ€ ์ง€์นจ:
1. ๋งค๋‰ด์–ผ์— ์žˆ๋Š” ๋‚ด์šฉ๋งŒ ๋‹ต๋ณ€ํ•˜์„ธ์š”
2. ๊ตฌ์ฒด์ ์ธ ์ˆ˜์น˜๋‚˜ ๋ฐฉ๋ฒ•์ด ์žˆ๋‹ค๋ฉด ์ •ํ™•ํžˆ ์ œ์‹œํ•˜์„ธ์š”
3. ๋งค๋‰ด์–ผ์— ์—†๋Š” ๋‚ด์šฉ์ด๋ฉด "๋งค๋‰ด์–ผ์—์„œ ํ•ด๋‹น ์ •๋ณด๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค"๋ผ๊ณ  ๋‹ต๋ณ€ํ•˜์„ธ์š”
4. ํ•œ๊ตญ์–ด๋กœ ์นœ์ ˆํ•˜๊ณ  ๋ช…ํ™•ํ•˜๊ฒŒ ๋‹ต๋ณ€ํ•˜์„ธ์š”
๋‹ต๋ณ€:"""
return PromptTemplate(
template=template,
input_variables=["context", "question"]
)
def _create_qa_chain(self) -> RetrievalQA:
# ์ฒด์ธ ํƒ€์ž… ์„ค์ •
chain_type_kwargs = {
"prompt": self.prompt_template,
"verbose": False # True๋กœ ํ•˜๋ฉด ์ค‘๊ฐ„ ๊ณผ์ • ์ถœ๋ ฅ
}
# RetrievalQA ์ฒด์ธ ์ƒ์„ฑ
qa_chain = RetrievalQA.from_chain_type(
llm=self.llm,
chain_type="stuff", # ๋ชจ๋“  ๋ฌธ์„œ๋ฅผ ํ•œ๋ฒˆ์— ์ฒ˜๋ฆฌ
retriever=self.vector_store.as_retriever(
search_kwargs={"k": 5} # ์ƒ์œ„ 5๊ฐœ ์ฒญํฌ ๊ฒ€์ƒ‰
),
chain_type_kwargs=chain_type_kwargs,
return_source_documents=True # ์ถœ์ฒ˜ ๋ฌธ์„œ๋„ ๋ฐ˜ํ™˜
)
return qa_chain
def answer_question(self, question: str) -> Dict[str, Any]:
"""
์งˆ๋ฌธ์— ๋Œ€ํ•œ ๋‹ต๋ณ€ ์ƒ์„ฑ
Args:
question: ์‚ฌ์šฉ์ž ์งˆ๋ฌธ
Returns:
๋‹ต๋ณ€๊ณผ ์ถœ์ฒ˜ ์ •๋ณด๋ฅผ ๋‹ด์€ ๋”•์…”๋„ˆ๋ฆฌ
"""
print(f"\nโ“ ์งˆ๋ฌธ: {question}")
print("๐Ÿ” ๊ด€๋ จ ๋‚ด์šฉ ๊ฒ€์ƒ‰ ์ค‘...")
start_time = time.time()
try:
# RAG ์ฒด์ธ ์‹คํ–‰
result = self.qa_chain.invoke({"query": question})
elapsed_time = time.time() - start_time
# ๊ฒฐ๊ณผ ์ •๋ฆฌ
answer = result.get("result", "๋‹ต๋ณ€์„ ์ƒ์„ฑํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
source_documents = result.get("source_documents", [])
# ์ถœ์ฒ˜ ํŽ˜์ด์ง€ ์ถ”์ถœ
source_pages = []
for doc in source_documents:
page = doc.metadata.get("page", "N/A")
if page not in source_pages and page != "N/A":
source_pages.append(page)
response = {
"question": question,
"answer": answer,
"source_pages": source_pages,
"response_time": elapsed_time,
"source_documents": source_documents
}
return response
except Exception as e:
print(f"โŒ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
# Ollama๊ฐ€ ์„ค์น˜๋˜์ง€ ์•Š์€ ๊ฒฝ์šฐ ๊ฐ„๋‹จํ•œ ๋Œ€์ฒด ๋ฐฉ๋ฒ•
print("\n๐Ÿ’ก Ollama ์—†์ด ๊ฐ„๋‹จํ•œ ๋‹ต๋ณ€ ์ƒ์„ฑ ์ค‘...")
return self._simple_answer(question)
def _simple_answer(self, question: str) -> Dict[str, Any]:
"""
LLM ์—†์ด ๊ฐ„๋‹จํ•œ ํ‚ค์›Œ๋“œ ๊ธฐ๋ฐ˜ ๋‹ต๋ณ€ (๋Œ€์ฒด ๋ฐฉ๋ฒ•)
"""
# ๊ด€๋ จ ๋ฌธ์„œ ๊ฒ€์ƒ‰
docs = self.vector_store.similarity_search(question, k=3)
if not docs:
return {
"question": question,
"answer": "๊ด€๋ จ ์ •๋ณด๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.",
"source_pages": [],
"response_time": 0,
"source_documents": []
}
# ๊ฐ„๋‹จํ•œ ๊ทœ์น™ ๊ธฐ๋ฐ˜ ๋‹ต๋ณ€ ์ƒ์„ฑ
answer_parts = []
keywords = {
"์—”์ง„์˜ค์ผ": "์—”์ง„์˜ค์ผ์€ 5,000km ๋˜๋Š” 6๊ฐœ์›”๋งˆ๋‹ค ๊ต์ฒด๋ฅผ ๊ถŒ์žฅํ•ฉ๋‹ˆ๋‹ค.",
"ํƒ€์ด์–ด ๊ณต๊ธฐ์••": "ํƒ€์ด์–ด ๊ณต๊ธฐ์••์€ ์ฐจ๋Ÿ‰ ๋„์–ด ์•ˆ์ชฝ ๋ผ๋ฒจ์„ ์ฐธ์กฐํ•˜์„ธ์š”. ์ผ๋ฐ˜์ ์œผ๋กœ 32-35 psi์ž…๋‹ˆ๋‹ค.",
"์™€์ดํผ": "์™€์ดํผ๋Š” 6๊ฐœ์›”-1๋…„๋งˆ๋‹ค ๊ต์ฒด๋ฅผ ๊ถŒ์žฅํ•ฉ๋‹ˆ๋‹ค.",
"๋ธŒ๋ ˆ์ดํฌ": "๋ธŒ๋ ˆ์ดํฌ ํŒจ๋“œ๋Š” ์ฃผํ–‰๊ฑฐ๋ฆฌ 30,000-50,000km๋งˆ๋‹ค ์ ๊ฒ€์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.",
"๋ฐฐํ„ฐ๋ฆฌ": "๋ฐฐํ„ฐ๋ฆฌ๋Š” 3-5๋…„๋งˆ๋‹ค ๊ต์ฒด๋ฅผ ๊ถŒ์žฅํ•ฉ๋‹ˆ๋‹ค."
}
# ํ‚ค์›Œ๋“œ ๋งค์นญ
for keyword, info in keywords.items():
if keyword in question:
answer_parts.append(info)
break
# ๊ฒ€์ƒ‰๋œ ๋‚ด์šฉ ์ถ”๊ฐ€
answer_parts.append("\n\n๊ด€๋ จ ๋งค๋‰ด์–ผ ๋‚ด์šฉ:")
for i, doc in enumerate(docs[:2], 1):
content = doc.page_content[:200]
page = doc.metadata.get("page", "N/A")
answer_parts.append(f"\n{i}. (ํŽ˜์ด์ง€ {page}) {content}...")
return {
"question": question,
"answer": "\n".join(answer_parts),
"source_pages": [doc.metadata.get("page", "N/A") for doc in docs],
"response_time": 0.1,
"source_documents": docs
}
def batch_questions(self, questions: List[str]) -> List[Dict[str, Any]]:
"""
์—ฌ๋Ÿฌ ์งˆ๋ฌธ์„ ํ•œ๋ฒˆ์— ์ฒ˜๋ฆฌ
"""
results = []
for question in questions:
result = self.answer_question(question)
results.append(result)
print(f"\nโœ… ๋‹ต๋ณ€: {result['answer'][:200]}...")
print(f"๐Ÿ“„ ์ถœ์ฒ˜: ํŽ˜์ด์ง€ {', '.join(map(str, result['source_pages']))}")
print(f"โฑ๏ธ ์‘๋‹ต์‹œ๊ฐ„: {result['response_time']:.2f}์ดˆ")
print("-" * 50)
return results
# ํ…Œ์ŠคํŠธ ์ฝ”๋“œ
if __name__ == "__main__":
"""
์‚ฌ์šฉ ์˜ˆ์‹œ ๋ฐ ํ…Œ์ŠคํŠธ
"""
from embeddings import VehicleManualEmbeddings
import os
# ๊ฒฝ๋กœ ์„ค์ •
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_dir)
index_path = os.path.join(project_root, "data", "faiss_index")
print("=" * 60)
print("๐Ÿš— ์ฐจ๋Ÿ‰ ๋งค๋‰ด์–ผ RAG Q&A ์‹œ์Šคํ…œ")
print("=" * 60)
# 1. ๋ฒกํ„ฐ ์ €์žฅ์†Œ ๋กœ๋“œ
print("\n1๏ธโƒฃ ๋ฒกํ„ฐ ์ธ๋ฑ์Šค ๋กœ๋”ฉ...")
embedder = VehicleManualEmbeddings()
vector_store = embedder.load_index()
# 2. RAG ์‹œ์Šคํ…œ ์ดˆ๊ธฐํ™”
print("\n2๏ธโƒฃ RAG ์‹œ์Šคํ…œ ์ดˆ๊ธฐํ™”...")
rag = VehicleManualRAG(vector_store, use_ollama=False) # Ollama ์—†์ด ํ…Œ์ŠคํŠธ
# 3. ํ…Œ์ŠคํŠธ ์งˆ๋ฌธ๋“ค
print("\n3๏ธโƒฃ Q&A ํ…Œ์ŠคํŠธ ์‹œ์ž‘")
print("=" * 60)
test_questions = [
"์—”์ง„์˜ค์ผ ๊ต์ฒด ์ฃผ๊ธฐ๋Š” ์–ผ๋งˆ๋‚˜ ๋˜๋‚˜์š”?",
"ํƒ€์ด์–ด ์ ์ • ๊ณต๊ธฐ์••์€ ์–ผ๋งˆ์ธ๊ฐ€์š”?",
"์™€์ดํผ๋ฅผ ์–ด๋–ป๊ฒŒ ๊ต์ฒดํ•˜๋‚˜์š”?",
"๋ธŒ๋ ˆ์ดํฌ ํŒจ๋“œ๋Š” ์–ธ์ œ ๊ต์ฒดํ•ด์•ผ ํ•˜๋‚˜์š”?",
"๊ฒฝ๊ณ ๋“ฑ์ด ์ผœ์กŒ์„ ๋•Œ ์–ด๋–ป๊ฒŒ ํ•ด์•ผ ํ•˜๋‚˜์š”?"
]
# ์งˆ๋ฌธ ์ฒ˜๋ฆฌ
results = rag.batch_questions(test_questions[:3]) # ์ฒ˜์Œ 3๊ฐœ๋งŒ
# 4. ๊ฒฐ๊ณผ ์š”์•ฝ
print("\n" + "=" * 60)
print("๐Ÿ“Š ํ…Œ์ŠคํŠธ ๊ฒฐ๊ณผ ์š”์•ฝ")
print("=" * 60)
for result in results:
print(f"\nQ: {result['question']}")
print(f"A: {result['answer'][:100]}...")
print(f"์ถœ์ฒ˜: {len(result['source_pages'])}๊ฐœ ํŽ˜์ด์ง€")
print("\nโœ… RAG ์‹œ์Šคํ…œ ํ…Œ์ŠคํŠธ ์™„๋ฃŒ!")
print("๐Ÿ’ก Tip: Ollama๋ฅผ ์„ค์น˜ํ•˜๋ฉด ๋” ์ •ํ™•ํ•œ ๋‹ต๋ณ€์„ ์–ป์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.")
print(" ์„ค์น˜: https://ollama.ai")