|
--- |
|
library_name: transformers |
|
tags: [] |
|
--- |
|
|
|
```python |
|
import torch |
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
# Pick one sentence |
|
sentence = "The patient has a right pneumothorax." |
|
|
|
# Load pretrained model and tokenizer |
|
model_name = "IAMJB/RadEvalModernBERT" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModel.from_pretrained(model_name) |
|
|
|
# Put model in eval mode and set device |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
model.eval() |
|
|
|
# Tokenize input |
|
inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True).to(device) |
|
|
|
# Get embeddings |
|
with torch.no_grad(): |
|
outputs = model(**inputs, output_hidden_states=True) |
|
last_hidden_state = outputs.hidden_states[-1] |
|
cls_embedding = last_hidden_state[:, 0, :] # CLS token |
|
|
|
print("Sentence:", sentence) |
|
print("Embedding shape:", cls_embedding.shape) |
|
``` |
|
|
|
|
|
|
|
### Similarity heatmap example |
|
|
|
|
|
```python |
|
import argparse |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import torch |
|
import seaborn as sns |
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
def get_cls_embeddings(model, tokenizer, texts, device): |
|
"""Get CLS token embeddings for a list of texts.""" |
|
embeddings = [] |
|
|
|
for text in texts: |
|
# Tokenize the text |
|
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True) |
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
# Get the embeddings (use CLS token) |
|
with torch.no_grad(): |
|
outputs = model(**inputs, output_hidden_states=True) |
|
# Use the last hidden state |
|
last_hidden_state = outputs.hidden_states[-1] |
|
# Extract CLS token (first token) embedding |
|
cls_embedding = last_hidden_state[:, 0, :] |
|
embeddings.append(cls_embedding.cpu().numpy()[0]) |
|
|
|
return np.array(embeddings) |
|
|
|
def compute_similarities(embeddings): |
|
"""Compute cosine similarity between embeddings.""" |
|
# Normalize embeddings |
|
normalized_embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) |
|
# Compute similarity matrix |
|
similarity_matrix = np.matmul(normalized_embeddings, normalized_embeddings.T) |
|
return similarity_matrix |
|
|
|
def plot_heatmap(similarity_matrix, labels, output_path="cls_embedding_similarities.png"): |
|
"""Generate a heatmap visualization of the similarity matrix.""" |
|
plt.figure(figsize=(10, 8)) |
|
|
|
# Find min value to set as vmin (or use 0.6 as a reasonable value) |
|
min_val = max(0.0, np.min(similarity_matrix)) |
|
|
|
# Create the heatmap with adjusted color scale |
|
ax = sns.heatmap( |
|
similarity_matrix, |
|
annot=True, |
|
fmt=".3f", |
|
cmap="viridis", # Better colormap for distinguishing high values |
|
vmin=min_val, # Start from minimum value or 0.6 |
|
vmax=1.0, |
|
xticklabels=labels, |
|
yticklabels=labels, |
|
cbar_kws={"label": "Similarity"} |
|
) |
|
|
|
# Add title and adjust layout |
|
plt.title("CLS Token Embedding Similarities") |
|
plt.tight_layout() |
|
|
|
# Rotate x-axis labels for better readability |
|
plt.xticks(rotation=90) |
|
|
|
# Save the figure |
|
plt.savefig(output_path, dpi=300, bbox_inches="tight") |
|
print(f"Heatmap saved to {output_path}") |
|
|
|
# Show the plot |
|
plt.show() |
|
|
|
def main(): |
|
# Medical terms to compare |
|
medical_terms = [ |
|
"large right pneumothorax", |
|
"right pneumothorax", |
|
"pneumonia in the right lower lobe", |
|
"consolidation in the right lower lobe", |
|
"right 9th rib fracture", |
|
"left 9th rib fracture", |
|
"left 5th rib fracture", |
|
"5th metatarsal fracture", |
|
"no pneumothorax is present", |
|
"prior consolidation has cleared", |
|
"no rib fractures" |
|
] |
|
|
|
# Set the device |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Using device: {device}") |
|
|
|
# Load the tokenizer |
|
tokenizer = AutoTokenizer.from_pretrained(IAMJB/RadEvalModernBERT) |
|
|
|
# Load the model |
|
model = AutoModel.from_pretrained(IAMJB/RadEvalModernBERT) |
|
model.to(device) |
|
model.eval() |
|
|
|
# Get CLS token embeddings for the medical terms |
|
print("Generating CLS token embeddings...") |
|
embeddings = get_cls_embeddings(model, tokenizer, medical_terms, device) |
|
|
|
# Compute similarities |
|
print("Computing similarity matrix...") |
|
similarity_matrix = compute_similarities(embeddings) |
|
|
|
# Plot and save the heatmap |
|
print("Generating heatmap...") |
|
plot_heatmap(similarity_matrix, medical_terms, "cls_embedding_similarities.png") |
|
|
|
print("Done!") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
``` |
|
|
|
 |
|
|