Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 4,334 Bytes
20d2150 875e2f3 20d2150 875e2f3 20d2150 875e2f3 20d2150 98b9e26 20d2150 98b9e26 db5eef3 98b9e26 20d2150 875e2f3 20d2150 875e2f3 20d2150 875e2f3 20d2150 875e2f3 20d2150 875e2f3 20d2150 875e2f3 20d2150 875e2f3 20d2150 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
"""
Copyright (c) Meta Platforms, Inc. and affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
import hashlib
import os
from pathlib import Path
import ase
import backoff
import gradio as gr
import huggingface_hub as hf_hub
import requests
from ase.calculators.calculator import Calculator
from ase.db.core import now
from ase.db.row import AtomsRow
from ase.io.jsonio import decode, encode
from requests.exceptions import HTTPError
def hash_save_file(atoms: ase.Atoms, task_name, path: Path | str):
atoms = atoms.copy()
atoms.info["task_name"] = task_name
atoms.write(
Path(path)
/ f"{hashlib.md5(atoms_to_json(atoms).encode('utf-8')).hexdigest()}.traj"
)
return
def validate_uma_access(oauth_token):
try:
hf_hub.HfApi().auth_check(repo_id="facebook/UMA", token=oauth_token.token)
return True
except (hf_hub.errors.HfHubHTTPError, AttributeError):
return False
class HFEndpointCalculator(Calculator):
# A simple calculator that uses the Hugging Face Inference Endpoints to run
implemented_properties = ["energy", "free_energy", "stress", "forces"]
def __init__(
self,
atoms,
endpoint_url,
oauth_token,
task_name,
example=False,
*args,
**kwargs,
):
# If we have an example structure, we don't need to check for authentication
# Otherwise, we need to check if the user is authenticated and has gated access to the UMA models
if not example:
if validate_uma_access(oauth_token):
try:
hash_save_file(atoms, task_name, "/data/custom_inputs/")
except FileNotFoundError:
pass
else:
raise gr.Error(
"You need to log in to HF and have gated model access to UMA before running your own simulations!"
)
self.url = endpoint_url
self.token = os.environ["HF_TOKEN"]
self.atoms = atoms
self.task_name = task_name
super().__init__(*args, **kwargs)
@staticmethod
@backoff.on_exception(
backoff.expo,
(requests.exceptions.RequestException,),
max_tries=10,
jitter=backoff.full_jitter,
)
def _post_with_backoff(url, headers, payload):
response = requests.post(url, headers=headers, json=payload)
response.raise_for_status()
return response
def calculate(self, atoms, properties, system_changes):
Calculator.calculate(self, atoms, properties, system_changes)
task_name = self.task_name.lower()
payload = {
"inputs": atoms_to_json(atoms, data=atoms.info),
"properties": properties,
"system_changes": system_changes,
"task_name": task_name,
}
headers = {
"Accept": "application/json",
"Authorization": f"Bearer {self.token}",
"Content-Type": "application/json",
}
try:
response = self._post_with_backoff(self.url, headers, payload)
response_dict = response.json()
except HTTPError as error:
hash_save_file(atoms, task_name, "/data/custom_inputs/errors/")
raise gr.Error(
f"Backend failure during your calculation; if you have continued issues please file an issue in the main FAIR chemistry repo (https://github.com/facebookresearch/fairchem).\n{error}"
)
# Load the response and store the results in the calc and atoms object
response_dict = decode(response_dict)
self.results = response_dict["results"]
atoms.info = response_dict["info"]
def atoms_to_json(atoms, data=None):
# Similar to ase.db.jsondb
mtime = now()
row = AtomsRow(atoms)
row.ctime = mtime
dct = {}
for key in row.__dict__:
if key[0] == "_" or key in row._keys or key == "id":
continue
dct[key] = row[key]
dct["mtime"] = mtime
if data:
dct["data"] = data
else:
dct["data"] = {}
constraints = row.get("constraints")
if constraints:
dct["constraints"] = constraints
return encode(dct)
|