zm-f21's picture
Update app.py
9c66a72 verified
import gradio as gr
from transformers import pipeline
from sentence_transformers import SentenceTransformer
import pandas as pd
import numpy as np
import zipfile
import os
import re
import torch
# -----------------------------
# Load Mistral pipeline
# -----------------------------
llm = pipeline(
"text-generation",
model="mistralai/Mistral-7B-Instruct-v0.2",
torch_dtype=torch.float16,
device_map="auto"
)
# -----------------------------
# Load SentenceTransformer embeddings
# -----------------------------
embedding_model = SentenceTransformer("nlpaueb/legal-bert-base-uncased")
# -----------------------------
# Extract Provinces ZIP
# -----------------------------
zip_path = "/app/provinces.zip" # Make sure you upload this to your HF Space
extract_folder = "/app/provinces_texts"
# Remove old folder if exists
if os.path.exists(extract_folder):
import shutil
shutil.rmtree(extract_folder)
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(extract_folder)
# Regex to capture YYYY_MM_DD or YYYY-MM-DD anywhere in filename
date_pattern = re.compile(r"(\d{4}[-]\d{2}[_-]\d{2})")
# -----------------------------
# Parse TXT files and create documents
# -----------------------------
def parse_metadata_and_content(raw_text):
if "CONTENT:" not in raw_text:
raise ValueError("File missing CONTENT: separator.")
header, content = raw_text.split("CONTENT:", 1)
metadata = {}
lines = header.strip().split("\n")
pdf_list = []
for line in lines:
if ":" in line and not line.strip().startswith("-"):
key, value = line.split(":", 1)
metadata[key.strip().upper()] = value.strip()
elif line.strip().startswith("-"):
pdf_list.append(line.strip())
if pdf_list:
metadata["PDF_LINKS"] = "\n".join(pdf_list)
return metadata, content.strip()
documents = []
for root, dirs, files in os.walk(extract_folder):
for filename in files:
if filename.startswith("._") or not filename.endswith(".txt"):
continue
filepath = os.path.join(root, filename)
try:
with open(filepath, "r", encoding="latin-1") as f:
raw = f.read()
metadata, content = parse_metadata_and_content(raw)
paragraphs = [p.strip() for p in content.split("\n\n") if p.strip()]
for p in paragraphs:
documents.append({
"source_title": metadata.get("SOURCE_TITLE", "Unknown"),
"province": metadata.get("PROVINCE", "Unknown"),
"last_updated": metadata.get("LAST_UPDATED", "Unknown"),
"url": metadata.get("URL", "N/A"),
"pdf_links": metadata.get("PDF_LINKS", ""),
"text": p
})
except ValueError as e:
print(f"Skipping {filepath}: {e}")
continue
print(f"Loaded {len(documents)} paragraphs from all provinces.")
# -----------------------------
# Create embeddings and dataframe
# -----------------------------
texts = [d["text"] for d in documents]
embeddings = embedding_model.encode(texts).astype("float16")
df = pd.DataFrame(documents)
df["Embedding"] = list(embeddings)
print("Indexing complete. Total:", len(df))
# -----------------------------
# Retrieve with Pandas
# -----------------------------
def retrieve_with_pandas(query, province=None, top_k=2):
query_emb = embedding_model.encode([query])[0]
if province is not None:
filtered_df = df[df['province'] == province].copy()
else:
filtered_df = df.copy()
filtered_df['Similarity'] = filtered_df['Embedding'].apply(
lambda x: np.dot(query_emb, x) / (np.linalg.norm(query_emb) * np.linalg.norm(x))
)
return filtered_df.sort_values("Similarity", ascending=False).head(top_k)
# -----------------------------
# Province detection
# -----------------------------
def detect_province(query):
provinces = {
"yukon": "Yukon",
"alberta": "Alberta",
"bc": "British Columbia",
"british columbia": "British Columbia",
"manitoba": "Manitoba",
"nl": "Newfoundland and Labrador",
"newfoundland": "Newfoundland and Labrador",
"sask": "Saskatchewan",
"saskatchewan": "Saskatchewan",
"ontario": "Ontario",
"pei": "Prince Edward Island",
"prince edward island": "Prince Edward Island",
"quebec": "Quebec",
"nb": "New Brunswick",
"new brunswick": "New Brunswick",
"nova scotia": "Nova Scotia",
"nunavut": "Nunavut",
"nwt": "Northwest Territories",
"northwest territories": "Northwest Territories"
}
q = query.lower()
for key, prov in provinces.items():
if key in q:
return prov
return None
# -----------------------------
# Guardrails
# -----------------------------
def is_disallowed(query):
banned = ["kill", "suicide", "harm yourself", "bomb", "weapon"]
return any(b in query.lower() for b in banned)
def is_off_topic(query):
tenancy_keywords = [
"tenant", "landlord", "rent", "evict", "lease",
"deposit", "tenancy", "rental", "apartment",
"unit", "heating", "notice", "repair", "pets"
]
q = query.lower()
return not any(k in q for k in tenancy_keywords)
INTRO_TEXT = (
"Hi! I'm a Canadian rental housing assistant. I can help you find, summarize, "
"and explain information from the Residential Tenancies Acts across all provinces and territories.\n\n"
"**Important:** I'm not a lawyer and this is **not legal advice**. Use your own judgment.\n\n"
)
# -----------------------------
# RAG generation function
# -----------------------------
def generate_with_rag(query, province=None, top_k=2):
if is_disallowed(query):
return INTRO_TEXT + "Sorry — I can’t help with harmful or dangerous topics."
if is_off_topic(query):
return INTRO_TEXT + "Sorry — I can only answer questions about Canadian tenancy and housing law."
if province is None:
province = detect_province(query)
top_docs = retrieve_with_pandas(query, province=province, top_k=top_k)
if top_docs is None or len(top_docs) == 0:
return INTRO_TEXT + "Sorry — I couldn't find any matching information in the tenancy database."
context = " ".join(top_docs["text"].tolist())
# Few-shot style examples (style guide)
qa_examples = """
Q: I asked my landlord three months ago to install handrails in my bathroom. Can the landlord take a long time to respond?
A: Landlords should respond promptly to reasonable accommodation requests. If they delay unreasonably, you can file a discrimination complaint.
Q: My building manager keeps complaining about my children’s noise. Can I be evicted?
A: Reasonable noise from children is expected. If you're treated differently because you have children, you may file a complaint based on family status.
"""
prompt = f"""
Use the examples as a STYLE GUIDE ONLY.
DO NOT repeat the example questions.
DO NOT invent laws — only use the context provided.
If the context does not contain the answer, say you cannot confidently answer.
{qa_examples}
Context:
{context}
Question:
{query}
Answer conversationally:
"""
raw_output = llm(prompt, max_new_tokens=150)[0]["generated_text"]
answer = raw_output.split("Answer conversationally:", 1)[-1].strip() if "Answer conversationally:" in raw_output else raw_output.strip()
metadata_block = ""
for _, row in top_docs.iterrows():
metadata_block += (
f"- Province: {row['province']}\n"
f" Source: {row['source_title']}\n"
f" Updated: {row['last_updated']}\n"
f" URL: {row['url']}\n"
)
return INTRO_TEXT + f"{answer}\n\nSources Used:\n{metadata_block}"
# -----------------------------
# Gradio Chat
# -----------------------------
def respond(message, history):
answer = generate_with_rag(message)
history.append((message, answer))
return history, history
with gr.Blocks() as demo:
chatbot = gr.Chatbot()
msg = gr.Textbox(label="Your question")
msg.submit(respond, [msg, chatbot], [chatbot, chatbot])
gr.Markdown(
"Ask questions about Canadian tenancy and housing law.\n\n"
"**Note:** I am not a lawyer. Responses are generated from official documents."
)
if __name__ == "__main__":
demo.launch(share=True)