|
import torch |
|
import medspacy |
|
nlp = medspacy.load(medspacy_enable=["medspacy_pyrush", "medspacy_context"]) |
|
|
|
from .utils import sentence_split, post_process |
|
|
|
def run_ner(texts, idx2label, tokenizer, model, device, batch_size): |
|
|
|
clean_text_list, is_start_list = sentence_split(texts) |
|
|
|
predicted_labels = [] |
|
|
|
for i in range(0, len(clean_text_list), batch_size): |
|
batch_text = clean_text_list[i:i+batch_size] |
|
|
|
inputs = tokenizer(batch_text, |
|
max_length=512, |
|
padding=True, |
|
truncation=True, |
|
return_tensors="pt").to(device) |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
predicted_labels.extend(torch.argmax(outputs.logits, dim=2).tolist()) |
|
|
|
inputs = tokenizer(clean_text_list, |
|
max_length=512, |
|
padding=True, |
|
truncation=True, |
|
return_tensors="pt") |
|
|
|
save_pairs = [] |
|
|
|
pad_token_id = tokenizer.pad_token_id |
|
|
|
for i, is_start in enumerate(is_start_list): |
|
|
|
predicted_entities = [idx2label[label] for label in predicted_labels[i]] |
|
|
|
non_pad_mask = inputs["input_ids"][i] != pad_token_id |
|
non_pad_length = non_pad_mask.sum().item() |
|
non_pad_input_ids = inputs["input_ids"][i][:non_pad_length] |
|
|
|
tokenized_text = tokenizer.convert_ids_to_tokens(non_pad_input_ids) |
|
|
|
if is_start: |
|
save_pair = post_process(tokenized_text, predicted_entities, tokenizer) |
|
else: |
|
save_pair = post_process(tokenized_text, predicted_entities, tokenizer) |
|
save_pairs[-1].extend(save_pair) |
|
continue |
|
|
|
save_pairs.append(save_pair) |
|
|
|
return save_pairs |
|
|
|
|
|
def process_embedding(pair, eval_tokenizer, eval_model, device): |
|
entities = [pair[0] for pair in pair] |
|
types = [pair[1] for pair in pair] |
|
|
|
if len(entities) == 0: |
|
embeds_word = torch.tensor([]) |
|
else: |
|
embeds_word = torch.tensor([]).to(device) |
|
|
|
with torch.no_grad(): |
|
|
|
encoded = eval_tokenizer( |
|
entities, |
|
truncation=True, |
|
padding=True, |
|
return_tensors='pt', |
|
max_length=30, |
|
).to(device) |
|
|
|
|
|
embeds_word = torch.cat((embeds_word.to('cpu'), |
|
eval_model(**encoded).last_hidden_state[:, 0, :].to('cpu')), dim=0) |
|
|
|
return embeds_word, types |
|
|
|
|