captaincobb's picture
updates for cube v0.5
ce16420
raw
history blame
23.9 kB
from typing import Optional, Tuple
import torch
from tqdm import tqdm
from transformers import CLIPTextModelWithProjection, CLIPTokenizerFast
from cube3d.inference.logits_postprocesses import process_logits
from cube3d.inference.utils import load_config, load_model_weights, parse_structured
from cube3d.model.autoencoder.one_d_autoencoder import OneDAutoEncoder
from cube3d.model.gpt.dual_stream_roformer import DualStreamRoformer
from cube3d.model.transformers.cache import Cache
class Engine:
def __init__(
self,
config_path: str,
gpt_ckpt_path: str,
shape_ckpt_path: str,
device: torch.device,
):
"""
Initializes the inference engine with the given configuration and checkpoint paths.
Args:
config_path (str): Path to the configuration file.
gpt_ckpt_path (str): Path to the GPT model checkpoint file.
shape_ckpt_path (str): Path to the shape model checkpoint file.
device (torch.device): The device to run the models on (e.g., 'cpu' or 'cuda').
Attributes:
cfg (dict): Loaded configuration from the config file.
device (torch.device): The device to run the models on.
gpt_model (DualStreamRoformer): The GPT model initialized and loaded with weights.
shape_model (OneDAutoEncoder): The shape model initialized and loaded with weights.
text_model (CLIPTextModelWithProjection): The text model initialized from a pretrained model.
text_tokenizer (CLIPTokenizerFast): The tokenizer for the text model.
max_new_tokens (int): Maximum number of new tokens for the shape model.
min_id (int): Minimum ID for the shape model codes.
max_id (int): Maximum ID for the shape model codes.
"""
self.cfg = load_config(config_path)
self.device = device
self.gpt_model = DualStreamRoformer(
parse_structured(DualStreamRoformer.Config, self.cfg.gpt_model)
)
load_model_weights(
self.gpt_model,
gpt_ckpt_path,
)
self.gpt_model = self.gpt_model.eval().to(self.device)
self.shape_model = OneDAutoEncoder(
parse_structured(OneDAutoEncoder.Config, self.cfg.shape_model)
)
load_model_weights(
self.shape_model,
shape_ckpt_path,
)
self.shape_model = self.shape_model.eval().to(self.device)
# copy vq codebook to gpt
with torch.no_grad():
codebook = self.shape_model.bottleneck.block.get_codebook()
codebook = self.gpt_model.shape_proj(codebook).detach()
self.gpt_model.transformer.wte.weight.data[: codebook.shape[0]] = codebook
self.text_model = CLIPTextModelWithProjection.from_pretrained(
self.cfg.text_model_pretrained_model_name_or_path,
force_download=False,
device_map=self.device,
).eval()
self.text_tokenizer = CLIPTokenizerFast.from_pretrained(
self.cfg.text_model_pretrained_model_name_or_path
)
self.max_new_tokens = self.shape_model.cfg.num_encoder_latents
self.min_id = 0
self.max_id = self.shape_model.cfg.num_codes
@torch.inference_mode()
def prepare_conditions_with_bbox(
self,
cond: torch.Tensor,
bounding_box_tensor: Optional[torch.Tensor] = None,
):
"""
Prepares condition embeddings by incorporating bounding box information.
Concatenates bounding box embeddings to the existing condition tensor if the model
supports bounding box projection. If no bounding box is provided, uses zero padding.
Args:
cond (torch.Tensor): The input condition embeddings tensor of shape (B, seq_len, dim).
bounding_box_xyz (Optional[torch.Tensor], optional): The size of the bounding box
as (x, y, z) dimensions represented as a tensor. If None, uses zero padding for
bounding box embeddings.
Returns:
torch.Tensor: The condition tensor with bounding box embeddings concatenated along
the sequence dimension if bounding box projection is supported, otherwise
returns the original condition tensor unchanged.
"""
if not hasattr(self.gpt_model, "bbox_proj"):
return cond
if bounding_box_tensor is None:
B = cond.shape[0]
bounding_box_tensor = torch.zeros((B, 3), dtype=cond.dtype, device=self.device)
bbox_emb = self.gpt_model.bbox_proj(bounding_box_tensor).unsqueeze(dim=1)
cond = torch.cat([cond, bbox_emb], dim=1)
return cond
@torch.inference_mode()
def prepare_inputs(
self,
prompts: list[str],
guidance_scale: float,
bounding_box_xyz: Optional[Tuple[float]] = None,
):
"""
Prepares the input embeddings for the model based on the provided prompts and guidance scale.
Args:
prompts (list[str]): A list of prompt strings to be encoded.
guidance_scale (float): A scaling factor for guidance. If greater than 0.0, additional processing is applied.
bounding_box_xyz (Optional[Tuple[float]], optional): The size of the bounding box for generation
as (x, y, z) dimensions. Each value must be between 0 and 1.925. If None,
uses default bounding box sizing.
Returns:
tuple: A tuple containing:
- embed (torch.Tensor): The encoded input embeddings.
- cond (torch.Tensor): The condition embeddings, which may include unconditional embeddings if guidance_scale is greater than 0.0.
"""
prompt_embeds = self.run_clip(prompts)
with torch.autocast(self.device.type, dtype=torch.bfloat16):
embed = self.encode_input(prompt_embeds, self.gpt_model.shape_bos_id)
if bounding_box_xyz is not None:
cond_bbox = torch.atleast_2d(torch.tensor(bounding_box_xyz)).to(self.device)
uncond_bbox = torch.zeros_like(cond_bbox).to(self.device)
else:
cond_bbox = None
uncond_bbox = None
cond = self.prepare_conditions_with_bbox(prompt_embeds, cond_bbox)
if guidance_scale > 0.0:
embed = torch.cat([embed, embed], dim=0)
uncond_embeds = self.run_clip([""] * len(prompts))
uncond = self.prepare_conditions_with_bbox(uncond_embeds, uncond_bbox)
cond = torch.cat([cond, uncond], dim=0)
return embed, cond
@torch.inference_mode()
def run_clip(self, text_inputs):
"""
Processes the given text inputs using a text tokenizer and a text model, and returns the encoded text embeddings.
Args:
text_inputs (str or List[str]): The input text or list of texts to be processed.
Returns:
torch.Tensor: The encoded text embeddings.
"""
text_inputs = self.text_tokenizer(
text_inputs,
max_length=self.text_tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
)
with torch.no_grad():
text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
# use full precision for text encoder
with torch.autocast(device_type=self.device.type, enabled=False):
encoded = self.text_model(**text_inputs)
if self.gpt_model.cfg.use_pooled_text_embed:
embed = encoded.text_embeds.unsqueeze(1) # [bs, 1, 512]
else:
embed = encoded.last_hidden_state # [bs, 77, 512]
embed = self.gpt_model.encode_text(embed)
return embed
@torch.inference_mode()
def encode_input(self, inputs: torch.Tensor, bos: int):
"""
Encodes the beginning of sequence (BOS) token for the given input tensor.
Args:
inputs (torch.Tensor): The input tensor containing sequences.
bos (int): The beginning of sequence token ID.
Returns:
torch.Tensor: The encoded BOS token embeddings.
"""
b = inputs.shape[0]
bos_embed = self.gpt_model.encode_token(
torch.full(
(b, 1),
fill_value=bos,
dtype=torch.long,
device=self.device,
)
)
return bos_embed
@torch.inference_mode()
def run_gpt(
self,
prompts: list[str],
use_kv_cache: bool,
guidance_scale: float = 3.0,
top_p: float = None,
bounding_box_xyz: Optional[Tuple[float]] = None,
):
"""
Generates text using a GPT model based on the provided prompts.
Args:
prompts (list[str]): A list of input prompts to generate text from.
use_kv_cache (bool): Whether to use key-value caching for faster generation.
guidance_scale (float, optional): The scale for guidance during generation. Default is 3.0.
top_p (float, optional): The cumulative probability threshold for nucleus sampling.
If None, argmax selection is performed (deterministic generation). Otherwise, smallest set of tokens with cumulative probability ≥ top_p are kept (stochastic generation).
bounding_box_xyz (Optional[Tuple[float]], optional): The size of the bounding box for generation
as (x, y, z) dimensions. Each value must be between 0 and 1.925. If None,
uses default bounding box sizing.
Returns:
torch.Tensor: A tensor containing the generated token IDs.
"""
embed, cond = self.prepare_inputs(prompts, guidance_scale, bounding_box_xyz)
output_ids = []
batch_size, input_seq_len, dim = embed.shape
max_seq_len = input_seq_len + self.max_new_tokens
embed_buffer = torch.zeros(
(batch_size, max_seq_len, dim), dtype=embed.dtype, device=embed.device
)
embed_buffer[:, :input_seq_len, :].copy_(embed)
cond_len = cond.shape[1]
kv_cache = None
if use_kv_cache:
kv_cache = self.gpt_model.init_kv_cache(
batch_size,
cond_len,
self.max_new_tokens + 1, # +1 for the BOS token
torch.bfloat16,
embed.device,
)
with torch.autocast(self.device.type, dtype=torch.bfloat16):
for i in tqdm(range(self.max_new_tokens), desc=f"generating"):
curr_pos_id = torch.tensor([i], dtype=torch.long, device=embed.device)
logits = self.gpt_model(
embed_buffer,
cond,
kv_cache=kv_cache,
curr_pos_id=curr_pos_id if use_kv_cache else None,
decode=(i > 0) if use_kv_cache else False,
)
if use_kv_cache:
logits = logits[:, 0, ...]
else:
logits = logits[:, i, ...]
logits = logits[..., self.min_id : self.max_id]
if guidance_scale > 0.0:
logits, uncond_logits = logits.float().chunk(2, dim=0)
gamma = (
guidance_scale * (self.max_new_tokens - i) / self.max_new_tokens
)
logits = (1 + gamma) * logits - gamma * uncond_logits
next_id = process_logits(
logits,
top_p=top_p,
)
output_ids.append(next_id)
next_embed = self.gpt_model.encode_token(next_id)
if guidance_scale > 0.0:
next_embed = torch.cat([next_embed, next_embed], dim=0)
embed_buffer[:, i + input_seq_len, :].copy_(next_embed.squeeze(1))
return torch.cat(output_ids, dim=1)
@torch.inference_mode()
def run_shape_decode(
self,
output_ids: torch.Tensor,
resolution_base: float = 8.0,
chunk_size: int = 100_000,
):
"""
Decodes the shape from the given output IDs and extracts the geometry.
Args:
output_ids (torch.Tensor): The tensor containing the output IDs.
resolution_base (float, optional): The base resolution for geometry extraction. Defaults to 8.43.
chunk_size (int, optional): The chunk size for processing. Defaults to 100,000.
Returns:
tuple: A tuple containing the vertices and faces of the mesh.
"""
shape_ids = (
output_ids[:, : self.shape_model.cfg.num_encoder_latents, ...]
.clamp_(0, self.shape_model.cfg.num_codes - 1)
.view(-1, self.shape_model.cfg.num_encoder_latents)
)
latents = self.shape_model.decode_indices(shape_ids)
mesh_v_f, _ = self.shape_model.extract_geometry(
latents,
resolution_base=resolution_base,
chunk_size=chunk_size,
use_warp=True,
)
return mesh_v_f
@torch.inference_mode()
def t2s(
self,
prompts: list[str],
use_kv_cache: bool,
guidance_scale: float = 3.0,
resolution_base: float = 8.0,
chunk_size: int = 100_000,
top_p: float = None,
bounding_box_xyz: Optional[Tuple[float]] = None,
):
"""
Generates a 3D mesh from text prompts using a GPT model and shape decoder.
Args:
prompts (list[str]): A list of text prompts to guide the generation.
use_kv_cache (bool): Whether to use key-value caching for the GPT model.
guidance_scale (float, optional): The scale of guidance for the GPT model. Default is 3.0.
resolution_base (float, optional): The base resolution for the shape decoder. Default is 8.0.
chunk_size (int, optional): The chunk size for processing the shape decoding. Default is 100,000.
top_p (float, optional): The cumulative probability threshold for nucleus sampling.
If None, argmax selection is performed (deterministic generation). Otherwise, smallest set of tokens with cumulative probability ≥ top_p are kept (stochastic generation).
bounding_box_xyz (Tuple[float] | None, optional): The size of the bounding box for the generated mesh
as (x, y, z) dimensions. Each value must be between 0 and 1.925. If None,
uses default bounding box sizing.
Returns:
mesh_v_f: The generated 3D mesh vertices and faces.
"""
output_ids = self.run_gpt(
prompts, use_kv_cache, guidance_scale, top_p, bounding_box_xyz
)
with torch.autocast(self.device.type, dtype=torch.bfloat16):
mesh_v_f = self.run_shape_decode(output_ids, resolution_base, chunk_size)
return mesh_v_f
class EngineFast(Engine):
def __init__(
self,
config_path: str,
gpt_ckpt_path: str,
shape_ckpt_path: str,
device: torch.device,
):
"""
Initializes the inference engine with the given configuration and checkpoint paths.
Args:
config_path (str): Path to the configuration file.
gpt_ckpt_path (str): Path to the GPT checkpoint file.
shape_ckpt_path (str): Path to the shape checkpoint file.
device (torch.device): The device to run the inference on (e.g., CPU or CUDA).
"""
assert (
device.type == "cuda"
), "EngineFast is only supported on cuda devices, please use Engine on non-cuda devices"
super().__init__(config_path, gpt_ckpt_path, shape_ckpt_path, device)
# CUDA Graph params
self.graph = torch.cuda.CUDAGraph()
self.embed_buffer = torch.Tensor()
self.cond_buffer = torch.Tensor()
self.logits_buffer = torch.Tensor()
self.curr_pos_id = torch.tensor([0], dtype=torch.long, device=self.device)
self.kv_cache: list[Cache] = []
self._warmup_and_capture_graph()
def _warmup_and_capture_graph(self):
"""
Warms up the model by running a series of forward passes and captures the CUDA graph for efficient execution.
This method performs the following steps:
1. Prepares the input embeddings and conditions using a warmup prompt.
2. Initializes buffers for embeddings and conditions.
3. Initializes the key-value cache for the GPT model.
4. Runs a series of warmup passes to prefill the model and generate logits.
5. Captures the CUDA graph for the model's forward pass to optimize future executions.
"""
warmup_prompt = "A cube"
embed, cond = self.prepare_inputs([warmup_prompt], guidance_scale=3.0)
batch_size, input_seq_len, dim = embed.shape
max_seq_len = input_seq_len + self.max_new_tokens
self.embed_buffer = torch.zeros(
(batch_size, max_seq_len, dim), dtype=embed.dtype, device=self.device
)
self.embed_buffer[:, :input_seq_len, :].copy_(embed)
self.cond_buffer = torch.empty_like(cond)
self.cond_buffer.copy_(cond)
cond_len = self.cond_buffer.shape[1]
# Initialize kv_cache for the first time
self.kv_cache = self.gpt_model.init_kv_cache(
batch_size,
cond_len,
self.max_new_tokens + 1, # +1 for the BOS token
torch.bfloat16,
self.device,
)
num_warmup_passes = 10
with torch.autocast(self.device.type, dtype=torch.bfloat16):
self._set_curr_pos_id(0)
_ = self._prefill_and_return_logits()
for x in range(1, num_warmup_passes):
self._set_curr_pos_id(x)
self.logits_buffer = self.gpt_model(
embed=self.embed_buffer,
cond=self.cond_buffer,
kv_cache=self.kv_cache,
curr_pos_id=self.curr_pos_id,
decode=True,
)
side_stream = torch.cuda.Stream(device=self.device)
with torch.cuda.graph(self.graph, stream=side_stream):
with torch.autocast(self.device.type, dtype=torch.bfloat16):
self.logits_buffer = self.gpt_model(
embed=self.embed_buffer,
cond=self.cond_buffer,
kv_cache=self.kv_cache,
curr_pos_id=self.curr_pos_id,
decode=True,
)
def _reset_kv_cache(self):
"""
Resets the key-value cache by setting all key and value states to zero.
This method iterates through each cache in the `kv_cache` attribute and
calls the `zero_()` method on both `key_states` and `value_states` to
reset them to their initial state.
"""
for cache in self.kv_cache:
cache.key_states.zero_()
cache.value_states.zero_()
def _prefill_and_return_logits(self) -> torch.Tensor:
"""
Prefills the model's key-value cache and returns the logits.
This method resets the key-value cache and then performs a forward pass
through the GPT model in eager mode to prefill the logits.
Returns:
torch.Tensor: The prefilled logits tensor with the first dimension removed.
"""
self._reset_kv_cache()
# Prefill is always eager
prefill_logits = self.gpt_model(
embed=self.embed_buffer,
cond=self.cond_buffer,
kv_cache=self.kv_cache,
curr_pos_id=self.curr_pos_id,
decode=False,
)
return prefill_logits[:, 0, ...]
def _set_curr_pos_id(self, pos: int):
"""
Set the current position ID.
This method updates the `curr_pos_id` attribute with the given position.
Args:
pos (int): The position ID to set.
"""
self.curr_pos_id.copy_(
torch.tensor([pos], dtype=torch.long, device=self.device)
)
def run_gpt(
self,
prompts: list[str],
use_kv_cache: bool,
guidance_scale: float = 3.0,
top_p: float = None,
bounding_box_xyz: Optional[Tuple[float]] = None,
):
"""
Runs the GPT model to generate text based on the provided prompts.
Args:
prompts (list[str]): A list of input prompts for the GPT model. Only a single prompt is supported.
use_kv_cache (bool): Flag indicating whether to use key-value caching. (Currently not used)
guidance_scale (float, optional): The scale factor for guidance. Default is 3.0.
top_p (float, optional): The cumulative probability threshold for nucleus sampling.
If None, argmax selection is performed. Otherwise, smallest
set of tokens with cumulative probability ≥ top_p are kept.
bounding_box_xyz (Tuple[float] | None, optional): The size of the bounding box for the generated mesh
as (x, y, z) dimensions. Each value must be between 0 and 1.925. If None,
uses default bounding box sizing.
Returns:
torch.Tensor: A tensor containing the generated output token IDs.
Raises:
AssertionError: If the batch size is greater than 1.
"""
embed, cond = self.prepare_inputs(prompts, guidance_scale, bounding_box_xyz)
assert len(prompts) == 1, "batch size > 1 not support for EngineFast"
batch_size, input_seq_len, _ = embed.shape
self.embed_buffer.zero_()
self.embed_buffer[:, :input_seq_len, :].copy_(embed)
assert self.cond_buffer.shape == cond.shape
self.cond_buffer.copy_(cond)
output_ids = torch.zeros(
(batch_size // 2, self.max_new_tokens), dtype=torch.int, device=self.device
)
with torch.autocast(self.device.type, dtype=torch.bfloat16):
self._set_curr_pos_id(0)
logits = self._prefill_and_return_logits()
logits = logits[..., self.min_id : self.max_id]
if guidance_scale > 0.0:
logits, uncond_logits = logits.float().chunk(2, dim=0)
gamma = guidance_scale
logits = (1 + gamma) * logits - gamma * uncond_logits
next_id = process_logits(logits, top_p=top_p)
output_ids[:, 0] = next_id.squeeze()
next_embed = self.gpt_model.encode_token(next_id)
next_embed = next_embed.repeat(2, 1, 1)
self.embed_buffer[:, input_seq_len, :].copy_(next_embed.squeeze(1))
for i in tqdm(range(1, self.max_new_tokens), desc=f"generating"):
self._set_curr_pos_id(i)
self.graph.replay()
logits = self.logits_buffer[:, 0, ...]
logits = logits[..., self.min_id : self.max_id]
if guidance_scale > 0.0:
logits, uncond_logits = logits.float().chunk(2, dim=0)
gamma = (
guidance_scale * (self.max_new_tokens - i) / self.max_new_tokens
)
logits = (1 + gamma) * logits - gamma * uncond_logits
next_id = process_logits(logits, top_p=top_p)
output_ids[:, i] = next_id.squeeze()
next_embed = self.gpt_model.encode_token(next_id)
next_embed = next_embed.repeat(2, 1, 1)
self.embed_buffer[:, i + input_seq_len, :].copy_(next_embed.squeeze(1))
return output_ids