import os import httpx import gradio as gr from openai import OpenAI from qdrant_client import QdrantClient, models from sentence_transformers import SentenceTransformer from fastembed import SparseTextEmbedding API_KEY = os.environ.get('DEEPSEEK_API_KEY') BASE_URL = "https://api.deepseek.com" QDRANT_PATH = "./qdrant_db" COLLECTION_NAME = "huggingface_transformers_docs" EMBEDDING_MODEL_ID = "fyerfyer/finetune-jina-transformers-v1" SPARSE_MODEL_ID = "prithivida/Splade_PP_en_v1" class HFRAG: def __init__(self): self.dense_model = SentenceTransformer(EMBEDDING_MODEL_ID, trust_remote_code=True) self.sparse_model = SparseTextEmbedding(model_name=SPARSE_MODEL_ID) lock_file = os.path.join(QDRANT_PATH, ".lock") if os.path.exists(lock_file): try: os.remove(lock_file) print("Cleaned up stale lock file.") except: pass if not os.path.exists(QDRANT_PATH): raise ValueError(f"Qdrant path not found: {QDRANT_PATH}.") self.db_client = QdrantClient(path=QDRANT_PATH) if not self.db_client.collection_exists(COLLECTION_NAME): raise ValueError(f"Collection '{COLLECTION_NAME}' not found in Qdrant DB.") print(f"Connected to Qdrant") self.llm_client = OpenAI( api_key=API_KEY, base_url=BASE_URL, http_client=httpx.Client(proxy=None, trust_env=False) ) def retrieve(self, query: str, top_k: int = 5): # Generate dense vector query_dense_vec = self.dense_model.encode(query).tolist() # Generate sparse vector query_sparse_gen = list(self.sparse_model.embed([query]))[0] query_sparse_vec = models.SparseVector( indices=query_sparse_gen.indices.tolist(), values=query_sparse_gen.values.tolist() ) # Create prefetch for dense retrieval prefetch_dense = models.Prefetch( query=query_dense_vec, using="text-dense", limit=20, ) # Create prefetch for sparse retrieval prefetch_sparse = models.Prefetch( query=query_sparse_vec, using="text-sparse", limit=20, ) # Hybrid search with RRF fusion results = self.db_client.query_points( collection_name=COLLECTION_NAME, prefetch=[prefetch_dense, prefetch_sparse], query=models.FusionQuery(fusion=models.Fusion.RRF), limit=top_k, with_payload=True ).points return results def format_context(self, search_results): context_pieces = [] sources_summary = [] for idx, hit in enumerate(search_results, 1): raw_source = hit.payload.get('source', 'unknown') filename = raw_source.split('/')[-1] if '/' in raw_source else raw_source text = hit.payload['text'] score = hit.score sources_summary.append(f"`{filename}` (Score: {score:.2f})") piece = f"""\n{text}\n""" context_pieces.append(piece) return "\n\n".join(context_pieces), sources_summary rag_system = None def initialize_system(): global rag_system if rag_system is None: try: rag_system = HFRAG() except Exception as e: print(f"Error initializing: {e}") return None return rag_system # ================= Gradio Logic ================= def predict(message, history): rag = initialize_system() if not rag: yield "āŒ System initialization failed. Check logs." return if not API_KEY: yield "āŒ Error: `DEEPSEEK_API_KEY` not set in Space secrets." return # 1. Retrieve yield "šŸ” Retrieving relevant documents..." results = rag.retrieve(message) if not results: yield "āš ļø No relevant documents found in the knowledge base." return # 2. Format context context_str, sources_list = rag.format_context(results) # 3. Build Prompt system_prompt = """You are an expert AI assistant specializing in the Hugging Face Transformers library. Your goal is to answer the user's question based ONLY on the provided "Retrieved Context". GUIDELINES: 1. **Code First**: Prioritize showing Python code examples. 2. **Citation**: Cite source filenames like `[model_doc.md]`. 3. **Honesty**: If the answer isn't in the context, say you don't know. 4. **Format**: Use Markdown.""" user_prompt = f"""### User Query\n{message}\n\n### Retrieved Context\n{context_str}""" header = "**šŸ“š Found relevant documents:**\n" + "\n".join([f"- {s}" for s in sources_list]) + "\n\n---\n\n" current_response = header yield current_response try: response = rag.llm_client.chat.completions.create( model="deepseek-chat", messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ], temperature=0.1, stream=True ) for chunk in response: if chunk.choices[0].delta.content: content = chunk.choices[0].delta.content current_response += content yield current_response except Exception as e: yield current_response + f"\n\nāŒ LLM API Error: {str(e)}" demo = gr.ChatInterface( fn=predict, title="šŸ¤— Hugging Face RAG Expert", description="Ask me anything about Transformers! Powered by DeepSeek-V3 & Finetuned Embeddings.", examples=[ "How to implement padding?", "How to use BERT pipeline?", "How to fine-tune a model using Trainer?", "What is the difference between padding and truncation?" ], theme="soft" ) if __name__ == "__main__": demo.launch()