fyerfyer's picture
Initial deploy
c9531de
raw
history blame
4.89 kB
import os
import httpx
import gradio as gr
from openai import OpenAI
from qdrant_client import QdrantClient
from sentence_transformers import SentenceTransformer
API_KEY = os.environ.get('DEEPSEEK_API_KEY')
BASE_URL = "https://api.deepseek.com"
QDRANT_PATH = "./qdrant_db"
COLLECTION_NAME = "huggingface_transformers_docs"
EMBEDDING_MODEL_ID = "fyerfyer/finetune-jina-transformers-v1"
class HFRAG:
def __init__(self):
self.embed_model = SentenceTransformer(EMBEDDING_MODEL_ID, trust_remote_code=True)
lock_file = os.path.join(QDRANT_PATH, ".lock")
if os.path.exists(lock_file):
try:
os.remove(lock_file)
print("Cleaned up stale lock file.")
except:
pass
if not os.path.exists(QDRANT_PATH):
raise ValueError(f"Qdrant path not found: {QDRANT_PATH}.")
self.db_client = QdrantClient(path=QDRANT_PATH)
if not self.db_client.collection_exists(COLLECTION_NAME):
raise ValueError(f"Collection '{COLLECTION_NAME}' not found in Qdrant DB.")
print(f"Connected to Qdrant")
self.llm_client = OpenAI(
api_key=API_KEY,
base_url=BASE_URL,
http_client=httpx.Client(proxy=None, trust_env=False)
)
def retrieve(self, query: str, top_k: int = 5, score_threshold: float = 0.40):
query_vector = self.embed_model.encode(query).tolist()
if hasattr(self.db_client, 'search'):
results = self.db_client.search(
collection_name=COLLECTION_NAME,
query_vector=query_vector,
limit=top_k,
score_threshold=score_threshold
)
else:
results = self.db_client.query_points(
collection_name=COLLECTION_NAME,
query=query_vector,
limit=top_k,
with_payload=True,
score_threshold=score_threshold
).points
return results
def format_context(self, search_results):
context_pieces = []
sources_summary = []
for idx, hit in enumerate(search_results, 1):
raw_source = hit.payload['metadata']['source']
filename = raw_source.split('/')[-1]
text = hit.payload['text']
score = hit.score
sources_summary.append(f"`{filename}` (Score: {score:.2f})")
piece = f"""<doc id="{idx}" source="{filename}">\n{text}\n</doc>"""
context_pieces.append(piece)
return "\n\n".join(context_pieces), sources_summary
rag_system = None
def initialize_system():
global rag_system
if rag_system is None:
try:
rag_system = HFRAG()
except Exception as e:
print(f"Error initializing: {e}")
return None
return rag_system
# ================= Gradio Logic =================
def predict(message, history):
rag = initialize_system()
if not rag:
yield "❌ System initialization failed. Check logs."
return
if not API_KEY:
yield "❌ Error: `DEEPSEEK_API_KEY` not set in Space secrets."
return
# 1. Retrieve
yield "πŸ” Retrieving relevant documents..."
results = rag.retrieve(message)
if not results:
yield "⚠️ No relevant documents found in the knowledge base."
return
# 2. Format context
context_str, sources_list = rag.format_context(results)
# 3. Build Prompt
system_prompt = """You are an expert AI assistant specializing in the Hugging Face Transformers library.
Your goal is to answer the user's question based ONLY on the provided "Retrieved Context".
GUIDELINES:
1. **Code First**: Prioritize showing Python code examples.
2. **Citation**: Cite source filenames like `[model_doc.md]`.
3. **Honesty**: If the answer isn't in the context, say you don't know.
4. **Format**: Use Markdown."""
user_prompt = f"""### User Query\n{message}\n\n### Retrieved Context\n{context_str}"""
header = "**πŸ“š Found relevant documents:**\n" + "\n".join([f"- {s}" for s in sources_list]) + "\n\n---\n\n"
current_response = header
yield current_response
try:
response = rag.llm_client.chat.completions.create(
model="deepseek-chat",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
temperature=0.1,
stream=True
)
for chunk in response:
if chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
current_response += content
yield current_response
except Exception as e:
yield current_response + f"\n\n❌ LLM API Error: {str(e)}"
demo = gr.ChatInterface(
fn=predict,
title="πŸ€— Hugging Face RAG Expert",
description="Ask me anything about Transformers! Powered by DeepSeek-V3 & Finetuned Embeddings.",
examples=[
"How to implement padding?",
"How to use BERT pipeline?",
"How to fine-tune a model using Trainer?",
"What is the difference between padding and truncation?"
],
theme="soft"
)
if __name__ == "__main__":
demo.launch()