Spaces:
Sleeping
Sleeping
import argparse | |
import os | |
import time | |
import traceback | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
from transformers import AutoConfig | |
import tensorrt_llm | |
from tensorrt_llm._utils import release_gc | |
from tensorrt_llm.logger import logger | |
from tensorrt_llm.mapping import Mapping | |
from tensorrt_llm.models import QWenForCausalLM | |
from tensorrt_llm.models.modeling_utils import QuantConfig | |
from tensorrt_llm.quantization import QuantAlgo | |
def parse_arguments(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--model_dir', type=str, default=None, required=True) | |
parser.add_argument('--tp_size', | |
type=int, | |
default=1, | |
help='N-way tensor parallelism size') | |
parser.add_argument('--pp_size', | |
type=int, | |
default=1, | |
help='N-way pipeline parallelism size') | |
parser.add_argument( | |
'--dtype', | |
type=str, | |
default='auto', | |
choices=['auto', 'float16', 'bfloat16', 'float32'], | |
help= | |
"The data type for the model weights and activations if not quantized. " | |
"If 'auto', the data type is automatically inferred from the source model; " | |
"however, if the source dtype is float32, it is converted to float16.") | |
parser.add_argument( | |
'--use_weight_only', | |
default=False, | |
action="store_true", | |
help='Quantize weights for the various GEMMs to INT4/INT8.' | |
'See --weight_only_precision to set the precision') | |
parser.add_argument( | |
'--disable_weight_only_quant_plugin', | |
default=False, | |
action="store_true", | |
help= | |
'By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin.' | |
'You must also use --use_weight_only for that argument to have an impact.' | |
) | |
parser.add_argument( | |
'--weight_only_precision', | |
const='int8', | |
type=str, | |
nargs='?', | |
default='int8', | |
choices=['int8', 'int4', 'int4_gptq'], | |
help= | |
'Define the precision for the weights when using weight-only quantization.' | |
'You must also use --use_weight_only for that argument to have an impact.' | |
) | |
parser.add_argument( | |
'--calib_dataset', | |
type=str, | |
default='ccdv/cnn_dailymail', | |
help= | |
"The huggingface dataset name or the local directory of the dataset for calibration." | |
) | |
parser.add_argument( | |
"--smoothquant", | |
"-sq", | |
type=float, | |
default=None, | |
help="Set the α parameter (see https://arxiv.org/pdf/2211.10438.pdf)" | |
" to Smoothquant the model, and output int8 weights." | |
" A good first try is 0.5. Must be in [0, 1]") | |
parser.add_argument( | |
'--per_channel', | |
action="store_true", | |
default=False, | |
help= | |
'By default, we use a single static scaling factor for the GEMM\'s result. ' | |
'per_channel instead uses a different static scaling factor for each channel. ' | |
'The latter is usually more accurate, but a little slower.') | |
parser.add_argument( | |
'--per_token', | |
action="store_true", | |
default=False, | |
help= | |
'By default, we use a single static scaling factor to scale activations in the int8 range. ' | |
'per_token chooses at run time, and for each token, a custom scaling factor. ' | |
'The latter is usually more accurate, but a little slower.') | |
parser.add_argument( | |
'--int8_kv_cache', | |
default=False, | |
action="store_true", | |
help= | |
'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV' | |
) | |
parser.add_argument( | |
'--per_group', | |
default=False, | |
action="store_true", | |
help= | |
'By default, we use a single static scaling factor to scale weights in the int4 range. ' | |
'per_group chooses at run time, and for each group, a custom scaling factor. ' | |
'The flag is built for GPTQ/AWQ quantization.') | |
parser.add_argument('--group_size', | |
type=int, | |
default=128, | |
help='Group size used in GPTQ quantization.') | |
parser.add_argument("--load_model_on_cpu", action="store_true") | |
parser.add_argument( | |
'--use_parallel_embedding', | |
action="store_true", | |
default=False, | |
help= | |
'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled' | |
) | |
parser.add_argument( | |
'--embedding_sharding_dim', | |
type=int, | |
default=0, | |
choices=[0, 1], | |
help= | |
'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). ' | |
'To shard it along hidden dimension, set embedding_sharding_dim=1' | |
'Note: embedding sharing is only enabled when embedding_sharding_dim = 0' | |
) | |
parser.add_argument('--output_dir', | |
type=str, | |
default='tllm_checkpoint', | |
help='The path to save the TensorRT-LLM checkpoint') | |
parser.add_argument( | |
'--workers', | |
type=int, | |
default=1, | |
help='The number of workers for converting checkpoint in parallel') | |
parser.add_argument( | |
'--moe_tp_size', | |
type=int, | |
default=-1, | |
help= | |
'N-way tensor parallelism size for MOE, default is tp_size, which will do tp-only for MoE' | |
) | |
parser.add_argument( | |
'--moe_ep_size', | |
type=int, | |
default=-1, | |
help= | |
'N-way expert parallelism size for MOE, default is 1, which will do tp-only for MoE' | |
) | |
args = parser.parse_args() | |
return args | |
def args_to_quant_config(args: argparse.Namespace) -> QuantConfig: | |
'''return config dict with quantization info based on the command line args | |
''' | |
quant_config = QuantConfig() | |
if args.use_weight_only: | |
if args.weight_only_precision == 'int8': | |
quant_config.quant_algo = QuantAlgo.W8A16 | |
elif args.weight_only_precision == 'int4': | |
quant_config.quant_algo = QuantAlgo.W4A16 | |
elif args.smoothquant: | |
quant_config.smoothquant_val = args.smoothquant | |
if args.per_channel: | |
if args.per_token: | |
quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN | |
else: | |
quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN | |
else: | |
if args.per_token: | |
quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN | |
else: | |
quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN | |
if args.int8_kv_cache: | |
quant_config.kv_cache_quant_algo = QuantAlgo.INT8 | |
if args.weight_only_precision == 'int4_gptq': | |
quant_config.group_size = args.group_size | |
quant_config.has_zero_point = True | |
quant_config.pre_quant_scale = False | |
quant_config.quant_algo = QuantAlgo.W4A16_GPTQ | |
return quant_config | |
def update_quant_config_from_hf(quant_config, hf_config, | |
override_fields) -> tuple[QuantConfig, dict]: | |
hf_config_dict = hf_config.to_dict() | |
if hf_config_dict.get('quantization_config'): | |
# update the quant_algo, and clamp_val. | |
if hf_config_dict['quantization_config'].get('quant_method') == 'awq': | |
logger.info( | |
"Load quantization configs from huggingface model_config.") | |
quant_config.quant_algo = QuantAlgo.W4A16_GPTQ | |
quant_config.group_size = hf_config_dict['quantization_config'].get( | |
'group_size', 128) | |
quant_config.has_zero_point = hf_config_dict[ | |
'quantization_config'].get('zero_point', False) | |
override_fields.update({"use_autoawq": True}) | |
elif hf_config_dict['quantization_config'].get( | |
'quant_method') == 'gptq': | |
logger.info( | |
"Load quantization configs from huggingface model_config.") | |
desc_act = hf_config_dict['quantization_config'].get( | |
'desc_act', False) | |
if desc_act: | |
raise ValueError("GPTQ with desc_act=True is not implemented!") | |
quant_config.quant_algo = QuantAlgo.W4A16_GPTQ | |
quant_config.group_size = hf_config_dict['quantization_config'].get( | |
'group_size', 128) | |
quant_config.has_zero_point = hf_config_dict[ | |
'quantization_config'].get('sym', False) | |
return quant_config, override_fields | |
def args_to_build_options(args): | |
return { | |
'use_parallel_embedding': args.use_parallel_embedding, | |
'embedding_sharding_dim': args.embedding_sharding_dim, | |
'disable_weight_only_quant_plugin': | |
args.disable_weight_only_quant_plugin | |
} | |
def convert_and_save_hf(args): | |
model_dir = args.model_dir | |
world_size = args.tp_size * args.pp_size | |
# Need to convert the cli args to the kay-value pairs and override them in the generate config dict. | |
# Ideally these fields will be moved out of the config and pass them into build API, keep them here for compatibility purpose for now, | |
# before the refactor is done. | |
override_fields = {} | |
override_fields.update(args_to_build_options(args)) | |
quant_config = args_to_quant_config(args) | |
try: | |
hf_config = AutoConfig.from_pretrained(model_dir, | |
trust_remote_code=True) | |
quant_config, override_fields = update_quant_config_from_hf( | |
quant_config, hf_config, override_fields) | |
except: | |
logger.warning("AutoConfig cannot load the huggingface config.") | |
if args.smoothquant is not None or args.int8_kv_cache: | |
mapping = Mapping( | |
world_size=world_size, | |
tp_size=args.tp_size, | |
pp_size=args.pp_size, | |
moe_tp_size=args.moe_tp_size, | |
moe_ep_size=args.moe_ep_size, | |
) | |
QWenForCausalLM.quantize(args.model_dir, | |
args.output_dir, | |
dtype=args.dtype, | |
mapping=mapping, | |
quant_config=quant_config, | |
calib_dataset=args.calib_dataset, | |
**override_fields) | |
else: | |
def convert_and_save_rank(args, rank): | |
mapping = Mapping(world_size=world_size, | |
rank=rank, | |
tp_size=args.tp_size, | |
pp_size=args.pp_size, | |
moe_tp_size=args.moe_tp_size, | |
moe_ep_size=args.moe_ep_size) | |
qwen = QWenForCausalLM.from_hugging_face(model_dir, | |
args.dtype, | |
mapping=mapping, | |
quant_config=quant_config, | |
**override_fields) | |
qwen.save_checkpoint(args.output_dir, save_config=(rank == 0)) | |
del qwen | |
execute(args.workers, [convert_and_save_rank] * world_size, args) | |
release_gc() | |
def execute(workers, func, args): | |
if workers == 1: | |
for rank, f in enumerate(func): | |
f(args, rank) | |
else: | |
with ThreadPoolExecutor(max_workers=workers) as p: | |
futures = [p.submit(f, args, rank) for rank, f in enumerate(func)] | |
exceptions = [] | |
for future in as_completed(futures): | |
try: | |
future.result() | |
except Exception as e: | |
traceback.print_exc() | |
exceptions.append(e) | |
assert len( | |
exceptions | |
) == 0, "Checkpoint conversion failed, please check error log." | |
def main(): | |
print(tensorrt_llm.__version__) | |
args = parse_arguments() | |
if (args.moe_tp_size == -1 and args.moe_ep_size == -1): | |
# moe default to tp-only | |
args.moe_tp_size = args.tp_size | |
args.moe_ep_size = 1 | |
elif (args.moe_tp_size == -1): | |
args.moe_tp_size = args.tp_size // args.moe_ep_size | |
elif (args.moe_ep_size == -1): | |
args.moe_ep_size = args.tp_size // args.moe_tp_size | |
assert (args.moe_tp_size * args.moe_ep_size == args.tp_size | |
), "moe_tp_size * moe_ep_size must equal to tp_size" | |
tik = time.time() | |
if not os.path.exists(args.output_dir): | |
os.makedirs(args.output_dir) | |
assert args.model_dir is not None | |
convert_and_save_hf(args) | |
tok = time.time() | |
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) | |
print(f'Total time of converting checkpoints: {t}') | |
if __name__ == '__main__': | |
main() | |