| import os |
|
|
| import numpy as np |
| import torch |
| from PIL import Image |
| from torchvision import transforms |
|
|
| from configs.state_vec import STATE_VEC_IDX_MAPPING |
| from models.multimodal_encoder.siglip_encoder import SiglipVisionTower |
| from models.multimodal_encoder.t5_encoder import T5Embedder |
| from models.rdt_runner import RDTRunner |
|
|
|
|
| MANISKILL_INDICES = [ |
| STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(7) |
| ] + [ |
| STATE_VEC_IDX_MAPPING[f"right_gripper_open"] |
| ] |
|
|
|
|
| def create_model(args, pretrained, **kwargs): |
| model = RoboticDiffusionTransformerModel(args, **kwargs) |
| if pretrained is not None: |
| model.load_pretrained_weights(pretrained) |
| return model |
|
|
|
|
| DATA_STAT = {'state_min': [-0.7463043928146362, -0.0801204964518547, -0.4976441562175751, -2.657780647277832, -0.5742632150650024, 1.8309762477874756, -2.2423808574676514, 0.0], 'state_max': [0.7645499110221863, 1.4967026710510254, 0.4650936424732208, -0.3866899907588959, 0.5505855679512024, 3.2900545597076416, 2.5737812519073486, 0.03999999910593033], 'action_min': [-0.7472005486488342, -0.08631071448326111, -0.4995281398296356, -2.658363103866577, -0.5751323103904724, 1.8290787935256958, -2.245187997817993, -1.0], 'action_max': [0.7654682397842407, 1.4984270334243774, 0.46786263585090637, -0.38181185722351074, 0.5517147779464722, 3.291581630706787, 2.575840711593628, 1.0]} |
|
|
| class RoboticDiffusionTransformerModel(object): |
| """A wrapper for the RDT model, which handles |
| 1. Model initialization |
| 2. Encodings of instructions |
| 3. Model inference |
| """ |
| def __init__( |
| self, args, |
| device='cuda', |
| dtype=torch.bfloat16, |
| image_size=None, |
| control_frequency=25, |
| pretrained_text_encoder_name_or_path=None, |
| pretrained_vision_encoder_name_or_path=None, |
| ): |
| self.args = args |
| self.dtype = dtype |
| self.image_size = image_size |
| self.device = device |
| self.control_frequency = control_frequency |
| self.text_tokenizer, self.text_model = self.get_text_encoder(pretrained_text_encoder_name_or_path) |
| self.image_processor, self.vision_model = self.get_vision_encoder(pretrained_vision_encoder_name_or_path) |
| self.policy = self.get_policy() |
|
|
| self.state_min = torch.tensor(DATA_STAT['state_min']).to(device) |
| self.state_max = torch.tensor(DATA_STAT['state_max']).to(device) |
| self.action_min = torch.tensor(DATA_STAT['action_min']).to(device) |
| self.action_max = torch.tensor(DATA_STAT['action_max']).to(device) |
|
|
| self.reset() |
|
|
| def get_policy(self): |
| """Initialize the model.""" |
| |
| img_cond_len = (self.args["common"]["img_history_size"] |
| * self.args["common"]["num_cameras"] |
| * self.vision_model.num_patches) |
| |
| _model = RDTRunner( |
| action_dim=self.args["common"]["state_dim"], |
| pred_horizon=self.args["common"]["action_chunk_size"], |
| config=self.args["model"], |
| lang_token_dim=self.args["model"]["lang_token_dim"], |
| img_token_dim=self.args["model"]["img_token_dim"], |
| state_token_dim=self.args["model"]["state_token_dim"], |
| max_lang_cond_len=self.args["dataset"]["tokenizer_max_length"], |
| img_cond_len=img_cond_len, |
| img_pos_embed_config=[ |
| |
| |
| ("image", (self.args["common"]["img_history_size"], |
| self.args["common"]["num_cameras"], |
| -self.vision_model.num_patches)), |
| ], |
| lang_pos_embed_config=[ |
| |
| ("lang", -self.args["dataset"]["tokenizer_max_length"]), |
| ], |
| dtype=self.dtype, |
| ) |
|
|
| return _model |
|
|
| def get_text_encoder(self, pretrained_text_encoder_name_or_path): |
| text_embedder = T5Embedder(from_pretrained=pretrained_text_encoder_name_or_path, |
| model_max_length=self.args["dataset"]["tokenizer_max_length"], |
| device=self.device) |
| tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model |
| return tokenizer, text_encoder |
|
|
| def get_vision_encoder(self, pretrained_vision_encoder_name_or_path): |
| vision_encoder = SiglipVisionTower(vision_tower=pretrained_vision_encoder_name_or_path, args=None) |
| image_processor = vision_encoder.image_processor |
| return image_processor, vision_encoder |
|
|
| def reset(self): |
| """Set model to evaluation mode. |
| """ |
| device = self.device |
| weight_dtype = self.dtype |
| self.policy.eval() |
| self.text_model.eval() |
| self.vision_model.eval() |
|
|
| self.policy = self.policy.to(device, dtype=weight_dtype) |
| self.text_model = self.text_model.to(device, dtype=weight_dtype) |
| self.vision_model = self.vision_model.to(device, dtype=weight_dtype) |
|
|
| def load_pretrained_weights(self, pretrained=None): |
| if pretrained is None: |
| return |
| print(f'Loading weights from {pretrained}') |
| filename = os.path.basename(pretrained) |
| if filename.endswith('.pt'): |
| checkpoint = torch.load(pretrained) |
| self.policy.load_state_dict(checkpoint["module"]) |
| elif filename.endswith('.safetensors'): |
| from safetensors.torch import load_model |
| load_model(self.policy, pretrained) |
| else: |
| raise NotImplementedError(f"Unknown checkpoint format: {pretrained}") |
|
|
| def encode_instruction(self, instruction, device="cuda"): |
| """Encode string instruction to latent embeddings. |
| |
| Args: |
| instruction: a string of instruction |
| device: a string of device |
| |
| Returns: |
| pred: a tensor of latent embeddings of shape (text_max_length, 512) |
| """ |
| tokens = self.text_tokenizer( |
| instruction, return_tensors="pt", |
| padding="longest", |
| truncation=True |
| )["input_ids"].to(device) |
|
|
| tokens = tokens.view(1, -1) |
| with torch.no_grad(): |
| pred = self.text_model(tokens).last_hidden_state.detach() |
|
|
| return pred |
|
|
| def _format_joint_to_state(self, joints): |
| """ |
| Format the robot joint state into the unified state vector. |
| |
| Args: |
| joints (torch.Tensor): The joint state to be formatted. |
| qpos ([B, N, 14]). |
| |
| Returns: |
| state (torch.Tensor): The formatted state for RDT ([B, N, 128]). |
| """ |
| |
| |
| |
| |
| |
| |
| |
| joints = (joints - self.state_min) / (self.state_max - self.state_min) * 2 - 1 |
| B, N, _ = joints.shape |
| state = torch.zeros( |
| (B, N, self.args["model"]["state_token_dim"]), |
| device=joints.device, dtype=joints.dtype |
| ) |
| |
| state[:, :, MANISKILL_INDICES] = joints |
| state_elem_mask = torch.zeros( |
| (B, self.args["model"]["state_token_dim"]), |
| device=joints.device, dtype=joints.dtype |
| ) |
| state_elem_mask[:, MANISKILL_INDICES] = 1 |
| return state, state_elem_mask |
|
|
| def _unformat_action_to_joint(self, action): |
| action_indices = MANISKILL_INDICES |
| joints = action[:, :, action_indices] |
| |
| |
|
|
| joints = (joints + 1) / 2 * (self.action_max - self.action_min) + self.action_min |
| |
| return joints |
|
|
| @torch.no_grad() |
| def step(self, proprio, images, text_embeds): |
| """ |
| Args: |
| proprio: proprioceptive states |
| images: RGB images |
| text_embeds: instruction embeddings |
| |
| Returns: |
| action: predicted action |
| """ |
| device = self.device |
| dtype = self.dtype |
| |
| background_color = np.array([ |
| int(x*255) for x in self.image_processor.image_mean |
| ], dtype=np.uint8).reshape(1, 1, 3) |
| background_image = np.ones(( |
| self.image_processor.size["height"], |
| self.image_processor.size["width"], 3), dtype=np.uint8 |
| ) * background_color |
| |
| image_tensor_list = [] |
| for image in images: |
| if image is None: |
| |
| image = Image.fromarray(background_image) |
| |
| if self.image_size is not None: |
| image = transforms.Resize(self.data_args.image_size)(image) |
| |
| if self.args["dataset"].get("auto_adjust_image_brightness", False): |
| pixel_values = list(image.getdata()) |
| average_brightness = sum(sum(pixel) for pixel in pixel_values) / (len(pixel_values) * 255.0 * 3) |
| if average_brightness <= 0.15: |
| image = transforms.ColorJitter(brightness=(1.75,1.75))(image) |
| |
| if self.args["dataset"].get("image_aspect_ratio", "pad") == 'pad': |
| def expand2square(pil_img, background_color): |
| width, height = pil_img.size |
| if width == height: |
| return pil_img |
| elif width > height: |
| result = Image.new(pil_img.mode, (width, width), background_color) |
| result.paste(pil_img, (0, (width - height) // 2)) |
| return result |
| else: |
| result = Image.new(pil_img.mode, (height, height), background_color) |
| result.paste(pil_img, ((height - width) // 2, 0)) |
| return result |
| image = expand2square(image, tuple(int(x*255) for x in self.image_processor.image_mean)) |
| image = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] |
| image_tensor_list.append(image) |
|
|
| image_tensor = torch.stack(image_tensor_list, dim=0).to(device, dtype=dtype) |
|
|
| image_embeds = self.vision_model(image_tensor).detach() |
| image_embeds = image_embeds.reshape(-1, self.vision_model.hidden_size).unsqueeze(0) |
|
|
| |
| joints = proprio.to(device).unsqueeze(0) |
| states, state_elem_mask = self._format_joint_to_state(joints) |
| states, state_elem_mask = states.to(device, dtype=dtype), state_elem_mask.to(device, dtype=dtype) |
| states = states[:, -1:, :] |
| ctrl_freqs = torch.tensor([self.control_frequency]).to(device) |
| |
| text_embeds = text_embeds.to(device, dtype=dtype) |
| |
| trajectory = self.policy.predict_action( |
| lang_tokens=text_embeds, |
| lang_attn_mask=torch.ones( |
| text_embeds.shape[:2], dtype=torch.bool, |
| device=text_embeds.device), |
| img_tokens=image_embeds, |
| state_tokens=states, |
| action_mask=state_elem_mask.unsqueeze(1), |
| ctrl_freqs=ctrl_freqs |
| ) |
| trajectory = self._unformat_action_to_joint(trajectory).to(torch.float32) |
|
|
| return trajectory |
|
|