File size: 3,638 Bytes
21e3f5a 39bbed9 21e3f5a 261b519 21e3f5a 812eb58 30757c1 812eb58 6ea9687 812eb58 21e3f5a 812eb58 61b2559 21e3f5a 7e1e02f 21e3f5a ab8809b 21e3f5a 61b2559 21e3f5a ab8809b 30757c1 21e3f5a 30757c1 21e3f5a ab8809b 21e3f5a ab8809b f345d42 ab8809b f345d42 ab8809b 21e3f5a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.
import argparse
import os
from loguru import logger
import torch
from torch import nn
from yolox.exp import get_exp
from yolox.models.network_blocks import SiLU
from yolox.utils import replace_module
def make_parser():
parser = argparse.ArgumentParser("YOLOX onnx deploy")
parser.add_argument(
"--output-name", type=str, default="yolox.onnx", help="output name of models"
)
parser.add_argument(
"--input", default="images", type=str, help="input node name of onnx model"
)
parser.add_argument(
"--output", default="output", type=str, help="output node name of onnx model"
)
parser.add_argument(
"-o", "--opset", default=11, type=int, help="onnx opset version"
)
parser.add_argument("--batch-size", type=int, default=1, help="batch size")
parser.add_argument(
"--dynamic", action="store_true", help="whether the input shape should be dynamic or not"
)
parser.add_argument("--no-onnxsim", action="store_true", help="use onnxsim or not")
parser.add_argument(
"-f",
"--exp_file",
default=None,
type=str,
help="experiment description file",
)
parser.add_argument("-expn", "--experiment-name", type=str, default=None)
parser.add_argument("-n", "--name", type=str, default=None, help="model name")
parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt path")
parser.add_argument(
"opts",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER,
)
parser.add_argument(
"--decode_in_inference",
action="store_true",
help="decode in inference or not"
)
return parser
@logger.catch
def main():
args = make_parser().parse_args()
logger.info("args value: {}".format(args))
exp = get_exp(args.exp_file, args.name)
exp.merge(args.opts)
if not args.experiment_name:
args.experiment_name = exp.exp_name
model = exp.get_model()
if args.ckpt is None:
file_name = os.path.join(exp.output_dir, args.experiment_name)
ckpt_file = os.path.join(file_name, "best_ckpt.pth")
else:
ckpt_file = args.ckpt
# load the model state dict
ckpt = torch.load(ckpt_file, map_location="cpu")
model.eval()
if "model" in ckpt:
ckpt = ckpt["model"]
model.load_state_dict(ckpt)
model = replace_module(model, nn.SiLU, SiLU)
model.head.decode_in_inference = args.decode_in_inference
logger.info("loading checkpoint done.")
dummy_input = torch.randn(args.batch_size, 3, exp.test_size[0], exp.test_size[1])
torch.onnx._export(
model,
dummy_input,
args.output_name,
input_names=[args.input],
output_names=[args.output],
dynamic_axes={args.input: {0: 'batch'},
args.output: {0: 'batch'}} if args.dynamic else None,
opset_version=args.opset,
)
logger.info("generated onnx model named {}".format(args.output_name))
if not args.no_onnxsim:
import onnx
from onnxsim import simplify
# use onnx-simplifier to reduce reduent model.
onnx_model = onnx.load(args.output_name)
model_simp, check = simplify(onnx_model)
assert check, "Simplified ONNX model could not be validated"
onnx.save(model_simp, args.output_name)
logger.info("generated simplified onnx model named {}".format(args.output_name))
if __name__ == "__main__":
main()
|