File size: 4,838 Bytes
099dc58 3d1e264 b5dbe8d 3d1e264 452c7ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
---
library_name: transformers
tags: []
---
```python
import torch
from transformers import AutoTokenizer, AutoModel
# Pick one sentence
sentence = "The patient has a right pneumothorax."
# Load pretrained model and tokenizer
model_name = "IAMJB/RadEvalModernBERT"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
# Put model in eval mode and set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
# Tokenize input
inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True).to(device)
# Get embeddings
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
last_hidden_state = outputs.hidden_states[-1]
cls_embedding = last_hidden_state[:, 0, :] # CLS token
print("Sentence:", sentence)
print("Embedding shape:", cls_embedding.shape)
```
### Similarity heatmap example
```python
import argparse
import numpy as np
import matplotlib.pyplot as plt
import torch
import seaborn as sns
from transformers import AutoTokenizer, AutoModel
def get_cls_embeddings(model, tokenizer, texts, device):
"""Get CLS token embeddings for a list of texts."""
embeddings = []
for text in texts:
# Tokenize the text
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(device) for k, v in inputs.items()}
# Get the embeddings (use CLS token)
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
# Use the last hidden state
last_hidden_state = outputs.hidden_states[-1]
# Extract CLS token (first token) embedding
cls_embedding = last_hidden_state[:, 0, :]
embeddings.append(cls_embedding.cpu().numpy()[0])
return np.array(embeddings)
def compute_similarities(embeddings):
"""Compute cosine similarity between embeddings."""
# Normalize embeddings
normalized_embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
# Compute similarity matrix
similarity_matrix = np.matmul(normalized_embeddings, normalized_embeddings.T)
return similarity_matrix
def plot_heatmap(similarity_matrix, labels, output_path="cls_embedding_similarities.png"):
"""Generate a heatmap visualization of the similarity matrix."""
plt.figure(figsize=(10, 8))
# Find min value to set as vmin (or use 0.6 as a reasonable value)
min_val = max(0.0, np.min(similarity_matrix))
# Create the heatmap with adjusted color scale
ax = sns.heatmap(
similarity_matrix,
annot=True,
fmt=".3f",
cmap="viridis", # Better colormap for distinguishing high values
vmin=min_val, # Start from minimum value or 0.6
vmax=1.0,
xticklabels=labels,
yticklabels=labels,
cbar_kws={"label": "Similarity"}
)
# Add title and adjust layout
plt.title("CLS Token Embedding Similarities")
plt.tight_layout()
# Rotate x-axis labels for better readability
plt.xticks(rotation=90)
# Save the figure
plt.savefig(output_path, dpi=300, bbox_inches="tight")
print(f"Heatmap saved to {output_path}")
# Show the plot
plt.show()
def main():
# Medical terms to compare
medical_terms = [
"large right pneumothorax",
"right pneumothorax",
"pneumonia in the right lower lobe",
"consolidation in the right lower lobe",
"right 9th rib fracture",
"left 9th rib fracture",
"left 5th rib fracture",
"5th metatarsal fracture",
"no pneumothorax is present",
"prior consolidation has cleared",
"no rib fractures"
]
# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(IAMJB/RadEvalModernBERT)
# Load the model
model = AutoModel.from_pretrained(IAMJB/RadEvalModernBERT)
model.to(device)
model.eval()
# Get CLS token embeddings for the medical terms
print("Generating CLS token embeddings...")
embeddings = get_cls_embeddings(model, tokenizer, medical_terms, device)
# Compute similarities
print("Computing similarity matrix...")
similarity_matrix = compute_similarities(embeddings)
# Plot and save the heatmap
print("Generating heatmap...")
plot_heatmap(similarity_matrix, medical_terms, "cls_embedding_similarities.png")
print("Done!")
if __name__ == "__main__":
main()
```

|