File size: 6,012 Bytes
44d9cf1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import torch
import torch.nn.functional as F

from PIL import Image


import torch

def mask_token_segment(
               start_id: int,
               end_id: int,
               input_ids: torch.Tensor,
               fill_value: int = -100):
    """
    Replace *every* token from each `start_id` **through** its matching `end_id`
    (boundaries included) with `fill_value`.  Any spans that start with some
    other token are left untouched.

    Works on CUDA, TorchScript, batched via vmap, etc.β€”no Python loops.
    """
    if input_ids.dim() != 1:
        raise ValueError("`input_ids` must be 1-D")

    device = input_ids.device
    n       = input_ids.size(0)

    # where the *target* start-tokens and end-tokens sit
    start_pos = (input_ids == start_id).nonzero(as_tuple=True)[0]      # ascending
    end_pos   = (input_ids == end_id).nonzero(as_tuple=True)[0]        # ascending

    if start_pos.numel() == 0:
        return input_ids.clone()

    # ── pair every start with the first end that comes *after* it ────────────────
    # searchsorted gives the insertion index into the (sorted) end positions
    idx_in_end = torch.searchsorted(end_pos, start_pos, right=False)

    have_match = idx_in_end < end_pos.size(0)                # safety: drop unmatched
    start_pos  = start_pos[have_match]
    end_pos    = end_pos[idx_in_end[have_match]]

    # (rare) guard against pathological orderings
    keep = end_pos > start_pos
    start_pos, end_pos = start_pos[keep], end_pos[keep]

    if start_pos.numel() == 0:
        return input_ids

    # ── differential β€œscan-line” trick to build the span mask in O(N) ───────────
    # +1 at each start index, -1 at the element *after* each end
    delta = torch.zeros(n + 1, dtype=torch.int8, device=device)
    delta[start_pos]        += 1
    delta[end_pos + 1]      -= 1          # +1 is safe because delta is length n+1

    inside = torch.cumsum(delta[:-1], dim=0) > 0   # boolean mask, incl. boundaries

    # ── apply ────────────────────────────────────────────────────────────────────
    out = input_ids.clone()
    out[inside] = fill_value
    return out



def maybe_zero_3(param, ignore_status=False, name=None):
    from deepspeed import zero
    from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
    if hasattr(param, "ds_id"):
        if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
            if not ignore_status:
                print(name, 'no ignore status')
        with zero.GatheredParameters([param]):
            param = param.data.detach().cpu().clone()
    else:
        param = param.detach().cpu().clone()
    return param


# Borrowed from peft.util.get_peft_model_state_dict
def get_peft_state_maybe_zero_3(named_params, bias):
    if bias == "none":
        to_return = {k: t for k, t in named_params if "lora_" in k}
    elif bias == "all":
        to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
    elif bias == "lora_only":
        to_return = {}
        maybe_lora_bias = {}
        lora_bias_names = set()
        for k, t in named_params:
            if "lora_" in k:
                to_return[k] = t
                bias_name = k.split("lora_")[0] + "bias"
                lora_bias_names.add(bias_name)
            elif "bias" in k:
                maybe_lora_bias[k] = t
        for k, t in maybe_lora_bias:
            if bias_name in lora_bias_names:
                to_return[bias_name] = t
    else:
        raise NotImplementedError
    to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
    return to_return


def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
    to_return = {k: t for k, t in named_params if "lora_" not in k}
    if require_grad_only:
        to_return = {k: t for k, t in to_return.items() if t.requires_grad}
    to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
    return to_return


def find_all_linear_names(modules):
    lora_module_names = set()
    for name, module in modules():
        if isinstance(module, torch.nn.Linear):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])

    if 'lm_head' in lora_module_names:  # needed for 16-bit
        lora_module_names.remove('lm_head')
    return list(lora_module_names)


def expand2square(pil_img, background_color):
    width, height = pil_img.size
    if width == height:
        return pil_img
    elif width > height:
        result = Image.new(pil_img.mode, (width, width), background_color)
        result.paste(pil_img, (0, (width - height) // 2))
        return result
    else:
        result = Image.new(pil_img.mode, (height, height), background_color)
        result.paste(pil_img, ((height - width) // 2, 0))
        return result

def pad_and_stack(img_list, pad_value=0.0):
    """
    img_list : list[Tensor]  each (C, H, W) already *normalised*
    pad_value: float or tuple/list of 3 floats (one per channel)
               Use 0.0 if your processor has already centred to mean 0.
    Returns
    -------
    batch : Tensor  (B, C, H_max, W_max)
    """

    # 1. target square size ---------------------------------------------------
    h_max = max(t.shape[1] for t in img_list)
    w_max = max(t.shape[2] for t in img_list)
    H, W  = max(h_max, w_max), max(h_max, w_max)

    # 2. create padded copies -------------------------------------------------
    padded = []
    for img in img_list:
        c, h, w = img.shape
        canvas   = img.new_full((c, H, W), pad_value)     # filled with mean/zeros
        canvas[:, :h, :w] = img                    # top-left corner
        padded.append(canvas)

    return torch.stack(padded, 0)                  # (B,C,H,W)