Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from enformer_pytorch import Enformer | |
from einops import rearrange | |
# Initialize Enformer with correct architecture (based on EleutherAI/enformer-191k) | |
model = Enformer( | |
num_channels=1536, | |
num_classes=5313, | |
target_length=896, | |
depth=11, | |
heads=8 | |
) | |
model.eval() | |
# Optionally load pretrained weights if available locally or upload to HF Spaces manually | |
# model.load_state_dict(torch.load("enformer-191k.pth")) # optional for offline Spaces | |
# Helper function to one-hot encode DNA | |
def one_hot_encode(sequence, length=196608): | |
mapping = {'A': 0, 'C': 1, 'G': 2, 'T': 3} | |
one_hot = np.zeros((length, 4), dtype=np.float32) | |
sequence = sequence.upper().replace("N", "A") | |
for i, base in enumerate(sequence[:length]): | |
if base in mapping: | |
one_hot[i, mapping[base]] = 1.0 | |
return one_hot | |
# Prediction function | |
def predict_expression(dna_sequence): | |
encoded = one_hot_encode(dna_sequence) | |
input_tensor = torch.tensor(encoded).unsqueeze(0) # shape: (1, length, 4) | |
input_tensor = rearrange(input_tensor, 'b l c -> b c l') # shape: (1, 4, length) | |
with torch.no_grad(): | |
output = model(input_tensor) | |
avg_expression = output[0].mean(dim=0).numpy() # (5313,) | |
# Plot first 10 expression predictions | |
plt.figure(figsize=(10, 4)) | |
plt.bar(range(10), avg_expression[:10]) | |
plt.xticks(range(10), [f"Tissue {i}" for i in range(10)]) | |
plt.title("Predicted Gene Expression") | |
plt.ylabel("Signal") | |
plt.tight_layout() | |
return plt.gcf() | |
# Gradio app | |
demo = gr.Interface( | |
fn=predict_expression, | |
inputs=gr.Textbox(lines=6, label="Paste DNA Sequence (200k bp)"), | |
outputs=gr.Plot(label="Predicted Expression Tracks (first 10 tissues)"), | |
title="Gene Expression Prediction with Enformer", | |
description="Paste a 200kb DNA sequence and see predicted expression levels using Enformer." | |
) | |
demo.launch() |