File size: 6,737 Bytes
63ea7d2
 
 
 
05c6778
1c74fb1
46a04b7
1c74fb1
63ea7d2
 
1784508
 
63ea7d2
 
 
 
46a04b7
63ea7d2
 
05c6778
46a04b7
 
 
 
 
 
 
 
 
 
05c6778
 
46a04b7
 
 
05c6778
46a04b7
63ea7d2
 
e05c3b0
05c6778
63ea7d2
 
 
 
 
 
46a04b7
 
 
63ea7d2
 
05c6778
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63ea7d2
 
 
e05c3b0
05c6778
63ea7d2
 
 
 
 
 
46a04b7
 
63ea7d2
 
05c6778
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63ea7d2
 
 
84a2044
05c6778
63ea7d2
 
 
 
05c6778
63ea7d2
05c6778
1c74fb1
 
 
 
 
05c6778
63ea7d2
05c6778
1c74fb1
63ea7d2
 
 
1784508
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e05c3b0
1784508
05c6778
63ea7d2
 
 
 
 
 
 
 
 
 
 
46a04b7
05c6778
1784508
46a04b7
 
 
 
 
 
 
 
 
 
 
63ea7d2
 
 
e05c3b0
05c6778
63ea7d2
 
 
 
 
 
 
46a04b7
63ea7d2
 
 
 
05c6778
46a04b7
 
 
 
 
 
 
 
 
 
 
 
05c6778
 
46a04b7
 
05c6778
46a04b7
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
"""Wrappers for BioNeMo NIMs."""

from lynxkite_graph_analytics import Bundle
from lynxkite.core import ops
import httpx
import io
import pandas as pd
import rdkit
import os

from . import k8s


ENV = "LynxKite Graph Analytics"
op = ops.op_registration(ENV)

key = os.getenv("NVCF_RUN_KEY")


async def query_bionemo_nim(
    url: str,
    payload: dict,
):
    headers = {
        "Authorization": f"Bearer {key}",
        "NVCF-POLL-SECONDS": "500",
        "Content-Type": "application/json",
    }
    try:
        print(f"Sending request to {url}")
        async with httpx.AsyncClient(timeout=600) as client:
            response = await client.post(url, json=payload, headers=headers)
        print(f"Received response from {url}", response.status_code)
        response.raise_for_status()
        return response.json()
    except httpx.RequestError as e:
        raise ValueError(f"Query failed: {e}")


@op("MSA-search", slow=True)
async def msa_search(
    bundle: Bundle,
    *,
    protein_table: str,
    protein_column: str,
    e_value: float = 0.0001,
    iterations: int = 1,
    search_type: str = "alphafold2",
    output_alignment_formats: str = "a3m",
    databases: str = "Uniref30_2302,colabfold_envdb_202108",
):
    bundle = bundle.copy()
    bundle.dfs[protein_table]["alignments"] = None

    formats = [format.strip() for format in output_alignment_formats.split(",")]
    dbs = [db.strip() for db in databases.split(",")]

    for idx, protein_sequence in enumerate(bundle.dfs[protein_table][protein_column]):
        print(f"Processing protein {idx + 1}/{len(bundle.dfs[protein_table])}")

        response = await query_bionemo_nim(
            url="https://health.api.nvidia.com/v1/biology/colabfold/msa-search/predict",
            payload={
                "sequence": protein_sequence,
                "e_value": e_value,
                "iterations": iterations,
                "search_type": search_type,
                "output_alignment_formats": formats,
                "databases": dbs,
            },
        )

        bundle.dfs[protein_table].at[idx, "alignments"] = response["alignments"]

    return bundle


