File size: 3,631 Bytes
4f558d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5292d0
4f558d2
 
 
 
 
 
 
a5292d0
 
 
 
 
c690ff0
a5292d0
4f558d2
a5292d0
 
 
 
 
 
4f558d2
a5292d0
 
 
 
4f558d2
 
 
 
 
 
 
 
 
 
 
a5292d0
4f558d2
 
 
 
 
a5292d0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
import numpy as np
import gradio as gr
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, EsmModel
from sklearn.decomposition import PCA
from Bio.PDB import PDBParser, PDBIO
import py3Dmol
import tempfile
import os

# Load ESM-1b model and tokenizer
model = EsmModel.from_pretrained("facebook/esm1b_t33_650M_UR50S", output_hidden_states=True)
tokenizer = AutoTokenizer.from_pretrained("facebook/esm1b_t33_650M_UR50S")

# Compute PCA and return scaled values for selected components
def compute_scaled_pca_scores(seq, components):
    inputs = tokenizer(seq, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
        embedding = outputs.last_hidden_state[0]

    L = len(seq)
    embedding = embedding[1:L+1]  # remove CLS and EOS

    pca = PCA(n_components=max(components) + 1)
    pca_result = pca.fit_transform(embedding.detach().cpu().numpy())

    scaled_components = []
    for c in components:
        selected = pca_result[:, c]
        scaled = (selected - selected.min()) / (selected.max() - selected.min()) * 100
        scaled_components.append(scaled)

    return scaled_components

# Inject scores into B-factor column and save each PDB separately
def inject_bfactors_and_save(pdb_file, scores_list, component_indices):
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure("prot", pdb_file.name)
    output_paths = []

    for scores, idx in zip(scores_list, component_indices):
        i = 0
        for model in structure:
            for chain in model:
                for residue in chain:
                    if i >= len(scores):
                        break
                    for atom in residue:
                        atom.bfactor = float(scores[i])
                    i += 1
        out_path = tempfile.NamedTemporaryFile(delete=False, suffix=f"_PC{idx}.pdb").name
        io = PDBIO()
        io.set_structure(structure)
        io.save(out_path)
        output_paths.append(out_path)

    return output_paths

# Render structure with py3Dmol and inject script tag manually
def render_structure(pdb_path):
    with open(pdb_path, 'r') as f:
        pdb_data = f.read()
    view = py3Dmol.view(width=600, height=400)
    view.addModel(pdb_data, 'pdb')
    view.setStyle({'cartoon': {'color': 'bfactor'}})
    view.zoomTo()

    # Combine viewer HTML with explicit 3Dmol.js script
    html = (
        '<script src="https://3Dmol.org/build/3Dmol.js"></script>'
        + view._make_html()
    )
    return html

# Gradio interface logic
def process(seq, pdb_file, component_string):
    try:
        components = [int(c.strip()) for c in component_string.split(",") if c.strip().isdigit()]
    except:
        return [], "<p style='color:red'>Error: Invalid component list. Use comma-separated integers.</p>"

    scores_list = compute_scaled_pca_scores(seq, components)
    pdb_paths = inject_bfactors_and_save(pdb_file, scores_list, components)
    html_view = render_structure(pdb_paths[0]) if pdb_paths else ""
    return pdb_paths, html_view

# Gradio UI
demo = gr.Interface(
    fn=process,
    inputs=[
        gr.Textbox(label="Input Protein Sequence (1-letter code)"),
        gr.File(label="Upload PDB File", file_types=[".pdb"]),
        gr.Textbox(label="Comma-separated PCA Components (e.g. 0,1,2)")
    ],
    outputs=[
        gr.File(label="Download PDBs with PCA Projections", file_types=[".pdb"], file_count="multiple"),
        gr.HTML(label="Interactive Structure Viewer (first PCA component only)")
    ],
    title="ESM-1b PCA Component Projection with Interactive 3D Structure"
)

demo.launch()