# -*- coding: utf-8 -*- # # This file is part of UniRig. # # This file is derived from https://github.com/NeuralCarver/Michelangelo # # Copyright (c) https://github.com/NeuralCarver/Michelangelo original authors # Copyright (c) 2025 VAST-AI-Research and contributors. # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . import torch import torch.nn as nn from typing import Optional, Union from einops import repeat import math from torch_cluster import fps import random import time import numpy as np from ..modules import checkpoint from ..modules.embedder import FourierEmbedder from ..modules.transformer_blocks import ( ResidualCrossAttentionBlock, Transformer ) from .tsal_base import ShapeAsLatentModule class CrossAttentionEncoder(nn.Module): def __init__(self, *, device: Optional[torch.device], dtype: Optional[torch.dtype], num_latents: int, fourier_embedder: FourierEmbedder, point_feats: int, width: int, heads: int, layers: int, init_scale: float = 0.25, qkv_bias: bool = True, flash: bool = False, use_ln_post: bool = False, use_checkpoint: bool = False, query_method: bool = False, use_full_input: bool = True, token_num: int = 256, no_query: bool=False): super().__init__() self.query_method = query_method self.token_num = token_num self.use_full_input = use_full_input self.use_checkpoint = use_checkpoint self.num_latents = num_latents if no_query: self.query = None else: self.query = nn.Parameter(torch.randn((num_latents, width), device=device, dtype=dtype) * 0.02) self.fourier_embedder = fourier_embedder self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width, device=device, dtype=dtype) self.cross_attn = ResidualCrossAttentionBlock( device=device, dtype=dtype, width=width, heads=heads, init_scale=init_scale, qkv_bias=qkv_bias, flash=flash, ) self.self_attn = Transformer( device=device, dtype=dtype, n_ctx=num_latents, width=width, layers=layers, heads=heads, init_scale=init_scale, qkv_bias=qkv_bias, flash=flash, use_checkpoint=False ) if use_ln_post: self.ln_post = nn.LayerNorm(width, dtype=dtype, device=device) else: self.ln_post = None def _forward(self, pc, feats): """ Args: pc (torch.FloatTensor): [B, N, 3] feats (torch.FloatTensor or None): [B, N, C] Returns: """ if self.query_method: token_num = self.num_latents bs = pc.shape[0] data = self.fourier_embedder(pc) if feats is not None: data = torch.cat([data, feats], dim=-1) data = self.input_proj(data) query = repeat(self.query, "m c -> b m c", b=bs) latents = self.cross_attn(query, data) latents = self.self_attn(latents) if self.ln_post is not None: latents = self.ln_post(latents) pre_pc = None else: if isinstance(self.token_num, int): token_num = self.token_num else: token_num = random.choice(self.token_num) if self.training: rng = np.random.default_rng() else: rng = np.random.default_rng(seed=0) ind = rng.choice(pc.shape[1], token_num * 4, replace=token_num * 4 > pc.shape[1]) pre_pc = pc[:,ind,:] pre_feats = feats[:,ind,:] B, N, D = pre_pc.shape C = pre_feats.shape[-1] ###### fps pos = pre_pc.view(B*N, D) pos_feats = pre_feats.view(B*N, C) batch = torch.arange(B).to(pc.device) batch = torch.repeat_interleave(batch, N) idx = fps(pos, batch, ratio=1. / 4, random_start=self.training) sampled_pc = pos[idx] sampled_pc = sampled_pc.view(B, -1, 3) sampled_feats = pos_feats[idx] sampled_feats = sampled_feats.view(B, -1, C) ###### if self.use_full_input: data = self.fourier_embedder(pc) else: data = self.fourier_embedder(pre_pc) if feats is not None: if not self.use_full_input: feats = pre_feats data = torch.cat([data, feats], dim=-1) data = self.input_proj(data) sampled_data = self.fourier_embedder(sampled_pc) if feats is not None: sampled_data = torch.cat([sampled_data, sampled_feats], dim=-1) sampled_data = self.input_proj(sampled_data) latents = self.cross_attn(sampled_data, data) latents = self.self_attn(latents) if self.ln_post is not None: latents = self.ln_post(latents) pre_pc = torch.cat([pre_pc, pre_feats], dim=-1) return latents, pc, token_num, pre_pc def forward(self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None): """ Args: pc (torch.FloatTensor): [B, N, 3] feats (torch.FloatTensor or None): [B, N, C] Returns: dict """ return checkpoint(self._forward, (pc, feats), self.parameters(), self.use_checkpoint) class CrossAttentionDecoder(nn.Module): def __init__(self, *, device: Optional[torch.device], dtype: Optional[torch.dtype], num_latents: int, out_channels: int, fourier_embedder: FourierEmbedder, width: int, heads: int, init_scale: float = 0.25, qkv_bias: bool = True, flash: bool = False, use_checkpoint: bool = False, mlp_width_scale: int = 4, supervision_type: str = 'occupancy'): super().__init__() self.use_checkpoint = use_checkpoint self.fourier_embedder = fourier_embedder self.supervision_type = supervision_type self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width, device=device, dtype=dtype) self.cross_attn_decoder = ResidualCrossAttentionBlock( device=device, dtype=dtype, n_data=num_latents, width=width, heads=heads, init_scale=init_scale, qkv_bias=qkv_bias, flash=flash, mlp_width_scale=mlp_width_scale, ) self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) self.output_proj = nn.Linear(width, out_channels, device=device, dtype=dtype) if self.supervision_type == 'occupancy-sdf': self.output_proj_sdf = nn.Linear(width, out_channels, device=device, dtype=dtype) def _forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor): if next(self.query_proj.parameters()).dtype == torch.float16: queries = queries.half() latents = latents.half() # print(f"queries: {queries.dtype}, {queries.device}") # print(f"latents: {latents.dtype}, {latents.device}"z) queries = self.query_proj(self.fourier_embedder(queries)) x = self.cross_attn_decoder(queries, latents) x = self.ln_post(x) x_1 = self.output_proj(x) if self.supervision_type == 'occupancy-sdf': x_2 = self.output_proj_sdf(x) return x_1, x_2 else: return x_1 def forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor): return checkpoint(self._forward, (queries, latents), self.parameters(), self.use_checkpoint) class ShapeAsLatentPerceiver(ShapeAsLatentModule): def __init__(self, *, device: Optional[torch.device], dtype: Optional[torch.dtype], num_latents: int, point_feats: int = 0, embed_dim: int = 0, num_freqs: int = 8, include_pi: bool = True, width: int, heads: int, num_encoder_layers: int, num_decoder_layers: int, decoder_width: Optional[int] = None, init_scale: float = 0.25, qkv_bias: bool = True, flash: bool = False, use_ln_post: bool = False, use_checkpoint: bool = False, supervision_type: str = 'occupancy', query_method: bool = False, token_num: int = 256, grad_type: str = "numerical", grad_interval: float = 0.005, use_full_input: bool = True, freeze_encoder: bool = False, decoder_mlp_width_scale: int = 4, residual_kl: bool = False, ): super().__init__() self.use_checkpoint = use_checkpoint self.num_latents = num_latents assert grad_type in ["numerical", "analytical"] self.grad_type = grad_type self.grad_interval = grad_interval self.supervision_type = supervision_type self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) init_scale = init_scale * math.sqrt(1.0 / width) self.encoder = CrossAttentionEncoder( device=device, dtype=dtype, fourier_embedder=self.fourier_embedder, num_latents=num_latents, point_feats=point_feats, width=width, heads=heads, layers=num_encoder_layers, init_scale=init_scale, qkv_bias=qkv_bias, flash=flash, use_ln_post=use_ln_post, use_checkpoint=use_checkpoint, query_method=query_method, use_full_input=use_full_input, token_num=token_num ) self.embed_dim = embed_dim self.residual_kl = residual_kl if decoder_width is None: decoder_width = width if embed_dim > 0: # VAE embed self.pre_kl = nn.Linear(width, embed_dim * 2, device=device, dtype=dtype) self.post_kl = nn.Linear(embed_dim, decoder_width, device=device, dtype=dtype) self.latent_shape = (num_latents, embed_dim) if self.residual_kl: assert self.post_kl.out_features % self.post_kl.in_features == 0 assert self.pre_kl.in_features % self.pre_kl.out_features == 0 else: self.latent_shape = (num_latents, width) self.transformer = Transformer( device=device, dtype=dtype, n_ctx=num_latents, width=decoder_width, layers=num_decoder_layers, heads=heads, init_scale=init_scale, qkv_bias=qkv_bias, flash=flash, use_checkpoint=use_checkpoint ) # geometry decoder self.geo_decoder = CrossAttentionDecoder( device=device, dtype=dtype, fourier_embedder=self.fourier_embedder, out_channels=1, num_latents=num_latents, width=decoder_width, heads=heads, init_scale=init_scale, qkv_bias=qkv_bias, flash=flash, use_checkpoint=use_checkpoint, supervision_type=supervision_type, mlp_width_scale=decoder_mlp_width_scale ) if freeze_encoder: for p in self.encoder.parameters(): p.requires_grad = False for p in self.pre_kl.parameters(): p.requires_grad = False print("freeze encoder and pre kl") def forward(self, pc: torch.FloatTensor, feats: torch.FloatTensor, volume_queries: torch.FloatTensor, sample_posterior: bool = True): """ Args: pc (torch.FloatTensor): [B, N, 3] feats (torch.FloatTensor or None): [B, N, C] volume_queries (torch.FloatTensor): [B, P, 3] sample_posterior (bool): Returns: logits (torch.FloatTensor): [B, P] center_pos (torch.FloatTensor): [B, M, 3] posterior (DiagonalGaussianDistribution or None). """ latents, center_pos, posterior = self.encode(pc, feats, sample_posterior=sample_posterior) latents = self.decode(latents) logits = self.query_geometry(volume_queries, latents) return logits, center_pos, posterior class AlignedShapeLatentPerceiver(ShapeAsLatentPerceiver): def __init__(self, *, device: Optional[torch.device], dtype: Optional[str], num_latents: int, point_feats: int = 0, embed_dim: int = 0, num_freqs: int = 8, include_pi: bool = True, width: int, heads: int, num_encoder_layers: int, num_decoder_layers: int, decoder_width: Optional[int] = None, init_scale: float = 0.25, qkv_bias: bool = True, flash: bool = False, use_ln_post: bool = False, use_checkpoint: bool = False, supervision_type: str = 'occupancy', grad_type: str = "numerical", grad_interval: float = 0.005, query_method: bool = False, use_full_input: bool = True, token_num: int = 256, freeze_encoder: bool = False, decoder_mlp_width_scale: int = 4, residual_kl: bool = False, ): MAP_DTYPE = { 'float32': torch.float32, 'float16': torch.float16, 'bfloat16': torch.bfloat16, } if dtype is not None: dtype = MAP_DTYPE[dtype] super().__init__( device=device, dtype=dtype, num_latents=1 + num_latents, point_feats=point_feats, embed_dim=embed_dim, num_freqs=num_freqs, include_pi=include_pi, width=width, decoder_width=decoder_width, heads=heads, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, init_scale=init_scale, qkv_bias=qkv_bias, flash=flash, use_ln_post=use_ln_post, use_checkpoint=use_checkpoint, supervision_type=supervision_type, grad_type=grad_type, grad_interval=grad_interval, query_method=query_method, token_num=token_num, use_full_input=use_full_input, freeze_encoder=freeze_encoder, decoder_mlp_width_scale=decoder_mlp_width_scale, residual_kl=residual_kl, ) self.width = width def encode(self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None, sample_posterior: bool = True, only_shape: bool=False): """ Args: pc (torch.FloatTensor): [B, N, 3] feats (torch.FloatTensor or None): [B, N, c] sample_posterior (bool): Returns: shape_embed (torch.FloatTensor) kl_embed (torch.FloatTensor): posterior (DiagonalGaussianDistribution or None): """ shape_embed, latents, token_num, pre_pc = self.encode_latents(pc, feats) if only_shape: return shape_embed kl_embed, posterior = self.encode_kl_embed(latents, sample_posterior) return shape_embed, kl_embed, posterior, token_num, pre_pc def encode_latents(self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None): x, _, token_num, pre_pc = self.encoder(pc, feats) shape_embed = x[:, 0] # latents = x[:, 1:] # use all tokens latents = x return shape_embed, latents, token_num, pre_pc def forward(self, pc: torch.FloatTensor, feats: torch.FloatTensor, volume_queries: torch.FloatTensor, sample_posterior: bool = True): raise NotImplementedError() ##################################################### # a simplified verstion of perceiver encoder ##################################################### class ShapeAsLatentPerceiverEncoder(ShapeAsLatentModule): def __init__(self, *, device: Optional[torch.device], dtype: Optional[Union[torch.dtype, str]], num_latents: int, point_feats: int = 0, embed_dim: int = 0, num_freqs: int = 8, include_pi: bool = True, width: int, heads: int, num_encoder_layers: int, init_scale: float = 0.25, qkv_bias: bool = True, flash: bool = False, use_ln_post: bool = False, use_checkpoint: bool = False, supervision_type: str = 'occupancy', query_method: bool = False, token_num: int = 256, grad_type: str = "numerical", grad_interval: float = 0.005, use_full_input: bool = True, freeze_encoder: bool = False, residual_kl: bool = False, ): super().__init__() MAP_DTYPE = { 'float32': torch.float32, 'float16': torch.float16, 'bfloat16': torch.bfloat16, } if dtype is not None and isinstance(dtype, str): dtype = MAP_DTYPE[dtype] self.use_checkpoint = use_checkpoint self.num_latents = num_latents assert grad_type in ["numerical", "analytical"] self.grad_type = grad_type self.grad_interval = grad_interval self.supervision_type = supervision_type self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) init_scale = init_scale * math.sqrt(1.0 / width) self.encoder = CrossAttentionEncoder( device=device, dtype=dtype, fourier_embedder=self.fourier_embedder, num_latents=num_latents, point_feats=point_feats, width=width, heads=heads, layers=num_encoder_layers, init_scale=init_scale, qkv_bias=qkv_bias, flash=flash, use_ln_post=use_ln_post, use_checkpoint=use_checkpoint, query_method=query_method, use_full_input=use_full_input, token_num=token_num, no_query=True, ) self.embed_dim = embed_dim self.residual_kl = residual_kl if freeze_encoder: for p in self.encoder.parameters(): p.requires_grad = False print("freeze encoder") self.width = width def encode_latents(self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None): x, _, token_num, pre_pc = self.encoder(pc, feats) shape_embed = x[:, 0] latents = x return shape_embed, latents, token_num, pre_pc def forward(self): raise NotImplementedError()