File size: 4,578 Bytes
78669aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# app.py

import gradio as gr
import torch
import plotly.express as px
import numpy as np
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.decomposition import PCA
from transformers import AutoTokenizer, AutoModel

# Load model once
model_name = "karina-zadorozhny/ume"
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
model.eval()

# Load all 3 tokenizers
tokenizer_aa = AutoTokenizer.from_pretrained(model_name, subfolder="tokenizer_amino_acid", trust_remote_code=True)
tokenizer_nt = AutoTokenizer.from_pretrained(model_name, subfolder="tokenizer_nucleotide", trust_remote_code=True)
tokenizer_sm = AutoTokenizer.from_pretrained(model_name, subfolder="tokenizer_smiles", trust_remote_code=True)


def detect_modality(seq):
    seq = seq.strip().upper()
    if all(c in "ATGCUN" for c in seq):  # DNA/RNA
        return "nucleotide"
    elif all(c in "ACDEFGHIKLMNPQRSTVWYBXZJUO" for c in seq):  # Protein
        return "amino_acid"
    else:
        return "smiles"


def compute_embeddings(sequences):
    embeddings = []

    for seq in sequences:
        modality = detect_modality(seq)
        if modality == "amino_acid":
            tokenizer = tokenizer_aa
        elif modality == "nucleotide":
            tokenizer = tokenizer_nt
        else:
            tokenizer = tokenizer_sm

        inputs = tokenizer([seq], return_tensors="pt", padding=True, truncation=True)
        with torch.no_grad():
            emb = model(inputs["input_ids"].unsqueeze(1), inputs["attention_mask"].unsqueeze(1))
            embeddings.append(emb.squeeze(0).squeeze(0).numpy())

    return np.vstack(embeddings)

def visualize_embeddings(sequences):
    embeddings = compute_embeddings(sequences)

    # PCA for 2D and 3D
    pca_2d = PCA(n_components=2).fit_transform(embeddings)
    pca_3d = PCA(n_components=3).fit_transform(embeddings)

    df_2d = pd.DataFrame(pca_2d, columns=["PC1", "PC2"])
    df_2d["Sequence"] = sequences

    df_3d = pd.DataFrame(pca_3d, columns=["X", "Y", "Z"])
    df_3d["Sequence"] = sequences

    fig_2d = px.scatter(df_2d, x="PC1", y="PC2", text="Sequence",
                        title="2D PCA of UME Embeddings", color="Sequence",
                        color_discrete_sequence=px.colors.qualitative.Bold)
    
    fig_3d = px.scatter_3d(df_3d, x="X", y="Y", z="Z", text="Sequence",
                           title="3D PCA of UME Embeddings", color="Sequence",
                           color_discrete_sequence=px.colors.qualitative.Vivid)

    return fig_2d, fig_3d


def similarity_matrix(sequences):
    embeddings = compute_embeddings(sequences)
    sim_matrix = cosine_similarity(embeddings)
    sim_df = pd.DataFrame(sim_matrix, index=sequences, columns=sequences)
    fig = px.imshow(sim_df, text_auto=True, color_continuous_scale='Viridis',
                    title="Cosine Similarity Matrix")
    return fig


description = """
# 🧬 UME Explorer: Biosequence Embedding Playground
Welcome to **UME Explorer**, an interactive space to explore representations of molecules using the UME model.

Paste in your DNA, amino acid, or SMILES sequences and:
- ✨ Visualize embeddings in 2D and 3D
- πŸ”¬ Explore pairwise similarities
- 🎨 Enjoy colorful, educational plots!

> **Tip**: Keep input sequences short and between 3–20 items for better visuals on CPU.
"""

with gr.Blocks(theme=gr.themes.Monochrome(), css="footer {display: none}") as demo:
    gr.Markdown(description)

    gr.Markdown("""
    ℹ️ <b>How sequence type is detected:</b><br>
    - 🧬 <b>Nucleotide (DNA/RNA):</b> Only uses A, T, G, C, U, N<br>
    - πŸ”Ή <b>Protein (Amino Acid):</b> Includes letters like M, K, V, L, etc.<br>
    - πŸ§ͺ <b>SMILES (Chemical):</b> Includes characters like =, (, ), C, O, etc.<br>
    <small>πŸ‘‰ Detection is automatic. You can mix sequence types in one run!</small>
    """)

    with gr.Row():
        seq_input = gr.Textbox(lines=8, placeholder="Enter sequences, one per line...", label="Input Sequences")
        submit_btn = gr.Button("Compute Embeddings & Visualize")

    with gr.Row():
        out2d = gr.Plot(label="2D Plot")
        out3d = gr.Plot(label="3D Plot")

    sim_out = gr.Plot(label="Similarity Heatmap")

    def process_input(text):
        seqs = [s.strip() for s in text.splitlines() if s.strip()]
        fig2d, fig3d = visualize_embeddings(seqs)
        sim_fig = similarity_matrix(seqs)
        return fig2d, fig3d, sim_fig

    submit_btn.click(fn=process_input, inputs=seq_input, outputs=[out2d, out3d, sim_out])

demo.launch()