File size: 6,012 Bytes
44d9cf1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import torch
import torch.nn.functional as F
from PIL import Image
import torch
def mask_token_segment(
start_id: int,
end_id: int,
input_ids: torch.Tensor,
fill_value: int = -100):
"""
Replace *every* token from each `start_id` **through** its matching `end_id`
(boundaries included) with `fill_value`. Any spans that start with some
other token are left untouched.
Works on CUDA, TorchScript, batched via vmap, etc.βno Python loops.
"""
if input_ids.dim() != 1:
raise ValueError("`input_ids` must be 1-D")
device = input_ids.device
n = input_ids.size(0)
# where the *target* start-tokens and end-tokens sit
start_pos = (input_ids == start_id).nonzero(as_tuple=True)[0] # ascending
end_pos = (input_ids == end_id).nonzero(as_tuple=True)[0] # ascending
if start_pos.numel() == 0:
return input_ids.clone()
# ββ pair every start with the first end that comes *after* it ββββββββββββββββ
# searchsorted gives the insertion index into the (sorted) end positions
idx_in_end = torch.searchsorted(end_pos, start_pos, right=False)
have_match = idx_in_end < end_pos.size(0) # safety: drop unmatched
start_pos = start_pos[have_match]
end_pos = end_pos[idx_in_end[have_match]]
# (rare) guard against pathological orderings
keep = end_pos > start_pos
start_pos, end_pos = start_pos[keep], end_pos[keep]
if start_pos.numel() == 0:
return input_ids
# ββ differential βscan-lineβ trick to build the span mask in O(N) βββββββββββ
# +1 at each start index, -1 at the element *after* each end
delta = torch.zeros(n + 1, dtype=torch.int8, device=device)
delta[start_pos] += 1
delta[end_pos + 1] -= 1 # +1 is safe because delta is length n+1
inside = torch.cumsum(delta[:-1], dim=0) > 0 # boolean mask, incl. boundaries
# ββ apply ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
out = input_ids.clone()
out[inside] = fill_value
return out
def maybe_zero_3(param, ignore_status=False, name=None):
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
if hasattr(param, "ds_id"):
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
if not ignore_status:
print(name, 'no ignore status')
with zero.GatheredParameters([param]):
param = param.data.detach().cpu().clone()
else:
param = param.detach().cpu().clone()
return param
# Borrowed from peft.util.get_peft_model_state_dict
def get_peft_state_maybe_zero_3(named_params, bias):
if bias == "none":
to_return = {k: t for k, t in named_params if "lora_" in k}
elif bias == "all":
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
elif bias == "lora_only":
to_return = {}
maybe_lora_bias = {}
lora_bias_names = set()
for k, t in named_params:
if "lora_" in k:
to_return[k] = t
bias_name = k.split("lora_")[0] + "bias"
lora_bias_names.add(bias_name)
elif "bias" in k:
maybe_lora_bias[k] = t
for k, t in maybe_lora_bias:
if bias_name in lora_bias_names:
to_return[bias_name] = t
else:
raise NotImplementedError
to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
return to_return
def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
to_return = {k: t for k, t in named_params if "lora_" not in k}
if require_grad_only:
to_return = {k: t for k, t in to_return.items() if t.requires_grad}
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
return to_return
def find_all_linear_names(modules):
lora_module_names = set()
for name, module in modules():
if isinstance(module, torch.nn.Linear):
names = name.split('.')
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
if 'lm_head' in lora_module_names: # needed for 16-bit
lora_module_names.remove('lm_head')
return list(lora_module_names)
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
def pad_and_stack(img_list, pad_value=0.0):
"""
img_list : list[Tensor] each (C, H, W) already *normalised*
pad_value: float or tuple/list of 3 floats (one per channel)
Use 0.0 if your processor has already centred to mean 0.
Returns
-------
batch : Tensor (B, C, H_max, W_max)
"""
# 1. target square size ---------------------------------------------------
h_max = max(t.shape[1] for t in img_list)
w_max = max(t.shape[2] for t in img_list)
H, W = max(h_max, w_max), max(h_max, w_max)
# 2. create padded copies -------------------------------------------------
padded = []
for img in img_list:
c, h, w = img.shape
canvas = img.new_full((c, H, W), pad_value) # filled with mean/zeros
canvas[:, :h, :w] = img # top-left corner
padded.append(canvas)
return torch.stack(padded, 0) # (B,C,H,W)
|