Spaces:
Sleeping
Sleeping
| import datetime | |
| import os | |
| from dotenv import load_dotenv | |
| import asyncio | |
| from fastapi import FastAPI, Body, File, UploadFile, HTTPException | |
| from fastapi.responses import StreamingResponse | |
| from typing import List, AsyncIterable, Annotated, Optional | |
| from enum import Enum | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from langchain_openai import ChatOpenAI | |
| from langchain import hub | |
| from langchain_chroma import Chroma | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.runnables import RunnablePassthrough | |
| from langchain_nomic.embeddings import NomicEmbeddings | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain.callbacks import AsyncIteratorCallbackHandler | |
| from langchain_core.documents import Document | |
| from in_memory import load_all_documents | |
| from langchain_nomic.embeddings import Embeddings, NomicEmbeddings | |
| from loader import load_web_content, load_youtube_content | |
| from get_pattern import generate_pattern | |
| from get_agents import process_agents | |
| # ################################### FastAPI setup ############################################ | |
| app = FastAPI() | |
| origins = ["*"] | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=origins, | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ################################### Helper functions ############################################ | |
| # async def load_all_documents(files: List[UploadFile]) -> List[Document]: | |
| # documents = [] | |
| # for file in files: | |
| # docs = await load_document(file) | |
| # documents.extend(docs) | |
| # return documents | |
| # ################################### LLM, RAG and Streaming ############################################ | |
| load_dotenv() | |
| GROQ_API_KEY = os.environ.get("GROQ_API_KEY") | |
| GROQ_API_BASE = os.environ.get("GROQ_API_BASE") | |
| OPENAI_MODEL_NAME = os.environ.get("OPENAI_MODEL_NAME") | |
| embedding_model = NomicEmbeddings(model="nomic-embed-text-v1.5") | |
| def split_documents(documents: List[Document], chunk_size=1000, chunk_overlap=200) -> List[Document]: | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=chunk_size, chunk_overlap=chunk_overlap | |
| ) | |
| print("Splitting documents into chunks...") | |
| return text_splitter.split_documents(documents) | |
| def generate_embeddings(documents: List[Document]) -> NomicEmbeddings: | |
| embedding_model = NomicEmbeddings(model="nomic-embed-text-v1.5") | |
| embeddings = [embedding_model.embed( | |
| [document.page_content], task_type='search_document') for document in documents] | |
| return embedding_model | |
| def store_embeddings(documents: List[Document], embeddings: NomicEmbeddings): | |
| vectorstore = Chroma.from_documents( | |
| documents=documents, embedding=embeddings, persist_directory="./chroma_db") | |
| return vectorstore | |
| def load_embeddings(embeddings: NomicEmbeddings) -> Chroma: | |
| embeddings = Chroma(persist_directory="./chroma_db", | |
| embedding_function=embeddings) | |
| return embeddings | |
| # ################################### Updated generate_chunks Function ############################################ | |
| async def generate_chunks(query: str) -> AsyncIterable[str]: | |
| callback = AsyncIteratorCallbackHandler() | |
| llm = ChatOpenAI( | |
| openai_api_base=GROQ_API_BASE, | |
| api_key=GROQ_API_KEY, | |
| temperature=0.0, | |
| model_name=OPENAI_MODEL_NAME, # "mixtral-8x7b-32768", | |
| streaming=True, # ! important | |
| verbose=True, | |
| callbacks=[callback] | |
| ) | |
| # Load vector store (this should be pre-populated with documents and embeddings) | |
| # Ensure to modify this to load your actual vector store | |
| vectorstore = load_embeddings(embeddings=embedding_model) | |
| # Retrieve relevant documents for the query | |
| retriever = vectorstore.as_retriever() | |
| # relevant_docs = retriever(query) | |
| # Combine the retrieved documents into a single string | |
| def format_docs(docs): | |
| return "\n\n".join(doc.page_content for doc in docs) | |
| # Define the RAG chain | |
| prompt = hub.pull("rlm/rag-prompt") | |
| rag_chain = ( | |
| {"context": retriever | format_docs, "question": RunnablePassthrough()} | |
| | prompt | |
| | llm | |
| | StrOutputParser() | |
| ) | |
| # Generate the response | |
| task = asyncio.create_task( | |
| rag_chain.ainvoke(query) | |
| ) | |
| index = 0 | |
| try: | |
| async for token in callback.aiter(): | |
| print(index, ": ", token, ": ", datetime.datetime.now().time()) | |
| index = index + 1 | |
| yield token | |
| except Exception as e: | |
| print(f"Caught exception: {e}") | |
| finally: | |
| callback.done.set() | |
| await task | |
| # ################################### Models ######################################## | |
| class QuestionType(str, Enum): | |
| PATTERN = "PATTERN" | |
| AGENTS = "AGENTS" | |
| RAG = "RAG" | |
| class Input(BaseModel): | |
| question: str | |
| type: QuestionType | |
| pattern: Optional[str] | |
| chat_history: List[str] | |
| class Metadata(BaseModel): | |
| conversation_id: str | |
| class Config(BaseModel): | |
| metadata: Metadata | |
| class RequestBody(BaseModel): | |
| input: Input | |
| config: Config | |
| # ################################### Routes ############################################ | |
| def read_root(): | |
| return {"Hello": "World from Marigen"} | |
| async def chat(query: RequestBody = Body(...)): | |
| print(query.input.question) | |
| print(query.input.type) | |
| if query.input.type == QuestionType.PATTERN: | |
| print(query.input.pattern) | |
| pattern = query.input.pattern | |
| gen = generate_pattern(pattern=pattern, query=query.input.question) | |
| return StreamingResponse(gen, media_type="text/event-stream") | |
| elif query.input.type == QuestionType.AGENTS: | |
| gen = process_agents(query.input.question) | |
| return StreamingResponse(gen, media_type="text/event-stream") | |
| elif query.input.type == QuestionType.RAG: | |
| gen = generate_chunks(query.input.question) | |
| return StreamingResponse(gen, media_type="text/event-stream") | |
| raise HTTPException(status_code=400, detail="No accurate response for your given query") | |
| async def create_upload_files( | |
| files: Annotated[List[UploadFile], File(description="Multiple files as UploadFile")], | |
| ): | |
| try: | |
| # Load documents from files | |
| documents = await load_all_documents(files) | |
| print(f"Loaded {len(documents)} documents") | |
| print(f"----------> {documents} documents <-----------") | |
| chunks = [] | |
| # Split documents into chunks | |
| for docs in documents: | |
| print(docs) | |
| chunk = split_documents(docs[0]) | |
| chunks.extend(chunk) | |
| print(f"Split into {len(chunks)} chunks") | |
| # Generate embeddings for chunks | |
| # embeddings_model = generate_embeddings(chunks) | |
| # print(f"Generated {len(embeddings)} embeddings") | |
| # # Store embeddings in vector store | |
| vectorstore = store_embeddings(chunks, embedding_model) | |
| print("Embeddings stored in vector store") | |
| return {"filenames": [file.filename for file in files], 'chunks': chunks, "message": "Files processed and embeddings generated."} | |
| except Exception as e: | |
| print(f"Error loading documents: {e}") | |
| return {"message": f"Error loading documents: {e}"} | |
| # New routes for YouTube and website content loading | |
| async def load_youtube(youtube_url: str): | |
| try: | |
| documents = load_youtube_content(youtube_url) | |
| chunks = split_documents(documents) | |
| store_embeddings(chunks, embedding_model) | |
| return {"message": f"YouTube video loaded and processed successfully.", "documents": documents} | |
| except Exception as e: | |
| print(f"Error loading YouTube video: {e}") | |
| return {"message": f"Error loading YouTube video: {e}"} | |
| async def load_website(website_url: str): | |
| try: | |
| documents = load_web_content(website_url) | |
| chunks = split_documents(documents) | |
| store_embeddings(chunks, embedding_model) | |
| return {"message": f"Website loaded and processed successfully.", "documents": documents} | |
| except Exception as e: | |
| print(f"Error loading website: {e}") | |
| return {"message": f"Error loading website: {e}"} | |
| async def query_vector_store(query: str): | |
| # Load the vector store (ensure you maintain a reference to it, possibly store in memory or a persistent store) | |
| # Modify this with actual loading mechanism | |
| vectorstore = load_embeddings(embeddings=embedding_model) | |
| # Perform a query to retrieve relevant documents | |
| relevant_docs = vectorstore.query(query) | |
| return {"query": query, "results": [doc.page_content for doc in relevant_docs]} | |