import torch import numpy as np import gradio as gr import matplotlib.pyplot as plt from transformers import AutoTokenizer, EsmModel from sklearn.decomposition import PCA from Bio.PDB import PDBParser, PDBIO import py3Dmol import tempfile import os # Load ESM-1b model and tokenizer model = EsmModel.from_pretrained("facebook/esm1b_t33_650M_UR50S", output_hidden_states=True) tokenizer = AutoTokenizer.from_pretrained("facebook/esm1b_t33_650M_UR50S") # Compute PCA and return scaled values for selected components def compute_scaled_pca_scores(seq, components): inputs = tokenizer(seq, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) embedding = outputs.last_hidden_state[0] L = len(seq) embedding = embedding[1:L+1] # remove CLS and EOS pca = PCA(n_components=max(components) + 1) pca_result = pca.fit_transform(embedding.detach().cpu().numpy()) scaled_components = [] for c in components: selected = pca_result[:, c] scaled = (selected - selected.min()) / (selected.max() - selected.min()) * 100 scaled_components.append(scaled) return scaled_components # Inject scores into B-factor column and save each PDB separately def inject_bfactors_and_save(pdb_file, scores_list, component_indices): parser = PDBParser(QUIET=True) structure = parser.get_structure("prot", pdb_file.name) output_paths = [] for scores, idx in zip(scores_list, component_indices): i = 0 for model in structure: for chain in model: for residue in chain: if i >= len(scores): break for atom in residue: atom.bfactor = float(scores[i]) i += 1 out_path = tempfile.NamedTemporaryFile(delete=False, suffix=f"_PC{idx}.pdb").name io = PDBIO() io.set_structure(structure) io.save(out_path) output_paths.append(out_path) return output_paths # Render structure with py3Dmol and inject script tag manually def render_structure(pdb_path): with open(pdb_path, 'r') as f: pdb_data = f.read() view = py3Dmol.view(width=600, height=400) view.addModel(pdb_data, 'pdb') view.setStyle({'cartoon': {'color': 'bfactor'}}) view.zoomTo() # Combine viewer HTML with explicit 3Dmol.js script html = ( '' + view._make_html() ) return html # Gradio interface logic def process(seq, pdb_file, component_string): try: components = [int(c.strip()) for c in component_string.split(",") if c.strip().isdigit()] except: return [], "

Error: Invalid component list. Use comma-separated integers.

" scores_list = compute_scaled_pca_scores(seq, components) pdb_paths = inject_bfactors_and_save(pdb_file, scores_list, components) html_view = render_structure(pdb_paths[0]) if pdb_paths else "" return pdb_paths, html_view # Gradio UI demo = gr.Interface( fn=process, inputs=[ gr.Textbox(label="Input Protein Sequence (1-letter code)"), gr.File(label="Upload PDB File", file_types=[".pdb"]), gr.Textbox(label="Comma-separated PCA Components (e.g. 0,1,2)") ], outputs=[ gr.File(label="Download PDBs with PCA Projections", file_types=[".pdb"], file_count="multiple"), gr.HTML(label="Interactive Structure Viewer (first PCA component only)") ], title="ESM-1b PCA Component Projection with Interactive 3D Structure" ) demo.launch()