""" This script will re-generate the dataset from target model, which better aligns the draft model with the target model’s output distribution. Usage: 1. Set up one or more SGLang servers for the target model. python3 -m sglang.launch_server \ --model Qwen/Qwen3.5-35B-A3B \ --mem-fraction-static 0.7 \ --tp 1 \ --trust-remote-code \ --cuda-graph-max-bs 128 \ --host 0.0.0.0 \ --port 30000 \ --dtype bfloat16 \ --reasoning-parser qwen3 2. Regenerate the dataset using the `regenerate_train_data.py` script. python scripts/regenerate_train_data.py \ --model Qwen/Qwen3.5-35B-A3B \ --concurrency 128 \ --max-tokens 4096 \ --server-address localhost:30000 localhost:30010 localhost:30020 localhost:30030 localhost:30040 localhost:30050 localhost:30060 localhost:30070 \ --temperature 0.8 \ --input-file-path /data/jiapingW/pr/SpecForge/cache/dataset/opc_train_first_turn.jsonl \ --output-file-path ./cache/dataset/opc_train_regen_first_turn.jsonl \ --resume \ --is-reasoning-model """ import argparse import json import os import random from concurrent.futures import ThreadPoolExecutor from typing import Any, Dict, List from openai import OpenAI from tqdm import tqdm def parse_arguments(): """Parse command line arguments""" parser = argparse.ArgumentParser( description="Re-generate training data using sglang model server" ) # model related arguments model_group = parser.add_argument_group("model") model_group.add_argument("--model", type=str, required=True) model_group.add_argument( "--is-reasoning-model", action="store_true", help="Whether the model is a reasoning model", ) model_group.add_argument( "--is-gpt-oss", action="store_true", help="Whether the model is a GPT-OSS model", ) # sampling params sampling_params_group = parser.add_argument_group("sampling parameters") sampling_params_group.add_argument( "--temperature", type=float, default=0.7, help="Temperature for sglang model server", ) sampling_params_group.add_argument( "--top-p", type=float, default=None, help="Nucleus sampling top_p", ) sampling_params_group.add_argument( "--top-k", type=int, default=None, help="Top-k sampling value sent via extra_body", ) sampling_params_group.add_argument( "--repetition-penalty", type=float, default=None, help="Mapped to presence_penalty in the OpenAI API", ) sampling_params_group.add_argument( "--max-tokens", type=int, default=4096, help="Maximum number of tokens (default: 4096)", ) # optimization optimization_group = parser.add_argument_group("optimization") optimization_group.add_argument( "--concurrency", type=int, default=64, help="The number of requests to send to a single server concurrently, the total number of concurrent requests is concurrency * number of server addresses", ) # data related arguments data_group = parser.add_argument_group("data") data_group.add_argument( "--input-file-path", type=str, required=True, help="Path to the input file" ) data_group.add_argument( "--output-file-path", type=str, required=True, help="Path to the output file" ) data_group.add_argument( "--num-samples", type=int, default=None, help="The number of samples to regenerate, if not provided, all samples will be regenerated", ) data_group.add_argument( "--resume", action="store_true", help="Resume from existing output file, skip already processed samples", ) # sglang server server_group = parser.add_argument_group("sglang server") server_group.add_argument( "--server-address", type=str, nargs="+", help="Server address and port for sglang model server", ) return parser.parse_args() def get_random_reasoning_effort() -> str: """Get a random reasoning effort level for the model with weighted probabilities.""" # usage example: https://huggingface.co/openai/gpt-oss-20b/discussions/28 # Reasoning effort levels with weights: LOW(4), MEDIUM(4), HIGH(2) reasoning_efforts = [ "low", "medium", "high", ] weights = [4, 4, 2] return random.choices(reasoning_efforts, weights=weights, k=1)[0] def compute_context_length(conversations: List[Dict[str, Any]]) -> int: """ This is a rough estimate of the context length measured in untokenized tokens. """ length = 0 for message in conversations: content = message.get("content") if isinstance(content, str): # {"role": "assistant", "content": "Hi, how can I help?"} length += len(content.split()) elif isinstance(content, list): for part in content: if isinstance(part, dict): text = part.get("text") if isinstance(text, str): length += len(text.split()) return length def build_query_kwargs(args, messages, max_tokens=None): effective_max_tokens = max_tokens if max_tokens is not None else args.max_tokens query_kwargs = dict( model=args.model, messages=messages, max_tokens=effective_max_tokens, temperature=args.temperature, stream=False, ) if args.top_p is not None: query_kwargs["top_p"] = args.top_p if args.repetition_penalty is not None: query_kwargs["presence_penalty"] = args.repetition_penalty extra_body = {} if args.top_k is not None: extra_body["top_k"] = args.top_k if extra_body: query_kwargs["extra_body"] = extra_body if args.is_gpt_oss: query_kwargs["reasoning_effort"] = get_random_reasoning_effort() return query_kwargs def call_sglang( args, server_address: str, data: List[Dict[str, Any]], max_tokens=None, ) -> str: """Send a batch of prompts to sglang /v1/completions.""" client = OpenAI(base_url=f"http://{server_address}/v1", api_key="None") messages = data["conversations"] regenerated_messages = [] # ignore data which starts with an assistant message if messages[0]["role"] == "assistant": data["status"] = "error" data["error"] = "Data starts with an assistant message" return data for message in messages: if message["role"] == "system": regenerated_messages.append(message) elif message["role"] == "assistant": continue elif message["role"] == "user": regenerated_messages.append(message) query_kwargs = build_query_kwargs(args, regenerated_messages, max_tokens) try: resp = client.chat.completions.create(**query_kwargs) except Exception as e: data["status"] = "error" data["error"] = str(e) return data response_text = resp.choices[0].message.content resp_msg = { "role": "assistant", "content": response_text, } if args.is_reasoning_model: resp_msg["reasoning_content"] = resp.choices[ 0 ].message.reasoning_content regenerated_messages.append(resp_msg) else: data["status"] = "error" data["error"] = f"Invalid message role: {message['role']}" return data data["conversations"] = regenerated_messages data["status"] = "success" return data def main(): # Parse command line arguments args = parse_arguments() # Validate parameters if not (0.0 <= args.temperature <= 1.0): raise ValueError("Temperature must be between 0.0 and 1.0") if args.max_tokens <= 0: raise ValueError("Max tokens must be greater than 0") print(f"Configuration:") print(f" Model path: {args.model}") print(f" Max tokens: {args.max_tokens}") print(f" Concurrency: {args.concurrency}") print(f" Temperature: {args.temperature}") print(f" API URL: {args.server_address}") print(f" Input file: {args.input_file_path}") print(f" Output file: {args.output_file_path}") print(f" Resume mode: {args.resume}") print("-" * 50) total_lines = sum(1 for _ in open(args.input_file_path)) skip_lines = 0 error_file_path = args.output_file_path.replace(".jsonl", "_error.jsonl") if args.resume and os.path.exists(args.output_file_path): existing_success = sum(1 for _ in open(args.output_file_path)) existing_error = 0 if os.path.exists(error_file_path): existing_error = sum(1 for _ in open(error_file_path)) skip_lines = existing_success + existing_error print(f"Resume mode enabled:") print(f" Found {existing_success} successful samples in output file") print(f" Found {existing_error} error samples in error file") print(f" Skipping first {skip_lines} input samples") print("-" * 50) if skip_lines >= total_lines: print(f"All {total_lines} samples already processed. Nothing to do.") return # test all server addresses valid_server_addresses = [] for server_address in args.server_address: dummy_data = dict( conversations=[{"role": "user", "content": "Hello, how are you?"}] ) result = call_sglang( args, server_address, dummy_data, max_tokens=1, ) if result is not None: valid_server_addresses.append(server_address) else: print(f"Server {server_address} is not available") if len(valid_server_addresses) == 0: raise ValueError("No server address is available") print( f"Using {len(valid_server_addresses)} server addresses: {valid_server_addresses}" ) print("-" * 50) # Determine file open mode based on resume flag file_mode = "a" if (args.resume and skip_lines > 0) else "w" print( f"Regenerating dataset and saving the output to {args.output_file_path} and error log to {error_file_path}" ) print( f"File open mode: {file_mode} ({'append' if file_mode == 'a' else 'overwrite'})" ) print("-" * 50) context_token_sum = 0 context_token_min = None context_token_max = 0 success_samples = 0 error_samples = 0 # Create progress bar with ( open(args.input_file_path, "r") as input_file, open(args.output_file_path, file_mode) as output_file_handle, open(error_file_path, file_mode) as error_file_handle, ): executor = ThreadPoolExecutor( max_workers=args.concurrency * len(valid_server_addresses) ) waiting_queue = { server_address: [] for server_address in valid_server_addresses } pbar = tqdm(total=total_lines, desc="Processing", initial=skip_lines) start_server_index = 0 if skip_lines > 0: print(f"Skipping {skip_lines} already processed samples...") for _ in range(skip_lines): next(input_file, None) print(f"Resuming from sample {skip_lines + 1}") for line in input_file: if ( args.num_samples is not None and success_samples + error_samples >= args.num_samples ): break data = json.loads(line.strip()) # find server address with the least waiting requests server_address = valid_server_addresses[start_server_index] start_server_index = (start_server_index + 1) % len(valid_server_addresses) # submit prompt to sglang while len(waiting_queue[server_address]) >= args.concurrency: finished_on_request = False # check if any future is done, if so, write the result to the output file for req_future in waiting_queue[server_address]: if req_future.done(): regen_data = req_future.result() if regen_data["status"] == "error": error_file_handle.write( json.dumps(regen_data, ensure_ascii=False) + "\n" ) error_samples += 1 else: ctx_len = compute_context_length( regen_data.get("conversations", []) ) context_token_sum += ctx_len if context_token_min is None: context_token_min = ctx_len else: context_token_min = min(context_token_min, ctx_len) context_token_max = max(context_token_max, ctx_len) output_file_handle.write( json.dumps(regen_data, ensure_ascii=False) + "\n" ) success_samples += 1 waiting_queue[server_address].remove(req_future) finished_on_request = True if finished_on_request: break req_future = executor.submit( call_sglang, args, server_address, data, ) waiting_queue[server_address].append(req_future) pbar.update(1) # deal with all the remaining requests for server_address, waiting_queue_items in waiting_queue.items(): for req_future in waiting_queue_items: regen_data = req_future.result() if regen_data["status"] == "error": error_file_handle.write( json.dumps(regen_data, ensure_ascii=False) + "\n" ) error_samples += 1 else: ctx_len = compute_context_length( regen_data.get("conversations", []) ) context_token_sum += ctx_len if context_token_min is None: context_token_min = ctx_len else: context_token_min = min(context_token_min, ctx_len) context_token_max = max(context_token_max, ctx_len) output_file_handle.write( json.dumps(regen_data, ensure_ascii=False) + "\n" ) success_samples += 1 print(f"\nProcessing completed!") if success_samples > 0: avg_len = context_token_sum / success_samples print("Context length statistics (token count over conversations):") print(f"Number of successful examples: {success_samples}") print(f"Shortest context length: {context_token_min}") print(f"Longest context length: {context_token_max}") print(f"Average context length: {avg_len:.2f}") else: print("No successful examples to compute context length statistics.") total_processed = success_samples + error_samples if skip_lines > 0: print(f"\nResume processing completed!") print(f" Previously processed: {skip_lines}") print( f" Newly processed: {total_processed} ({success_samples} success, {error_samples} failed)" ) print(f" Total: {skip_lines + total_processed}") else: print( f"\nProcessing completed! {success_samples} samples regenerated, {error_samples} samples failed." ) if __name__ == "__main__": main()