|
import logging |
|
import pathlib |
|
import gradio as gr |
|
import pandas as pd |
|
from gt4sd.algorithms.conditional_generation.reinvent import Reinvent, ReinventGenerator |
|
from gt4sd.algorithms.registry import ApplicationsRegistry |
|
|
|
from utils import draw_grid_generate |
|
|
|
logger = logging.getLogger(__name__) |
|
logger.addHandler(logging.NullHandler()) |
|
|
|
|
|
def run_inference( |
|
algorithm_version: str, |
|
smiles: str, |
|
length: float, |
|
sample_uniquely: bool, |
|
number_of_samples: int, |
|
): |
|
|
|
config = ReinventGenerator( |
|
algorithm_version=algorithm_version, |
|
max_sequence_length=length, |
|
randomize=True, |
|
sample_uniquely=sample_uniquely, |
|
) |
|
model = Reinvent(config, target=smiles) |
|
samples = list(model.sample(number_of_samples)) |
|
|
|
return draw_grid_generate(samples=samples, n_cols=5, seeds=[smiles]) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
all_algos = ApplicationsRegistry.list_available() |
|
algos = [ |
|
x["algorithm_version"] |
|
for x in list(filter(lambda x: "Reinvent" in x["algorithm_name"], all_algos)) |
|
] |
|
|
|
|
|
metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards") |
|
|
|
examples = pd.read_csv(metadata_root.joinpath("examples.csv"), header=None).fillna( |
|
"" |
|
) |
|
|
|
with open(metadata_root.joinpath("article.md"), "r") as f: |
|
article = f.read() |
|
with open(metadata_root.joinpath("description.md"), "r") as f: |
|
description = f.read() |
|
|
|
demo = gr.Interface( |
|
fn=run_inference, |
|
title="REINVENT", |
|
inputs=[ |
|
gr.Dropdown( |
|
algos, |
|
label="Algorithm version", |
|
value="v0", |
|
), |
|
gr.Textbox( |
|
label="Primer SMILES", |
|
placeholder="FP(F)F.CP(C)c1ccccc1.[Au]", |
|
lines=1, |
|
), |
|
gr.Slider( |
|
minimum=5, |
|
maximum=400, |
|
value=100, |
|
label="Maximal sequence length", |
|
step=1, |
|
), |
|
gr.Radio(choices=[True, False], label="Sampling uniquely", value=True), |
|
gr.Slider( |
|
minimum=1, maximum=50, value=10, label="Number of samples", step=1 |
|
), |
|
], |
|
outputs=gr.HTML(label="Output"), |
|
article=article, |
|
description=description, |
|
examples=examples.values.tolist(), |
|
) |
|
demo.launch(debug=True, show_error=True) |
|
|