File size: 8,158 Bytes
62dca4c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 | import argparse
from dataclasses import dataclass
from typing import Any, Dict, List
from sglang.srt.server_args import ATTENTION_BACKEND_CHOICES
@dataclass
class TrackerArgs:
report_to: str = "none"
wandb_project: str = None
wandb_name: str = None
wandb_key: str = None
swanlab_project: str = None
swanlab_name: str = None
swanlab_key: str = None
mlflow_experiment_id: str = None
mlflow_run_name: str = None
mlflow_run_id: str = None
mlflow_tracking_uri: str = None
mlflow_registry_uri: str = None
@staticmethod
def add_args(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--report-to",
type=str,
default="none",
choices=["wandb", "tensorboard", "swanlab", "mlflow", "none"],
help="The integration to report results and logs to.",
)
# wandb-specific args
parser.add_argument("--wandb-project", type=str, default=None)
parser.add_argument("--wandb-name", type=str, default=None)
parser.add_argument("--wandb-key", type=str, default=None, help="W&B API key.")
# swanlab-specific args
parser.add_argument(
"--swanlab-project",
type=str,
default=None,
help="The project name for swanlab.",
)
parser.add_argument(
"--swanlab-name",
type=str,
default=None,
help="The experiment name for swanlab.",
)
parser.add_argument(
"--swanlab-key",
type=str,
default=None,
help="The API key for swanlab non-interactive login.",
)
# mlflow-specific args
parser.add_argument(
"--mlflow-tracking-uri",
type=str,
default=None,
help="The MLflow tracking URI. If not set, uses MLFLOW_TRACKING_URI environment variable or defaults to local './mlruns'.",
)
parser.add_argument(
"--mlflow-experiment-name",
type=str,
default=None,
help="The MLflow experiment name. If not set, uses MLFLOW_EXPERIMENT_NAME environment variable.",
)
parser.add_argument(
"--mlflow-run-name",
type=str,
default=None,
help="The MLflow run name. If not set, MLflow will auto-generate one.",
)
@dataclass
class SGLangBackendArgs:
sglang_attention_backend: str = "fa3"
sglang_mem_fraction_static: float = 0.4
sglang_context_length: int = None
sglang_enable_nccl_nvls: bool = False
sglang_enable_symm_mem: bool = False
sglang_enable_torch_compile: bool = True
sglang_enable_dp_attention: bool = False
sglang_enable_dp_lm_head: bool = False
sglang_enable_piecewise_cuda_graph: bool = False
sglang_piecewise_cuda_graph_max_tokens: int = 4096
sglang_piecewise_cuda_graph_tokens: List[int] = None
sglang_ep_size: int = 1
sglang_max_running_requests: int = None # assign based on batch size
sglang_max_total_tokens: int = None # assign based on batch size and seq length
@staticmethod
def add_args(parser: argparse.ArgumentParser) -> None:
# sglang arguments
parser.add_argument(
"--sglang-attention-backend",
type=str,
default="flashinfer",
choices=ATTENTION_BACKEND_CHOICES,
help="The attention backend of SGLang backend",
)
parser.add_argument(
"--sglang-mem-fraction-static",
type=float,
default=0.4,
help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
)
parser.add_argument(
"--sglang-context-length",
type=int,
default=None,
help="The context length of the SGLang backend",
)
parser.add_argument(
"--sglang-enable-nccl-nvls",
action="store_true",
help="Enable NCCL NVLS for prefill heavy requests when available for SGLang backend",
)
parser.add_argument(
"--sglang-enable-symm-mem",
action="store_true",
help="Enable NCCL symmetric memory for fast collectives for SGLang backend",
)
parser.add_argument(
"--sglang-enable-torch-compile",
action="store_true",
help="Optimize the model with torch.compile for SGLang backend",
)
parser.add_argument(
"--sglang-enable-dp-attention",
action="store_true",
help="Enable DP attention for SGLang backend",
)
parser.add_argument(
"--sglang-enable-dp-lm-head",
action="store_true",
help="Enable piecewise CUDA graph for SGLang backend",
)
parser.add_argument(
"--sglang-enable-piecewise-cuda-graph",
action="store_true",
help="Enable piecewise CUDA graph for SGLang backend's prefill",
)
parser.add_argument(
"--sglang-piecewise-cuda-graph-max-tokens",
type=int,
default=4096,
help="Set the max tokens for piecewise CUDA graph for SGLang backend",
)
parser.add_argument(
"--sglang-piecewise-cuda-graph-tokens",
type=int,
nargs="+",
default=None,
help="Set the list of tokens when using piecewise cuda graph for SGLang backend",
)
parser.add_argument(
"--sglang-ep-size",
type=int,
default=1,
help="The ep size of the SGLang backend",
)
@staticmethod
def from_args(args: argparse.Namespace) -> "SGLangBackendArgs":
return SGLangBackendArgs(
sglang_attention_backend=args.sglang_attention_backend,
sglang_mem_fraction_static=args.sglang_mem_fraction_static,
sglang_context_length=args.sglang_context_length,
sglang_enable_nccl_nvls=args.sglang_enable_nccl_nvls,
sglang_enable_symm_mem=args.sglang_enable_symm_mem,
sglang_enable_torch_compile=args.sglang_enable_torch_compile,
sglang_enable_dp_attention=args.sglang_enable_dp_attention,
sglang_enable_dp_lm_head=args.sglang_enable_dp_lm_head,
sglang_enable_piecewise_cuda_graph=args.sglang_enable_piecewise_cuda_graph,
sglang_piecewise_cuda_graph_max_tokens=args.sglang_piecewise_cuda_graph_max_tokens,
sglang_piecewise_cuda_graph_tokens=args.sglang_piecewise_cuda_graph_tokens,
sglang_ep_size=args.sglang_ep_size,
sglang_max_running_requests=(
args.target_batch_size if hasattr(args, "target_batch_size") else None
),
sglang_max_total_tokens=(
args.target_batch_size * args.max_length
if hasattr(args, "target_batch_size") and hasattr(args, "max_length")
else None
),
)
def to_kwargs(self) -> Dict[str, Any]:
return dict(
attention_backend=self.sglang_attention_backend,
mem_fraction_static=self.sglang_mem_fraction_static,
context_length=self.sglang_context_length,
enable_nccl_nvls=self.sglang_enable_nccl_nvls,
enable_symm_mem=self.sglang_enable_symm_mem,
enable_torch_compile=self.sglang_enable_torch_compile,
enable_dp_attention=self.sglang_enable_dp_attention,
enable_dp_lm_head=self.sglang_enable_dp_lm_head,
enable_piecewise_cuda_graph=self.sglang_enable_piecewise_cuda_graph,
piecewise_cuda_graph_max_tokens=self.sglang_piecewise_cuda_graph_max_tokens,
piecewise_cuda_graph_tokens=self.sglang_piecewise_cuda_graph_tokens,
ep_size=self.sglang_ep_size,
max_running_requests=self.sglang_max_running_requests,
max_total_tokens=self.sglang_max_total_tokens,
)
|