FinalV2 / app.py
zm-f21's picture
Update app.py
c1b32e3 verified
import gradio as gr
from transformers import pipeline
from sentence_transformers import SentenceTransformer
import pandas as pd
import numpy as np
import zipfile
import os
import re
import torch
import shutil
# =======================================================
# 1) Load Mistral LLM (FP16)
# =======================================================
llm = pipeline(
"text-generation",
model="mistralai/Mistral-7B-Instruct-v0.2",
torch_dtype=torch.float16,
device_map="auto"
)
# =======================================================
# 2) Load Embedding Model (Legal-BERT)
# =======================================================
embedding_model = SentenceTransformer("nlpaueb/legal-bert-base-uncased")
# =======================================================
# 3) Extract the ZIP dataset
# =======================================================
zip_path = "/app/provinces.zip" # Make sure this is uploaded in your HF Space
extract_folder = "/app/provinces_texts"
if os.path.exists(extract_folder):
shutil.rmtree(extract_folder)
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(extract_folder)
date_pattern = re.compile(r"(\d{4}[-]\d{2}[_-]\d{2})")
# =======================================================
# 4) Parse TXT files into documents
# =======================================================
def parse_metadata_and_content(raw_text):
if "CONTENT:" not in raw_text:
raise ValueError("File missing CONTENT: separator.")
header, content = raw_text.split("CONTENT:", 1)
metadata = {}
pdf_list = []
for line in header.strip().split("\n"):
if ":" in line and not line.strip().startswith("-"):
key, value = line.split(":", 1)
metadata[key.strip().upper()] = value.strip()
elif line.strip().startswith("-"):
pdf_list.append(line.strip())
if pdf_list:
metadata["PDF_LINKS"] = "\n".join(pdf_list)
return metadata, content.strip()
documents = []
for root, dirs, files in os.walk(extract_folder):
for filename in files:
if filename.startswith("._") or not filename.endswith(".txt"):
continue
filepath = os.path.join(root, filename)
try:
with open(filepath, "r", encoding="latin-1") as f:
raw = f.read()
metadata, content = parse_metadata_and_content(raw)
paragraphs = [p.strip() for p in content.split("\n\n") if p.strip()]
for p in paragraphs:
documents.append({
"source_title": metadata.get("SOURCE_TITLE", "Unknown"),
"province": metadata.get("PROVINCE", "Unknown"),
"last_updated": metadata.get("LAST_UPDATED", "Unknown"),
"url": metadata.get("URL", "N/A"),
"pdf_links": metadata.get("PDF_LINKS", ""),
"text": p
})
except Exception as e:
print(f"Skipping {filepath}: {e}")
print(f"Loaded {len(documents)} paragraphs from all provinces.")
# =======================================================
# 5) Build embeddings & dataframe
# =======================================================
texts = [d["text"] for d in documents]
embeddings = embedding_model.encode(texts).astype("float16")
df = pd.DataFrame(documents)
df["Embedding"] = list(embeddings)
print("Indexing complete. Total:", len(df))
# =======================================================
# 6) Retrieval
# =======================================================
def retrieve_with_pandas(query, province=None, top_k=2):
query_emb = embedding_model.encode([query])[0]
filtered = df if province is None else df[df["province"] == province]
filtered = filtered.copy()
filtered["Similarity"] = filtered["Embedding"].apply(
lambda x: np.dot(query_emb, x) / (np.linalg.norm(query_emb) * np.linalg.norm(x))
)
return filtered.sort_values("Similarity", ascending=False).head(top_k)
# =======================================================
# 7) Province detection
# =======================================================
def detect_province(query):
provinces = {
"yukon": "Yukon",
"alberta": "Alberta",
"bc": "British Columbia",
"british columbia": "British Columbia",
"manitoba": "Manitoba",
"nl": "Newfoundland and Labrador",
"newfoundland": "Newfoundland and Labrador",
"sask": "Saskatchewan",
"saskatchewan": "Saskatchewan",
"ontario": "Ontario",
"pei": "Prince Edward Island",
"prince edward island": "Prince Edward Island",
"quebec": "Quebec",
"nb": "New Brunswick",
"new brunswick": "New Brunswick",
"nova scotia": "Nova Scotia",
"nunavut": "Nunavut",
"nwt": "Northwest Territories",
"northwest territories": "Northwest Territories"
}
q = query.lower()
for key, prov in provinces.items():
if key in q:
return prov
return None
# =======================================================
# 8) Guardrails
# =======================================================
def is_disallowed(query):
banned = ["suicide", "harm yourself", "bomb", "weapon"]
return any(b in query.lower() for b in banned)
def is_off_topic(query):
tenancy_keywords = [
"tenant", "landlord", "rent", "evict", "lease", "deposit",
"tenancy", "rental", "apartment", "unit", "repair", "pets",
"heating", "notice"
]
q = query.lower()
return not any(k in q for k in tenancy_keywords)
INTRO_TEXT = (
"Hi! I'm a Canadian rental housing assistant. I can help you find, summarize, "
"and explain information from the Residential Tenancies Acts across all provinces.\n\n"
"**Important:** I'm not a lawyer and this is **not legal advice**."
)
# =======================================================
# 9) RAG Generation
# =======================================================
def generate_with_rag(query, province=None, top_k=2):
if is_disallowed(query):
return "Sorry — I can’t help with harmful or dangerous topics."
if is_off_topic(query):
return "Sorry — I can only answer questions about Canadian tenancy and housing law."
if province is None:
province = detect_province(query)
top_docs = retrieve_with_pandas(query, province=province, top_k=top_k)
if len(top_docs) == 0:
return "Sorry — I couldn't find matching information."
context = " ".join(top_docs["text"].tolist())
qa_examples = """
Q: My landlord took too long to install a safety item. Is that allowed?
A: Landlords should respond promptly to reasonable accommodation requests.
Q: I have kids making noise. Can I be evicted?
A: Reasonable family noise is expected; eviction should not be based on discrimination.
"""
prompt = f"""
Use the examples ONLY AS A STYLE GUIDE.
Do not repeat them and do not invent laws.
If the context does not contain the answer, say so.
Context:
{context}
Question:
{query}
Answer conversationally:
"""
output = llm(prompt, max_new_tokens=150)[0]["generated_text"]
answer = output.split("Answer conversationally:", 1)[-1].strip()
metadata = ""
for _, row in top_docs.iterrows():
metadata += (
f"- Province: {row['province']}\n"
f" Source: {row['source_title']}\n"
f" Updated: {row['last_updated']}\n"
f" URL: {row['url']}\n"
)
return f"{answer}\n\nSources Used:\n{metadata}"
# =======================================================
# 10) Gradio Chat Interface (INTRO only once, FIXED)
# =======================================================
INTRO_TUPLE = (None, INTRO_TEXT)
def chat_api(message, history):
# history is a list of tuples: [(user, bot), ...]
# Add user message
history.append((message, None))
# Generate bot reply
reply = generate_with_rag(message)
# Replace last tuple with completed (user, bot) pair
history[-1] = (message, reply)
return history, history
with gr.Blocks() as demo:
gr.Markdown("## Canada Residential Tenancy Assistant (RAG + Mistral 7B)")
chatbot = gr.Chatbot(
value=[INTRO_TUPLE], # must be a list of tuples!
height=500,
)
user_box = gr.Textbox(
label="Your question",
placeholder="Ask a question about rentals, repairs, evictions, deposits, etc..."
)
send_btn = gr.Button("Send")
send_btn.click(chat_api, inputs=[user_box, chatbot], outputs=[chatbot, chatbot])
user_box.submit(chat_api, inputs=[user_box, chatbot], outputs=[chatbot, chatbot])
if __name__ == "__main__":
demo.launch(share=True)