|
import os |
|
import json |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from transformers import BertForSequenceClassification, BertTokenizer |
|
from tqdm import tqdm |
|
import re |
|
import nltk |
|
|
|
|
|
def srr_bert_parse_sentences(text): |
|
|
|
|
|
|
|
text = re.sub(r'^\s*\d+\.(?!\d)\s*', '', text) |
|
|
|
|
|
text = re.sub(r'(\w)\.(\d+)\.(?!\d)\s*', r'\1. ', text) |
|
|
|
|
|
text = re.sub(r'(\w)(\d+)\.(?!\d)\s*', r'\1. ', text) |
|
|
|
|
|
text = re.sub(r'(\w)\s+\d+\.(?!\d)\s*', r'\1. ', text) |
|
|
|
|
|
text = re.sub(r'([.!?])\s*\d+\.(?!\d)\s*', r'\1 ', text) |
|
|
|
|
|
text = re.sub(r'\.([A-Z])', r'. \1', text) |
|
|
|
|
|
if not text.strip().endswith(('.', '!', '?')): |
|
text = text.strip() + '.' |
|
|
|
|
|
sentences = nltk.sent_tokenize(text) |
|
|
|
return sentences |
|
|
|
|
|
class SRRBert(nn.Module): |
|
|
|
MODEL_CONFIGS = { |
|
"leaves": { |
|
"model_path": "StanfordAIMI/SRR-BERT-Leaves", |
|
"mapping_file": "leaves_mapping.json" |
|
}, |
|
"upper": { |
|
"model_path": "StanfordAIMI/SRR-BERT-Upper", |
|
"mapping_file": "upper_mapping.json" |
|
}, |
|
"leaves_with_statuses": { |
|
"model_path": "StanfordAIMI/SRR-BERT-Leaves-with-Statuses", |
|
"mapping_file": "leaves_with_statuses_mapping.json" |
|
}, |
|
"upper_with_statuses": { |
|
"model_path": "StanfordAIMI/SRRG-BERT-Upper-with-Statuses", |
|
"mapping_file": "upper_with_statuses_mapping.json" |
|
}, |
|
} |
|
|
|
def __init__( |
|
self, |
|
model_type: str = "leaves", |
|
batch_size: int = 4, |
|
tqdm_enable: bool = False |
|
): |
|
super().__init__() |
|
if model_type not in self.MODEL_CONFIGS: |
|
raise ValueError( |
|
f"model_type must be one of {list(self.MODEL_CONFIGS.keys())}" |
|
) |
|
config = self.MODEL_CONFIGS[model_type] |
|
|
|
|
|
mapping_path = os.path.join( |
|
os.path.dirname(__file__), |
|
config["mapping_file"] |
|
) |
|
with open(mapping_path, 'r') as f: |
|
self.mapping = json.load(f) |
|
|
|
|
|
self.device = torch.device( |
|
'cuda' if torch.cuda.is_available() else 'cpu' |
|
) |
|
|
|
|
|
self.model = BertForSequenceClassification.from_pretrained( |
|
config["model_path"], |
|
num_labels=len(self.mapping) |
|
) |
|
self.model.to(self.device) |
|
self.model.eval() |
|
|
|
|
|
self.tokenizer = BertTokenizer.from_pretrained( |
|
"microsoft/BiomedVLP-CXR-BERT-general" |
|
) |
|
|
|
|
|
self.batch_size = batch_size |
|
self.tqdm_enable = tqdm_enable |
|
|
|
def map_predictions_to_labels(self, outputs): |
|
inverted_mapping = {v: k for k, v in self.mapping.items()} |
|
all_labels = [] |
|
for output in outputs: |
|
labels = [inverted_mapping[i] for i, flag in enumerate(output) if flag == 1] |
|
all_labels.append(labels) |
|
return all_labels |
|
|
|
def forward(self, sentences): |
|
|
|
batches = [ |
|
sentences[i:i + self.batch_size] |
|
for i in range(0, len(sentences), self.batch_size) |
|
] |
|
outputs = [] |
|
with torch.no_grad(): |
|
for batch in tqdm( |
|
batches, desc="Predicting", disable=not self.tqdm_enable |
|
): |
|
inputs = self.tokenizer.batch_encode_plus( |
|
batch, |
|
add_special_tokens=True, |
|
max_length=512, |
|
padding="max_length", |
|
truncation=True, |
|
return_attention_mask=True, |
|
return_tensors="pt", |
|
) |
|
inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
logits = self.model(**inputs).logits |
|
preds = (torch.sigmoid(logits) > 0.5).cpu().numpy().astype(int) |
|
outputs.append(preds) |
|
|
|
outputs = np.concatenate(outputs, axis=0) |
|
return outputs, self.map_predictions_to_labels(outputs) |
|
|
|
|
|
if __name__ == "__main__": |
|
example_sentences = [ |
|
"Layering pleural effusions", |
|
"Moderate pulmonary edema.", |
|
"Chronic fracture and dislocation involving the left humeral surgical neck and glenoid.", |
|
"Stable cardiomegaly.", |
|
] |
|
|
|
|
|
model = SRRBert( |
|
model_type="leaves", |
|
batch_size=4, |
|
tqdm_enable=True |
|
) |
|
outputs, labels = model(example_sentences) |
|
print("Raw outputs:", outputs) |
|
print("Predicted labels:", labels) |
|
|