######################################################################### # Copyright (C) # # 2025-August Sen Li (Sen.Li.Sprout@gmail.com) # # Permission given to modify the code only for Non-Profit Research # # as long as you keep this declaration at the top # ######################################################################### import os import gradio as gr import huggingface_hub import sentence_transformers from transformers import AutoTokenizer, AutoModel import torch # ---------------------------------------------------------------------------------------------------------------------- def func_ClearInputs(): return "", "", "" # str_ModelID_ClinicalBERT = "medicalai/ClinicalBERT" str_ModelID_ModernBERT = "answerdotai/ModernBERT-large" # str_ModelID_ClinicalBERT = "TsinghuaC3I/Llama-3-8B-UltraMedical" # tokenizer = AutoTokenizer.from_pretrained(str_ModelID_ModernBERT) # model_ClinicalBERT = AutoModel.from_pretrained(str_ModelID_ModernBERT) # Wrap ClinicalBERT inside SentenceTransformers word_embedding_model = sentence_transformers.models.Transformer(str_ModelID_ModernBERT) pooling_model = sentence_transformers.models.Pooling(word_embedding_model.get_word_embedding_dimension()) sentenceModel_ModernBERT = sentence_transformers.SentenceTransformer(modules=[word_embedding_model, pooling_model]) # device="cuda") def get_SentenceEmbeddings_ModernBERT(sentence): # Encode sentences in batches efficiently embeddings = sentenceModel_ModernBERT.encode(sentence) return embeddings def get_sentence_embedding(sentence: str) -> torch.Tensor: # Tokenize and encode inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True) # Get hidden states with torch.no_grad(): outputs = sentenceModel_ModernBERT(**inputs) # outputs.last_hidden_state shape: (batch_size, seq_len, hidden_dim) token_embeddings = outputs.last_hidden_state # Create sentence embedding (mean pooling across tokens, ignoring padding) attention_mask = inputs['attention_mask'] mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() sentence_embedding = torch.sum(token_embeddings * mask_expanded, dim=1) / torch.clamp(mask_expanded.sum(dim=1), min=1e-9) return sentence_embedding.squeeze() def func_sBERT_SimilarityResult(str_Text_1, str_Text_2): if not str_Text_1.strip() or not str_Text_2.strip(): return "Both text inputs must be non-empty." # 01. Load SBERT model (you can choose other pre-trained models too) inferenceClient = huggingface_hub.InferenceClient(provider="hf-inference") # 02. Get sentence embeddings # str_ModelID_sBERT = "medicalai/ClinicalBERT" # arrEmbedding_Text_1 = inferenceClient.feature_extraction(text=str_Text_1, model=str_ModelID_sBERT) # arrEmbedding_Text_2 = inferenceClient.feature_extraction(text=str_Text_2, model=str_ModelID_sBERT) # arrEmbedding_Text_1 = get_sentence_embedding(str_Text_1) # arrEmbedding_Text_2 = get_sentence_embedding(str_Text_2) arrEmbedding_Text_1 = get_SentenceEmbeddings_ModernBERT(str_Text_1) arrEmbedding_Text_2 = get_SentenceEmbeddings_ModernBERT(str_Text_2) # 03. Compute cosine similarity tensor_Similarity = sentence_transformers.util.pytorch_cos_sim(arrEmbedding_Text_1, arrEmbedding_Text_2) f_Similarity = tensor_Similarity.item() return f"Clinical Similarity Score: {f_Similarity:.4f}" # ---------------------------------------------------------------------------------------------------------------------- # Launch the interface and MCP server if __name__ == "__main__": print(f"os.getcwd() = {os.getcwd()}") os.system(f"echo ls -al {os.getcwd()} && ls -al {os.getcwd()}") os.system(f"echo ls -al /: && ls -al /") os.system(f"echo ls -al /home/: && ls -al /home/") # 03. Gradio UI elements with gr.Blocks() as grBlocks_SentenceSimilarity__MCP_Server: gr.Markdown("# ModernBERT for Clinical Text Similarity using HF Inference Server, MaxSeqLength==8192") gr.Markdown("This application calculates Cosine Similarity Score between two Texts' ModernBERT Sentence-Embeddings") with gr.Row(): grTextBox_Input_1 = gr.Textbox(label="Text Panel 1", lines=20) grTextBox_Input_2 = gr.Textbox(label="Text Panel 2", lines=20) with gr.Row(): with gr.Column(scale=1): grButton_Clear = gr.Button("Clear") grButton_Submit = gr.Button("Submit") with gr.Column(scale=3): grTextbox_Output = gr.Textbox(label="Similarity Result", interactive=False) # Set button functionality grButton_Submit.click(fn=func_sBERT_SimilarityResult, inputs=[grTextBox_Input_1, grTextBox_Input_2], outputs=grTextbox_Output) grButton_Clear.click(fn=func_ClearInputs, inputs=[], outputs=[grTextBox_Input_1, grTextBox_Input_2, grTextbox_Output]) # 04. Launch Gradio MCP server grBlocks_SentenceSimilarity__MCP_Server.launch(mcp_server=True, share=True)