Spaces:
Running
on
Zero
Running
on
Zero
from typing import List | |
import argparse | |
import random | |
import json | |
import os | |
from torch.distributions.categorical import Categorical | |
from PIL import Image | |
from datasets import Dataset | |
import gymnasium as gym | |
import torch.nn as nn | |
import numpy as np | |
import torch | |
from sonicverse.constants import ROLE_ASSISTANT, ROLE_USER | |
LUNAR_LANDER_OPTIONS = ( | |
"[FIRE LEFT ENGINE], [FIRE RIGHT ENGINE], [FIRE MAIN ENGINE], [NOTHING]".split(", ") | |
) | |
MAX_STEPS = 1000 | |
def layer_init(layer, std=np.sqrt(2), bias_const=0.0): | |
torch.nn.init.orthogonal_(layer.weight, std) | |
torch.nn.init.constant_(layer.bias, bias_const) | |
return layer | |
class Agent(nn.Module): | |
def __init__(self, envs): | |
super().__init__() | |
self.critic = nn.Sequential( | |
layer_init( | |
nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64) | |
), | |
nn.Tanh(), | |
layer_init(nn.Linear(64, 64)), | |
nn.Tanh(), | |
layer_init(nn.Linear(64, 1), std=1.0), | |
) | |
self.actor = nn.Sequential( | |
layer_init( | |
nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64) | |
), | |
nn.Tanh(), | |
layer_init(nn.Linear(64, 64)), | |
nn.Tanh(), | |
layer_init(nn.Linear(64, envs.single_action_space.n), std=0.01), | |
) | |
def get_value(self, x): | |
return self.critic(x) | |
def get_action_and_value(self, x, action=None): | |
logits = self.actor(x) | |
probs = Categorical(logits=logits) | |
if action is None: | |
action = probs.sample() | |
return action, probs.log_prob(action), probs.entropy(), self.critic(x) | |
def _gen_examples(round_num, args): | |
env = gym.make("LunarLander-v2", render_mode="rgb_array") | |
random.seed(round_num) | |
np.random.seed(round_num) | |
class EnvWrapper: | |
single_observation_space = env.observation_space | |
single_action_space = env.action_space | |
model = Agent(EnvWrapper()).to("cpu") | |
model.load_state_dict( | |
torch.load(args.pretrained_ppo_model_path, map_location="cpu") | |
) | |
model.eval() | |
os.makedirs(args.output_image_folder, exist_ok=True) | |
observation, info = env.reset(seed=round_num) | |
for frame in range(MAX_STEPS): | |
img = env.render() | |
with torch.no_grad(): | |
action, logprob, _, value = model.get_action_and_value( | |
torch.from_numpy(observation) | |
) | |
action = action.cpu().numpy() | |
resp = "" | |
if action == 0: | |
resp = "[NOTHING]" | |
elif action == 1: | |
resp = "[FIRE LEFT ENGINE]" | |
elif action == 2: | |
resp = "[FIRE MAIN ENGINE]" | |
elif action == 3: | |
resp = "[FIRE RIGHT ENGINE]" | |
if random.random() < args.sample_rate: | |
random.shuffle(LUNAR_LANDER_OPTIONS) | |
options_str = ", ".join(LUNAR_LANDER_OPTIONS) | |
img_fn = os.path.join(args.output_image_folder, f"{round_num}_{frame}.jpg") | |
messages = [ | |
{ | |
"role": ROLE_USER, | |
"content": f"<image>\nYou are playing lunar lander. The goal is to land the craft between the yellow flags. What is the optimal next action? {options_str}", | |
}, | |
{"role": ROLE_ASSISTANT, "content": resp}, | |
] | |
Image.fromarray(img).save(img_fn) | |
example = { | |
"id": f"{round_num}_{frame}", | |
"images": [img_fn], | |
"messages": messages, | |
} | |
yield example | |
observation, reward, terminated, truncated, info = env.step(action) | |
if terminated or truncated: | |
break | |
def main(args): | |
def gen(idxs): | |
for r in idxs: | |
yield from _gen_examples(r, args) | |
ds = Dataset.from_generator( | |
gen, gen_kwargs={"idxs": list(range(args.rounds))}, num_proc=args.num_proc | |
) | |
ds.save_to_disk(args.output_folder) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--pretrained_ppo_model_path", type=str) | |
parser.add_argument("--output_image_folder", type=str) | |
parser.add_argument("--output_folder", type=str) | |
parser.add_argument("--rounds", type=int, default=10_000) | |
parser.add_argument("--sample_rate", type=float, default=0.01) | |
parser.add_argument("--num_proc", type=int, default=16) | |
args = parser.parse_args() | |
main(args) | |