File size: 5,328 Bytes
37c3864
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50ca472
37c3864
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
925a8ef
 
 
 
 
37c3864
 
 
 
 
 
 
 
 
a43d3d9
37c3864
9f5f901
987af52
37c3864
 
 
 
 
 
 
 
 
 
 
 
925a8ef
64f7110
 
 
 
 
 
 
 
 
 
 
 
37c3864
 
 
 
df80773
37c3864
 
 
64f7110
 
 
925a8ef
64f7110
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import json
import os
import gradio as gr
from langchain_groq import ChatGroq
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import PyPDFLoader, DirectoryLoader
from transformers import pipeline
import traceback

# Load psychiatrist details from JSON file
def load_psychiatrists_data():
    try:
        json_path = "psychiatrists_data.json"  # Adjusted for local use
        with open(json_path, "r", encoding="utf-8") as file:
            data = json.load(file)
            return {key.strip().lower(): value for key, value in data.get("India", {}).items()}
    except FileNotFoundError:
        print("❌ Error: psychiatrists_data.json file not found.")
        return {}

doc_data = load_psychiatrists_data()

# Initialize sentiment analysis model
sentiment_classifier = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english")

# Initialize LLM
def initialize_llm():
    return ChatGroq(
        temperature=0,
        groq_api_key=os.getenv("GROQ_API_KEY"), 
        model_name="llama-3.3-70b-versatile"
    )

# Create or Load ChromaDB
def create_vector_db():
    db_path = "./chroma_db"

    if os.path.exists(db_path):
        embeddings = HuggingFaceBgeEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
        return Chroma(persist_directory=db_path, embedding_function=embeddings)

    print("📄 Creating new ChromaDB...")
    loader = DirectoryLoader("./data", glob="*.pdf", loader_cls=PyPDFLoader)  # Adjusted path
    documents = loader.load()
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
    texts = text_splitter.split_documents(documents)
    embeddings = HuggingFaceBgeEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
    vector_db = Chroma.from_documents(texts, embeddings, persist_directory=db_path)
    vector_db.persist()
    return vector_db

# Setup QA Chain
def setup_qa_chain(vector_db, llm):
    retriever = vector_db.as_retriever()
    prompt_template = """You are a compassionate mental health chatbot. Respond thoughtfully to the following question:
        {context}
        User: {question}
        Chatbot: """

    PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])

    return RetrievalQA.from_chain_type(
        llm=llm,
        chain_type="stuff",
        retriever=retriever,
        chain_type_kwargs={"prompt": PROMPT}
    )

# Initialize LLM and QA Chain before using them
llm = initialize_llm()
vector_db = create_vector_db()
qa_chain = setup_qa_chain(vector_db, llm)

# Detect Serious Issues using Transformer Model
def detect_serious_issue(user_message):
    result = sentiment_classifier(user_message)[0]
    negative_sentiment = result["label"] == "NEGATIVE" and result["score"] > 0.7
    return negative_sentiment

# Fetch Top Psychiatrists Based on Location
def get_psychiatrists_by_location(state):
    state = state.strip().lower()
    return doc_data.get(state, [])

def chatbot_interface(user_message, country, state):
    try:
        # Ensure default values if inputs are None
        user_message = user_message or ""  # Default empty string if None
        country = country or "Other"  # Default to "Other"
        state = state or ""  # Default empty string

        if user_message.lower() == "exit":
            return "Chatbot: Take care of yourself. Goodbye! ❤️"

        # Generate chatbot response
        response = qa_chain.run(user_message)

        # Check for serious issues
        if detect_serious_issue(user_message):
            if country.lower() == "india":
                doctors = get_psychiatrists_by_location(state)
                if doctors:
                    doc_info = "\n".join([f"🏥 {doc['name']}\n📍 {doc['hospital']}\n🧠 {doc['specialization']}\n" for doc in doctors])
                    return f"Chatbot: {response}\n\n🔹 Your mental health is important. If you're struggling, please don't hesitate to seek help. Here are some top psychiatrists in {state} who can assist you:\n{doc_info}"
                else:
                    return f"Chatbot: {response}\n\n⚠️ Sorry, no specific doctors found for {state}. Please visit a nearby hospital."
            else:
                return f"Chatbot: {response}\n\n⚠️ Currently, psychiatrist details are only available for India."

        return f"Chatbot: {response}"

    except Exception as e:
        error_message = traceback.format_exc()
        print("❌ ERROR DETECTED:\n", error_message)
        return f"⚠️ Error in chatbot: {str(e)}"


custom_css = """
body {
    background-color: #1e1e1e; /* Dark Grey */
    color: white;
    font-size: 18px;
}
h1 {
    text-align: center;
    color: #FFD700; /* Gold */
}
"""

gr.Interface(
    fn=chatbot_interface,
    inputs=[
        gr.Textbox(label="Enter your message"),
        gr.Textbox(label="Country"),
        gr.Textbox(label="State (if in India)")
    ],
    outputs=gr.Textbox(label="Output"),
    title="Welcome to AI-Powered Chatbot Created by Sandeep",
    theme="dark",
    css=custom_css
).launch(share=True)