Spaces:
Sleeping
Sleeping
import gradio as gr | |
import base64 | |
import os | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
from langchain.vectorstores import Chroma | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.prompts import PromptTemplate | |
from langchain.llms import HuggingFacePipeline | |
from langchain.chains import RetrievalQA | |
from langchain.document_loaders import TextLoader, DirectoryLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
# === Load and Embed Documents === | |
loader = DirectoryLoader( | |
"courses", | |
glob="**/*.txt", | |
loader_cls=TextLoader | |
) | |
raw_docs = loader.load() | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=700, | |
chunk_overlap=100, | |
separators=["\n###", "\n##", "\n\n", "\n", ".", " "] | |
) | |
docs = text_splitter.split_documents(raw_docs) | |
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
vectorstore = Chroma.from_documents(docs, embedding=embedding_model) | |
retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 4}) | |
# === Prompt Template === | |
custom_prompt_template = """ | |
You are a helpful and knowledgeable course advisor at the University of Hertfordshire. Answer the student's question using only the information provided in the context below. | |
If the context does not contain the answer, politely respond that the information is not available. | |
Context: | |
{context} | |
Question: | |
{question} | |
Answer: | |
""" | |
prompt = PromptTemplate( | |
input_variables=["context", "question"], | |
template=custom_prompt_template | |
) | |
# === Load Falcon Model === | |
model_name = "tiiuae/Falcon3-1B-Instruct" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) | |
generator = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_new_tokens=256, | |
do_sample=False, | |
temperature=0.1, | |
top_p=0.9 | |
) | |
llm = HuggingFacePipeline(pipeline=generator, model_kwargs={"return_full_text": False}) | |
# === Setup Retrieval QA Chain === | |
qa_chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
retriever=retriever, | |
chain_type="stuff", | |
chain_type_kwargs={"prompt": prompt} | |
) | |
# === Avatar and Crest === | |
avatar_img = "images/UH.png" # Avatar shown beside bot messages | |
logo = "images/UH Crest.png" # Crest image | |
# # === Chat Logic with Course Memory === | |
def chat_with_bot(message, history, course_state): | |
lower_msg = message.lower() | |
# Try to detect course from first question | |
if "msc" in lower_msg: | |
course_state = message.strip() # Store it for later use | |
full_query = f"For the course '{course_state}': {message}" | |
elif "change course to" in lower_msg: | |
course_state = message.replace("change course to", "").strip() | |
response = f"๐ Course changed. Now answering based on: **{course_state}**" | |
history.append((message, response)) | |
return "", history, course_state | |
elif course_state: | |
full_query = f"For the course '{course_state}': {message}" | |
else: | |
full_query = message # No course memory yet | |
try: | |
raw_output = qa_chain.run(full_query) | |
response = raw_output.split("Answer:")[-1].strip() | |
response = response.replace("<|assistant|>", "").strip() | |
except Exception as e: | |
response = f"โ ๏ธ An error occurred: {str(e)}" | |
history.append((message, response)) | |
return "", history, course_state | |
# === Build Gradio UI === | |
initial_message = ( | |
"๐ Welcome! I'm your Assistant for the University of Hertfordshire.\n" | |
"Struggling to find something on our website?\n" | |
"Want to know anything about your MSc course?\n\n" | |
"Simply ask and we can get started!\n\n" | |
"โ ๏ธ Please avoid sharing personal details in this chat.\n" | |
"If personal details are ever needed, weโll always ask for consent first." | |
) | |
with gr.Blocks(title="๐ UH Academic Advisor", css=""" | |
.message.user { | |
background-color: #d2e5ff !important; | |
} | |
""") as demo: | |
# Convert crest image to base64 | |
with open(logo, "rb") as image_file: | |
encoded_string = base64.b64encode(image_file.read()).decode("utf-8") | |
# Logo header | |
gr.Markdown(f""" | |
<div style='display: flex; align-items: center; gap: 6px; line-height: 1;'> | |
<img src="data:image/png;base64,{encoded_string}" style="height: 30px; margin-bottom: 2px;"> | |
<h1 style='font-size: 18px; margin: 0;'>University of Hertfordshire Course Advisor Chatbot</h1> | |
</div> | |
""") | |
chatbot = gr.Chatbot( | |
avatar_images=(None, avatar_img), | |
value=[(initial_message, "I'm ready to help!")], | |
show_copy_button=True | |
) | |
state = gr.State("") # Keeps course memory in-session | |
with gr.Row(): | |
msg = gr.Textbox(placeholder="Ask a question...", lines=1, scale=5) | |
send_btn = gr.Button(" Send", scale=1) | |
msg.submit(chat_with_bot, [msg, chatbot, state], [msg, chatbot, state]) | |
send_btn.click(chat_with_bot, [msg, chatbot, state], [msg, chatbot, state]) | |
# === Launch === | |
demo.launch() | |