Spaces:
Running
on
A10G
Running
on
A10G
Jiaming Han
commited on
Commit
·
3c55139
1
Parent(s):
5b036b0
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +4 -4
- app.py +90 -436
- config/llama2/7B.json +0 -1
- config/llama2/tokenizer.model +0 -3
- data/__pycache__/conversation_lib.cpython-310.pyc +0 -0
- data/__pycache__/conversation_lib.cpython-39.pyc +0 -0
- data/__pycache__/fintune_dataset.cpython-310.pyc +0 -0
- data/__pycache__/fintune_dataset.cpython-39.pyc +0 -0
- data/__pycache__/imu_utils.cpython-310.pyc +0 -0
- data/__pycache__/imu_utils.cpython-39.pyc +0 -0
- data/__pycache__/video_utils.cpython-310.pyc +0 -0
- data/__pycache__/video_utils.cpython-39.pyc +0 -0
- data/conversation_lib.py +0 -369
- data/fintune_dataset.py +0 -449
- data/imu_utils.py +0 -257
- data/video_utils.py +0 -204
- demos/multi_turn_mm.py +0 -300
- examples/bell_ring.wav +0 -3
- examples/bird_audio.wav +0 -0
- examples/depth_normal/depth/0084.png +0 -0
- examples/depth_normal/depth/0131.png +0 -0
- examples/depth_normal/depth/0297.png +0 -0
- examples/depth_normal/depth/0331.png +0 -0
- examples/depth_normal/depth/0432.png +0 -0
- examples/depth_normal/depth/0633.png +0 -0
- examples/depth_normal/depth/0663.png +0 -0
- examples/depth_normal/depth/0771.png +0 -0
- examples/depth_normal/depth/0782.png +0 -0
- examples/depth_normal/depth/1001.png +0 -0
- examples/depth_normal/depth/1051.png +0 -0
- examples/depth_normal/depth/1129.png +0 -0
- examples/depth_normal/depth/1205.png +0 -0
- examples/depth_normal/depth/1336.png +0 -0
- examples/depth_normal/depth/1383.png +0 -0
- examples/depth_normal/depth/1386.png +0 -0
- examples/depth_normal/depth/1393.png +0 -0
- examples/depth_normal/depth/1447.png +0 -0
- examples/depth_normal/depth_scaled/0084.png +0 -0
- examples/depth_normal/depth_scaled/0131.png +0 -0
- examples/depth_normal/depth_scaled/0297.png +0 -0
- examples/depth_normal/depth_scaled/0331.png +0 -0
- examples/depth_normal/depth_scaled/0432.png +0 -0
- examples/depth_normal/depth_scaled/0633.png +0 -0
- examples/depth_normal/depth_scaled/0663.png +0 -0
- examples/depth_normal/depth_scaled/0771.png +0 -0
- examples/depth_normal/depth_scaled/0782.png +0 -0
- examples/depth_normal/depth_scaled/1001.png +0 -0
- examples/depth_normal/depth_scaled/1051.png +0 -0
- examples/depth_normal/depth_scaled/1129.png +0 -0
- examples/depth_normal/depth_scaled/1205.png +0 -0
README.md
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
emoji: 🚀
|
4 |
colorFrom: red
|
5 |
colorTo: indigo
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
python_version: 3.
|
11 |
---
|
12 |
|
13 |
-
#
|
|
|
1 |
---
|
2 |
+
title: Tar
|
3 |
emoji: 🚀
|
4 |
colorFrom: red
|
5 |
colorTo: indigo
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 5.34.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
python_version: 3.10.18
|
11 |
---
|
12 |
|
13 |
+
# Tar: Unifying Visual Understanding and Generation via Text-Aligned Representations
|
app.py
CHANGED
@@ -1,457 +1,111 @@
|
|
1 |
-
import sys
|
2 |
import os
|
3 |
-
import argparse
|
4 |
-
import multiprocessing as mp
|
5 |
-
import numpy as np
|
6 |
-
from typing import List, Optional
|
7 |
-
|
8 |
-
import torch
|
9 |
-
import torch.distributed as dist
|
10 |
-
|
11 |
-
from fairscale.nn.model_parallel import initialize as fs_init
|
12 |
-
|
13 |
import gradio as gr
|
14 |
-
from
|
15 |
-
from
|
16 |
-
from model.meta import MetaModel
|
17 |
-
from data.conversation_lib import conv_templates, SeparatorStyle
|
18 |
-
from PIL import Image
|
19 |
-
import torchvision.transforms as transforms
|
20 |
-
from data.fintune_dataset import make_audio_features
|
21 |
-
from data import video_utils
|
22 |
-
from dataclasses import dataclass
|
23 |
-
from huggingface_hub import hf_hub_download
|
24 |
-
import plotly.graph_objects as go
|
25 |
-
from data.fintune_dataset import pc_norm
|
26 |
-
from functools import partial
|
27 |
-
import glob
|
28 |
-
import torchvision.transforms.functional as F
|
29 |
|
30 |
-
|
31 |
-
transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(0.75, 1.3333), interpolation=3,
|
32 |
-
antialias=None), # 3 is bicubic
|
33 |
-
transforms.ToTensor(),
|
34 |
-
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
return [F.resized_crop(img, i, j, h, w, self.size, self.interpolation, antialias=self.antialias) for img in imgs]
|
40 |
-
|
41 |
-
class PairToTensor(transforms.ToTensor):
|
42 |
-
def __call__(self, pics):
|
43 |
-
return [F.to_tensor(pic) for pic in pics]
|
44 |
-
|
45 |
-
class PairNormalize(transforms.Normalize):
|
46 |
-
def forward(self, tensors):
|
47 |
-
return [F.normalize(tensor, self.mean, self.std, self.inplace) for tensor in tensors]
|
48 |
|
49 |
-
|
50 |
-
|
51 |
-
PairToTensor(),
|
52 |
-
PairNormalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
|
53 |
-
|
54 |
-
def load_audio(audio_path):
|
55 |
-
fbank = make_audio_features(audio_path, mel_bins=128)
|
56 |
-
fbank = fbank.transpose(0, 1)[None] #[1, 128, 1024]
|
57 |
-
return fbank
|
58 |
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
def load_point(point_path):
|
64 |
-
point_feat = np.load(point_path)
|
65 |
-
point_feat = torch.tensor(point_feat)
|
66 |
-
point_feat = pc_norm(point_feat)
|
67 |
-
return point_feat
|
68 |
-
|
69 |
-
def load_fmri(fmri_path):
|
70 |
-
data = np.load(fmri_path)
|
71 |
-
data = data.mean(axis=0)
|
72 |
-
data = torch.tensor(data[None])
|
73 |
-
return data
|
74 |
-
|
75 |
-
def load_rgbx(image_path, x_image_path):
|
76 |
-
# trick: replace path if 'depth_scaled' in path
|
77 |
-
x_image_path = x_image_path.replace('depth_scaled', 'depth')
|
78 |
-
|
79 |
-
image = Image.open(image_path).convert('RGB')
|
80 |
-
x_image = Image.open(x_image_path).convert('RGB')
|
81 |
-
x_image = x_image.resize(image.size[-2:])
|
82 |
-
|
83 |
-
image, x_image = transform_pairimg_train([image, x_image])
|
84 |
-
|
85 |
-
# [2, 3, H, W]
|
86 |
-
image = torch.stack([image, x_image], dim=0)
|
87 |
-
return image
|
88 |
-
|
89 |
-
|
90 |
-
class Ready: pass
|
91 |
-
|
92 |
-
|
93 |
-
def model_worker(
|
94 |
-
rank: int, args: argparse.Namespace, barrier: mp.Barrier,
|
95 |
-
request_queue: mp.Queue, response_queue: Optional[mp.Queue] = None,
|
96 |
-
) -> None:
|
97 |
-
"""
|
98 |
-
The worker function that manipulates the GPU to run the inference.
|
99 |
-
Exact n_gpu workers are started, with each one operating on a separate GPU.
|
100 |
-
|
101 |
-
Args:
|
102 |
-
rank (int): Distributed rank of the worker.
|
103 |
-
args (argparse.Namespace): All command line arguments.
|
104 |
-
barrier (multiprocessing.Barrier): A barrier used to delay the start
|
105 |
-
of Web UI to be after the start of the model.
|
106 |
-
"""
|
107 |
-
|
108 |
-
world_size = len(args.gpu_ids)
|
109 |
-
gpu_id = args.gpu_ids[rank]
|
110 |
-
dist.init_process_group(
|
111 |
-
backend="nccl", rank=rank, world_size=world_size,
|
112 |
-
init_method=f"tcp://{args.master_addr}:{args.master_port}",
|
113 |
-
)
|
114 |
-
print(f"| distributed init on worker {rank}/{world_size}. "
|
115 |
-
f"using gpu: {gpu_id}")
|
116 |
-
fs_init.initialize_model_parallel(world_size)
|
117 |
-
torch.cuda.set_device(gpu_id)
|
118 |
-
|
119 |
-
torch.manual_seed(1)
|
120 |
-
np.random.seed(1)
|
121 |
-
|
122 |
-
# set the print behavior.
|
123 |
-
setup_for_distributed(rank == 0)
|
124 |
-
|
125 |
-
target_dtype = {
|
126 |
-
"bf16": torch.bfloat16,
|
127 |
-
"fp16": torch.float16
|
128 |
-
}[args.dtype]
|
129 |
-
with default_tensor_type(dtype=target_dtype, device="cuda"):
|
130 |
-
model = MetaModel(args.llama_type, args.llama_config, tokenizer_path=args.tokenizer_path)
|
131 |
-
for ckpt_id in range(args.num_ckpts):
|
132 |
-
ckpt_path = hf_hub_download(repo_id=args.pretrained_path, filename=args.ckpt_format.format(str(ckpt_id)))
|
133 |
-
# ckpt_path = os.path.join(args.pretrained_path, args.ckpt_format.format(str(ckpt_id)))
|
134 |
-
print(f"Loading pretrained weights {ckpt_path}")
|
135 |
-
checkpoint = torch.load(ckpt_path, map_location='cpu')
|
136 |
-
msg = model.load_state_dict(checkpoint, strict=False)
|
137 |
-
# print("load result:\n", msg)
|
138 |
-
model.cuda()
|
139 |
-
model.eval()
|
140 |
-
print(f"Model = {str(model)}")
|
141 |
-
|
142 |
-
barrier.wait()
|
143 |
-
|
144 |
-
while True:
|
145 |
-
if response_queue is not None:
|
146 |
-
response_queue.put(Ready())
|
147 |
-
img_path, audio_path, video_path, point_path, fmri_path, depth_path, depth_rgb_path, normal_path, normal_rgb_path, chatbot, max_gen_len, temperature, top_p, modality = request_queue.get()
|
148 |
-
try:
|
149 |
-
if 'image' in modality and img_path is not None:
|
150 |
-
image = Image.open(img_path).convert('RGB')
|
151 |
-
inputs = T_random_resized_crop(image)
|
152 |
-
elif 'video' in modality and video_path is not None:
|
153 |
-
inputs = load_video(video_path)
|
154 |
-
elif 'audio' in modality and audio_path is not None:
|
155 |
-
inputs = load_audio(audio_path)
|
156 |
-
elif 'point' in modality and point_path is not None:
|
157 |
-
inputs = load_point(point_path)
|
158 |
-
elif 'fmri' in modality and fmri_path is not None:
|
159 |
-
inputs = load_fmri(fmri_path)
|
160 |
-
elif 'rgbd' in modality and depth_path is not None and depth_rgb_path is not None:
|
161 |
-
inputs = load_rgbx(depth_rgb_path, depth_path)
|
162 |
-
elif 'rgbn' in modality and normal_path is not None and normal_rgb_path is not None:
|
163 |
-
inputs = load_rgbx(normal_rgb_path, normal_path)
|
164 |
-
else:
|
165 |
-
inputs = None
|
166 |
-
except:
|
167 |
-
inputs = None
|
168 |
-
|
169 |
-
if inputs is not None:
|
170 |
-
inputs = inputs[None].cuda().to(target_dtype)
|
171 |
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
stream_response["end_of_content"] = True
|
195 |
-
|
196 |
-
# keep a few characters if not end_of_content to avoid sending
|
197 |
-
# part of conv_sep before all of it is generated.
|
198 |
-
if not stream_response["end_of_content"]:
|
199 |
-
if len(stream_response["text"]) < len(conv_sep):
|
200 |
-
continue
|
201 |
-
stream_response["text"] = (
|
202 |
-
stream_response["text"][:-len(conv_sep)]
|
203 |
-
)
|
204 |
-
|
205 |
-
if response_queue is not None:
|
206 |
-
response_queue.put(stream_response)
|
207 |
-
|
208 |
-
if stream_response["end_of_content"]:
|
209 |
-
break
|
210 |
-
|
211 |
-
|
212 |
-
def gradio_worker(
|
213 |
-
request_queues: List[mp.Queue], response_queue: mp.Queue,
|
214 |
-
args: argparse.Namespace, barrier: mp.Barrier,
|
215 |
-
) -> None:
|
216 |
-
"""
|
217 |
-
The gradio worker is responsible for displaying the WebUI and relay the
|
218 |
-
requests to model workers. It should be launched only once.
|
219 |
|
220 |
-
|
221 |
-
|
222 |
-
each model worker).
|
223 |
-
args (argparse.Namespace): All command line arguments.
|
224 |
-
barrier (multiprocessing.Barrier): A barrier used to delay the start
|
225 |
-
of Web UI to be after the start of the model.
|
226 |
-
"""
|
227 |
|
228 |
-
|
229 |
-
|
230 |
|
231 |
-
|
232 |
-
|
233 |
-
content_piece = response_queue.get()
|
234 |
-
if isinstance(content_piece, Ready):
|
235 |
-
break
|
236 |
-
for queue in request_queues:
|
237 |
-
queue.put((img_path, audio_path, video_path, point_path, fmri_path, depth_path, depth_rgb_path, normal_path, normal_rgb_path, chatbot, max_gen_len, gen_t, top_p, modality))
|
238 |
-
while True:
|
239 |
-
content_piece = response_queue.get()
|
240 |
-
chatbot[-1][1] = content_piece["text"]
|
241 |
-
yield chatbot
|
242 |
-
if content_piece["end_of_content"]:
|
243 |
-
break
|
244 |
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
|
250 |
-
|
251 |
-
chatbot = []
|
252 |
-
msg = ""
|
253 |
-
return chatbot, msg
|
254 |
-
|
255 |
-
def show_point_cloud(file):
|
256 |
-
point = load_point(file).numpy()
|
257 |
-
fig = go.Figure(
|
258 |
-
data=[
|
259 |
-
go.Scatter3d(
|
260 |
-
x=point[:,0], y=point[:,1], z=point[:,2],
|
261 |
-
mode='markers',
|
262 |
-
marker=dict(
|
263 |
-
size=1.2,
|
264 |
-
color=['rgb({},{},{})'.format(r, g, b) for r,g,b in zip(point[:,3], point[:,4], point[:,5])]
|
265 |
-
))],
|
266 |
-
layout=dict(
|
267 |
-
scene=dict(
|
268 |
-
xaxis=dict(visible=False),
|
269 |
-
yaxis=dict(visible=False),
|
270 |
-
zaxis=dict(visible=False)
|
271 |
-
)),)
|
272 |
-
return fig
|
273 |
-
|
274 |
-
def change_modality(modal):
|
275 |
-
return modal
|
276 |
|
277 |
-
|
278 |
-
.contain { display: flex; flex-direction: column; }
|
279 |
-
#component-0 { height: 100%; }
|
280 |
-
#chatbot { flex-grow: 1; overflow: auto;}
|
281 |
-
"""
|
282 |
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
292 |
with gr.Column(scale=1):
|
293 |
-
|
294 |
-
|
295 |
-
gr.Examples(
|
296 |
-
examples=[
|
297 |
-
"examples/new_york.jpg",
|
298 |
-
"examples/food_menu.png",
|
299 |
-
],
|
300 |
-
inputs=[img_path],
|
301 |
-
)
|
302 |
-
with gr.Tab('Video') as video_tab:
|
303 |
-
video_path = gr.Video(label='Video Input', max_length=180)
|
304 |
-
gr.Examples(
|
305 |
-
examples=[
|
306 |
-
"examples/flower.mp4",
|
307 |
-
"examples/star_kun.mp4",
|
308 |
-
],
|
309 |
-
inputs=[video_path],
|
310 |
-
)
|
311 |
-
with gr.Tab('Audio') as audio_tab:
|
312 |
-
audio_path = gr.Audio(label='Audio Input', type='filepath', sources=['upload'])
|
313 |
-
gr.Examples(
|
314 |
-
examples=[
|
315 |
-
"examples/bell_ring.wav",
|
316 |
-
"examples/bird_audio.wav",
|
317 |
-
],
|
318 |
-
inputs=[audio_path],
|
319 |
-
)
|
320 |
-
with gr.Tab('Point Cloud') as point_tab:
|
321 |
-
point_path = gr.File(label='Point Cloud Input', elem_id="pointpath", elem_classes="")
|
322 |
-
point_vis = gr.Plot()
|
323 |
-
btn = gr.Button(value="Show Point Cloud")
|
324 |
-
btn.click(show_point_cloud, point_path, point_vis)
|
325 |
-
gr.Examples(
|
326 |
-
examples=glob.glob("examples/point/*.npy"),
|
327 |
-
inputs=[point_path],
|
328 |
-
examples_per_page=5,
|
329 |
-
)
|
330 |
-
with gr.Tab('IMU') as imu_tab:
|
331 |
-
gr.Markdown('Coming soon🤗')
|
332 |
-
with gr.Tab('fMRI') as fmri_tab:
|
333 |
-
fmri_path = gr.File(label='fMRI Input', elem_id="fmripath", elem_classes="")
|
334 |
-
fmri_image_path = gr.Image(label='Reference Image', interactive=False)
|
335 |
-
gr.Examples(
|
336 |
-
examples=[
|
337 |
-
[file.replace('.jpg', '.npy'), file]
|
338 |
-
for file in glob.glob("examples/fmri/*.jpg")
|
339 |
-
],
|
340 |
-
inputs=[fmri_path, fmri_image_path],
|
341 |
-
examples_per_page=3,
|
342 |
-
)
|
343 |
-
with gr.Tab('Depth Map') as depth_tab:
|
344 |
-
depth_path = gr.Image(label='Depth Map', type='filepath')
|
345 |
-
depth_rgb_path = gr.Image(label='RGB Image', type='filepath')
|
346 |
-
gr.Examples(
|
347 |
-
examples=[
|
348 |
-
[rgb_image.replace('rgb', 'depth_scaled'), rgb_image]
|
349 |
-
for rgb_image in glob.glob("examples/depth_normal/rgb/*.png")[:9]
|
350 |
-
],
|
351 |
-
inputs=[depth_path, depth_rgb_path],
|
352 |
-
examples_per_page=3,
|
353 |
-
)
|
354 |
-
with gr.Tab('Normal Map') as normal_tab:
|
355 |
-
normal_path = gr.Image(label='Normal Map', type='filepath')
|
356 |
-
normal_rgb_path = gr.Image(label='RGB Image', type='filepath')
|
357 |
-
gr.Examples(
|
358 |
-
examples=[
|
359 |
-
[rgb_image.replace('rgb', 'normal'), rgb_image]
|
360 |
-
for rgb_image in glob.glob("examples/depth_normal/rgb/*.png")[9:]
|
361 |
-
],
|
362 |
-
inputs=[normal_path, normal_rgb_path],
|
363 |
-
examples_per_page=3,
|
364 |
-
)
|
365 |
-
with gr.Column(scale=2):
|
366 |
-
chatbot = gr.Chatbot(elem_id="chatbot")
|
367 |
-
msg = gr.Textbox()
|
368 |
-
|
369 |
with gr.Row():
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
max_gen_len = gr.Slider(
|
375 |
-
minimum=1, maximum=args.model_max_seq_len // 2,
|
376 |
-
value=args.model_max_seq_len // 2, interactive=True,
|
377 |
-
label="Single-turn max response length",
|
378 |
-
)
|
379 |
-
gen_t = gr.Slider(
|
380 |
-
minimum=0, maximum=1, value=0.1, interactive=True,
|
381 |
-
label="Temperature",
|
382 |
-
)
|
383 |
-
top_p = gr.Slider(
|
384 |
-
minimum=0, maximum=1, value=0.75, interactive=True,
|
385 |
-
label="Top-p",
|
386 |
-
)
|
387 |
-
|
388 |
-
img_tab.select(partial(change_modality, 'image'), [], [modality])
|
389 |
-
video_tab.select(partial(change_modality, 'video'), [], [modality])
|
390 |
-
audio_tab.select(partial(change_modality, 'audio'), [], [modality])
|
391 |
-
point_tab.select(partial(change_modality, 'point'), [], [modality])
|
392 |
-
fmri_tab.select(partial(change_modality, 'fmri'), [], [modality])
|
393 |
-
depth_tab.select(partial(change_modality, 'rgbd'), [], [modality])
|
394 |
-
normal_tab.select(partial(change_modality, 'rgbn'), [], [modality])
|
395 |
-
|
396 |
-
img_path.change(clear, [], [chatbot, msg])
|
397 |
-
audio_path.change(clear, [], [chatbot, msg])
|
398 |
-
video_path.change(clear, [], [chatbot, msg])
|
399 |
-
point_path.change(clear, [], [chatbot, msg])
|
400 |
-
fmri_path.change(clear, [], [chatbot, msg])
|
401 |
-
depth_path.change(clear, [], [chatbot, msg])
|
402 |
-
normal_path.change(clear, [], [chatbot, msg])
|
403 |
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
)
|
409 |
-
submit_button.click(
|
410 |
-
show_user_input, [msg, chatbot], [msg, chatbot],
|
411 |
-
).then(
|
412 |
-
stream_model_output, [img_path, audio_path, video_path, point_path, fmri_path, depth_path, depth_rgb_path, normal_path, normal_rgb_path, chatbot, max_gen_len, gen_t, top_p, modality], chatbot,
|
413 |
-
)
|
414 |
-
undo_button.click(undo, chatbot, chatbot)
|
415 |
-
barrier.wait()
|
416 |
-
demo.queue(api_open=True).launch(share=True, max_threads=1)
|
417 |
-
|
418 |
-
|
419 |
-
@dataclass
|
420 |
-
class DemoConfig:
|
421 |
-
gpu_ids = [0]
|
422 |
-
tokenizer_path = "config/llama2/tokenizer.model"
|
423 |
-
llama_type = "onellm"
|
424 |
-
llama_config = "config/llama2/7B.json"
|
425 |
-
model_max_seq_len = 2048
|
426 |
-
pretrained_path = "csuhan/OneLLM-7B-hf"
|
427 |
-
# pretrained_path = "/home/pgao/jiaming/weights/7B_v20_splits/"
|
428 |
-
ckpt_format = "consolidated.00-of-01.s{}.pth"
|
429 |
-
num_ckpts = 10
|
430 |
-
master_port = 23863
|
431 |
-
master_addr = "127.0.0.1"
|
432 |
-
dtype = "fp16"
|
433 |
-
|
434 |
-
if __name__ == "__main__":
|
435 |
-
args = DemoConfig()
|
436 |
-
|
437 |
-
# using the default "fork" method messes up some imported libs (e.g.,
|
438 |
-
# pandas)
|
439 |
-
# mp.set_start_method("spawn")
|
440 |
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
worker_processes = []
|
445 |
-
barrier = mp.Barrier(len(args.gpu_ids) + 1)
|
446 |
-
for rank, gpu_id in enumerate(args.gpu_ids):
|
447 |
-
request_queue = mp.Queue()
|
448 |
-
rank_response_queue = response_queue if rank == 0 else None
|
449 |
-
process = mp.Process(
|
450 |
-
target=model_worker,
|
451 |
-
args=(rank, args, barrier, request_queue, rank_response_queue),
|
452 |
)
|
453 |
-
process.start()
|
454 |
-
worker_processes.append(process)
|
455 |
-
request_queues.append(request_queue)
|
456 |
|
457 |
-
|
|
|
|
|
1 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import gradio as gr
|
3 |
+
from torchvision.transforms.functional import to_tensor
|
4 |
+
from huggingface_hub import hf_hub_download, login
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
+
from t2i_inference import T2IConfig, TextToImageInference
|
|
|
|
|
|
|
|
|
7 |
|
8 |
+
def generate_text(self, image: str, prompt: str) -> str:
|
9 |
+
image = image.convert('RGB')
|
10 |
+
image = to_tensor(image).unsqueeze(0).to(self.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
+
image_code = self.visual_tokenizer.encoder(image)['bottleneck_rep']
|
13 |
+
image_text = "".join([f"<I{x}>" for x in image_code[0].cpu().tolist()])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
+
messages = [
|
16 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
17 |
+
{"role": "user", "content": f"{image_text}\n{prompt}"}
|
18 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
+
input_text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
21 |
+
inputs = self.tokenizer(input_text, return_tensors="pt")
|
22 |
+
|
23 |
+
gen_ids = self.model.generate(
|
24 |
+
inputs.input_ids.to(self.device),
|
25 |
+
max_new_tokens=512,
|
26 |
+
do_sample=True)
|
27 |
+
return self.tokenizer.batch_decode(gen_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)[0]
|
28 |
+
|
29 |
+
login(token=os.getenv('HF_TOKEN'))
|
30 |
+
config = T2IConfig()
|
31 |
+
config.ar_path = hf_hub_download("csuhan/TA-Tok", "ar_dtok_lp_512px.pth")
|
32 |
+
config.encoder_path = hf_hub_download("csuhan/TA-Tok", "ta_tok.pth")
|
33 |
+
config.decoder_path = hf_hub_download("peizesun/llamagen_t2i", "vq_ds16_t2i.pt")
|
34 |
+
inference = TextToImageInference(config)
|
35 |
+
|
36 |
+
def generate_image(prompt, top_p, top_k, cfg_scale):
|
37 |
+
config.top_p = top_p
|
38 |
+
config.top_k = top_k
|
39 |
+
config.cfg_scale = cfg_scale
|
40 |
+
image = inference.generate_image(prompt)
|
41 |
+
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
+
def clear_inputs_t2i():
|
44 |
+
return "", None
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
+
def understand_image(image, prompt):
|
47 |
+
return generate_text(inference, image, prompt)
|
48 |
|
49 |
+
def clear_inputs_i2t():
|
50 |
+
return None, ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
53 |
+
gr.Markdown(
|
54 |
+
"""
|
55 |
+
<div align="center">
|
56 |
|
57 |
+
### Tar: Unifying Visual Understanding and Generation via Text-Aligned Representations
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
+
[📄 Paper](https://arxiv.org/abs/xxxx.xxxxx) • [💻 Code](https://github.com/csuhan/Tar) • [📦 Model](https://huggingface.co/csuhan/TA-Tok)
|
|
|
|
|
|
|
|
|
60 |
|
61 |
+
</div>
|
62 |
+
""",
|
63 |
+
elem_id="title",
|
64 |
+
)
|
65 |
+
with gr.Tab("Image Generation"):
|
66 |
+
with gr.Row():
|
67 |
+
with gr.Column(scale=1):
|
68 |
+
prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt")
|
69 |
+
with gr.Accordion("Advanced Settings", open=False):
|
70 |
+
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")
|
71 |
+
top_k = gr.Slider(1, 2000, value=1200, step=10, label="Top-k")
|
72 |
+
cfg_scale = gr.Slider(1.0, 20.0, value=4.0, step=0.5, label="CFG Scale")
|
73 |
+
with gr.Row():
|
74 |
+
generate_btn = gr.Button("Generate")
|
75 |
+
clear_btn = gr.Button("Clear")
|
76 |
+
with gr.Column(scale=2):
|
77 |
+
output_image = gr.Image(label="Generated Image")
|
78 |
+
|
79 |
+
generate_btn.click(
|
80 |
+
generate_image,
|
81 |
+
inputs=[prompt, top_p, top_k, cfg_scale],
|
82 |
+
outputs=output_image
|
83 |
+
)
|
84 |
+
clear_btn.click(
|
85 |
+
clear_inputs_t2i,
|
86 |
+
outputs=[prompt, output_image]
|
87 |
+
)
|
88 |
+
|
89 |
+
with gr.Tab("Image Understanding"):
|
90 |
+
with gr.Row():
|
91 |
with gr.Column(scale=1):
|
92 |
+
image_input = gr.Image(label="Upload Image", type="pil")
|
93 |
+
question_input = gr.Textbox(label="Instruction", value="Describe the image shortly.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
with gr.Row():
|
95 |
+
qa_btn = gr.Button("Generate")
|
96 |
+
clear_btn_i2t = gr.Button("Clear")
|
97 |
+
with gr.Column(scale=1):
|
98 |
+
answer_output = gr.Textbox(label="Response", lines=4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
+
qa_btn.click(
|
101 |
+
understand_image,
|
102 |
+
inputs=[image_input, question_input],
|
103 |
+
outputs=answer_output
|
104 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
+
clear_btn_i2t.click(
|
107 |
+
clear_inputs_i2t,
|
108 |
+
outputs=[image_input, question_input, answer_output]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
)
|
|
|
|
|
|
|
110 |
|
111 |
+
demo.launch(share=True)
|
config/llama2/7B.json
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
{"dim": 4096, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-05, "vocab_size": -1}
|
|
|
|
config/llama2/tokenizer.model
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
|
3 |
-
size 499723
|
|
|
|
|
|
|
|
data/__pycache__/conversation_lib.cpython-310.pyc
DELETED
Binary file (9.14 kB)
|
|
data/__pycache__/conversation_lib.cpython-39.pyc
DELETED
Binary file (9.15 kB)
|
|
data/__pycache__/fintune_dataset.cpython-310.pyc
DELETED
Binary file (14.2 kB)
|
|
data/__pycache__/fintune_dataset.cpython-39.pyc
DELETED
Binary file (14.2 kB)
|
|
data/__pycache__/imu_utils.cpython-310.pyc
DELETED
Binary file (6.71 kB)
|
|
data/__pycache__/imu_utils.cpython-39.pyc
DELETED
Binary file (6.71 kB)
|
|
data/__pycache__/video_utils.cpython-310.pyc
DELETED
Binary file (6.53 kB)
|
|
data/__pycache__/video_utils.cpython-39.pyc
DELETED
Binary file (6.51 kB)
|
|
data/conversation_lib.py
DELETED
@@ -1,369 +0,0 @@
|
|
1 |
-
import dataclasses
|
2 |
-
from enum import auto, Enum
|
3 |
-
from typing import List, Tuple
|
4 |
-
|
5 |
-
|
6 |
-
class SeparatorStyle(Enum):
|
7 |
-
"""Different separator style."""
|
8 |
-
SINGLE = auto()
|
9 |
-
TWO = auto()
|
10 |
-
MPT = auto()
|
11 |
-
|
12 |
-
|
13 |
-
@dataclasses.dataclass
|
14 |
-
class Conversation:
|
15 |
-
"""A class that keeps all conversation history."""
|
16 |
-
system: str
|
17 |
-
roles: List[str]
|
18 |
-
messages: List[List[str]]
|
19 |
-
offset: int
|
20 |
-
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
21 |
-
sep: str = "###"
|
22 |
-
sep2: str = None
|
23 |
-
version: str = "Unknown"
|
24 |
-
|
25 |
-
skip_next: bool = False
|
26 |
-
|
27 |
-
def get_prompt(self):
|
28 |
-
if self.sep_style == SeparatorStyle.SINGLE:
|
29 |
-
ret = self.system + '\n\n' + self.sep
|
30 |
-
for role, message in self.messages:
|
31 |
-
if message:
|
32 |
-
if type(message) is tuple:
|
33 |
-
message, _, _ = message
|
34 |
-
ret += role + ": " + message + '\n' + self.sep
|
35 |
-
else:
|
36 |
-
ret += role + ":"
|
37 |
-
return ret
|
38 |
-
elif self.sep_style == SeparatorStyle.TWO:
|
39 |
-
seps = [self.sep, self.sep2]
|
40 |
-
ret = self.system + seps[0]
|
41 |
-
for i, (role, message) in enumerate(self.messages):
|
42 |
-
if message:
|
43 |
-
if type(message) is tuple:
|
44 |
-
message, _, _ = message
|
45 |
-
ret += role + ": " + message + seps[i % 2]
|
46 |
-
else:
|
47 |
-
ret += role + ":"
|
48 |
-
return ret
|
49 |
-
if self.sep_style == SeparatorStyle.MPT:
|
50 |
-
ret = self.system + self.sep
|
51 |
-
for role, message in self.messages:
|
52 |
-
if message:
|
53 |
-
if type(message) is tuple:
|
54 |
-
message, _, _ = message
|
55 |
-
ret += role + message + self.sep
|
56 |
-
else:
|
57 |
-
ret += role
|
58 |
-
return ret
|
59 |
-
else:
|
60 |
-
raise ValueError(f"Invalid style: {self.sep_style}")
|
61 |
-
|
62 |
-
def append_message(self, role, message):
|
63 |
-
self.messages.append([role, message])
|
64 |
-
|
65 |
-
def get_images(self, return_pil=False):
|
66 |
-
images = []
|
67 |
-
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
68 |
-
if i % 2 == 0:
|
69 |
-
if type(msg) is tuple:
|
70 |
-
import base64
|
71 |
-
from io import BytesIO
|
72 |
-
from PIL import Image
|
73 |
-
msg, image, image_process_mode = msg
|
74 |
-
if image_process_mode == "Pad":
|
75 |
-
def expand2square(pil_img, background_color=(122, 116, 104)):
|
76 |
-
width, height = pil_img.size
|
77 |
-
if width == height:
|
78 |
-
return pil_img
|
79 |
-
elif width > height:
|
80 |
-
result = Image.new(pil_img.mode, (width, width), background_color)
|
81 |
-
result.paste(pil_img, (0, (width - height) // 2))
|
82 |
-
return result
|
83 |
-
else:
|
84 |
-
result = Image.new(pil_img.mode, (height, height), background_color)
|
85 |
-
result.paste(pil_img, ((height - width) // 2, 0))
|
86 |
-
return result
|
87 |
-
|
88 |
-
image = expand2square(image)
|
89 |
-
elif image_process_mode == "Crop":
|
90 |
-
pass
|
91 |
-
elif image_process_mode == "Resize":
|
92 |
-
image = image.resize((224, 224))
|
93 |
-
else:
|
94 |
-
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
|
95 |
-
max_hw, min_hw = max(image.size), min(image.size)
|
96 |
-
aspect_ratio = max_hw / min_hw
|
97 |
-
max_len, min_len = 800, 400
|
98 |
-
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
99 |
-
longest_edge = int(shortest_edge * aspect_ratio)
|
100 |
-
W, H = image.size
|
101 |
-
if H > W:
|
102 |
-
H, W = longest_edge, shortest_edge
|
103 |
-
else:
|
104 |
-
H, W = shortest_edge, longest_edge
|
105 |
-
image = image.resize((W, H))
|
106 |
-
if return_pil:
|
107 |
-
images.append(image)
|
108 |
-
else:
|
109 |
-
buffered = BytesIO()
|
110 |
-
image.save(buffered, format="JPEG")
|
111 |
-
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
112 |
-
images.append(img_b64_str)
|
113 |
-
return images
|
114 |
-
|
115 |
-
def to_gradio_chatbot(self):
|
116 |
-
ret = []
|
117 |
-
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
118 |
-
if i % 2 == 0:
|
119 |
-
if type(msg) is tuple:
|
120 |
-
import base64
|
121 |
-
from io import BytesIO
|
122 |
-
msg, image, image_process_mode = msg
|
123 |
-
max_hw, min_hw = max(image.size), min(image.size)
|
124 |
-
aspect_ratio = max_hw / min_hw
|
125 |
-
max_len, min_len = 800, 400
|
126 |
-
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
127 |
-
longest_edge = int(shortest_edge * aspect_ratio)
|
128 |
-
W, H = image.size
|
129 |
-
if H > W:
|
130 |
-
H, W = longest_edge, shortest_edge
|
131 |
-
else:
|
132 |
-
H, W = shortest_edge, longest_edge
|
133 |
-
image = image.resize((W, H))
|
134 |
-
# image = image.resize((224, 224))
|
135 |
-
buffered = BytesIO()
|
136 |
-
image.save(buffered, format="JPEG")
|
137 |
-
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
138 |
-
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
139 |
-
msg = msg.replace('<image>', img_str)
|
140 |
-
ret.append([msg, None])
|
141 |
-
else:
|
142 |
-
ret[-1][-1] = msg
|
143 |
-
return ret
|
144 |
-
|
145 |
-
def copy(self):
|
146 |
-
return Conversation(
|
147 |
-
system=self.system,
|
148 |
-
roles=self.roles,
|
149 |
-
messages=[[x, y] for x, y in self.messages],
|
150 |
-
offset=self.offset,
|
151 |
-
sep_style=self.sep_style,
|
152 |
-
sep=self.sep,
|
153 |
-
sep2=self.sep2)
|
154 |
-
|
155 |
-
def dict(self):
|
156 |
-
if len(self.get_images()) > 0:
|
157 |
-
return {
|
158 |
-
"system": self.system,
|
159 |
-
"roles": self.roles,
|
160 |
-
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
161 |
-
"offset": self.offset,
|
162 |
-
"sep": self.sep,
|
163 |
-
"sep2": self.sep2,
|
164 |
-
}
|
165 |
-
return {
|
166 |
-
"system": self.system,
|
167 |
-
"roles": self.roles,
|
168 |
-
"messages": self.messages,
|
169 |
-
"offset": self.offset,
|
170 |
-
"sep": self.sep,
|
171 |
-
"sep2": self.sep2,
|
172 |
-
}
|
173 |
-
|
174 |
-
|
175 |
-
conv_v1 = Conversation(
|
176 |
-
system="A chat between a curious human and an artificial intelligence assistant. "
|
177 |
-
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
178 |
-
roles=("Human", "Assistant"),
|
179 |
-
messages=(
|
180 |
-
("Human", "Give three tips for staying healthy."),
|
181 |
-
("Assistant",
|
182 |
-
"Sure, here are three tips for staying healthy:\n"
|
183 |
-
"1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. "
|
184 |
-
"It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, "
|
185 |
-
"and certain cancers. Aim for at least 150 minutes of moderate-intensity aerobic exercise or "
|
186 |
-
"75 minutes of vigorous-intensity aerobic exercise per week, along with muscle-strengthening "
|
187 |
-
"activities at least two days per week.\n"
|
188 |
-
"2. Eat a balanced diet: Eating a balanced diet that is rich in fruits, "
|
189 |
-
"vegetables, whole grains, lean proteins, and healthy fats can help support "
|
190 |
-
"your overall health. Try to limit your intake of processed and high-sugar foods, "
|
191 |
-
"and aim to drink plenty of water throughout the day.\n"
|
192 |
-
"3. Get enough sleep: Getting enough quality sleep is essential for your physical "
|
193 |
-
"and mental health. Adults should aim for seven to nine hours of sleep per night. "
|
194 |
-
"Establish a regular sleep schedule and try to create a relaxing bedtime routine to "
|
195 |
-
"help improve the quality of your sleep.")
|
196 |
-
),
|
197 |
-
offset=2,
|
198 |
-
sep_style=SeparatorStyle.SINGLE,
|
199 |
-
sep="###",
|
200 |
-
)
|
201 |
-
|
202 |
-
conv_v1_2 = Conversation(
|
203 |
-
system="A chat between a curious human and an artificial intelligence assistant. "
|
204 |
-
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
205 |
-
roles=("Human", "Assistant"),
|
206 |
-
messages=(),
|
207 |
-
|
208 |
-
# (
|
209 |
-
# ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
|
210 |
-
# ("Assistant",
|
211 |
-
# "Renewable energy sources are those that can be replenished naturally in a relatively "
|
212 |
-
# "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
213 |
-
# "Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
214 |
-
# "depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
215 |
-
# "renewable and non-renewable energy sources:\n"
|
216 |
-
# "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
217 |
-
# "energy sources are finite and will eventually run out.\n"
|
218 |
-
# "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
219 |
-
# "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
220 |
-
# "and other negative effects.\n"
|
221 |
-
# "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
222 |
-
# "have lower operational costs than non-renewable sources.\n"
|
223 |
-
# "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
224 |
-
# "locations than non-renewable sources.\n"
|
225 |
-
# "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
226 |
-
# "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
227 |
-
# "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
228 |
-
# "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
|
229 |
-
# )
|
230 |
-
offset = 2,
|
231 |
-
sep_style = SeparatorStyle.SINGLE,
|
232 |
-
sep = "###",
|
233 |
-
)
|
234 |
-
|
235 |
-
conv_vicuna_v1_1 = Conversation(
|
236 |
-
system="A chat between a curious user and an artificial intelligence assistant. "
|
237 |
-
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
238 |
-
roles=("USER", "ASSISTANT"),
|
239 |
-
version="v1",
|
240 |
-
messages=(),
|
241 |
-
offset=0,
|
242 |
-
sep_style=SeparatorStyle.TWO,
|
243 |
-
sep=" ",
|
244 |
-
sep2="</s>",
|
245 |
-
)
|
246 |
-
|
247 |
-
conv_mpt = Conversation(
|
248 |
-
system="""<|im_start|>system
|
249 |
-
- You are a helpful language and vision assistant.
|
250 |
-
- You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
|
251 |
-
- You should follow the instructions carefully and explain your answers in detail.""",
|
252 |
-
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
253 |
-
version="mpt",
|
254 |
-
messages=(),
|
255 |
-
offset=0,
|
256 |
-
sep_style=SeparatorStyle.MPT,
|
257 |
-
sep="<|im_end|>",
|
258 |
-
)
|
259 |
-
|
260 |
-
conv_mpt_text = Conversation(
|
261 |
-
system="""<|im_start|>system
|
262 |
-
- You are a helpful assistant chatbot trained by MosaicML.
|
263 |
-
- You answer questions.
|
264 |
-
- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
265 |
-
- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.""",
|
266 |
-
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
267 |
-
version="mpt",
|
268 |
-
messages=(),
|
269 |
-
offset=0,
|
270 |
-
sep_style=SeparatorStyle.MPT,
|
271 |
-
sep="<|im_end|>",
|
272 |
-
)
|
273 |
-
|
274 |
-
conv_bair_v1 = Conversation(
|
275 |
-
system="BEGINNING OF CONVERSATION:",
|
276 |
-
roles=("USER", "GPT"),
|
277 |
-
messages=(),
|
278 |
-
offset=0,
|
279 |
-
sep_style=SeparatorStyle.TWO,
|
280 |
-
sep=" ",
|
281 |
-
sep2="</s>",
|
282 |
-
)
|
283 |
-
|
284 |
-
simple_conv = Conversation(
|
285 |
-
system="A chat between a curious human and an artificial intelligence assistant. "
|
286 |
-
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
287 |
-
roles=("Human", "Assistant"),
|
288 |
-
messages=(
|
289 |
-
("Human", "Hi!"),
|
290 |
-
("Assistant", "Hi there! How can I help you today?")
|
291 |
-
),
|
292 |
-
offset=2,
|
293 |
-
sep_style=SeparatorStyle.SINGLE,
|
294 |
-
sep="###",
|
295 |
-
)
|
296 |
-
|
297 |
-
simple_conv_multimodal = Conversation(
|
298 |
-
system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
|
299 |
-
"You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
300 |
-
"Follow the instructions carefully and explain your answers in detail.",
|
301 |
-
roles=("Human", "Assistant"),
|
302 |
-
messages=(
|
303 |
-
("Human", "Hi!"),
|
304 |
-
("Assistant", "Hi there! How can I help you today?\n")
|
305 |
-
),
|
306 |
-
offset=2,
|
307 |
-
sep_style=SeparatorStyle.SINGLE,
|
308 |
-
sep="###",
|
309 |
-
)
|
310 |
-
|
311 |
-
simple_conv_mpt_multimodal = Conversation(
|
312 |
-
system="""<|im_start|>system
|
313 |
-
- You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab.
|
314 |
-
- You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
|
315 |
-
- You should follow the instructions carefully and explain your answers in detail.""",
|
316 |
-
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
317 |
-
version="mpt",
|
318 |
-
messages=(),
|
319 |
-
offset=0,
|
320 |
-
sep_style=SeparatorStyle.MPT,
|
321 |
-
sep="<|im_end|>",
|
322 |
-
)
|
323 |
-
|
324 |
-
simple_conv_legacy = Conversation(
|
325 |
-
system="You are LLaVA, a large language model trained by UW Madison WAIV Lab."
|
326 |
-
"You are designed to assist human with a variety of tasks using natural language."
|
327 |
-
"Follow the instructions carefully.",
|
328 |
-
roles=("Human", "Assistant"),
|
329 |
-
messages=(
|
330 |
-
("Human", "Hi!\n\n### Response:"),
|
331 |
-
("Assistant", "Hi there! How can I help you today?\n")
|
332 |
-
),
|
333 |
-
offset=2,
|
334 |
-
sep_style=SeparatorStyle.SINGLE,
|
335 |
-
sep="###",
|
336 |
-
)
|
337 |
-
|
338 |
-
conv_llava_v1 = Conversation(
|
339 |
-
system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
|
340 |
-
"You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
341 |
-
"Follow the instructions carefully and explain your answers in detail.",
|
342 |
-
roles=("USER", "ASSISTANT"),
|
343 |
-
version="v1",
|
344 |
-
messages=(),
|
345 |
-
offset=0,
|
346 |
-
sep_style=SeparatorStyle.TWO,
|
347 |
-
sep=" ",
|
348 |
-
sep2="</s>",
|
349 |
-
)
|
350 |
-
|
351 |
-
default_conversation = conv_v1_2
|
352 |
-
conv_templates = {
|
353 |
-
"default": conv_v1_2,
|
354 |
-
"simple": simple_conv,
|
355 |
-
"simple_legacy": simple_conv_legacy,
|
356 |
-
"multimodal": simple_conv_multimodal,
|
357 |
-
"mpt_multimodal": simple_conv_mpt_multimodal,
|
358 |
-
"llava_v1": conv_llava_v1,
|
359 |
-
|
360 |
-
# fastchat
|
361 |
-
"v1": conv_v1_2,
|
362 |
-
"bair_v1": conv_bair_v1,
|
363 |
-
"vicuna_v1_1": conv_vicuna_v1_1,
|
364 |
-
"mpt": conv_mpt,
|
365 |
-
"mpt_text": conv_mpt_text,
|
366 |
-
}
|
367 |
-
|
368 |
-
if __name__ == "__main__":
|
369 |
-
print(default_conversation.get_prompt())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/fintune_dataset.py
DELETED
@@ -1,449 +0,0 @@
|
|
1 |
-
import warnings
|
2 |
-
|
3 |
-
import torch
|
4 |
-
import yaml
|
5 |
-
from torch.utils.data import Dataset
|
6 |
-
from PIL import Image
|
7 |
-
import json
|
8 |
-
from model.tokenizer import Tokenizer
|
9 |
-
import os
|
10 |
-
import torchvision.transforms as transforms
|
11 |
-
import random
|
12 |
-
import torchvision.transforms.functional as F
|
13 |
-
import torchaudio
|
14 |
-
from . import conversation_lib
|
15 |
-
|
16 |
-
import numpy as np
|
17 |
-
from . import video_utils
|
18 |
-
from .imu_utils import get_imu_frames
|
19 |
-
|
20 |
-
|
21 |
-
IGNORE_INDEX = -100
|
22 |
-
|
23 |
-
DEFAULT_IMAGE_TOKEN = "<image>"
|
24 |
-
try:
|
25 |
-
from torchvision.transforms import InterpolationMode
|
26 |
-
|
27 |
-
BICUBIC = InterpolationMode.BICUBIC
|
28 |
-
except ImportError:
|
29 |
-
BICUBIC = Image.BICUBIC
|
30 |
-
|
31 |
-
T_random_resized_crop = transforms.Compose([
|
32 |
-
transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(0.75, 1.3333), interpolation=BICUBIC,
|
33 |
-
antialias=None), # 3 is bicubic
|
34 |
-
transforms.ToTensor(),
|
35 |
-
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
|
36 |
-
|
37 |
-
|
38 |
-
# image transform
|
39 |
-
transform_img_train = transforms.Compose([
|
40 |
-
transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(
|
41 |
-
0.75, 1.3333), interpolation=3, antialias=None), # 3 is bicubic
|
42 |
-
transforms.ToTensor(),
|
43 |
-
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
|
44 |
-
|
45 |
-
|
46 |
-
class PairRandomResizedCrop(transforms.RandomResizedCrop):
|
47 |
-
def forward(self, imgs):
|
48 |
-
i, j, h, w = self.get_params(imgs[0], self.scale, self.ratio)
|
49 |
-
return [F.resized_crop(img, i, j, h, w, self.size, self.interpolation, antialias=self.antialias) for img in imgs]
|
50 |
-
|
51 |
-
|
52 |
-
class PairToTensor(transforms.ToTensor):
|
53 |
-
def __call__(self, pics):
|
54 |
-
return [F.to_tensor(pic) for pic in pics]
|
55 |
-
|
56 |
-
|
57 |
-
class PairNormalize(transforms.Normalize):
|
58 |
-
def forward(self, tensors):
|
59 |
-
return [F.normalize(tensor, self.mean, self.std, self.inplace) for tensor in tensors]
|
60 |
-
|
61 |
-
|
62 |
-
transform_pairimg_train = transforms.Compose([
|
63 |
-
PairRandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(
|
64 |
-
0.75, 1.3333), interpolation=3, antialias=None), # 3 is bicubic
|
65 |
-
PairToTensor(),
|
66 |
-
PairNormalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
|
67 |
-
|
68 |
-
|
69 |
-
def pc_norm(pc):
|
70 |
-
""" pc: NxC, return NxC """
|
71 |
-
xyz = pc[:, :3]
|
72 |
-
other_feature = pc[:, 3:]
|
73 |
-
|
74 |
-
centroid = torch.mean(xyz, dim=0)
|
75 |
-
xyz = xyz - centroid
|
76 |
-
m = torch.max(torch.sqrt(torch.sum(xyz ** 2, dim=1)))
|
77 |
-
xyz = xyz / m
|
78 |
-
|
79 |
-
pc = torch.cat((xyz, other_feature), dim=1)
|
80 |
-
return pc
|
81 |
-
|
82 |
-
|
83 |
-
def make_audio_features(wav_name, mel_bins=128, target_length=1024, aug=False):
|
84 |
-
waveform, sr = torchaudio.load(wav_name)
|
85 |
-
# assert sr == 16000, 'input audio sampling rate must be 16kHz'
|
86 |
-
if sr != 16000:
|
87 |
-
trans = torchaudio.transforms.Resample(sr, 16000)
|
88 |
-
waveform = trans(waveform)
|
89 |
-
|
90 |
-
waveform = waveform - waveform.mean()
|
91 |
-
|
92 |
-
fbank = torchaudio.compliance.kaldi.fbank(
|
93 |
-
waveform, htk_compat=True, sample_frequency=16000, use_energy=False,
|
94 |
-
window_type='hanning', num_mel_bins=mel_bins, dither=0.0, frame_shift=10)
|
95 |
-
|
96 |
-
n_frames = fbank.shape[0]
|
97 |
-
|
98 |
-
p = target_length - n_frames
|
99 |
-
if p > 0:
|
100 |
-
m = torch.nn.ZeroPad2d((0, 0, 0, p))
|
101 |
-
fbank = m(fbank)
|
102 |
-
elif p < 0:
|
103 |
-
fbank = fbank[0:target_length, :]
|
104 |
-
|
105 |
-
if aug:
|
106 |
-
freqm = torchaudio.transforms.FrequencyMasking(48)
|
107 |
-
timem = torchaudio.transforms.TimeMasking(192)
|
108 |
-
fbank = torch.transpose(fbank, 0, 1)
|
109 |
-
fbank = fbank.unsqueeze(0)
|
110 |
-
fbank = freqm(fbank)
|
111 |
-
fbank = timem(fbank)
|
112 |
-
fbank = fbank.squeeze(0)
|
113 |
-
fbank = torch.transpose(fbank, 0, 1)
|
114 |
-
|
115 |
-
fbank = (fbank - (-4.2677393)) / (4.5689974 * 2)
|
116 |
-
return fbank
|
117 |
-
|
118 |
-
|
119 |
-
class ConversationGenerator:
|
120 |
-
def __init__(self, tokenizer):
|
121 |
-
self.tokenizer = tokenizer
|
122 |
-
self.header = f"{conversation_lib.default_conversation.system}\n\n"
|
123 |
-
self._probe_tokenizer_style()
|
124 |
-
|
125 |
-
def _probe_tokenizer_style(self):
|
126 |
-
"""
|
127 |
-
Given a sentence, e.g. "My darling", some tokenizers will make the space a seperate token,
|
128 |
-
while some others will merge the space into the next word, forming a token representing " darling".
|
129 |
-
Knowing which style the tokenizer takes is necessary for correct ground-truth label masking.
|
130 |
-
|
131 |
-
"""
|
132 |
-
probe = "Probe am I"
|
133 |
-
sentence1 = self.tokenizer.encode(conversation_lib.default_conversation.roles[1] + ": " + probe,
|
134 |
-
bos=False, eos=False)
|
135 |
-
sentence2 = self.tokenizer.encode(probe,
|
136 |
-
bos=False, eos=False)
|
137 |
-
if sentence1[-len(sentence2):] == sentence2:
|
138 |
-
self.space_before_to_predict = False
|
139 |
-
else:
|
140 |
-
sentence3 = self.tokenizer.encode(" " + probe,
|
141 |
-
bos=False, eos=False)
|
142 |
-
assert sentence1[-len(sentence3):] == sentence3
|
143 |
-
self.space_before_to_predict = True
|
144 |
-
|
145 |
-
def add_speaker_and_signal(self, source, get_conversation=True):
|
146 |
-
"""Add speaker and start/end signal on each round."""
|
147 |
-
BEGIN_SIGNAL = "### "
|
148 |
-
END_SIGNAL = "\n"
|
149 |
-
conversation = self.header
|
150 |
-
|
151 |
-
to_predict_list = []
|
152 |
-
|
153 |
-
for sentence in source:
|
154 |
-
from_str = sentence["from"]
|
155 |
-
if from_str.lower() in ["human"]:
|
156 |
-
from_str = conversation_lib.default_conversation.roles[0]
|
157 |
-
elif from_str.lower() in ["gpt", "assistant"]:
|
158 |
-
from_str = conversation_lib.default_conversation.roles[1]
|
159 |
-
else:
|
160 |
-
raise ValueError(f"unknown dialog role: {from_str.lower()}")
|
161 |
-
|
162 |
-
value = sentence["value"]
|
163 |
-
if DEFAULT_IMAGE_TOKEN in value:
|
164 |
-
value = value.replace(DEFAULT_IMAGE_TOKEN, '').strip()
|
165 |
-
|
166 |
-
sentence_value = BEGIN_SIGNAL + from_str + ": " + value + END_SIGNAL
|
167 |
-
|
168 |
-
if from_str == conversation_lib.default_conversation.roles[1]:
|
169 |
-
to_predict_value = value + END_SIGNAL + "###"
|
170 |
-
if self.space_before_to_predict:
|
171 |
-
to_predict_value = " " + to_predict_value
|
172 |
-
to_predict_list.append(to_predict_value)
|
173 |
-
|
174 |
-
if get_conversation:
|
175 |
-
conversation = conversation + sentence_value
|
176 |
-
|
177 |
-
conversation = conversation + BEGIN_SIGNAL
|
178 |
-
return conversation, to_predict_list
|
179 |
-
|
180 |
-
|
181 |
-
DATASETS = dict(
|
182 |
-
image=[
|
183 |
-
dict(path="datasets/InstructionTuning/image/llava_v1_5_mix665k_image.json", type='image'),
|
184 |
-
dict(path='datasets/InstructionTuning/image/cococap_train.json', type='image'),
|
185 |
-
dict(path="datasets/InstructionTuning/image/llava_v1_5_mix665k_text.json", type='text'),
|
186 |
-
],
|
187 |
-
audio=[
|
188 |
-
dict(path="datasets/InstructionTuning/audio/audiocap_train.json", type='audio'),
|
189 |
-
dict(path="datasets/InstructionTuning/audio/audiocap_val.json", type='audio'),
|
190 |
-
dict(path="datasets/InstructionTuning/audio/audio_conversation.json", type='audio'),
|
191 |
-
],
|
192 |
-
video=[
|
193 |
-
dict(path="datasets/InstructionTuning/video/msrvtt_cap_trainval.json", type='video'),
|
194 |
-
dict(path="datasets/InstructionTuning/video/msrvtt_cap_test.json", type='video'),
|
195 |
-
dict(path="datasets/InstructionTuning/video/msrvtt_vqa_train.json", type='video'),
|
196 |
-
dict(path="datasets/InstructionTuning/video/msrvtt_vqa_val.json", type='video'),
|
197 |
-
dict(path="datasets/InstructionTuning/video/msrvtt_vqa_test.json", type='video'),
|
198 |
-
dict(path="datasets/InstructionTuning/video/video_complex_reasoning_10k.json", type='video'),
|
199 |
-
dict(path="datasets/InstructionTuning/video/video_conversation_10k.json", type='video'),
|
200 |
-
dict(path="datasets/InstructionTuning/video/video_detail_10k.json", type='video'),
|
201 |
-
],
|
202 |
-
point=[
|
203 |
-
dict(path="datasets/InstructionTuning/point/pointllm_70k_formated.json", type='point'),
|
204 |
-
],
|
205 |
-
rgbd=[
|
206 |
-
dict(path="datasets/InstructionTuning/depth_normal/llava_instruct_50k_depth.json", type='rgbd'),
|
207 |
-
],
|
208 |
-
rgbn=[
|
209 |
-
dict(path="datasets/InstructionTuning/depth_normal/llava_instruct_50k_normal.json", type='rgbn'),
|
210 |
-
],
|
211 |
-
imu=[
|
212 |
-
dict(path="datasets/InstructionTuning/imu/imu_fixed_50k.json", type='imu'),
|
213 |
-
],
|
214 |
-
fmri=[
|
215 |
-
dict(path="datasets/InstructionTuning/fmri/fmri_fixed.json", type='fmri'),
|
216 |
-
],
|
217 |
-
)
|
218 |
-
IMU_PATH = "/mnt/petrelfs/share_data/hanjiaming/ego4d/v2/processed_imu/"
|
219 |
-
|
220 |
-
|
221 |
-
class FinetuneDialogDataset(Dataset):
|
222 |
-
def __init__(self, dataset=['image'], transform=T_random_resized_crop, max_words=2048, image_words=30, tokenizer_path=None):
|
223 |
-
if isinstance(dataset, str):
|
224 |
-
dataset = [dataset]
|
225 |
-
|
226 |
-
self.dataset = dataset
|
227 |
-
|
228 |
-
group_ann = {}
|
229 |
-
for d in dataset:
|
230 |
-
for meta in DATASETS[d]:
|
231 |
-
meta_path, meta_type = meta['path'], meta['type']
|
232 |
-
meta_ext = os.path.splitext(meta_path)[-1]
|
233 |
-
if meta_ext == ".json":
|
234 |
-
with open(meta_path) as f:
|
235 |
-
meta_l = json.load(f)
|
236 |
-
# add data_type
|
237 |
-
# this is a temp solution
|
238 |
-
new_meta_l = []
|
239 |
-
for l in meta_l:
|
240 |
-
l['data_type'] = meta_type
|
241 |
-
new_meta_l.append(l)
|
242 |
-
meta_l = new_meta_l
|
243 |
-
elif meta_ext == ".jsonl":
|
244 |
-
meta_l = []
|
245 |
-
with open(meta_path) as f:
|
246 |
-
for i, line in enumerate(f):
|
247 |
-
try:
|
248 |
-
meta_l.append(json.loads(line))
|
249 |
-
except json.decoder.JSONDecodeError as e:
|
250 |
-
print(
|
251 |
-
f"Error decoding the following jsonl line ({i}):\n{line.rstrip()}", force=True)
|
252 |
-
raise e
|
253 |
-
else:
|
254 |
-
raise NotImplementedError(
|
255 |
-
f"Unknown meta file extension: \"{meta_ext}\". "
|
256 |
-
f"Currently, .json, .jsonl are supported. "
|
257 |
-
"If you are using a supported format, please set the file extension so that the proper parsing "
|
258 |
-
"routine can be called."
|
259 |
-
)
|
260 |
-
if meta_type not in group_ann:
|
261 |
-
group_ann[meta_type] = []
|
262 |
-
print(f"{meta_path}, type {meta_type}: len {len(meta_l)}")
|
263 |
-
group_ann[meta_type] += meta_l
|
264 |
-
|
265 |
-
# sort group_ann for higher efficiency (items in one global batch with similar length)
|
266 |
-
for meta_type, meta_l in group_ann.items():
|
267 |
-
meta_l.sort(key=lambda data_item: sum(
|
268 |
-
[len(_['value']) for _ in data_item['conversations']]))
|
269 |
-
|
270 |
-
self.group_ann = group_ann
|
271 |
-
self.ann = sum(list(self.group_ann.values()), start=[])
|
272 |
-
|
273 |
-
self.group_indices = {}
|
274 |
-
start_pos = 0
|
275 |
-
for meta_type, meta_l in self.group_ann.items():
|
276 |
-
self.group_indices[meta_type] = list(
|
277 |
-
range(start_pos, start_pos + len(meta_l)))
|
278 |
-
start_pos = start_pos + len(meta_l)
|
279 |
-
|
280 |
-
print(f"total length: {len(self)}")
|
281 |
-
self.transform = transform
|
282 |
-
print(f"transform:\n{self.transform}")
|
283 |
-
self.max_words = max_words
|
284 |
-
self.image_words = image_words
|
285 |
-
self.tokenizer = Tokenizer(model_path=tokenizer_path)
|
286 |
-
self.conversation_generator = ConversationGenerator(self.tokenizer)
|
287 |
-
|
288 |
-
self.load_funcs = dict(
|
289 |
-
image=self.load_image,
|
290 |
-
audio=self.load_audio,
|
291 |
-
video=self.load_video,
|
292 |
-
point=self.load_point,
|
293 |
-
rgbd=self.load_rgbx,
|
294 |
-
rgbn=self.load_rgbx,
|
295 |
-
imu=self.load_imu,
|
296 |
-
fmri=self.load_fmri
|
297 |
-
)
|
298 |
-
|
299 |
-
def __len__(self):
|
300 |
-
return len(self.ann)
|
301 |
-
|
302 |
-
def load_image(self, data):
|
303 |
-
filename = data['image']
|
304 |
-
image = Image.open(filename).convert('RGB')
|
305 |
-
image = self.transform(image)
|
306 |
-
return image
|
307 |
-
|
308 |
-
def load_audio(self, data):
|
309 |
-
audio_path = data['image']
|
310 |
-
fbank = make_audio_features(audio_path, mel_bins=128)
|
311 |
-
fbank = fbank.transpose(0, 1)[None] # [1, 128, 1024]
|
312 |
-
return fbank
|
313 |
-
|
314 |
-
def load_video(self, data):
|
315 |
-
video_path = data['image']
|
316 |
-
video_feats = video_utils.load_and_transform_video_data(
|
317 |
-
video_path, video_path, clip_duration=1, clips_per_video=5)
|
318 |
-
return video_feats[:, :, 0]
|
319 |
-
|
320 |
-
def load_point(self, data):
|
321 |
-
point_path = data['image']
|
322 |
-
point_feat = torch.load(point_path, map_location='cpu')
|
323 |
-
point_feat = point_feat.transpose(0, 1)
|
324 |
-
return point_feat
|
325 |
-
|
326 |
-
def load_rgbx(self, data):
|
327 |
-
image_path = data['image']
|
328 |
-
x_image_path = data['depth_image'] if 'depth_image' in data else data['normal_image']
|
329 |
-
image = Image.open(image_path).convert('RGB')
|
330 |
-
x_image = Image.open(x_image_path).convert('RGB')
|
331 |
-
x_image = x_image.resize(image.size[-2:])
|
332 |
-
|
333 |
-
image, x_image = transform_pairimg_train([image, x_image])
|
334 |
-
# [2, 3, H, W]
|
335 |
-
image = torch.stack([image, x_image], dim=0)
|
336 |
-
return image
|
337 |
-
|
338 |
-
def load_fmri(self, data):
|
339 |
-
fmri_path = data['image']
|
340 |
-
data = np.load(fmri_path)
|
341 |
-
data = data.mean(axis=0)
|
342 |
-
data = torch.tensor(data[None])
|
343 |
-
return data
|
344 |
-
|
345 |
-
def load_imu(self, data_dict):
|
346 |
-
uid = data_dict["video_uid"]
|
347 |
-
w_s = data_dict["window_start"]
|
348 |
-
w_e = data_dict["window_end"]
|
349 |
-
|
350 |
-
imu_data = get_imu_frames(
|
351 |
-
IMU_PATH, uid,
|
352 |
-
video_start_sec=w_s,
|
353 |
-
video_end_sec=w_e,
|
354 |
-
)
|
355 |
-
if imu_data is None:
|
356 |
-
raise ValueError
|
357 |
-
return imu_data['signal']
|
358 |
-
|
359 |
-
def __getitem__(self, index, expect_type=None):
|
360 |
-
if expect_type is None:
|
361 |
-
data_item = self.ann[index]
|
362 |
-
else:
|
363 |
-
# in case we want get data from specific data_type
|
364 |
-
data_item = self.group_ann[expect_type][index]
|
365 |
-
|
366 |
-
data_type = data_item['data_type']
|
367 |
-
if data_type != 'text':
|
368 |
-
if data_type in self.load_funcs:
|
369 |
-
try:
|
370 |
-
image = self.load_funcs[data_type](data_item)
|
371 |
-
if image == None:
|
372 |
-
raise ValueError('Data is None')
|
373 |
-
except:
|
374 |
-
print('Error', data_item)
|
375 |
-
rand_idx = random.randint(
|
376 |
-
0, len(self.group_ann[data_type]))
|
377 |
-
return self.__getitem__(rand_idx, expect_type=data_type)
|
378 |
-
else:
|
379 |
-
raise ValueError(f'Does not support {data_type}')
|
380 |
-
else:
|
381 |
-
image = None
|
382 |
-
# warnings.warn("pure black image for examples without image")
|
383 |
-
# image = torch.zeros(3, 224, 224)
|
384 |
-
|
385 |
-
source = data_item["conversations"]
|
386 |
-
conversation, to_predict_values = self.conversation_generator.add_speaker_and_signal(
|
387 |
-
source)
|
388 |
-
if len(to_predict_values) == 0:
|
389 |
-
warnings.warn(
|
390 |
-
f"see dialog data with nothing to predict, data: {data_item}")
|
391 |
-
return self[index-1]
|
392 |
-
|
393 |
-
tokenzed_conversation = self.tokenizer.encode(
|
394 |
-
conversation, bos=True, eos=True)
|
395 |
-
labels = [IGNORE_INDEX for _ in tokenzed_conversation]
|
396 |
-
|
397 |
-
check_pos = 0
|
398 |
-
for value in to_predict_values:
|
399 |
-
tokenized_value = self.tokenizer.encode(
|
400 |
-
value, bos=False, eos=False)
|
401 |
-
value_pos = find_sublist(
|
402 |
-
tokenzed_conversation[check_pos:], tokenized_value) + check_pos
|
403 |
-
if value_pos == -1:
|
404 |
-
print(
|
405 |
-
"a sentence mismatches the corresponding piece in the conversation")
|
406 |
-
return self[index-1]
|
407 |
-
labels[value_pos:value_pos+len(tokenized_value)] = tokenized_value
|
408 |
-
assert labels[value_pos:value_pos+len(
|
409 |
-
tokenized_value)] == tokenzed_conversation[value_pos:value_pos+len(tokenized_value)]
|
410 |
-
check_pos = value_pos+len(tokenized_value)
|
411 |
-
|
412 |
-
input2 = torch.tensor(tokenzed_conversation, dtype=torch.int64)
|
413 |
-
labels = torch.tensor(labels, dtype=torch.int64)
|
414 |
-
|
415 |
-
if image is not None:
|
416 |
-
max_words = self.max_words - self.image_words
|
417 |
-
else:
|
418 |
-
max_words = self.max_words
|
419 |
-
padding = max_words - input2.shape[0]
|
420 |
-
if padding > 0:
|
421 |
-
input2 = torch.cat(
|
422 |
-
(input2, torch.zeros(padding, dtype=torch.int64) - 1))
|
423 |
-
labels = torch.cat(
|
424 |
-
(labels, torch.zeros(padding, dtype=torch.int64) - 1))
|
425 |
-
elif padding < 0:
|
426 |
-
input2 = input2[:max_words]
|
427 |
-
labels = labels[:max_words]
|
428 |
-
|
429 |
-
input2_mask = input2.ge(0)
|
430 |
-
label_mask = labels.ge(0)
|
431 |
-
input2[~input2_mask] = 0
|
432 |
-
labels[~label_mask] = 0
|
433 |
-
input2_mask = input2_mask.float()
|
434 |
-
label_mask = label_mask.float()
|
435 |
-
if image is None:
|
436 |
-
return input2, labels, data_item['data_type']
|
437 |
-
else:
|
438 |
-
return input2, labels, image, data_item['data_type']
|
439 |
-
|
440 |
-
def groups(self):
|
441 |
-
return list(self.group_indices.values())
|
442 |
-
|
443 |
-
|
444 |
-
def find_sublist(a: list, b: list):
|
445 |
-
len_a, len_b = len(a), len(b)
|
446 |
-
for i in range(len_a - len_b + 1):
|
447 |
-
if a[i:i+len_b] == b:
|
448 |
-
return i
|
449 |
-
return -1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/imu_utils.py
DELETED
@@ -1,257 +0,0 @@
|
|
1 |
-
import string
|
2 |
-
import numpy as np
|
3 |
-
import matplotlib.animation as animation
|
4 |
-
from matplotlib import pyplot as plt
|
5 |
-
import json
|
6 |
-
from collections import defaultdict
|
7 |
-
from bisect import bisect_left
|
8 |
-
import os
|
9 |
-
import torch
|
10 |
-
import torchaudio
|
11 |
-
torchaudio.set_audio_backend("sox_io")
|
12 |
-
|
13 |
-
|
14 |
-
def load_json(json_path: str):
|
15 |
-
"""
|
16 |
-
Load a json file
|
17 |
-
"""
|
18 |
-
with open(json_path, "r", encoding="utf-8") as f_name:
|
19 |
-
data = json.load(f_name)
|
20 |
-
return data
|
21 |
-
|
22 |
-
|
23 |
-
def check_window_signal(info_t, w_s, w_e):
|
24 |
-
length = w_e - w_s
|
25 |
-
frame_offset = int(w_s * info_t.sample_rate)
|
26 |
-
num_frames = int(length * info_t.sample_rate)
|
27 |
-
if frame_offset + num_frames > int(info_t.num_frames):
|
28 |
-
return False
|
29 |
-
else:
|
30 |
-
return True
|
31 |
-
|
32 |
-
|
33 |
-
def index_narrations(ann_path):
|
34 |
-
narration_raw = load_json(ann_path)
|
35 |
-
|
36 |
-
narration_dict = defaultdict(list)
|
37 |
-
summary_dict = defaultdict(list)
|
38 |
-
avg_len = []
|
39 |
-
for v_id, narr in narration_raw.items():
|
40 |
-
narr_list = []
|
41 |
-
summ_list = []
|
42 |
-
if "narration_pass_1" in narr:
|
43 |
-
narr_list += narr["narration_pass_1"]["narrations"]
|
44 |
-
summ_list += narr["narration_pass_1"]["summaries"]
|
45 |
-
if "narration_pass_2" in narr:
|
46 |
-
narr_list += narr["narration_pass_2"]["narrations"]
|
47 |
-
summ_list += narr["narration_pass_2"]["summaries"]
|
48 |
-
|
49 |
-
if len(narr_list) > 0:
|
50 |
-
narration_dict[v_id] = [
|
51 |
-
(
|
52 |
-
float(n_t["timestamp_sec"]),
|
53 |
-
n_t["narration_text"],
|
54 |
-
n_t["annotation_uid"],
|
55 |
-
n_t["timestamp_frame"],
|
56 |
-
)
|
57 |
-
for n_t in narr_list
|
58 |
-
]
|
59 |
-
avg_len.append(len(narration_dict[v_id]))
|
60 |
-
else:
|
61 |
-
narration_dict[v_id] = []
|
62 |
-
if len(summ_list) > 0:
|
63 |
-
summary_dict[v_id] = [
|
64 |
-
(
|
65 |
-
float(s_t["start_sec"]),
|
66 |
-
float(s_t["end_sec"]),
|
67 |
-
s_t["summary_text"],
|
68 |
-
)
|
69 |
-
for s_t in summ_list
|
70 |
-
]
|
71 |
-
else:
|
72 |
-
summary_dict[v_id] = []
|
73 |
-
# print(f"Number of Videos with narration {len(narration_dict)}")
|
74 |
-
# print(f"Avg. narration length {np.mean(avg_len)}")
|
75 |
-
# print(f"Number of Videos with summaries {len(summary_dict)}")
|
76 |
-
return narration_dict, summary_dict
|
77 |
-
|
78 |
-
|
79 |
-
def get_signal_info(signal_fn: str):
|
80 |
-
return torchaudio.info(signal_fn)
|
81 |
-
|
82 |
-
|
83 |
-
def get_signal_frames(signal_fn: str, video_start_sec: float, video_end_sec: float):
|
84 |
-
"""
|
85 |
-
Given a signal track return the frames between video_start_sec and video_end_sec
|
86 |
-
"""
|
87 |
-
info_t = get_signal_info(signal_fn)
|
88 |
-
|
89 |
-
length = video_end_sec - video_start_sec
|
90 |
-
aframes, _ = torchaudio.load(
|
91 |
-
signal_fn,
|
92 |
-
normalize=True,
|
93 |
-
frame_offset=int(video_start_sec * info_t.sample_rate),
|
94 |
-
num_frames=int(length * info_t.sample_rate),
|
95 |
-
)
|
96 |
-
return {"signal": aframes, "meta": info_t}
|
97 |
-
|
98 |
-
|
99 |
-
def tosec(value):
|
100 |
-
return value / 1000
|
101 |
-
|
102 |
-
|
103 |
-
def toms(value):
|
104 |
-
return value * 1000
|
105 |
-
|
106 |
-
|
107 |
-
def delta(first_num: float, second_num: float):
|
108 |
-
"""Compute the absolute value of the difference of two numbers"""
|
109 |
-
return abs(first_num - second_num)
|
110 |
-
|
111 |
-
|
112 |
-
def padIMU(signal, duration_sec):
|
113 |
-
"""
|
114 |
-
Pad the signal if necessary
|
115 |
-
"""
|
116 |
-
expected_elements = round(duration_sec) * 200
|
117 |
-
|
118 |
-
if signal.shape[0] > expected_elements:
|
119 |
-
signal = signal[:expected_elements, :]
|
120 |
-
elif signal.shape[0] < expected_elements:
|
121 |
-
padding = expected_elements - signal.shape[0]
|
122 |
-
padded_zeros = np.zeros((padding, 6))
|
123 |
-
signal = np.concatenate([signal, padded_zeros], 0)
|
124 |
-
# signal = signal[:expected_elements, :]
|
125 |
-
return signal
|
126 |
-
|
127 |
-
|
128 |
-
def resample(
|
129 |
-
signals: np.ndarray,
|
130 |
-
timestamps: np.ndarray,
|
131 |
-
original_sample_rate: int,
|
132 |
-
resample_rate: int,
|
133 |
-
):
|
134 |
-
"""
|
135 |
-
Resamples data to new sample rate
|
136 |
-
"""
|
137 |
-
signals = torch.as_tensor(signals)
|
138 |
-
timestamps = torch.from_numpy(timestamps).unsqueeze(-1)
|
139 |
-
signals = torchaudio.functional.resample(
|
140 |
-
waveform=signals.data.T,
|
141 |
-
orig_freq=original_sample_rate,
|
142 |
-
new_freq=resample_rate,
|
143 |
-
).T.numpy()
|
144 |
-
|
145 |
-
nsamples = len(signals)
|
146 |
-
|
147 |
-
period = 1 / resample_rate
|
148 |
-
|
149 |
-
# timestamps are expected to be shape (N, 1)
|
150 |
-
initital_seconds = timestamps[0] / 1e3
|
151 |
-
|
152 |
-
ntimes = (torch.arange(nsamples) * period).view(-1, 1) + initital_seconds
|
153 |
-
|
154 |
-
timestamps = (ntimes * 1e3).squeeze().numpy()
|
155 |
-
return signals, timestamps
|
156 |
-
|
157 |
-
|
158 |
-
def resampleIMU(signal, timestamps):
|
159 |
-
sampling_rate = int(1000 * (1 / (np.mean(np.diff(timestamps)))))
|
160 |
-
# resample all to 200hz
|
161 |
-
if sampling_rate != 200:
|
162 |
-
signal, timestamps = resample(signal, timestamps, sampling_rate, 200)
|
163 |
-
return signal, timestamps
|
164 |
-
|
165 |
-
|
166 |
-
def get_imu_frames(
|
167 |
-
imu_path,
|
168 |
-
uid: str,
|
169 |
-
video_start_sec: float,
|
170 |
-
video_end_sec: float,
|
171 |
-
):
|
172 |
-
"""
|
173 |
-
Given a IMU signal return the frames between video_start_sec and video_end_sec
|
174 |
-
"""
|
175 |
-
signal = np.load(os.path.join(imu_path, f"{uid}.npy"))
|
176 |
-
signal = signal.transpose()
|
177 |
-
timestamps = np.load(os.path.join(imu_path, f"{uid}_timestamps.npy"))
|
178 |
-
|
179 |
-
if toms(video_start_sec) > timestamps[-1] or toms(video_end_sec) > timestamps[-1]:
|
180 |
-
return None
|
181 |
-
|
182 |
-
start_id = bisect_left(timestamps, toms(video_start_sec))
|
183 |
-
end_id = bisect_left(timestamps, toms(video_end_sec))
|
184 |
-
|
185 |
-
# make sure the retrieved window interval are correct by a max of 1 sec margin
|
186 |
-
if (
|
187 |
-
delta(video_start_sec, tosec(timestamps[start_id])) > 4
|
188 |
-
or delta(video_end_sec, tosec(timestamps[end_id])) > 4
|
189 |
-
):
|
190 |
-
return None
|
191 |
-
|
192 |
-
# get the window
|
193 |
-
if start_id == end_id:
|
194 |
-
start_id -= 1
|
195 |
-
end_id += 1
|
196 |
-
signal, timestamps = signal[start_id:end_id], timestamps[start_id:end_id]
|
197 |
-
|
198 |
-
if len(signal) < 10 or len(timestamps) < 10:
|
199 |
-
return None
|
200 |
-
# resample the signal at 200hz if necessary
|
201 |
-
signal, timestamps = resampleIMU(signal, timestamps)
|
202 |
-
|
203 |
-
# pad the signal if necessary
|
204 |
-
signal = padIMU(signal, video_end_sec - video_start_sec)
|
205 |
-
|
206 |
-
sample_dict = {
|
207 |
-
"timestamp": timestamps,
|
208 |
-
"signal": torch.tensor(signal.T),
|
209 |
-
"sampling_rate": 200,
|
210 |
-
}
|
211 |
-
|
212 |
-
return sample_dict
|
213 |
-
|
214 |
-
|
215 |
-
def display_animation(frames, title, save_path_gif):
|
216 |
-
fig, ax = plt.subplots()
|
217 |
-
frames = [[ax.imshow(frames[i])] for i in range(len(frames))]
|
218 |
-
plt.title(title)
|
219 |
-
ani = animation.ArtistAnimation(fig, frames)
|
220 |
-
ani.save(save_path_gif, writer="imagemagick")
|
221 |
-
plt.close()
|
222 |
-
|
223 |
-
|
224 |
-
def display_animation_imu(frames, imu, title, save_path_gif):
|
225 |
-
fig, (ax1, ax2, ax3) = plt.subplots(3, 1)
|
226 |
-
ax1.set_title(title)
|
227 |
-
ax2.set_title("Acc.")
|
228 |
-
ax3.set_title("Gyro.")
|
229 |
-
frames = [[ax1.imshow(frames[i])] for i in range(len(frames))]
|
230 |
-
ani = animation.ArtistAnimation(fig, frames)
|
231 |
-
|
232 |
-
ax2.plot(imu[0].cpu().numpy(), color="red")
|
233 |
-
ax2.plot(imu[1].cpu().numpy(), color="blue")
|
234 |
-
ax2.plot(imu[2].cpu().numpy(), color="green")
|
235 |
-
ax3.plot(imu[3].cpu().numpy(), color="red")
|
236 |
-
ax3.plot(imu[4].cpu().numpy(), color="blue")
|
237 |
-
ax3.plot(imu[5].cpu().numpy(), color="green")
|
238 |
-
plt.tight_layout()
|
239 |
-
ani.save(save_path_gif, writer="imagemagick")
|
240 |
-
plt.close()
|
241 |
-
|
242 |
-
|
243 |
-
def filter_narration(narration_text: str) -> bool:
|
244 |
-
if "#c" in narration_text.lower():
|
245 |
-
return True
|
246 |
-
return False
|
247 |
-
|
248 |
-
|
249 |
-
def clean_narration_text(narration_text: str) -> str:
|
250 |
-
return (
|
251 |
-
narration_text.replace("#C C ", "")
|
252 |
-
.replace("#C", "")
|
253 |
-
.replace("#unsure", "something")
|
254 |
-
.strip()
|
255 |
-
.strip(string.punctuation)
|
256 |
-
.lower()[:128]
|
257 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/video_utils.py
DELETED
@@ -1,204 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
import torch
|
3 |
-
import torch.nn as nn
|
4 |
-
from pytorchvideo import transforms as pv_transforms
|
5 |
-
from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler
|
6 |
-
from pytorchvideo.data.encoded_video import EncodedVideo
|
7 |
-
from pytorchvideo.data.encoded_video_decord import EncodedVideoDecord
|
8 |
-
from torchvision import transforms
|
9 |
-
from torchvision.transforms._transforms_video import NormalizeVideo
|
10 |
-
|
11 |
-
|
12 |
-
def get_clip_timepoints(clip_sampler, duration):
|
13 |
-
# Read out all clips in this video
|
14 |
-
all_clips_timepoints = []
|
15 |
-
is_last_clip = False
|
16 |
-
end = 0.0
|
17 |
-
while not is_last_clip:
|
18 |
-
start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None)
|
19 |
-
all_clips_timepoints.append((start, end))
|
20 |
-
return all_clips_timepoints
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
def crop_boxes(boxes, x_offset, y_offset):
|
25 |
-
"""
|
26 |
-
Perform crop on the bounding boxes given the offsets.
|
27 |
-
Args:
|
28 |
-
boxes (ndarray or None): bounding boxes to perform crop. The dimension
|
29 |
-
is `num boxes` x 4.
|
30 |
-
x_offset (int): cropping offset in the x axis.
|
31 |
-
y_offset (int): cropping offset in the y axis.
|
32 |
-
Returns:
|
33 |
-
cropped_boxes (ndarray or None): the cropped boxes with dimension of
|
34 |
-
`num boxes` x 4.
|
35 |
-
"""
|
36 |
-
cropped_boxes = boxes.copy()
|
37 |
-
cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
|
38 |
-
cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
|
39 |
-
|
40 |
-
return cropped_boxes
|
41 |
-
|
42 |
-
|
43 |
-
def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
|
44 |
-
"""
|
45 |
-
Perform uniform spatial sampling on the images and corresponding boxes.
|
46 |
-
Args:
|
47 |
-
images (tensor): images to perform uniform crop. The dimension is
|
48 |
-
`num frames` x `channel` x `height` x `width`.
|
49 |
-
size (int): size of height and weight to crop the images.
|
50 |
-
spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
|
51 |
-
is larger than height. Or 0, 1, or 2 for top, center, and bottom
|
52 |
-
crop if height is larger than width.
|
53 |
-
boxes (ndarray or None): optional. Corresponding boxes to images.
|
54 |
-
Dimension is `num boxes` x 4.
|
55 |
-
scale_size (int): optinal. If not None, resize the images to scale_size before
|
56 |
-
performing any crop.
|
57 |
-
Returns:
|
58 |
-
cropped (tensor): images with dimension of
|
59 |
-
`num frames` x `channel` x `size` x `size`.
|
60 |
-
cropped_boxes (ndarray or None): the cropped boxes with dimension of
|
61 |
-
`num boxes` x 4.
|
62 |
-
"""
|
63 |
-
assert spatial_idx in [0, 1, 2]
|
64 |
-
ndim = len(images.shape)
|
65 |
-
if ndim == 3:
|
66 |
-
images = images.unsqueeze(0)
|
67 |
-
height = images.shape[2]
|
68 |
-
width = images.shape[3]
|
69 |
-
|
70 |
-
if scale_size is not None:
|
71 |
-
if width <= height:
|
72 |
-
width, height = scale_size, int(height / width * scale_size)
|
73 |
-
else:
|
74 |
-
width, height = int(width / height * scale_size), scale_size
|
75 |
-
images = torch.nn.functional.interpolate(
|
76 |
-
images,
|
77 |
-
size=(height, width),
|
78 |
-
mode="bilinear",
|
79 |
-
align_corners=False,
|
80 |
-
)
|
81 |
-
|
82 |
-
y_offset = int(math.ceil((height - size) / 2))
|
83 |
-
x_offset = int(math.ceil((width - size) / 2))
|
84 |
-
|
85 |
-
if height > width:
|
86 |
-
if spatial_idx == 0:
|
87 |
-
y_offset = 0
|
88 |
-
elif spatial_idx == 2:
|
89 |
-
y_offset = height - size
|
90 |
-
else:
|
91 |
-
if spatial_idx == 0:
|
92 |
-
x_offset = 0
|
93 |
-
elif spatial_idx == 2:
|
94 |
-
x_offset = width - size
|
95 |
-
cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size]
|
96 |
-
cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
|
97 |
-
if ndim == 3:
|
98 |
-
cropped = cropped.squeeze(0)
|
99 |
-
return cropped, cropped_boxes
|
100 |
-
|
101 |
-
|
102 |
-
class SpatialCrop(nn.Module):
|
103 |
-
"""
|
104 |
-
Convert the video into 3 smaller clips spatially. Must be used after the
|
105 |
-
temporal crops to get spatial crops, and should be used with
|
106 |
-
-2 in the spatial crop at the slowfast augmentation stage (so full
|
107 |
-
frames are passed in here). Will return a larger list with the
|
108 |
-
3x spatial crops as well.
|
109 |
-
"""
|
110 |
-
|
111 |
-
def __init__(self, crop_size: int = 224, num_crops: int = 3):
|
112 |
-
super().__init__()
|
113 |
-
self.crop_size = crop_size
|
114 |
-
if num_crops == 3:
|
115 |
-
self.crops_to_ext = [0, 1, 2]
|
116 |
-
self.flipped_crops_to_ext = []
|
117 |
-
elif num_crops == 1:
|
118 |
-
self.crops_to_ext = [1]
|
119 |
-
self.flipped_crops_to_ext = []
|
120 |
-
else:
|
121 |
-
raise NotImplementedError("Nothing else supported yet")
|
122 |
-
|
123 |
-
def forward(self, videos):
|
124 |
-
"""
|
125 |
-
Args:
|
126 |
-
videos: A list of C, T, H, W videos.
|
127 |
-
Returns:
|
128 |
-
videos: A list with 3x the number of elements. Each video converted
|
129 |
-
to C, T, H', W' by spatial cropping.
|
130 |
-
"""
|
131 |
-
assert isinstance(videos, list), "Must be a list of videos after temporal crops"
|
132 |
-
assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)"
|
133 |
-
res = []
|
134 |
-
for video in videos:
|
135 |
-
for spatial_idx in self.crops_to_ext:
|
136 |
-
res.append(uniform_crop(video, self.crop_size, spatial_idx)[0])
|
137 |
-
if not self.flipped_crops_to_ext:
|
138 |
-
continue
|
139 |
-
flipped_video = transforms.functional.hflip(video)
|
140 |
-
for spatial_idx in self.flipped_crops_to_ext:
|
141 |
-
res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0])
|
142 |
-
return res
|
143 |
-
|
144 |
-
|
145 |
-
def load_and_transform_video_data(
|
146 |
-
video_file,
|
147 |
-
video_path,
|
148 |
-
clip_duration=2,
|
149 |
-
clips_per_video=5,
|
150 |
-
sample_rate=16000,
|
151 |
-
with_audio=False
|
152 |
-
):
|
153 |
-
video_transform = transforms.Compose(
|
154 |
-
[
|
155 |
-
pv_transforms.ShortSideScale(224),
|
156 |
-
NormalizeVideo(
|
157 |
-
mean=(0.48145466, 0.4578275, 0.40821073),
|
158 |
-
std=(0.26862954, 0.26130258, 0.27577711),
|
159 |
-
),
|
160 |
-
]
|
161 |
-
)
|
162 |
-
|
163 |
-
clip_sampler = ConstantClipsPerVideoSampler(
|
164 |
-
clip_duration=clip_duration, clips_per_video=clips_per_video
|
165 |
-
)
|
166 |
-
frame_sampler = pv_transforms.UniformTemporalSubsample(num_samples=clip_duration)
|
167 |
-
|
168 |
-
if isinstance(video_file, str):
|
169 |
-
video = EncodedVideo.from_path(
|
170 |
-
video_file,
|
171 |
-
decoder="decord",
|
172 |
-
decode_audio=with_audio,
|
173 |
-
# **{"sample_rate": sample_rate},
|
174 |
-
)
|
175 |
-
else:
|
176 |
-
video = EncodedVideoDecord(video_file, video_name=video_path, decode_video=True, decode_audio=with_audio, sample_rate=sample_rate)
|
177 |
-
|
178 |
-
all_clips_timepoints = get_clip_timepoints(clip_sampler, video.duration)
|
179 |
-
|
180 |
-
all_video = []
|
181 |
-
for clip_timepoints in all_clips_timepoints:
|
182 |
-
# Read the clip, get frames
|
183 |
-
clip = video.get_clip(clip_timepoints[0], clip_timepoints[1])
|
184 |
-
if clip is None:
|
185 |
-
raise ValueError("No clip found")
|
186 |
-
video_clip = frame_sampler(clip["video"])
|
187 |
-
video_clip = video_clip / 255.0 # since this is float, need 0-1
|
188 |
-
|
189 |
-
all_video.append(video_clip)
|
190 |
-
|
191 |
-
all_video = [video_transform(clip) for clip in all_video]
|
192 |
-
all_video = SpatialCrop(224, num_crops=3)(all_video)
|
193 |
-
|
194 |
-
all_video = torch.stack(all_video, dim=0)
|
195 |
-
|
196 |
-
if not with_audio:
|
197 |
-
return all_video
|
198 |
-
else:
|
199 |
-
return all_video, clip['audio']
|
200 |
-
|
201 |
-
if __name__ == '__main__':
|
202 |
-
video_path = "datasets/InstructionTuning/video/music_aqa/MUSIC-AVQA-videos-Real/00000002.mp4"
|
203 |
-
video, audio = load_and_transform_video_data(video_path, video_path, clip_duration=1, clips_per_video=5, with_audio=True)
|
204 |
-
import pdb;pdb.set_trace()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demos/multi_turn_mm.py
DELETED
@@ -1,300 +0,0 @@
|
|
1 |
-
import sys
|
2 |
-
import os
|
3 |
-
sys.path.append(os.path.abspath(__file__).rsplit('/', 2)[0])
|
4 |
-
|
5 |
-
import argparse
|
6 |
-
import multiprocessing as mp
|
7 |
-
import numpy as np
|
8 |
-
from typing import List, Optional
|
9 |
-
|
10 |
-
import torch
|
11 |
-
import torch.distributed as dist
|
12 |
-
|
13 |
-
from fairscale.nn.model_parallel import initialize as fs_init
|
14 |
-
|
15 |
-
import gradio as gr
|
16 |
-
from util.misc import setup_for_distributed
|
17 |
-
from util.misc import default_tensor_type
|
18 |
-
from model.meta import MetaModel
|
19 |
-
from data.conversation_lib import conv_templates, SeparatorStyle
|
20 |
-
from PIL import Image
|
21 |
-
import torchvision.transforms as transforms
|
22 |
-
from data.fintune_dataset import make_audio_features
|
23 |
-
from data import video_utils
|
24 |
-
|
25 |
-
|
26 |
-
T_random_resized_crop = transforms.Compose([
|
27 |
-
transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(0.75, 1.3333), interpolation=3,
|
28 |
-
antialias=None), # 3 is bicubic
|
29 |
-
transforms.ToTensor(),
|
30 |
-
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
|
31 |
-
|
32 |
-
|
33 |
-
def load_audio(audio_path):
|
34 |
-
fbank = make_audio_features(audio_path, mel_bins=128)
|
35 |
-
fbank = fbank.transpose(0, 1)[None] #[1, 128, 1024]
|
36 |
-
return fbank
|
37 |
-
|
38 |
-
def load_video(video_path):
|
39 |
-
video_feats = video_utils.load_and_transform_video_data(video_path, video_path, clip_duration=1, clips_per_video=5)
|
40 |
-
return video_feats[:, :, 0]
|
41 |
-
|
42 |
-
|
43 |
-
def model_worker(
|
44 |
-
rank: int, args: argparse.Namespace, barrier: mp.Barrier,
|
45 |
-
request_queue: mp.Queue, response_queue: Optional[mp.Queue] = None,
|
46 |
-
) -> None:
|
47 |
-
"""
|
48 |
-
The worker function that manipulates the GPU to run the inference.
|
49 |
-
Exact n_gpu workers are started, with each one operating on a separate GPU.
|
50 |
-
|
51 |
-
Args:
|
52 |
-
rank (int): Distributed rank of the worker.
|
53 |
-
args (argparse.Namespace): All command line arguments.
|
54 |
-
barrier (multiprocessing.Barrier): A barrier used to delay the start
|
55 |
-
of Web UI to be after the start of the model.
|
56 |
-
"""
|
57 |
-
|
58 |
-
world_size = len(args.gpu_ids)
|
59 |
-
gpu_id = args.gpu_ids[rank]
|
60 |
-
dist.init_process_group(
|
61 |
-
backend="nccl", rank=rank, world_size=world_size,
|
62 |
-
init_method=f"tcp://{args.master_addr}:{args.master_port}",
|
63 |
-
)
|
64 |
-
print(f"| distributed init on worker {rank}/{world_size}. "
|
65 |
-
f"using gpu: {gpu_id}")
|
66 |
-
fs_init.initialize_model_parallel(world_size)
|
67 |
-
torch.cuda.set_device(gpu_id)
|
68 |
-
|
69 |
-
torch.manual_seed(1)
|
70 |
-
np.random.seed(1)
|
71 |
-
|
72 |
-
# set the print behavior.
|
73 |
-
setup_for_distributed(rank == 0)
|
74 |
-
|
75 |
-
target_dtype = {
|
76 |
-
"bf16": torch.bfloat16,
|
77 |
-
"fp16": torch.float16
|
78 |
-
}[args.dtype]
|
79 |
-
with default_tensor_type(dtype=target_dtype, device="cuda"):
|
80 |
-
model = MetaModel(args.llama_type, args.llama_config, tokenizer_path=args.tokenizer_path)
|
81 |
-
print("Loading pretrained weights ...")
|
82 |
-
checkpoint = torch.load(args.pretrained_path, map_location='cpu')
|
83 |
-
msg = model.load_state_dict(checkpoint, strict=False)
|
84 |
-
print("load result:\n", msg)
|
85 |
-
model.cuda()
|
86 |
-
model.eval()
|
87 |
-
print(f"Model = {str(model)}")
|
88 |
-
|
89 |
-
barrier.wait()
|
90 |
-
|
91 |
-
while True:
|
92 |
-
img_path, audio_path, video_path, chatbot, max_gen_len, temperature, top_p, modality = request_queue.get()
|
93 |
-
if 'image' in modality and img_path is not None:
|
94 |
-
image = Image.open(img_path).convert('RGB')
|
95 |
-
inputs = T_random_resized_crop(image)
|
96 |
-
elif 'video' in modality and video_path is not None:
|
97 |
-
inputs = load_video(video_path)
|
98 |
-
elif 'audio' in modality and audio_path is not None:
|
99 |
-
inputs = load_audio(audio_path)
|
100 |
-
else:
|
101 |
-
inputs = None
|
102 |
-
|
103 |
-
if inputs is not None:
|
104 |
-
inputs = inputs[None].cuda().to(target_dtype)
|
105 |
-
|
106 |
-
conv = conv_templates["v1"].copy()
|
107 |
-
for user, bot in chatbot:
|
108 |
-
conv.append_message(conv.roles[0], user)
|
109 |
-
conv.append_message(conv.roles[1], bot)
|
110 |
-
|
111 |
-
with torch.cuda.amp.autocast(dtype=target_dtype):
|
112 |
-
print(conv.get_prompt())
|
113 |
-
for stream_response in model.stream_generate(
|
114 |
-
conv.get_prompt(), inputs,
|
115 |
-
max_gen_len=max_gen_len, temperature=temperature, top_p=top_p,
|
116 |
-
modal = modality
|
117 |
-
):
|
118 |
-
conv_sep = (
|
119 |
-
conv.sep
|
120 |
-
if conv.sep_style == SeparatorStyle.SINGLE
|
121 |
-
else conv.sep2
|
122 |
-
)
|
123 |
-
end_pos = stream_response["text"].find(conv_sep)
|
124 |
-
if end_pos != -1:
|
125 |
-
stream_response["text"] = (
|
126 |
-
stream_response['text'][:end_pos].rstrip() + "\n"
|
127 |
-
)
|
128 |
-
stream_response["end_of_content"] = True
|
129 |
-
|
130 |
-
# keep a few characters if not end_of_content to avoid sending
|
131 |
-
# part of conv_sep before all of it is generated.
|
132 |
-
if not stream_response["end_of_content"]:
|
133 |
-
if len(stream_response["text"]) < len(conv_sep):
|
134 |
-
continue
|
135 |
-
stream_response["text"] = (
|
136 |
-
stream_response["text"][:-len(conv_sep)]
|
137 |
-
)
|
138 |
-
|
139 |
-
if response_queue is not None:
|
140 |
-
response_queue.put(stream_response)
|
141 |
-
|
142 |
-
if stream_response["end_of_content"]:
|
143 |
-
break
|
144 |
-
|
145 |
-
|
146 |
-
def gradio_worker(
|
147 |
-
request_queues: List[mp.Queue], response_queue: mp.Queue,
|
148 |
-
args: argparse.Namespace, barrier: mp.Barrier,
|
149 |
-
) -> None:
|
150 |
-
"""
|
151 |
-
The gradio worker is responsible for displaying the WebUI and relay the
|
152 |
-
requests to model workers. It should be launched only once.
|
153 |
-
|
154 |
-
Args:
|
155 |
-
request_queues (List[mp.Queue]): A list of request queues (one for
|
156 |
-
each model worker).
|
157 |
-
args (argparse.Namespace): All command line arguments.
|
158 |
-
barrier (multiprocessing.Barrier): A barrier used to delay the start
|
159 |
-
of Web UI to be after the start of the model.
|
160 |
-
"""
|
161 |
-
|
162 |
-
def show_user_input(msg, chatbot):
|
163 |
-
return "", chatbot + [[msg, None]]
|
164 |
-
|
165 |
-
def stream_model_output(img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality):
|
166 |
-
for queue in request_queues:
|
167 |
-
queue.put((img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality))
|
168 |
-
while True:
|
169 |
-
content_piece = response_queue.get()
|
170 |
-
chatbot[-1][1] = content_piece["text"]
|
171 |
-
yield chatbot
|
172 |
-
if content_piece["end_of_content"]:
|
173 |
-
break
|
174 |
-
|
175 |
-
def undo(chatbot):
|
176 |
-
if len(chatbot) > 0:
|
177 |
-
chatbot = chatbot[:-1]
|
178 |
-
return chatbot
|
179 |
-
|
180 |
-
def clear():
|
181 |
-
chatbot = []
|
182 |
-
msg = ""
|
183 |
-
return chatbot, msg
|
184 |
-
|
185 |
-
CSS ="""
|
186 |
-
.contain { display: flex; flex-direction: column; }
|
187 |
-
#component-0 { height: 100%; }
|
188 |
-
#chatbot { flex-grow: 1; overflow: auto;}
|
189 |
-
"""
|
190 |
-
with gr.Blocks(css=CSS) as demo:
|
191 |
-
gr.Markdown("## OneLLM: One Framework to Align All Modalities with Language")
|
192 |
-
with gr.Row(equal_height=True):
|
193 |
-
with gr.Column(scale=1):
|
194 |
-
img_path = gr.Image(label='Image Input', type='filepath')
|
195 |
-
video_path = gr.Video(label='Video Input')
|
196 |
-
audio_path = gr.Audio(label='Audio Input', type='filepath', sources=['upload'])
|
197 |
-
modality = gr.Radio(choices=['image', 'audio', 'video'], value='image', interactive=True, label='Input Modalities')
|
198 |
-
|
199 |
-
with gr.Column(scale=2):
|
200 |
-
chatbot = gr.Chatbot(elem_id="chatbot")
|
201 |
-
msg = gr.Textbox()
|
202 |
-
|
203 |
-
with gr.Row():
|
204 |
-
submit_button = gr.Button("Submit", variant="primary")
|
205 |
-
undo_button = gr.Button("Undo")
|
206 |
-
clear_button = gr.ClearButton([chatbot, msg, img_path, audio_path, video_path, modality])
|
207 |
-
with gr.Row():
|
208 |
-
max_gen_len = gr.Slider(
|
209 |
-
minimum=1, maximum=args.model_max_seq_len // 2,
|
210 |
-
value=args.model_max_seq_len // 2, interactive=True,
|
211 |
-
label="Single-turn max response length",
|
212 |
-
)
|
213 |
-
gen_t = gr.Slider(
|
214 |
-
minimum=0, maximum=1, value=0.1, interactive=True,
|
215 |
-
label="Temperature",
|
216 |
-
)
|
217 |
-
top_p = gr.Slider(
|
218 |
-
minimum=0, maximum=1, value=0.75, interactive=True,
|
219 |
-
label="Top-p",
|
220 |
-
)
|
221 |
-
msg.submit(
|
222 |
-
show_user_input, [msg, chatbot], [msg, chatbot],
|
223 |
-
).then(
|
224 |
-
stream_model_output, [img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality], chatbot,
|
225 |
-
)
|
226 |
-
submit_button.click(
|
227 |
-
show_user_input, [msg, chatbot], [msg, chatbot],
|
228 |
-
).then(
|
229 |
-
stream_model_output, [img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality], chatbot,
|
230 |
-
)
|
231 |
-
undo_button.click(undo, chatbot, chatbot)
|
232 |
-
# img_path.change(clear, [], [chatbot, msg])
|
233 |
-
barrier.wait()
|
234 |
-
demo.queue(api_open=True).launch(share=True, max_threads=1)
|
235 |
-
|
236 |
-
|
237 |
-
if __name__ == "__main__":
|
238 |
-
parser = argparse.ArgumentParser("Chat Demo")
|
239 |
-
group = parser.add_mutually_exclusive_group()
|
240 |
-
group.add_argument(
|
241 |
-
"--gpu_ids", type=int, nargs="+",
|
242 |
-
help="A list of space-separated gpu ids to run the model on. "
|
243 |
-
"The model will span across GPUs in tensor-parallel mode."
|
244 |
-
)
|
245 |
-
parser.add_argument(
|
246 |
-
"--tokenizer_path", type=str,
|
247 |
-
help="Path to the tokenizer.model file provided along with the LLaMA "
|
248 |
-
"model."
|
249 |
-
)
|
250 |
-
parser.add_argument(
|
251 |
-
"--llama_type", default="onellm", type=str, metavar="MODEL",
|
252 |
-
help="LLaMA model type."
|
253 |
-
)
|
254 |
-
parser.add_argument(
|
255 |
-
"--llama_config", type=str, required=True,
|
256 |
-
help="Path to the llama model config json."
|
257 |
-
)
|
258 |
-
parser.add_argument(
|
259 |
-
"--model_max_seq_len", type=int, default=2048,
|
260 |
-
help="Max sequence length accepted by the pretrained model."
|
261 |
-
)
|
262 |
-
parser.add_argument(
|
263 |
-
"--pretrained_path", type=str, required=True,
|
264 |
-
help="Path to the llama model checkpoints. A list of checkpoints is "
|
265 |
-
"supported and will be merged from left to right.")
|
266 |
-
parser.add_argument(
|
267 |
-
"--master_port", type=int, default=23862,
|
268 |
-
help="A port used by the PyTorch distributed module to initialize."
|
269 |
-
)
|
270 |
-
parser.add_argument(
|
271 |
-
"--master_addr", type=str, default="127.0.0.1",
|
272 |
-
help="An address used by the PyTorch distributed module to initialize."
|
273 |
-
)
|
274 |
-
parser.add_argument(
|
275 |
-
"--dtype", type=str, choices=["fp16", "bf16"], default="fp16",
|
276 |
-
help="The dtype used for model weights and inference."
|
277 |
-
)
|
278 |
-
args = parser.parse_args()
|
279 |
-
|
280 |
-
# using the default "fork" method messes up some imported libs (e.g.,
|
281 |
-
# pandas)
|
282 |
-
mp.set_start_method("spawn")
|
283 |
-
|
284 |
-
# setup the queues and start the model workers
|
285 |
-
request_queues = []
|
286 |
-
response_queue = mp.Queue()
|
287 |
-
worker_processes = []
|
288 |
-
barrier = mp.Barrier(len(args.gpu_ids) + 1)
|
289 |
-
for rank, gpu_id in enumerate(args.gpu_ids):
|
290 |
-
request_queue = mp.Queue()
|
291 |
-
rank_response_queue = response_queue if rank == 0 else None
|
292 |
-
process = mp.Process(
|
293 |
-
target=model_worker,
|
294 |
-
args=(rank, args, barrier, request_queue, rank_response_queue),
|
295 |
-
)
|
296 |
-
process.start()
|
297 |
-
worker_processes.append(process)
|
298 |
-
request_queues.append(request_queue)
|
299 |
-
|
300 |
-
gradio_worker(request_queues, response_queue, args, barrier)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/bell_ring.wav
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:415b5406175cafa874396f89811121fb306c17084db8ec20753ab5666b4fdcca
|
3 |
-
size 3630044
|
|
|
|
|
|
|
|
examples/bird_audio.wav
DELETED
Binary file (882 kB)
|
|
examples/depth_normal/depth/0084.png
DELETED
Binary file (10.9 kB)
|
|
examples/depth_normal/depth/0131.png
DELETED
Binary file (5.5 kB)
|
|
examples/depth_normal/depth/0297.png
DELETED
Binary file (11.1 kB)
|
|
examples/depth_normal/depth/0331.png
DELETED
Binary file (10.6 kB)
|
|
examples/depth_normal/depth/0432.png
DELETED
Binary file (7.89 kB)
|
|
examples/depth_normal/depth/0633.png
DELETED
Binary file (9.35 kB)
|
|
examples/depth_normal/depth/0663.png
DELETED
Binary file (5.68 kB)
|
|
examples/depth_normal/depth/0771.png
DELETED
Binary file (9.65 kB)
|
|
examples/depth_normal/depth/0782.png
DELETED
Binary file (9.58 kB)
|
|
examples/depth_normal/depth/1001.png
DELETED
Binary file (7.17 kB)
|
|
examples/depth_normal/depth/1051.png
DELETED
Binary file (6.81 kB)
|
|
examples/depth_normal/depth/1129.png
DELETED
Binary file (6.66 kB)
|
|
examples/depth_normal/depth/1205.png
DELETED
Binary file (9.24 kB)
|
|
examples/depth_normal/depth/1336.png
DELETED
Binary file (11.5 kB)
|
|
examples/depth_normal/depth/1383.png
DELETED
Binary file (9.98 kB)
|
|
examples/depth_normal/depth/1386.png
DELETED
Binary file (12.2 kB)
|
|
examples/depth_normal/depth/1393.png
DELETED
Binary file (8.6 kB)
|
|
examples/depth_normal/depth/1447.png
DELETED
Binary file (10.3 kB)
|
|
examples/depth_normal/depth_scaled/0084.png
DELETED
Binary file (11.1 kB)
|
|
examples/depth_normal/depth_scaled/0131.png
DELETED
Binary file (5.68 kB)
|
|
examples/depth_normal/depth_scaled/0297.png
DELETED
Binary file (11.8 kB)
|
|
examples/depth_normal/depth_scaled/0331.png
DELETED
Binary file (11.1 kB)
|
|
examples/depth_normal/depth_scaled/0432.png
DELETED
Binary file (8.24 kB)
|
|
examples/depth_normal/depth_scaled/0633.png
DELETED
Binary file (9.96 kB)
|
|
examples/depth_normal/depth_scaled/0663.png
DELETED
Binary file (5.91 kB)
|
|
examples/depth_normal/depth_scaled/0771.png
DELETED
Binary file (10.1 kB)
|
|
examples/depth_normal/depth_scaled/0782.png
DELETED
Binary file (9.98 kB)
|
|
examples/depth_normal/depth_scaled/1001.png
DELETED
Binary file (7.19 kB)
|
|
examples/depth_normal/depth_scaled/1051.png
DELETED
Binary file (7.05 kB)
|
|
examples/depth_normal/depth_scaled/1129.png
DELETED
Binary file (6.9 kB)
|
|
examples/depth_normal/depth_scaled/1205.png
DELETED
Binary file (9.58 kB)
|
|