| from typing import (Callable, Dict, Iterable, List, Literal, Mapping, Optional, Protocol, Set, Tuple, Union, overload) |
| import torch |
| from torch.func import functional_call |
|
|
| @overload |
| def flatten_bn(x: torch.Tensor) -> torch.Tensor: |
| ... |
|
|
|
|
| @overload |
| def flatten_bn(x: List[torch.Tensor]) -> List[torch.Tensor]: |
| ... |
|
|
|
|
| @overload |
| def flatten_bn( |
| x: Union[List[torch.Tensor], torch.Tensor], |
| *, |
| concat: Literal[True], |
| ) -> torch.Tensor: |
| ... |
|
|
|
|
| @overload |
| def flatten_bn( |
| x: Union[List[torch.Tensor], torch.Tensor], |
| *, |
| concat: bool = False, |
| ) -> Union[List[torch.Tensor], torch.Tensor]: |
| ... |
|
|
|
|
| def flatten_bn( |
| x: Union[List[torch.Tensor], torch.Tensor], |
| *, |
| concat: bool = False, |
| ) -> Union[List[torch.Tensor], torch.Tensor]: |
| """ |
| Flatten the ``B`` and ``N`` dimensions of batched multimodal inputs. |
| |
| The input tensor should have shape ``(B, N, ...)```. |
| """ |
| if isinstance(x, torch.Tensor): |
| return x.flatten(0, 1) |
|
|
| if concat: |
| return torch.cat(x) |
|
|
| return [x_n for x_b in x for x_n in x_b] |
|
|
| def _flatten_embeddings(embeddings: torch.Tensor) -> torch.Tensor: |
| """ |
| Recursively flattens and concatenates NestedTensors on all but the last |
| dimension. |
| """ |
|
|
| if isinstance(embeddings, torch.Tensor): |
| |
| return embeddings.flatten(0, -2) |
|
|
| return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings)) |
|
|
| def _embedding_count_expression(embeddings: torch.Tensor) -> str: |
| """ |
| Constructs a debugging representation of the number of embeddings in the |
| Tensors. |
| """ |
|
|
| if isinstance(embeddings, torch.Tensor): |
| return " x ".join([str(dim) for dim in embeddings.shape[:-1]]) |
|
|
| return " + ".join( |
| _embedding_count_expression(inner) for inner in embeddings) |
|
|
| def _merge_multimodal_embeddings( |
| inputs_embeds: torch.Tensor, |
| is_multimodal: torch.Tensor, |
| multimodal_embeddings: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the |
| positions in ``inputs_embeds`` corresponding to placeholder tokens in |
| ``input_ids``. |
| |
| Note: |
| This updates ``inputs_embeds`` in place. |
| """ |
| num_expected_tokens = is_multimodal.sum().item() |
| assert isinstance(num_expected_tokens, int) |
| |
| flattened = _flatten_embeddings(multimodal_embeddings) |
| if flattened.shape[0] != num_expected_tokens: |
| expr = _embedding_count_expression(multimodal_embeddings) |
| raise ValueError( |
| f"Attempted to assign {expr} = {flattened.shape[0]} " |
| f"multimodal tokens to {num_expected_tokens} placeholders") |
|
|
| inputs_embeds[is_multimodal] = flattened |
| return inputs_embeds |
|
|
| def merge_multimodal_embeddings( |
| input_ids: torch.Tensor, |
| inputs_embeds: torch.Tensor, |
| multimodal_embeddings: torch.Tensor, |
| placeholder_token_id: Union[int, List[int]], |
| ) -> torch.Tensor: |
| """ |
| Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the |
| positions in ``inputs_embeds`` corresponding to placeholder tokens in |
| ``input_ids``. |
| |
| ``placeholder_token_id`` can be a list of token ids (e.g, token ids |
| of img_start, img_break, and img_end tokens) when needed: This means |
| the order of these tokens in the ``input_ids`` MUST MATCH the order of |
| their embeddings in ``multimodal_embeddings`` since we need to |
| slice-merge instead of individually scattering. |
| |
| For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where |
| - T is text token |
| - S is image start token |
| - I is image embedding token |
| - B is image break token |
| - E is image end token. |
| |
| Then the image embeddings (that correspond to I's) from vision encoder |
| must be padded with embeddings of S, B, and E in the same order of |
| input_ids for a correct embedding merge. |
| |
| Note: |
| This updates ``inputs_embeds`` in place. |
| """ |
| if isinstance(placeholder_token_id, list): |
| placeholder_token_id = torch.tensor(placeholder_token_id, |
| device=input_ids.device) |
| return _merge_multimodal_embeddings( |
| inputs_embeds, |
| torch.isin(input_ids, placeholder_token_id), |
| multimodal_embeddings, |
| ) |
| return _merge_multimodal_embeddings( |
| inputs_embeds, |
| (input_ids == placeholder_token_id), |
| multimodal_embeddings, |
| ) |
|
|