# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # // # // Licensed under the Apache License, Version 2.0 (the "License"); # // you may not use this file except in compliance with the License. # // You may obtain a copy of the License at # // # // http://www.apache.org/licenses/LICENSE-2.0 # // # // Unless required by applicable law or agreed to in writing, software # // distributed under the License is distributed on an "AS IS" BASIS, # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # // See the License for the specific language governing permissions and # // limitations under the License. from itertools import chain from typing import Callable, Dict, List, Tuple import einops import torch def flatten( hid: List[torch.FloatTensor], # List of (*** c) ) -> Tuple[ torch.FloatTensor, # (L c) torch.LongTensor, # (b n) ]: assert len(hid) > 0 shape = torch.stack([torch.tensor(x.shape[:-1], device=hid[0].device) for x in hid]) hid = torch.cat([x.flatten(0, -2) for x in hid]) return hid, shape def unflatten( hid: torch.FloatTensor, # (L c) or (L ... c) hid_shape: torch.LongTensor, # (b n) ) -> List[torch.Tensor]: # List of (*** c) or (*** ... c) hid_len = hid_shape.prod(-1) hid = hid.split(hid_len.tolist()) hid = [x.unflatten(0, s.tolist()) for x, s in zip(hid, hid_shape)] return hid def concat( vid: torch.FloatTensor, # (VL ... c) txt: torch.FloatTensor, # (TL ... c) vid_len: torch.LongTensor, # (b) txt_len: torch.LongTensor, # (b) ) -> torch.FloatTensor: # (L ... c) vid = torch.split(vid, vid_len.tolist()) txt = torch.split(txt, txt_len.tolist()) return torch.cat(list(chain(*zip(vid, txt)))) def concat_idx( vid_len: torch.LongTensor, # (b) txt_len: torch.LongTensor, # (b) ) -> Tuple[ Callable, Callable, ]: device = vid_len.device vid_idx = torch.arange(vid_len.sum(), device=device) txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device) tgt_idx = concat(vid_idx, txt_idx, vid_len, txt_len) src_idx = torch.argsort(tgt_idx) return ( lambda vid, txt: torch.index_select(torch.cat([vid, txt]), 0, tgt_idx), lambda all: torch.index_select(all, 0, src_idx).split([len(vid_idx), len(txt_idx)]), ) def unconcat( all: torch.FloatTensor, # (L ... c) vid_len: torch.LongTensor, # (b) txt_len: torch.LongTensor, # (b) ) -> Tuple[ torch.FloatTensor, # (VL ... c) torch.FloatTensor, # (TL ... c) ]: interleave_len = list(chain(*zip(vid_len.tolist(), txt_len.tolist()))) all = all.split(interleave_len) vid = torch.cat(all[0::2]) txt = torch.cat(all[1::2]) return vid, txt def repeat_concat( vid: torch.FloatTensor, # (VL ... c) txt: torch.FloatTensor, # (TL ... c) vid_len: torch.LongTensor, # (n*b) txt_len: torch.LongTensor, # (b) txt_repeat: List, # (n) ) -> torch.FloatTensor: # (L ... c) vid = torch.split(vid, vid_len.tolist()) txt = torch.split(txt, txt_len.tolist()) txt = [[x] * n for x, n in zip(txt, txt_repeat)] txt = list(chain(*txt)) return torch.cat(list(chain(*zip(vid, txt)))) def repeat_concat_idx( vid_len: torch.LongTensor, # (n*b) txt_len: torch.LongTensor, # (b) txt_repeat: torch.LongTensor, # (n) ) -> Tuple[ Callable, Callable, ]: device = vid_len.device vid_idx = torch.arange(vid_len.sum(), device=device) txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device) txt_repeat_list = txt_repeat.tolist() tgt_idx = repeat_concat(vid_idx, txt_idx, vid_len, txt_len, txt_repeat) src_idx = torch.argsort(tgt_idx) txt_idx_len = len(tgt_idx) - len(vid_idx) repeat_txt_len = (txt_len * txt_repeat).tolist() def unconcat_coalesce(all): """ Un-concat vid & txt, and coalesce the repeated txt. e.g. vid [0 1 2 3 4 5 6 7 8] -> 3 splits -> [0 1 2] [3 4 5] [6 7 8] txt [9 10] repeat_concat ==> [0 1 2 9 10 3 4 5 9 10 6 7 8 9 10] 1. argsort re-index ==> [0 1 2 3 4 5 6 7 8 9 9 9 10 10 10] split ==> vid_out [0 1 2 3 4 5 6 7 8] txt_out [9 9 9 10 10 10] 2. reshape & mean for each sample to coalesce the repeated txt. """ vid_out, txt_out = all[src_idx].split([len(vid_idx), txt_idx_len]) txt_out_coalesced = [] for txt, repeat_time in zip(txt_out.split(repeat_txt_len), txt_repeat_list): txt = txt.reshape(-1, repeat_time, *txt.shape[1:]).mean(1) txt_out_coalesced.append(txt) return vid_out, torch.cat(txt_out_coalesced) # Note: Backward of torch.index_select is non-deterministic when existing repeated index, # the difference may cumulative like torch.repeat_interleave, so we use vanilla index here. return ( lambda vid, txt: torch.cat([vid, txt])[tgt_idx], lambda all: unconcat_coalesce(all), ) def rearrange( hid: torch.FloatTensor, # (L c) hid_shape: torch.LongTensor, # (b n) pattern: str, **kwargs: Dict[str, int], ) -> Tuple[ torch.FloatTensor, torch.LongTensor, ]: return flatten([einops.rearrange(h, pattern, **kwargs) for h in unflatten(hid, hid_shape)]) def rearrange_idx( hid_shape: torch.LongTensor, # (b n) pattern: str, **kwargs: Dict[str, int], ) -> Tuple[Callable, Callable, torch.LongTensor]: hid_idx = torch.arange(hid_shape.prod(-1).sum(), device=hid_shape.device).unsqueeze(-1) tgt_idx, tgt_shape = rearrange(hid_idx, hid_shape, pattern, **kwargs) tgt_idx = tgt_idx.squeeze(-1) src_idx = torch.argsort(tgt_idx) return ( lambda hid: torch.index_select(hid, 0, tgt_idx), lambda hid: torch.index_select(hid, 0, src_idx), tgt_shape, ) def repeat( hid: torch.FloatTensor, # (L c) hid_shape: torch.LongTensor, # (b n) pattern: str, **kwargs: Dict[str, torch.LongTensor], # (b) ) -> Tuple[ torch.FloatTensor, torch.LongTensor, ]: hid = unflatten(hid, hid_shape) kwargs = [{k: v[i].item() for k, v in kwargs.items()} for i in range(len(hid))] return flatten([einops.repeat(h, pattern, **a) for h, a in zip(hid, kwargs)]) def pack( samples: List[torch.Tensor], # List of (h w c). ) -> Tuple[ List[torch.Tensor], # groups [(b1 h1 w1 c1), (b2 h2 w2 c2)] List[List[int]], # reversal indices. ]: batches = {} indices = {} for i, sample in enumerate(samples): shape = sample.shape batches[shape] = batches.get(shape, []) indices[shape] = indices.get(shape, []) batches[shape].append(sample) indices[shape].append(i) batches = list(map(torch.stack, batches.values())) indices = list(indices.values()) return batches, indices def unpack( batches: List[torch.Tensor], indices: List[List[int]], ) -> List[torch.Tensor]: samples = [None] * (max(chain(*indices)) + 1) for batch, index in zip(batches, indices): for sample, i in zip(batch.unbind(), index): samples[i] = sample return samples def window( hid: torch.FloatTensor, # (L c) hid_shape: torch.LongTensor, # (b n) window_fn: Callable[[torch.Tensor], List[torch.Tensor]], ): hid = unflatten(hid, hid_shape) hid = list(map(window_fn, hid)) hid_windows = torch.tensor(list(map(len, hid)), device=hid_shape.device) hid, hid_shape = flatten(list(chain(*hid))) return hid, hid_shape, hid_windows def window_idx( hid_shape: torch.LongTensor, # (b n) window_fn: Callable[[torch.Tensor], List[torch.Tensor]], ): hid_idx = torch.arange(hid_shape.prod(-1).sum(), device=hid_shape.device).unsqueeze(-1) tgt_idx, tgt_shape, tgt_windows = window(hid_idx, hid_shape, window_fn) tgt_idx = tgt_idx.squeeze(-1) src_idx = torch.argsort(tgt_idx) return ( lambda hid: torch.index_select(hid, 0, tgt_idx), lambda hid: torch.index_select(hid, 0, src_idx), tgt_shape, tgt_windows, )