Jiaming Han commited on
Commit
3c55139
·
1 Parent(s): 5b036b0
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +4 -4
  2. app.py +90 -436
  3. config/llama2/7B.json +0 -1
  4. config/llama2/tokenizer.model +0 -3
  5. data/__pycache__/conversation_lib.cpython-310.pyc +0 -0
  6. data/__pycache__/conversation_lib.cpython-39.pyc +0 -0
  7. data/__pycache__/fintune_dataset.cpython-310.pyc +0 -0
  8. data/__pycache__/fintune_dataset.cpython-39.pyc +0 -0
  9. data/__pycache__/imu_utils.cpython-310.pyc +0 -0
  10. data/__pycache__/imu_utils.cpython-39.pyc +0 -0
  11. data/__pycache__/video_utils.cpython-310.pyc +0 -0
  12. data/__pycache__/video_utils.cpython-39.pyc +0 -0
  13. data/conversation_lib.py +0 -369
  14. data/fintune_dataset.py +0 -449
  15. data/imu_utils.py +0 -257
  16. data/video_utils.py +0 -204
  17. demos/multi_turn_mm.py +0 -300
  18. examples/bell_ring.wav +0 -3
  19. examples/bird_audio.wav +0 -0
  20. examples/depth_normal/depth/0084.png +0 -0
  21. examples/depth_normal/depth/0131.png +0 -0
  22. examples/depth_normal/depth/0297.png +0 -0
  23. examples/depth_normal/depth/0331.png +0 -0
  24. examples/depth_normal/depth/0432.png +0 -0
  25. examples/depth_normal/depth/0633.png +0 -0
  26. examples/depth_normal/depth/0663.png +0 -0
  27. examples/depth_normal/depth/0771.png +0 -0
  28. examples/depth_normal/depth/0782.png +0 -0
  29. examples/depth_normal/depth/1001.png +0 -0
  30. examples/depth_normal/depth/1051.png +0 -0
  31. examples/depth_normal/depth/1129.png +0 -0
  32. examples/depth_normal/depth/1205.png +0 -0
  33. examples/depth_normal/depth/1336.png +0 -0
  34. examples/depth_normal/depth/1383.png +0 -0
  35. examples/depth_normal/depth/1386.png +0 -0
  36. examples/depth_normal/depth/1393.png +0 -0
  37. examples/depth_normal/depth/1447.png +0 -0
  38. examples/depth_normal/depth_scaled/0084.png +0 -0
  39. examples/depth_normal/depth_scaled/0131.png +0 -0
  40. examples/depth_normal/depth_scaled/0297.png +0 -0
  41. examples/depth_normal/depth_scaled/0331.png +0 -0
  42. examples/depth_normal/depth_scaled/0432.png +0 -0
  43. examples/depth_normal/depth_scaled/0633.png +0 -0
  44. examples/depth_normal/depth_scaled/0663.png +0 -0
  45. examples/depth_normal/depth_scaled/0771.png +0 -0
  46. examples/depth_normal/depth_scaled/0782.png +0 -0
  47. examples/depth_normal/depth_scaled/1001.png +0 -0
  48. examples/depth_normal/depth_scaled/1051.png +0 -0
  49. examples/depth_normal/depth_scaled/1129.png +0 -0
  50. examples/depth_normal/depth_scaled/1205.png +0 -0
README.md CHANGED
@@ -1,13 +1,13 @@
1
  ---
2
- title: OneLLM
3
  emoji: 🚀
4
  colorFrom: red
5
  colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 4.7.1
8
  app_file: app.py
9
  pinned: false
10
- python_version: 3.9.18
11
  ---
12
 
13
- # OneLLM: One Framework to Align All Modalities with Language
 
1
  ---
2
+ title: Tar
3
  emoji: 🚀
4
  colorFrom: red
5
  colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 5.34.0
8
  app_file: app.py
9
  pinned: false
10
+ python_version: 3.10.18
11
  ---
12
 
13
+ # Tar: Unifying Visual Understanding and Generation via Text-Aligned Representations
app.py CHANGED
@@ -1,457 +1,111 @@
1
- import sys
2
  import os
3
- import argparse
4
- import multiprocessing as mp
5
- import numpy as np
6
- from typing import List, Optional
7
-
8
- import torch
9
- import torch.distributed as dist
10
-
11
- from fairscale.nn.model_parallel import initialize as fs_init
12
-
13
  import gradio as gr
14
- from util.misc import setup_for_distributed
15
- from util.misc import default_tensor_type
16
- from model.meta import MetaModel
17
- from data.conversation_lib import conv_templates, SeparatorStyle
18
- from PIL import Image
19
- import torchvision.transforms as transforms
20
- from data.fintune_dataset import make_audio_features
21
- from data import video_utils
22
- from dataclasses import dataclass
23
- from huggingface_hub import hf_hub_download
24
- import plotly.graph_objects as go
25
- from data.fintune_dataset import pc_norm
26
- from functools import partial
27
- import glob
28
- import torchvision.transforms.functional as F
29
 
