| """ |
| This script will re-generate the dataset from target model, |
| which better aligns the draft model with the target model’s output distribution. |
| |
| Usage: |
| 1. Set up one or more SGLang servers for the target model. |
| |
| python3 -m sglang.launch_server \ |
| --model Qwen/Qwen3.5-35B-A3B \ |
| --mem-fraction-static 0.7 \ |
| --tp 1 \ |
| --trust-remote-code \ |
| --cuda-graph-max-bs 128 \ |
| --host 0.0.0.0 \ |
| --port 30000 \ |
| --dtype bfloat16 \ |
| --reasoning-parser qwen3 |
| |
| |
| 2. Regenerate the dataset using the `regenerate_train_data.py` script. |
| python scripts/regenerate_train_data.py \ |
| --model Qwen/Qwen3.5-35B-A3B \ |
| --concurrency 128 \ |
| --max-tokens 4096 \ |
| --server-address localhost:30000 localhost:30010 localhost:30020 localhost:30030 localhost:30040 localhost:30050 localhost:30060 localhost:30070 \ |
| --temperature 0.8 \ |
| --input-file-path /data/jiapingW/pr/SpecForge/cache/dataset/opc_train_first_turn.jsonl \ |
| --output-file-path ./cache/dataset/opc_train_regen_first_turn.jsonl \ |
| --resume \ |
| --is-reasoning-model |
| """ |
|
|
| import argparse |
| import json |
| import os |
| import random |
| from concurrent.futures import ThreadPoolExecutor |
| from typing import Any, Dict, List |
|
|
| from openai import OpenAI |
| from tqdm import tqdm |
|
|
|
|
| def parse_arguments(): |
| """Parse command line arguments""" |
| parser = argparse.ArgumentParser( |
| description="Re-generate training data using sglang model server" |
| ) |
|
|
| |
| model_group = parser.add_argument_group("model") |
| model_group.add_argument("--model", type=str, required=True) |
| model_group.add_argument( |
| "--is-reasoning-model", |
| action="store_true", |
| help="Whether the model is a reasoning model", |
| ) |
| model_group.add_argument( |
| "--is-gpt-oss", |
| action="store_true", |
| help="Whether the model is a GPT-OSS model", |
| ) |
|
|
| |
| sampling_params_group = parser.add_argument_group("sampling parameters") |
| sampling_params_group.add_argument( |
| "--temperature", |
| type=float, |
| default=0.7, |
| help="Temperature for sglang model server", |
| ) |
| sampling_params_group.add_argument( |
| "--top-p", |
| type=float, |
| default=None, |
| help="Nucleus sampling top_p", |
| ) |
| sampling_params_group.add_argument( |
| "--top-k", |
| type=int, |
| default=None, |
| help="Top-k sampling value sent via extra_body", |
| ) |
| sampling_params_group.add_argument( |
| "--repetition-penalty", |
| type=float, |
| default=None, |
| help="Mapped to presence_penalty in the OpenAI API", |
| ) |
| sampling_params_group.add_argument( |
| "--max-tokens", |
| type=int, |
| default=4096, |
| help="Maximum number of tokens (default: 4096)", |
| ) |
|
|
| |
| optimization_group = parser.add_argument_group("optimization") |
| optimization_group.add_argument( |
| "--concurrency", |
| type=int, |
| default=64, |
| help="The number of requests to send to a single server concurrently, the total number of concurrent requests is concurrency * number of server addresses", |
| ) |
|
|
| |
| data_group = parser.add_argument_group("data") |
| data_group.add_argument( |
| "--input-file-path", type=str, required=True, help="Path to the input file" |
| ) |
| data_group.add_argument( |
| "--output-file-path", type=str, required=True, help="Path to the output file" |
| ) |
| data_group.add_argument( |
| "--num-samples", |
| type=int, |
| default=None, |
| help="The number of samples to regenerate, if not provided, all samples will be regenerated", |
| ) |
| data_group.add_argument( |
| "--resume", |
| action="store_true", |
| help="Resume from existing output file, skip already processed samples", |
| ) |
|
|
| |
| server_group = parser.add_argument_group("sglang server") |
| server_group.add_argument( |
| "--server-address", |
| type=str, |
| nargs="+", |
| help="Server address and port for sglang model server", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def get_random_reasoning_effort() -> str: |
| """Get a random reasoning effort level for the model with weighted probabilities.""" |
| |
| |
| reasoning_efforts = [ |
| "low", |
| "medium", |
| "high", |
| ] |
| weights = [4, 4, 2] |
| return random.choices(reasoning_efforts, weights=weights, k=1)[0] |
|
|
|
|
| def compute_context_length(conversations: List[Dict[str, Any]]) -> int: |
| """ |
| This is a rough estimate of the context length measured in untokenized |
| tokens. |
| """ |
| length = 0 |
| for message in conversations: |
| content = message.get("content") |
| if isinstance(content, str): |
| |
| length += len(content.split()) |
| elif isinstance(content, list): |
| for part in content: |
| if isinstance(part, dict): |
| text = part.get("text") |
| if isinstance(text, str): |
| length += len(text.split()) |
| return length |
|
|
|
|
| def build_query_kwargs(args, messages, max_tokens=None): |
| effective_max_tokens = max_tokens if max_tokens is not None else args.max_tokens |
|
|
| query_kwargs = dict( |
| model=args.model, |
| messages=messages, |
| max_tokens=effective_max_tokens, |
| temperature=args.temperature, |
| stream=False, |
| ) |
| if args.top_p is not None: |
| query_kwargs["top_p"] = args.top_p |
| if args.repetition_penalty is not None: |
| query_kwargs["presence_penalty"] = args.repetition_penalty |
| extra_body = {} |
| if args.top_k is not None: |
| extra_body["top_k"] = args.top_k |
| if extra_body: |
| query_kwargs["extra_body"] = extra_body |
| if args.is_gpt_oss: |
| query_kwargs["reasoning_effort"] = get_random_reasoning_effort() |
| return query_kwargs |
|
|
|
|
| def call_sglang( |
| args, |
| server_address: str, |
| data: List[Dict[str, Any]], |
| max_tokens=None, |
| ) -> str: |
| """Send a batch of prompts to sglang /v1/completions.""" |
| client = OpenAI(base_url=f"http://{server_address}/v1", api_key="None") |
|
|
| messages = data["conversations"] |
| regenerated_messages = [] |
|
|
| |
| if messages[0]["role"] == "assistant": |
| data["status"] = "error" |
| data["error"] = "Data starts with an assistant message" |
| return data |
|
|
| for message in messages: |
| if message["role"] == "system": |
| regenerated_messages.append(message) |
| elif message["role"] == "assistant": |
| continue |
| elif message["role"] == "user": |
| regenerated_messages.append(message) |
|
|
| query_kwargs = build_query_kwargs(args, regenerated_messages, max_tokens) |
|
|
| try: |
| resp = client.chat.completions.create(**query_kwargs) |
| except Exception as e: |
| data["status"] = "error" |
| data["error"] = str(e) |
| return data |
| response_text = resp.choices[0].message.content |
| resp_msg = { |
| "role": "assistant", |
| "content": response_text, |
| } |
| if args.is_reasoning_model: |
| resp_msg["reasoning_content"] = resp.choices[ |
| 0 |
| ].message.reasoning_content |
| regenerated_messages.append(resp_msg) |
| else: |
| data["status"] = "error" |
| data["error"] = f"Invalid message role: {message['role']}" |
| return data |
| data["conversations"] = regenerated_messages |
| data["status"] = "success" |
| return data |
|
|
|
|
| def main(): |
| |
| args = parse_arguments() |
|
|
| |
| if not (0.0 <= args.temperature <= 1.0): |
| raise ValueError("Temperature must be between 0.0 and 1.0") |
|
|
| if args.max_tokens <= 0: |
| raise ValueError("Max tokens must be greater than 0") |
|
|
| print(f"Configuration:") |
| print(f" Model path: {args.model}") |
| print(f" Max tokens: {args.max_tokens}") |
| print(f" Concurrency: {args.concurrency}") |
| print(f" Temperature: {args.temperature}") |
| print(f" API URL: {args.server_address}") |
| print(f" Input file: {args.input_file_path}") |
| print(f" Output file: {args.output_file_path}") |
| print(f" Resume mode: {args.resume}") |
| print("-" * 50) |
| total_lines = sum(1 for _ in open(args.input_file_path)) |
|
|
| skip_lines = 0 |
| error_file_path = args.output_file_path.replace(".jsonl", "_error.jsonl") |
|
|
| if args.resume and os.path.exists(args.output_file_path): |
| existing_success = sum(1 for _ in open(args.output_file_path)) |
| existing_error = 0 |
| if os.path.exists(error_file_path): |
| existing_error = sum(1 for _ in open(error_file_path)) |
| skip_lines = existing_success + existing_error |
| print(f"Resume mode enabled:") |
| print(f" Found {existing_success} successful samples in output file") |
| print(f" Found {existing_error} error samples in error file") |
| print(f" Skipping first {skip_lines} input samples") |
| print("-" * 50) |
|
|
| if skip_lines >= total_lines: |
| print(f"All {total_lines} samples already processed. Nothing to do.") |
| return |
|
|
| |
| valid_server_addresses = [] |
| for server_address in args.server_address: |
| dummy_data = dict( |
| conversations=[{"role": "user", "content": "Hello, how are you?"}] |
| ) |
| result = call_sglang( |
| args, |
| server_address, |
| dummy_data, |
| max_tokens=1, |
| ) |
| if result is not None: |
| valid_server_addresses.append(server_address) |
| else: |
| print(f"Server {server_address} is not available") |
|
|
| if len(valid_server_addresses) == 0: |
| raise ValueError("No server address is available") |
| print( |
| f"Using {len(valid_server_addresses)} server addresses: {valid_server_addresses}" |
| ) |
| print("-" * 50) |
|
|
| |
| file_mode = "a" if (args.resume and skip_lines > 0) else "w" |
| print( |
| f"Regenerating dataset and saving the output to {args.output_file_path} and error log to {error_file_path}" |
| ) |
| print( |
| f"File open mode: {file_mode} ({'append' if file_mode == 'a' else 'overwrite'})" |
| ) |
| print("-" * 50) |
| context_token_sum = 0 |
| context_token_min = None |
| context_token_max = 0 |
| success_samples = 0 |
| error_samples = 0 |
|
|
| |
| with ( |
| open(args.input_file_path, "r") as input_file, |
| open(args.output_file_path, file_mode) as output_file_handle, |
| open(error_file_path, file_mode) as error_file_handle, |
| ): |
| executor = ThreadPoolExecutor( |
| max_workers=args.concurrency * len(valid_server_addresses) |
| ) |
| waiting_queue = { |
| server_address: [] for server_address in valid_server_addresses |
| } |
| pbar = tqdm(total=total_lines, desc="Processing", initial=skip_lines) |
| start_server_index = 0 |
|
|
| if skip_lines > 0: |
| print(f"Skipping {skip_lines} already processed samples...") |
| for _ in range(skip_lines): |
| next(input_file, None) |
| print(f"Resuming from sample {skip_lines + 1}") |
|
|
| for line in input_file: |
| if ( |
| args.num_samples is not None |
| and success_samples + error_samples >= args.num_samples |
| ): |
| break |
|
|
| data = json.loads(line.strip()) |
|
|
| |
| server_address = valid_server_addresses[start_server_index] |
| start_server_index = (start_server_index + 1) % len(valid_server_addresses) |
|
|
| |
| while len(waiting_queue[server_address]) >= args.concurrency: |
| finished_on_request = False |
| |
| for req_future in waiting_queue[server_address]: |
| if req_future.done(): |
| regen_data = req_future.result() |
|
|
| if regen_data["status"] == "error": |
| error_file_handle.write( |
| json.dumps(regen_data, ensure_ascii=False) + "\n" |
| ) |
| error_samples += 1 |
| else: |
| ctx_len = compute_context_length( |
| regen_data.get("conversations", []) |
| ) |
| context_token_sum += ctx_len |
| if context_token_min is None: |
| context_token_min = ctx_len |
| else: |
| context_token_min = min(context_token_min, ctx_len) |
| context_token_max = max(context_token_max, ctx_len) |
|
|
| output_file_handle.write( |
| json.dumps(regen_data, ensure_ascii=False) + "\n" |
| ) |
| success_samples += 1 |
| waiting_queue[server_address].remove(req_future) |
| finished_on_request = True |
|
|
| if finished_on_request: |
| break |
|
|
| req_future = executor.submit( |
| call_sglang, |
| args, |
| server_address, |
| data, |
| ) |
| waiting_queue[server_address].append(req_future) |
| pbar.update(1) |
|
|
| |
| for server_address, waiting_queue_items in waiting_queue.items(): |
| for req_future in waiting_queue_items: |
| regen_data = req_future.result() |
| if regen_data["status"] == "error": |
| error_file_handle.write( |
| json.dumps(regen_data, ensure_ascii=False) + "\n" |
| ) |
| error_samples += 1 |
| else: |
| ctx_len = compute_context_length( |
| regen_data.get("conversations", []) |
| ) |
| context_token_sum += ctx_len |
| if context_token_min is None: |
| context_token_min = ctx_len |
| else: |
| context_token_min = min(context_token_min, ctx_len) |
| context_token_max = max(context_token_max, ctx_len) |
|
|
| output_file_handle.write( |
| json.dumps(regen_data, ensure_ascii=False) + "\n" |
| ) |
| success_samples += 1 |
|
|
| print(f"\nProcessing completed!") |
| if success_samples > 0: |
| avg_len = context_token_sum / success_samples |
| print("Context length statistics (token count over conversations):") |
| print(f"Number of successful examples: {success_samples}") |
| print(f"Shortest context length: {context_token_min}") |
| print(f"Longest context length: {context_token_max}") |
| print(f"Average context length: {avg_len:.2f}") |
| else: |
| print("No successful examples to compute context length statistics.") |
|
|
| total_processed = success_samples + error_samples |
| if skip_lines > 0: |
| print(f"\nResume processing completed!") |
| print(f" Previously processed: {skip_lines}") |
| print( |
| f" Newly processed: {total_processed} ({success_samples} success, {error_samples} failed)" |
| ) |
| print(f" Total: {skip_lines + total_processed}") |
| else: |
| print( |
| f"\nProcessing completed! {success_samples} samples regenerated, {error_samples} samples failed." |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|