File size: 4,334 Bytes
20d2150
 
 
 
 
 
 
 
 
 
 
 
875e2f3
20d2150
 
875e2f3
20d2150
 
 
 
875e2f3
20d2150
 
 
 
 
 
 
 
 
 
 
 
98b9e26
 
 
 
 
 
 
 
20d2150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98b9e26
db5eef3
 
 
 
98b9e26
20d2150
 
 
 
875e2f3
 
20d2150
 
 
 
 
875e2f3
 
 
 
 
 
 
 
 
 
 
 
20d2150
 
 
 
 
875e2f3
 
 
 
 
 
 
 
 
 
 
 
 
20d2150
875e2f3
 
 
20d2150
 
875e2f3
20d2150
 
 
875e2f3
20d2150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
875e2f3
 
20d2150
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""
Copyright (c) Meta Platforms, Inc. and affiliates.

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

import hashlib
import os
from pathlib import Path

import ase
import backoff
import gradio as gr
import huggingface_hub as hf_hub
import requests
from ase.calculators.calculator import Calculator
from ase.db.core import now
from ase.db.row import AtomsRow
from ase.io.jsonio import decode, encode
from requests.exceptions import HTTPError


def hash_save_file(atoms: ase.Atoms, task_name, path: Path | str):
    atoms = atoms.copy()
    atoms.info["task_name"] = task_name
    atoms.write(
        Path(path)
        / f"{hashlib.md5(atoms_to_json(atoms).encode('utf-8')).hexdigest()}.traj"
    )
    return


def validate_uma_access(oauth_token):
    try:
        hf_hub.HfApi().auth_check(repo_id="facebook/UMA", token=oauth_token.token)
        return True
    except (hf_hub.errors.HfHubHTTPError, AttributeError):
        return False


class HFEndpointCalculator(Calculator):
    # A simple calculator that uses the Hugging Face Inference Endpoints to run

    implemented_properties = ["energy", "free_energy", "stress", "forces"]

    def __init__(
        self,
        atoms,
        endpoint_url,
        oauth_token,
        task_name,
        example=False,
        *args,
        **kwargs,
    ):
        # If we have an example structure, we don't need to check for authentication
        # Otherwise, we need to check if the user is authenticated and has gated access to the UMA models
        if not example:
            if validate_uma_access(oauth_token):
                try:
                    hash_save_file(atoms, task_name, "/data/custom_inputs/")
                except FileNotFoundError:
                    pass
            else:
                raise gr.Error(
                    "You need to log in to HF and have gated model access to UMA before running your own simulations!"
                )

        self.url = endpoint_url
        self.token = os.environ["HF_TOKEN"]
        self.atoms = atoms
        self.task_name = task_name

        super().__init__(*args, **kwargs)

    @staticmethod
    @backoff.on_exception(
        backoff.expo,
        (requests.exceptions.RequestException,),
        max_tries=10,
        jitter=backoff.full_jitter,
    )
    def _post_with_backoff(url, headers, payload):
        response = requests.post(url, headers=headers, json=payload)
        response.raise_for_status()
        return response

    def calculate(self, atoms, properties, system_changes):
        Calculator.calculate(self, atoms, properties, system_changes)

        task_name = self.task_name.lower()

        payload = {
            "inputs": atoms_to_json(atoms, data=atoms.info),
            "properties": properties,
            "system_changes": system_changes,
            "task_name": task_name,
        }

        headers = {
            "Accept": "application/json",
            "Authorization": f"Bearer {self.token}",
            "Content-Type": "application/json",
        }

        try:
            response = self._post_with_backoff(self.url, headers, payload)
            response_dict = response.json()
        except HTTPError as error:
            hash_save_file(atoms, task_name, "/data/custom_inputs/errors/")
            raise gr.Error(
                f"Backend failure during your calculation; if you have continued issues please file an issue in the main FAIR chemistry repo (https://github.com/facebookresearch/fairchem).\n{error}"
            )

        # Load the response and store the results in the calc and atoms object
        response_dict = decode(response_dict)
        self.results = response_dict["results"]
        atoms.info = response_dict["info"]


def atoms_to_json(atoms, data=None):
    # Similar to ase.db.jsondb

    mtime = now()

    row = AtomsRow(atoms)
    row.ctime = mtime

    dct = {}
    for key in row.__dict__:
        if key[0] == "_" or key in row._keys or key == "id":
            continue
        dct[key] = row[key]

    dct["mtime"] = mtime

    if data:
        dct["data"] = data
    else:
        dct["data"] = {}

    constraints = row.get("constraints")
    if constraints:
        dct["constraints"] = constraints

    return encode(dct)