30
- T_random_resized_crop = transforms.Compose([
31
- transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(0.75, 1.3333), interpolation=3,
32
- antialias=None), # 3 is bicubic
33
- transforms.ToTensor(),
34
- transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
35
 
36
- class PairRandomResizedCrop(transforms.RandomResizedCrop):
37
- def forward(self, imgs):
38
- i, j, h, w = self.get_params(imgs[0], self.scale, self.ratio)
39
- return [F.resized_crop(img, i, j, h, w, self.size, self.interpolation, antialias=self.antialias) for img in imgs]
40
-
41
- class PairToTensor(transforms.ToTensor):
42
- def __call__(self, pics):
43
- return [F.to_tensor(pic) for pic in pics]
44
-
45
- class PairNormalize(transforms.Normalize):
46
- def forward(self, tensors):
47
- return [F.normalize(tensor, self.mean, self.std, self.inplace) for tensor in tensors]
48
 
49
- transform_pairimg_train = transforms.Compose([
50
- PairRandomResizedCrop(size=(224, 224), scale=(0.99, 1.0), ratio=(0.75, 1.3333), interpolation=3, antialias=None), # 3 is bicubic
51
- PairToTensor(),
52
- PairNormalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
53
-
54
- def load_audio(audio_path):
55
- fbank = make_audio_features(audio_path, mel_bins=128)
56
- fbank = fbank.transpose(0, 1)[None] #[1, 128, 1024]
57
- return fbank
58
 
59
- def load_video(video_path):
60
- video_feats = video_utils.load_and_transform_video_data(video_path, video_path, clip_duration=1, clips_per_video=5)
61
- return video_feats[:, :, 0]
62
-
63
- def load_point(point_path):
64
- point_feat = np.load(point_path)
65
- point_feat = torch.tensor(point_feat)
66
- point_feat = pc_norm(point_feat)
67
- return point_feat
68
-
69
- def load_fmri(fmri_path):
70
- data = np.load(fmri_path)
71
- data = data.mean(axis=0)
72
- data = torch.tensor(data[None])
73
- return data
74
-
75
- def load_rgbx(image_path, x_image_path):
76
- # trick: replace path if 'depth_scaled' in path
77
- x_image_path = x_image_path.replace('depth_scaled', 'depth')
78
-
79
- image = Image.open(image_path).convert('RGB')
80
- x_image = Image.open(x_image_path).convert('RGB')
81
- x_image = x_image.resize(image.size[-2:])
82
-
83
- image, x_image = transform_pairimg_train([image, x_image])
84
-
85
- # [2, 3, H, W]
86
- image = torch.stack([image, x_image], dim=0)
87
- return image
88
-
89
-
90
- class Ready: pass
91
-
92
-
93
- def model_worker(
94
- rank: int, args: argparse.Namespace, barrier: mp.Barrier,
95
- request_queue: mp.Queue, response_queue: Optional[mp.Queue] = None,
96
- ) -> None:
97
- """
98
- The worker function that manipulates the GPU to run the inference.
99
- Exact n_gpu workers are started, with each one operating on a separate GPU.
100
-
101
- Args:
102
- rank (int): Distributed rank of the worker.
103
- args (argparse.Namespace): All command line arguments.
104
- barrier (multiprocessing.Barrier): A barrier used to delay the start
105
- of Web UI to be after the start of the model.
106
- """
107
-
108
- world_size = len(args.gpu_ids)
109
- gpu_id = args.gpu_ids[rank]
110
- dist.init_process_group(
111
- backend="nccl", rank=rank, world_size=world_size,
112
- init_method=f"tcp://{args.master_addr}:{args.master_port}",
113
- )
114
- print(f"| distributed init on worker {rank}/{world_size}. "
115
- f"using gpu: {gpu_id}")
116
- fs_init.initialize_model_parallel(world_size)
117
- torch.cuda.set_device(gpu_id)
118
-
119
- torch.manual_seed(1)
120
- np.random.seed(1)
121
-
122
- # set the print behavior.
123
- setup_for_distributed(rank == 0)
124
-
125
- target_dtype = {
126
- "bf16": torch.bfloat16,
127
- "fp16": torch.float16
128
- }[args.dtype]
129
- with default_tensor_type(dtype=target_dtype, device="cuda"):
130
- model = MetaModel(args.llama_type, args.llama_config, tokenizer_path=args.tokenizer_path)
131
- for ckpt_id in range(args.num_ckpts):
132
- ckpt_path = hf_hub_download(repo_id=args.pretrained_path, filename=args.ckpt_format.format(str(ckpt_id)))
133
- # ckpt_path = os.path.join(args.pretrained_path, args.ckpt_format.format(str(ckpt_id)))
134
- print(f"Loading pretrained weights {ckpt_path}")
135
- checkpoint = torch.load(ckpt_path, map_location='cpu')
136
- msg = model.load_state_dict(checkpoint, strict=False)
137
- # print("load result:\n", msg)
138
- model.cuda()
139
- model.eval()
140
- print(f"Model = {str(model)}")
141
-
142
- barrier.wait()
143
-
144
- while True:
145
- if response_queue is not None:
146
- response_queue.put(Ready())
147
- img_path, audio_path, video_path, point_path, fmri_path, depth_path, depth_rgb_path, normal_path, normal_rgb_path, chatbot, max_gen_len, temperature, top_p, modality = request_queue.get()
148
- try:
149
- if 'image' in modality and img_path is not None:
150
- image = Image.open(img_path).convert('RGB')
151
- inputs = T_random_resized_crop(image)
152
- elif 'video' in modality and video_path is not None:
153
- inputs = load_video(video_path)
154
- elif 'audio' in modality and audio_path is not None:
155
- inputs = load_audio(audio_path)
156
- elif 'point' in modality and point_path is not None:
157
- inputs = load_point(point_path)
158
- elif 'fmri' in modality and fmri_path is not None:
159
- inputs = load_fmri(fmri_path)
160
- elif 'rgbd' in modality and depth_path is not None and depth_rgb_path is not None:
161
- inputs = load_rgbx(depth_rgb_path, depth_path)
162
- elif 'rgbn' in modality and normal_path is not None and normal_rgb_path is not None:
163
- inputs = load_rgbx(normal_rgb_path, normal_path)
164
- else:
165
- inputs = None
166
- except:
167
- inputs = None
168
-
169
- if inputs is not None:
170
- inputs = inputs[None].cuda().to(target_dtype)
171
 
172
- conv = conv_templates["v1"].copy()
173
- for user, bot in chatbot:
174
- conv.append_message(conv.roles[0], user)
175
- conv.append_message(conv.roles[1], bot)
176
-
177
- with torch.cuda.amp.autocast(dtype=target_dtype):
178
- print(conv.get_prompt())
179
- for stream_response in model.stream_generate(
180
- conv.get_prompt(), inputs,
181
- max_gen_len=max_gen_len, temperature=temperature, top_p=top_p,
182
- modal = modality
183
- ):
184
- conv_sep = (
185
- conv.sep
186
- if conv.sep_style == SeparatorStyle.SINGLE
187
- else conv.sep2
188
- )
189
- end_pos = stream_response["text"].find(conv_sep)
190
- if end_pos != -1:
191
- stream_response["text"] = (
192
- stream_response['text'][:end_pos].rstrip() + "\n"
193
- )
194
- stream_response["end_of_content"] = True
195
-
196
- # keep a few characters if not end_of_content to avoid sending
197
- # part of conv_sep before all of it is generated.
198
- if not stream_response["end_of_content"]:
199
- if len(stream_response["text"]) < len(conv_sep):
200
- continue
201
- stream_response["text"] = (
202
- stream_response["text"][:-len(conv_sep)]
203
- )
204
-
205
- if response_queue is not None:
206
- response_queue.put(stream_response)
207
-
208
- if stream_response["end_of_content"]:
209
- break
210
-
211
-
212
- def gradio_worker(
213
- request_queues: List[mp.Queue], response_queue: mp.Queue,
214
- args: argparse.Namespace, barrier: mp.Barrier,
215
- ) -> None:
216
- """
217
- The gradio worker is responsible for displaying the WebUI and relay the
218
- requests to model workers. It should be launched only once.
219
 
220
- Args:
221
- request_queues (List[mp.Queue]): A list of request queues (one for
222
- each model worker).
223
- args (argparse.Namespace): All command line arguments.
224
- barrier (multiprocessing.Barrier): A barrier used to delay the start
225
- of Web UI to be after the start of the model.
226
- """
227
 
228
- def show_user_input(msg, chatbot):
229
- return "", chatbot + [[msg, None]]
230
 
231
- def stream_model_output(img_path, audio_path, video_path, point_path, fmri_path, depth_path, depth_rgb_path, normal_path, normal_rgb_path, chatbot, max_gen_len, gen_t, top_p, modality):
232
- while True:
233
- content_piece = response_queue.get()
234
- if isinstance(content_piece, Ready):
235
- break
236
- for queue in request_queues:
237
- queue.put((img_path, audio_path, video_path, point_path, fmri_path, depth_path, depth_rgb_path, normal_path, normal_rgb_path, chatbot, max_gen_len, gen_t, top_p, modality))
238
- while True:
239
- content_piece = response_queue.get()
240
- chatbot[-1][1] = content_piece["text"]
241
- yield chatbot
242
- if content_piece["end_of_content"]:
243
- break
244
 
245
- def undo(chatbot):
246
- if len(chatbot) > 0:
247
- chatbot = chatbot[:-1]
248
- return chatbot
249
 
250
- def clear():
251
- chatbot = []
252
- msg = ""
253
- return chatbot, msg
254
-
255
- def show_point_cloud(file):
256
- point = load_point(file).numpy()
257
- fig = go.Figure(
258
- data=[
259
- go.Scatter3d(
260
- x=point[:,0], y=point[:,1], z=point[:,2],
261
- mode='markers',
262
- marker=dict(
263
- size=1.2,
264
- color=['rgb({},{},{})'.format(r, g, b) for r,g,b in zip(point[:,3], point[:,4], point[:,5])]
265
- ))],
266
- layout=dict(
267
- scene=dict(
268
- xaxis=dict(visible=False),
269
- yaxis=dict(visible=False),
270
- zaxis=dict(visible=False)
271
- )),)
272
- return fig
273
-
274
- def change_modality(modal):
275
- return modal
276
 
277
- CSS ="""
278
- .contain { display: flex; flex-direction: column; }
279
- #component-0 { height: 100%; }
280
- #chatbot { flex-grow: 1; overflow: auto;}
281
- """
282
 
283
- header="""
284
- ## OneLLM: One Framework to Align All Modalities with Language
285
- [[Project Page](https://onellm.csuhan.com)] [[Paper](https://arxiv.org/abs/2312.03700)] [[Code](https://github.com/csuhan/OneLLM)]
286
- """
287
-
288
- with gr.Blocks(css=CSS, theme=gr.themes.Base()) as demo:
289
- gr.Markdown(header)
290
- with gr.Row(equal_height=True):
291
- modality = gr.Textbox(value='image', visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  with gr.Column(scale=1):
293
- with gr.Tab('Image') as img_tab:
294
- img_path = gr.Image(label='Image Input', type='filepath')
295
- gr.Examples(
296
- examples=[
297
- "examples/new_york.jpg",
298
- "examples/food_menu.png",
299
- ],
300
- inputs=[img_path],
301
- )
302
- with gr.Tab('Video') as video_tab:
303
- video_path = gr.Video(label='Video Input', max_length=180)
304
- gr.Examples(
305
- examples=[
306
- "examples/flower.mp4",
307
- "examples/star_kun.mp4",
308
- ],
309
- inputs=[video_path],
310
- )
311
- with gr.Tab('Audio') as audio_tab:
312
- audio_path = gr.Audio(label='Audio Input', type='filepath', sources=['upload'])
313
- gr.Examples(
314
- examples=[
315
- "examples/bell_ring.wav",
316
- "examples/bird_audio.wav",
317
- ],
318
- inputs=[audio_path],
319
- )
320
- with gr.Tab('Point Cloud') as point_tab:
321
- point_path = gr.File(label='Point Cloud Input', elem_id="pointpath", elem_classes="")
322
- point_vis = gr.Plot()
323
- btn = gr.Button(value="Show Point Cloud")
324
- btn.click(show_point_cloud, point_path, point_vis)
325
- gr.Examples(
326
- examples=glob.glob("examples/point/*.npy"),
327
- inputs=[point_path],
328
- examples_per_page=5,
329
- )
330
- with gr.Tab('IMU') as imu_tab:
331
- gr.Markdown('Coming soon🤗')
332
- with gr.Tab('fMRI') as fmri_tab:
333
- fmri_path = gr.File(label='fMRI Input', elem_id="fmripath", elem_classes="")
334
- fmri_image_path = gr.Image(label='Reference Image', interactive=False)
335
- gr.Examples(
336
- examples=[
337
- [file.replace('.jpg', '.npy'), file]
338
- for file in glob.glob("examples/fmri/*.jpg")
339
- ],
340
- inputs=[fmri_path, fmri_image_path],
341
- examples_per_page=3,
342
- )
343
- with gr.Tab('Depth Map') as depth_tab:
344
- depth_path = gr.Image(label='Depth Map', type='filepath')
345
- depth_rgb_path = gr.Image(label='RGB Image', type='filepath')
346
- gr.Examples(
347
- examples=[
348
- [rgb_image.replace('rgb', 'depth_scaled'), rgb_image]
349
- for rgb_image in glob.glob("examples/depth_normal/rgb/*.png")[:9]
350
- ],
351
- inputs=[depth_path, depth_rgb_path],
352
- examples_per_page=3,
353
- )
354
- with gr.Tab('Normal Map') as normal_tab:
355
- normal_path = gr.Image(label='Normal Map', type='filepath')
356
- normal_rgb_path = gr.Image(label='RGB Image', type='filepath')
357
- gr.Examples(
358
- examples=[
359
- [rgb_image.replace('rgb', 'normal'), rgb_image]
360
- for rgb_image in glob.glob("examples/depth_normal/rgb/*.png")[9:]
361
- ],
362
- inputs=[normal_path, normal_rgb_path],
363
- examples_per_page=3,
364
- )
365
- with gr.Column(scale=2):
366
- chatbot = gr.Chatbot(elem_id="chatbot")
367
- msg = gr.Textbox()
368
-
369
  with gr.Row():
370
- submit_button = gr.Button("Submit", variant="primary")
371
- undo_button = gr.Button("Undo")
372
- clear_button = gr.ClearButton([chatbot, msg, img_path, audio_path, video_path, point_path, fmri_path, depth_path, depth_rgb_path, normal_path, normal_rgb_path, point_vis])
373
- with gr.Row():
374
- max_gen_len = gr.Slider(
375
- minimum=1, maximum=args.model_max_seq_len // 2,
376
- value=args.model_max_seq_len // 2, interactive=True,
377
- label="Single-turn max response length",
378
- )
379
- gen_t = gr.Slider(
380
- minimum=0, maximum=1, value=0.1, interactive=True,
381
- label="Temperature",
382
- )
383
- top_p = gr.Slider(
384
- minimum=0, maximum=1, value=0.75, interactive=True,
385
- label="Top-p",
386
- )
387
-
388
- img_tab.select(partial(change_modality, 'image'), [], [modality])
389
- video_tab.select(partial(change_modality, 'video'), [], [modality])
390
- audio_tab.select(partial(change_modality, 'audio'), [], [modality])
391
- point_tab.select(partial(change_modality, 'point'), [], [modality])
392
- fmri_tab.select(partial(change_modality, 'fmri'), [], [modality])
393
- depth_tab.select(partial(change_modality, 'rgbd'), [], [modality])
394
- normal_tab.select(partial(change_modality, 'rgbn'), [], [modality])
395
-
396
- img_path.change(clear, [], [chatbot, msg])
397
- audio_path.change(clear, [], [chatbot, msg])
398
- video_path.change(clear, [], [chatbot, msg])
399
- point_path.change(clear, [], [chatbot, msg])
400
- fmri_path.change(clear, [], [chatbot, msg])
401
- depth_path.change(clear, [], [chatbot, msg])
402
- normal_path.change(clear, [], [chatbot, msg])
403
 
404
- msg.submit(
405
- show_user_input, [msg, chatbot], [msg, chatbot],
406
- ).then(
407
- stream_model_output, [img_path, audio_path, video_path, point_path, fmri_path, depth_path, depth_rgb_path, normal_path, normal_rgb_path, chatbot, max_gen_len, gen_t, top_p, modality], chatbot,
408
  )
409
- submit_button.click(
410
- show_user_input, [msg, chatbot], [msg, chatbot],
411
- ).then(
412
- stream_model_output, [img_path, audio_path, video_path, point_path, fmri_path, depth_path, depth_rgb_path, normal_path, normal_rgb_path, chatbot, max_gen_len, gen_t, top_p, modality], chatbot,
413
- )
414
- undo_button.click(undo, chatbot, chatbot)
415
- barrier.wait()
416
- demo.queue(api_open=True).launch(share=True, max_threads=1)
417
-
418
-
419
- @dataclass
420
- class DemoConfig:
421
- gpu_ids = [0]
422
- tokenizer_path = "config/llama2/tokenizer.model"
423
- llama_type = "onellm"
424
- llama_config = "config/llama2/7B.json"
425
- model_max_seq_len = 2048
426
- pretrained_path = "csuhan/OneLLM-7B-hf"
427
- # pretrained_path = "/home/pgao/jiaming/weights/7B_v20_splits/"
428
- ckpt_format = "consolidated.00-of-01.s{}.pth"
429
- num_ckpts = 10
430
- master_port = 23863
431
- master_addr = "127.0.0.1"
432
- dtype = "fp16"
433
-
434
- if __name__ == "__main__":
435
- args = DemoConfig()
436
-
437
- # using the default "fork" method messes up some imported libs (e.g.,
438
- # pandas)
439
- # mp.set_start_method("spawn")
440
 
441
- # setup the queues and start the model workers
442
- request_queues = []
443
- response_queue = mp.Queue()
444
- worker_processes = []
445
- barrier = mp.Barrier(len(args.gpu_ids) + 1)
446
- for rank, gpu_id in enumerate(args.gpu_ids):
447
- request_queue = mp.Queue()
448
- rank_response_queue = response_queue if rank == 0 else None
449
- process = mp.Process(
450
- target=model_worker,
451
- args=(rank, args, barrier, request_queue, rank_response_queue),
452
  )
453
- process.start()
454
- worker_processes.append(process)
455
- request_queues.append(request_queue)
456
 
457
- gradio_worker(request_queues, response_queue, args, barrier)
 
 
1
  import os
 
 
 
 
 
 
 
 
 
 
2
  import gradio as gr
3
+ from torchvision.transforms.functional import to_tensor
4
+ from huggingface_hub import hf_hub_download, login
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ from t2i_inference import T2IConfig, TextToImageInference
 
 
 
 
7
 
8
+ def generate_text(self, image: str, prompt: str) -> str:
9
+ image = image.convert('RGB')
10
+ image = to_tensor(image).unsqueeze(0).to(self.device)
 
 
 
 
 
 
 
 
 
11
 
12
+ image_code = self.visual_tokenizer.encoder(image)['bottleneck_rep']
13
+ image_text = "".join([f"<I{x}>" for x in image_code[0].cpu().tolist()])
 
 
 
 
 
 
 
14
 
15
+ messages = [
16
+ {"role": "system", "content": "You are a helpful assistant."},
17
+ {"role": "user", "content": f"{image_text}\n{prompt}"}
18
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ input_text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
21
+ inputs = self.tokenizer(input_text, return_tensors="pt")
22
+
23
+ gen_ids = self.model.generate(
24
+ inputs.input_ids.to(self.device),
25
+ max_new_tokens=512,
26
+ do_sample=True)
27
+ return self.tokenizer.batch_decode(gen_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)[0]
28
+
29
+ login(token=os.getenv('HF_TOKEN'))
30
+ config = T2IConfig()
31
+ config.ar_path = hf_hub_download("csuhan/TA-Tok", "ar_dtok_lp_512px.pth")
32
+ config.encoder_path = hf_hub_download("csuhan/TA-Tok", "ta_tok.pth")
33
+ config.decoder_path = hf_hub_download("peizesun/llamagen_t2i", "vq_ds16_t2i.pt")
34
+ inference = TextToImageInference(config)
35
+
36
+ def generate_image(prompt, top_p, top_k, cfg_scale):
37
+ config.top_p = top_p
38
+ config.top_k = top_k
39
+ config.cfg_scale = cfg_scale
40
+ image = inference.generate_image(prompt)
41
+ return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ def clear_inputs_t2i():
44
+ return "", None
 
 
 
 
 
45
 
46
+ def understand_image(image, prompt):
47
+ return generate_text(inference, image, prompt)
48
 
49
+ def clear_inputs_i2t():
50
+ return None, ""
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
53
+ gr.Markdown(
54
+ """
55
+ <div align="center">
56
 
57
+ ### Tar: Unifying Visual Understanding and Generation via Text-Aligned Representations
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ [📄 Paper](https://arxiv.org/abs/xxxx.xxxxx) • [💻 Code](https://github.com/csuhan/Tar) • [📦 Model](https://huggingface.co/csuhan/TA-Tok)
 
 
 
 
60
 
61
+ </div>
62
+ """,
63
+ elem_id="title",
64
+ )
65
+ with gr.Tab("Image Generation"):
66
+ with gr.Row():
67
+ with gr.Column(scale=1):
68
+ prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt")
69
+ with gr.Accordion("Advanced Settings", open=False):
70
+ top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")
71
+ top_k = gr.Slider(1, 2000, value=1200, step=10, label="Top-k")
72
+ cfg_scale = gr.Slider(1.0, 20.0, value=4.0, step=0.5, label="CFG Scale")
73
+ with gr.Row():
74
+ generate_btn = gr.Button("Generate")
75
+ clear_btn = gr.Button("Clear")
76
+ with gr.Column(scale=2):
77
+ output_image = gr.Image(label="Generated Image")
78
+
79
+ generate_btn.click(
80
+ generate_image,
81
+ inputs=[prompt, top_p, top_k, cfg_scale],
82
+ outputs=output_image
83
+ )
84
+ clear_btn.click(
85
+ clear_inputs_t2i,
86
+ outputs=[prompt, output_image]
87
+ )
88
+
89
+ with gr.Tab("Image Understanding"):
90
+ with gr.Row():
91
  with gr.Column(scale=1):
92
+ image_input = gr.Image(label="Upload Image", type="pil")
93
+ question_input = gr.Textbox(label="Instruction", value="Describe the image shortly.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  with gr.Row():
95
+ qa_btn = gr.Button("Generate")
96
+ clear_btn_i2t = gr.Button("Clear")
97
+ with gr.Column(scale=1):
98
+ answer_output = gr.Textbox(label="Response", lines=4)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
+ qa_btn.click(
101
+ understand_image,
102
+ inputs=[image_input, question_input],
103
+ outputs=answer_output
104
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
+ clear_btn_i2t.click(
107
+ clear_inputs_i2t,
108
+ outputs=[image_input, question_input, answer_output]
 
 
 
 
 
 
 
 
109
  )
 
 
 
110
 
111
+ demo.launch(share=True)
config/llama2/7B.json DELETED
@@ -1 +0,0 @@
1
- {"dim": 4096, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-05, "vocab_size": -1}
 
 
config/llama2/tokenizer.model DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
- size 499723
 
 
 
 
data/__pycache__/conversation_lib.cpython-310.pyc DELETED
Binary file (9.14 kB)
 
data/__pycache__/conversation_lib.cpython-39.pyc DELETED
Binary file (9.15 kB)
 
data/__pycache__/fintune_dataset.cpython-310.pyc DELETED
Binary file (14.2 kB)
 
data/__pycache__/fintune_dataset.cpython-39.pyc DELETED
Binary file (14.2 kB)
 
data/__pycache__/imu_utils.cpython-310.pyc DELETED
Binary file (6.71 kB)
 
data/__pycache__/imu_utils.cpython-39.pyc DELETED
Binary file (6.71 kB)
 
data/__pycache__/video_utils.cpython-310.pyc DELETED
Binary file (6.53 kB)
 
data/__pycache__/video_utils.cpython-39.pyc DELETED
Binary file (6.51 kB)
 
data/conversation_lib.py DELETED
@@ -1,369 +0,0 @@
1
- import dataclasses
2
- from enum import auto, Enum
3
- from typing import List, Tuple
4
-
5
-
6
- class SeparatorStyle(Enum):
7
- """Different separator style."""
8
- SINGLE = auto()
9
- TWO = auto()
10
- MPT = auto()
11
-
12
-
13
- @dataclasses.dataclass
14
- class Conversation:
15
- """A class that keeps all conversation history."""
16
- system: str
17
- roles: List[str]
18
- messages: List[List[str]]
19
- offset: int
20
- sep_style: SeparatorStyle = SeparatorStyle.SINGLE
21
- sep: str = "###"
22
- sep2: str = None
23
- version: str = "Unknown"
24
-
25
- skip_next: bool = False
26
-
27
- def get_prompt(self):
28
- if self.sep_style == SeparatorStyle.SINGLE:
29
- ret = self.system + '\n\n' + self.sep
30
- for role, message in self.messages:
31
- if message:
32
- if type(message) is tuple:
33
- message, _, _ = message
34
- ret += role + ": " + message + '\n' + self.sep
35
- else:
36
- ret += role + ":"
37
- return ret
38
- elif self.sep_style == SeparatorStyle.TWO:
39
- seps = [self.sep, self.sep2]
40
- ret = self.system + seps[0]
41
- for i, (role, message) in enumerate(self.messages):
42
- if message:
43
- if type(message) is tuple:
44
- message, _, _ = message
45
- ret += role + ": " + message + seps[i % 2]
46
- else:
47
- ret += role + ":"
48
- return ret
49
- if self.sep_style == SeparatorStyle.MPT:
50
- ret = self.system + self.sep
51
- for role, message in self.messages:
52
- if message:
53
- if type(message) is tuple:
54
- message, _, _ = message
55
- ret += role + message + self.sep
56
- else:
57
- ret += role
58
- return ret
59
- else:
60
- raise ValueError(f"Invalid style: {self.sep_style}")
61
-
62
- def append_message(self, role, message):
63
- self.messages.append([role, message])
64
-
65
- def get_images(self, return_pil=False):
66
- images = []
67
- for i, (role, msg) in enumerate(self.messages[self.offset:]):
68
- if i % 2 == 0:
69
- if type(msg) is tuple:
70
- import base64
71
- from io import BytesIO
72
- from PIL import Image
73
- msg, image, image_process_mode = msg
74
- if image_process_mode == "Pad":
75
- def expand2square(pil_img, background_color=(122, 116, 104)):
76
- width, height = pil_img.size
77
- if width == height:
78
- return pil_img
79
- elif width > height:
80
- result = Image.new(pil_img.mode, (width, width), background_color)
81
- result.paste(pil_img, (0, (width - height) // 2))
82
- return result
83
- else:
84
- result = Image.new(pil_img.mode, (height, height), background_color)
85
- result.paste(pil_img, ((height - width) // 2, 0))
86
- return result
87
-
88
- image = expand2square(image)
89
- elif image_process_mode == "Crop":
90
- pass
91
- elif image_process_mode == "Resize":
92
- image = image.resize((224, 224))
93
- else:
94
- raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
95
- max_hw, min_hw = max(image.size), min(image.size)
96
- aspect_ratio = max_hw / min_hw
97
- max_len, min_len = 800, 400
98
- shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
99
- longest_edge = int(shortest_edge * aspect_ratio)
100
- W, H = image.size
101
- if H > W:
102
- H, W = longest_edge, shortest_edge
103
- else:
104
- H, W = shortest_edge, longest_edge
105
- image = image.resize((W, H))
106
- if return_pil:
107
- images.append(image)
108
- else:
109
- buffered = BytesIO()
110
- image.save(buffered, format="JPEG")
111
- img_b64_str = base64.b64encode(buffered.getvalue()).decode()
112
- images.append(img_b64_str)
113
- return images
114
-
115
- def to_gradio_chatbot(self):
116
- ret = []
117
- for i, (role, msg) in enumerate(self.messages[self.offset:]):
118
- if i % 2 == 0:
119
- if type(msg) is tuple:
120
- import base64
121
- from io import BytesIO
122
- msg, image, image_process_mode = msg
123
- max_hw, min_hw = max(image.size), min(image.size)
124
- aspect_ratio = max_hw / min_hw
125
- max_len, min_len = 800, 400
126
- shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
127
- longest_edge = int(shortest_edge * aspect_ratio)
128
- W, H = image.size
129
- if H > W:
130
- H, W = longest_edge, shortest_edge
131
- else:
132
- H, W = shortest_edge, longest_edge
133
- image = image.resize((W, H))
134
- # image = image.resize((224, 224))
135
- buffered = BytesIO()
136
- image.save(buffered, format="JPEG")
137
- img_b64_str = base64.b64encode(buffered.getvalue()).decode()
138
- img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
139
- msg = msg.replace('<image>', img_str)
140
- ret.append([msg, None])
141
- else:
142
- ret[-1][-1] = msg
143
- return ret
144
-
145
- def copy(self):
146
- return Conversation(
147
- system=self.system,
148
- roles=self.roles,
149
- messages=[[x, y] for x, y in self.messages],
150
- offset=self.offset,
151
- sep_style=self.sep_style,
152
- sep=self.sep,
153
- sep2=self.sep2)
154
-
155
- def dict(self):
156
- if len(self.get_images()) > 0:
157
- return {
158
- "system": self.system,
159
- "roles": self.roles,
160
- "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
161
- "offset": self.offset,
162
- "sep": self.sep,
163
- "sep2": self.sep2,
164
- }
165
- return {
166
- "system": self.system,
167
- "roles": self.roles,
168
- "messages": self.messages,
169
- "offset": self.offset,
170
- "sep": self.sep,
171
- "sep2": self.sep2,
172
- }
173
-
174
-
175
- conv_v1 = Conversation(
176
- system="A chat between a curious human and an artificial intelligence assistant. "
177
- "The assistant gives helpful, detailed, and polite answers to the human's questions.",
178
- roles=("Human", "Assistant"),
179
- messages=(
180
- ("Human", "Give three tips for staying healthy."),
181
- ("Assistant",
182
- "Sure, here are three tips for staying healthy:\n"
183
- "1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. "
184
- "It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, "
185
- "and certain cancers. Aim for at least 150 minutes of moderate-intensity aerobic exercise or "
186
- "75 minutes of vigorous-intensity aerobic exercise per week, along with muscle-strengthening "
187
- "activities at least two days per week.\n"
188
- "2. Eat a balanced diet: Eating a balanced diet that is rich in fruits, "
189
- "vegetables, whole grains, lean proteins, and healthy fats can help support "
190
- "your overall health. Try to limit your intake of processed and high-sugar foods, "
191
- "and aim to drink plenty of water throughout the day.\n"
192
- "3. Get enough sleep: Getting enough quality sleep is essential for your physical "
193
- "and mental health. Adults should aim for seven to nine hours of sleep per night. "
194
- "Establish a regular sleep schedule and try to create a relaxing bedtime routine to "
195
- "help improve the quality of your sleep.")
196
- ),
197
- offset=2,
198
- sep_style=SeparatorStyle.SINGLE,
199
- sep="###",
200
- )
201
-
202
- conv_v1_2 = Conversation(
203
- system="A chat between a curious human and an artificial intelligence assistant. "
204
- "The assistant gives helpful, detailed, and polite answers to the human's questions.",
205
- roles=("Human", "Assistant"),
206
- messages=(),
207
-
208
- # (
209
- # ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
210
- # ("Assistant",
211
- # "Renewable energy sources are those that can be replenished naturally in a relatively "
212
- # "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
213
- # "Non-renewable energy sources, on the other hand, are finite and will eventually be "
214
- # "depleted, such as coal, oil, and natural gas. Here are some key differences between "
215
- # "renewable and non-renewable energy sources:\n"
216
- # "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
217
- # "energy sources are finite and will eventually run out.\n"
218
- # "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
219
- # "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
220
- # "and other negative effects.\n"
221
- # "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
222
- # "have lower operational costs than non-renewable sources.\n"
223
- # "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
224
- # "locations than non-renewable sources.\n"
225
- # "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
226
- # "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
227
- # "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
228
- # "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
229
- # )
230
- offset = 2,
231
- sep_style = SeparatorStyle.SINGLE,
232
- sep = "###",
233
- )
234
-
235
- conv_vicuna_v1_1 = Conversation(
236
- system="A chat between a curious user and an artificial intelligence assistant. "
237
- "The assistant gives helpful, detailed, and polite answers to the user's questions.",
238
- roles=("USER", "ASSISTANT"),
239
- version="v1",
240
- messages=(),
241
- offset=0,
242
- sep_style=SeparatorStyle.TWO,
243
- sep=" ",
244
- sep2="</s>",
245
- )
246
-
247
- conv_mpt = Conversation(
248
- system="""<|im_start|>system
249
- - You are a helpful language and vision assistant.
250
- - You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
251
- - You should follow the instructions carefully and explain your answers in detail.""",
252
- roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
253
- version="mpt",
254
- messages=(),
255
- offset=0,
256
- sep_style=SeparatorStyle.MPT,
257
- sep="<|im_end|>",
258
- )
259
-
260
- conv_mpt_text = Conversation(
261
- system="""<|im_start|>system
262
- - You are a helpful assistant chatbot trained by MosaicML.
263
- - You answer questions.
264
- - You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
265
- - You are more than just an information source, you are also able to write poetry, short stories, and make jokes.""",
266
- roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
267
- version="mpt",
268
- messages=(),
269
- offset=0,
270
- sep_style=SeparatorStyle.MPT,
271
- sep="<|im_end|>",
272
- )
273
-
274
- conv_bair_v1 = Conversation(
275
- system="BEGINNING OF CONVERSATION:",
276
- roles=("USER", "GPT"),
277
- messages=(),
278
- offset=0,
279
- sep_style=SeparatorStyle.TWO,
280
- sep=" ",
281
- sep2="</s>",
282
- )
283
-
284
- simple_conv = Conversation(
285
- system="A chat between a curious human and an artificial intelligence assistant. "
286
- "The assistant gives helpful, detailed, and polite answers to the human's questions.",
287
- roles=("Human", "Assistant"),
288
- messages=(
289
- ("Human", "Hi!"),
290
- ("Assistant", "Hi there! How can I help you today?")
291
- ),
292
- offset=2,
293
- sep_style=SeparatorStyle.SINGLE,
294
- sep="###",
295
- )
296
-
297
- simple_conv_multimodal = Conversation(
298
- system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
299
- "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
300
- "Follow the instructions carefully and explain your answers in detail.",
301
- roles=("Human", "Assistant"),
302
- messages=(
303
- ("Human", "Hi!"),
304
- ("Assistant", "Hi there! How can I help you today?\n")
305
- ),
306
- offset=2,
307
- sep_style=SeparatorStyle.SINGLE,
308
- sep="###",
309
- )
310
-
311
- simple_conv_mpt_multimodal = Conversation(
312
- system="""<|im_start|>system
313
- - You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab.
314
- - You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
315
- - You should follow the instructions carefully and explain your answers in detail.""",
316
- roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
317
- version="mpt",
318
- messages=(),
319
- offset=0,
320
- sep_style=SeparatorStyle.MPT,
321
- sep="<|im_end|>",
322
- )
323
-
324
- simple_conv_legacy = Conversation(
325
- system="You are LLaVA, a large language model trained by UW Madison WAIV Lab."
326
- "You are designed to assist human with a variety of tasks using natural language."
327
- "Follow the instructions carefully.",
328
- roles=("Human", "Assistant"),
329
- messages=(
330
- ("Human", "Hi!\n\n### Response:"),
331
- ("Assistant", "Hi there! How can I help you today?\n")
332
- ),
333
- offset=2,
334
- sep_style=SeparatorStyle.SINGLE,
335
- sep="###",
336
- )
337
-
338
- conv_llava_v1 = Conversation(
339
- system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
340
- "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
341
- "Follow the instructions carefully and explain your answers in detail.",
342
- roles=("USER", "ASSISTANT"),
343
- version="v1",
344
- messages=(),
345
- offset=0,
346
- sep_style=SeparatorStyle.TWO,
347
- sep=" ",
348
- sep2="</s>",
349
- )
350
-
351
- default_conversation = conv_v1_2
352
- conv_templates = {
353
- "default": conv_v1_2,
354
- "simple": simple_conv,
355
- "simple_legacy": simple_conv_legacy,
356
- "multimodal": simple_conv_multimodal,
357
- "mpt_multimodal": simple_conv_mpt_multimodal,
358
- "llava_v1": conv_llava_v1,
359
-
360
- # fastchat
361
- "v1": conv_v1_2,
362
- "bair_v1": conv_bair_v1,
363
- "vicuna_v1_1": conv_vicuna_v1_1,
364
- "mpt": conv_mpt,
365
- "mpt_text": conv_mpt_text,
366
- }
367
-
368
- if __name__ == "__main__":
369
- print(default_conversation.get_prompt())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/fintune_dataset.py DELETED
@@ -1,449 +0,0 @@
1
- import warnings
2
-
3
- import torch
4
- import yaml
5
- from torch.utils.data import Dataset
6
- from PIL import Image
7
- import json
8
- from model.tokenizer import Tokenizer
9
- import os
10
- import torchvision.transforms as transforms
11
- import random
12
- import torchvision.transforms.functional as F
13
- import torchaudio
14
- from . import conversation_lib
15
-
16
- import numpy as np
17
- from . import video_utils
18
- from .imu_utils import get_imu_frames
19
-
20
-
21
- IGNORE_INDEX = -100
22
-
23
- DEFAULT_IMAGE_TOKEN = "<image>"
24
- try:
25
- from torchvision.transforms import InterpolationMode
26
-
27
- BICUBIC = InterpolationMode.BICUBIC
28
- except ImportError:
29
- BICUBIC = Image.BICUBIC
30
-
31
- T_random_resized_crop = transforms.Compose([
32
- transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(0.75, 1.3333), interpolation=BICUBIC,
33
- antialias=None), # 3 is bicubic
34
- transforms.ToTensor(),
35
- transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
36
-
37
-
38
- # image transform
39
- transform_img_train = transforms.Compose([
40
- transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(
41
- 0.75, 1.3333), interpolation=3, antialias=None), # 3 is bicubic
42
- transforms.ToTensor(),
43
- transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
44
-
45
-
46
- class PairRandomResizedCrop(transforms.RandomResizedCrop):
47
- def forward(self, imgs):
48
- i, j, h, w = self.get_params(imgs[0], self.scale, self.ratio)
49
- return [F.resized_crop(img, i, j, h, w, self.size, self.interpolation, antialias=self.antialias) for img in imgs]
50
-
51
-
52
- class PairToTensor(transforms.ToTensor):
53
- def __call__(self, pics):
54
- return [F.to_tensor(pic) for pic in pics]
55
-
56
-
57
- class PairNormalize(transforms.Normalize):
58
- def forward(self, tensors):
59
- return [F.normalize(tensor, self.mean, self.std, self.inplace) for tensor in tensors]
60
-
61
-
62
- transform_pairimg_train = transforms.Compose([
63
- PairRandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(
64
- 0.75, 1.3333), interpolation=3, antialias=None), # 3 is bicubic
65
- PairToTensor(),
66
- PairNormalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
67
-
68
-
69
- def pc_norm(pc):
70
- """ pc: NxC, return NxC """
71
- xyz = pc[:, :3]
72
- other_feature = pc[:, 3:]
73
-
74
- centroid = torch.mean(xyz, dim=0)
75
- xyz = xyz - centroid
76
- m = torch.max(torch.sqrt(torch.sum(xyz ** 2, dim=1)))
77
- xyz = xyz / m
78
-
79
- pc = torch.cat((xyz, other_feature), dim=1)
80
- return pc
81
-
82
-
83
- def make_audio_features(wav_name, mel_bins=128, target_length=1024, aug=False):
84
- waveform, sr = torchaudio.load(wav_name)
85
- # assert sr == 16000, 'input audio sampling rate must be 16kHz'
86
- if sr != 16000:
87
- trans = torchaudio.transforms.Resample(sr, 16000)
88
- waveform = trans(waveform)
89
-
90
- waveform = waveform - waveform.mean()
91
-
92
- fbank = torchaudio.compliance.kaldi.fbank(
93
- waveform, htk_compat=True, sample_frequency=16000, use_energy=False,
94
- window_type='hanning', num_mel_bins=mel_bins, dither=0.0, frame_shift=10)
95
-
96
- n_frames = fbank.shape[0]
97
-
98
- p = target_length - n_frames
99
- if p > 0:
100
- m = torch.nn.ZeroPad2d((0, 0, 0, p))
101
- fbank = m(fbank)
102
- elif p < 0:
103
- fbank = fbank[0:target_length, :]
104
-
105
- if aug:
106
- freqm = torchaudio.transforms.FrequencyMasking(48)
107
- timem = torchaudio.transforms.TimeMasking(192)
108
- fbank = torch.transpose(fbank, 0, 1)
109
- fbank = fbank.unsqueeze(0)
110
- fbank = freqm(fbank)
111
- fbank = timem(fbank)
112
- fbank = fbank.squeeze(0)
113
- fbank = torch.transpose(fbank, 0, 1)
114
-
115
- fbank = (fbank - (-4.2677393)) / (4.5689974 * 2)
116
- return fbank
117
-
118
-
119
- class ConversationGenerator:
120
- def __init__(self, tokenizer):
121
- self.tokenizer = tokenizer
122
- self.header = f"{conversation_lib.default_conversation.system}\n\n"
123
- self._probe_tokenizer_style()
124
-
125
- def _probe_tokenizer_style(self):
126
- """
127
- Given a sentence, e.g. "My darling", some tokenizers will make the space a seperate token,
128
- while some others will merge the space into the next word, forming a token representing " darling".
129
- Knowing which style the tokenizer takes is necessary for correct ground-truth label masking.
130
-
131
- """
132
- probe = "Probe am I"
133
- sentence1 = self.tokenizer.encode(conversation_lib.default_conversation.roles[1] + ": " + probe,
134
- bos=False, eos=False)
135
- sentence2 = self.tokenizer.encode(probe,
136
- bos=False, eos=False)
137
- if sentence1[-len(sentence2):] == sentence2:
138
- self.space_before_to_predict = False
139
- else:
140
- sentence3 = self.tokenizer.encode(" " + probe,
141
- bos=False, eos=False)
142
- assert sentence1[-len(sentence3):] == sentence3
143
- self.space_before_to_predict = True
144
-
145
- def add_speaker_and_signal(self, source, get_conversation=True):
146
- """Add speaker and start/end signal on each round."""
147
- BEGIN_SIGNAL = "### "
148
- END_SIGNAL = "\n"
149
- conversation = self.header
150
-
151
- to_predict_list = []
152
-
153
- for sentence in source:
154
- from_str = sentence["from"]
155
- if from_str.lower() in ["human"]:
156
- from_str = conversation_lib.default_conversation.roles[0]
157
- elif from_str.lower() in ["gpt", "assistant"]:
158
- from_str = conversation_lib.default_conversation.roles[1]
159
- else:
160
- raise ValueError(f"unknown dialog role: {from_str.lower()}")
161
-
162
- value = sentence["value"]
163
- if DEFAULT_IMAGE_TOKEN in value:
164
- value = value.replace(DEFAULT_IMAGE_TOKEN, '').strip()
165
-
166
- sentence_value = BEGIN_SIGNAL + from_str + ": " + value + END_SIGNAL
167
-
168
- if from_str == conversation_lib.default_conversation.roles[1]:
169
- to_predict_value = value + END_SIGNAL + "###"
170
- if self.space_before_to_predict:
171
- to_predict_value = " " + to_predict_value
172
- to_predict_list.append(to_predict_value)
173
-
174
- if get_conversation:
175
- conversation = conversation + sentence_value
176
-
177
- conversation = conversation + BEGIN_SIGNAL
178
- return conversation, to_predict_list
179
-
180
-
181
- DATASETS = dict(
182
- image=[
183
- dict(path="datasets/InstructionTuning/image/llava_v1_5_mix665k_image.json", type='image'),
184
- dict(path='datasets/InstructionTuning/image/cococap_train.json', type='image'),
185
- dict(path="datasets/InstructionTuning/image/llava_v1_5_mix665k_text.json", type='text'),
186
- ],
187
- audio=[
188
- dict(path="datasets/InstructionTuning/audio/audiocap_train.json", type='audio'),
189
- dict(path="datasets/InstructionTuning/audio/audiocap_val.json", type='audio'),
190
- dict(path="datasets/InstructionTuning/audio/audio_conversation.json", type='audio'),
191
- ],
192
- video=[
193
- dict(path="datasets/InstructionTuning/video/msrvtt_cap_trainval.json", type='video'),
194
- dict(path="datasets/InstructionTuning/video/msrvtt_cap_test.json", type='video'),
195
- dict(path="datasets/InstructionTuning/video/msrvtt_vqa_train.json", type='video'),
196
- dict(path="datasets/InstructionTuning/video/msrvtt_vqa_val.json", type='video'),
197
- dict(path="datasets/InstructionTuning/video/msrvtt_vqa_test.json", type='video'),
198
- dict(path="datasets/InstructionTuning/video/video_complex_reasoning_10k.json", type='video'),
199
- dict(path="datasets/InstructionTuning/video/video_conversation_10k.json", type='video'),
200
- dict(path="datasets/InstructionTuning/video/video_detail_10k.json", type='video'),
201
- ],
202
- point=[
203
- dict(path="datasets/InstructionTuning/point/pointllm_70k_formated.json", type='point'),
204
- ],
205
- rgbd=[
206
- dict(path="datasets/InstructionTuning/depth_normal/llava_instruct_50k_depth.json", type='rgbd'),
207
- ],
208
- rgbn=[
209
- dict(path="datasets/InstructionTuning/depth_normal/llava_instruct_50k_normal.json", type='rgbn'),
210
- ],
211
- imu=[
212
- dict(path="datasets/InstructionTuning/imu/imu_fixed_50k.json", type='imu'),
213
- ],
214
- fmri=[
215
- dict(path="datasets/InstructionTuning/fmri/fmri_fixed.json", type='fmri'),
216
- ],
217
- )
218
- IMU_PATH = "/mnt/petrelfs/share_data/hanjiaming/ego4d/v2/processed_imu/"
219
-
220
-
221
- class FinetuneDialogDataset(Dataset):
222
- def __init__(self, dataset=['image'], transform=T_random_resized_crop, max_words=2048, image_words=30, tokenizer_path=None):
223
- if isinstance(dataset, str):
224
- dataset = [dataset]
225
-
226
- self.dataset = dataset
227
-
228
- group_ann = {}
229
- for d in dataset:
230
- for meta in DATASETS[d]:
231
- meta_path, meta_type = meta['path'], meta['type']
232
- meta_ext = os.path.splitext(meta_path)[-1]
233
- if meta_ext == ".json":
234
- with open(meta_path) as f:
235
- meta_l = json.load(f)
236
- # add data_type
237
- # this is a temp solution
238
- new_meta_l = []
239
- for l in meta_l:
240
- l['data_type'] = meta_type
241
- new_meta_l.append(l)
242
- meta_l = new_meta_l
243
- elif meta_ext == ".jsonl":
244
- meta_l = []
245
- with open(meta_path) as f:
246
- for i, line in enumerate(f):
247
- try:
248
- meta_l.append(json.loads(line))
249
- except json.decoder.JSONDecodeError as e:
250
- print(
251
- f"Error decoding the following jsonl line ({i}):\n{line.rstrip()}", force=True)
252
- raise e
253
- else:
254
- raise NotImplementedError(
255
- f"Unknown meta file extension: \"{meta_ext}\". "
256
- f"Currently, .json, .jsonl are supported. "
257
- "If you are using a supported format, please set the file extension so that the proper parsing "
258
- "routine can be called."
259
- )
260
- if meta_type not in group_ann:
261
- group_ann[meta_type] = []
262
- print(f"{meta_path}, type {meta_type}: len {len(meta_l)}")
263
- group_ann[meta_type] += meta_l
264
-
265
- # sort group_ann for higher efficiency (items in one global batch with similar length)
266
- for meta_type, meta_l in group_ann.items():
267
- meta_l.sort(key=lambda data_item: sum(
268
- [len(_['value']) for _ in data_item['conversations']]))
269
-
270
- self.group_ann = group_ann
271
- self.ann = sum(list(self.group_ann.values()), start=[])
272
-
273
- self.group_indices = {}
274
- start_pos = 0
275
- for meta_type, meta_l in self.group_ann.items():
276
- self.group_indices[meta_type] = list(
277
- range(start_pos, start_pos + len(meta_l)))
278
- start_pos = start_pos + len(meta_l)
279
-
280
- print(f"total length: {len(self)}")
281
- self.transform = transform
282
- print(f"transform:\n{self.transform}")
283
- self.max_words = max_words
284
- self.image_words = image_words
285
- self.tokenizer = Tokenizer(model_path=tokenizer_path)
286
- self.conversation_generator = ConversationGenerator(self.tokenizer)
287
-
288
- self.load_funcs = dict(
289
- image=self.load_image,
290
- audio=self.load_audio,
291
- video=self.load_video,
292
- point=self.load_point,
293
- rgbd=self.load_rgbx,
294
- rgbn=self.load_rgbx,
295
- imu=self.load_imu,
296
- fmri=self.load_fmri
297
- )
298
-
299
- def __len__(self):
300
- return len(self.ann)
301
-
302
- def load_image(self, data):
303
- filename = data['image']
304
- image = Image.open(filename).convert('RGB')
305
- image = self.transform(image)
306
- return image
307
-
308
- def load_audio(self, data):
309
- audio_path = data['image']
310
- fbank = make_audio_features(audio_path, mel_bins=128)
311
- fbank = fbank.transpose(0, 1)[None] # [1, 128, 1024]
312
- return fbank
313
-
314
- def load_video(self, data):
315
- video_path = data['image']
316
- video_feats = video_utils.load_and_transform_video_data(
317
- video_path, video_path, clip_duration=1, clips_per_video=5)
318
- return video_feats[:, :, 0]
319
-
320
- def load_point(self, data):
321
- point_path = data['image']
322
- point_feat = torch.load(point_path, map_location='cpu')
323
- point_feat = point_feat.transpose(0, 1)
324
- return point_feat
325
-
326
- def load_rgbx(self, data):
327
- image_path = data['image']
328
- x_image_path = data['depth_image'] if 'depth_image' in data else data['normal_image']
329
- image = Image.open(image_path).convert('RGB')
330
- x_image = Image.open(x_image_path).convert('RGB')
331
- x_image = x_image.resize(image.size[-2:])
332
-
333
- image, x_image = transform_pairimg_train([image, x_image])
334
- # [2, 3, H, W]
335
- image = torch.stack([image, x_image], dim=0)
336
- return image
337
-
338
- def load_fmri(self, data):
339
- fmri_path = data['image']
340
- data = np.load(fmri_path)
341
- data = data.mean(axis=0)
342
- data = torch.tensor(data[None])
343
- return data
344
-
345
- def load_imu(self, data_dict):
346
- uid = data_dict["video_uid"]
347
- w_s = data_dict["window_start"]
348
- w_e = data_dict["window_end"]
349
-
350
- imu_data = get_imu_frames(
351
- IMU_PATH, uid,
352
- video_start_sec=w_s,
353
- video_end_sec=w_e,
354
- )
355
- if imu_data is None:
356
- raise ValueError
357
- return imu_data['signal']
358
-
359
- def __getitem__(self, index, expect_type=None):
360
- if expect_type is None:
361
- data_item = self.ann[index]
362
- else:
363
- # in case we want get data from specific data_type
364
- data_item = self.group_ann[expect_type][index]
365
-
366
- data_type = data_item['data_type']
367
- if data_type != 'text':
368
- if data_type in self.load_funcs:
369
- try:
370
- image = self.load_funcs[data_type](data_item)
371
- if image == None:
372
- raise ValueError('Data is None')
373
- except:
374
- print('Error', data_item)
375
- rand_idx = random.randint(
376
- 0, len(self.group_ann[data_type]))
377
- return self.__getitem__(rand_idx, expect_type=data_type)
378
- else:
379
- raise ValueError(f'Does not support {data_type}')
380
- else:
381
- image = None
382
- # warnings.warn("pure black image for examples without image")
383
- # image = torch.zeros(3, 224, 224)
384
-
385
- source = data_item["conversations"]
386
- conversation, to_predict_values = self.conversation_generator.add_speaker_and_signal(
387
- source)
388
- if len(to_predict_values) == 0:
389
- warnings.warn(
390
- f"see dialog data with nothing to predict, data: {data_item}")
391
- return self[index-1]
392
-
393
- tokenzed_conversation = self.tokenizer.encode(
394
- conversation, bos=True, eos=True)
395
- labels = [IGNORE_INDEX for _ in tokenzed_conversation]
396
-
397
- check_pos = 0
398
- for value in to_predict_values:
399
- tokenized_value = self.tokenizer.encode(
400
- value, bos=False, eos=False)
401
- value_pos = find_sublist(
402
- tokenzed_conversation[check_pos:], tokenized_value) + check_pos
403
- if value_pos == -1:
404
- print(
405
- "a sentence mismatches the corresponding piece in the conversation")
406
- return self[index-1]
407
- labels[value_pos:value_pos+len(tokenized_value)] = tokenized_value
408
- assert labels[value_pos:value_pos+len(
409
- tokenized_value)] == tokenzed_conversation[value_pos:value_pos+len(tokenized_value)]
410
- check_pos = value_pos+len(tokenized_value)
411
-
412
- input2 = torch.tensor(tokenzed_conversation, dtype=torch.int64)
413
- labels = torch.tensor(labels, dtype=torch.int64)
414
-
415
- if image is not None:
416
- max_words = self.max_words - self.image_words
417
- else:
418
- max_words = self.max_words
419
- padding = max_words - input2.shape[0]
420
- if padding > 0:
421
- input2 = torch.cat(
422
- (input2, torch.zeros(padding, dtype=torch.int64) - 1))
423
- labels = torch.cat(
424
- (labels, torch.zeros(padding, dtype=torch.int64) - 1))
425
- elif padding < 0:
426
- input2 = input2[:max_words]
427
- labels = labels[:max_words]
428
-
429
- input2_mask = input2.ge(0)
430
- label_mask = labels.ge(0)
431
- input2[~input2_mask] = 0
432
- labels[~label_mask] = 0
433
- input2_mask = input2_mask.float()
434
- label_mask = label_mask.float()
435
- if image is None:
436
- return input2, labels, data_item['data_type']
437
- else:
438
- return input2, labels, image, data_item['data_type']
439
-
440
- def groups(self):
441
- return list(self.group_indices.values())
442
-
443
-
444
- def find_sublist(a: list, b: list):
445
- len_a, len_b = len(a), len(b)
446
- for i in range(len_a - len_b + 1):
447
- if a[i:i+len_b] == b:
448
- return i
449
- return -1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/imu_utils.py DELETED
@@ -1,257 +0,0 @@
1
- import string
2
- import numpy as np
3
- import matplotlib.animation as animation
4
- from matplotlib import pyplot as plt
5
- import json
6
- from collections import defaultdict
7
- from bisect import bisect_left
8
- import os
9
- import torch
10
- import torchaudio
11
- torchaudio.set_audio_backend("sox_io")
12
-
13
-
14
- def load_json(json_path: str):
15
- """
16
- Load a json file
17
- """
18
- with open(json_path, "r", encoding="utf-8") as f_name:
19
- data = json.load(f_name)
20
- return data
21
-
22
-
23
- def check_window_signal(info_t, w_s, w_e):
24
- length = w_e - w_s
25
- frame_offset = int(w_s * info_t.sample_rate)
26
- num_frames = int(length * info_t.sample_rate)
27
- if frame_offset + num_frames > int(info_t.num_frames):
28
- return False
29
- else:
30
- return True
31
-
32
-
33
- def index_narrations(ann_path):
34
- narration_raw = load_json(ann_path)
35
-
36
- narration_dict = defaultdict(list)
37
- summary_dict = defaultdict(list)
38
- avg_len = []
39
- for v_id, narr in narration_raw.items():
40
- narr_list = []
41
- summ_list = []
42
- if "narration_pass_1" in narr:
43
- narr_list += narr["narration_pass_1"]["narrations"]
44
- summ_list += narr["narration_pass_1"]["summaries"]
45
- if "narration_pass_2" in narr:
46
- narr_list += narr["narration_pass_2"]["narrations"]
47
- summ_list += narr["narration_pass_2"]["summaries"]
48
-
49
- if len(narr_list) > 0:
50
- narration_dict[v_id] = [
51
- (
52
- float(n_t["timestamp_sec"]),
53
- n_t["narration_text"],
54
- n_t["annotation_uid"],
55
- n_t["timestamp_frame"],
56
- )
57
- for n_t in narr_list
58
- ]
59
- avg_len.append(len(narration_dict[v_id]))
60
- else:
61
- narration_dict[v_id] = []
62
- if len(summ_list) > 0:
63
- summary_dict[v_id] = [
64
- (
65
- float(s_t["start_sec"]),
66
- float(s_t["end_sec"]),
67
- s_t["summary_text"],
68
- )
69
- for s_t in summ_list
70
- ]
71
- else:
72
- summary_dict[v_id] = []
73
- # print(f"Number of Videos with narration {len(narration_dict)}")
74
- # print(f"Avg. narration length {np.mean(avg_len)}")
75
- # print(f"Number of Videos with summaries {len(summary_dict)}")
76
- return narration_dict, summary_dict
77
-
78
-
79
- def get_signal_info(signal_fn: str):
80
- return torchaudio.info(signal_fn)
81
-
82
-
83
- def get_signal_frames(signal_fn: str, video_start_sec: float, video_end_sec: float):
84
- """
85
- Given a signal track return the frames between video_start_sec and video_end_sec
86
- """
87
- info_t = get_signal_info(signal_fn)
88
-
89
- length = video_end_sec - video_start_sec
90
- aframes, _ = torchaudio.load(
91
- signal_fn,
92
- normalize=True,
93
- frame_offset=int(video_start_sec * info_t.sample_rate),
94
- num_frames=int(length * info_t.sample_rate),
95
- )
96
- return {"signal": aframes, "meta": info_t}
97
-
98
-
99
- def tosec(value):
100
- return value / 1000
101
-
102
-
103
- def toms(value):
104
- return value * 1000
105
-
106
-
107
- def delta(first_num: float, second_num: float):
108
- """Compute the absolute value of the difference of two numbers"""
109
- return abs(first_num - second_num)
110
-
111
-
112
- def padIMU(signal, duration_sec):
113
- """
114
- Pad the signal if necessary
115
- """
116
- expected_elements = round(duration_sec) * 200
117
-
118
- if signal.shape[0] > expected_elements:
119
- signal = signal[:expected_elements, :]
120
- elif signal.shape[0] < expected_elements:
121
- padding = expected_elements - signal.shape[0]
122
- padded_zeros = np.zeros((padding, 6))
123
- signal = np.concatenate([signal, padded_zeros], 0)
124
- # signal = signal[:expected_elements, :]
125
- return signal
126
-
127
-
128
- def resample(
129
- signals: np.ndarray,
130
- timestamps: np.ndarray,
131
- original_sample_rate: int,
132
- resample_rate: int,
133
- ):
134
- """
135
- Resamples data to new sample rate
136
- """
137
- signals = torch.as_tensor(signals)
138
- timestamps = torch.from_numpy(timestamps).unsqueeze(-1)
139
- signals = torchaudio.functional.resample(
140
- waveform=signals.data.T,
141
- orig_freq=original_sample_rate,
142
- new_freq=resample_rate,
143
- ).T.numpy()
144
-
145
- nsamples = len(signals)
146
-
147
- period = 1 / resample_rate
148
-
149
- # timestamps are expected to be shape (N, 1)
150
- initital_seconds = timestamps[0] / 1e3
151
-
152
- ntimes = (torch.arange(nsamples) * period).view(-1, 1) + initital_seconds
153
-
154
- timestamps = (ntimes * 1e3).squeeze().numpy()
155
- return signals, timestamps
156
-
157
-
158
- def resampleIMU(signal, timestamps):
159
- sampling_rate = int(1000 * (1 / (np.mean(np.diff(timestamps)))))
160
- # resample all to 200hz
161
- if sampling_rate != 200:
162
- signal, timestamps = resample(signal, timestamps, sampling_rate, 200)
163
- return signal, timestamps
164
-
165
-
166
- def get_imu_frames(
167
- imu_path,
168
- uid: str,
169
- video_start_sec: float,
170
- video_end_sec: float,
171
- ):
172
- """
173
- Given a IMU signal return the frames between video_start_sec and video_end_sec
174
- """
175
- signal = np.load(os.path.join(imu_path, f"{uid}.npy"))
176
- signal = signal.transpose()
177
- timestamps = np.load(os.path.join(imu_path, f"{uid}_timestamps.npy"))
178
-
179
- if toms(video_start_sec) > timestamps[-1] or toms(video_end_sec) > timestamps[-1]:
180
- return None
181
-
182
- start_id = bisect_left(timestamps, toms(video_start_sec))
183
- end_id = bisect_left(timestamps, toms(video_end_sec))
184
-
185
- # make sure the retrieved window interval are correct by a max of 1 sec margin
186
- if (
187
- delta(video_start_sec, tosec(timestamps[start_id])) > 4
188
- or delta(video_end_sec, tosec(timestamps[end_id])) > 4
189
- ):
190
- return None
191
-
192
- # get the window
193
- if start_id == end_id:
194
- start_id -= 1
195
- end_id += 1
196
- signal, timestamps = signal[start_id:end_id], timestamps[start_id:end_id]
197
-
198
- if len(signal) < 10 or len(timestamps) < 10:
199
- return None
200
- # resample the signal at 200hz if necessary
201
- signal, timestamps = resampleIMU(signal, timestamps)
202
-
203
- # pad the signal if necessary
204
- signal = padIMU(signal, video_end_sec - video_start_sec)
205
-
206
- sample_dict = {
207
- "timestamp": timestamps,
208
- "signal": torch.tensor(signal.T),
209
- "sampling_rate": 200,
210
- }
211
-
212
- return sample_dict
213
-
214
-
215
- def display_animation(frames, title, save_path_gif):
216
- fig, ax = plt.subplots()
217
- frames = [[ax.imshow(frames[i])] for i in range(len(frames))]
218
- plt.title(title)
219
- ani = animation.ArtistAnimation(fig, frames)
220
- ani.save(save_path_gif, writer="imagemagick")
221
- plt.close()
222
-
223
-
224
- def display_animation_imu(frames, imu, title, save_path_gif):
225
- fig, (ax1, ax2, ax3) = plt.subplots(3, 1)
226
- ax1.set_title(title)
227
- ax2.set_title("Acc.")
228
- ax3.set_title("Gyro.")
229
- frames = [[ax1.imshow(frames[i])] for i in range(len(frames))]
230
- ani = animation.ArtistAnimation(fig, frames)
231
-
232
- ax2.plot(imu[0].cpu().numpy(), color="red")
233
- ax2.plot(imu[1].cpu().numpy(), color="blue")
234
- ax2.plot(imu[2].cpu().numpy(), color="green")
235
- ax3.plot(imu[3].cpu().numpy(), color="red")
236
- ax3.plot(imu[4].cpu().numpy(), color="blue")
237
- ax3.plot(imu[5].cpu().numpy(), color="green")
238
- plt.tight_layout()
239
- ani.save(save_path_gif, writer="imagemagick")
240
- plt.close()
241
-
242
-
243
- def filter_narration(narration_text: str) -> bool:
244
- if "#c" in narration_text.lower():
245
- return True
246
- return False
247
-
248
-
249
- def clean_narration_text(narration_text: str) -> str:
250
- return (
251
- narration_text.replace("#C C ", "")
252
- .replace("#C", "")
253
- .replace("#unsure", "something")
254
- .strip()
255
- .strip(string.punctuation)
256
- .lower()[:128]
257
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/video_utils.py DELETED
@@ -1,204 +0,0 @@
1
- import math
2
- import torch
3
- import torch.nn as nn
4
- from pytorchvideo import transforms as pv_transforms
5
- from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler
6
- from pytorchvideo.data.encoded_video import EncodedVideo
7
- from pytorchvideo.data.encoded_video_decord import EncodedVideoDecord
8
- from torchvision import transforms
9
- from torchvision.transforms._transforms_video import NormalizeVideo
10
-
11
-
12
- def get_clip_timepoints(clip_sampler, duration):
13
- # Read out all clips in this video
14
- all_clips_timepoints = []
15
- is_last_clip = False
16
- end = 0.0
17
- while not is_last_clip:
18
- start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None)
19
- all_clips_timepoints.append((start, end))
20
- return all_clips_timepoints
21
-
22
-
23
-
24
- def crop_boxes(boxes, x_offset, y_offset):
25
- """
26
- Perform crop on the bounding boxes given the offsets.
27
- Args:
28
- boxes (ndarray or None): bounding boxes to perform crop. The dimension
29
- is `num boxes` x 4.
30
- x_offset (int): cropping offset in the x axis.
31
- y_offset (int): cropping offset in the y axis.
32
- Returns:
33
- cropped_boxes (ndarray or None): the cropped boxes with dimension of
34
- `num boxes` x 4.
35
- """
36
- cropped_boxes = boxes.copy()
37
- cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
38
- cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
39
-
40
- return cropped_boxes
41
-
42
-
43
- def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
44
- """
45
- Perform uniform spatial sampling on the images and corresponding boxes.
46
- Args:
47
- images (tensor): images to perform uniform crop. The dimension is
48
- `num frames` x `channel` x `height` x `width`.
49
- size (int): size of height and weight to crop the images.
50
- spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
51
- is larger than height. Or 0, 1, or 2 for top, center, and bottom
52
- crop if height is larger than width.
53
- boxes (ndarray or None): optional. Corresponding boxes to images.
54
- Dimension is `num boxes` x 4.
55
- scale_size (int): optinal. If not None, resize the images to scale_size before
56
- performing any crop.
57
- Returns:
58
- cropped (tensor): images with dimension of
59
- `num frames` x `channel` x `size` x `size`.
60
- cropped_boxes (ndarray or None): the cropped boxes with dimension of
61
- `num boxes` x 4.
62
- """
63
- assert spatial_idx in [0, 1, 2]
64
- ndim = len(images.shape)
65
- if ndim == 3:
66
- images = images.unsqueeze(0)
67
- height = images.shape[2]
68
- width = images.shape[3]
69
-
70
- if scale_size is not None:
71
- if width <= height:
72
- width, height = scale_size, int(height / width * scale_size)
73
- else:
74
- width, height = int(width / height * scale_size), scale_size
75
- images = torch.nn.functional.interpolate(
76
- images,
77
- size=(height, width),
78
- mode="bilinear",
79
- align_corners=False,
80
- )
81
-
82
- y_offset = int(math.ceil((height - size) / 2))
83
- x_offset = int(math.ceil((width - size) / 2))
84
-
85
- if height > width:
86
- if spatial_idx == 0:
87
- y_offset = 0
88
- elif spatial_idx == 2:
89
- y_offset = height - size
90
- else:
91
- if spatial_idx == 0:
92
- x_offset = 0
93
- elif spatial_idx == 2:
94
- x_offset = width - size
95
- cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size]
96
- cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
97
- if ndim == 3:
98
- cropped = cropped.squeeze(0)
99
- return cropped, cropped_boxes
100
-
101
-
102
- class SpatialCrop(nn.Module):
103
- """
104
- Convert the video into 3 smaller clips spatially. Must be used after the
105
- temporal crops to get spatial crops, and should be used with
106
- -2 in the spatial crop at the slowfast augmentation stage (so full
107
- frames are passed in here). Will return a larger list with the
108
- 3x spatial crops as well.
109
- """
110
-
111
- def __init__(self, crop_size: int = 224, num_crops: int = 3):
112
- super().__init__()
113
- self.crop_size = crop_size
114
- if num_crops == 3:
115
- self.crops_to_ext = [0, 1, 2]
116
- self.flipped_crops_to_ext = []
117
- elif num_crops == 1:
118
- self.crops_to_ext = [1]
119
- self.flipped_crops_to_ext = []
120
- else:
121
- raise NotImplementedError("Nothing else supported yet")
122
-
123
- def forward(self, videos):
124
- """
125
- Args:
126
- videos: A list of C, T, H, W videos.
127
- Returns:
128
- videos: A list with 3x the number of elements. Each video converted
129
- to C, T, H', W' by spatial cropping.
130
- """
131
- assert isinstance(videos, list), "Must be a list of videos after temporal crops"
132
- assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)"
133
- res = []
134
- for video in videos:
135
- for spatial_idx in self.crops_to_ext:
136
- res.append(uniform_crop(video, self.crop_size, spatial_idx)[0])
137
- if not self.flipped_crops_to_ext:
138
- continue
139
- flipped_video = transforms.functional.hflip(video)
140
- for spatial_idx in self.flipped_crops_to_ext:
141
- res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0])
142
- return res
143
-
144
-
145
- def load_and_transform_video_data(
146
- video_file,
147
- video_path,
148
- clip_duration=2,
149
- clips_per_video=5,
150
- sample_rate=16000,
151
- with_audio=False
152
- ):
153
- video_transform = transforms.Compose(
154
- [
155
- pv_transforms.ShortSideScale(224),
156
- NormalizeVideo(
157
- mean=(0.48145466, 0.4578275, 0.40821073),
158
- std=(0.26862954, 0.26130258, 0.27577711),
159
- ),
160
- ]
161
- )
162
-
163
- clip_sampler = ConstantClipsPerVideoSampler(
164
- clip_duration=clip_duration, clips_per_video=clips_per_video
165
- )
166
- frame_sampler = pv_transforms.UniformTemporalSubsample(num_samples=clip_duration)
167
-
168
- if isinstance(video_file, str):
169
- video = EncodedVideo.from_path(
170
- video_file,
171
- decoder="decord",
172
- decode_audio=with_audio,
173
- # **{"sample_rate": sample_rate},
174
- )
175
- else:
176
- video = EncodedVideoDecord(video_file, video_name=video_path, decode_video=True, decode_audio=with_audio, sample_rate=sample_rate)
177
-
178
- all_clips_timepoints = get_clip_timepoints(clip_sampler, video.duration)
179
-
180
- all_video = []
181
- for clip_timepoints in all_clips_timepoints:
182
- # Read the clip, get frames
183
- clip = video.get_clip(clip_timepoints[0], clip_timepoints[1])
184
- if clip is None:
185
- raise ValueError("No clip found")
186
- video_clip = frame_sampler(clip["video"])
187
- video_clip = video_clip / 255.0 # since this is float, need 0-1
188
-
189
- all_video.append(video_clip)
190
-
191
- all_video = [video_transform(clip) for clip in all_video]
192
- all_video = SpatialCrop(224, num_crops=3)(all_video)
193
-
194
- all_video = torch.stack(all_video, dim=0)
195
-
196
- if not with_audio:
197
- return all_video
198
- else:
199
- return all_video, clip['audio']
200
-
201
- if __name__ == '__main__':
202
- video_path = "datasets/InstructionTuning/video/music_aqa/MUSIC-AVQA-videos-Real/00000002.mp4"
203
- video, audio = load_and_transform_video_data(video_path, video_path, clip_duration=1, clips_per_video=5, with_audio=True)
204
- import pdb;pdb.set_trace()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demos/multi_turn_mm.py DELETED
@@ -1,300 +0,0 @@
1
- import sys
2
- import os
3
- sys.path.append(os.path.abspath(__file__).rsplit('/', 2)[0])
4
-
5
- import argparse
6
- import multiprocessing as mp
7
- import numpy as np
8
- from typing import List, Optional
9
-
10
- import torch
11
- import torch.distributed as dist
12
-
13
- from fairscale.nn.model_parallel import initialize as fs_init
14
-
15
- import gradio as gr
16
- from util.misc import setup_for_distributed
17
- from util.misc import default_tensor_type
18
- from model.meta import MetaModel
19
- from data.conversation_lib import conv_templates, SeparatorStyle
20
- from PIL import Image
21
- import torchvision.transforms as transforms
22
- from data.fintune_dataset import make_audio_features
23
- from data import video_utils
24
-
25
-
26
- T_random_resized_crop = transforms.Compose([
27
- transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(0.75, 1.3333), interpolation=3,
28
- antialias=None), # 3 is bicubic
29
- transforms.ToTensor(),
30
- transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
31
-
32
-
33
- def load_audio(audio_path):
34
- fbank = make_audio_features(audio_path, mel_bins=128)
35
- fbank = fbank.transpose(0, 1)[None] #[1, 128, 1024]
36
- return fbank
37
-
38
- def load_video(video_path):
39
- video_feats = video_utils.load_and_transform_video_data(video_path, video_path, clip_duration=1, clips_per_video=5)
40
- return video_feats[:, :, 0]
41
-
42
-
43
- def model_worker(
44
- rank: int, args: argparse.Namespace, barrier: mp.Barrier,
45
- request_queue: mp.Queue, response_queue: Optional[mp.Queue] = None,
46
- ) -> None:
47
- """
48
- The worker function that manipulates the GPU to run the inference.
49
- Exact n_gpu workers are started, with each one operating on a separate GPU.
50
-
51
- Args:
52
- rank (int): Distributed rank of the worker.
53
- args (argparse.Namespace): All command line arguments.
54
- barrier (multiprocessing.Barrier): A barrier used to delay the start
55
- of Web UI to be after the start of the model.
56
- """
57
-
58
- world_size = len(args.gpu_ids)
59
- gpu_id = args.gpu_ids[rank]
60
- dist.init_process_group(
61
- backend="nccl", rank=rank, world_size=world_size,
62
- init_method=f"tcp://{args.master_addr}:{args.master_port}",
63
- )
64
- print(f"| distributed init on worker {rank}/{world_size}. "
65
- f"using gpu: {gpu_id}")
66
- fs_init.initialize_model_parallel(world_size)
67
- torch.cuda.set_device(gpu_id)
68
-
69
- torch.manual_seed(1)
70
- np.random.seed(1)
71
-
72
- # set the print behavior.
73
- setup_for_distributed(rank == 0)
74
-
75
- target_dtype = {
76
- "bf16": torch.bfloat16,
77
- "fp16": torch.float16
78
- }[args.dtype]
79
- with default_tensor_type(dtype=target_dtype, device="cuda"):
80
- model = MetaModel(args.llama_type, args.llama_config, tokenizer_path=args.tokenizer_path)
81
- print("Loading pretrained weights ...")
82
- checkpoint = torch.load(args.pretrained_path, map_location='cpu')
83
- msg = model.load_state_dict(checkpoint, strict=False)
84
- print("load result:\n", msg)
85
- model.cuda()
86
- model.eval()
87
- print(f"Model = {str(model)}")
88
-
89
- barrier.wait()
90
-
91
- while True:
92
- img_path, audio_path, video_path, chatbot, max_gen_len, temperature, top_p, modality = request_queue.get()
93
- if 'image' in modality and img_path is not None:
94
- image = Image.open(img_path).convert('RGB')
95
- inputs = T_random_resized_crop(image)
96
- elif 'video' in modality and video_path is not None:
97
- inputs = load_video(video_path)
98
- elif 'audio' in modality and audio_path is not None:
99
- inputs = load_audio(audio_path)
100
- else:
101
- inputs = None
102
-
103
- if inputs is not None:
104
- inputs = inputs[None].cuda().to(target_dtype)
105
-
106
- conv = conv_templates["v1"].copy()
107
- for user, bot in chatbot:
108
- conv.append_message(conv.roles[0], user)
109
- conv.append_message(conv.roles[1], bot)
110
-
111
- with torch.cuda.amp.autocast(dtype=target_dtype):
112
- print(conv.get_prompt())
113
- for stream_response in model.stream_generate(
114
- conv.get_prompt(), inputs,
115
- max_gen_len=max_gen_len, temperature=temperature, top_p=top_p,
116
- modal = modality
117
- ):
118
- conv_sep = (
119
- conv.sep
120
- if conv.sep_style == SeparatorStyle.SINGLE
121
- else conv.sep2
122
- )
123
- end_pos = stream_response["text"].find(conv_sep)
124
- if end_pos != -1:
125
- stream_response["text"] = (
126
- stream_response['text'][:end_pos].rstrip() + "\n"
127
- )
128
- stream_response["end_of_content"] = True
129
-
130
- # keep a few characters if not end_of_content to avoid sending
131
- # part of conv_sep before all of it is generated.
132
- if not stream_response["end_of_content"]:
133
- if len(stream_response["text"]) < len(conv_sep):
134
- continue
135
- stream_response["text"] = (
136
- stream_response["text"][:-len(conv_sep)]
137
- )
138
-
139
- if response_queue is not None:
140
- response_queue.put(stream_response)
141
-
142
- if stream_response["end_of_content"]:
143
- break
144
-
145
-
146
- def gradio_worker(
147
- request_queues: List[mp.Queue], response_queue: mp.Queue,
148
- args: argparse.Namespace, barrier: mp.Barrier,
149
- ) -> None:
150
- """
151
- The gradio worker is responsible for displaying the WebUI and relay the
152
- requests to model workers. It should be launched only once.
153
-
154
- Args:
155
- request_queues (List[mp.Queue]): A list of request queues (one for
156
- each model worker).
157
- args (argparse.Namespace): All command line arguments.
158
- barrier (multiprocessing.Barrier): A barrier used to delay the start
159
- of Web UI to be after the start of the model.
160
- """
161
-
162
- def show_user_input(msg, chatbot):
163
- return "", chatbot + [[msg, None]]
164
-
165
- def stream_model_output(img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality):
166
- for queue in request_queues:
167
- queue.put((img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality))
168
- while True:
169
- content_piece = response_queue.get()
170
- chatbot[-1][1] = content_piece["text"]
171
- yield chatbot
172
- if content_piece["end_of_content"]:
173
- break
174
-
175
- def undo(chatbot):
176
- if len(chatbot) > 0:
177
- chatbot = chatbot[:-1]
178
- return chatbot
179
-
180
- def clear():
181
- chatbot = []
182
- msg = ""
183
- return chatbot, msg
184
-
185
- CSS ="""
186
- .contain { display: flex; flex-direction: column; }
187
- #component-0 { height: 100%; }
188
- #chatbot { flex-grow: 1; overflow: auto;}
189
- """
190
- with gr.Blocks(css=CSS) as demo:
191
- gr.Markdown("## OneLLM: One Framework to Align All Modalities with Language")
192
- with gr.Row(equal_height=True):
193
- with gr.Column(scale=1):
194
- img_path = gr.Image(label='Image Input', type='filepath')
195
- video_path = gr.Video(label='Video Input')
196
- audio_path = gr.Audio(label='Audio Input', type='filepath', sources=['upload'])
197
- modality = gr.Radio(choices=['image', 'audio', 'video'], value='image', interactive=True, label='Input Modalities')
198
-
199
- with gr.Column(scale=2):
200
- chatbot = gr.Chatbot(elem_id="chatbot")
201
- msg = gr.Textbox()
202
-
203
- with gr.Row():
204
- submit_button = gr.Button("Submit", variant="primary")
205
- undo_button = gr.Button("Undo")
206
- clear_button = gr.ClearButton([chatbot, msg, img_path, audio_path, video_path, modality])
207
- with gr.Row():
208
- max_gen_len = gr.Slider(
209
- minimum=1, maximum=args.model_max_seq_len // 2,
210
- value=args.model_max_seq_len // 2, interactive=True,
211
- label="Single-turn max response length",
212
- )
213
- gen_t = gr.Slider(
214
- minimum=0, maximum=1, value=0.1, interactive=True,
215
- label="Temperature",
216
- )
217
- top_p = gr.Slider(
218
- minimum=0, maximum=1, value=0.75, interactive=True,
219
- label="Top-p",
220
- )
221
- msg.submit(
222
- show_user_input, [msg, chatbot], [msg, chatbot],
223
- ).then(
224
- stream_model_output, [img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality], chatbot,
225
- )
226
- submit_button.click(
227
- show_user_input, [msg, chatbot], [msg, chatbot],
228
- ).then(
229
- stream_model_output, [img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality], chatbot,
230
- )
231
- undo_button.click(undo, chatbot, chatbot)
232
- # img_path.change(clear, [], [chatbot, msg])
233
- barrier.wait()
234
- demo.queue(api_open=True).launch(share=True, max_threads=1)
235
-
236
-
237
- if __name__ == "__main__":
238
- parser = argparse.ArgumentParser("Chat Demo")
239
- group = parser.add_mutually_exclusive_group()
240
- group.add_argument(
241
- "--gpu_ids", type=int, nargs="+",
242
- help="A list of space-separated gpu ids to run the model on. "
243
- "The model will span across GPUs in tensor-parallel mode."
244
- )
245
- parser.add_argument(
246
- "--tokenizer_path", type=str,
247
- help="Path to the tokenizer.model file provided along with the LLaMA "
248
- "model."
249
- )
250
- parser.add_argument(
251
- "--llama_type", default="onellm", type=str, metavar="MODEL",
252
- help="LLaMA model type."
253
- )
254
- parser.add_argument(
255
- "--llama_config", type=str, required=True,
256
- help="Path to the llama model config json."
257
- )
258
- parser.add_argument(
259
- "--model_max_seq_len", type=int, default=2048,
260
- help="Max sequence length accepted by the pretrained model."
261
- )
262
- parser.add_argument(
263
- "--pretrained_path", type=str, required=True,
264
- help="Path to the llama model checkpoints. A list of checkpoints is "
265
- "supported and will be merged from left to right.")
266
- parser.add_argument(
267
- "--master_port", type=int, default=23862,
268
- help="A port used by the PyTorch distributed module to initialize."
269
- )
270
- parser.add_argument(
271
- "--master_addr", type=str, default="127.0.0.1",
272
- help="An address used by the PyTorch distributed module to initialize."
273
- )
274
- parser.add_argument(
275
- "--dtype", type=str, choices=["fp16", "bf16"], default="fp16",
276
- help="The dtype used for model weights and inference."
277
- )
278
- args = parser.parse_args()
279
-
280
- # using the default "fork" method messes up some imported libs (e.g.,
281
- # pandas)
282
- mp.set_start_method("spawn")
283
-
284
- # setup the queues and start the model workers
285
- request_queues = []
286
- response_queue = mp.Queue()
287
- worker_processes = []
288
- barrier = mp.Barrier(len(args.gpu_ids) + 1)
289
- for rank, gpu_id in enumerate(args.gpu_ids):
290
- request_queue = mp.Queue()
291
- rank_response_queue = response_queue if rank == 0 else None
292
- process = mp.Process(
293
- target=model_worker,
294
- args=(rank, args, barrier, request_queue, rank_response_queue),
295
- )
296
- process.start()
297
- worker_processes.append(process)
298
- request_queues.append(request_queue)
299
-
300
- gradio_worker(request_queues, response_queue, args, barrier)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/bell_ring.wav DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:415b5406175cafa874396f89811121fb306c17084db8ec20753ab5666b4fdcca
3
- size 3630044
 
 
 
 
examples/bird_audio.wav DELETED
Binary file (882 kB)
 
examples/depth_normal/depth/0084.png DELETED
Binary file (10.9 kB)
 
examples/depth_normal/depth/0131.png DELETED
Binary file (5.5 kB)
 
examples/depth_normal/depth/0297.png DELETED
Binary file (11.1 kB)
 
examples/depth_normal/depth/0331.png DELETED
Binary file (10.6 kB)
 
examples/depth_normal/depth/0432.png DELETED
Binary file (7.89 kB)
 
examples/depth_normal/depth/0633.png DELETED
Binary file (9.35 kB)
 
examples/depth_normal/depth/0663.png DELETED
Binary file (5.68 kB)
 
examples/depth_normal/depth/0771.png DELETED
Binary file (9.65 kB)
 
examples/depth_normal/depth/0782.png DELETED
Binary file (9.58 kB)
 
examples/depth_normal/depth/1001.png DELETED
Binary file (7.17 kB)
 
examples/depth_normal/depth/1051.png DELETED
Binary file (6.81 kB)
 
examples/depth_normal/depth/1129.png DELETED
Binary file (6.66 kB)
 
examples/depth_normal/depth/1205.png DELETED
Binary file (9.24 kB)
 
examples/depth_normal/depth/1336.png DELETED
Binary file (11.5 kB)
 
examples/depth_normal/depth/1383.png DELETED
Binary file (9.98 kB)
 
examples/depth_normal/depth/1386.png DELETED
Binary file (12.2 kB)
 
examples/depth_normal/depth/1393.png DELETED
Binary file (8.6 kB)
 
examples/depth_normal/depth/1447.png DELETED
Binary file (10.3 kB)
 
examples/depth_normal/depth_scaled/0084.png DELETED
Binary file (11.1 kB)
 
examples/depth_normal/depth_scaled/0131.png DELETED
Binary file (5.68 kB)
 
examples/depth_normal/depth_scaled/0297.png DELETED
Binary file (11.8 kB)
 
examples/depth_normal/depth_scaled/0331.png DELETED
Binary file (11.1 kB)
 
examples/depth_normal/depth_scaled/0432.png DELETED
Binary file (8.24 kB)
 
examples/depth_normal/depth_scaled/0633.png DELETED
Binary file (9.96 kB)
 
examples/depth_normal/depth_scaled/0663.png DELETED
Binary file (5.91 kB)
 
examples/depth_normal/depth_scaled/0771.png DELETED
Binary file (10.1 kB)
 
examples/depth_normal/depth_scaled/0782.png DELETED
Binary file (9.98 kB)
 
examples/depth_normal/depth_scaled/1001.png DELETED
Binary file (7.19 kB)
 
examples/depth_normal/depth_scaled/1051.png DELETED
Binary file (7.05 kB)
 
examples/depth_normal/depth_scaled/1129.png DELETED
Binary file (6.9 kB)
 
examples/depth_normal/depth_scaled/1205.png DELETED
Binary file (9.58 kB)