|
from langchain.chains import RetrievalQA |
|
from langchain.prompts import PromptTemplate |
|
from langchain_community.llms import Ollama |
|
from langchain_community.vectorstores import FAISS |
|
from langchain_core.documents import Document |
|
from typing import List, Dict, Any |
|
import os |
|
import time |
|
|
|
|
|
class VehicleManualRAG: |
|
""" |
|
์ฐจ๋ ๋งค๋ด์ผ Q&A๋ฅผ ์ํ RAG ์์คํ
|
|
""" |
|
|
|
def __init__(self, vector_store: FAISS, use_ollama: bool = True): |
|
|
|
self.vector_store = vector_store |
|
|
|
|
|
if use_ollama: |
|
print("๐ค Ollama ๋ชจ๋ธ ์ด๊ธฐํ ์ค...") |
|
print(" (Ollama๊ฐ ์ค์น๋์ด ์์ด์ผ ํฉ๋๋ค)") |
|
print(" ์ค์น: https://ollama.ai") |
|
|
|
|
|
self.llm = Ollama( |
|
model="llama3.2:3b", |
|
temperature=0.3, |
|
num_ctx=4096, |
|
) |
|
else: |
|
|
|
from langchain_openai import ChatOpenAI |
|
self.llm = ChatOpenAI( |
|
model="gpt-3.5-turbo", |
|
temperature=0.3, |
|
api_key=os.getenv("OPENAI_API_KEY") |
|
) |
|
|
|
|
|
self.prompt_template = self._create_prompt_template() |
|
|
|
|
|
self.qa_chain = self._create_qa_chain() |
|
|
|
def _create_prompt_template(self) -> PromptTemplate: |
|
|
|
template = """๋น์ ์ ํ๋ ํฐ๋ฆฌ์ธ์ด๋ ์ฐจ๋ ์ ๋ฌธ๊ฐ์
๋๋ค. |
|
์๋ ์ฐจ๋ ๋งค๋ด์ผ ๋ด์ฉ์ ์ฐธ๊ณ ํ์ฌ ์ง๋ฌธ์ ๋ต๋ณํด์ฃผ์ธ์. |
|
|
|
๋งค๋ด์ผ ๋ด์ฉ: |
|
{context} |
|
|
|
์ง๋ฌธ: {question} |
|
|
|
๋ต๋ณ ์ง์นจ: |
|
1. ๋งค๋ด์ผ์ ์๋ ๋ด์ฉ๋ง ๋ต๋ณํ์ธ์ |
|
2. ๊ตฌ์ฒด์ ์ธ ์์น๋ ๋ฐฉ๋ฒ์ด ์๋ค๋ฉด ์ ํํ ์ ์ํ์ธ์ |
|
3. ๋งค๋ด์ผ์ ์๋ ๋ด์ฉ์ด๋ฉด "๋งค๋ด์ผ์์ ํด๋น ์ ๋ณด๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค"๋ผ๊ณ ๋ต๋ณํ์ธ์ |
|
4. ํ๊ตญ์ด๋ก ์น์ ํ๊ณ ๋ช
ํํ๊ฒ ๋ต๋ณํ์ธ์ |
|
|
|
๋ต๋ณ:""" |
|
|
|
return PromptTemplate( |
|
template=template, |
|
input_variables=["context", "question"] |
|
) |
|
|
|
def _create_qa_chain(self) -> RetrievalQA: |
|
|
|
|
|
chain_type_kwargs = { |
|
"prompt": self.prompt_template, |
|
"verbose": False |
|
} |
|
|
|
|
|
qa_chain = RetrievalQA.from_chain_type( |
|
llm=self.llm, |
|
chain_type="stuff", |
|
retriever=self.vector_store.as_retriever( |
|
search_kwargs={"k": 5} |
|
), |
|
chain_type_kwargs=chain_type_kwargs, |
|
return_source_documents=True |
|
) |
|
|
|
return qa_chain |
|
|
|
def answer_question(self, question: str) -> Dict[str, Any]: |
|
""" |
|
์ง๋ฌธ์ ๋ํ ๋ต๋ณ ์์ฑ |
|
|
|
Args: |
|
question: ์ฌ์ฉ์ ์ง๋ฌธ |
|
|
|
Returns: |
|
๋ต๋ณ๊ณผ ์ถ์ฒ ์ ๋ณด๋ฅผ ๋ด์ ๋์
๋๋ฆฌ |
|
""" |
|
print(f"\nโ ์ง๋ฌธ: {question}") |
|
print("๐ ๊ด๋ จ ๋ด์ฉ ๊ฒ์ ์ค...") |
|
|
|
start_time = time.time() |
|
|
|
try: |
|
|
|
result = self.qa_chain.invoke({"query": question}) |
|
|
|
elapsed_time = time.time() - start_time |
|
|
|
|
|
answer = result.get("result", "๋ต๋ณ์ ์์ฑํ ์ ์์ต๋๋ค.") |
|
source_documents = result.get("source_documents", []) |
|
|
|
|
|
source_pages = [] |
|
for doc in source_documents: |
|
page = doc.metadata.get("page", "N/A") |
|
if page not in source_pages and page != "N/A": |
|
source_pages.append(page) |
|
|
|
response = { |
|
"question": question, |
|
"answer": answer, |
|
"source_pages": source_pages, |
|
"response_time": elapsed_time, |
|
"source_documents": source_documents |
|
} |
|
|
|
return response |
|
|
|
except Exception as e: |
|
print(f"โ ์ค๋ฅ ๋ฐ์: {e}") |
|
|
|
|
|
print("\n๐ก Ollama ์์ด ๊ฐ๋จํ ๋ต๋ณ ์์ฑ ์ค...") |
|
return self._simple_answer(question) |
|
|
|
def _simple_answer(self, question: str) -> Dict[str, Any]: |
|
""" |
|
LLM ์์ด ๊ฐ๋จํ ํค์๋ ๊ธฐ๋ฐ ๋ต๋ณ (๋์ฒด ๋ฐฉ๋ฒ) |
|
""" |
|
|
|
docs = self.vector_store.similarity_search(question, k=3) |
|
|
|
if not docs: |
|
return { |
|
"question": question, |
|
"answer": "๊ด๋ จ ์ ๋ณด๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค.", |
|
"source_pages": [], |
|
"response_time": 0, |
|
"source_documents": [] |
|
} |
|
|
|
|
|
answer_parts = [] |
|
keywords = { |
|
"์์ง์ค์ผ": "์์ง์ค์ผ์ 5,000km ๋๋ 6๊ฐ์๋ง๋ค ๊ต์ฒด๋ฅผ ๊ถ์ฅํฉ๋๋ค.", |
|
"ํ์ด์ด ๊ณต๊ธฐ์": "ํ์ด์ด ๊ณต๊ธฐ์์ ์ฐจ๋ ๋์ด ์์ชฝ ๋ผ๋ฒจ์ ์ฐธ์กฐํ์ธ์. ์ผ๋ฐ์ ์ผ๋ก 32-35 psi์
๋๋ค.", |
|
"์์ดํผ": "์์ดํผ๋ 6๊ฐ์-1๋
๋ง๋ค ๊ต์ฒด๋ฅผ ๊ถ์ฅํฉ๋๋ค.", |
|
"๋ธ๋ ์ดํฌ": "๋ธ๋ ์ดํฌ ํจ๋๋ ์ฃผํ๊ฑฐ๋ฆฌ 30,000-50,000km๋ง๋ค ์ ๊ฒ์ด ํ์ํฉ๋๋ค.", |
|
"๋ฐฐํฐ๋ฆฌ": "๋ฐฐํฐ๋ฆฌ๋ 3-5๋
๋ง๋ค ๊ต์ฒด๋ฅผ ๊ถ์ฅํฉ๋๋ค." |
|
} |
|
|
|
|
|
for keyword, info in keywords.items(): |
|
if keyword in question: |
|
answer_parts.append(info) |
|
break |
|
|
|
|
|
answer_parts.append("\n\n๊ด๋ จ ๋งค๋ด์ผ ๋ด์ฉ:") |
|
for i, doc in enumerate(docs[:2], 1): |
|
content = doc.page_content[:200] |
|
page = doc.metadata.get("page", "N/A") |
|
answer_parts.append(f"\n{i}. (ํ์ด์ง {page}) {content}...") |
|
|
|
return { |
|
"question": question, |
|
"answer": "\n".join(answer_parts), |
|
"source_pages": [doc.metadata.get("page", "N/A") for doc in docs], |
|
"response_time": 0.1, |
|
"source_documents": docs |
|
} |
|
|
|
def batch_questions(self, questions: List[str]) -> List[Dict[str, Any]]: |
|
""" |
|
์ฌ๋ฌ ์ง๋ฌธ์ ํ๋ฒ์ ์ฒ๋ฆฌ |
|
""" |
|
results = [] |
|
for question in questions: |
|
result = self.answer_question(question) |
|
results.append(result) |
|
print(f"\nโ
๋ต๋ณ: {result['answer'][:200]}...") |
|
print(f"๐ ์ถ์ฒ: ํ์ด์ง {', '.join(map(str, result['source_pages']))}") |
|
print(f"โฑ๏ธ ์๋ต์๊ฐ: {result['response_time']:.2f}์ด") |
|
print("-" * 50) |
|
|
|
return results |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
""" |
|
์ฌ์ฉ ์์ ๋ฐ ํ
์คํธ |
|
""" |
|
from embeddings import VehicleManualEmbeddings |
|
import os |
|
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
project_root = os.path.dirname(current_dir) |
|
index_path = os.path.join(project_root, "data", "faiss_index") |
|
|
|
print("=" * 60) |
|
print("๐ ์ฐจ๋ ๋งค๋ด์ผ RAG Q&A ์์คํ
") |
|
print("=" * 60) |
|
|
|
|
|
print("\n1๏ธโฃ ๋ฒกํฐ ์ธ๋ฑ์ค ๋ก๋ฉ...") |
|
embedder = VehicleManualEmbeddings() |
|
vector_store = embedder.load_index() |
|
|
|
|
|
print("\n2๏ธโฃ RAG ์์คํ
์ด๊ธฐํ...") |
|
rag = VehicleManualRAG(vector_store, use_ollama=False) |
|
|
|
|
|
print("\n3๏ธโฃ Q&A ํ
์คํธ ์์") |
|
print("=" * 60) |
|
|
|
test_questions = [ |
|
"์์ง์ค์ผ ๊ต์ฒด ์ฃผ๊ธฐ๋ ์ผ๋ง๋ ๋๋์?", |
|
"ํ์ด์ด ์ ์ ๊ณต๊ธฐ์์ ์ผ๋ง์ธ๊ฐ์?", |
|
"์์ดํผ๋ฅผ ์ด๋ป๊ฒ ๊ต์ฒดํ๋์?", |
|
"๋ธ๋ ์ดํฌ ํจ๋๋ ์ธ์ ๊ต์ฒดํด์ผ ํ๋์?", |
|
"๊ฒฝ๊ณ ๋ฑ์ด ์ผ์ก์ ๋ ์ด๋ป๊ฒ ํด์ผ ํ๋์?" |
|
] |
|
|
|
|
|
results = rag.batch_questions(test_questions[:3]) |
|
|
|
|
|
print("\n" + "=" * 60) |
|
print("๐ ํ
์คํธ ๊ฒฐ๊ณผ ์์ฝ") |
|
print("=" * 60) |
|
|
|
for result in results: |
|
print(f"\nQ: {result['question']}") |
|
print(f"A: {result['answer'][:100]}...") |
|
print(f"์ถ์ฒ: {len(result['source_pages'])}๊ฐ ํ์ด์ง") |
|
|
|
print("\nโ
RAG ์์คํ
ํ
์คํธ ์๋ฃ!") |
|
print("๐ก Tip: Ollama๋ฅผ ์ค์นํ๋ฉด ๋ ์ ํํ ๋ต๋ณ์ ์ป์ ์ ์์ต๋๋ค.") |
|
print(" ์ค์น: https://ollama.ai") |