Spaces:
Sleeping
Sleeping
import os | |
from concurrent.futures import ThreadPoolExecutor | |
from typing import Dict, List, Optional | |
import pandas as pd | |
import numpy as np | |
from tqdm import tqdm | |
from rdkit import Chem | |
from protac_splitter.evaluation import check_reassembly | |
def generate_protacs( | |
poi_fg_distr: Dict[str, float], | |
e3_fg_distr: Dict[str, float], | |
substr_fg_2_linker: Dict[str, List[str]], | |
poi_fg_2_substr: Dict[str, List[str]], | |
e3_fg_2_substr: Dict[str, List[str]], | |
num_samples: int, | |
random_state: int = 42, | |
batch_size: int = 1000, | |
max_workers: int = 4, | |
original_df: Optional[pd.DataFrame] = None, | |
filename_generated_df: Optional[str] = None, | |
base_data_dir: Optional[str] = None, | |
cover_all_smiles: bool = False, | |
) -> pd.DataFrame: | |
""" Generate PROTACs given the distributions of functional groups at attachment points. | |
Args: | |
poi_fg_distr: The distribution of functional groups at the POI attachment point. | |
e3_fg_distr: The distribution of functional groups at the E3 attachment point. | |
substr_fg_2_linker: The mapping of functional groups to linkers. | |
poi_fg_2_substr: The mapping of functional groups to POI substrates. | |
e3_fg_2_substr: The mapping of functional groups to E3 substrates. | |
num_samples: The number of PROTACs to generate. | |
random_state: The random state for reproducibility. | |
batch_size: The batch size for generating PROTACs. | |
max_workers: The maximum number of workers for the ThreadPoolExecutor. | |
original_df: The original DataFrame containing the PROTACs. Must have a | |
column named 'PROTAC SMILES' containing the strings to | |
avoid generating. The check is done on strings, so make | |
sure to canonize/standardize the SMILES strings. | |
filename_generated_df: The filename to save the generated PROTACs. | |
Returns: | |
pd.DataFrame: The DataFrame containing the generated PROTACs. | |
""" | |
np.random.seed(random_state) | |
final_df = pd.DataFrame() | |
total_batches = int(np.ceil(num_samples / batch_size)) | |
def generate_protac_batch(batch_size: int, random_state: int) -> List[dict]: | |
np.random.seed(random_state) | |
# Sample functional groups for POI and E3 | |
poi_fgs = np.random.choice(list(poi_fg_distr.keys()), size=batch_size, p=list(poi_fg_distr.values())) | |
e3_fgs = np.random.choice(list(e3_fg_distr.keys()), size=batch_size, p=list(e3_fg_distr.values())) | |
# Map functional groups to corresponding substrates | |
# NOTE: When size argument is specified, the output is a numpy array. | |
# NOTE: If the functional group is not in the dictionary, the output is an empty numpy array. | |
poi_samples = [ | |
np.random.choice(poi_fg_2_substr.get(fg, []), size=1 if fg in poi_fg_2_substr and poi_fg_2_substr[fg] else 0) | |
for fg in poi_fgs | |
] | |
e3_samples = [ | |
np.random.choice(e3_fg_2_substr.get(fg, []), size=1 if fg in e3_fg_2_substr and e3_fg_2_substr[fg] else 0) | |
for fg in e3_fgs | |
] | |
generated_protacs = [] | |
for poi_smiles, poi_fg, e3_smiles, e3_fg in zip(poi_samples, poi_fgs, e3_samples, e3_fgs): | |
# Check if poi_smiles and e3_smiles are not an empty numpy array | |
if poi_smiles.size == 0 or e3_smiles.size == 0: | |
continue | |
# Convert the numpy arrays to strings | |
poi_smiles, e3_smiles = poi_smiles[0], e3_smiles[0] | |
linkers = set(substr_fg_2_linker.get(poi_fg, [])) & set(substr_fg_2_linker.get(e3_fg, [])) | |
if not linkers: | |
continue | |
linker_smiles = np.random.choice(list(linkers)) | |
# Get the PROTAC by combining the POI, linker, and E3 | |
ligands_smiles = '.'.join([poi_smiles, linker_smiles, e3_smiles]) | |
protac = Chem.MolFromSmiles(ligands_smiles) | |
if protac is None: | |
continue | |
try: | |
protac = Chem.molzip(protac) | |
except: | |
continue | |
# Sanitize molecule | |
try: | |
zero_on_success = Chem.SanitizeMol(protac, catchErrors=True) | |
if zero_on_success != 0: | |
continue | |
protac_smiles = Chem.MolToSmiles(protac, canonical=True) | |
except: | |
continue | |
if original_df is not None and protac_smiles in original_df['PROTAC SMILES'].values: | |
continue | |
# Check if PROTAC can be reassembled | |
if not check_reassembly(protac_smiles, ligands_smiles): | |
continue | |
generated_protacs.append({ | |
'PROTAC SMILES': protac_smiles, | |
'POI Ligand SMILES with direction': poi_smiles, | |
'Linker SMILES with direction': linker_smiles, | |
'E3 Binder SMILES with direction': e3_smiles, | |
'POI Ligand Functional Group': poi_fg, | |
'E3 Binder Functional Group': e3_fg, | |
}) | |
return generated_protacs | |
with ThreadPoolExecutor(max_workers=max_workers) as executor: | |
futures = [] | |
for i in tqdm(range(total_batches), desc="Generating Batches"): | |
futures.append(executor.submit(generate_protac_batch, batch_size, random_state + i)) | |
for i, future in tqdm(enumerate(futures), desc="Processing Results", total=total_batches): | |
generated_batch = future.result() | |
if generated_batch: | |
batch_df = pd.DataFrame(generated_batch) | |
final_df = pd.concat([final_df, batch_df]).drop_duplicates() | |
if i % 100 == 0: | |
if base_data_dir: | |
batch_df.to_csv(os.path.join(base_data_dir, f'generated_protacs_batch={i}.csv'), index=False) | |
else: | |
batch_df.to_csv(f'generated_protacs_batch={i}.csv', index=False) | |
if filename_generated_df: | |
final_df.to_csv(filename_generated_df, index=False) | |
if len(final_df) >= num_samples: | |
break | |
if not final_df.empty: | |
generated_pois = set(final_df['POI Ligand SMILES with direction'].unique()) | |
generated_e3s = set(final_df['E3 Binder SMILES with direction'].unique()) | |
generated_linkers = set(final_df['Linker SMILES with direction'].unique()) | |
else: | |
generated_pois = set() | |
generated_e3s = set() | |
generated_linkers = set() | |
# Check how we covered the available substructures | |
avail_pois = set() | |
avail_e3s = set() | |
avail_linkers = set() | |
for fg in poi_fg_2_substr: | |
avail_pois.update(set(poi_fg_2_substr[fg])) | |
for fg in e3_fg_2_substr: | |
avail_e3s.update(set(e3_fg_2_substr[fg])) | |
for fg in substr_fg_2_linker: | |
avail_linkers.update(set(substr_fg_2_linker[fg])) | |
e3_coverage = len(generated_e3s) / len(avail_e3s) | |
poi_coverage = len(generated_pois) / len(avail_pois) | |
linker_coverage = len(generated_linkers) / len(avail_linkers) | |
print(f"POI coverage: {poi_coverage:.3%}") | |
print(f"E3 coverage: {e3_coverage:.3%}") | |
print(f"Linker coverage: {linker_coverage:.3%}") | |
# Get the "leftover" ligands | |
leftover_pois = avail_pois - generated_pois | |
leftover_e3s = avail_e3s - generated_e3s | |
leftover_linkers = avail_linkers - generated_linkers | |
covering_df = [] | |
with tqdm(total=len(leftover_pois) + len(leftover_e3s) + len(leftover_linkers), desc="Covering Leftover Ligands") as pbar: | |
while True: | |
if not cover_all_smiles: | |
break | |
# Randomly select a POI, E3, and linker | |
if not leftover_pois: | |
pois_to_sample = avail_pois | |
else: | |
pois_to_sample = leftover_pois | |
if not leftover_e3s: | |
e3s_to_sample = avail_e3s | |
else: | |
e3s_to_sample = leftover_e3s | |
if not leftover_linkers: | |
linkers_to_sample = avail_linkers | |
else: | |
linkers_to_sample = leftover_linkers | |
poi_smiles = np.random.choice(list(pois_to_sample)) | |
e3_smiles = np.random.choice(list(e3s_to_sample)) | |
linker_smiles = np.random.choice(list(linkers_to_sample)) | |
# Get the PROTAC by combining the POI, linker, and E3 | |
ligands_smiles = '.'.join([poi_smiles, linker_smiles, e3_smiles]) | |
protac = Chem.MolFromSmiles(ligands_smiles) | |
if protac is None: | |
continue | |
try: | |
protac = Chem.molzip(protac) | |
except: | |
continue | |
# Sanitize molecule | |
try: | |
zero_on_success = Chem.SanitizeMol(protac, catchErrors=True) | |
if zero_on_success != 0: | |
continue | |
protac_smiles = Chem.MolToSmiles(protac, canonical=True) | |
except: | |
continue | |
if original_df is not None and protac_smiles in original_df['PROTAC SMILES'].values: | |
continue | |
# Check if PROTAC can be reassembled | |
if not check_reassembly(protac_smiles, ligands_smiles): | |
continue | |
covering_df.append({ | |
'PROTAC SMILES': protac_smiles, | |
'POI Ligand SMILES with direction': poi_smiles, | |
'Linker SMILES with direction': linker_smiles, | |
'E3 Binder SMILES with direction': e3_smiles, | |
'POI Ligand Functional Group': None, | |
'E3 Binder Functional Group': None, | |
}) | |
generated_pois.add(poi_smiles) | |
generated_e3s.add(e3_smiles) | |
generated_linkers.add(linker_smiles) | |
ligands_added = 0 | |
if poi_smiles in leftover_pois: | |
leftover_pois.remove(poi_smiles) | |
ligands_added += 1 | |
if e3_smiles in leftover_e3s: | |
leftover_e3s.remove(e3_smiles) | |
ligands_added += 1 | |
if linker_smiles in leftover_linkers: | |
leftover_linkers.remove(linker_smiles) | |
ligands_added += 1 | |
e3_coverage = len(generated_e3s) / len(avail_e3s) | |
poi_coverage = len(generated_pois) / len(avail_pois) | |
linker_coverage = len(generated_linkers) / len(avail_linkers) | |
# Update the pbar and write the coverage | |
pbar.update(ligands_added) | |
pbar.set_postfix({ | |
'POI': f"{poi_coverage:.2%}", | |
'E3': f"{e3_coverage:.2%}", | |
'Linker': f"{linker_coverage:.2%}", | |
}) | |
if not leftover_pois and not leftover_e3s and not leftover_linkers: | |
break | |
final_df = pd.concat([final_df, pd.DataFrame(covering_df)]).drop_duplicates() | |
# Save to file if specified | |
if filename_generated_df: | |
final_df.to_csv(filename_generated_df, index=False) | |
print(f"Generated PROTACs saved to: {filename_generated_df}") | |
return final_df |