bcadkins01's picture
Update app.py
4a19ce8 verified
import streamlit as st
import torch
import os
from rdkit import Chem
from rdkit.Chem import Draw
from transformers import BartForConditionalGeneration, BartTokenizer
from admet_ai import ADMETModel
import safe
import io
from PIL import Image
import cairosvg
import pandas as pd
# Page Configuration
st.set_page_config(
page_title='Beta-Lactam Molecule Generator',
layout='wide'
)
# Load Models
@st.cache_resource(show_spinner="Loading Models...", ttl=600)
def load_models():
"""
Load the molecule generation model and the ADMET-AI model.
Caches the models to avoid reloading on every run.
"""
# Load your molecule generation model
model_name = "bcadkins01/beta_lactam_generator"
access_token = os.getenv("HUGGING_FACE_TOKEN")
if access_token is None:
st.error("Access token not found. Please set the HUGGING_FACE_TOKEN environment variable.")
st.stop()
model = BartForConditionalGeneration.from_pretrained(model_name, token=access_token)
tokenizer = BartTokenizer.from_pretrained(model_name, token=access_token)
# Load ADMET-AI model
admet_model = ADMETModel()
return model, tokenizer, admet_model
# Load models once and reuse
model, tokenizer, admet_model = load_models()
# Set Generation Parameters in Sidebar
st.sidebar.header('Generation Parameters')
# Creativity Slider (Temperature)
creativity = st.sidebar.slider(
'Creativity (Temperature):',
min_value=0.0,
max_value=2.4,
value=1.0,
step=0.2,
help="Higher values lead to more diverse (or wild) outputs."
)
# Number of Molecules to Generate
num_molecules = st.sidebar.number_input(
'Number of Molecules to Generate:',
min_value=1,
max_value=3, # Adjust as needed
value=3,
help="Select the number of molecules you want to generate (up to 3)."
)
# Function to Generate Molecule Images
def generate_molecule_image(input_string, use_safe=False):
"""
Generates an image of the molecule from the input string.
If use_safe is True, input_string is treated as a SAFE string.
"""
try:
if use_safe and input_string is not None:
# Generate image from SAFE encoding
svg_str = safe.to_image(input_string)
# Convert SVG to PNG bytes
png_bytes = cairosvg.svg2png(bytestring=svg_str.encode('utf-8'))
# Create an image object
img = Image.open(io.BytesIO(png_bytes))
else:
# Generate standard molecule image
mol = Chem.MolFromSmiles(input_string)
if mol:
img = Draw.MolToImage(mol, size=(250, 250))
else:
img = None
return img
except Exception as e:
st.error(f"Error generating molecule image: {e}")
return None
# Generate Molecules Button
if st.button('Generate Molecules'):
st.info("Generating molecules... Please wait.")
# Beta-lactam core structure
core_smiles = "C1C(=O)N(C)C(=O)C1"
# Tokenize the core SMILES
input_ids = tokenizer(core_smiles, return_tensors='pt').input_ids
# Generate molecules using the model with diverse beam search
output_ids = model.generate(
input_ids=input_ids,
max_length=128,
do_sample=True,
temperature=1.2, # Increase for more diversity
top_k=0, # Disable top-k sampling
top_p=0.9, # Enable nucleus (top-p) sampling
num_return_sequences=num_molecules,
num_beams=1
)
# Decode generated molecule SMILES
generated_smiles = [
tokenizer.decode(ids, skip_special_tokens=True)
for ids in output_ids
]
# Create generic molecule names for demo
molecule_names = [
f"Mol{str(i).zfill(2)}"
for i in range(1, len(generated_smiles) + 1)
]
# Create df for generated molecules
df_molecules = pd.DataFrame({
'Molecule Name': molecule_names,
'SMILES': generated_smiles
})
# Invalid SMILES Check
# Function to validate SMILES
def is_valid_smile(smile):
return Chem.MolFromSmiles(smile) is not None
# Apply validation function
df_molecules['Valid'] = df_molecules['SMILES'].apply(is_valid_smile)
df_valid = df_molecules[df_molecules['Valid']].copy()
# Inform user if any molecules were invalid
invalid_molecules = df_molecules[~df_molecules['Valid']]
if not invalid_molecules.empty:
st.warning(f"{len(invalid_molecules)} generated molecules were invalid and excluded from predictions.")
# Check if there are valid molecules to proceed
if df_valid.empty:
st.error("No valid molecules were generated. Please try adjusting the generation parameters.")
else:
# ADMET Predictions
preds = admet_model.predict(smiles=df_valid['SMILES'].tolist())
# Ensure 'SMILES' is a column in preds
if 'SMILES' not in preds.columns:
preds['SMILES'] = df_valid['SMILES'].values
# Merge predictions with valid molecules
df_results = pd.merge(df_valid, preds, on='SMILES', how='inner')
# Set 'Molecule Name' as index
df_results.set_index('Molecule Name', inplace=True)
# Select only desired ADMET properties
admet_properties = [
'molecular_weight', 'logP', 'hydrogen_bond_acceptors',
'hydrogen_bond_donors', 'QED', 'ClinTox', 'hERG', 'BBB_Martins'
]
df_results_filtered = df_results[['SMILES', 'Valid'] + admet_properties]
# Check if df_results_filtered is empty after filtering
if df_results_filtered.empty:
st.error("No valid ADMET predictions were obtained. Please try adjusting the generation parameters.")
else:
# Display Molecules
st.subheader('Generated Molecules')
cols_per_row = min(3, len(df_results_filtered)) # Max 3 columns
cols = st.columns(cols_per_row)
for idx, (mol_name, row) in enumerate(df_results_filtered.iterrows()):
smiles = row['SMILES']
# Attempt to encode to SAFE
try:
safe_string = safe.encode(smiles)
except Exception as e:
safe_string = None
st.error(f"Could not convert to SAFE encoding for {mol_name}: {e}")
# Generate molecule image (SMILES or SAFE)
img = generate_molecule_image(smiles)
with cols[idx % cols_per_row]:
if img is not None and isinstance(img, Image.Image):
st.image(img, caption=mol_name)
else:
st.error(f"Could not generate image for {mol_name}")
# Display SMILES string
st.write("**SMILES:**")
st.text(smiles)
# Display SAFE encoding if available
if safe_string:
st.write("**SAFE Encoding:**")
st.text(safe_string)
# Optionally display SAFE visualization
safe_img = generate_molecule_image(safe_string, use_safe=True)
if safe_img is not None:
st.image(safe_img, caption=f"{mol_name} (SAFE Visualization)")
# Display selected ADMET properties
st.write("**ADMET Properties:**")
admet_data = row.drop(['SMILES', 'Valid'])
st.write(admet_data)
else:
st.write("Click the 'Generate Molecules' button to generate beta-lactam molecules.")