import gradio as gr from transformers import pipeline from sentence_transformers import SentenceTransformer import pandas as pd import numpy as np import zipfile import os import re import torch import shutil # ======================================================= # 1) Load Mistral LLM (FP16) # ======================================================= llm = pipeline( "text-generation", model="mistralai/Mistral-7B-Instruct-v0.2", torch_dtype=torch.float16, device_map="auto" ) # ======================================================= # 2) Load Embedding Model (Legal-BERT) # ======================================================= embedding_model = SentenceTransformer("nlpaueb/legal-bert-base-uncased") # ======================================================= # 3) Extract the ZIP dataset # ======================================================= zip_path = "/app/provinces.zip" # Make sure this is uploaded in your HF Space extract_folder = "/app/provinces_texts" if os.path.exists(extract_folder): shutil.rmtree(extract_folder) with zipfile.ZipFile(zip_path, "r") as zip_ref: zip_ref.extractall(extract_folder) date_pattern = re.compile(r"(\d{4}[-]\d{2}[_-]\d{2})") # ======================================================= # 4) Parse TXT files into documents # ======================================================= def parse_metadata_and_content(raw_text): if "CONTENT:" not in raw_text: raise ValueError("File missing CONTENT: separator.") header, content = raw_text.split("CONTENT:", 1) metadata = {} pdf_list = [] for line in header.strip().split("\n"): if ":" in line and not line.strip().startswith("-"): key, value = line.split(":", 1) metadata[key.strip().upper()] = value.strip() elif line.strip().startswith("-"): pdf_list.append(line.strip()) if pdf_list: metadata["PDF_LINKS"] = "\n".join(pdf_list) return metadata, content.strip() documents = [] for root, dirs, files in os.walk(extract_folder): for filename in files: if filename.startswith("._") or not filename.endswith(".txt"): continue filepath = os.path.join(root, filename) try: with open(filepath, "r", encoding="latin-1") as f: raw = f.read() metadata, content = parse_metadata_and_content(raw) paragraphs = [p.strip() for p in content.split("\n\n") if p.strip()] for p in paragraphs: documents.append({ "source_title": metadata.get("SOURCE_TITLE", "Unknown"), "province": metadata.get("PROVINCE", "Unknown"), "last_updated": metadata.get("LAST_UPDATED", "Unknown"), "url": metadata.get("URL", "N/A"), "pdf_links": metadata.get("PDF_LINKS", ""), "text": p }) except Exception as e: print(f"Skipping {filepath}: {e}") print(f"Loaded {len(documents)} paragraphs from all provinces.") # ======================================================= # 5) Build embeddings & dataframe # ======================================================= texts = [d["text"] for d in documents] embeddings = embedding_model.encode(texts).astype("float16") df = pd.DataFrame(documents) df["Embedding"] = list(embeddings) print("Indexing complete. Total:", len(df)) # ======================================================= # 6) Retrieval # ======================================================= def retrieve_with_pandas(query, province=None, top_k=2): query_emb = embedding_model.encode([query])[0] filtered = df if province is None else df[df["province"] == province] filtered = filtered.copy() filtered["Similarity"] = filtered["Embedding"].apply( lambda x: np.dot(query_emb, x) / (np.linalg.norm(query_emb) * np.linalg.norm(x)) ) return filtered.sort_values("Similarity", ascending=False).head(top_k) # ======================================================= # 7) Province detection # ======================================================= def detect_province(query): provinces = { "yukon": "Yukon", "alberta": "Alberta", "bc": "British Columbia", "british columbia": "British Columbia", "manitoba": "Manitoba", "nl": "Newfoundland and Labrador", "newfoundland": "Newfoundland and Labrador", "sask": "Saskatchewan", "saskatchewan": "Saskatchewan", "ontario": "Ontario", "pei": "Prince Edward Island", "prince edward island": "Prince Edward Island", "quebec": "Quebec", "nb": "New Brunswick", "new brunswick": "New Brunswick", "nova scotia": "Nova Scotia", "nunavut": "Nunavut", "nwt": "Northwest Territories", "northwest territories": "Northwest Territories" } q = query.lower() for key, prov in provinces.items(): if key in q: return prov return None # ======================================================= # 8) Guardrails # ======================================================= def is_disallowed(query): banned = ["suicide", "harm yourself", "bomb", "weapon"] return any(b in query.lower() for b in banned) def is_off_topic(query): tenancy_keywords = [ "tenant", "landlord", "rent", "evict", "lease", "deposit", "tenancy", "rental", "apartment", "unit", "repair", "pets", "heating", "notice" ] q = query.lower() return not any(k in q for k in tenancy_keywords) INTRO_TEXT = ( "Hi! I'm a Canadian rental housing assistant. I can help you find, summarize, " "and explain information from the Residential Tenancies Acts across all provinces.\n\n" "**Important:** I'm not a lawyer and this is **not legal advice**." ) # ======================================================= # 9) RAG Generation # ======================================================= def generate_with_rag(query, province=None, top_k=2): if is_disallowed(query): return "Sorry — I can’t help with harmful or dangerous topics." if is_off_topic(query): return "Sorry — I can only answer questions about Canadian tenancy and housing law." if province is None: province = detect_province(query) top_docs = retrieve_with_pandas(query, province=province, top_k=top_k) if len(top_docs) == 0: return "Sorry — I couldn't find matching information." context = " ".join(top_docs["text"].tolist()) qa_examples = """ Q: My landlord took too long to install a safety item. Is that allowed? A: Landlords should respond promptly to reasonable accommodation requests. Q: I have kids making noise. Can I be evicted? A: Reasonable family noise is expected; eviction should not be based on discrimination. """ prompt = f""" Use the examples ONLY AS A STYLE GUIDE. Do not repeat them and do not invent laws. If the context does not contain the answer, say so. Context: {context} Question: {query} Answer conversationally: """ output = llm(prompt, max_new_tokens=150)[0]["generated_text"] answer = output.split("Answer conversationally:", 1)[-1].strip() metadata = "" for _, row in top_docs.iterrows(): metadata += ( f"- Province: {row['province']}\n" f" Source: {row['source_title']}\n" f" Updated: {row['last_updated']}\n" f" URL: {row['url']}\n" ) return f"{answer}\n\nSources Used:\n{metadata}" # ======================================================= # 10) Gradio Chat Interface (INTRO only once, FIXED) # ======================================================= INTRO_TUPLE = (None, INTRO_TEXT) def chat_api(message, history): # history is a list of tuples: [(user, bot), ...] # Add user message history.append((message, None)) # Generate bot reply reply = generate_with_rag(message) # Replace last tuple with completed (user, bot) pair history[-1] = (message, reply) return history, history with gr.Blocks() as demo: gr.Markdown("## Canada Residential Tenancy Assistant (RAG + Mistral 7B)") chatbot = gr.Chatbot( value=[INTRO_TUPLE], # must be a list of tuples! height=500, ) user_box = gr.Textbox( label="Your question", placeholder="Ask a question about rentals, repairs, evictions, deposits, etc..." ) send_btn = gr.Button("Send") send_btn.click(chat_api, inputs=[user_box, chatbot], outputs=[chatbot, chatbot]) user_box.submit(chat_api, inputs=[user_box, chatbot], outputs=[chatbot, chatbot]) if __name__ == "__main__": demo.launch(share=True)