import scispacy import spacy import torch import re import sys from transformers import AutoTokenizer, AutoModelForCausalLM from sklearn.metrics.pairwise import cosine_similarity import os # Set cache directory at module level os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache" os.environ["HF_HOME"] = "/tmp/huggingface_cache" class MedicalTextSimplifier: def __init__(self): print("Loading models...") try: # Load SciSpaCy for term identification self.nlp = spacy.load("en_core_sci_sm") # Load BioMedLM (Stanford) self.tokenizer = AutoTokenizer.from_pretrained( "stanford-crfm/BioMedLM", cache_dir="/tmp/huggingface_cache" ) self.model = AutoModelForCausalLM.from_pretrained( "stanford-crfm/BioMedLM", cache_dir="/tmp/huggingface_cache" ) print("Models loaded successfully!") except Exception as e: print(f"Error loading models: {e}") print("Make sure all required packages are installed.") sys.exit(1) def identify_medical_terms(self, text): """Identify medical entities using SciSpaCy""" doc = self.nlp(text) terms = [] for ent in doc.ents: terms.append({'term': ent.text, 'start': ent.start_char, 'end': ent.end_char}) return terms def generate_simplified_explanation(self, term, context): """Generate a plain-language explanation using BioMedLM""" try: prompt = f"Explain the medical term '{term}' in simple language for a patient. Context: {context}\nExplanation:" inputs = self.tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True) outputs = self.model.generate( inputs, max_length=100, num_return_sequences=1, temperature=0.7, top_p=0.9, do_sample=True, eos_token_id=self.tokenizer.eos_token_id ) explanation = self.tokenizer.decode(outputs[0], skip_special_tokens=True) explanation = explanation.split("Explanation:")[-1].strip() return explanation except Exception as e: print(f"Error generating explanation: {e}") return f"a medical term related to {term}" def simplify_text(self, text): print("\nOriginal text:") print(text) print("\nIdentifying medical terms using SciSpaCy...") medical_terms = self.identify_medical_terms(text) if not medical_terms: print("No medical terms found.") return text simplified_text = text offset = 0 # to keep track of text length changes during replacement print("\nMedical terms and explanations:") for item in medical_terms: term = item['term'] start = item['start'] + offset end = item['end'] + offset explanation = self.generate_simplified_explanation(term, text) annotated = f"{term} ({explanation})" # Replace the term in text simplified_text = simplified_text[:start] + annotated + simplified_text[end:] offset += len(annotated) - len(term) print(f"\nTerm: {term}") print(f"Explanation: {explanation}") print("\nSimplified text:") print(simplified_text) return simplified_text