|
import os |
|
import sys |
|
import docarray |
|
sys.path.append('../..') |
|
from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter |
|
from langchain.vectorstores import DocArrayInMemorySearch |
|
from langchain.document_loaders import TextLoader |
|
from langchain.chains import RetrievalQA, ConversationalRetrievalChain |
|
from langchain.memory import ConversationBufferMemory |
|
from langchain.chat_models import ChatOllama |
|
from langchain.llms import Ollama |
|
from langchain.embeddings import OllamaEmbeddings |
|
from langchain.callbacks.manager import CallbackManager |
|
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler |
|
from langchain.document_loaders import TextLoader |
|
from langchain.document_loaders import GitLoader |
|
from langchain.memory import ConversationBufferMemory, ConversationBufferWindowMemory |
|
from langchain.vectorstores import Chroma |
|
from langchain.prompts import PromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate, AIMessagePromptTemplate, ChatPromptTemplate |
|
import datetime |
|
import shutil |
|
|
|
|
|
def loader(url: str, branch: str, file_filter: str): |
|
repo_path = "./github_repo" |
|
if os.path.exists(repo_path): |
|
shutil.rmtree(repo_path) |
|
|
|
loader = GitLoader( |
|
clone_url= url, |
|
repo_path="./github_repo/", |
|
branch=branch, |
|
file_filter=lambda file_path: file_path.endswith(tuple(file_filter.split(','))) |
|
) |
|
|
|
data = loader.load() |
|
return data |
|
|
|
|
|
|
|
def split_data(data): |
|
splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=1000, |
|
chunk_overlap=150, |
|
length_function=len, |
|
add_start_index=True |
|
) |
|
chunks = splitter.split_documents(data) |
|
return chunks |
|
|
|
|
|
def ingest_chunks(chunks): |
|
embedding = OllamaEmbeddings( |
|
base_url='https://thewise-ollama-server.hf.space', |
|
model="nomic-embed-text", |
|
) |
|
vector_store = DocArrayInMemorySearch.from_documents(chunks, embedding) |
|
|
|
repo_path = "./github_repo" |
|
if os.path.exists(repo_path): |
|
shutil.rmtree(repo_path) |
|
|
|
return vector_store |
|
|
|
|
|
def retreival(vector_store, k): |
|
|
|
llm = ChatOllama( |
|
base_url='https://thewise-ollama-server.hf.space', |
|
model="codellama", |
|
callback_manager=CallbackManager([StreamingStdOutCallbackHandler()])) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
system_template = """You're a code summarisation assistant. Given the following extracted parts of a long document as "CONTEXT" create a final answer. |
|
If you don't know the answer, just say that you don't know. Don't try to make up an answer. |
|
Only If asked to create a "DIAGRAM" for code use "MERMAID SYNTAX LANGUAGE" in your answer from "CONTEXT" and "CHAT HISTORY" with a short explanation of diagram. |
|
CONTEXT: {context} |
|
======= |
|
CHAT HISTORY: {chat_history} |
|
======= |
|
FINAL ANSWER:""" |
|
|
|
human_template = """{question}""" |
|
|
|
|
|
|
|
|
|
|
|
messages = [ |
|
SystemMessagePromptTemplate.from_template(system_template), |
|
HumanMessagePromptTemplate.from_template(human_template) |
|
|
|
] |
|
|
|
PROMPT = ChatPromptTemplate.from_messages(messages) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
memory = ConversationBufferWindowMemory( |
|
memory_key="chat_history", |
|
input_key="question", |
|
output_key="answer", |
|
return_messages=True, |
|
k=5) |
|
|
|
|
|
retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": k}) |
|
|
|
chain = ConversationalRetrievalChain.from_llm( |
|
llm=llm, |
|
chain_type="stuff", |
|
retriever=retriever, |
|
memory=memory, |
|
return_source_documents=True, |
|
combine_docs_chain_kwargs=dict({"prompt": PROMPT}) |
|
) |
|
|
|
return chain |
|
|
|
|
|
class ConversationalResponse: |
|
def __init__(self, url, branch, file_filter): |
|
self.url = url |
|
self.branch = branch |
|
self.file_filter = file_filter |
|
self.data = loader(self.url, self.branch, self.file_filter) |
|
self.chunks = split_data(self.data) |
|
self.vector_store = ingest_chunks(self.chunks) |
|
self.chain_type = "stuff" |
|
self.k = 10 |
|
self.chain = retreival(self.vector_store, self.k) |
|
|
|
def __call__(self, question): |
|
agent = self.chain(question) |
|
return agent['answer'] |