Skip to content

Commit

Permalink
fix onnx export for inception3/4, resnext, mobilenetv2
Browse files Browse the repository at this point in the history
  • Loading branch information
lostkevin committed Jul 9, 2024
1 parent 83ea6b7 commit 1388f0c
Show file tree
Hide file tree
Showing 8 changed files with 206 additions and 54 deletions.
48 changes: 25 additions & 23 deletions configs/classification/imagenet/inception/inceptionv3_b32x8_100e.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,30 @@
model = dict(
type='Classification',
backbone=dict(type='Inception3', num_classes=1000),
head=[dict(
type='ClsHead',
with_fc=False,
in_channels=2048,
loss_config=dict(
type='CrossEntropyLossWithLabelSmooth',
label_smooth=0,
head=[
dict(
type='ClsHead',
with_fc=False,
in_channels=2048,
loss_config=dict(
type='CrossEntropyLossWithLabelSmooth',
label_smooth=0,
),
num_classes=num_classes,
input_feature_index=[1],
),
num_classes=num_classes,
input_feature_index=[1],
),
dict(
type='ClsHead',
with_fc=False,
in_channels=768,
loss_config=dict(
type='CrossEntropyLossWithLabelSmooth',
label_smooth=0,
),
num_classes=num_classes,
input_feature_index=[0],
)
]
)
dict(
type='ClsHead',
with_fc=False,
in_channels=768,
loss_config=dict(
type='CrossEntropyLossWithLabelSmooth',
label_smooth=0,
),
num_classes=num_classes,
input_feature_index=[0],
)
])

class_list = [
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13',
Expand Down Expand Up @@ -211,3 +211,5 @@
interval=10,
hooks=[dict(type='TextLoggerHook'),
dict(type='TensorboardLoggerHook')])

export = dict(export_type='raw', export_neck=True)
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,5 @@
interval=10,
hooks=[dict(type='TextLoggerHook'),
dict(type='TensorboardLoggerHook')])

export = dict(export_type='raw', export_neck=True)
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,5 @@
interval=10,
hooks=[dict(type='TextLoggerHook'),
dict(type='TensorboardLoggerHook')])

export = dict(export_type='raw', export_neck=True)
5 changes: 3 additions & 2 deletions configs/classification/imagenet/mobilenet/mobilenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
type='CrossEntropyLossWithLabelSmooth',
label_smooth=0,
),
num_classes=num_classes))
num_classes=num_classes),
pretrained=True)

# optimizer
optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001)
Expand All @@ -25,4 +26,4 @@
# runtime settings
total_epochs = 100
checkpoint_sync_export = True
export = dict(export_neck=True)
export = dict(export_type='raw', export_neck=True)
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
type='CrossEntropyLossWithLabelSmooth',
label_smooth=0,
),
num_classes=num_classes))
num_classes=num_classes),
pretrained=True)

# optimizer
optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001)
Expand All @@ -30,3 +31,4 @@

# runtime settings
total_epochs = 100
export = dict(export_type='raw', export_neck=True)
69 changes: 42 additions & 27 deletions easycv/apis/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,34 +158,49 @@ def _get_blade_model():


def _export_onnx_cls(model, model_config, cfg, filename, meta):
support_backbones = {
'ResNet': {
'depth': [50]
},
'MobileNetV2': {},
'Inception3': {},
'Inception4': {},
'ResNeXt': {
'depth': [50]
}
}
if model_config['backbone'].get('type', None) not in support_backbones:
tmp = ' '.join(support_backbones.keys())
info_str = f'Only support export onnx model for {tmp} now!'
raise ValueError(info_str)
configs = support_backbones[model_config['backbone'].get('type')]
for k, v in configs.items():
if v[0].__class__(model_config['backbone'].get(k, None)) not in v:
raise ValueError(
f"Unsupport config for {model_config['backbone'].get('type')}")

# save json config for test_pipline and class
with io.open(
filename +
'.config.json' if filename.endswith('onnx') else filename +
'.onnx.config.json', 'w') as ofile:
json.dump(meta, ofile)

if model_config['backbone'].get(
'type', None) == 'ResNet' and model_config['backbone'].get(
'depth', None) == 50:
# save json config for test_pipline and class
with io.open(
filename +
'.config.json' if filename.endswith('onnx') else filename +
'.onnx.config.json', 'w') as ofile:
json.dump(meta, ofile)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.eval()
model.to(device)
img_size = int(cfg.image_size2)
x_input = torch.randn((1, 3, img_size, img_size)).to(device)
torch.onnx.export(
model,
(x_input, 'onnx'),
filename if filename.endswith('onnx') else filename + '.onnx',
export_params=True,
opset_version=12,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
)
else:
raise ValueError('Only support export onnx model for ResNet now!')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.eval()
model.to(device)
img_size = int(cfg.image_size2)
x_input = torch.randn((1, 3, img_size, img_size)).to(device)
torch.onnx.export(
model,
(x_input, 'onnx'),
filename if filename.endswith('onnx') else filename + '.onnx',
export_params=True,
opset_version=12,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
)


