|
|
|
|
|
|
|
|
|
import argparse |
|
import os |
|
import shutil |
|
from loguru import logger |
|
|
|
import tensorrt as trt |
|
import torch |
|
from torch2trt import torch2trt |
|
|
|
from yolox.exp import get_exp |
|
|
|
|
|
def make_parser(): |
|
parser = argparse.ArgumentParser("YOLOX ncnn deploy") |
|
parser.add_argument("-expn", "--experiment-name", type=str, default=None) |
|
parser.add_argument("-n", "--name", type=str, default=None, help="model name") |
|
|
|
parser.add_argument( |
|
"-f", |
|
"--exp_file", |
|
default=None, |
|
type=str, |
|
help="please input your experiment description file", |
|
) |
|
parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt path") |
|
parser.add_argument( |
|
"-w", '--workspace', type=int, default=32, help='max workspace size in detect' |
|
) |
|
parser.add_argument("-b", '--batch', type=int, default=1, help='max batch size in detect') |
|
return parser |
|
|
|
|
|
@logger.catch |
|
@torch.no_grad() |
|
def main(): |
|
args = make_parser().parse_args() |
|
exp = get_exp(args.exp_file, args.name) |
|
if not args.experiment_name: |
|
args.experiment_name = exp.exp_name |
|
|
|
model = exp.get_model() |
|
file_name = os.path.join(exp.output_dir, args.experiment_name) |
|
os.makedirs(file_name, exist_ok=True) |
|
if args.ckpt is None: |
|
ckpt_file = os.path.join(file_name, "best_ckpt.pth") |
|
else: |
|
ckpt_file = args.ckpt |
|
|
|
ckpt = torch.load(ckpt_file, map_location="cpu") |
|
|
|
|
|
model.load_state_dict(ckpt["model"]) |
|
logger.info("loaded checkpoint done.") |
|
model.eval() |
|
model.cuda() |
|
model.head.decode_in_inference = False |
|
x = torch.ones(1, 3, exp.test_size[0], exp.test_size[1]).cuda() |
|
model_trt = torch2trt( |
|
model, |
|
[x], |
|
fp16_mode=True, |
|
log_level=trt.Logger.INFO, |
|
max_workspace_size=(1 << args.workspace), |
|
max_batch_size=args.batch, |
|
) |
|
torch.save(model_trt.state_dict(), os.path.join(file_name, "model_trt.pth")) |
|
logger.info("Converted TensorRT model done.") |
|
engine_file = os.path.join(file_name, "model_trt.engine") |
|
engine_file_demo = os.path.join("demo", "TensorRT", "cpp", "model_trt.engine") |
|
with open(engine_file, "wb") as f: |
|
f.write(model_trt.engine.serialize()) |
|
|
|
shutil.copyfile(engine_file, engine_file_demo) |
|
|
|
logger.info("Converted TensorRT model engine file is saved for C++ inference.") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|