File size: 4,838 Bytes
099dc58
 
 
 
 
3d1e264
 
 
 
 
 
 
 
b5dbe8d
3d1e264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452c7ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
---
library_name: transformers
tags: []
---

```python
import torch
from transformers import AutoTokenizer, AutoModel

# Pick one sentence
sentence = "The patient has a right pneumothorax."

# Load pretrained model and tokenizer
model_name = "IAMJB/RadEvalModernBERT"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

# Put model in eval mode and set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

# Tokenize input
inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True).to(device)

# Get embeddings
with torch.no_grad():
    outputs = model(**inputs, output_hidden_states=True)
    last_hidden_state = outputs.hidden_states[-1]
    cls_embedding = last_hidden_state[:, 0, :]  # CLS token

print("Sentence:", sentence)
print("Embedding shape:", cls_embedding.shape)
```



### Similarity heatmap example


```python
import argparse
import numpy as np
import matplotlib.pyplot as plt
import torch
import seaborn as sns
from transformers import AutoTokenizer, AutoModel

def get_cls_embeddings(model, tokenizer, texts, device):
    """Get CLS token embeddings for a list of texts."""
    embeddings = []
    
    for text in texts:
        # Tokenize the text
        inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        # Get the embeddings (use CLS token)
        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True)
            # Use the last hidden state
            last_hidden_state = outputs.hidden_states[-1]
            # Extract CLS token (first token) embedding
            cls_embedding = last_hidden_state[:, 0, :]
            embeddings.append(cls_embedding.cpu().numpy()[0])
    
    return np.array(embeddings)

def compute_similarities(embeddings):
    """Compute cosine similarity between embeddings."""
    # Normalize embeddings
    normalized_embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
    # Compute similarity matrix
    similarity_matrix = np.matmul(normalized_embeddings, normalized_embeddings.T)
    return similarity_matrix

def plot_heatmap(similarity_matrix, labels, output_path="cls_embedding_similarities.png"):
    """Generate a heatmap visualization of the similarity matrix."""
    plt.figure(figsize=(10, 8))
    
    # Find min value to set as vmin (or use 0.6 as a reasonable value)
    min_val = max(0.0, np.min(similarity_matrix))
    
    # Create the heatmap with adjusted color scale
    ax = sns.heatmap(
        similarity_matrix,
        annot=True,
        fmt=".3f",
        cmap="viridis",  # Better colormap for distinguishing high values
        vmin=min_val,    # Start from minimum value or 0.6
        vmax=1.0,
        xticklabels=labels,
        yticklabels=labels,
        cbar_kws={"label": "Similarity"}
    )
    
    # Add title and adjust layout
    plt.title("CLS Token Embedding Similarities")
    plt.tight_layout()
    
    # Rotate x-axis labels for better readability
    plt.xticks(rotation=90)
    
    # Save the figure
    plt.savefig(output_path, dpi=300, bbox_inches="tight")
    print(f"Heatmap saved to {output_path}")
    
    # Show the plot
    plt.show()

def main():
    # Medical terms to compare
    medical_terms = [
        "large right pneumothorax",
        "right pneumothorax",
        "pneumonia in the right lower lobe",
        "consolidation in the right lower lobe",
        "right 9th rib fracture",
        "left 9th rib fracture",
        "left 5th rib fracture",
        "5th metatarsal fracture",
        "no pneumothorax is present",
        "prior consolidation has cleared",
        "no rib fractures"
    ]
    
    # Set the device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Load the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(IAMJB/RadEvalModernBERT)
    
    # Load the model
    model = AutoModel.from_pretrained(IAMJB/RadEvalModernBERT)
    model.to(device)
    model.eval()
    
    # Get CLS token embeddings for the medical terms
    print("Generating CLS token embeddings...")
    embeddings = get_cls_embeddings(model, tokenizer, medical_terms, device)
    
    # Compute similarities
    print("Computing similarity matrix...")
    similarity_matrix = compute_similarities(embeddings)
    
    # Plot and save the heatmap
    print("Generating heatmap...")
    plot_heatmap(similarity_matrix, medical_terms, "cls_embedding_similarities.png")
    
    print("Done!")

if __name__ == "__main__":
    main()
```

![image/png](https://cdn-uploads.huggingface.co/production/uploads/62716952bcef985363db8485/6mzZ5_Xz2ovl3a6TlAzxo.png)