import torch import torch.nn.functional as F from PIL import Image import torch def mask_token_segment( start_id: int, end_id: int, input_ids: torch.Tensor, fill_value: int = -100): """ Replace *every* token from each `start_id` **through** its matching `end_id` (boundaries included) with `fill_value`. Any spans that start with some other token are left untouched. Works on CUDA, TorchScript, batched via vmap, etc.—no Python loops. """ if input_ids.dim() != 1: raise ValueError("`input_ids` must be 1-D") device = input_ids.device n = input_ids.size(0) # where the *target* start-tokens and end-tokens sit start_pos = (input_ids == start_id).nonzero(as_tuple=True)[0] # ascending end_pos = (input_ids == end_id).nonzero(as_tuple=True)[0] # ascending if start_pos.numel() == 0: return input_ids.clone() # ── pair every start with the first end that comes *after* it ──────────────── # searchsorted gives the insertion index into the (sorted) end positions idx_in_end = torch.searchsorted(end_pos, start_pos, right=False) have_match = idx_in_end < end_pos.size(0) # safety: drop unmatched start_pos = start_pos[have_match] end_pos = end_pos[idx_in_end[have_match]] # (rare) guard against pathological orderings keep = end_pos > start_pos start_pos, end_pos = start_pos[keep], end_pos[keep] if start_pos.numel() == 0: return input_ids # ── differential “scan-line” trick to build the span mask in O(N) ─────────── # +1 at each start index, -1 at the element *after* each end delta = torch.zeros(n + 1, dtype=torch.int8, device=device) delta[start_pos] += 1 delta[end_pos + 1] -= 1 # +1 is safe because delta is length n+1 inside = torch.cumsum(delta[:-1], dim=0) > 0 # boolean mask, incl. boundaries # ── apply ──────────────────────────────────────────────────────────────────── out = input_ids.clone() out[inside] = fill_value return out def maybe_zero_3(param, ignore_status=False, name=None): from deepspeed import zero from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus if hasattr(param, "ds_id"): if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: if not ignore_status: print(name, 'no ignore status') with zero.GatheredParameters([param]): param = param.data.detach().cpu().clone() else: param = param.detach().cpu().clone() return param # Borrowed from peft.util.get_peft_model_state_dict def get_peft_state_maybe_zero_3(named_params, bias): if bias == "none": to_return = {k: t for k, t in named_params if "lora_" in k} elif bias == "all": to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} elif bias == "lora_only": to_return = {} maybe_lora_bias = {} lora_bias_names = set() for k, t in named_params: if "lora_" in k: to_return[k] = t bias_name = k.split("lora_")[0] + "bias" lora_bias_names.add(bias_name) elif "bias" in k: maybe_lora_bias[k] = t for k, t in maybe_lora_bias: if bias_name in lora_bias_names: to_return[bias_name] = t else: raise NotImplementedError to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()} return to_return def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): to_return = {k: t for k, t in named_params if "lora_" not in k} if require_grad_only: to_return = {k: t for k, t in to_return.items() if t.requires_grad} to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} return to_return def find_all_linear_names(modules): lora_module_names = set() for name, module in modules(): if isinstance(module, torch.nn.Linear): names = name.split('.') lora_module_names.add(names[0] if len(names) == 1 else names[-1]) if 'lm_head' in lora_module_names: # needed for 16-bit lora_module_names.remove('lm_head') return list(lora_module_names) def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result def pad_and_stack(img_list, pad_value=0.0): """ img_list : list[Tensor] each (C, H, W) already *normalised* pad_value: float or tuple/list of 3 floats (one per channel) Use 0.0 if your processor has already centred to mean 0. Returns ------- batch : Tensor (B, C, H_max, W_max) """ # 1. target square size --------------------------------------------------- h_max = max(t.shape[1] for t in img_list) w_max = max(t.shape[2] for t in img_list) H, W = max(h_max, w_max), max(h_max, w_max) # 2. create padded copies ------------------------------------------------- padded = [] for img in img_list: c, h, w = img.shape canvas = img.new_full((c, H, W), pad_value) # filled with mean/zeros canvas[:, :h, :w] = img # top-left corner padded.append(canvas) return torch.stack(padded, 0) # (B,C,H,W)