Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForMaskedLM | |
import torch | |
model_name = "yangheng/PlantRNA-FM" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForMaskedLM.from_pretrained(model_name) | |
def predict_rna(sequence): | |
inputs = tokenizer(sequence, return_tensors="pt") | |
mask_token_index = torch.where(inputs.input_ids == tokenizer.mask_token_id)[1] # 找到 <mask> 的位置 | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
mask_token_logits = outputs.logits[0, mask_token_index, :] | |
predicted_token_ids = torch.argmax(mask_token_logits, dim=-1) | |
predicted_tokens = tokenizer.convert_ids_to_tokens(predicted_token_ids) | |
return " ".join(predicted_tokens) | |
input_text = gr.Textbox(lines=2, placeholder="Input RNA Sequence with <mask>, e.g., AAAGAGTCATATACGATATTGTCGACCGTGG<mask>AGAGAGAAGAATGTACGATTGGAGT") | |
output_text = gr.Textbox() | |
app = gr.Interface( | |
fn=predict_rna, | |
inputs=input_text, | |
outputs=output_text, | |
title="Zero-shot PlantFM-RNA MNM Inference", | |
description="Zero-shot PlantFM-RNA MNM Inference: Predicts only the <mask> tokens." | |
) | |
if __name__ == "__main__": | |
app.launch() | |