File size: 4,482 Bytes
12f2e48
 
 
 
 
 
 
 
 
 
 
 
 
 
4401dfb
12f2e48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from typing import List
import argparse
import random
import json
import os

from torch.distributions.categorical import Categorical
from PIL import Image
from datasets import Dataset
import gymnasium as gym
import torch.nn as nn
import numpy as np
import torch

from sonicverse.constants import ROLE_ASSISTANT, ROLE_USER

LUNAR_LANDER_OPTIONS = (
    "[FIRE LEFT ENGINE], [FIRE RIGHT ENGINE], [FIRE MAIN ENGINE], [NOTHING]".split(", ")
)

MAX_STEPS = 1000


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer


class Agent(nn.Module):
    def __init__(self, envs):
        super().__init__()
        self.critic = nn.Sequential(
            layer_init(
                nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)
            ),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 1), std=1.0),
        )
        self.actor = nn.Sequential(
            layer_init(
                nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)
            ),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, envs.single_action_space.n), std=0.01),
        )

    def get_value(self, x):
        return self.critic(x)

    def get_action_and_value(self, x, action=None):
        logits = self.actor(x)
        probs = Categorical(logits=logits)
        if action is None:
            action = probs.sample()
        return action, probs.log_prob(action), probs.entropy(), self.critic(x)


def _gen_examples(round_num, args):
    env = gym.make("LunarLander-v2", render_mode="rgb_array")

    random.seed(round_num)
    np.random.seed(round_num)

    class EnvWrapper:
        single_observation_space = env.observation_space
        single_action_space = env.action_space

    model = Agent(EnvWrapper()).to("cpu")
    model.load_state_dict(
        torch.load(args.pretrained_ppo_model_path, map_location="cpu")
    )
    model.eval()

    os.makedirs(args.output_image_folder, exist_ok=True)

    observation, info = env.reset(seed=round_num)

    for frame in range(MAX_STEPS):
        img = env.render()
        with torch.no_grad():
            action, logprob, _, value = model.get_action_and_value(
                torch.from_numpy(observation)
            )

        action = action.cpu().numpy()
        resp = ""
        if action == 0:
            resp = "[NOTHING]"
        elif action == 1:
            resp = "[FIRE LEFT ENGINE]"
        elif action == 2:
            resp = "[FIRE MAIN ENGINE]"
        elif action == 3:
            resp = "[FIRE RIGHT ENGINE]"
        if random.random() < args.sample_rate:
            random.shuffle(LUNAR_LANDER_OPTIONS)
            options_str = ", ".join(LUNAR_LANDER_OPTIONS)
            img_fn = os.path.join(args.output_image_folder, f"{round_num}_{frame}.jpg")
            messages = [
                {
                    "role": ROLE_USER,
                    "content": f"<image>\nYou are playing lunar lander. The goal is to land the craft between the yellow flags. What is the optimal next action? {options_str}",
                },
                {"role": ROLE_ASSISTANT, "content": resp},
            ]
            Image.fromarray(img).save(img_fn)
            example = {
                "id": f"{round_num}_{frame}",
                "images": [img_fn],
                "messages": messages,
            }
            yield example

        observation, reward, terminated, truncated, info = env.step(action)

        if terminated or truncated:
            break


def main(args):
    def gen(idxs):
        for r in idxs:
            yield from _gen_examples(r, args)

    ds = Dataset.from_generator(
        gen, gen_kwargs={"idxs": list(range(args.rounds))}, num_proc=args.num_proc
    )
    ds.save_to_disk(args.output_folder)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--pretrained_ppo_model_path", type=str)
    parser.add_argument("--output_image_folder", type=str)
    parser.add_argument("--output_folder", type=str)
    parser.add_argument("--rounds", type=int, default=10_000)
    parser.add_argument("--sample_rate", type=float, default=0.01)
    parser.add_argument("--num_proc", type=int, default=16)
    args = parser.parse_args()
    main(args)