@op("Query OpenFold2", slow=True)
async def query_openfold2(
    bundle: Bundle,
    *,
    protein_table: str,
    protein_column: str,
    alignment_table: str,
    alignment_column: str,
    selected_models: str = "1,2",
    relax_prediction: bool = False,
):
    bundle = bundle.copy()

    bundle.dfs[protein_table]["folded_protein"] = None
    selected_models_list = [int(model) for model in selected_models.split(",")]

    for idx in range(len(bundle.dfs[protein_table])):
        print(f"Processing protein {idx + 1}/{len(bundle.dfs[protein_table])}")

        protein = bundle.dfs[protein_table][protein_column].iloc[idx]
        alignments = bundle.dfs[alignment_table][alignment_column].iloc[idx]

        response = await query_bionemo_nim(
            url="https://health.api.nvidia.com/v1/biology/openfold/openfold2/predict-structure-from-msa-and-template",
            payload={
                "sequence": protein,
                "alignments": alignments,
                "selected_models": selected_models_list,
                "relax_prediction": relax_prediction,
            },
        )

        folded_protein = response["structures_in_ranked_order"].pop(0)["structure"]
        bundle.dfs[protein_table].at[idx, "folded_protein"] = folded_protein

    bundle.dfs["openfold"] = pd.DataFrame()

    return bundle


@op("View molecule", view="molecule")
def view_molecule(
    bundle: Bundle,
    *,
    molecule_table: str,
    molecule_column: str,
    row_index: int = 0,
):
    molecule_data = bundle.dfs[molecule_table][molecule_column].iloc[row_index]
    if isinstance(molecule_data, rdkit.Chem.Mol):
        sio = io.StringIO()
        with rdkit.Chem.SDWriter(sio) as w:
            w.write(molecule_data)
        molecule_data = sio.getvalue()

    return {
        "data": molecule_data,
        "format": "pdb" if molecule_data.startswith("ATOM") else "sdf",
    }


def _needs_bionemo_k8s(**k8s_kwargs):
    if USE_K8S:
        return k8s.needs(**k8s_kwargs)
    else:
        return lambda func: func


def base_url(service):
    if USE_K8S:
        return f"http://{k8s.get_ip(service)}/"
    else:
        return "https://health.api.nvidia.com/"


USE_K8S = False  # Not production ready yet.
needs_genmol_k8s = _needs_bionemo_k8s(
    name="genmol",
    image="nvcr.io/nim/nvidia/genmol:1.0.0",
    port=8000,
)


@op("Query GenMol", slow=True)
@needs_genmol_k8s
async def query_genmol(
    bundle: Bundle,
    *,
    molecule_table: str,
    molecule_column: str,
    num_molecules: int = 5,
    temperature: float = 1.0,
    noise: float = 0.2,
    step_size: int = 4,
    scoring: str = "QED",
):
    bundle = bundle.copy()

    response = await query_bionemo_nim(
        url=f"{base_url('genmol')}v1/biology/nvidia/genmol/generate",
        payload={
            "smiles": bundle.dfs[molecule_table][molecule_column].iloc[0],
            "num_molecules": num_molecules,
            "temperature": temperature,
            "noise": noise,
            "step_size": step_size,
            "scoring": scoring,
        },
    )
    generated_ligands = "\n".join([v["smiles"] for v in response["molecules"]])
    bundle.dfs[molecule_table]["ligands"] = generated_ligands
    return bundle


@op("Query DiffDock", slow=True)
async def query_diffdock(
    proteins: Bundle,
    ligands: Bundle,
    *,
    protein_table: str,
    protein_column: str,
    ligand_table: str,
    ligand_column: str,
    ligand_file_type: str = "txt",
    num_poses=10,
    time_divisions=20,
    num_steps=18,
):
    response = await query_bionemo_nim(
        url="https://health.api.nvidia.com/v1/biology/mit/diffdock",
        payload={
            "protein": proteins.dfs[protein_table][protein_column].iloc[0],
            "ligand": ligands.dfs[ligand_table][ligand_column].iloc[0],
            "ligand_file_type": ligand_file_type,
            "num_poses": num_poses,
            "time_divisions": time_divisions,
            "num_steps": num_steps,
        },
    )
    bundle = Bundle()
    bundle.dfs["diffdock_table"] = pd.DataFrame()
    bundle.dfs["diffdock_table"]["protein"] = [response["protein"]] * len(response["status"])
    bundle.dfs["diffdock_table"]["ligand"] = [response["ligand"]] * len(response["status"])
    bundle.dfs["diffdock_table"]["trajectory"] = response["trajectory"]
    bundle.dfs["diffdock_table"]["ligand_positions"] = response["ligand_positions"]
    bundle.dfs["diffdock_table"]["position_confidence"] = response["position_confidence"]
    bundle.dfs["diffdock_table"]["status"] = response["status"]

    return bundle