palisade / src /optimized_rag_proper.py
Jina Camellia Yoo
upload full project structure
01661a1
"""
optimized_rag_proper.py
LLM์„ ์œ ์ง€ํ•˜๋ฉด์„œ ์‘๋‹ต์‹œ๊ฐ„์„ ๊ฐœ์„ ํ•˜๋Š” ์˜ฌ๋ฐ”๋ฅธ ๋ฐฉ๋ฒ•
๋ชฉํ‘œ: 1์ดˆ ์ด๋‚ด (ํ˜„์‹ค์  ๋ชฉํ‘œ)
"""
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document
from typing import List, Dict, Any, Optional
import os
import time
import asyncio
import aiohttp
from concurrent.futures import ThreadPoolExecutor
import json
class ProperlyOptimizedRAG:
"""
LLM์„ ์œ ์ง€ํ•˜๋ฉด์„œ ์ตœ์ ํ™”ํ•˜๋Š” ์˜ฌ๋ฐ”๋ฅธ ๋ฐฉ๋ฒ•
"""
def __init__(self, vector_store: FAISS):
self.vector_store = vector_store
# 1. ๋” ๋น ๋ฅธ ๋ชจ๋ธ ์‚ฌ์šฉ (gpt-3.5-turbo = ์•ˆ์ •์ ์ธ ๋ฒ„์ „)
self.llm = ChatOpenAI(
model="gpt-3.5-turbo", # ์•ˆ์ •์ ์ธ ๋ชจ๋ธ
temperature=0.1, # ์ผ๊ด€์„ฑ
max_tokens=300, # ๋‹ต๋ณ€ ๊ธธ์ด ์ œํ•œ์œผ๋กœ ์†๋„ ํ–ฅ์ƒ
api_key=os.getenv("OPENAI_API_KEY"),
request_timeout=5
)
# 2. ์ตœ์ ํ™”๋œ ์งง์€ ํ”„๋กฌํ”„ํŠธ
self.prompt_template = self._create_minimal_prompt()
# 3. ์Šค๋งˆํŠธ ์บ์‹ฑ (LRU ์บ์‹œ)
self.cache = {} # {question_hash: (answer, timestamp)}
self.cache_ttl = 3600 # 1์‹œ๊ฐ„
# 4. ๋น„๋™๊ธฐ ์ฒ˜๋ฆฌ ์ค€๋น„
self.executor = ThreadPoolExecutor(max_workers=2)
def _create_minimal_prompt(self) -> PromptTemplate:
"""
์ตœ์†Œํ•œ์˜ ํ”„๋กฌํ”„ํŠธ (ํ† ํฐ ์ˆ˜ ์ค„์ด๊ธฐ)
"""
# ์งง๊ณ  ๋ช…ํ™•ํ•œ ํ”„๋กฌํ”„ํŠธ = ๋น ๋ฅธ ์ฒ˜๋ฆฌ
template = """๋งค๋‰ด์–ผ: {context}
์งˆ๋ฌธ: {question}
๋งค๋‰ด์–ผ ๋‚ด์šฉ๋งŒ์œผ๋กœ ๊ฐ„๋‹จ๋ช…๋ฃŒํ•˜๊ฒŒ ๋‹ต๋ณ€:"""
return PromptTemplate(
template=template,
input_variables=["context", "question"]
)
def answer_question(self, question: str) -> Dict[str, Any]:
"""
์ตœ์ ํ™”๋œ ๋‹ต๋ณ€ ์ƒ์„ฑ (์—ฌ์ „ํžˆ LLM ์‚ฌ์šฉ)
"""
start_time = time.time()
# 1. ์บ์‹œ ํ™•์ธ (5ms)
cache_key = hash(question)
if cache_key in self.cache:
cached_answer, cached_time = self.cache[cache_key]
if time.time() - cached_time < self.cache_ttl:
return {
"question": question,
"answer": cached_answer['answer'],
"source_pages": cached_answer['pages'],
"response_time": 0.01, # ์บ์‹œ๋Š” 10ms
"cached": True
}
# 2. ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ: ๋ฒกํ„ฐ ๊ฒ€์ƒ‰๊ณผ ์ „์ฒ˜๋ฆฌ ๋™์‹œ ์‹คํ–‰
with self.executor as executor:
# ๋ฒกํ„ฐ ๊ฒ€์ƒ‰ (๋น„๋™๊ธฐ)
future_search = executor.submit(
self._fast_vector_search,
question
)
# ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ๋Œ€๊ธฐ
search_results = future_search.result(timeout=0.5)
if not search_results:
return self._fallback_response(question, start_time)
# 3. ์ปจํ…์ŠคํŠธ ์ตœ์ ํ™” (์ค‘๋ณต ์ œ๊ฑฐ, ์š”์•ฝ)
context = self._optimize_context(search_results)
# 4. LLM ํ˜ธ์ถœ (์ŠคํŠธ๋ฆฌ๋ฐ์œผ๋กœ ์ฒซ ํ† ํฐ ๋น ๋ฅด๊ฒŒ)
answer = self._fast_llm_call(question, context)
# 5. ํŽ˜์ด์ง€ ์ •๋ณด ์ถ”์ถœ
pages = list(set([
doc.metadata.get('page', 0)
for doc in search_results
]))[:3]
response = {
"question": question,
"answer": answer,
"source_pages": sorted(pages),
"response_time": time.time() - start_time,
"cached": False
}
# 6. ์บ์‹œ ์ €์žฅ
self.cache[cache_key] = (
{"answer": answer, "pages": pages},
time.time()
)
return response
def _fast_vector_search(self, question: str, k: int = 3) -> List[Document]:
"""
๋น ๋ฅธ ๋ฒกํ„ฐ ๊ฒ€์ƒ‰ (k๋ฅผ ์ค„์ด๊ณ  MMR ์ œ๊ฑฐ)
"""
try:
# similarity_search๊ฐ€ similarity_search_with_score๋ณด๋‹ค ๋น ๋ฆ„
docs = self.vector_store.similarity_search(
question,
k=k, # 3๊ฐœ๋งŒ ๊ฒ€์ƒ‰
fetch_k=k # MMR ๋น„ํ™œ์„ฑํ™”
)
return docs
except Exception as e:
print(f"๋ฒกํ„ฐ ๊ฒ€์ƒ‰ ์˜ค๋ฅ˜: {e}")
return []
def _optimize_context(self, docs: List[Document]) -> str:
"""
์ปจํ…์ŠคํŠธ ์ตœ์ ํ™” (์ค‘๋ณต ์ œ๊ฑฐ, ํ•ต์‹ฌ๋งŒ ์ถ”์ถœ)
"""
seen_content = set()
optimized = []
for doc in docs:
content = doc.page_content.strip()
# ์ค‘๋ณต ์ฒดํฌ
content_hash = hash(content[:50]) # ์•ž 50์ž๋กœ ์ค‘๋ณต ์ฒดํฌ
if content_hash in seen_content:
continue
seen_content.add(content_hash)
# ๋„ˆ๋ฌด ์งง๊ฑฐ๋‚˜ ๊ธด ๋‚ด์šฉ ์ œ์™ธ
if len(content) < 50 or len(content) > 500:
content = content[:500] # ์ตœ๋Œ€ 500์ž
optimized.append(content)
# ์ตœ๋Œ€ 3๊ฐœ ์ฒญํฌ๋งŒ ์‚ฌ์šฉ (ํ† ํฐ ์ˆ˜ ์ œํ•œ)
return "\n---\n".join(optimized[:3])
def _fast_llm_call(self, question: str, context: str) -> str:
"""
๋น ๋ฅธ LLM ํ˜ธ์ถœ
"""
try:
# ํ”„๋กฌํ”„ํŠธ ๊ตฌ์„ฑ
prompt = self.prompt_template.format(
context=context,
question=question
)
# LLM ํ˜ธ์ถœ (invoke๊ฐ€ ๊ฐ€์žฅ ๋น ๋ฆ„)
response = self.llm.invoke(prompt)
# ์‘๋‹ต ์ถ”์ถœ
if hasattr(response, 'content'):
return response.content
else:
return str(response)
except Exception as e:
print(f"LLM ํ˜ธ์ถœ ์˜ค๋ฅ˜: {e}")
return "๋‹ต๋ณ€ ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค."
def _fallback_response(self, question: str, start_time: float) -> Dict:
"""
ํด๋ฐฑ ์‘๋‹ต (๋ฒกํ„ฐ ๊ฒ€์ƒ‰ ์‹คํŒจ ์‹œ)
"""
return {
"question": question,
"answer": "ํ•ด๋‹น ์ •๋ณด๋ฅผ ๋งค๋‰ด์–ผ์—์„œ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ๋‹ค๋ฅธ ํ‘œํ˜„์œผ๋กœ ์งˆ๋ฌธํ•ด ์ฃผ์‹œ๊ฑฐ๋‚˜, ์„œ๋น„์Šค์„ผํ„ฐ(1577-0001)๋กœ ๋ฌธ์˜ํ•ด ์ฃผ์„ธ์š”.",
"source_pages": [],
"response_time": time.time() - start_time,
"cached": False
}
def batch_test(self, questions: List[str]) -> Dict[str, Any]:
"""
๋ฐฐ์น˜ ํ…Œ์ŠคํŠธ ๋ฐ ์„ฑ๋Šฅ ์ธก์ •
"""
results = []
times = []
print("\n" + "="*60)
print("๐Ÿš€ ์ตœ์ ํ™”๋œ RAG ์„ฑ๋Šฅ ํ…Œ์ŠคํŠธ (LLM ์œ ์ง€)")
print("="*60)
for i, question in enumerate(questions, 1):
result = self.answer_question(question)
results.append(result)
times.append(result['response_time'])
status = "๐Ÿ“ฆ ์บ์‹œ" if result.get('cached') else "๐Ÿ” ๊ฒ€์ƒ‰"
print(f"\n[{i}] {question}")
print(f"โฑ๏ธ {result['response_time']:.2f}์ดˆ ({status})")
print(f"๐Ÿ“„ ํŽ˜์ด์ง€: {result['source_pages']}")
print(f"๐Ÿ’ฌ {result['answer'][:150]}...")
# ํ†ต๊ณ„
avg_time = sum(times) / len(times)
cached_count = sum(1 for r in results if r.get('cached'))
print("\n" + "="*60)
print(f"๐Ÿ“Š ์„ฑ๋Šฅ ํ†ต๊ณ„:")
print(f" ํ‰๊ท  ์‘๋‹ต: {avg_time:.2f}์ดˆ")
print(f" ์ตœ์†Œ/์ตœ๋Œ€: {min(times):.2f}์ดˆ / {max(times):.2f}์ดˆ")
print(f" ์บ์‹œ ์ ์ค‘: {cached_count}/{len(questions)}")
print(f" ๋ชฉํ‘œ ๋‹ฌ์„ฑ: {'โœ…' if avg_time < 1.0 else 'โš ๏ธ 1์ดˆ ๋ชฉํ‘œ ๋ฏธ๋‹ฌ'}")
print("="*60)
return {
"average": avg_time,
"min": min(times),
"max": max(times),
"results": results
}
# ์ถ”๊ฐ€: ๋น„๋™๊ธฐ ๋ฒ„์ „ (๋” ๋น ๋ฆ„)
class AsyncOptimizedRAG:
"""
๋น„๋™๊ธฐ ์ฒ˜๋ฆฌ๋กœ ๋” ๋น ๋ฅธ ์‘๋‹ต (์‹คํ—˜์ )
"""
def __init__(self, vector_store: FAISS):
self.vector_store = vector_store
self.llm = ChatOpenAI(
model="gpt-3.5-turbo-1106",
temperature=0.1,
max_tokens=300,
api_key=os.getenv("OPENAI_API_KEY")
)
self.cache = {}
async def answer_question_async(self, question: str) -> Dict[str, Any]:
"""
๋น„๋™๊ธฐ ๋‹ต๋ณ€ ์ƒ์„ฑ
"""
start_time = time.time()
# ๋น„๋™๊ธฐ๋กœ ๋ฒกํ„ฐ ๊ฒ€์ƒ‰๊ณผ LLM ์ค€๋น„ ๋™์‹œ ์‹คํ–‰
search_task = asyncio.create_task(
self._async_vector_search(question)
)
# ๊ฒ€์ƒ‰ ์™„๋ฃŒ ๋Œ€๊ธฐ
docs = await search_task
if not docs:
return {
"question": question,
"answer": "์ •๋ณด๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.",
"source_pages": [],
"response_time": time.time() - start_time
}
# LLM ํ˜ธ์ถœ
context = "\n".join([d.page_content[:300] for d in docs[:3]])
answer = await self._async_llm_call(question, context)
return {
"question": question,
"answer": answer,
"source_pages": [d.metadata.get('page', 0) for d in docs],
"response_time": time.time() - start_time
}
async def _async_vector_search(self, question: str) -> List[Document]:
"""๋น„๋™๊ธฐ ๋ฒกํ„ฐ ๊ฒ€์ƒ‰"""
return await asyncio.to_thread(
self.vector_store.similarity_search,
question,
k=3
)
async def _async_llm_call(self, question: str, context: str) -> str:
"""๋น„๋™๊ธฐ LLM ํ˜ธ์ถœ"""
prompt = f"๋งค๋‰ด์–ผ: {context}\n์งˆ๋ฌธ: {question}\n๋‹ต๋ณ€:"
response = await asyncio.to_thread(
self.llm.invoke,
prompt
)
return response.content if hasattr(response, 'content') else str(response)
# ๋ฉ”์ธ ํ…Œ์ŠคํŠธ
if __name__ == "__main__":
from embeddings import VehicleManualEmbeddings
import os
# API ํ‚ค ํ™•์ธ
if not os.getenv("OPENAI_API_KEY"):
print("OpenAI API Key๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”:")
api_key = input("sk-... : ").strip()
os.environ["OPENAI_API_KEY"] = api_key
# ๋ฒกํ„ฐ ์ €์žฅ์†Œ ๋กœ๋“œ
print("๋ฒกํ„ฐ ์ธ๋ฑ์Šค ๋กœ๋”ฉ ์ค‘...")
embedder = VehicleManualEmbeddings()
vector_store = embedder.load_index()
# ์ตœ์ ํ™”๋œ RAG ์‹œ์Šคํ…œ
rag = ProperlyOptimizedRAG(vector_store)
# ํ…Œ์ŠคํŠธ ์งˆ๋ฌธ
test_questions = [
"์—”์ง„์˜ค์ผ ๊ต์ฒด ์ฃผ๊ธฐ๋Š”?",
"ํƒ€์ด์–ด ๊ณต๊ธฐ์••์€ ์–ผ๋งˆ?",
"๊ฒฝ๊ณ ๋“ฑ์ด ์ผœ์กŒ์„ ๋•Œ ์–ด๋–ป๊ฒŒ ํ•ด์•ผ ํ•˜๋‚˜์š”?",
"๋ˆˆ๊ธธ ์ฃผํ–‰ ์‹œ ์ฃผ์˜์‚ฌํ•ญ",
"์šด์ „์ž ๋ณด์กฐ ์‹œ์Šคํ…œ ์„ค์ •"
]
# ์ฒซ ๋ฒˆ์งธ ์‹คํ–‰ (์ฝœ๋“œ ์Šคํƒ€ํŠธ)
print("\n### 1์ฐจ ์‹คํ–‰ (์บ์‹œ ์—†์Œ) ###")
stats1 = rag.batch_test(test_questions)
# ๋‘ ๋ฒˆ์งธ ์‹คํ–‰ (์บ์‹œ ํ™œ์šฉ)
print("\n### 2์ฐจ ์‹คํ–‰ (์บ์‹œ ํ™œ์šฉ) ###")
stats2 = rag.batch_test(test_questions[:3])
# ๊ฐœ์„ ์œจ ๊ณ„์‚ฐ
improvement = ((stats1['average'] - stats2['average']) / stats1['average']) * 100
print(f"\n๐ŸŽฏ ์บ์‹œ ํšจ๊ณผ: {improvement:.1f}% ์†๋„ ํ–ฅ์ƒ")