Spaces:
Runtime error
Runtime error
import pandas as pd | |
import json | |
import gradio as gr | |
import numpy as np | |
import torch | |
from sklearn.ensemble import GradientBoostingClassifier | |
from sklearn.metrics import classification_report | |
from sklearn.preprocessing import LabelEncoder | |
from transformers import AutoTokenizer, AutoModel | |
# Load datasets from local storage | |
train_df = pd.read_csv("Train_dataset.csv") | |
test_df = pd.read_csv("Test_dataset.csv") | |
# Load disease mapping | |
with open("disease_mapping.json", "r") as f: | |
disease_info = {item["Disease"]: item for item in json.load(f)} | |
# Encode disease labels | |
le = LabelEncoder() | |
train_df['label'] = le.fit_transform(train_df['Disease']) | |
# Filter out test samples with unseen diseases | |
test_df = test_df[test_df['Disease'].isin(le.classes_)] | |
test_df['label'] = le.transform(test_df['Disease']) | |
# Load SciBERT tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased") | |
model = AutoModel.from_pretrained("allenai/scibert_scivocab_uncased") | |
# Function to get [CLS] token embedding | |
def get_embedding(text): | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding='max_length', max_length=128) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
return outputs.last_hidden_state[:, 0, :].squeeze().numpy() | |
# Generate embeddings for training and testing | |
print("Generating embeddings for training data...") | |
X_train = np.vstack([get_embedding(text) for text in train_df['text']]) | |
y_train = train_df['label'].values | |
print("Generating embeddings for test data...") | |
X_test = np.vstack([get_embedding(text) for text in test_df['text']]) | |
y_test = test_df['label'].values | |
# Train Gradient Boosting classifier | |
print("Training classifier...") | |
clf = GradientBoostingClassifier() | |
clf.fit(X_train, y_train) | |
# Prediction function | |
def predict_disease(symptoms): | |
emb = get_embedding(symptoms).reshape(1, -1) | |
probs = clf.predict_proba(emb)[0] | |
top3_idx = np.argsort(probs)[::-1][:3] | |
results = [] | |
for idx in top3_idx: | |
disease = le.inverse_transform([idx])[0] | |
info = disease_info.get(disease, {}) | |
results.append({ | |
"Disease": disease, | |
"Confidence": round(probs[idx] * 100, 2), | |
"Description": info.get("Description", "N/A"), | |
"Severity": info.get("Severity", "N/A"), | |
"Precaution": info.get("Precaution", "N/A") | |
}) | |
return results | |
# Gradio chatbot interface | |
def chatbot_interface(symptom_text): | |
preds = predict_disease(symptom_text) | |
output = "" | |
for i, pred in enumerate(preds, 1): | |
output += f"### Prediction {i}\n" | |
output += f"- **Disease:** {pred['Disease']} ({pred['Confidence']}%)\n" | |
output += f"- **Description:** {pred['Description']}\n" | |
output += f"- **Severity:** {pred['Severity']}\n" | |
output += f"- **Precaution:** {pred['Precaution']}\n\n" | |
return output.strip() | |
# Launch Gradio UI | |
gr.Interface( | |
fn=chatbot_interface, | |
inputs=gr.Textbox(label="Enter your symptoms"), | |
outputs=gr.Markdown(), | |
title="SciBERT Medical Chatbot", | |
description="AI Medical Assistant that predicts diseases based on symptoms using SciBERT embeddings." | |
).launch() | |