import argparse from dataclasses import dataclass from typing import Any, Dict, List from sglang.srt.server_args import ATTENTION_BACKEND_CHOICES @dataclass class TrackerArgs: report_to: str = "none" wandb_project: str = None wandb_name: str = None wandb_key: str = None swanlab_project: str = None swanlab_name: str = None swanlab_key: str = None mlflow_experiment_id: str = None mlflow_run_name: str = None mlflow_run_id: str = None mlflow_tracking_uri: str = None mlflow_registry_uri: str = None @staticmethod def add_args(parser: argparse.ArgumentParser) -> None: parser.add_argument( "--report-to", type=str, default="none", choices=["wandb", "tensorboard", "swanlab", "mlflow", "none"], help="The integration to report results and logs to.", ) # wandb-specific args parser.add_argument("--wandb-project", type=str, default=None) parser.add_argument("--wandb-name", type=str, default=None) parser.add_argument("--wandb-key", type=str, default=None, help="W&B API key.") # swanlab-specific args parser.add_argument( "--swanlab-project", type=str, default=None, help="The project name for swanlab.", ) parser.add_argument( "--swanlab-name", type=str, default=None, help="The experiment name for swanlab.", ) parser.add_argument( "--swanlab-key", type=str, default=None, help="The API key for swanlab non-interactive login.", ) # mlflow-specific args parser.add_argument( "--mlflow-tracking-uri", type=str, default=None, help="The MLflow tracking URI. If not set, uses MLFLOW_TRACKING_URI environment variable or defaults to local './mlruns'.", ) parser.add_argument( "--mlflow-experiment-name", type=str, default=None, help="The MLflow experiment name. If not set, uses MLFLOW_EXPERIMENT_NAME environment variable.", ) parser.add_argument( "--mlflow-run-name", type=str, default=None, help="The MLflow run name. If not set, MLflow will auto-generate one.", ) @dataclass class SGLangBackendArgs: sglang_attention_backend: str = "fa3" sglang_mem_fraction_static: float = 0.4 sglang_context_length: int = None sglang_enable_nccl_nvls: bool = False sglang_enable_symm_mem: bool = False sglang_enable_torch_compile: bool = True sglang_enable_dp_attention: bool = False sglang_enable_dp_lm_head: bool = False sglang_enable_piecewise_cuda_graph: bool = False sglang_piecewise_cuda_graph_max_tokens: int = 4096 sglang_piecewise_cuda_graph_tokens: List[int] = None sglang_ep_size: int = 1 sglang_max_running_requests: int = None # assign based on batch size sglang_max_total_tokens: int = None # assign based on batch size and seq length @staticmethod def add_args(parser: argparse.ArgumentParser) -> None: # sglang arguments parser.add_argument( "--sglang-attention-backend", type=str, default="flashinfer", choices=ATTENTION_BACKEND_CHOICES, help="The attention backend of SGLang backend", ) parser.add_argument( "--sglang-mem-fraction-static", type=float, default=0.4, help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.", ) parser.add_argument( "--sglang-context-length", type=int, default=None, help="The context length of the SGLang backend", ) parser.add_argument( "--sglang-enable-nccl-nvls", action="store_true", help="Enable NCCL NVLS for prefill heavy requests when available for SGLang backend", ) parser.add_argument( "--sglang-enable-symm-mem", action="store_true", help="Enable NCCL symmetric memory for fast collectives for SGLang backend", ) parser.add_argument( "--sglang-enable-torch-compile", action="store_true", help="Optimize the model with torch.compile for SGLang backend", ) parser.add_argument( "--sglang-enable-dp-attention", action="store_true", help="Enable DP attention for SGLang backend", ) parser.add_argument( "--sglang-enable-dp-lm-head", action="store_true", help="Enable piecewise CUDA graph for SGLang backend", ) parser.add_argument( "--sglang-enable-piecewise-cuda-graph", action="store_true", help="Enable piecewise CUDA graph for SGLang backend's prefill", ) parser.add_argument( "--sglang-piecewise-cuda-graph-max-tokens", type=int, default=4096, help="Set the max tokens for piecewise CUDA graph for SGLang backend", ) parser.add_argument( "--sglang-piecewise-cuda-graph-tokens", type=int, nargs="+", default=None, help="Set the list of tokens when using piecewise cuda graph for SGLang backend", ) parser.add_argument( "--sglang-ep-size", type=int, default=1, help="The ep size of the SGLang backend", ) @staticmethod def from_args(args: argparse.Namespace) -> "SGLangBackendArgs": return SGLangBackendArgs( sglang_attention_backend=args.sglang_attention_backend, sglang_mem_fraction_static=args.sglang_mem_fraction_static, sglang_context_length=args.sglang_context_length, sglang_enable_nccl_nvls=args.sglang_enable_nccl_nvls, sglang_enable_symm_mem=args.sglang_enable_symm_mem, sglang_enable_torch_compile=args.sglang_enable_torch_compile, sglang_enable_dp_attention=args.sglang_enable_dp_attention, sglang_enable_dp_lm_head=args.sglang_enable_dp_lm_head, sglang_enable_piecewise_cuda_graph=args.sglang_enable_piecewise_cuda_graph, sglang_piecewise_cuda_graph_max_tokens=args.sglang_piecewise_cuda_graph_max_tokens, sglang_piecewise_cuda_graph_tokens=args.sglang_piecewise_cuda_graph_tokens, sglang_ep_size=args.sglang_ep_size, sglang_max_running_requests=( args.target_batch_size if hasattr(args, "target_batch_size") else None ), sglang_max_total_tokens=( args.target_batch_size * args.max_length if hasattr(args, "target_batch_size") and hasattr(args, "max_length") else None ), ) def to_kwargs(self) -> Dict[str, Any]: return dict( attention_backend=self.sglang_attention_backend, mem_fraction_static=self.sglang_mem_fraction_static, context_length=self.sglang_context_length, enable_nccl_nvls=self.sglang_enable_nccl_nvls, enable_symm_mem=self.sglang_enable_symm_mem, enable_torch_compile=self.sglang_enable_torch_compile, enable_dp_attention=self.sglang_enable_dp_attention, enable_dp_lm_head=self.sglang_enable_dp_lm_head, enable_piecewise_cuda_graph=self.sglang_enable_piecewise_cuda_graph, piecewise_cuda_graph_max_tokens=self.sglang_piecewise_cuda_graph_max_tokens, piecewise_cuda_graph_tokens=self.sglang_piecewise_cuda_graph_tokens, ep_size=self.sglang_ep_size, max_running_requests=self.sglang_max_running_requests, max_total_tokens=self.sglang_max_total_tokens, )