from ibm_watsonx_ai.foundation_models import ModelInference from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams from ibm_watsonx_ai.metanames import EmbedTextParamsMetaNames from ibm_watsonx_ai import Credentials from langchain_ibm import WatsonxLLM, WatsonxEmbeddings from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.vectorstores import Chroma from langchain_community.document_loaders import PyPDFLoader from langchain.chains import RetrievalQA import gradio as gr # You can use this section to suppress warnings generated by your code: def warn(*args, **kwargs): pass import warnings warnings.warn = warn warnings.filterwarnings('ignore') ## LLM def get_llm(): model_id = 'mistralai/mixtral-8x7b-instruct-v01' parameters = { GenParams.MAX_NEW_TOKENS: 256, GenParams.TEMPERATURE: 0.5, } project_id = "skills-network" watsonx_llm = WatsonxLLM( model_id=model_id, url="https://us-south.ml.cloud.ibm.com", project_id=project_id, params=parameters, ) return watsonx_llm ## Document loader def document_loader(file): loader = PyPDFLoader(file.name) loaded_document = loader.load() return loaded_document ## Text splitter def text_splitter(data): text_splitter = RecursiveCharacterTextSplitter( chunk_size=1000, chunk_overlap=50, length_function=len, ) chunks = text_splitter.split_documents(data) return chunks ## Vector db def vector_database(chunks): embedding_model = watsonx_embedding() vectordb = Chroma.from_documents(chunks, embedding_model) return vectordb ## Embedding model def watsonx_embedding(): embed_params = { EmbedTextParamsMetaNames.TRUNCATE_INPUT_TOKENS: 3, EmbedTextParamsMetaNames.RETURN_OPTIONS: {"input_text": True}, } watsonx_embedding = WatsonxEmbeddings( model_id="ibm/slate-125m-english-rtrvr", url="https://us-south.ml.cloud.ibm.com", project_id="skills-network", params=embed_params, ) return watsonx_embedding ## Retriever def retriever(file): splits = document_loader(file) chunks = text_splitter(splits) vectordb = vector_database(chunks) retriever = vectordb.as_retriever() return retriever ## QA Chain def retriever_qa(file, query): llm = get_llm() retriever_obj = retriever(file) qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever_obj, return_source_documents=False) response = qa.invoke(query) return response['result'] # Create Gradio interface rag_application = gr.Interface( fn=retriever_qa, allow_flagging="never", inputs=[ gr.File(label="Upload PDF File", file_count="single", file_types=['.pdf'], type="filepath"), # Drag and drop file upload gr.Textbox(label="Input Query", lines=2, placeholder="Type your question here...") ], outputs=gr.Textbox(label="Output"), title="RAG Chatbot", description="Upload a PDF document and ask any question. The chatbot will try to answer using the provided document." ) # Launch the app rag_application.launch(server_name="127.0.0.1", server_port= 7862, share=True)