Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import os | |
os.environ['SPCONV_ALGO'] = 'native' | |
import spaces | |
from gradio_litmodel3d import LitModel3D | |
import warp as wp | |
import subprocess | |
import torch | |
import uuid | |
from threading import Thread | |
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor,TextIteratorStreamer,AutoTokenizer | |
from qwen_vl_utils import process_vision_info | |
from trellis.pipelines import TrellisImageTo3DPipeline,TrellisTextTo3DPipeline | |
from trellis.utils import render_utils, postprocessing_utils | |
import trimesh | |
from trimesh.exchange.gltf import export_glb | |
import tempfile | |
import copy | |
#import plotly.graph_objs as go | |
import plotly.graph_objects as go | |
from PIL import Image | |
import plotly.express as px | |
import random | |
import open3d as o3d | |
import imageio | |
from huggingface_hub import hf_hub_download | |
import numpy as np | |
HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
TMP_DIR = "/tmp/Trellis-demo" | |
os.makedirs(TMP_DIR, exist_ok=True) | |
def _remove_image_special(text): | |
text = text.replace('<ref>', '').replace('</ref>', '') | |
return re.sub(r'<box>.*?(</box>|$)', '', text) | |
def is_video_file(filename): | |
video_extensions = ['.mp4', '.avi', '.mkv', '.mov', '.wmv', '.flv', '.webm', '.mpeg'] | |
return any(filename.lower().endswith(ext) for ext in video_extensions) | |
def token_to_mesh(full_response): | |
d1=full_response.split("><mesh") | |
d2=[] | |
for i in range(len(d1)): | |
try: | |
if d1[i][:5]=="<mesh": | |
d2.append(int(d1[i][5:])) | |
else: | |
d2.append(int(d1[i])) | |
except: | |
pass | |
while len(d2)<1024: | |
d2.append(d2[-1]) | |
encoding_indices=torch.tensor(d2).unsqueeze(0) | |
return encoding_indices | |
def save_ply_from_array(verts): | |
header = [ | |
"ply", | |
"format ascii 1.0", | |
f"element vertex {verts.shape[0]}", | |
"property float x", | |
"property float y", | |
"property float z", | |
"end_header" | |
] | |
tmpf = tempfile.NamedTemporaryFile(suffix=".ply", delete=False) | |
tmpf.write(("\n".join(header) + "\n").encode("utf-8")) | |
np.savetxt(tmpf, verts, fmt="%.6f") | |
tmpf.flush(); tmpf.close() | |
return tmpf.name | |
def predict(_chatbot,task_history,viewer_voxel,viewer_mesh,task_new,seed,top_k,top_p,temperature,video_path,simplify,texture_size): | |
torch.manual_seed(seed) | |
chat_query = _chatbot[-1][0] | |
query = task_history[-1][0] | |
if len(chat_query) == 0: | |
_chatbot.pop() | |
task_history.pop() | |
return _chatbot,task_history,viewer_voxel,viewer_mesh,task_new,video_path | |
print("User: " + _parse_text(query)) | |
history_cp = copy.deepcopy(task_history) | |
full_response = "" | |
messages = [] | |
content = [] | |
image_lst = [] | |
for q, a in task_new: | |
if isinstance(q, (tuple, list)): | |
if not is_video_file(q[0]): | |
image_lst.append(q[0]) | |
else: | |
image_lst.append(q[0]) | |
task_new.clear() | |
for q, a in history_cp: | |
if isinstance(q, (tuple, list)): | |
if is_video_file(q[0]): | |
content.append({'video': f'file://{q[0]}'}) | |
else: | |
trial_id = uuid.uuid4() | |
pipeline_image.preprocess_image_white(Image.open(q[0])).save(f"{TMP_DIR}/{trial_id}.png", "png") | |
content.append({'image': f'file://{TMP_DIR}/{trial_id}.png'}) | |
#content.append({'image': f'file://{q[0]}'}) | |
else: | |
content.append({'text': q}) | |
messages.append({'role': 'user', 'content': content}) | |
messages.append({'role': 'assistant', 'content': [{'text': a}]}) | |
content = [] | |
messages.pop() | |
messages = _transform_messages(messages) | |
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
image_inputs, video_inputs = process_vision_info(messages) | |
inputs = processor(text=[text], images=image_inputs,videos=video_inputs, padding=True, return_tensors='pt') | |
inputs = inputs.to("cuda") | |
eos_token_id = [tokenizer.eos_token_id,159858] | |
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) | |
gen_kwargs = {'max_new_tokens': 2048, 'streamer': streamer,"eos_token_id":eos_token_id,\ | |
"top_k":top_k,"top_p":top_p,"temperature":temperature,"eos_token_id":eos_token_id,**inputs} | |
thread = Thread(target=model.generate, kwargs=gen_kwargs) | |
thread.start() | |
full_response = "" | |
encoding_indices = None | |
_chatbot[-1] = (_parse_text(chat_query), "") | |
for new_text in streamer: | |
if new_text: | |
if "<mesh" in new_text: | |
encoding_indices = token_to_mesh(new_text) | |
new_text = new_text.replace("><",",")[1:-1] | |
new_text = new_text.split("mesh-start,")[1].split(",mesh-end")[0] | |
new_text = f"mesh-start\n{new_text}\nmesh-end" | |
full_response += new_text | |
_chatbot[-1] = (_parse_text(chat_query), _parse_text(full_response)) | |
yield _chatbot,viewer_voxel,viewer_mesh,task_new,video_path | |
task_history[-1] = (chat_query, full_response) | |
yield _chatbot,viewer_voxel,viewer_mesh,task_new,video_path | |
if encoding_indices is not None: | |
print("processing mesh...") | |
recon = vqvae.Decode(encoding_indices.to("cuda")) | |
z_s = recon[0].detach().cpu() | |
z_s = (z_s>0)*1 | |
indices = torch.nonzero(z_s[0] == 1) | |
position_recon= (indices.float() + 0.5) / 64 - 0.5 | |
fig = make_pointcloud_figure(position_recon) | |
yield _chatbot,fig,viewer_mesh,task_new,video_path | |
position=position_recon | |
coords = ((position + 0.5) * 64).int().contiguous() | |
ss = torch.zeros(1, 64, 64, 64, dtype=torch.long) | |
ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1 | |
ss=ss.unsqueeze(0) | |
coords = torch.argwhere(ss>0)[:, [0, 2, 3, 4]].int() | |
coords = coords.to("cuda") | |
try: | |
print("processing mesh...") | |
if len(image_lst) == 0: | |
# text to 3d | |
with torch.no_grad(): | |
prompt = chat_query | |
cond = pipeline_text.get_cond([prompt]) | |
slat = pipeline_text.sample_slat(cond, coords) | |
outputs = pipeline_text.decode_slat(slat, ['mesh', 'gaussian']) | |
video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color'] | |
video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal'] | |
video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))] | |
trial_id = uuid.uuid4() | |
video_path = f"{TMP_DIR}/{trial_id}.mp4" | |
os.makedirs(os.path.dirname(video_path), exist_ok=True) | |
imageio.mimsave(video_path, video, fps=15) | |
yield _chatbot,fig,viewer_mesh,task_new,video_path | |
glb = postprocessing_utils.to_glb( | |
outputs['gaussian'][0], | |
outputs['mesh'][0], | |
simplify=simplify, | |
texture_size=texture_size, | |
verbose=False | |
) | |
glb.export(f"{TMP_DIR}/{trial_id}.glb") | |
print("processing mesh over...") | |
yield _chatbot,fig,f"{TMP_DIR}/{trial_id}.glb",task_new,video_path | |
else: | |
# image to 3d | |
with torch.no_grad(): | |
img = pipeline_image.preprocess_image(Image.open(image_lst[-1])) | |
cond = pipeline_image.get_cond([img]) | |
slat = pipeline_image.sample_slat(cond, coords) | |
outputs = pipeline_image.decode_slat(slat, ['mesh', 'gaussian']) | |
video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color'] | |
video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal'] | |
video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))] | |
trial_id = uuid.uuid4() | |
video_path = f"{TMP_DIR}/{trial_id}.mp4" | |
os.makedirs(os.path.dirname(video_path), exist_ok=True) | |
imageio.mimsave(video_path, video, fps=15) | |
yield _chatbot,fig,viewer_mesh,task_new,video_path | |
glb = postprocessing_utils.to_glb( | |
outputs['gaussian'][0], | |
outputs['mesh'][0], | |
simplify=simplify, | |
texture_size=texture_size, | |
verbose=False | |
) | |
glb.export(f"{TMP_DIR}/{trial_id}.glb") | |
print("processing mesh over...") | |
yield _chatbot,fig,f"{TMP_DIR}/{trial_id}.glb",task_new,video_path | |
except Exception as e: | |
print(e) | |
yield _chatbot,fig,viewer_mesh,task_new,video_path | |
def regenerate(_chatbot, task_history): | |
if not task_history: | |
return _chatbot | |
item = task_history[-1] | |
if item[1] is None: | |
return _chatbot | |
task_history[-1] = (item[0], None) | |
chatbot_item = _chatbot.pop(-1) | |
if chatbot_item[0] is None: | |
_chatbot[-1] = (_chatbot[-1][0], None) | |
else: | |
_chatbot.append((chatbot_item[0], None)) | |
_chatbot_gen = predict(_chatbot, task_history) | |
for _chatbot in _chatbot_gen: | |
yield _chatbot | |
def _parse_text(text): | |
lines = text.split("\n") | |
lines = [line for line in lines if line != ""] | |
count = 0 | |
for i, line in enumerate(lines): | |
if "```" in line: | |
count += 1 | |
items = line.split("`") | |
if count % 2 == 1: | |
lines[i] = f'<pre><code class="language-{items[-1]}">' | |
else: | |
lines[i] = f"<br></code></pre>" | |
else: | |
if i > 0: | |
if count % 2 == 1: | |
line = line.replace("`", r"\`") | |
line = line.replace("<", "<") | |
line = line.replace(">", ">") | |
line = line.replace(" ", " ") | |
line = line.replace("*", "*") | |
line = line.replace("_", "_") | |
line = line.replace("-", "-") | |
line = line.replace(".", ".") | |
line = line.replace("!", "!") | |
line = line.replace("(", "(") | |
line = line.replace(")", ")") | |
line = line.replace("$", "$") | |
lines[i] = "<br>" + line | |
text = "".join(lines) | |
return text | |
def add_text_prefix(text): | |
text = f"Please generate a 3D asset based on the prompt I provided: {text}" | |
return gr.update(value=text) | |
def token_to_words(token): | |
mesh = "<mesh-start>" | |
for j in range(1024): | |
mesh += f"<mesh{token[j]}>" | |
mesh += "<mesh-end>" | |
return mesh | |
def add_text(history, task_history, text,task_new): | |
task_text = text | |
history = history if history is not None else [] | |
task_history = task_history if task_history is not None else [] | |
history = history + [(_parse_text(text), None)] | |
task_history = task_history + [(task_text, None)] | |
task_new = task_new + [(task_text, None)] | |
return history, task_history,task_new | |
def add_file(history, task_history, file, task_new, fig, query): | |
if file.name.endswith(('.obj', '.glb')): | |
position_recon = load_vertices(file.name)#(N,3) | |
coords = ((torch.from_numpy(position_recon) + 0.5) * 64).int().contiguous() | |
ss = torch.zeros(1, 64, 64, 64, dtype=torch.long) | |
ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1 | |
token = vqvae.Encode(ss.to(dtype=torch.float32).unsqueeze(0).to("cuda")) | |
token = token[0].cpu().numpy().tolist() | |
words = token_to_words(token) | |
fig = make_pointcloud_figure(position_recon,rotate=True) | |
return history, task_history,file.name,task_new,fig,gr.update( | |
value= f"{words}\nGive a quick overview of the object represented by this 3D mesh.") | |
history = history if history is not None else [] | |
task_history = task_history if task_history is not None else [] | |
history = history + [((file.name,), None)] | |
task_history = task_history + [((file.name,), None)] | |
task_new = task_new + [((file.name,), None)] | |
return history, task_history, file.name, task_new, fig, query | |
def reset_user_input(): | |
return gr.update(value="") | |
def reset_state(task_history): | |
task_history.clear() | |
return [] | |
def make_pointcloud_figure_success(verts,rotate=False): | |
fig = go.Figure(go.Scatter3d( | |
x=[0.005*n for n in range(100)], y=[0.005*n for n in range(100)], z=[0.005*n for n in range(100)], | |
mode='markers', marker=dict(size=8) | |
)) | |
return fig | |
def make_pointcloud_figure(verts,rotate=False): | |
if rotate: | |
verts = verts.copy() | |
verts[:, 0] *= -1.0 | |
N = len(verts) | |
soft_palette = ["#FFEBEE", "#FFF3E0", "#FFFDE7", "#E8F5E9",] | |
palette = px.colors.qualitative.Set3 | |
base_colors = [palette[i % len(palette)] for i in range(N)] | |
random.shuffle(base_colors) | |
camera = dict( | |
eye=dict(x=0.0, y=2.5, z=0.0), | |
center=dict(x=0.0, y=0.0, z=0.0), | |
up=dict(x=0.0, y=0.0, z=1.0), | |
projection=dict(type="orthographic") | |
) | |
scatter = go.Scatter3d( | |
x=verts[:, 0].tolist(), | |
y=verts[:, 1].tolist(), | |
z=verts[:, 2].tolist(), | |
mode='markers', | |
marker=dict( | |
size=2, | |
color=base_colors, | |
opacity=1, | |
line=dict(width=1) | |
) | |
) | |
layout = go.Layout( | |
width =800, | |
height=300, | |
scene=dict( | |
xaxis=dict(visible=False, range=[-0.6,0.6]), | |
yaxis=dict(visible=False, range=[-0.6,0.6]), | |
zaxis=dict(visible=False, range=[-0.6,0.6]), | |
camera=camera | |
), | |
margin=dict(l=0, r=0, b=0, t=0) | |
) | |
fig = go.Figure(data=[scatter], layout=layout) | |
return fig | |
def rotate_points(points, axis='x', angle_deg=90): | |
angle_rad = np.deg2rad(angle_deg) | |
if axis == 'x': | |
R = trimesh.transformations.rotation_matrix(angle_rad, [1, 0, 0])[:3, :3] | |
elif axis == 'y': | |
R = trimesh.transformations.rotation_matrix(angle_rad, [0, 1, 0])[:3, :3] | |
elif axis == 'z': | |
R = trimesh.transformations.rotation_matrix(angle_rad, [0, 0, 1])[:3, :3] | |
else: | |
raise ValueError("axis must be 'x', 'y', or 'z'") | |
return points @ R.T | |
def convert_trimesh_to_open3d(trimesh_mesh): | |
o3d_mesh = o3d.geometry.TriangleMesh() | |
o3d_mesh.vertices = o3d.utility.Vector3dVector( | |
np.asarray(trimesh_mesh.vertices, dtype=np.float64) | |
) | |
o3d_mesh.triangles = o3d.utility.Vector3iVector( | |
np.asarray(trimesh_mesh.faces, dtype=np.int32) | |
) | |
return o3d_mesh | |
def load_vertices(filepath): | |
mesh = trimesh.load(filepath, force='mesh') | |
mesh = convert_trimesh_to_open3d(mesh) | |
vertices = np.asarray(mesh.vertices) | |
min_vals = vertices.min() | |
max_vals = vertices.max() | |
vertices_normalized = (vertices - min_vals) / (max_vals - min_vals) | |
vertices = vertices_normalized * 1.0 - 0.5 | |
vertices = np.clip(vertices, -0.5 + 1e-6, 0.5 - 1e-6) | |
mesh.vertices = o3d.utility.Vector3dVector(vertices) | |
voxel_grid = o3d.geometry.VoxelGrid.create_from_triangle_mesh_within_bounds(mesh, voxel_size=1/64, min_bound=(-0.5, -0.5, -0.5), max_bound=(0.5, 0.5, 0.5)) | |
vertices = np.array([voxel.grid_index for voxel in voxel_grid.get_voxels()]) | |
assert np.all(vertices >= 0) and np.all(vertices < 64), "Some vertices are out of bounds" | |
vertices = (vertices + 0.5) / 64 - 0.5 | |
voxel = rotate_points(vertices, axis='x', angle_deg=90) | |
return voxel | |
def add_file2(history, task_history, file,task_new): | |
history = history if history is not None else [] | |
task_history = task_history if task_history is not None else [] | |
history = history + [((file,), None)] | |
task_history = task_history + [((file,), None)] | |
task_new = task_new + [((file,), None)] | |
return history, task_history, file, task_new | |
def _transform_messages(original_messages): | |
transformed_messages = [] | |
for message in original_messages: | |
new_content = [] | |
for item in message['content']: | |
if 'image' in item: | |
new_item = {'type': 'image', 'image': item['image']} | |
elif 'text' in item: | |
new_item = {'type': 'text', 'text': item['text']} | |
elif 'video' in item: | |
new_item = {'type': 'video', 'video': item['video']} | |
else: | |
continue | |
new_content.append(new_item) | |
new_message = {'role': message['role'], 'content': new_content} | |
transformed_messages.append(new_message) | |
return transformed_messages | |
print(f"CUDA Available: {torch.cuda.is_available()}") | |
print(f"CUDA Version: {torch.version.cuda}") | |
print(f"Number of GPUs: {torch.cuda.device_count()}") | |
from trellis.models.sparse_structure_vqvae import VQVAE3D | |
device = torch.device("cuda") | |
vqvae = VQVAE3D(num_embeddings=8192) | |
vqvae.eval() | |
filepath = hf_hub_download(repo_id="yejunliang23/3DVQVAE",filename="3DVQVAE.bin") | |
state_dict = torch.load(filepath, map_location="cpu") | |
vqvae.load_state_dict(state_dict) | |
vqvae=vqvae.to(device) | |
MODEL_DIR = "yejunliang23/ShapeLLM-7B-omni" | |
model_ckpt_path=MODEL_DIR | |
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_ckpt_path, torch_dtype="auto", device_map="auto") | |
processor = AutoProcessor.from_pretrained(model_ckpt_path) | |
tokenizer = processor.tokenizer | |
from huggingface_hub import hf_hub_download | |
pipeline_text = TrellisTextTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-text-xlarge") | |
pipeline_text.to(device) | |
pipeline_image = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large") | |
pipeline_image.to(device) | |
_DESCRIPTION = ''' | |
* Project page of ShapeLLM-Omni: https://jamesyjl.github.io/ShapeLLM/ | |
* As generation tasks currently lack support for multi-turn dialogue, it's strongly recommended to clear the chat history before starting a new task | |
* The model's 3D understanding is limited to shape only, so color and texture should be ignored in 3D captioning tasks | |
''' | |
with gr.Blocks() as demo: | |
gr.Markdown("# ShapeLLM-omni: A Native Multimodal LLM for 3D Generation and Understanding") | |
gr.Markdown(_DESCRIPTION) | |
with gr.Row(): | |
with gr.Column(): | |
chatbot = gr.Chatbot(label='ShapeLLM-Omni', elem_classes="control-height", height=500) | |
with gr.Accordion(label="Generation Settings", open=False): | |
seed = gr.Number(value=42, label="seed", precision=0) | |
top_k = gr.Slider(label="top_k",minimum=1024,maximum=8194,value=1024,step=10) | |
top_p = gr.Slider(label="top_p",minimum=0.1,maximum=1.0,value=0.1,step=0.05) | |
temperature = gr.Slider(label="temperature",minimum=0.1,maximum=1.0,value=0.1,step=0.05) | |
with gr.Accordion(label="GLB Extraction Settings", open=False): | |
mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01) | |
texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512) | |
query = gr.Textbox(lines=2, label='Input') | |
image_input = gr.Image(visible=False, type="filepath", label="Image Input") | |
with gr.Column(): | |
with gr.Row(): | |
addfile_btn = gr.UploadButton("📁 Upload", file_types=["image", "video",".obj",".glb"]) | |
submit_btn = gr.Button("🚀 Submit") | |
with gr.Row(): | |
regen_btn = gr.Button("🤔️ Regenerate") | |
empty_bin = gr.Button("🧹 Clear History") | |
task_history = gr.State([]) | |
task_new = gr.State([]) | |
with gr.Column(): | |
viewer_plot = gr.Plot(label="Voxel Visual",scale=0.5) | |
video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300) | |
viewer_mesh = LitModel3D(label="Extracted GLB", exposure=20.0, height=300) | |
examples_text = gr.Examples( | |
examples=[ | |
["A drone with four propellers and a central body."], | |
["A stone axe with a handle."], | |
["the titanic, aerial view."], | |
["A 3D model of a small yellow and blue robot with wheels and two pots."], | |
["A futuristic vehicle with a sleek design and multiple wheels."], | |
["A car with four wheels and a roof."], | |
], | |
inputs=[query], | |
label="text-to-3d examples", | |
fn=add_text_prefix, | |
outputs=[query], | |
cache_examples=True, | |
) | |
examples_text.dataset.click( | |
fn=add_text, | |
inputs=[chatbot, task_history, query,task_new], | |
outputs=[chatbot, task_history,task_new], | |
) | |
examples_image = gr.Examples( | |
label="image-to-3d examples", | |
examples=[os.path.join("examples", i) for i in os.listdir("examples")], | |
inputs=[image_input], | |
examples_per_page = 20, | |
) | |
image_input.change( | |
fn=add_file2, | |
inputs=[chatbot, task_history, image_input,task_new], | |
outputs=[chatbot, task_history,viewer_mesh,task_new], | |
show_progress=True | |
) | |
submit_btn.click(add_text, [chatbot, task_history, query,task_new],\ | |
[chatbot, task_history,task_new]).then( | |
predict, [chatbot, task_history,viewer_plot,viewer_mesh,task_new,seed,top_k,top_p,temperature,video_output,mesh_simplify,texture_size],\ | |
[chatbot,viewer_plot,viewer_mesh,task_new,video_output], show_progress=True | |
) | |
submit_btn.click(reset_user_input, [], [query]) | |
empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True) | |
regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True) | |
addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn, task_new, viewer_plot, query],\ | |
[chatbot, task_history, viewer_mesh, task_new, viewer_plot, query],\ | |
show_progress=True) | |
demo.launch() |