Spaces:
Sleeping
Sleeping
| import os | |
| import httpx | |
| import gradio as gr | |
| from openai import OpenAI | |
| from qdrant_client import QdrantClient | |
| from sentence_transformers import SentenceTransformer | |
| API_KEY = os.environ.get('DEEPSEEK_API_KEY') | |
| BASE_URL = "https://api.deepseek.com" | |
| QDRANT_PATH = "./qdrant_db" | |
| COLLECTION_NAME = "huggingface_transformers_docs" | |
| EMBEDDING_MODEL_ID = "fyerfyer/finetune-jina-transformers-v1" | |
| class HFRAG: | |
| def __init__(self): | |
| self.embed_model = SentenceTransformer(EMBEDDING_MODEL_ID, trust_remote_code=True) | |
| lock_file = os.path.join(QDRANT_PATH, ".lock") | |
| if os.path.exists(lock_file): | |
| try: | |
| os.remove(lock_file) | |
| print("Cleaned up stale lock file.") | |
| except: | |
| pass | |
| if not os.path.exists(QDRANT_PATH): | |
| raise ValueError(f"Qdrant path not found: {QDRANT_PATH}.") | |
| self.db_client = QdrantClient(path=QDRANT_PATH) | |
| if not self.db_client.collection_exists(COLLECTION_NAME): | |
| raise ValueError(f"Collection '{COLLECTION_NAME}' not found in Qdrant DB.") | |
| print(f"Connected to Qdrant") | |
| self.llm_client = OpenAI( | |
| api_key=API_KEY, | |
| base_url=BASE_URL, | |
| http_client=httpx.Client(proxy=None, trust_env=False) | |
| ) | |
| def retrieve(self, query: str, top_k: int = 5, score_threshold: float = 0.40): | |
| query_vector = self.embed_model.encode(query).tolist() | |
| if hasattr(self.db_client, 'search'): | |
| results = self.db_client.search( | |
| collection_name=COLLECTION_NAME, | |
| query_vector=query_vector, | |
| limit=top_k, | |
| score_threshold=score_threshold | |
| ) | |
| else: | |
| results = self.db_client.query_points( | |
| collection_name=COLLECTION_NAME, | |
| query=query_vector, | |
| limit=top_k, | |
| with_payload=True, | |
| score_threshold=score_threshold | |
| ).points | |
| return results | |
| def format_context(self, search_results): | |
| context_pieces = [] | |
| sources_summary = [] | |
| for idx, hit in enumerate(search_results, 1): | |
| raw_source = hit.payload['metadata']['source'] | |
| filename = raw_source.split('/')[-1] | |
| text = hit.payload['text'] | |
| score = hit.score | |
| sources_summary.append(f"`{filename}` (Score: {score:.2f})") | |
| piece = f"""<doc id="{idx}" source="{filename}">\n{text}\n</doc>""" | |
| context_pieces.append(piece) | |
| return "\n\n".join(context_pieces), sources_summary | |
| rag_system = None | |
| def initialize_system(): | |
| global rag_system | |
| if rag_system is None: | |
| try: | |
| rag_system = HFRAG() | |
| except Exception as e: | |
| print(f"Error initializing: {e}") | |
| return None | |
| return rag_system | |
| # ================= Gradio Logic ================= | |
| def predict(message, history): | |
| rag = initialize_system() | |
| if not rag: | |
| yield "β System initialization failed. Check logs." | |
| return | |
| if not API_KEY: | |
| yield "β Error: `DEEPSEEK_API_KEY` not set in Space secrets." | |
| return | |
| # 1. Retrieve | |
| yield "π Retrieving relevant documents..." | |
| results = rag.retrieve(message) | |
| if not results: | |
| yield "β οΈ No relevant documents found in the knowledge base." | |
| return | |
| # 2. Format context | |
| context_str, sources_list = rag.format_context(results) | |
| # 3. Build Prompt | |
| system_prompt = """You are an expert AI assistant specializing in the Hugging Face Transformers library. | |
| Your goal is to answer the user's question based ONLY on the provided "Retrieved Context". | |
| GUIDELINES: | |
| 1. **Code First**: Prioritize showing Python code examples. | |
| 2. **Citation**: Cite source filenames like `[model_doc.md]`. | |
| 3. **Honesty**: If the answer isn't in the context, say you don't know. | |
| 4. **Format**: Use Markdown.""" | |
| user_prompt = f"""### User Query\n{message}\n\n### Retrieved Context\n{context_str}""" | |
| header = "**π Found relevant documents:**\n" + "\n".join([f"- {s}" for s in sources_list]) + "\n\n---\n\n" | |
| current_response = header | |
| yield current_response | |
| try: | |
| response = rag.llm_client.chat.completions.create( | |
| model="deepseek-chat", | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| temperature=0.1, | |
| stream=True | |
| ) | |
| for chunk in response: | |
| if chunk.choices[0].delta.content: | |
| content = chunk.choices[0].delta.content | |
| current_response += content | |
| yield current_response | |
| except Exception as e: | |
| yield current_response + f"\n\nβ LLM API Error: {str(e)}" | |
| demo = gr.ChatInterface( | |
| fn=predict, | |
| title="π€ Hugging Face RAG Expert", | |
| description="Ask me anything about Transformers! Powered by DeepSeek-V3 & Finetuned Embeddings.", | |
| examples=[ | |
| "How to implement padding?", | |
| "How to use BERT pipeline?", | |
| "How to fine-tune a model using Trainer?", | |
| "What is the difference between padding and truncation?" | |
| ], | |
| theme="soft" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |