Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import numpy as np | |
from torch.utils.data import Dataset | |
from sklearn.preprocessing import MinMaxScaler | |
from Models.interpretable_diffusion.model_utils import ( | |
normalize_to_neg_one_to_one, | |
unnormalize_to_zero_to_one, | |
) | |
from Utils.masking_utils import noise_mask | |
class MuJoCoDataset(Dataset): | |
def __init__( | |
self, | |
window=128, | |
num=30000, | |
dim=12, | |
save2npy=True, | |
neg_one_to_one=True, | |
seed=123, | |
scalar=None, | |
period="train", | |
output_dir="./OUTPUT", | |
predict_length=None, | |
missing_ratio=None, | |
style="separate", | |
distribution="geometric", | |
mean_mask_length=3, | |
): | |
super(MuJoCoDataset, self).__init__() | |
assert period in ["train", "test"], "period must be train or test." | |
if period == "train": | |
assert ~(predict_length is not None or missing_ratio is not None), "" | |
self.window, self.var_num = window, dim | |
self.auto_norm = neg_one_to_one | |
self.dir = os.path.join(output_dir, "samples") | |
os.makedirs(self.dir, exist_ok=True) | |
self.pred_len, self.missing_ratio = predict_length, missing_ratio | |
self.style, self.distribution, self.mean_mask_length = ( | |
style, | |
distribution, | |
mean_mask_length, | |
) | |
self.rawdata, self.scaler = self._generate_random_trajectories( | |
n_samples=num, seed=seed | |
) | |
if scalar is not None: | |
self.scaler = scalar | |
self.period, self.save2npy = period, save2npy | |
self.samples = self.normalize(self.rawdata) | |
self.sample_num = self.samples.shape[0] | |
if period == "test": | |
if missing_ratio is not None: | |
self.masking = self.mask_data(seed) | |
elif predict_length is not None: | |
masks = np.ones(self.samples.shape) | |
masks[:, -predict_length:, :] = 0 | |
self.masking = masks.astype(bool) | |
else: | |
raise NotImplementedError() | |
def _generate_random_trajectories(self, n_samples, seed=123): | |
try: | |
from dm_control import suite # noqa: F401 | |
except ImportError as e: | |
raise Exception( | |
"Deepmind Control Suite is required to generate the dataset." | |
) from e | |
env = suite.load("hopper", "stand") | |
physics = env.physics | |
# Store the state of the RNG to restore later. | |
st0 = np.random.get_state() | |
np.random.seed(seed) | |
data = np.zeros((n_samples, self.window, self.var_num)) | |
for i in range(n_samples): | |
with physics.reset_context(): | |
# x and z positions of the hopper. We want z > 0 for the hopper to stay above ground. | |
physics.data.qpos[:2] = np.random.uniform(0, 0.5, size=2) | |
physics.data.qpos[2:] = np.random.uniform( | |
-2, 2, size=physics.data.qpos[2:].shape | |
) | |
physics.data.qvel[:] = np.random.uniform( | |
-5, 5, size=physics.data.qvel.shape | |
) | |
for t in range(self.window): | |
data[i, t, : self.var_num // 2] = physics.data.qpos | |
data[i, t, self.var_num // 2 :] = physics.data.qvel | |
physics.step() | |
# Restore RNG. | |
np.random.set_state(st0) | |
scaler = MinMaxScaler() | |
scaler = scaler.fit(data.reshape(-1, self.var_num)) | |
return data, scaler | |
def normalize(self, sq): | |
d = self.__normalize(sq.reshape(-1, self.var_num)) | |
data = d.reshape(-1, self.window, self.var_num) | |
if self.save2npy: | |
np.save( | |
os.path.join( | |
self.dir, f"mujoco_ground_truth_{self.window}_{self.period}.npy" | |
), | |
sq, | |
) | |
if self.auto_norm: | |
np.save( | |
os.path.join( | |
self.dir, f"mujoco_norm_truth_{self.window}_{self.period}.npy" | |
), | |
unnormalize_to_zero_to_one(data), | |
) | |
else: | |
np.save( | |
os.path.join( | |
self.dir, f"mujoco_norm_truth_{self.window}_{self.period}.npy" | |
), | |
data, | |
) | |
return data | |
def __normalize(self, rawdata): | |
data = self.scaler.transform(rawdata) | |
if self.auto_norm: | |
data = normalize_to_neg_one_to_one(data) | |
return data | |
def unnormalize(self, sq): | |
d = self.__unnormalize(sq.reshape(-1, self.var_num)) | |
return d.reshape(-1, self.window, self.var_num) | |
def __unnormalize(self, data): | |
if self.auto_norm: | |
data = unnormalize_to_zero_to_one(data) | |
x = data | |
return self.scaler.inverse_transform(x) | |
def mask_data(self, seed=2023): | |
masks = np.ones_like(self.samples) | |
# Store the state of the RNG to restore later. | |
st0 = np.random.get_state() | |
np.random.seed(seed) | |
for idx in range(self.samples.shape[0]): | |
x = self.samples[idx, :, :] # (seq_length, feat_dim) array | |
mask = noise_mask( | |
x, | |
self.missing_ratio, | |
self.mean_mask_length, | |
self.style, | |
self.distribution, | |
) # (seq_length, feat_dim) boolean array | |
masks[idx, :, :] = mask | |
if self.save2npy: | |
np.save(os.path.join(self.dir, f"mujoco_masking_{self.window}.npy"), masks) | |
# Restore RNG. | |
np.random.set_state(st0) | |
return masks.astype(bool) | |
def __getitem__(self, ind): | |
if self.period == "test": | |
x = self.samples[ind, :, :] # (seq_length, feat_dim) array | |
m = self.masking[ind, :, :] # (seq_length, feat_dim) boolean array | |
return torch.from_numpy(x).float(), torch.from_numpy(m) | |
x = self.samples[ind, :, :] # (seq_length, feat_dim) array | |
return torch.from_numpy(x).float() | |
def __len__(self): | |
return self.sample_num | |