moPPIt-v2 / sample.py
AlienChen's picture
Update sample.py
f095528 verified
import os
import hydra
import lightning as L
import numpy as np
import omegaconf
import pandas as pd
import rdkit
import rich.syntax
import rich.tree
import torch
from tqdm.auto import tqdm
import pdb
import dataloader
import diffusion
from models.bindevaluator import BindEvaluator
rdkit.rdBase.DisableLog('rdApp.error')
omegaconf.OmegaConf.register_new_resolver(
'cwd', os.getcwd)
omegaconf.OmegaConf.register_new_resolver(
'device_count', torch.cuda.device_count)
omegaconf.OmegaConf.register_new_resolver(
'eval', eval)
omegaconf.OmegaConf.register_new_resolver(
'div_up', lambda x, y: (x + y - 1) // y)
omegaconf.OmegaConf.register_new_resolver(
'if_then_else',
lambda condition, x, y: x if condition else y
)
def _print_config(
config: omegaconf.DictConfig,
resolve: bool = True) -> None:
"""Prints content of DictConfig using Rich library and its tree structure.
Args:
config (DictConfig): Configuration composed by Hydra.
resolve (bool): Whether to resolve reference fields of DictConfig.
"""
style = 'dim'
tree = rich.tree.Tree('CONFIG', style=style,
guide_style=style)
fields = config.keys()
for field in fields:
branch = tree.add(field, style=style, guide_style=style)
config_section = config.get(field)
branch_content = str(config_section)
if isinstance(config_section, omegaconf.DictConfig):
branch_content = omegaconf.OmegaConf.to_yaml(
config_section, resolve=resolve)
branch.add(rich.syntax.Syntax(branch_content, 'yaml'))
rich.print(tree)
def parse_motif(motif: str) -> list:
parts = motif.split(',')
result = []
for part in parts:
part = part.strip()
if '-' in part:
start, end = map(int, part.split('-'))
result.extend(range(start, end + 1))
else:
result.append(int(part))
return torch.tensor(result)
@hydra.main(version_base=None, config_path='./configs',
config_name='config')
def main(config: omegaconf.DictConfig) -> None:
# Reproducibility
L.seed_everything(config.seed)
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = False
# _print_config(config, resolve=True)
print(f"Checkpoint: {config.eval.checkpoint_path}")
tokenizer = dataloader.get_tokenizer(config)
target_sequence = tokenizer(config.eval.target_sequence, return_tensors='pt')['input_ids']
pretrained = diffusion.Diffusion.load_from_checkpoint(
config.eval.checkpoint_path,
tokenizer=tokenizer,
config=config, logger=False)
pretrained.eval()
pretrained = pretrained.to('cuda')
bindevaluator = BindEvaluator.load_from_checkpoint(
config.guidance.classifier_checkpoint_path,
n_layers=8,
d_model=128,
d_hidden=128,
n_head=8,
d_k=64,
d_v=128,
d_inner=64)
bindevaluator = bindevaluator.to('cuda')
samples = []
for _ in tqdm(
range(config.sampling.num_sample_batches),
desc='Gen. batches', leave=False):
sample = pretrained.sample(
target_sequence = target_sequence,
target_motifs = parse_motif(config.eval.target_motifs),
classifier_model = bindevaluator
)
# print(f"Batch took {time.time() - start:.2f} seconds.")
samples.extend(
pretrained.tokenizer.batch_decode(sample))
print([sample.replace(' ', '')[5:-5] for sample in samples])
samples = [sample.replace(' ', '')[5:-5] for sample in samples]
print(samples)
if __name__ == '__main__':
main()