BillyZ1129 commited on
Commit
79bcb1b
·
verified ·
1 Parent(s): a1de76d

Upload 7 files

Browse files
Files changed (7) hide show
  1. Full_Patient_Risk_Prediction_Dataset.csv +0 -0
  2. README.md +119 -19
  3. app.py +598 -0
  4. models.py +549 -0
  5. requirements.txt +12 -3
  6. style.css +245 -0
  7. utils.py +170 -0
Full_Patient_Risk_Prediction_Dataset.csv ADDED
The diff for this file is too large to render. See raw diff
 
README.md CHANGED
@@ -1,19 +1,119 @@
1
- ---
2
- title: FinalProject
3
- emoji: 🚀
4
- colorFrom: red
5
- colorTo: red
6
- sdk: docker
7
- app_port: 8501
8
- tags:
9
- - streamlit
10
- pinned: false
11
- short_description: deep learning final project
12
- ---
13
-
14
- # Welcome to Streamlit!
15
-
16
- Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
17
-
18
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
19
- forums](https://discuss.streamlit.io).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AI Medical Consultation System
2
+
3
+ An intelligent medical consultation system built with Streamlit that uses multiple AI models to analyze patient symptoms, assess risk levels, and provide personalized medical recommendations.
4
+
5
+ ## 🚀 Features
6
+
7
+ - **Natural Language Symptom Description**: Patients describe their symptoms in natural language
8
+ - **Symptom Extraction**: Automatically extracts key symptoms and duration information using BioBERT
9
+ - **Risk Assessment**: Classifies the risk level (Low, Medium, High) using PubMedBERT
10
+ - **Personalized Recommendations**: Generates tailored medical recommendations using a fine-tuned T5 model
11
+ - **User-Friendly Interface**: Clean, intuitive UI with interactive visualizations
12
+ - **Consultation History**: Save and review past consultations
13
+ - **Responsive Design**: Works on desktop and mobile devices
14
+
15
+ ## 📋 System Components
16
+
17
+ The system consists of three AI models working in a pipeline:
18
+
19
+ 1. **Symptom Extraction Model**: [dmis-lab/biobert-v1.1](https://huggingface.co/dmis-lab/biobert-v1.1)
20
+ - Identifies symptoms and their duration in the patient's description
21
+ - Implemented as a Named Entity Recognition (NER) task
22
+
23
+ 2. **Risk Classification Model**: [microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract](https://huggingface.co/microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract)
24
+ - Classifies the patient's condition into Low, Medium, or High risk
25
+ - Fine-tuned for medical risk assessment
26
+
27
+ 3. **Recommendation Generation Model**: Fine-tuned T5-small
28
+ - Generates personalized medical recommendations
29
+ - Fine-tuned on a dataset of medical advice and recommendations
30
+
31
+ ## 🛠️ Installation
32
+
33
+ 1. Clone this repository:
34
+ ```bash
35
+ git clone <repository-url>
36
+ cd medical-consultation-system
37
+ ```
38
+
39
+ 2. Install the required packages:
40
+ ```bash
41
+ pip install -r requirements.txt
42
+ ```
43
+
44
+ 3. Download the fine-tuned T5 model (if not included):
45
+ ```bash
46
+ # Instructions for downloading or fine-tuning the T5 model would go here
47
+ # For example:
48
+ # python download_models.py
49
+ ```
50
+
51
+ ## 🚀 Usage
52
+
53
+ 1. Run the Streamlit app:
54
+ ```bash
55
+ streamlit run app.py
56
+ ```
57
+
58
+ 2. Open your web browser and navigate to the URL displayed in your terminal (typically http://localhost:8501)
59
+
60
+ 3. Enter your symptoms in natural language in the text area
61
+
62
+ 4. Click the "Analyze Symptoms" button to process your input
63
+
64
+ 5. Review the results in the various tabs:
65
+ - **Overview**: Summary of symptoms, risk level, and recommendations
66
+ - **Symptoms Analysis**: Detailed analysis of extracted symptoms and duration
67
+ - **Risk Assessment**: Risk level with confidence and explanation
68
+ - **Recommendations**: Detailed medical recommendations and department suggestions
69
+
70
+ ## 📊 Example
71
+
72
+ Input:
73
+ ```
74
+ I've been experiencing severe headaches and dizziness for about 2 weeks. Sometimes I also feel nauseous.
75
+ ```
76
+
77
+ Output:
78
+ - **Extracted Symptoms**: Headaches, dizziness, nauseous
79
+ - **Duration**: 2 weeks
80
+ - **Risk Level**: Medium
81
+ - **Recommendation**: Personalized guidance on seeking medical attention and home care
82
+
83
+ ## 📁 Project Structure
84
+
85
+ ```
86
+ medical-consultation-system/
87
+ ├── app.py # Main Streamlit application
88
+ ├── models.py # Model loading and inference code
89
+ ├── utils.py # Helper functions and utilities
90
+ ├── style.css # Custom CSS styling
91
+ ├── requirements.txt # Package dependencies
92
+ ├── README.md # Project documentation
93
+ └── consultation_history/ # Stored consultation records (created on first use)
94
+ ```
95
+
96
+ ## ⚠️ Limitations and Disclaimer
97
+
98
+ - This system is for **informational purposes only** and is not a substitute for professional medical advice, diagnosis, or treatment.
99
+ - The AI models may not capture all symptoms or correctly assess all conditions.
100
+ - Risk assessments and recommendations are based on general patterns and may not be accurate for specific individual cases.
101
+ - Always consult with qualified healthcare providers for medical concerns.
102
+
103
+ ## 🔧 Customization
104
+
105
+ You can customize the system by:
106
+ - Fine-tuning the models on different or additional datasets
107
+ - Modifying the UI in app.py
108
+ - Adjusting the CSS styling in style.css
109
+ - Adding new features like multilingual support or additional visualization options
110
+
111
+ ## 📝 License
112
+
113
+ This project is licensed under the MIT License - see the LICENSE file for details.
114
+
115
+ ## 🙏 Acknowledgements
116
+
117
+ - [Hugging Face](https://huggingface.co/) for providing access to pre-trained models
118
+ - [Streamlit](https://streamlit.io/) for the web application framework
119
+ - [Plotly](https://plotly.com/) for interactive visualizations
app.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import time
4
+ import torch
5
+ import os
6
+ from models import MedicalConsultationPipeline
7
+ from utils import (
8
+ highlight_text_with_entities,
9
+ format_duration,
10
+ create_risk_gauge,
11
+ create_risk_probability_chart,
12
+ save_consultation,
13
+ load_consultation_history,
14
+ init_session_state,
15
+ RISK_COLORS
16
+ )
17
+
18
+ # Page configuration
19
+ st.set_page_config(
20
+ page_title="AI Medical Consultation",
21
+ page_icon="🩺",
22
+ layout="wide",
23
+ initial_sidebar_state="expanded"
24
+ )
25
+
26
+ # Custom CSS
27
+ def load_css():
28
+ with open("style.css", "r") as f:
29
+ st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
30
+
31
+ # 检查本地是否有fine-tuned的T5模型
32
+ def find_fine_tuned_model():
33
+ possible_local_paths = [
34
+ "./finetuned_t5-small", # 添加用户提供的微调模型路径
35
+ "./t5-small-medical-recommendation",
36
+ "./models/t5-small-medical-recommendation",
37
+ "./fine_tuned_models/t5-small",
38
+ "./output",
39
+ "./fine_tuning_output"
40
+ ]
41
+
42
+ for path in possible_local_paths:
43
+ if os.path.exists(path):
44
+ return path
45
+
46
+ return "t5-small" # 如果没有找到,返回基础模型
47
+
48
+ # Initialize session state
49
+ init_session_state()
50
+
51
+ # Apply custom CSS
52
+ load_css()
53
+
54
+ # Sidebar for settings and history
55
+ with st.sidebar:
56
+ st.image("https://img.icons8.com/fluency/96/000000/hospital-3.png", width=80)
57
+ st.title("AI Medical Assistant")
58
+
59
+ st.markdown("---")
60
+ with st.expander("⚙️ Settings", expanded=False):
61
+ # Model settings
62
+ st.subheader("Model Settings")
63
+ symptom_model = st.selectbox(
64
+ "Symptom Extraction Model",
65
+ ["dmis-lab/biobert-v1.1"],
66
+ index=0,
67
+ disabled=st.session_state.loaded_models # Disable after models are loaded
68
+ )
69
+
70
+ risk_model = st.selectbox(
71
+ "Risk Classification Model",
72
+ ["microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract"],
73
+ index=0,
74
+ disabled=st.session_state.loaded_models # Disable after models are loaded
75
+ )
76
+
77
+ # 查找可用的t5模型
78
+ available_t5_model = find_fine_tuned_model()
79
+ recommendation_model_options = []
80
+
81
+ # 总是添加基础模型
82
+ recommendation_model_options.append("t5-small (base model)")
83
+
84
+ # 如果找到了fine-tuned模型,添加到选项中
85
+ if available_t5_model != "t5-small":
86
+ recommendation_model_options.insert(0, f"{available_t5_model} (fine-tuned)")
87
+
88
+ recommendation_model_label = st.selectbox(
89
+ "Recommendation Model",
90
+ recommendation_model_options,
91
+ index=0,
92
+ disabled=st.session_state.loaded_models # Disable after models are loaded
93
+ )
94
+
95
+ # 提取实际的模型路径
96
+ if "(fine-tuned)" in recommendation_model_label:
97
+ recommendation_model = available_t5_model
98
+ else:
99
+ recommendation_model = "t5-small"
100
+
101
+ # Device selection
102
+ device = st.radio(
103
+ "Compute Device",
104
+ ["CPU", "GPU (if available)"],
105
+ index=1 if torch.cuda.is_available() else 0,
106
+ disabled=st.session_state.loaded_models # Disable after models are loaded
107
+ )
108
+ device = "cuda" if device == "GPU (if available)" and torch.cuda.is_available() else "cpu"
109
+
110
+ if st.session_state.loaded_models:
111
+ st.info("注意:设置已锁定,因为模型已加载。要更改设置,请刷新页面。")
112
+
113
+ # Consultation history section
114
+ st.markdown("---")
115
+ st.subheader("📋 Consultation History")
116
+
117
+ # Load consultation history
118
+ if st.button("Refresh History"):
119
+ st.session_state.consultation_history = load_consultation_history()
120
+ st.success("History refreshed!")
121
+
122
+ # If history is not already loaded, load it
123
+ if not st.session_state.consultation_history:
124
+ st.session_state.consultation_history = load_consultation_history()
125
+
126
+ # Display history items
127
+ if not st.session_state.consultation_history:
128
+ st.info("No previous consultations found.")
129
+ else:
130
+ for i, consultation in enumerate(st.session_state.consultation_history[:10]): # Show only the 10 most recent
131
+ timestamp = pd.to_datetime(consultation.get("timestamp", "")).strftime("%Y-%m-%d %H:%M")
132
+ risk_level = consultation.get("risk", {}).get("risk_level", "Unknown")
133
+ risk_color = RISK_COLORS.get(risk_level, "#6c757d")
134
+
135
+ # Create a clickable history item
136
+ history_item = f"""
137
+ <div class='history-item' onclick=''>
138
+ <strong>Patient Input:</strong> {consultation.get('input_text', '')[:50]}...<br>
139
+ <strong>Time:</strong> {timestamp}<br>
140
+ <strong>Risk Level:</strong> <span style='color:{risk_color};'>{risk_level}</span>
141
+ </div>
142
+ """
143
+ clicked = st.markdown(history_item, unsafe_allow_html=True)
144
+
145
+ # If clicked, set this consultation as the current result
146
+ if clicked:
147
+ st.session_state.current_result = consultation
148
+
149
+ # Main app layout
150
+ st.markdown("<h1 class='main-header'>AI-Powered Medical Consultation</h1>", unsafe_allow_html=True)
151
+
152
+ # Introduction row
153
+ col1, col2 = st.columns([2, 1])
154
+ with col1:
155
+ st.markdown("""
156
+ <div class="card">
157
+ <h2 class="card-header">How it Works</h2>
158
+ <p>This AI-powered medical consultation system helps you understand your symptoms and provides guidance on next steps.</p>
159
+ <p><strong>Simply describe your symptoms</strong> in natural language and the system will:</p>
160
+ <ol>
161
+ <li>Extract key symptoms and duration information</li>
162
+ <li>Assess your risk level</li>
163
+ <li>Generate personalized medical recommendations</li>
164
+ </ol>
165
+ <p><em>Note: This system is for informational purposes only and does not replace professional medical advice.</em></p>
166
+ </div>
167
+ """, unsafe_allow_html=True)
168
+
169
+ with col2:
170
+ st.markdown("""
171
+ <div class="card">
172
+ <h2 class="card-header">Example Inputs</h2>
173
+ <ul>
174
+ <li>"I've been experiencing severe headaches and dizziness for about 2 weeks. Sometimes I also feel nauseous."</li>
175
+ <li>"My child has had a high fever of 39°C since yesterday and is coughing a lot."</li>
176
+ <li>"I've noticed a persistent rash on my arm for the past 3 days, it's itchy and slightly swollen."</li>
177
+ </ul>
178
+ </div>
179
+ """, unsafe_allow_html=True)
180
+
181
+ # 显示当前使用的模型信息
182
+ model_info = f"""
183
+ <div class="card">
184
+ <h2 class="card-header">当前模型配置</h2>
185
+ <ul>
186
+ <li><strong>症状抽取模型:</strong> {symptom_model}</li>
187
+ <li><strong>风险分类模型:</strong> {risk_model}</li>
188
+ <li><strong>推荐生成模型:</strong> {recommendation_model} {"(微调模型)" if recommendation_model != "t5-small" else "(基础模型)"}</li>
189
+ <li><strong>计算设备:</strong> {device.upper()}</li>
190
+ </ul>
191
+ </div>
192
+ """
193
+ st.markdown(model_info, unsafe_allow_html=True)
194
+
195
+ # Load models on first run or when settings change
196
+ @st.cache_resource
197
+ def load_pipeline(_symptom_model, _risk_model, _recommendation_model, _device):
198
+ return MedicalConsultationPipeline(
199
+ symptom_model=_symptom_model,
200
+ risk_model=_risk_model,
201
+ recommendation_model=_recommendation_model,
202
+ device=_device
203
+ )
204
+
205
+ # Only load models if they haven't been loaded yet
206
+ if not st.session_state.loaded_models:
207
+ try:
208
+ with st.spinner("Loading AI models... This may take a minute..."):
209
+ pipeline = load_pipeline(symptom_model, risk_model, recommendation_model, device)
210
+ st.session_state.pipeline = pipeline
211
+ st.session_state.loaded_models = True
212
+ st.success("✅ Models loaded successfully!")
213
+ except Exception as e:
214
+ st.error(f"Error loading models: {str(e)}")
215
+ else:
216
+ pipeline = st.session_state.pipeline
217
+
218
+ # Input section
219
+ st.markdown("<h2 class='subheader'>Describe Your Symptoms</h2>", unsafe_allow_html=True)
220
+
221
+ # Text input for patient description
222
+ patient_input = st.text_area(
223
+ "Please describe your symptoms, including when they started and any other relevant information:",
224
+ height=150,
225
+ placeholder="Example: I've been experiencing severe headaches and dizziness for about 2 weeks. Sometimes I also feel nauseous."
226
+ )
227
+
228
+ # Process button
229
+ col1, col2, col3 = st.columns([1, 1, 1])
230
+ with col2:
231
+ process_button = st.button("Analyze Symptoms", type="primary", use_container_width=True)
232
+
233
+ # Handle processing
234
+ if process_button and patient_input and not st.session_state.is_processing:
235
+ st.session_state.is_processing = True
236
+
237
+ # Process the input
238
+ with st.spinner("Analyzing your symptoms..."):
239
+ try:
240
+ # Process through pipeline
241
+ start_time = time.time()
242
+ result = pipeline.process(patient_input)
243
+ elapsed_time = time.time() - start_time
244
+
245
+ # Save result to session state
246
+ st.session_state.current_result = result
247
+
248
+ # Save consultation to history
249
+ save_consultation(result)
250
+
251
+ # Success message
252
+ st.success(f"Analysis completed in {elapsed_time:.2f} seconds!")
253
+ except Exception as e:
254
+ st.error(f"Error processing your input: {str(e)}")
255
+
256
+ st.session_state.is_processing = False
257
+
258
+ # Results section - show if there's a current result
259
+ if st.session_state.current_result:
260
+ result = st.session_state.current_result
261
+
262
+ st.markdown("<h2 class='subheader'>Consultation Results</h2>", unsafe_allow_html=True)
263
+
264
+ # Create tabs for different sections of the results
265
+ tabs = st.tabs(["Overview", "Symptoms Analysis", "Risk Assessment", "Recommendations"])
266
+
267
+ # Overview tab - summary of all results
268
+ with tabs[0]:
269
+ col1, col2 = st.columns([3, 2])
270
+
271
+ with col1:
272
+ st.markdown("""
273
+ <div class="card">
274
+ <h3 class="card-header">Patient Description</h3>
275
+ """, unsafe_allow_html=True)
276
+
277
+ # Highlight symptoms and duration in the text
278
+ highlighted_text = highlight_text_with_entities(
279
+ result.get("input_text", ""),
280
+ result.get("extraction", {}).get("symptoms", [])
281
+ )
282
+ st.markdown(f"<p>{highlighted_text}</p>", unsafe_allow_html=True)
283
+
284
+ st.markdown("</div>", unsafe_allow_html=True)
285
+
286
+ # Recommendations card
287
+ st.markdown("""
288
+ <div class="card">
289
+ <h3 class="card-header">Medical Recommendations</h3>
290
+ <div class="recommendation-container">
291
+ """, unsafe_allow_html=True)
292
+
293
+ recommendation = result.get("recommendation", "No recommendations available.")
294
+ st.markdown(f"<p>{recommendation}</p>", unsafe_allow_html=True)
295
+
296
+ st.markdown("""
297
+ </div>
298
+ <p><em>Note: This is AI-generated guidance and should not replace professional medical advice.</em></p>
299
+ </div>
300
+ """, unsafe_allow_html=True)
301
+
302
+ with col2:
303
+ # Risk level card
304
+ risk_level = result.get("risk", {}).get("risk_level", "Unknown")
305
+ confidence = result.get("risk", {}).get("confidence", 0.0)
306
+
307
+ st.markdown(f"""
308
+ <div class="card">
309
+ <h3 class="card-header">Risk Assessment</h3>
310
+ <div style="text-align: center;">
311
+ <span class="risk-{risk_level.lower()}" style="font-size: 1.8rem;">{risk_level}</span>
312
+ <p>Confidence: {confidence:.1%}</p>
313
+ </div>
314
+ """, unsafe_allow_html=True)
315
+
316
+ # Add risk gauge
317
+ risk_gauge = create_risk_gauge(risk_level, confidence)
318
+ st.plotly_chart(risk_gauge, use_container_width=True, key="overview_risk_gauge")
319
+
320
+ st.markdown("</div>", unsafe_allow_html=True)
321
+
322
+ # Extracted symptoms summary
323
+ st.markdown("""
324
+ <div class="card">
325
+ <h3 class="card-header">Key Findings</h3>
326
+ """, unsafe_allow_html=True)
327
+
328
+ symptoms = result.get("extraction", {}).get("symptoms", [])
329
+ duration = result.get("extraction", {}).get("duration", [])
330
+
331
+ if symptoms:
332
+ st.markdown("<strong>Identified Symptoms:</strong>", unsafe_allow_html=True)
333
+ for symptom in symptoms:
334
+ st.markdown(f"• {symptom['text']} ({symptom['score']:.1%} confidence)", unsafe_allow_html=True)
335
+ else:
336
+ st.info("No specific symptoms identified")
337
+
338
+ st.markdown("<br><strong>Duration Information:</strong>", unsafe_allow_html=True)
339
+ st.markdown(f"<p>{format_duration(duration)}</p>", unsafe_allow_html=True)
340
+
341
+ st.markdown("</div>", unsafe_allow_html=True)
342
+
343
+ # Symptoms Analysis tab
344
+ with tabs[1]:
345
+ st.markdown("""
346
+ <div class="card">
347
+ <h3 class="card-header">Detailed Symptom Analysis</h3>
348
+ """, unsafe_allow_html=True)
349
+
350
+ symptoms = result.get("extraction", {}).get("symptoms", [])
351
+
352
+ if symptoms:
353
+ # Create a DataFrame for symptoms
354
+ symptom_df = pd.DataFrame([
355
+ {
356
+ "Symptom": s["text"],
357
+ "Confidence": s["score"],
358
+ "Start Position": s["start"],
359
+ "End Position": s["end"]
360
+ } for s in symptoms
361
+ ])
362
+
363
+ # Sort by confidence
364
+ symptom_df = symptom_df.sort_values("Confidence", ascending=False)
365
+
366
+ # Display DataFrame
367
+ st.dataframe(symptom_df, use_container_width=True)
368
+
369
+ # Bar chart of symptoms by confidence
370
+ if len(symptoms) > 1:
371
+ st.markdown("<h4>Symptom Confidence Scores</h4>", unsafe_allow_html=True)
372
+ chart_data = symptom_df[["Symptom", "Confidence"]].set_index("Symptom")
373
+ st.bar_chart(chart_data, use_container_width=True)
374
+ else:
375
+ st.info("No specific symptoms were detected in the input text.")
376
+
377
+ st.markdown("</div>", unsafe_allow_html=True)
378
+
379
+ # Duration information card
380
+ st.markdown("""
381
+ <div class="card">
382
+ <h3 class="card-header">Duration Analysis</h3>
383
+ """, unsafe_allow_html=True)
384
+
385
+ duration = result.get("extraction", {}).get("duration", [])
386
+
387
+ if duration:
388
+ # Create a DataFrame for duration information
389
+ duration_df = pd.DataFrame([
390
+ {
391
+ "Duration": d["text"],
392
+ "Start Position": d["start"],
393
+ "End Position": d["end"]
394
+ } for d in duration
395
+ ])
396
+
397
+ # Display DataFrame
398
+ st.dataframe(duration_df, use_container_width=True)
399
+
400
+ # Highlight duration in text
401
+ st.markdown("<h4>Original Text with Duration Highlighted</h4>", unsafe_allow_html=True)
402
+
403
+ # Highlight duration in a different color
404
+ duration_text = result.get("input_text", "")
405
+ sorted_duration = sorted(duration, key=lambda x: x['start'], reverse=True)
406
+
407
+ for d in sorted_duration:
408
+ start = d['start']
409
+ end = d['end']
410
+ highlight = f"<span class='duration-highlight'>{duration_text[start:end]}</span>"
411
+ duration_text = duration_text[:start] + highlight + duration_text[end:]
412
+
413
+ st.markdown(f"<p>{duration_text}</p>", unsafe_allow_html=True)
414
+ else:
415
+ st.info("No specific duration information was detected in the input text.")
416
+
417
+ st.markdown("</div>", unsafe_allow_html=True)
418
+
419
+ # Risk Assessment tab
420
+ with tabs[2]:
421
+ st.markdown("""
422
+ <div class="card">
423
+ <h3 class="card-header">Risk Level Assessment</h3>
424
+ """, unsafe_allow_html=True)
425
+
426
+ risk_data = result.get("risk", {})
427
+ risk_level = risk_data.get("risk_level", "Unknown")
428
+ confidence = risk_data.get("confidence", 0.0)
429
+ probabilities = risk_data.get("all_probabilities", {})
430
+
431
+ col1, col2 = st.columns(2)
432
+
433
+ with col1:
434
+ # Display risk gauge
435
+ risk_gauge = create_risk_gauge(risk_level, confidence)
436
+ st.plotly_chart(risk_gauge, use_container_width=True, key="risk_assessment_gauge")
437
+
438
+ with col2:
439
+ # Display probability distribution
440
+ prob_chart = create_risk_probability_chart(probabilities)
441
+ st.plotly_chart(prob_chart, use_container_width=True, key="risk_probability_chart")
442
+
443
+ # Risk level descriptions
444
+ st.markdown("<h4>Risk Levels Explained</h4>", unsafe_allow_html=True)
445
+
446
+ risk_descriptions = {
447
+ "Low": """
448
+ <div style="border-left: 3px solid #7FD8BE; padding-left: 10px; margin: 10px 0;">
449
+ <strong style="color: #7FD8BE;">Low Risk</strong>: Your symptoms suggest a condition that is likely non-urgent.
450
+ While it's good to stay vigilant, these types of conditions typically don't require immediate medical attention
451
+ and can often be managed with self-care or a routine appointment within the next few days or weeks.
452
+ </div>
453
+ """,
454
+
455
+ "Medium": """
456
+ <div style="border-left: 3px solid #FFC857; padding-left: 10px; margin: 10px 0;">
457
+ <strong style="color: #FFC857;">Medium Risk</strong>: Your symptoms indicate a condition that may need medical attention
458
+ soon, but may not be an emergency. Consider scheduling an appointment with your primary care provider within 24-48 hours,
459
+ or visit an urgent care facility if your symptoms worsen or if you cannot schedule a timely appointment.
460
+ </div>
461
+ """,
462
+
463
+ "High": """
464
+ <div style="border-left: 3px solid #E84855; padding-left: 10px; margin: 10px 0;">
465
+ <strong style="color: #E84855;">High Risk</strong>: Your symptoms suggest a potentially serious condition that requires
466
+ prompt medical attention. Consider seeking emergency care or calling emergency services if symptoms are severe or rapidly
467
+ worsening, especially if they include difficulty breathing, severe pain, or altered consciousness.
468
+ </div>
469
+ """
470
+ }
471
+
472
+ # Display the description for the current risk level first
473
+ if risk_level in risk_descriptions:
474
+ st.markdown(risk_descriptions[risk_level], unsafe_allow_html=True)
475
+
476
+ # Then display the others
477
+ for level, desc in risk_descriptions.items():
478
+ if level != risk_level:
479
+ st.markdown(desc, unsafe_allow_html=True)
480
+
481
+ st.markdown("</div>", unsafe_allow_html=True)
482
+
483
+ # Disclaimer
484
+ st.warning("""
485
+ **Important Disclaimer**: This risk assessment is based on AI analysis and should be used as a guidance only.
486
+ It is not a definitive medical diagnosis. Always consult with a healthcare professional for proper evaluation,
487
+ especially if you experience severe symptoms, symptoms that persist or worsen, or if you're unsure about your condition.
488
+ """)
489
+
490
+ # Recommendations tab
491
+ with tabs[3]:
492
+ st.markdown("""
493
+ <div class="card">
494
+ <h3 class="card-header">Detailed Recommendations</h3>
495
+ """, unsafe_allow_html=True)
496
+
497
+ recommendation = result.get("recommendation", "No recommendations available.")
498
+
499
+ # Split recommendation into paragraphs for better readability
500
+ recommendation_parts = recommendation.split('. ')
501
+ formatted_recommendation = ""
502
+
503
+ current_paragraph = []
504
+ for part in recommendation_parts:
505
+ current_paragraph.append(part)
506
+
507
+ # Start a new paragraph every 2-3 sentences
508
+ if len(current_paragraph) >= 2 and part.endswith('.'):
509
+ formatted_recommendation += '. '.join(current_paragraph) + ".<br><br>"
510
+ current_paragraph = []
511
+
512
+ # Add any remaining parts
513
+ if current_paragraph:
514
+ formatted_recommendation += '. '.join(current_paragraph)
515
+
516
+ st.markdown(f"<p>{formatted_recommendation}</p>", unsafe_allow_html=True)
517
+
518
+ st.markdown("</div>", unsafe_allow_html=True)
519
+
520
+ # Department suggestion based on symptoms
521
+ st.markdown("""
522
+ <div class="card">
523
+ <h3 class="card-header">Suggested Medical Departments</h3>
524
+ """, unsafe_allow_html=True)
525
+
526
+ # 使用模型生成的科室建议而不是规则基础的建议
527
+ departments = result.get("structured_recommendation", {}).get("departments", [])
528
+ if not departments:
529
+ departments = ["General Medicine / Primary Care"]
530
+
531
+ # Display departments
532
+ for dept in departments:
533
+ st.markdown(f"• **{dept}**", unsafe_allow_html=True)
534
+
535
+ st.markdown("<br><em>Note: Department suggestions are based on your symptoms and risk level. Consult with a healthcare provider for proper referral.</em>", unsafe_allow_html=True)
536
+
537
+ st.markdown("</div>", unsafe_allow_html=True)
538
+
539
+ # Self-care suggestions
540
+ st.markdown("""
541
+ <div class="card">
542
+ <h3 class="card-header">Self-Care Suggestions</h3>
543
+ """, unsafe_allow_html=True)
544
+
545
+ # 使用模型生成的自我护理建议
546
+ self_care_tips = result.get("structured_recommendation", {}).get("self_care", [])
547
+
548
+ if self_care_tips:
549
+ st.markdown("<ul>", unsafe_allow_html=True)
550
+ for tip in self_care_tips:
551
+ st.markdown(f"<li>{tip}</li>", unsafe_allow_html=True)
552
+ st.markdown("</ul>", unsafe_allow_html=True)
553
+ else:
554
+ # 如果没有获取到模型生成的自我护理建议,则显示默认信息
555
+ risk_level = result.get("risk", {}).get("risk_level", "Medium")
556
+ if risk_level == "Low":
557
+ st.markdown("""
558
+ <p>While waiting for your symptoms to improve:</p>
559
+ <ul>
560
+ <li>Ensure you're getting adequate rest</li>
561
+ <li>Stay hydrated by drinking plenty of water</li>
562
+ <li>Monitor your symptoms and note any changes</li>
563
+ <li>Consider over-the-counter medications appropriate for your symptoms</li>
564
+ <li>Maintain a balanced diet to support your immune system</li>
565
+ </ul>
566
+ """, unsafe_allow_html=True)
567
+ elif risk_level == "Medium":
568
+ st.markdown("""
569
+ <p>While arranging medical care:</p>
570
+ <ul>
571
+ <li>Rest and avoid strenuous activities</li>
572
+ <li>Stay hydrated and maintain proper nutrition</li>
573
+ <li>Take your temperature and other vital signs if possible</li>
574
+ <li>Write down any changes in symptoms and when they occur</li>
575
+ <li>Have someone stay with you if your symptoms are concerning</li>
576
+ </ul>
577
+ """, unsafe_allow_html=True)
578
+ else: # High risk
579
+ st.markdown("""
580
+ <p>While seeking emergency care:</p>
581
+ <ul>
582
+ <li>Don't wait - seek medical attention immediately</li>
583
+ <li>Have someone drive you to the emergency room if safe to do so</li>
584
+ <li>Call emergency services if symptoms are severe</li>
585
+ <li>Bring a list of your current medications if possible</li>
586
+ <li>Follow any first aid protocols appropriate for your symptoms</li>
587
+ </ul>
588
+ """, unsafe_allow_html=True)
589
+
590
+ st.markdown("</div>", unsafe_allow_html=True)
591
+
592
+ # Footer
593
+ st.markdown("""
594
+ <div class="footer">
595
+ <p>AI Medical Consultation System | Created with Streamlit | Not a substitute for professional medical advice</p>
596
+ <p>Powered by: dmis-lab/biobert-v1.1, microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract, and fine-tuned T5-small</p>
597
+ </div>
598
+ """, unsafe_allow_html=True)
models.py ADDED
@@ -0,0 +1,549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from transformers import (
4
+ AutoTokenizer,
5
+ AutoModelForTokenClassification,
6
+ AutoModelForSequenceClassification,
7
+ AutoModelForSeq2SeqLM,
8
+ pipeline
9
+ )
10
+ import re
11
+ import os
12
+ import json
13
+ from typing import Dict, List, Tuple, Any
14
+
15
+ class SymptomExtractor:
16
+ """Model for extracting symptoms from patient descriptions using BioBERT."""
17
+
18
+ def __init__(self, model_name="dmis-lab/biobert-v1.1", device=None):
19
+ self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
20
+ print(f"Loading Symptom Extractor model on {self.device}...")
21
+
22
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
23
+ self.model = AutoModelForTokenClassification.from_pretrained(model_name).to(self.device)
24
+ self.nlp = pipeline("ner", model=self.model, tokenizer=self.tokenizer, device=0 if self.device == "cuda" else -1)
25
+
26
+ print("Symptom Extractor model loaded successfully.")
27
+
28
+ def extract_symptoms(self, text: str) -> Dict[str, Any]:
29
+ """Extract symptoms from the input text."""
30
+ results = self.nlp(text)
31
+
32
+ # Process the NER results to group related tokens
33
+ symptoms = []
34
+ current_symptom = None
35
+
36
+ for entity in results:
37
+ if entity["entity"].startswith("B-"): # Beginning of a symptom
38
+ if current_symptom:
39
+ symptoms.append(current_symptom)
40
+ current_symptom = {
41
+ "text": entity["word"],
42
+ "start": entity["start"],
43
+ "end": entity["end"],
44
+ "score": entity["score"]
45
+ }
46
+ elif entity["entity"].startswith("I-") and current_symptom: # Inside a symptom
47
+ current_symptom["text"] += " " + entity["word"].replace("##", "")
48
+ current_symptom["end"] = entity["end"]
49
+ current_symptom["score"] = (current_symptom["score"] + entity["score"]) / 2
50
+
51
+ if current_symptom:
52
+ symptoms.append(current_symptom)
53
+
54
+ # Extract duration information
55
+ duration_patterns = [
56
+ r"(\d+)\s*(day|days|week|weeks|month|months|year|years)",
57
+ r"since\s+(\w+)",
58
+ r"for\s+(\w+)"
59
+ ]
60
+
61
+ duration_info = []
62
+ for pattern in duration_patterns:
63
+ matches = re.finditer(pattern, text, re.IGNORECASE)
64
+ for match in matches:
65
+ duration_info.append({
66
+ "text": match.group(0),
67
+ "start": match.start(),
68
+ "end": match.end()
69
+ })
70
+
71
+ return {
72
+ "symptoms": symptoms,
73
+ "duration": duration_info
74
+ }
75
+
76
+
77
+ class RiskClassifier:
78
+ """Model for classifying patient risk level using PubMedBERT."""
79
+
80
+ def __init__(self, model_name="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract", device=None):
81
+ self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
82
+ print(f"Loading Risk Classifier model on {self.device}...")
83
+
84
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
85
+ self.model = AutoModelForSequenceClassification.from_pretrained(
86
+ model_name,
87
+ num_labels=3 # Low, Medium, High
88
+ ).to(self.device)
89
+
90
+ self.id2label = {0: "Low", 1: "Medium", 2: "High"}
91
+ print("Risk Classifier model loaded successfully.")
92
+
93
+ def classify_risk(self, text: str) -> Dict[str, Any]:
94
+ """Classify the risk level based on the input text."""
95
+ inputs = self.tokenizer(
96
+ text,
97
+ return_tensors="pt",
98
+ padding=True,
99
+ truncation=True,
100
+ max_length=512
101
+ ).to(self.device)
102
+
103
+ with torch.no_grad():
104
+ outputs = self.model(**inputs)
105
+
106
+ logits = outputs.logits
107
+ probabilities = torch.softmax(logits, dim=1)[0].cpu().numpy()
108
+ model_prediction = torch.argmax(logits, dim=1).item()
109
+
110
+ # 由于模型没有经过微调,我们添加基于规则的后处理来调整风险级别
111
+ # 检查文本中是否存在高风险关键词
112
+ high_risk_keywords = [
113
+ "severe", "extreme", "intense", "unbearable", "emergency",
114
+ "chest pain", "difficulty breathing", "can't breathe",
115
+ "losing consciousness", "fainted", "seizure", "stroke", "heart attack",
116
+ "allergic reaction", "bleeding heavily", "blood", "poisoning",
117
+ "overdose", "suicide", "self-harm", "hallucinations"
118
+ ]
119
+
120
+ medium_risk_keywords = [
121
+ "worsening", "spreading", "persistent", "chronic", "recurring",
122
+ "infection", "fever", "swelling", "rash", "pain", "vomiting",
123
+ "diarrhea", "dizzy", "headache", "concerning", "worried",
124
+ "weeks", "days", "increasing", "progressing"
125
+ ]
126
+
127
+ low_risk_keywords = [
128
+ "mild", "slight", "minor", "occasional", "intermittent",
129
+ "improving", "better", "sometimes", "rarely", "manageable"
130
+ ]
131
+
132
+ text_lower = text.lower()
133
+
134
+ # 计算匹配的关键词数量
135
+ high_risk_matches = sum(keyword in text_lower for keyword in high_risk_keywords)
136
+ medium_risk_matches = sum(keyword in text_lower for keyword in medium_risk_keywords)
137
+ low_risk_matches = sum(keyword in text_lower for keyword in low_risk_keywords)
138
+
139
+ # 根据关键词匹配调整风险预测
140
+ adjusted_prediction = model_prediction
141
+ if high_risk_matches >= 2:
142
+ adjusted_prediction = 2 # High risk
143
+ elif high_risk_matches == 1 and medium_risk_matches >= 2:
144
+ adjusted_prediction = 2 # High risk
145
+ elif medium_risk_matches >= 3:
146
+ adjusted_prediction = 1 # Medium risk
147
+ elif medium_risk_matches >= 1 and low_risk_matches <= 1:
148
+ adjusted_prediction = 1 # Medium risk
149
+ elif low_risk_matches >= 2 and high_risk_matches == 0:
150
+ adjusted_prediction = 0 # Low risk
151
+
152
+ # 如果文本很长(详细描述),可能表明情况更复杂,风险更高
153
+ if len(text.split()) > 40 and adjusted_prediction == 0:
154
+ adjusted_prediction = 1 # 升级到Medium风险
155
+
156
+ # 对调整后的概率进行修正
157
+ adjusted_probabilities = probabilities.copy()
158
+ # 增强对应风险级别的概率
159
+ adjusted_probabilities[adjusted_prediction] = max(0.6, adjusted_probabilities[adjusted_prediction])
160
+ # 规范化概率使其总和为1
161
+ adjusted_probabilities = adjusted_probabilities / adjusted_probabilities.sum()
162
+
163
+ return {
164
+ "risk_level": self.id2label[adjusted_prediction],
165
+ "confidence": float(adjusted_probabilities[adjusted_prediction]),
166
+ "all_probabilities": {
167
+ self.id2label[i]: float(prob)
168
+ for i, prob in enumerate(adjusted_probabilities)
169
+ },
170
+ "original_prediction": self.id2label[model_prediction]
171
+ }
172
+
173
+
174
+ class RecommendationGenerator:
175
+ """Model for generating medical recommendations using fine-tuned t5-small."""
176
+
177
+ def __init__(self, model_path="t5-small", device=None):
178
+ self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
179
+ print(f"Loading Recommendation Generator model on {self.device}...")
180
+
181
+ # 检查常见的微调模型路径
182
+ possible_local_paths = [
183
+ "./finetuned_t5-small", # 添加用户指定的微调模型路径
184
+ "./t5-small-medical-recommendation",
185
+ "./models/t5-small-medical-recommendation",
186
+ "./fine_tuned_models/t5-small",
187
+ "./output",
188
+ "./fine_tuning_output"
189
+ ]
190
+
191
+ # 检查是否为路径或模型标识符
192
+ model_exists = False
193
+ for path in possible_local_paths:
194
+ if os.path.exists(path):
195
+ model_path = path
196
+ model_exists = True
197
+ print(f"Found fine-tuned model at: {model_path}")
198
+ break
199
+
200
+ if not model_exists and model_path == "t5-small-medical-recommendation":
201
+ print("Fine-tuned model not found locally. Falling back to base t5-small...")
202
+ model_path = "t5-small"
203
+
204
+ try:
205
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
206
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to(self.device)
207
+ print(f"Recommendation Generator model '{model_path}' loaded successfully.")
208
+ except Exception as e:
209
+ print(f"Error loading model from {model_path}: {str(e)}")
210
+ print("Falling back to base t5-small model...")
211
+ self.tokenizer = AutoTokenizer.from_pretrained("t5-small")
212
+ self.model = AutoModelForSeq2SeqLM.from_pretrained("t5-small").to(self.device)
213
+ print("Base t5-small model loaded successfully as fallback.")
214
+
215
+ # 科室映射 - 症状关键词到科室的映射
216
+ self.symptom_to_department = {
217
+ "headache": "Neurology",
218
+ "dizziness": "Neurology",
219
+ "confusion": "Neurology",
220
+ "memory": "Neurology",
221
+ "numbness": "Neurology",
222
+ "tingling": "Neurology",
223
+ "seizure": "Neurology",
224
+ "nerve": "Neurology",
225
+
226
+ "chest pain": "Cardiology",
227
+ "heart": "Cardiology",
228
+ "palpitation": "Cardiology",
229
+ "arrhythmia": "Cardiology",
230
+ "high blood pressure": "Cardiology",
231
+ "hypertension": "Cardiology",
232
+ "heart attack": "Cardiology",
233
+ "cardiovascular": "Cardiology",
234
+
235
+ "cough": "Pulmonology",
236
+ "breathing": "Pulmonology",
237
+ "shortness": "Pulmonology",
238
+ "lung": "Pulmonology",
239
+ "respiratory": "Pulmonology",
240
+ "asthma": "Pulmonology",
241
+ "pneumonia": "Pulmonology",
242
+ "copd": "Pulmonology",
243
+
244
+ "stomach": "Gastroenterology",
245
+ "abdomen": "Gastroenterology",
246
+ "nausea": "Gastroenterology",
247
+ "vomit": "Gastroenterology",
248
+ "diarrhea": "Gastroenterology",
249
+ "constipation": "Gastroenterology",
250
+ "heartburn": "Gastroenterology",
251
+ "liver": "Gastroenterology",
252
+ "digestive": "Gastroenterology",
253
+
254
+ "joint": "Orthopedics",
255
+ "bone": "Orthopedics",
256
+ "muscle": "Orthopedics",
257
+ "pain": "Orthopedics",
258
+ "back": "Orthopedics",
259
+ "arthritis": "Orthopedics",
260
+ "fracture": "Orthopedics",
261
+ "sprain": "Orthopedics",
262
+
263
+ "rash": "Dermatology",
264
+ "skin": "Dermatology",
265
+ "itching": "Dermatology",
266
+ "itch": "Dermatology",
267
+ "acne": "Dermatology",
268
+ "eczema": "Dermatology",
269
+ "psoriasis": "Dermatology",
270
+
271
+ "fever": "General Medicine / Primary Care",
272
+ "infection": "General Medicine / Primary Care",
273
+ "sore throat": "General Medicine / Primary Care",
274
+ "flu": "General Medicine / Primary Care",
275
+ "cold": "General Medicine / Primary Care",
276
+ "fatigue": "General Medicine / Primary Care",
277
+
278
+ "pregnancy": "Obstetrics / Gynecology",
279
+ "menstruation": "Obstetrics / Gynecology",
280
+ "period": "Obstetrics / Gynecology",
281
+ "vaginal": "Obstetrics / Gynecology",
282
+ "menopause": "Obstetrics / Gynecology",
283
+
284
+ "depression": "Psychiatry",
285
+ "anxiety": "Psychiatry",
286
+ "mood": "Psychiatry",
287
+ "stress": "Psychiatry",
288
+ "sleep": "Psychiatry",
289
+ "insomnia": "Psychiatry",
290
+ "mental": "Psychiatry",
291
+
292
+ "ear": "Otolaryngology (ENT)",
293
+ "nose": "Otolaryngology (ENT)",
294
+ "throat": "Otolaryngology (ENT)",
295
+ "hearing": "Otolaryngology (ENT)",
296
+ "sinus": "Otolaryngology (ENT)",
297
+
298
+ "eye": "Ophthalmology",
299
+ "vision": "Ophthalmology",
300
+ "blindness": "Ophthalmology",
301
+ "blurry": "Ophthalmology",
302
+
303
+ "urination": "Urology",
304
+ "kidney": "Urology",
305
+ "bladder": "Urology",
306
+ "urine": "Urology",
307
+ "prostate": "Urology"
308
+ }
309
+
310
+ # 自我护理建议
311
+ self.self_care_by_risk = {
312
+ "Low": [
313
+ "Ensure you're getting adequate rest",
314
+ "Stay hydrated by drinking plenty of water",
315
+ "Monitor your symptoms and note any changes",
316
+ "Consider over-the-counter medications appropriate for your symptoms",
317
+ "Maintain a balanced diet to support your immune system",
318
+ "Try gentle exercises if appropriate for your condition",
319
+ "Avoid activities that worsen your symptoms",
320
+ "Keep track of any patterns in your symptoms"
321
+ ],
322
+ "Medium": [
323
+ "Rest and avoid strenuous activities",
324
+ "Stay hydrated and maintain proper nutrition",
325
+ "Take your temperature and other vital signs if possible",
326
+ "Write down any changes in symptoms and when they occur",
327
+ "Have someone stay with you if your symptoms are concerning",
328
+ "Prepare a list of your symptoms and medications for your doctor",
329
+ "Avoid self-medicating beyond basic over-the-counter remedies",
330
+ "Consider arranging transportation to your medical appointment"
331
+ ],
332
+ "High": [
333
+ "Don't wait - seek medical attention immediately",
334
+ "Have someone drive you to the emergency room if safe to do so",
335
+ "Call emergency services if symptoms are severe",
336
+ "Bring a list of your current medications if possible",
337
+ "Follow any first aid protocols appropriate for your symptoms",
338
+ "Don't eat or drink anything if you might need surgery",
339
+ "Take prescribed emergency medications if applicable (like an inhaler for asthma)",
340
+ "Try to remain calm and focused on getting help"
341
+ ]
342
+ }
343
+
344
+ def _extract_departments_from_symptoms(self, symptoms_text: str) -> List[str]:
345
+ """
346
+ 从症状描述中提取可能的相关科室
347
+
348
+ Args:
349
+ symptoms_text: 症状描述文本
350
+
351
+ Returns:
352
+ 科室名称列表
353
+ """
354
+ departments = set()
355
+ symptoms_lower = symptoms_text.lower()
356
+
357
+ # 通过关键词匹配寻找相关科室
358
+ for keyword, department in self.symptom_to_department.items():
359
+ if keyword in symptoms_lower:
360
+ departments.add(department)
361
+
362
+ # 如果没有找到匹配的科室,返回常规医疗科室
363
+ if not departments:
364
+ departments.add("General Medicine / Primary Care")
365
+
366
+ return list(departments)
367
+
368
+ def _get_self_care_suggestions(self, risk_level: str) -> List[str]:
369
+ """
370
+ 根据风险级别获取自我护理建议
371
+
372
+ Args:
373
+ risk_level: 风险级别 (Low, Medium, High)
374
+
375
+ Returns:
376
+ 自我护理建议列表
377
+ """
378
+ # 确保风险级别有效
379
+ if risk_level not in self.self_care_by_risk:
380
+ risk_level = "Medium" # 默认返回中等风险的建议
381
+
382
+ # 返回为该风险级别准备的建议
383
+ suggestions = self.self_care_by_risk[risk_level]
384
+
385
+ # 随机选择5项建议,避免每次返回完全相同的内容
386
+ import random
387
+ if len(suggestions) > 5:
388
+ selected = random.sample(suggestions, 5)
389
+ else:
390
+ selected = suggestions
391
+
392
+ return selected
393
+
394
+ def _format_structured_recommendation(self, medical_advice: str, departments: List[str], self_care: List[str], risk_level: str) -> str:
395
+ """
396
+ 格式化结构化建议为文本格式
397
+
398
+ Args:
399
+ medical_advice: 主要医疗建议
400
+ departments: 建议科室列表
401
+ self_care: 自我护理建议列表
402
+ risk_level: 风险级别
403
+
404
+ Returns:
405
+ 格式化后的完整建议文本
406
+ """
407
+ # 初始化建议文本
408
+ recommendation = ""
409
+
410
+ # 添加主要医疗建议
411
+ recommendation += medical_advice.strip() + "\n\n"
412
+
413
+ # 添加建议科室部分
414
+ recommendation += f"RECOMMENDED DEPARTMENTS: Based on your symptoms, consider consulting the following departments: {', '.join(departments)}.\n\n"
415
+
416
+ # 添加自我护理部分
417
+ recommendation += f"SELF-CARE SUGGESTIONS: While {risk_level.lower()} risk level requires {'immediate attention' if risk_level == 'High' else 'medical care soon' if risk_level == 'Medium' else 'monitoring'}, you can also:\n"
418
+ for suggestion in self_care:
419
+ recommendation += f"- {suggestion}\n"
420
+
421
+ return recommendation
422
+
423
+ def generate_recommendation(self,
424
+ symptoms: str,
425
+ risk_level: str,
426
+ max_length: int = 150) -> Dict[str, Any]:
427
+ """
428
+ Generate a comprehensive medical recommendation based on symptoms and risk level.
429
+
430
+ Args:
431
+ symptoms: Symptom description text
432
+ risk_level: Risk level (Low, Medium, High)
433
+ max_length: Maximum length for generated text
434
+
435
+ Returns:
436
+ Dictionary containing structured recommendation including medical advice,
437
+ department suggestions, and self-care tips
438
+ """
439
+ # 创建输入提示
440
+ input_text = f"Symptoms: {symptoms} Risk: {risk_level}"
441
+
442
+ # 通过模型生成主要医疗建议
443
+ inputs = self.tokenizer(
444
+ input_text,
445
+ return_tensors="pt",
446
+ padding=True,
447
+ truncation=True,
448
+ max_length=512
449
+ ).to(self.device)
450
+
451
+ with torch.no_grad():
452
+ output_ids = self.model.generate(
453
+ **inputs,
454
+ max_length=max_length,
455
+ num_beams=4,
456
+ early_stopping=True
457
+ )
458
+
459
+ # 解码生成的医疗建议
460
+ medical_advice = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
461
+
462
+ # 从症状提取建议科室
463
+ departments = self._extract_departments_from_symptoms(symptoms)
464
+
465
+ # 如果是高风险,添加急诊科
466
+ if risk_level == "High" and "Emergency Medicine" not in departments:
467
+ departments.insert(0, "Emergency Medicine")
468
+
469
+ # 获取自我护理建议
470
+ self_care_suggestions = self._get_self_care_suggestions(risk_level)
471
+
472
+ # 创建完整的结构化建议
473
+ structured_recommendation = {
474
+ "medical_advice": medical_advice,
475
+ "departments": departments,
476
+ "self_care": self_care_suggestions
477
+ }
478
+
479
+ # 格式化为文本格式的完整建议
480
+ formatted_text = self._format_structured_recommendation(
481
+ medical_advice,
482
+ departments,
483
+ self_care_suggestions,
484
+ risk_level
485
+ )
486
+
487
+ return {
488
+ "text": formatted_text,
489
+ "structured": structured_recommendation
490
+ }
491
+
492
+
493
+ class MedicalConsultationPipeline:
494
+ """Complete pipeline for medical consultation."""
495
+
496
+ def __init__(self,
497
+ symptom_model="dmis-lab/biobert-v1.1",
498
+ risk_model="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract",
499
+ recommendation_model="t5-small",
500
+ device=None):
501
+
502
+ self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
503
+ print(f"Initializing Medical Consultation Pipeline on {self.device}...")
504
+
505
+ self.symptom_extractor = SymptomExtractor(model_name=symptom_model, device=self.device)
506
+ self.risk_classifier = RiskClassifier(model_name=risk_model, device=self.device)
507
+ self.recommendation_generator = RecommendationGenerator(model_path=recommendation_model, device=self.device)
508
+
509
+ print("Medical Consultation Pipeline initialized successfully.")
510
+
511
+ def process(self, text: str) -> Dict[str, Any]:
512
+ """Process the patient description through the complete pipeline."""
513
+ # Step 1: Extract symptoms
514
+ extraction_results = self.symptom_extractor.extract_symptoms(text)
515
+
516
+ # Step 2: Classify risk
517
+ risk_results = self.risk_classifier.classify_risk(text)
518
+
519
+ # Create a summary of the symptoms for the recommendation model
520
+ symptoms_summary = ", ".join([symptom["text"] for symptom in extraction_results["symptoms"]])
521
+ if not symptoms_summary:
522
+ symptoms_summary = text # Use original text if no symptoms found
523
+
524
+ # Step 3: Generate recommendation
525
+ recommendation_result = self.recommendation_generator.generate_recommendation(
526
+ symptoms=symptoms_summary,
527
+ risk_level=risk_results["risk_level"]
528
+ )
529
+
530
+ return {
531
+ "extraction": extraction_results,
532
+ "risk": risk_results,
533
+ "recommendation": recommendation_result["text"],
534
+ "structured_recommendation": recommendation_result["structured"],
535
+ "input_text": text
536
+ }
537
+
538
+ # Example usage
539
+ if __name__ == "__main__":
540
+ # This is just a test code that won't run in the Streamlit app
541
+ pipeline = MedicalConsultationPipeline()
542
+
543
+ sample_text = "I've been experiencing severe headaches and dizziness for about 2 weeks. Sometimes I also feel nauseous."
544
+
545
+ result = pipeline.process(sample_text)
546
+ print("Extracted symptoms:", [s["text"] for s in result["extraction"]["symptoms"]])
547
+ print("Duration info:", [d["text"] for d in result["extraction"]["duration"]])
548
+ print("Risk level:", result["risk"]["risk_level"], f"(Confidence: {result['risk']['confidence']:.2f})")
549
+ print("Recommendation:", result["recommendation"])
requirements.txt CHANGED
@@ -1,3 +1,12 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
 
 
 
 
 
1
+ streamlit==1.31.0
2
+ torch==2.0.1
3
+ transformers==4.35.0
4
+ pandas==2.0.3
5
+ numpy==1.24.3
6
+ scikit-learn==1.3.0
7
+ matplotlib==3.7.2
8
+ plotly==5.15.0
9
+ nltk==3.8.1
10
+ spacy==3.6.1
11
+ seaborn==0.12.2
12
+ jsonlines==3.1.0
style.css ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Main style elements */
2
+ @import url('https://fonts.googleapis.com/css2?family=Roboto:wght@100;300;400;500;700&family=Source+Sans+Pro:wght@400;600;700&display=swap');
3
+
4
+ html, body, [class*="css"] {
5
+ font-family: 'Source Sans Pro', -apple-system, BlinkMacSystemFont, sans-serif;
6
+ color: #2C363F;
7
+ }
8
+
9
+ .main .block-container {
10
+ padding-top: 2rem;
11
+ padding-bottom: 2rem;
12
+ }
13
+
14
+ /* Header styling */
15
+ .main-header {
16
+ color: #2C393F;
17
+ font-weight: 600;
18
+ text-align: center;
19
+ margin-bottom: 2rem;
20
+ }
21
+
22
+ .subheader {
23
+ color: #557A95;
24
+ font-weight: 500;
25
+ font-size: 1.2rem;
26
+ margin-bottom: 1rem;
27
+ }
28
+
29
+ /* Card elements */
30
+ .card {
31
+ background-color: #FFFFFF;
32
+ border-radius: 10px;
33
+ border: 1px solid #EAEAEA;
34
+ padding: 1.5rem;
35
+ margin-bottom: 1rem;
36
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.05);
37
+ transition: all 0.3s ease;
38
+ }
39
+
40
+ .card:hover {
41
+ box-shadow: 0 6px 8px rgba(0, 0, 0, 0.1);
42
+ transform: translateY(-2px);
43
+ }
44
+
45
+ .card-header {
46
+ font-weight: 600;
47
+ margin-bottom: 0.8rem;
48
+ color: #557A95;
49
+ border-bottom: 1px solid #EAEAEA;
50
+ padding-bottom: 0.5rem;
51
+ }
52
+
53
+ /* Risk level indicators */
54
+ .risk-low {
55
+ color: #7FD8BE;
56
+ font-weight: 600;
57
+ }
58
+
59
+ .risk-medium {
60
+ color: #FFC857;
61
+ font-weight: 600;
62
+ }
63
+
64
+ .risk-high {
65
+ color: #E84855;
66
+ font-weight: 600;
67
+ }
68
+
69
+ /* Input area */
70
+ .stTextInput > div > div > input {
71
+ border-radius: 8px;
72
+ border: 1px solid #CCCCCC;
73
+ padding: 0.5rem;
74
+ font-size: 1rem;
75
+ }
76
+
77
+ .stTextArea > div > div > textarea {
78
+ border-radius: 8px;
79
+ border: 1px solid #CCCCCC;
80
+ padding: 0.8rem;
81
+ font-size: 1rem;
82
+ min-height: 150px;
83
+ }
84
+
85
+ /* Button styling */
86
+ .stButton > button {
87
+ background-color: #557A95;
88
+ color: white;
89
+ border: none;
90
+ border-radius: 8px;
91
+ padding: 0.5rem 2rem;
92
+ font-weight: 600;
93
+ transition: all 0.3s ease;
94
+ }
95
+
96
+ .stButton > button:hover {
97
+ background-color: #395B74;
98
+ color: white;
99
+ transform: translateY(-2px);
100
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
101
+ }
102
+
103
+ .stButton > button:focus {
104
+ background-color: #395B74;
105
+ color: white;
106
+ }
107
+
108
+ /* Symptom highlight styling */
109
+ .symptom-highlight {
110
+ background-color: rgba(255, 200, 87, 0.3);
111
+ border-radius: 3px;
112
+ padding: 0 3px;
113
+ }
114
+
115
+ /* Duration highlight styling */
116
+ .duration-highlight {
117
+ background-color: rgba(127, 216, 190, 0.3);
118
+ border-radius: 3px;
119
+ padding: 0 3px;
120
+ }
121
+
122
+ /* Recommendation styling */
123
+ .recommendation-container {
124
+ background-color: #F8F9FA;
125
+ border-left: 5px solid #557A95;
126
+ padding: 1rem;
127
+ margin: 1rem 0;
128
+ border-radius: 0 5px 5px 0;
129
+ }
130
+
131
+ /* History item */
132
+ .history-item {
133
+ padding: 1rem;
134
+ margin-bottom: 0.5rem;
135
+ border-radius: 5px;
136
+ border: 1px solid #EAEAEA;
137
+ background-color: #F8F9FA;
138
+ cursor: pointer;
139
+ transition: all 0.2s ease;
140
+ }
141
+
142
+ .history-item:hover {
143
+ background-color: #E9ECEF;
144
+ }
145
+
146
+ /* Loading animation */
147
+ .loading-spinner {
148
+ display: flex;
149
+ justify-content: center;
150
+ align-items: center;
151
+ margin: 2rem 0;
152
+ }
153
+
154
+ /* Custom metric container */
155
+ .metric-container {
156
+ background-color: white;
157
+ border-radius: 10px;
158
+ padding: 1rem;
159
+ text-align: center;
160
+ box-shadow: 0 2px 5px rgba(0, 0, 0, 0.05);
161
+ }
162
+
163
+ .metric-value {
164
+ font-size: 2.5rem;
165
+ font-weight: 600;
166
+ margin: 0.5rem 0;
167
+ }
168
+
169
+ .metric-label {
170
+ font-size: 1rem;
171
+ color: #6c757d;
172
+ }
173
+
174
+ /* App footer */
175
+ .footer {
176
+ text-align: center;
177
+ margin-top: 3rem;
178
+ padding-top: 1rem;
179
+ border-top: 1px solid #EAEAEA;
180
+ color: #6c757d;
181
+ font-size: 0.8rem;
182
+ }
183
+
184
+ /* Override Streamlit's default padding in widgets */
185
+ div.stRadio > div {
186
+ padding-top: 0.5rem;
187
+ padding-bottom: 0.5rem;
188
+ }
189
+
190
+ div.stCheckbox > div {
191
+ padding-top: 0.5rem;
192
+ padding-bottom: 0.5rem;
193
+ }
194
+
195
+ /* Tabs styling */
196
+ .stTabs [data-baseweb="tab-list"] {
197
+ gap: 1rem;
198
+ }
199
+
200
+ .stTabs [data-baseweb="tab"] {
201
+ height: 3rem;
202
+ border-radius: 8px 8px 0 0;
203
+ padding: 0 1.5rem;
204
+ background-color: #F8F9FA;
205
+ }
206
+
207
+ .stTabs [aria-selected="true"] {
208
+ background-color: white !important;
209
+ border-bottom: 2px solid #557A95 !important;
210
+ font-weight: 600;
211
+ }
212
+
213
+ /* Responsive adjustments */
214
+ @media (max-width: 768px) {
215
+ .main .block-container {
216
+ padding-top: 1rem;
217
+ padding-bottom: 1rem;
218
+ }
219
+
220
+ .card {
221
+ padding: 1rem;
222
+ }
223
+
224
+ .metric-value {
225
+ font-size: 2rem;
226
+ }
227
+ }
228
+
229
+ /* Animation for success message */
230
+ @keyframes fadeInUp {
231
+ from {
232
+ opacity: 0;
233
+ transform: translateY(20px);
234
+ }
235
+ to {
236
+ opacity: 1;
237
+ transform: translateY(0);
238
+ }
239
+ }
240
+
241
+ .fadeInUp {
242
+ animation-name: fadeInUp;
243
+ animation-duration: 0.5s;
244
+ animation-fill-mode: both;
245
+ }
utils.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import plotly.express as px
4
+ import plotly.graph_objects as go
5
+ from datetime import datetime
6
+ import json
7
+ import os
8
+ from typing import Dict, List, Any
9
+
10
+ # Constants
11
+ RISK_COLORS = {
12
+ "Low": "#7FD8BE", # Soft mint green
13
+ "Medium": "#FFC857", # Warm amber
14
+ "High": "#E84855" # Bright red
15
+ }
16
+
17
+ def highlight_text_with_entities(text: str, entities: List[Dict[str, Any]]) -> str:
18
+ """
19
+ Format text with HTML to highlight extracted entities.
20
+
21
+ Args:
22
+ text: Original input text
23
+ entities: List of entity dictionaries with 'start', 'end', and 'text' keys
24
+
25
+ Returns:
26
+ HTML formatted string with highlighted entities
27
+ """
28
+ if not entities:
29
+ return text
30
+
31
+ # Sort entities by start position (descending) to avoid index issues when replacing
32
+ sorted_entities = sorted(entities, key=lambda x: x['start'], reverse=True)
33
+
34
+ result = text
35
+ for entity in sorted_entities:
36
+ start = entity['start']
37
+ end = entity['end']
38
+ highlight = f"<span style='background-color: rgba(255, 200, 87, 0.3); border-radius: 3px; padding: 0px 3px;'>{text[start:end]}</span>"
39
+ result = result[:start] + highlight + result[end:]
40
+
41
+ return result
42
+
43
+ def format_duration(duration_entities: List[Dict[str, Any]]) -> str:
44
+ """Format duration entities into a readable string."""
45
+ if not duration_entities:
46
+ return "No specific duration mentioned"
47
+
48
+ return ", ".join([entity["text"] for entity in duration_entities])
49
+
50
+ def create_risk_gauge(risk_level: str, confidence: float) -> go.Figure:
51
+ """Create a gauge chart for risk level visualization."""
52
+
53
+ # Map risk levels to numerical values for the gauge
54
+ risk_value_map = {"Low": 1, "Medium": 2, "High": 3}
55
+ risk_value = risk_value_map.get(risk_level, 2) # Default to Medium if unknown
56
+
57
+ fig = go.Figure(go.Indicator(
58
+ mode="gauge+number+delta",
59
+ value=risk_value,
60
+ domain={'x': [0, 1], 'y': [0, 1]},
61
+ gauge={
62
+ 'axis': {'range': [0, 3], 'tickvals': [1, 2, 3], 'ticktext': ['Low', 'Medium', 'High']},
63
+ 'bar': {'color': RISK_COLORS[risk_level]},
64
+ 'steps': [
65
+ {'range': [0, 1.5], 'color': "rgba(127, 216, 190, 0.3)"},
66
+ {'range': [1.5, 2.5], 'color': "rgba(255, 200, 87, 0.3)"},
67
+ {'range': [2.5, 3], 'color': "rgba(232, 72, 85, 0.3)"}
68
+ ],
69
+ 'threshold': {
70
+ 'line': {'color': "white", 'width': 2},
71
+ 'thickness': 0.85,
72
+ 'value': risk_value
73
+ }
74
+ },
75
+ number={'valueformat': '.0f', 'font': {'size': 36}},
76
+ title={
77
+ 'text': f"Risk Level: {risk_level}",
78
+ 'font': {'size': 24}
79
+ },
80
+ ))
81
+
82
+ fig.update_layout(
83
+ height=250,
84
+ margin=dict(l=10, r=10, t=50, b=10),
85
+ paper_bgcolor='white',
86
+ font={'color': "#2C363F", 'family': "Arial"}
87
+ )
88
+
89
+ return fig
90
+
91
+ def create_risk_probability_chart(probabilities: Dict[str, float]) -> go.Figure:
92
+ """Create a horizontal bar chart for risk probabilities."""
93
+ labels = list(probabilities.keys())
94
+ values = list(probabilities.values())
95
+ colors = [RISK_COLORS[label] for label in labels]
96
+
97
+ fig = go.Figure(go.Bar(
98
+ x=values,
99
+ y=labels,
100
+ orientation='h',
101
+ marker_color=colors,
102
+ text=[f"{v:.1%}" for v in values],
103
+ textposition='auto'
104
+ ))
105
+
106
+ fig.update_layout(
107
+ title="Risk Probability Distribution",
108
+ xaxis_title="Probability",
109
+ yaxis_title="Risk Level",
110
+ height=250,
111
+ margin=dict(l=10, r=10, t=50, b=10),
112
+ xaxis=dict(range=[0, 1], tickformat=".0%"),
113
+ paper_bgcolor='white',
114
+ plot_bgcolor='white',
115
+ font={'color': "#2C363F", 'family': "Arial"}
116
+ )
117
+
118
+ return fig
119
+
120
+ def save_consultation(consultation_data: Dict[str, Any]):
121
+ """Save consultation data to a JSON file."""
122
+ # Create history directory if it doesn't exist
123
+ os.makedirs("consultation_history", exist_ok=True)
124
+
125
+ # Generate a filename with timestamp
126
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
127
+ filename = f"consultation_history/consultation_{timestamp}.json"
128
+
129
+ # Add timestamp to data
130
+ consultation_data["timestamp"] = datetime.now().isoformat()
131
+
132
+ # Save to file
133
+ with open(filename, "w") as f:
134
+ json.dump(consultation_data, f, indent=2)
135
+
136
+ return filename
137
+
138
+ def load_consultation_history() -> List[Dict[str, Any]]:
139
+ """Load all saved consultations from the history directory."""
140
+ history_dir = "consultation_history"
141
+ if not os.path.exists(history_dir):
142
+ return []
143
+
144
+ history = []
145
+ for filename in os.listdir(history_dir):
146
+ if filename.endswith(".json"):
147
+ try:
148
+ with open(os.path.join(history_dir, filename), "r") as f:
149
+ consultation = json.load(f)
150
+ history.append(consultation)
151
+ except Exception as e:
152
+ st.error(f"Error loading {filename}: {str(e)}")
153
+
154
+ # Sort by timestamp (newest first)
155
+ history.sort(key=lambda x: x.get("timestamp", ""), reverse=True)
156
+ return history
157
+
158
+ def init_session_state():
159
+ """Initialize session state variables."""
160
+ if "consultation_history" not in st.session_state:
161
+ st.session_state.consultation_history = []
162
+
163
+ if "current_result" not in st.session_state:
164
+ st.session_state.current_result = None
165
+
166
+ if "is_processing" not in st.session_state:
167
+ st.session_state.is_processing = False
168
+
169
+ if "loaded_models" not in st.session_state:
170
+ st.session_state.loaded_models = False