health-simplify-tool / medical_simplifier.py
faisalshah012003's picture
Update medical_simplifier.py
22029cc verified
raw
history blame
3.67 kB
import scispacy
import spacy
import torch
import re
import sys
from transformers import AutoTokenizer, AutoModelForCausalLM
from sklearn.metrics.pairwise import cosine_similarity
import os
# Set cache directory at module level
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache"
os.environ["HF_HOME"] = "/tmp/huggingface_cache"
class MedicalTextSimplifier:
def __init__(self):
print("Loading models...")
try:
# Load SciSpaCy for term identification
self.nlp = spacy.load("en_core_sci_sm")
# Load BioMedLM (Stanford)
self.tokenizer = AutoTokenizer.from_pretrained(
"stanford-crfm/BioMedLM",
cache_dir="/tmp/huggingface_cache"
)
self.model = AutoModelForCausalLM.from_pretrained(
"stanford-crfm/BioMedLM",
cache_dir="/tmp/huggingface_cache"
)
print("Models loaded successfully!")
except Exception as e:
print(f"Error loading models: {e}")
print("Make sure all required packages are installed.")
sys.exit(1)
def identify_medical_terms(self, text):
"""Identify medical entities using SciSpaCy"""
doc = self.nlp(text)
terms = []
for ent in doc.ents:
terms.append({'term': ent.text, 'start': ent.start_char, 'end': ent.end_char})
return terms
def generate_simplified_explanation(self, term, context):
"""Generate a plain-language explanation using BioMedLM"""
try:
prompt = f"Explain the medical term '{term}' in simple language for a patient. Context: {context}\nExplanation:"
inputs = self.tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True)
outputs = self.model.generate(
inputs,
max_length=100,
num_return_sequences=1,
temperature=0.7,
top_p=0.9,
do_sample=True,
eos_token_id=self.tokenizer.eos_token_id
)
explanation = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
explanation = explanation.split("Explanation:")[-1].strip()
return explanation
except Exception as e:
print(f"Error generating explanation: {e}")
return f"a medical term related to {term}"
def simplify_text(self, text):
print("\nOriginal text:")
print(text)
print("\nIdentifying medical terms using SciSpaCy...")
medical_terms = self.identify_medical_terms(text)
if not medical_terms:
print("No medical terms found.")
return text
simplified_text = text
offset = 0 # to keep track of text length changes during replacement
print("\nMedical terms and explanations:")
for item in medical_terms:
term = item['term']
start = item['start'] + offset
end = item['end'] + offset
explanation = self.generate_simplified_explanation(term, text)
annotated = f"{term} ({explanation})"
# Replace the term in text
simplified_text = simplified_text[:start] + annotated + simplified_text[end:]
offset += len(annotated) - len(term)
print(f"\nTerm: {term}")
print(f"Explanation: {explanation}")
print("\nSimplified text:")
print(simplified_text)
return simplified_text