def _export_cls(model, cfg, filename):
Expand Down
128 changes: 128 additions & 0 deletions tests/test_tools/test_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
import os
import sys
import unittest

import numpy as np
import onnxruntime
import torch

from easycv.models import build_model
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.config_tools import mmcv_config_fromfile, rebuild_config
from easycv.utils.test_util import run_in_subprocess

sys.path.append(os.path.dirname(os.path.realpath(__file__)))
logging.basicConfig(level=logging.INFO)

BASIC_EXPORT_CONFIGS = {
'config_file': None,
'checkpoint': 'dummy',
'output_filename': 'work_dir/test_out.pth',
'user_config_params': ['--export.export_type', 'onnx']
}


def build_cmd(export_configs) -> str:
base_cmd = 'python tools/export.py'
base_cmd += f" {export_configs['config_file']}"
base_cmd += f" {export_configs['checkpoint']}"
base_cmd += f" {export_configs['output_filename']}"
user_params = ' '.join(export_configs['user_config_params'])
base_cmd += f' --user_config_params {user_params}'
return base_cmd


class ExportTest(unittest.TestCase):
"""In this unittest, we test the onnx export functionality of
some classification/detection models.
"""

def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))

def tearDown(self):
super().tearDown()

def run_test(self, CONFIG_FILE, img_size: int = 224, **override_configs):
configs = BASIC_EXPORT_CONFIGS.copy()
configs['config_file'] = CONFIG_FILE

configs.update(override_configs)

cmd = build_cmd(configs)
logging.info(f'Export with commands: {cmd}')
run_in_subprocess(cmd)

cfg = mmcv_config_fromfile(configs['config_file'])
cfg = rebuild_config(cfg, configs['user_config_params'])

if hasattr(cfg.model, 'pretrained'):
cfg.model.pretrained = False

torch_model = build_model(cfg.model).eval()
if 'checkpoint' in override_configs:
load_checkpoint(
torch_model,
override_configs['checkpoint'],
strict=False,
logger=logging.getLogger())
session = onnxruntime.InferenceSession(configs['output_filename'] +
'.onnx')
input_tensor = torch.randn((1, 3, img_size, img_size))

torch_output = torch_model(input_tensor, mode='test')['prob']

onnx_output = session.run(
[session.get_outputs()[0].name],
{session.get_inputs()[0].name: np.array(input_tensor)})
if isinstance(onnx_output, list):
onnx_output = onnx_output[0]

onnx_output = torch.tensor(onnx_output)

is_same_shape = torch_output.shape == onnx_output.shape

self.assertTrue(
is_same_shape,
f'The shapes of the two outputs are mismatch, got {torch_output.shape} and {onnx_output.shape}'
)
is_allclose = torch.allclose(torch_output, onnx_output)

torch_out_minmax = f'{float(torch_output.min())}~{float(torch_output.max())}'
onnx_out_minmax = f'{float(onnx_output.min())}~{float(onnx_output.max())}'

info_msg = f'got avg: {float(torch_output.mean())} and {float(onnx_output.mean())},'
info_msg += f' and range: {torch_out_minmax} and {onnx_out_minmax}'
self.assertTrue(
is_allclose,
f'The values between the two outputs are mismatch, {info_msg}')

def test_inceptionv3(self):
CONFIG_FILE = 'configs/classification/imagenet/inception/inceptionv3_b32x8_100e.py'
self.run_test(CONFIG_FILE, 299)

def test_inceptionv4(self):
CONFIG_FILE = 'configs/classification/imagenet/inception/inceptionv4_b32x8_100e.py'
self.run_test(CONFIG_FILE, 299)

def test_resnext50(self):
CONFIG_FILE = 'configs/classification/imagenet/resnext/imagenet_resnext50-32x4d_jpg.py'
self.run_test(
CONFIG_FILE,
checkpoint=
'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnext/resnext50-32x4d/epoch_100.pth'
)

def test_mobilenetv2(self):
CONFIG_FILE = 'configs/classification/imagenet/mobilenet/mobilenetv2.py'
self.run_test(
CONFIG_FILE,
checkpoint=
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/easycv/mobilenetv2/mobilenet_v2.pth'
)


if __name__ == '__main__':
unittest.main()
2 changes: 1 addition & 1 deletion tools/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def main():
cfg = mmcv_config_fromfile(args.config)

if args.user_config_params is not None:
assert args.model_type is not None, 'model_type must be setted'
# assert args.model_type is not None, 'model_type must be setted'
# rebuild config by user config params
cfg = rebuild_config(cfg, args.user_config_params)

Expand Down

0 comments on commit 1388f0c

Please sign in to comment.