|
import scispacy |
|
import spacy |
|
import torch |
|
import re |
|
import sys |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
import os |
|
|
|
|
|
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache" |
|
os.environ["HF_HOME"] = "/tmp/huggingface_cache" |
|
|
|
class MedicalTextSimplifier: |
|
def __init__(self): |
|
print("Loading models...") |
|
try: |
|
|
|
self.nlp = spacy.load("en_core_sci_sm") |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
"stanford-crfm/BioMedLM", |
|
cache_dir="/tmp/huggingface_cache" |
|
) |
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
"stanford-crfm/BioMedLM", |
|
cache_dir="/tmp/huggingface_cache" |
|
) |
|
print("Models loaded successfully!") |
|
except Exception as e: |
|
print(f"Error loading models: {e}") |
|
print("Make sure all required packages are installed.") |
|
sys.exit(1) |
|
|
|
def identify_medical_terms(self, text): |
|
"""Identify medical entities using SciSpaCy""" |
|
doc = self.nlp(text) |
|
terms = [] |
|
for ent in doc.ents: |
|
terms.append({'term': ent.text, 'start': ent.start_char, 'end': ent.end_char}) |
|
return terms |
|
|
|
def generate_simplified_explanation(self, term, context): |
|
"""Generate a plain-language explanation using BioMedLM""" |
|
try: |
|
prompt = f"Explain the medical term '{term}' in simple language for a patient. Context: {context}\nExplanation:" |
|
inputs = self.tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True) |
|
|
|
outputs = self.model.generate( |
|
inputs, |
|
max_length=100, |
|
num_return_sequences=1, |
|
temperature=0.7, |
|
top_p=0.9, |
|
do_sample=True, |
|
eos_token_id=self.tokenizer.eos_token_id |
|
) |
|
|
|
explanation = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
explanation = explanation.split("Explanation:")[-1].strip() |
|
return explanation |
|
except Exception as e: |
|
print(f"Error generating explanation: {e}") |
|
return f"a medical term related to {term}" |
|
|
|
def simplify_text(self, text): |
|
print("\nOriginal text:") |
|
print(text) |
|
|
|
print("\nIdentifying medical terms using SciSpaCy...") |
|
medical_terms = self.identify_medical_terms(text) |
|
|
|
if not medical_terms: |
|
print("No medical terms found.") |
|
return text |
|
|
|
simplified_text = text |
|
offset = 0 |
|
|
|
print("\nMedical terms and explanations:") |
|
for item in medical_terms: |
|
term = item['term'] |
|
start = item['start'] + offset |
|
end = item['end'] + offset |
|
|
|
explanation = self.generate_simplified_explanation(term, text) |
|
annotated = f"{term} ({explanation})" |
|
|
|
|
|
simplified_text = simplified_text[:start] + annotated + simplified_text[end:] |
|
offset += len(annotated) - len(term) |
|
|
|
print(f"\nTerm: {term}") |
|
print(f"Explanation: {explanation}") |
|
|
|
print("\nSimplified text:") |
|
print(simplified_text) |
|
return simplified_text |