Spaces:
Sleeping
Sleeping
# app.py | |
import gradio as gr | |
import torch | |
import plotly.express as px | |
import numpy as np | |
import pandas as pd | |
from sklearn.metrics.pairwise import cosine_similarity | |
from sklearn.decomposition import PCA | |
from transformers import AutoTokenizer, AutoModel | |
# Load model once | |
model_name = "karina-zadorozhny/ume" | |
model = AutoModel.from_pretrained(model_name, trust_remote_code=True) | |
model.eval() | |
# Load all 3 tokenizers | |
tokenizer_aa = AutoTokenizer.from_pretrained(model_name, subfolder="tokenizer_amino_acid", trust_remote_code=True) | |
tokenizer_nt = AutoTokenizer.from_pretrained(model_name, subfolder="tokenizer_nucleotide", trust_remote_code=True) | |
tokenizer_sm = AutoTokenizer.from_pretrained(model_name, subfolder="tokenizer_smiles", trust_remote_code=True) | |
def detect_modality(seq): | |
seq = seq.strip().upper() | |
if all(c in "ATGCUN" for c in seq): # DNA/RNA | |
return "nucleotide" | |
elif all(c in "ACDEFGHIKLMNPQRSTVWYBXZJUO" for c in seq): # Protein | |
return "amino_acid" | |
else: | |
return "smiles" | |
def compute_embeddings(sequences): | |
embeddings = [] | |
for seq in sequences: | |
modality = detect_modality(seq) | |
if modality == "amino_acid": | |
tokenizer = tokenizer_aa | |
elif modality == "nucleotide": | |
tokenizer = tokenizer_nt | |
else: | |
tokenizer = tokenizer_sm | |
inputs = tokenizer([seq], return_tensors="pt", padding=True, truncation=True) | |
with torch.no_grad(): | |
emb = model(inputs["input_ids"].unsqueeze(1), inputs["attention_mask"].unsqueeze(1)) | |
embeddings.append(emb.squeeze(0).squeeze(0).numpy()) | |
return np.vstack(embeddings) | |
def visualize_embeddings(sequences): | |
embeddings = compute_embeddings(sequences) | |
# PCA for 2D and 3D | |
pca_2d = PCA(n_components=2).fit_transform(embeddings) | |
pca_3d = PCA(n_components=3).fit_transform(embeddings) | |
df_2d = pd.DataFrame(pca_2d, columns=["PC1", "PC2"]) | |
df_2d["Sequence"] = sequences | |
df_3d = pd.DataFrame(pca_3d, columns=["X", "Y", "Z"]) | |
df_3d["Sequence"] = sequences | |
fig_2d = px.scatter(df_2d, x="PC1", y="PC2", text="Sequence", | |
title="2D PCA of UME Embeddings", color="Sequence", | |
color_discrete_sequence=px.colors.qualitative.Bold) | |
fig_3d = px.scatter_3d(df_3d, x="X", y="Y", z="Z", text="Sequence", | |
title="3D PCA of UME Embeddings", color="Sequence", | |
color_discrete_sequence=px.colors.qualitative.Vivid) | |
return fig_2d, fig_3d | |
def similarity_matrix(sequences): | |
embeddings = compute_embeddings(sequences) | |
sim_matrix = cosine_similarity(embeddings) | |
sim_df = pd.DataFrame(sim_matrix, index=sequences, columns=sequences) | |
fig = px.imshow(sim_df, text_auto=True, color_continuous_scale='Viridis', | |
title="Cosine Similarity Matrix") | |
return fig | |
description = """ | |
# 𧬠UME Explorer: Biosequence Embedding Playground | |
Welcome to **UME Explorer**, an interactive space to explore representations of molecules using the UME model. | |
Paste in your DNA, amino acid, or SMILES sequences and: | |
- β¨ Visualize embeddings in 2D and 3D | |
- π¬ Explore pairwise similarities | |
- π¨ Enjoy colorful, educational plots! | |
> **Tip**: Keep input sequences short and between 3β20 items for better visuals on CPU. | |
""" | |
with gr.Blocks(theme=gr.themes.Monochrome(), css="footer {display: none}") as demo: | |
gr.Markdown(description) | |
gr.Markdown(""" | |
βΉοΈ <b>How sequence type is detected:</b><br> | |
- 𧬠<b>Nucleotide (DNA/RNA):</b> Only uses A, T, G, C, U, N<br> | |
- πΉ <b>Protein (Amino Acid):</b> Includes letters like M, K, V, L, etc.<br> | |
- π§ͺ <b>SMILES (Chemical):</b> Includes characters like =, (, ), C, O, etc.<br> | |
<small>π Detection is automatic. You can mix sequence types in one run!</small> | |
""") | |
with gr.Row(): | |
seq_input = gr.Textbox(lines=8, placeholder="Enter sequences, one per line...", label="Input Sequences") | |
submit_btn = gr.Button("Compute Embeddings & Visualize") | |
with gr.Row(): | |
out2d = gr.Plot(label="2D Plot") | |
out3d = gr.Plot(label="3D Plot") | |
sim_out = gr.Plot(label="Similarity Heatmap") | |
def process_input(text): | |
seqs = [s.strip() for s in text.splitlines() if s.strip()] | |
fig2d, fig3d = visualize_embeddings(seqs) | |
sim_fig = similarity_matrix(seqs) | |
return fig2d, fig3d, sim_fig | |
submit_btn.click(fn=process_input, inputs=seq_input, outputs=[out2d, out3d, sim_out]) | |
demo.launch() | |