Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import logging | |
from pathlib import Path | |
from typing import Dict, List, Tuple | |
import numpy as np | |
import torch | |
import json | |
from fairchem.data.omol.modules.evaluator import ( | |
ligand_pocket, | |
ligand_strain, | |
geom_conformers, | |
protonation_energies, | |
unoptimized_ie_ea, | |
distance_scaling, | |
unoptimized_spin_gap, | |
) | |
class SubmissionLoadError(Exception): | |
"""Raised if unable to load the submission file.""" | |
OMOL_EVAL_FUNCTIONS = { | |
"Ligand pocket": ligand_pocket, | |
"Ligand strain": ligand_strain, | |
"Conformers": geom_conformers, | |
"Protonation": protonation_energies, | |
"IE_EA": unoptimized_ie_ea, | |
"Distance scaling": distance_scaling, | |
"Spin gap": unoptimized_spin_gap, | |
} | |
OMOL_DATA_ID_MAPPING = { | |
"metal_complexes": ["metal_complexes"], | |
"electrolytes": ["elytes"], | |
"biomolecules": ["biomolecules"], | |
"neutral_organics": ["ani2x", "orbnet_denali", "geom_orca6", "trans1x", "rgd"], | |
} | |
def reorder(ref: np.ndarray, to_reorder: np.ndarray) -> np.ndarray: | |
""" | |
Get the ordering so that `to_reorder[ordering]` == ref. | |
eg: | |
ref = [c, a, b] | |
to_reorder = [b, a, c] | |
order = reorder(ref, to_reorder) # [2, 1, 0] | |
assert ref == to_reorder[order] | |
Parameters | |
---------- | |
ref : np.ndarray | |
Reference array. Must not contains duplicates. | |
to_reorder : np.ndarray | |
Array to re-order. Must not contains duplicates. | |
Items must be the same as in `ref`. | |
Returns | |
------- | |
np.ndarray | |
the ordering to apply on `to_reorder` | |
""" | |
assert len(ref) == len(set(ref)) | |
assert len(to_reorder) == len(set(to_reorder)) | |
assert set(ref) == set(to_reorder) | |
item_to_idx = {item: idx for idx, item in enumerate(to_reorder)} | |
return np.array([item_to_idx[item] for item in ref]) | |
def get_order(path_submission: Path, path_annotations: Path): | |
try: | |
with np.load(path_submission) as data: | |
submission_ids = data["ids"] | |
except Exception as e: | |
raise SubmissionLoadError( | |
f"Error loading submission file. 'ids' must not be object types." | |
) from e | |
with np.load(path_annotations, allow_pickle=True) as data: | |
annotations_ids = data["ids"] | |
# Use sets for faster comparison | |
submission_set = set(submission_ids) | |
annotations_set = set(annotations_ids) | |
if submission_set != annotations_set: | |
missing_ids = annotations_set - submission_set | |
unexpected_ids = submission_set - annotations_set | |
details = ( | |
f"{len(missing_ids)} missing IDs: ({list(missing_ids)[:3]}, ...)\n" | |
f"{len(unexpected_ids)} unexpected IDs: ({list(unexpected_ids)[:3]}, ...)" | |
) | |
raise Exception(f"IDs don't match.\n{details}") | |
assert len(submission_ids) == len( | |
submission_set | |
), "Duplicate IDs found in submission." | |
return reorder(annotations_ids, submission_ids) | |
def s2ef_metrics( | |
annotations_path: Path, | |
submission_filename: Path, | |
subsets: list = ["all"], | |
) -> Dict[str, float]: | |
order = get_order(submission_filename, annotations_path) | |
try: | |
with np.load(submission_filename) as data: | |
forces = data["forces"] | |
energy = data["energy"][order] | |
forces = np.array( | |
np.split(forces, np.cumsum(data["natoms"])[:-1]), dtype=object | |
)[order] | |
except Exception as e: | |
raise SubmissionLoadError( | |
f"Error loading submission data. Make sure you concatenated your forces and there are no object types." | |
) from e | |
if len(set(np.where(np.isinf(energy))[0])) != 0: | |
inf_energy_ids = list(set(np.where(np.isinf(energy))[0])) | |
raise Exception( | |
f"Inf values found in `energy` for IDs: ({inf_energy_ids[:3]}, ...)" | |
) | |
with np.load(annotations_path, allow_pickle=True) as data: | |
target_forces = data["forces"] | |
target_energy = data["energy"] | |
target_data_ids = data["data_ids"] | |
metrics = {} | |
for subset in subsets: | |
if subset == "all": | |
subset_mask = np.ones(len(target_data_ids), dtype=bool) | |
else: | |
allowed_ids = set(OMOL_DATA_ID_MAPPING.get(subset, [])) | |
subset_mask = np.array( | |
[data_id in allowed_ids for data_id in target_data_ids] | |
) | |
sub_energy = energy[subset_mask] | |
sub_target_energy = target_energy[subset_mask] | |
energy_mae = np.mean(np.abs(sub_target_energy - sub_energy)) | |
metrics[f"{subset}_energy_mae"] = energy_mae | |
forces_mae = 0 | |
natoms = 0 | |
for sub_forces, sub_target_forces in zip( | |
forces[subset_mask], target_forces[subset_mask] | |
): | |
forces_mae += np.sum(np.abs(sub_target_forces - sub_forces)) | |
natoms += sub_forces.shape[0] | |
forces_mae /= 3 * natoms | |
metrics[f"{subset}_forces_mae"] = forces_mae | |
return metrics | |
def omol_evaluations( | |
annotations_path: Path, | |
submission_filename: Path, | |
eval_type: str, | |
) -> Dict[str, float]: | |
try: | |
with open(submission_filename) as f: | |
submission_data = json.load(f) | |
except Exception as e: | |
raise SubmissionLoadError(f"Error loading submission file") from e | |
with open(annotations_path) as f: | |
annotations_data = json.load(f) | |
submission_entries = set(submission_data.keys()) | |
annotation_entries = set(annotations_data.keys()) | |
if submission_entries != annotation_entries: | |
missing = annotation_entries - submission_entries | |
unexpected = submission_entries - annotation_entries | |
raise ValueError( | |
f"Submission and annotations entries do not match.\n" | |
f"Missing entries in submission: {missing}\n" | |
f"Unexpected entries in submission: {unexpected}" | |
) | |
assert len(submission_entries) == len( | |
submission_data | |
), "Duplicate entries found in submission." | |
eval_fn = OMOL_EVAL_FUNCTIONS.get(eval_type) | |
metrics = eval_fn(annotations_data, submission_data) | |
return metrics | |
def evaluate( | |
annotations_path: Path, | |
submission_filename: Path, | |
eval_type: str, | |
): | |
if eval_type in ["Validation", "Test"]: | |
metrics = s2ef_metrics( | |
annotations_path, | |
submission_filename, | |
subsets=[ | |
"all", | |
"metal_complexes", | |
"electrolytes", | |
"biomolecules", | |
"neutral_organics", | |
], | |
) | |
elif eval_type in OMOL_EVAL_FUNCTIONS: | |
metrics = omol_evaluations( | |
annotations_path, | |
submission_filename, | |
eval_type, | |
) | |
else: | |
raise ValueError(f"Unknown eval_type: {eval_type}") | |
return metrics | |