Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import logging | |
| import os | |
| import tempfile | |
| import shutil | |
| import pdfplumber | |
| import ollama | |
| import time | |
| import httpx | |
| from langchain_community.document_loaders import UnstructuredPDFLoader | |
| from langchain_community.embeddings import OllamaEmbeddings | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import Chroma | |
| from langchain.prompts import ChatPromptTemplate, PromptTemplate | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_community.chat_models import ChatOllama | |
| from langchain_core.runnables import RunnablePassthrough | |
| from langchain.retrievers.multi_query import MultiQueryRetriever | |
| from typing import List, Tuple, Dict, Any, Optional | |
| # Streamlit page configuration | |
| st.set_page_config( | |
| page_title="Ollama PDF RAG Streamlit UI", | |
| page_icon="π", | |
| layout="wide", | |
| initial_sidebar_state="collapsed", | |
| ) | |
| # Logging configuration | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(levelname)s - %(message)s", | |
| datefmt="%Y-%m-%d %H:%M:%S", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def ollama_list_with_retry(retries=3, delay=5): | |
| """Attempt to list models from Ollama with retry logic.""" | |
| for attempt in range(retries): | |
| try: | |
| response = ollama.list() | |
| logger.info("Successfully retrieved model list from Ollama") | |
| return response | |
| except httpx.ConnectError as e: | |
| logger.error(f"Connection error: {e}. Attempt {attempt + 1} of {retries}") | |
| if attempt < retries - 1: | |
| time.sleep(delay) | |
| else: | |
| logger.error("All retry attempts failed. Cannot connect to Ollama service.") | |
| raise | |
| def extract_model_names(models_info: Dict[str, List[Dict[str, Any]]]) -> Tuple[str, ...]: | |
| """Extract model names from the provided models information.""" | |
| logger.info("Extracting model names from models_info") | |
| model_names = tuple(model["name"] for model in models_info["models"]) | |
| logger.info(f"Extracted model names: {model_names}") | |
| return model_names | |
| def create_vector_db(file_upload) -> Chroma: | |
| """Create a vector database from an uploaded PDF file.""" | |
| logger.info(f"Creating vector DB from file upload: {file_upload.name}") | |
| temp_dir = tempfile.mkdtemp() | |
| path = os.path.join(temp_dir, file_upload.name) | |
| with open(path, "wb") as f: | |
| f.write(file_upload.getvalue()) | |
| logger.info(f"File saved to temporary path: {path}") | |
| loader = UnstructuredPDFLoader(path) | |
| data = loader.load() | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=7500, chunk_overlap=100) | |
| chunks = text_splitter.split_documents(data) | |
| logger.info("Document split into chunks") | |
| embeddings = OllamaEmbeddings(model="nomic-embed-text", show_progress=True) | |
| vector_db = Chroma.from_documents( | |
| documents=chunks, embedding=embeddings, collection_name="myRAG" | |
| ) | |
| logger.info("Vector DB created") | |
| shutil.rmtree(temp_dir) | |
| logger.info(f"Temporary directory {temp_dir} removed") | |
| return vector_db | |
| def process_question(question: str, vector_db: Chroma, selected_model: str) -> str: | |
| """Process a user question using the vector database and selected language model.""" | |
| logger.info(f"Processing question: {question} using model: {selected_model}") | |
| llm = ChatOllama(model=selected_model, temperature=0) | |
| QUERY_PROMPT = PromptTemplate( | |
| input_variables=["question"], | |
| template="""You are an AI language model assistant. Your task is to generate 3 | |
| different versions of the given user question to retrieve relevant documents from | |
| a vector database. By generating multiple perspectives on the user question, your | |
| goal is to help the user overcome some of the limitations of the distance-based | |
| similarity search. Provide these alternative questions separated by newlines. | |
| Original question: {question}""", | |
| ) | |
| retriever = MultiQueryRetriever.from_llm( | |
| vector_db.as_retriever(), llm, prompt=QUERY_PROMPT | |
| ) | |
| template = """Answer the question based ONLY on the following context: | |
| {context} | |
| Question: {question} | |
| If you don't know the answer, just say that you don't know, don't try to make up an answer. | |
| Only provide the answer from the {context}, nothing else. | |
| Add snippets of the context you used to answer the question. | |
| """ | |
| prompt = ChatPromptTemplate.from_template(template) | |
| chain = ( | |
| {"context": retriever, "question": RunnablePassthrough()} | |
| | prompt | |
| | llm | |
| | StrOutputParser() | |
| ) | |
| response = chain.invoke(question) | |
| logger.info("Question processed and response generated") | |
| return response | |
| def extract_all_pages_as_images(file_upload) -> List[Any]: | |
| """Extract all pages from a PDF file as images.""" | |
| logger.info(f"Extracting all pages as images from file: {file_upload.name}") | |
| pdf_pages = [] | |
| with pdfplumber.open(file_upload) as pdf: | |
| pdf_pages = [page.to_image().original for page in pdf.pages] | |
| logger.info("PDF pages extracted as images") | |
| return pdf_pages | |
| def delete_vector_db(vector_db: Optional[Chroma]) -> None: | |
| """Delete the vector database and clear related session state.""" | |
| logger.info("Deleting vector DB") | |
| if vector_db is not None: | |
| vector_db.delete_collection() | |
| st.session_state.pop("pdf_pages", None) | |
| st.session_state.pop("file_upload", None) | |
| st.session_state.pop("vector_db", None) | |
| st.success("Collection and temporary files deleted successfully.") | |
| logger.info("Vector DB and related session state cleared") | |
| st.rerun() | |
| else: | |
| st.error("No vector database found to delete.") | |
| logger.warning("Attempted to delete vector DB, but none was found") | |
| def main() -> None: | |
| """Main function to run the Streamlit application.""" | |
| st.subheader("π§ Ollama PDF RAG playground", divider="gray", anchor=False) | |
| try: | |
| models_info = ollama_list_with_retry() | |
| available_models = extract_model_names(models_info) | |
| except httpx.ConnectError: | |
| st.error("Could not connect to the Ollama service. Please check your setup and try again.") | |
| return | |
| col1, col2 = st.columns([1.5, 2]) | |
| if "messages" not in st.session_state: | |
| st.session_state["messages"] = [] | |
| if "vector_db" not in st.session_state: | |
| st.session_state["vector_db"] | |