|
import json |
|
import os |
|
import gradio as gr |
|
from langchain_groq import ChatGroq |
|
from langchain.embeddings import HuggingFaceBgeEmbeddings |
|
from langchain.vectorstores import Chroma |
|
from langchain.chains import RetrievalQA |
|
from langchain.prompts import PromptTemplate |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.document_loaders import PyPDFLoader, DirectoryLoader |
|
from transformers import pipeline |
|
import traceback |
|
|
|
|
|
def load_psychiatrists_data(): |
|
try: |
|
json_path = "psychiatrists_data.json" |
|
with open(json_path, "r", encoding="utf-8") as file: |
|
data = json.load(file) |
|
return {key.strip().lower(): value for key, value in data.get("India", {}).items()} |
|
except FileNotFoundError: |
|
print("❌ Error: psychiatrists_data.json file not found.") |
|
return {} |
|
|
|
doc_data = load_psychiatrists_data() |
|
|
|
|
|
sentiment_classifier = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english") |
|
|
|
|
|
def initialize_llm(): |
|
return ChatGroq( |
|
temperature=0, |
|
groq_api_key=os.getenv("GROQ_API_KEY"), |
|
model_name="llama-3.3-70b-versatile" |
|
) |
|
|
|
|
|
def create_vector_db(): |
|
db_path = "./chroma_db" |
|
|
|
if os.path.exists(db_path): |
|
embeddings = HuggingFaceBgeEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") |
|
return Chroma(persist_directory=db_path, embedding_function=embeddings) |
|
|
|
print("📄 Creating new ChromaDB...") |
|
loader = DirectoryLoader("./data", glob="*.pdf", loader_cls=PyPDFLoader) |
|
documents = loader.load() |
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) |
|
texts = text_splitter.split_documents(documents) |
|
embeddings = HuggingFaceBgeEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") |
|
vector_db = Chroma.from_documents(texts, embeddings, persist_directory=db_path) |
|
vector_db.persist() |
|
return vector_db |
|
|
|
|
|
def setup_qa_chain(vector_db, llm): |
|
retriever = vector_db.as_retriever() |
|
prompt_template = """You are a compassionate mental health chatbot. Respond thoughtfully to the following question: |
|
{context} |
|
User: {question} |
|
Chatbot: """ |
|
|
|
PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"]) |
|
|
|
return RetrievalQA.from_chain_type( |
|
llm=llm, |
|
chain_type="stuff", |
|
retriever=retriever, |
|
chain_type_kwargs={"prompt": PROMPT} |
|
) |
|
|
|
|
|
llm = initialize_llm() |
|
vector_db = create_vector_db() |
|
qa_chain = setup_qa_chain(vector_db, llm) |
|
|
|
|
|
def detect_serious_issue(user_message): |
|
result = sentiment_classifier(user_message)[0] |
|
negative_sentiment = result["label"] == "NEGATIVE" and result["score"] > 0.7 |
|
return negative_sentiment |
|
|
|
|
|
def get_psychiatrists_by_location(state): |
|
state = state.strip().lower() |
|
return doc_data.get(state, []) |
|
|
|
def chatbot_interface(user_message, country, state): |
|
try: |
|
|
|
user_message = user_message or "" |
|
country = country or "Other" |
|
state = state or "" |
|
|
|
if user_message.lower() == "exit": |
|
return "Chatbot: Take care of yourself. Goodbye! ❤️" |
|
|
|
|
|
response = qa_chain.run(user_message) |
|
|
|
|
|
if detect_serious_issue(user_message): |
|
if country.lower() == "india": |
|
doctors = get_psychiatrists_by_location(state) |
|
if doctors: |
|
doc_info = "\n".join([f"🏥 {doc['name']}\n📍 {doc['hospital']}\n🧠 {doc['specialization']}\n" for doc in doctors]) |
|
return f"Chatbot: {response}\n\n🔹 Your mental health is important. If you're struggling, please don't hesitate to seek help. Here are some top psychiatrists in {state} who can assist you:\n{doc_info}" |
|
else: |
|
return f"Chatbot: {response}\n\n⚠️ Sorry, no specific doctors found for {state}. Please visit a nearby hospital." |
|
else: |
|
return f"Chatbot: {response}\n\n⚠️ Currently, psychiatrist details are only available for India." |
|
|
|
return f"Chatbot: {response}" |
|
|
|
except Exception as e: |
|
error_message = traceback.format_exc() |
|
print("❌ ERROR DETECTED:\n", error_message) |
|
return f"⚠️ Error in chatbot: {str(e)}" |
|
|
|
|
|
custom_css = """ |
|
body { |
|
background-color: #1e1e1e; /* Dark Grey */ |
|
color: white; |
|
font-size: 18px; |
|
} |
|
h1 { |
|
text-align: center; |
|
color: #FFD700; /* Gold */ |
|
} |
|
""" |
|
|
|
gr.Interface( |
|
fn=chatbot_interface, |
|
inputs=[ |
|
gr.Textbox(label="Enter your message"), |
|
gr.Textbox(label="Country"), |
|
gr.Textbox(label="State (if in India)") |
|
], |
|
outputs=gr.Textbox(label="Output"), |
|
title="Welcome to AI-Powered Chatbot Created by Sandeep", |
|
theme="dark", |
|
css=custom_css |
|
).launch(share=True) |
|
|
|
|