dsk129's picture
Update app.py
a5292d0 verified
import torch
import numpy as np
import gradio as gr
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, EsmModel
from sklearn.decomposition import PCA
from Bio.PDB import PDBParser, PDBIO
import py3Dmol
import tempfile
import os
# Load ESM-1b model and tokenizer
model = EsmModel.from_pretrained("facebook/esm1b_t33_650M_UR50S", output_hidden_states=True)
tokenizer = AutoTokenizer.from_pretrained("facebook/esm1b_t33_650M_UR50S")
# Compute PCA and return scaled values for selected components
def compute_scaled_pca_scores(seq, components):
inputs = tokenizer(seq, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
embedding = outputs.last_hidden_state[0]
L = len(seq)
embedding = embedding[1:L+1] # remove CLS and EOS
pca = PCA(n_components=max(components) + 1)
pca_result = pca.fit_transform(embedding.detach().cpu().numpy())
scaled_components = []
for c in components:
selected = pca_result[:, c]
scaled = (selected - selected.min()) / (selected.max() - selected.min()) * 100
scaled_components.append(scaled)
return scaled_components
# Inject scores into B-factor column and save each PDB separately
def inject_bfactors_and_save(pdb_file, scores_list, component_indices):
parser = PDBParser(QUIET=True)
structure = parser.get_structure("prot", pdb_file.name)
output_paths = []
for scores, idx in zip(scores_list, component_indices):
i = 0
for model in structure:
for chain in model:
for residue in chain:
if i >= len(scores):
break
for atom in residue:
atom.bfactor = float(scores[i])
i += 1
out_path = tempfile.NamedTemporaryFile(delete=False, suffix=f"_PC{idx}.pdb").name
io = PDBIO()
io.set_structure(structure)
io.save(out_path)
output_paths.append(out_path)
return output_paths
# Render structure with py3Dmol and inject script tag manually
def render_structure(pdb_path):
with open(pdb_path, 'r') as f:
pdb_data = f.read()
view = py3Dmol.view(width=600, height=400)
view.addModel(pdb_data, 'pdb')
view.setStyle({'cartoon': {'color': 'bfactor'}})
view.zoomTo()
# Combine viewer HTML with explicit 3Dmol.js script
html = (
'<script src="https://3Dmol.org/build/3Dmol.js"></script>'
+ view._make_html()
)
return html
# Gradio interface logic
def process(seq, pdb_file, component_string):
try:
components = [int(c.strip()) for c in component_string.split(",") if c.strip().isdigit()]
except:
return [], "<p style='color:red'>Error: Invalid component list. Use comma-separated integers.</p>"
scores_list = compute_scaled_pca_scores(seq, components)
pdb_paths = inject_bfactors_and_save(pdb_file, scores_list, components)
html_view = render_structure(pdb_paths[0]) if pdb_paths else ""
return pdb_paths, html_view
# Gradio UI
demo = gr.Interface(
fn=process,
inputs=[
gr.Textbox(label="Input Protein Sequence (1-letter code)"),
gr.File(label="Upload PDB File", file_types=[".pdb"]),
gr.Textbox(label="Comma-separated PCA Components (e.g. 0,1,2)")
],
outputs=[
gr.File(label="Download PDBs with PCA Projections", file_types=[".pdb"], file_count="multiple"),
gr.HTML(label="Interactive Structure Viewer (first PCA component only)")
],
title="ESM-1b PCA Component Projection with Interactive 3D Structure"
)
demo.launch()