demo-gradio / app.py
ullahi's picture
updated
b82eba8 verified
import gradio as gr
import torch
import numpy as np
import matplotlib.pyplot as plt
from enformer_pytorch import Enformer
from einops import rearrange
# Initialize Enformer with correct architecture (based on EleutherAI/enformer-191k)
model = Enformer(
num_channels=1536,
num_classes=5313,
target_length=896,
depth=11,
heads=8
)
model.eval()
# Optionally load pretrained weights if available locally or upload to HF Spaces manually
# model.load_state_dict(torch.load("enformer-191k.pth")) # optional for offline Spaces
# Helper function to one-hot encode DNA
def one_hot_encode(sequence, length=196608):
mapping = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
one_hot = np.zeros((length, 4), dtype=np.float32)
sequence = sequence.upper().replace("N", "A")
for i, base in enumerate(sequence[:length]):
if base in mapping:
one_hot[i, mapping[base]] = 1.0
return one_hot
# Prediction function
def predict_expression(dna_sequence):
encoded = one_hot_encode(dna_sequence)
input_tensor = torch.tensor(encoded).unsqueeze(0) # shape: (1, length, 4)
input_tensor = rearrange(input_tensor, 'b l c -> b c l') # shape: (1, 4, length)
with torch.no_grad():
output = model(input_tensor)
avg_expression = output[0].mean(dim=0).numpy() # (5313,)
# Plot first 10 expression predictions
plt.figure(figsize=(10, 4))
plt.bar(range(10), avg_expression[:10])
plt.xticks(range(10), [f"Tissue {i}" for i in range(10)])
plt.title("Predicted Gene Expression")
plt.ylabel("Signal")
plt.tight_layout()
return plt.gcf()
# Gradio app
demo = gr.Interface(
fn=predict_expression,
inputs=gr.Textbox(lines=6, label="Paste DNA Sequence (200k bp)"),
outputs=gr.Plot(label="Predicted Expression Tracks (first 10 tissues)"),
title="Gene Expression Prediction with Enformer",
description="Paste a 200kb DNA sequence and see predicted expression levels using Enformer."
)
demo.launch()