File size: 12,185 Bytes
7a60a87 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 | # coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in HuggingFace Transformers.
# Portions of this code are adapted from:
# - https://github.com/SafeAILab/EAGLE (Apache License 2.0)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Any, Dict, List, Optional
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler
from datasets import Dataset
from specforge.distributed import get_draft_sp_group, get_sp_ulysses_group
class DataCollatorWithPadding:
"""
Datacollator that will dynamically pad the inputs for batching.
"""
def __init__(self):
self.sp_degree = torch.distributed.get_world_size(get_draft_sp_group())
self.ulysses_degree = torch.distributed.get_world_size(get_sp_ulysses_group())
def paddingtensor(self, intensors: torch.Tensor, N: int) -> torch.Tensor:
"""
Pad to the longest sequence in the batch.
Args:
intensors: (B, n, S)
N: the length to pad to, N >= n
Returns:
outtensors: (B, N, S)
"""
B, n, S = intensors.shape
padding_tensor = torch.zeros(
B, N - n, S, dtype=intensors.dtype, device=intensors.device
)
outtensors = torch.cat((intensors, padding_tensor), dim=1)
return outtensors
def paddingtensor2D(self, intensors: torch.Tensor, N: int) -> torch.Tensor:
"""
Pad 2D tensor to the longest sequence in the batch.
Args:
intensors: (B, n)
N: the length to pad to, N >= n
Returns:
outtensors: (B, N)
"""
B, n = intensors.shape
padding_tensor = torch.zeros(
B, N - n, dtype=intensors.dtype, device=intensors.device
)
outtensors = torch.cat((intensors, padding_tensor), dim=1)
return outtensors
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Collate a batch of features.
Args:
features: A list of features, where each feature is a dictionary containing:
- input_ids: torch.Tensor of shape (n,)
- attention_mask: torch.Tensor of shape (n,)
- loss_mask: torch.Tensor of shape (n,)
Returns:
A dictionary containing:
- input_ids: torch.Tensor of shape (B, N)
- attention_mask: torch.Tensor of shape (B, N)
- loss_mask: torch.Tensor of shape (B, N)
"""
max_length = max(item["input_ids"].shape[1] for item in features)
# pad for sequence parrel
max_length = (
(max_length + self.sp_degree - 1) // self.sp_degree
) * self.sp_degree
# position max len, ulysses do not need chuck position ids
position_max_len = max_length * self.ulysses_degree
batch_input_ids = torch.cat(
[self.paddingtensor2D(item["input_ids"], max_length) for item in features]
)
batch_attention_mask = torch.cat(
[
self.paddingtensor2D(item["attention_mask"], max_length)
for item in features
]
)
batch_loss_mask = torch.cat(
[self.paddingtensor2D(item["loss_mask"], max_length) for item in features]
)
if "position_ids" in features[0]:
batch_position_ids = torch.cat(
[
self.paddingtensor2D(item["position_ids"], position_max_len)
for item in features
]
)
else:
batch_position_ids = None
batch = {
"input_ids": batch_input_ids,
"attention_mask": batch_attention_mask,
"loss_mask": batch_loss_mask,
"hidden_state": None,
"target": None,
}
if batch_position_ids is not None:
batch["position_ids"] = batch_position_ids
if all("hidden_state" in item for item in features):
assert all(
"target" in item for item in features
), "target is required when hidden_state is provided"
if self.sp_degree > 1: # USP mode
batch["hidden_state"] = torch.cat(
[item["hidden_state"] for item in features]
)
else:
batch["hidden_state"] = torch.cat(
[
self.paddingtensor(item["hidden_state"], max_length)
for item in features
]
)
batch["target"] = torch.cat(
[self.paddingtensor(item["target"], max_length) for item in features]
)
return batch
class VlmDataCollatorWithPadding:
"""
Datacollator that will dynamically pad the inputs for batching.
"""
def paddingtensor(self, intensors: torch.Tensor, N: int) -> torch.Tensor:
"""
Pad to the longest sequence in the batch.
Args:
intensors: (B, n, S)
N: the length to pad to, N >= n
Returns:
outtensors: (B, N, S)
"""
B, n, S = intensors.shape
padding_tensor = torch.zeros(B, N - n, S, dtype=intensors.dtype)
outtensors = torch.cat((intensors, padding_tensor), dim=1)
return outtensors
def paddingtensor2D(self, intensors: torch.Tensor, N: int) -> torch.Tensor:
"""
Pad 2D tensor to the longest sequence in the batch.
Args:
intensors: (B, n)
N: the length to pad to, N >= n
Returns:
outtensors: (B, N)
"""
B, n = intensors.shape
padding_tensor = torch.zeros(B, N - n, dtype=intensors.dtype)
outtensors = torch.cat((intensors, padding_tensor), dim=1)
return outtensors
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Collate a batch of features.
Args:
features: A list of features, where each feature is a dictionary containing:
- input_ids: torch.Tensor of shape (n,)
- attention_mask: torch.Tensor of shape (n,)
- loss_mask: torch.Tensor of shape (n,)
- pixel_values: torch.Tensor of shape (grid_t * grid_h * grid_w, channel * temporal_patch_size * patch_size * patch_size)
- image_grid_thw: torch.Tensor of shape (3,)
Returns:
A dictionary containing:
- input_ids: torch.Tensor of shape (B, N)
- attention_mask: torch.Tensor of shape (B, N)
- loss_mask: torch.Tensor of shape (B, N)
"""
max_length = max(item["input_ids"].shape[1] for item in features)
batch_input_ids = torch.cat(
[self.paddingtensor2D(item["input_ids"], max_length) for item in features]
)
batch_attention_mask = torch.cat(
[
self.paddingtensor2D(item["attention_mask"], max_length)
for item in features
]
)
batch_loss_mask = torch.cat(
[self.paddingtensor2D(item["loss_mask"], max_length) for item in features]
)
batch_pixel_values = torch.cat(
[item["pixel_values"] for item in features], dim=0
)
batch_image_grid_thw = torch.cat(
[item["image_grid_thw"] for item in features], dim=0
)
batch = {
"input_ids": batch_input_ids,
"attention_mask": batch_attention_mask,
"loss_mask": batch_loss_mask,
"pixel_values": batch_pixel_values,
"image_grid_thw": batch_image_grid_thw,
"hidden_state": None,
"target": None,
}
if all("hidden_state" in item for item in features):
assert all(
"target" in item for item in features
), "target is required when hidden_state is provided"
batch["hidden_state"] = torch.cat(
[
self.paddingtensor(item["hidden_state"], max_length)
for item in features
]
)
batch["target"] = torch.cat(
[self.paddingtensor(item["target"], max_length) for item in features]
)
return batch
def prepare_dp_dataloaders(
dataset: Dataset,
batch_size: int,
num_workers: int = 4,
process_group: Optional[dist.ProcessGroup] = None,
pin_memory: Optional[bool] = False,
shuffle: Optional[bool] = False,
is_vlm: Optional[bool] = False,
prefetch_factor: Optional[int] = 2,
**dataloader_kwargs,
) -> DataLoader:
"""
Prepare dataloader for distributed data parallel training.
Args:
dataset: The dataset to load data from.
batch_size: The batch size for each GPU.
num_workers: The number of workers for data loading.
process_group: The process group for distributed training.
pin_memory: Whether to pin memory for data loading.
shuffle: Whether to shuffle the dataset.
is_vlm: Whether the dataset is a vision-language model dataset.
**dataloader_kwargs: Additional keyword arguments for the DataLoader.
Returns:
A DataLoader for the dataset.
"""
world_size = dist.get_world_size(process_group)
rank = dist.get_rank(process_group)
sampler = DistributedSampler(
dataset, num_replicas=world_size, rank=rank, shuffle=shuffle
)
if is_vlm:
datacollator_cls = VlmDataCollatorWithPadding
else:
datacollator_cls = DataCollatorWithPadding
if num_workers == 0:
prefetch_factor = None
dataloader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
pin_memory=pin_memory,
prefetch_factor=prefetch_factor,
collate_fn=datacollator_cls(),
drop_last=True,
**dataloader_kwargs,
)
return dataloader
def parse_harmony_message_content(content):
"""
解析 content 字符串中的 Harmony 格式。
如果匹配到 Harmony 格式,返回包含 channel 和 content 的列表;
否则,返回原内容并标记为默认 channel。
"""
# 匹配 <|channel|>xxx<|message|>yyy<|end|>
pattern = r"<\|channel\|>(.*?)<\|message\|>(.*?)<\|end|>"
matches = re.findall(pattern, content, re.DOTALL)
if not matches:
# 如果没有匹配到 Harmony 标签,视作普通文本
return [{"channel": "text", "content": content}]
results = []
for channel, msg_body in matches:
results.append({"channel": channel.strip(), "content": msg_body.strip()})
return results
def process_harmony_conversations(conversation):
"""
处理传入的 list[list[dict]] 结构
"""
new_conversation = []
for msg in conversation:
role = msg.get("role")
original_content = msg.get("content", "")
# 解析 content 中的 Harmony 结构
segments = parse_harmony_message_content(original_content)
# 为每个解析出的通道生成一个新的消息字典
for seg in segments:
new_msg = {
"role": role,
"channel": seg["channel"], # 新增字段标识通道
"content": seg["content"],
}
new_conversation.append(new_msg)
return new_conversation
|