File size: 2,513 Bytes
21e3f5a 39bbed9 21e3f5a 261b519 21e3f5a 6ea9687 21e3f5a 665328e 49504ee 21e3f5a f00f5bc 21e3f5a 7e1e02f 21e3f5a 49504ee 21e3f5a 812eb58 21e3f5a 812eb58 39bbed9 21e3f5a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.
import argparse
import os
import shutil
from loguru import logger
import tensorrt as trt
import torch
from torch2trt import torch2trt
from yolox.exp import get_exp
def make_parser():
parser = argparse.ArgumentParser("YOLOX ncnn deploy")
parser.add_argument("-expn", "--experiment-name", type=str, default=None)
parser.add_argument("-n", "--name", type=str, default=None, help="model name")
parser.add_argument(
"-f",
"--exp_file",
default=None,
type=str,
help="please input your experiment description file",
)
parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt path")
parser.add_argument(
"-w", '--workspace', type=int, default=32, help='max workspace size in detect'
)
parser.add_argument("-b", '--batch', type=int, default=1, help='max batch size in detect')
return parser
@logger.catch
@torch.no_grad()
def main():
args = make_parser().parse_args()
exp = get_exp(args.exp_file, args.name)
if not args.experiment_name:
args.experiment_name = exp.exp_name
model = exp.get_model()
file_name = os.path.join(exp.output_dir, args.experiment_name)
os.makedirs(file_name, exist_ok=True)
if args.ckpt is None:
ckpt_file = os.path.join(file_name, "best_ckpt.pth")
else:
ckpt_file = args.ckpt
ckpt = torch.load(ckpt_file, map_location="cpu")
# load the model state dict
model.load_state_dict(ckpt["model"])
logger.info("loaded checkpoint done.")
model.eval()
model.cuda()
model.head.decode_in_inference = False
x = torch.ones(1, 3, exp.test_size[0], exp.test_size[1]).cuda()
model_trt = torch2trt(
model,
[x],
fp16_mode=True,
log_level=trt.Logger.INFO,
max_workspace_size=(1 << args.workspace),
max_batch_size=args.batch,
)
torch.save(model_trt.state_dict(), os.path.join(file_name, "model_trt.pth"))
logger.info("Converted TensorRT model done.")
engine_file = os.path.join(file_name, "model_trt.engine")
engine_file_demo = os.path.join("demo", "TensorRT", "cpp", "model_trt.engine")
with open(engine_file, "wb") as f:
f.write(model_trt.engine.serialize())
shutil.copyfile(engine_file, engine_file_demo)
logger.info("Converted TensorRT model engine file is saved for C++ inference.")
if __name__ == "__main__":
main()
|