sandeepgiri908's picture
Update app.py
987af52 verified
import json
import os
import gradio as gr
from langchain_groq import ChatGroq
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import PyPDFLoader, DirectoryLoader
from transformers import pipeline
import traceback
# Load psychiatrist details from JSON file
def load_psychiatrists_data():
try:
json_path = "psychiatrists_data.json" # Adjusted for local use
with open(json_path, "r", encoding="utf-8") as file:
data = json.load(file)
return {key.strip().lower(): value for key, value in data.get("India", {}).items()}
except FileNotFoundError:
print("❌ Error: psychiatrists_data.json file not found.")
return {}
doc_data = load_psychiatrists_data()
# Initialize sentiment analysis model
sentiment_classifier = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english")
# Initialize LLM
def initialize_llm():
return ChatGroq(
temperature=0,
groq_api_key=os.getenv("GROQ_API_KEY"),
model_name="llama-3.3-70b-versatile"
)
# Create or Load ChromaDB
def create_vector_db():
db_path = "./chroma_db"
if os.path.exists(db_path):
embeddings = HuggingFaceBgeEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
return Chroma(persist_directory=db_path, embedding_function=embeddings)
print("📄 Creating new ChromaDB...")
loader = DirectoryLoader("./data", glob="*.pdf", loader_cls=PyPDFLoader) # Adjusted path
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
texts = text_splitter.split_documents(documents)
embeddings = HuggingFaceBgeEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vector_db = Chroma.from_documents(texts, embeddings, persist_directory=db_path)
vector_db.persist()
return vector_db
# Setup QA Chain
def setup_qa_chain(vector_db, llm):
retriever = vector_db.as_retriever()
prompt_template = """You are a compassionate mental health chatbot. Respond thoughtfully to the following question:
{context}
User: {question}
Chatbot: """
PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
return RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
chain_type_kwargs={"prompt": PROMPT}
)
# Initialize LLM and QA Chain before using them
llm = initialize_llm()
vector_db = create_vector_db()
qa_chain = setup_qa_chain(vector_db, llm)
# Detect Serious Issues using Transformer Model
def detect_serious_issue(user_message):
result = sentiment_classifier(user_message)[0]
negative_sentiment = result["label"] == "NEGATIVE" and result["score"] > 0.7
return negative_sentiment
# Fetch Top Psychiatrists Based on Location
def get_psychiatrists_by_location(state):
state = state.strip().lower()
return doc_data.get(state, [])
def chatbot_interface(user_message, country, state):
try:
# Ensure default values if inputs are None
user_message = user_message or "" # Default empty string if None
country = country or "Other" # Default to "Other"
state = state or "" # Default empty string
if user_message.lower() == "exit":
return "Chatbot: Take care of yourself. Goodbye! ❤️"
# Generate chatbot response
response = qa_chain.run(user_message)
# Check for serious issues
if detect_serious_issue(user_message):
if country.lower() == "india":
doctors = get_psychiatrists_by_location(state)
if doctors:
doc_info = "\n".join([f"🏥 {doc['name']}\n📍 {doc['hospital']}\n🧠 {doc['specialization']}\n" for doc in doctors])
return f"Chatbot: {response}\n\n🔹 Your mental health is important. If you're struggling, please don't hesitate to seek help. Here are some top psychiatrists in {state} who can assist you:\n{doc_info}"
else:
return f"Chatbot: {response}\n\n⚠️ Sorry, no specific doctors found for {state}. Please visit a nearby hospital."
else:
return f"Chatbot: {response}\n\n⚠️ Currently, psychiatrist details are only available for India."
return f"Chatbot: {response}"
except Exception as e:
error_message = traceback.format_exc()
print("❌ ERROR DETECTED:\n", error_message)
return f"⚠️ Error in chatbot: {str(e)}"
custom_css = """
body {
background-color: #1e1e1e; /* Dark Grey */
color: white;
font-size: 18px;
}
h1 {
text-align: center;
color: #FFD700; /* Gold */
}
"""
gr.Interface(
fn=chatbot_interface,
inputs=[
gr.Textbox(label="Enter your message"),
gr.Textbox(label="Country"),
gr.Textbox(label="State (if in India)")
],
outputs=gr.Textbox(label="Output"),
title="Welcome to AI-Powered Chatbot Created by Sandeep",
theme="dark",
css=custom_css
).launch(share=True)