From 5d17168cbb0c67b7486898483bf6878bf0a800c7 Mon Sep 17 00:00:00 2001 From: "gushen.hkw" Date: Thu, 4 Jul 2024 14:39:45 +0800 Subject: [PATCH 01/11] add inceptionv4 backbone; the performance is not checked --- .../inception/inceptionv4_b32x8_100e.py | 199 +++++++++ .../inceptionv4_b32x8_200e_rmsprop.py | 201 +++++++++ easycv/models/backbones/__init__.py | 1 + easycv/models/backbones/inceptionv4.py | 380 ++++++++++++++++++ easycv/models/modelzoo.py | 5 + 5 files changed, 786 insertions(+) create mode 100644 configs/classification/imagenet/inception/inceptionv4_b32x8_100e.py create mode 100644 configs/classification/imagenet/inception/inceptionv4_b32x8_200e_rmsprop.py create mode 100644 easycv/models/backbones/inceptionv4.py diff --git a/configs/classification/imagenet/inception/inceptionv4_b32x8_100e.py b/configs/classification/imagenet/inception/inceptionv4_b32x8_100e.py new file mode 100644 index 00000000..3752844a --- /dev/null +++ b/configs/classification/imagenet/inception/inceptionv4_b32x8_100e.py @@ -0,0 +1,199 @@ +_base_ = 'configs/base.py' + +num_classes = 1000 +# model settings +model = dict( + type='Classification', + pretrained=False, + backbone=dict(type='Inception4'), + head=dict( + type='ClsHead', + with_avg_pool=True, + in_channels=1536, + loss_config=dict( + type='CrossEntropyLossWithLabelSmooth', + label_smooth=0, + ), + num_classes=num_classes)) + +class_list = [ + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', + '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', + '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36', '37', + '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', + '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60', '61', + '62', '63', '64', '65', '66', '67', '68', '69', '70', '71', '72', '73', + '74', '75', '76', '77', '78', '79', '80', '81', '82', '83', '84', '85', + '86', '87', '88', '89', '90', '91', '92', '93', '94', '95', '96', '97', + '98', '99', '100', '101', '102', '103', '104', '105', '106', '107', '108', + '109', '110', '111', '112', '113', '114', '115', '116', '117', '118', + '119', '120', '121', '122', '123', '124', '125', '126', '127', '128', + '129', '130', '131', '132', '133', '134', '135', '136', '137', '138', + '139', '140', '141', '142', '143', '144', '145', '146', '147', '148', + '149', '150', '151', '152', '153', '154', '155', '156', '157', '158', + '159', '160', '161', '162', '163', '164', '165', '166', '167', '168', + '169', '170', '171', '172', '173', '174', '175', '176', '177', '178', + '179', '180', '181', '182', '183', '184', '185', '186', '187', '188', + '189', '190', '191', '192', '193', '194', '195', '196', '197', '198', + '199', '200', '201', '202', '203', '204', '205', '206', '207', '208', + '209', '210', '211', '212', '213', '214', '215', '216', '217', '218', + '219', '220', '221', '222', '223', '224', '225', '226', '227', '228', + '229', '230', '231', '232', '233', '234', '235', '236', '237', '238', + '239', '240', '241', '242', '243', '244', '245', '246', '247', '248', + '249', '250', '251', '252', '253', '254', '255', '256', '257', '258', + '259', '260', '261', '262', '263', '264', '265', '266', '267', '268', + '269', '270', '271', '272', '273', '274', '275', '276', '277', '278', + '279', '280', '281', '282', '283', '284', '285', '286', '287', '288', + '289', '290', '291', '292', '293', '294', '295', '296', '297', '298', + '299', '300', '301', '302', '303', '304', '305', '306', '307', '308', + '309', '310', '311', '312', '313', '314', '315', '316', '317', '318', + '319', '320', '321', '322', '323', '324', '325', '326', '327', '328', + '329', '330', '331', '332', '333', '334', '335', '336', '337', '338', + '339', '340', '341', '342', '343', '344', '345', '346', '347', '348', + '349', '350', '351', '352', '353', '354', '355', '356', '357', '358', + '359', '360', '361', '362', '363', '364', '365', '366', '367', '368', + '369', '370', '371', '372', '373', '374', '375', '376', '377', '378', + '379', '380', '381', '382', '383', '384', '385', '386', '387', '388', + '389', '390', '391', '392', '393', '394', '395', '396', '397', '398', + '399', '400', '401', '402', '403', '404', '405', '406', '407', '408', + '409', '410', '411', '412', '413', '414', '415', '416', '417', '418', + '419', '420', '421', '422', '423', '424', '425', '426', '427', '428', + '429', '430', '431', '432', '433', '434', '435', '436', '437', '438', + '439', '440', '441', '442', '443', '444', '445', '446', '447', '448', + '449', '450', '451', '452', '453', '454', '455', '456', '457', '458', + '459', '460', '461', '462', '463', '464', '465', '466', '467', '468', + '469', '470', '471', '472', '473', '474', '475', '476', '477', '478', + '479', '480', '481', '482', '483', '484', '485', '486', '487', '488', + '489', '490', '491', '492', '493', '494', '495', '496', '497', '498', + '499', '500', '501', '502', '503', '504', '505', '506', '507', '508', + '509', '510', '511', '512', '513', '514', '515', '516', '517', '518', + '519', '520', '521', '522', '523', '524', '525', '526', '527', '528', + '529', '530', '531', '532', '533', '534', '535', '536', '537', '538', + '539', '540', '541', '542', '543', '544', '545', '546', '547', '548', + '549', '550', '551', '552', '553', '554', '555', '556', '557', '558', + '559', '560', '561', '562', '563', '564', '565', '566', '567', '568', + '569', '570', '571', '572', '573', '574', '575', '576', '577', '578', + '579', '580', '581', '582', '583', '584', '585', '586', '587', '588', + '589', '590', '591', '592', '593', '594', '595', '596', '597', '598', + '599', '600', '601', '602', '603', '604', '605', '606', '607', '608', + '609', '610', '611', '612', '613', '614', '615', '616', '617', '618', + '619', '620', '621', '622', '623', '624', '625', '626', '627', '628', + '629', '630', '631', '632', '633', '634', '635', '636', '637', '638', + '639', '640', '641', '642', '643', '644', '645', '646', '647', '648', + '649', '650', '651', '652', '653', '654', '655', '656', '657', '658', + '659', '660', '661', '662', '663', '664', '665', '666', '667', '668', + '669', '670', '671', '672', '673', '674', '675', '676', '677', '678', + '679', '680', '681', '682', '683', '684', '685', '686', '687', '688', + '689', '690', '691', '692', '693', '694', '695', '696', '697', '698', + '699', '700', '701', '702', '703', '704', '705', '706', '707', '708', + '709', '710', '711', '712', '713', '714', '715', '716', '717', '718', + '719', '720', '721', '722', '723', '724', '725', '726', '727', '728', + '729', '730', '731', '732', '733', '734', '735', '736', '737', '738', + '739', '740', '741', '742', '743', '744', '745', '746', '747', '748', + '749', '750', '751', '752', '753', '754', '755', '756', '757', '758', + '759', '760', '761', '762', '763', '764', '765', '766', '767', '768', + '769', '770', '771', '772', '773', '774', '775', '776', '777', '778', + '779', '780', '781', '782', '783', '784', '785', '786', '787', '788', + '789', '790', '791', '792', '793', '794', '795', '796', '797', '798', + '799', '800', '801', '802', '803', '804', '805', '806', '807', '808', + '809', '810', '811', '812', '813', '814', '815', '816', '817', '818', + '819', '820', '821', '822', '823', '824', '825', '826', '827', '828', + '829', '830', '831', '832', '833', '834', '835', '836', '837', '838', + '839', '840', '841', '842', '843', '844', '845', '846', '847', '848', + '849', '850', '851', '852', '853', '854', '855', '856', '857', '858', + '859', '860', '861', '862', '863', '864', '865', '866', '867', '868', + '869', '870', '871', '872', '873', '874', '875', '876', '877', '878', + '879', '880', '881', '882', '883', '884', '885', '886', '887', '888', + '889', '890', '891', '892', '893', '894', '895', '896', '897', '898', + '899', '900', '901', '902', '903', '904', '905', '906', '907', '908', + '909', '910', '911', '912', '913', '914', '915', '916', '917', '918', + '919', '920', '921', '922', '923', '924', '925', '926', '927', '928', + '929', '930', '931', '932', '933', '934', '935', '936', '937', '938', + '939', '940', '941', '942', '943', '944', '945', '946', '947', '948', + '949', '950', '951', '952', '953', '954', '955', '956', '957', '958', + '959', '960', '961', '962', '963', '964', '965', '966', '967', '968', + '969', '970', '971', '972', '973', '974', '975', '976', '977', '978', + '979', '980', '981', '982', '983', '984', '985', '986', '987', '988', + '989', '990', '991', '992', '993', '994', '995', '996', '997', '998', '999' +] + +data_source_type = 'ClsSourceImageList' +base_root = 'data/imagenet_raw/' +data_train_list = base_root + 'meta/train_labeled.txt' +data_train_root = base_root + 'train/' +data_test_list = base_root + 'meta/val_labeled.txt' +data_test_root = base_root + 'validation/' +image_size2 = 299 +image_size1 = int((256 / 224) * image_size2) + +dataset_type = 'ClsDataset' +img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) +train_pipeline = [ + dict(type='RandomResizedCrop', size=image_size2), + dict(type='RandomHorizontalFlip'), + dict(type='ToTensor'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Collect', keys=['img', 'gt_labels']) +] +test_pipeline = [ + dict(type='Resize', size=image_size1), + dict(type='CenterCrop', size=image_size2), + dict(type='ToTensor'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Collect', keys=['img', 'gt_labels']) +] + +data = dict( + imgs_per_gpu=32, # total 256 + workers_per_gpu=4, + train=dict( + type=dataset_type, + data_source=dict( + list_file=data_train_list, + root=data_train_root, + type=data_source_type), + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_source=dict( + list_file=data_test_list, + root=data_test_root, + type=data_source_type), + pipeline=test_pipeline)) + +eval_config = dict(initial=False, interval=1, gpu_collect=True) +eval_pipelines = [ + dict( + mode='test', + data=data['val'], + dist_eval=True, + evaluators=[ + dict(type='ClsEvaluator', topk=(1, 5), class_list=class_list) + ], + ) +] + +# optimizer +optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001) + +# learning policy +lr_config = dict(policy='step', step=[30, 60, 90]) +checkpoint_config = dict(interval=10) + +# runtime settings +total_epochs = 100 + +predict = dict( + type='ClassificationPredictor', + pipelines=[ + dict(type='Resize', size=image_size1), + dict(type='CenterCrop', size=image_size2), + dict(type='ToTensor'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Collect', keys=['img']) + ]) + +log_config = dict( + interval=10, + hooks=[dict(type='TextLoggerHook'), + dict(type='TensorboardLoggerHook')]) diff --git a/configs/classification/imagenet/inception/inceptionv4_b32x8_200e_rmsprop.py b/configs/classification/imagenet/inception/inceptionv4_b32x8_200e_rmsprop.py new file mode 100644 index 00000000..5b4151e4 --- /dev/null +++ b/configs/classification/imagenet/inception/inceptionv4_b32x8_200e_rmsprop.py @@ -0,0 +1,201 @@ +# A config with the optimization settings from https://arxiv.org/pdf/1602.07261 +# May run with 20 GPUs +_base_ = 'configs/base.py' + +num_classes = 1000 +# model settings +model = dict( + type='Classification', + pretrained=False, + backbone=dict(type='Inception4'), + head=dict( + type='ClsHead', + with_avg_pool=True, + in_channels=1536, + loss_config=dict( + type='CrossEntropyLossWithLabelSmooth', + label_smooth=0, + ), + num_classes=num_classes)) + +class_list = [ + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', + '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', + '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36', '37', + '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', + '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60', '61', + '62', '63', '64', '65', '66', '67', '68', '69', '70', '71', '72', '73', + '74', '75', '76', '77', '78', '79', '80', '81', '82', '83', '84', '85', + '86', '87', '88', '89', '90', '91', '92', '93', '94', '95', '96', '97', + '98', '99', '100', '101', '102', '103', '104', '105', '106', '107', '108', + '109', '110', '111', '112', '113', '114', '115', '116', '117', '118', + '119', '120', '121', '122', '123', '124', '125', '126', '127', '128', + '129', '130', '131', '132', '133', '134', '135', '136', '137', '138', + '139', '140', '141', '142', '143', '144', '145', '146', '147', '148', + '149', '150', '151', '152', '153', '154', '155', '156', '157', '158', + '159', '160', '161', '162', '163', '164', '165', '166', '167', '168', + '169', '170', '171', '172', '173', '174', '175', '176', '177', '178', + '179', '180', '181', '182', '183', '184', '185', '186', '187', '188', + '189', '190', '191', '192', '193', '194', '195', '196', '197', '198', + '199', '200', '201', '202', '203', '204', '205', '206', '207', '208', + '209', '210', '211', '212', '213', '214', '215', '216', '217', '218', + '219', '220', '221', '222', '223', '224', '225', '226', '227', '228', + '229', '230', '231', '232', '233', '234', '235', '236', '237', '238', + '239', '240', '241', '242', '243', '244', '245', '246', '247', '248', + '249', '250', '251', '252', '253', '254', '255', '256', '257', '258', + '259', '260', '261', '262', '263', '264', '265', '266', '267', '268', + '269', '270', '271', '272', '273', '274', '275', '276', '277', '278', + '279', '280', '281', '282', '283', '284', '285', '286', '287', '288', + '289', '290', '291', '292', '293', '294', '295', '296', '297', '298', + '299', '300', '301', '302', '303', '304', '305', '306', '307', '308', + '309', '310', '311', '312', '313', '314', '315', '316', '317', '318', + '319', '320', '321', '322', '323', '324', '325', '326', '327', '328', + '329', '330', '331', '332', '333', '334', '335', '336', '337', '338', + '339', '340', '341', '342', '343', '344', '345', '346', '347', '348', + '349', '350', '351', '352', '353', '354', '355', '356', '357', '358', + '359', '360', '361', '362', '363', '364', '365', '366', '367', '368', + '369', '370', '371', '372', '373', '374', '375', '376', '377', '378', + '379', '380', '381', '382', '383', '384', '385', '386', '387', '388', + '389', '390', '391', '392', '393', '394', '395', '396', '397', '398', + '399', '400', '401', '402', '403', '404', '405', '406', '407', '408', + '409', '410', '411', '412', '413', '414', '415', '416', '417', '418', + '419', '420', '421', '422', '423', '424', '425', '426', '427', '428', + '429', '430', '431', '432', '433', '434', '435', '436', '437', '438', + '439', '440', '441', '442', '443', '444', '445', '446', '447', '448', + '449', '450', '451', '452', '453', '454', '455', '456', '457', '458', + '459', '460', '461', '462', '463', '464', '465', '466', '467', '468', + '469', '470', '471', '472', '473', '474', '475', '476', '477', '478', + '479', '480', '481', '482', '483', '484', '485', '486', '487', '488', + '489', '490', '491', '492', '493', '494', '495', '496', '497', '498', + '499', '500', '501', '502', '503', '504', '505', '506', '507', '508', + '509', '510', '511', '512', '513', '514', '515', '516', '517', '518', + '519', '520', '521', '522', '523', '524', '525', '526', '527', '528', + '529', '530', '531', '532', '533', '534', '535', '536', '537', '538', + '539', '540', '541', '542', '543', '544', '545', '546', '547', '548', + '549', '550', '551', '552', '553', '554', '555', '556', '557', '558', + '559', '560', '561', '562', '563', '564', '565', '566', '567', '568', + '569', '570', '571', '572', '573', '574', '575', '576', '577', '578', + '579', '580', '581', '582', '583', '584', '585', '586', '587', '588', + '589', '590', '591', '592', '593', '594', '595', '596', '597', '598', + '599', '600', '601', '602', '603', '604', '605', '606', '607', '608', + '609', '610', '611', '612', '613', '614', '615', '616', '617', '618', + '619', '620', '621', '622', '623', '624', '625', '626', '627', '628', + '629', '630', '631', '632', '633', '634', '635', '636', '637', '638', + '639', '640', '641', '642', '643', '644', '645', '646', '647', '648', + '649', '650', '651', '652', '653', '654', '655', '656', '657', '658', + '659', '660', '661', '662', '663', '664', '665', '666', '667', '668', + '669', '670', '671', '672', '673', '674', '675', '676', '677', '678', + '679', '680', '681', '682', '683', '684', '685', '686', '687', '688', + '689', '690', '691', '692', '693', '694', '695', '696', '697', '698', + '699', '700', '701', '702', '703', '704', '705', '706', '707', '708', + '709', '710', '711', '712', '713', '714', '715', '716', '717', '718', + '719', '720', '721', '722', '723', '724', '725', '726', '727', '728', + '729', '730', '731', '732', '733', '734', '735', '736', '737', '738', + '739', '740', '741', '742', '743', '744', '745', '746', '747', '748', + '749', '750', '751', '752', '753', '754', '755', '756', '757', '758', + '759', '760', '761', '762', '763', '764', '765', '766', '767', '768', + '769', '770', '771', '772', '773', '774', '775', '776', '777', '778', + '779', '780', '781', '782', '783', '784', '785', '786', '787', '788', + '789', '790', '791', '792', '793', '794', '795', '796', '797', '798', + '799', '800', '801', '802', '803', '804', '805', '806', '807', '808', + '809', '810', '811', '812', '813', '814', '815', '816', '817', '818', + '819', '820', '821', '822', '823', '824', '825', '826', '827', '828', + '829', '830', '831', '832', '833', '834', '835', '836', '837', '838', + '839', '840', '841', '842', '843', '844', '845', '846', '847', '848', + '849', '850', '851', '852', '853', '854', '855', '856', '857', '858', + '859', '860', '861', '862', '863', '864', '865', '866', '867', '868', + '869', '870', '871', '872', '873', '874', '875', '876', '877', '878', + '879', '880', '881', '882', '883', '884', '885', '886', '887', '888', + '889', '890', '891', '892', '893', '894', '895', '896', '897', '898', + '899', '900', '901', '902', '903', '904', '905', '906', '907', '908', + '909', '910', '911', '912', '913', '914', '915', '916', '917', '918', + '919', '920', '921', '922', '923', '924', '925', '926', '927', '928', + '929', '930', '931', '932', '933', '934', '935', '936', '937', '938', + '939', '940', '941', '942', '943', '944', '945', '946', '947', '948', + '949', '950', '951', '952', '953', '954', '955', '956', '957', '958', + '959', '960', '961', '962', '963', '964', '965', '966', '967', '968', + '969', '970', '971', '972', '973', '974', '975', '976', '977', '978', + '979', '980', '981', '982', '983', '984', '985', '986', '987', '988', + '989', '990', '991', '992', '993', '994', '995', '996', '997', '998', '999' +] + +data_source_type = 'ClsSourceImageList' +base_root = 'data/imagenet_raw/' +data_train_list = base_root + 'meta/train_labeled.txt' +data_train_root = base_root + 'train/' +data_test_list = base_root + 'meta/val_labeled.txt' +data_test_root = base_root + 'validation/' +image_size2 = 299 +image_size1 = int((256 / 224) * image_size2) + +dataset_type = 'ClsDataset' +img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) +train_pipeline = [ + dict(type='RandomResizedCrop', size=image_size2), + dict(type='RandomHorizontalFlip'), + dict(type='ToTensor'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Collect', keys=['img', 'gt_labels']) +] +test_pipeline = [ + dict(type='Resize', size=image_size1), + dict(type='CenterCrop', size=image_size2), + dict(type='ToTensor'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Collect', keys=['img', 'gt_labels']) +] + +data = dict( + imgs_per_gpu=32, # total 256 + workers_per_gpu=4, + train=dict( + type=dataset_type, + data_source=dict( + list_file=data_train_list, + root=data_train_root, + type=data_source_type), + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_source=dict( + list_file=data_test_list, + root=data_test_root, + type=data_source_type), + pipeline=test_pipeline)) + +eval_config = dict(initial=False, interval=1, gpu_collect=True) +eval_pipelines = [ + dict( + mode='test', + data=data['val'], + dist_eval=True, + evaluators=[ + dict(type='ClsEvaluator', topk=(1, 5), class_list=class_list) + ], + ) +] + +# optimizer +optimizer = dict(type='RMSprop', lr=0.045, momentum=0.9, weight_decay=0.9, eps=1.0) + +# learning policy +lr_config = dict(policy='exp', gamma=0.96954) # gamma**2 ~ 0.94 +checkpoint_config = dict(interval=10) + +# runtime settings +total_epochs = 200 + +predict = dict( + type='ClassificationPredictor', + pipelines=[ + dict(type='Resize', size=image_size1), + dict(type='CenterCrop', size=image_size2), + dict(type='ToTensor'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Collect', keys=['img']) + ]) + +log_config = dict( + interval=10, + hooks=[dict(type='TextLoggerHook'), + dict(type='TensorboardLoggerHook')]) diff --git a/easycv/models/backbones/__init__.py b/easycv/models/backbones/__init__.py index 5dc4561b..0db2af72 100644 --- a/easycv/models/backbones/__init__.py +++ b/easycv/models/backbones/__init__.py @@ -28,3 +28,4 @@ from .vision_transformer import VisionTransformer from .vitdet import ViTDet from .x3d import X3D +from .inceptionv4 import Inception4 \ No newline at end of file diff --git a/easycv/models/backbones/inceptionv4.py b/easycv/models/backbones/inceptionv4.py new file mode 100644 index 00000000..b0398c61 --- /dev/null +++ b/easycv/models/backbones/inceptionv4.py @@ -0,0 +1,380 @@ +from __future__ import print_function, division, absolute_import +from collections import namedtuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.modules.batchnorm import _BatchNorm + +from mmcv.cnn import constant_init, kaiming_init +from ..registry import BACKBONES + + +_InceptionOutputs = namedtuple('InceptionOutputs', ['logits', 'aux_logits']) +__all__ = ['Inception4'] + +class BasicConv2d(nn.Module): + + def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): + super(BasicConv2d, self).__init__() + self.conv = nn.Conv2d(in_planes, out_planes, + kernel_size=kernel_size, stride=stride, + padding=padding, bias=False) # verify bias false + self.bn = nn.BatchNorm2d(out_planes, + eps=0.001, # value found in tensorflow + momentum=0.1, # default pytorch value + affine=True) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class Mixed_3a(nn.Module): + + def __init__(self): + super(Mixed_3a, self).__init__() + self.maxpool = nn.MaxPool2d(3, stride=2) + self.conv = BasicConv2d(64, 96, kernel_size=3, stride=2) + + def forward(self, x): + x0 = self.maxpool(x) + x1 = self.conv(x) + out = torch.cat((x0, x1), 1) + return out + + +class Mixed_4a(nn.Module): + + def __init__(self): + super(Mixed_4a, self).__init__() + + self.branch0 = nn.Sequential( + BasicConv2d(160, 64, kernel_size=1, stride=1), + BasicConv2d(64, 96, kernel_size=3, stride=1) + ) + + self.branch1 = nn.Sequential( + BasicConv2d(160, 64, kernel_size=1, stride=1), + BasicConv2d(64, 64, kernel_size=(1,7), stride=1, padding=(0,3)), + BasicConv2d(64, 64, kernel_size=(7,1), stride=1, padding=(3,0)), + BasicConv2d(64, 96, kernel_size=(3,3), stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + out = torch.cat((x0, x1), 1) + return out + + +class Mixed_5a(nn.Module): + + def __init__(self): + super(Mixed_5a, self).__init__() + self.conv = BasicConv2d(192, 192, kernel_size=3, stride=2) + self.maxpool = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.conv(x) + x1 = self.maxpool(x) + out = torch.cat((x0, x1), 1) + return out + + +class Inception_A(nn.Module): + + def __init__(self): + super(Inception_A, self).__init__() + self.branch0 = BasicConv2d(384, 96, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(384, 64, kernel_size=1, stride=1), + BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1) + ) + + self.branch2 = nn.Sequential( + BasicConv2d(384, 64, kernel_size=1, stride=1), + BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1), + BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1) + ) + + self.branch3 = nn.Sequential( + nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), + BasicConv2d(384, 96, kernel_size=1, stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class Reduction_A(nn.Module): + + def __init__(self): + super(Reduction_A, self).__init__() + self.branch0 = BasicConv2d(384, 384, kernel_size=3, stride=2) + + self.branch1 = nn.Sequential( + BasicConv2d(384, 192, kernel_size=1, stride=1), + BasicConv2d(192, 224, kernel_size=3, stride=1, padding=1), + BasicConv2d(224, 256, kernel_size=3, stride=2) + ) + + self.branch2 = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + out = torch.cat((x0, x1, x2), 1) + return out + + +class Inception_B(nn.Module): + + def __init__(self): + super(Inception_B, self).__init__() + self.branch0 = BasicConv2d(1024, 384, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(1024, 192, kernel_size=1, stride=1), + BasicConv2d(192, 224, kernel_size=(1,7), stride=1, padding=(0,3)), + BasicConv2d(224, 256, kernel_size=(7,1), stride=1, padding=(3,0)) + ) + + self.branch2 = nn.Sequential( + BasicConv2d(1024, 192, kernel_size=1, stride=1), + BasicConv2d(192, 192, kernel_size=(7,1), stride=1, padding=(3,0)), + BasicConv2d(192, 224, kernel_size=(1,7), stride=1, padding=(0,3)), + BasicConv2d(224, 224, kernel_size=(7,1), stride=1, padding=(3,0)), + BasicConv2d(224, 256, kernel_size=(1,7), stride=1, padding=(0,3)) + ) + + self.branch3 = nn.Sequential( + nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), + BasicConv2d(1024, 128, kernel_size=1, stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class Reduction_B(nn.Module): + + def __init__(self): + super(Reduction_B, self).__init__() + + self.branch0 = nn.Sequential( + BasicConv2d(1024, 192, kernel_size=1, stride=1), + BasicConv2d(192, 192, kernel_size=3, stride=2) + ) + + self.branch1 = nn.Sequential( + BasicConv2d(1024, 256, kernel_size=1, stride=1), + BasicConv2d(256, 256, kernel_size=(1,7), stride=1, padding=(0,3)), + BasicConv2d(256, 320, kernel_size=(7,1), stride=1, padding=(3,0)), + BasicConv2d(320, 320, kernel_size=3, stride=2) + ) + + self.branch2 = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + out = torch.cat((x0, x1, x2), 1) + return out + + +class Inception_C(nn.Module): + + def __init__(self): + super(Inception_C, self).__init__() + + self.branch0 = BasicConv2d(1536, 256, kernel_size=1, stride=1) + + self.branch1_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1) + self.branch1_1a = BasicConv2d(384, 256, kernel_size=(1,3), stride=1, padding=(0,1)) + self.branch1_1b = BasicConv2d(384, 256, kernel_size=(3,1), stride=1, padding=(1,0)) + + self.branch2_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1) + self.branch2_1 = BasicConv2d(384, 448, kernel_size=(3,1), stride=1, padding=(1,0)) + self.branch2_2 = BasicConv2d(448, 512, kernel_size=(1,3), stride=1, padding=(0,1)) + self.branch2_3a = BasicConv2d(512, 256, kernel_size=(1,3), stride=1, padding=(0,1)) + self.branch2_3b = BasicConv2d(512, 256, kernel_size=(3,1), stride=1, padding=(1,0)) + + self.branch3 = nn.Sequential( + nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), + BasicConv2d(1536, 256, kernel_size=1, stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + + x1_0 = self.branch1_0(x) + x1_1a = self.branch1_1a(x1_0) + x1_1b = self.branch1_1b(x1_0) + x1 = torch.cat((x1_1a, x1_1b), 1) + + x2_0 = self.branch2_0(x) + x2_1 = self.branch2_1(x2_0) + x2_2 = self.branch2_2(x2_1) + x2_3a = self.branch2_3a(x2_2) + x2_3b = self.branch2_3b(x2_2) + x2 = torch.cat((x2_3a, x2_3b), 1) + + x3 = self.branch3(x) + + out = torch.cat((x0, x1, x2, x3), 1) + return out + +class InceptionAux(nn.Module): + + def __init__(self, in_channels, num_classes): + super(InceptionAux, self).__init__() + self.conv0 = BasicConv2d(in_channels, 128, kernel_size=1) + self.conv1 = BasicConv2d(128, 768, kernel_size=5) + self.conv1.stddev = 0.01 + self.fc = nn.Linear(768, num_classes) + self.fc.stddev = 0.001 + + def forward(self, x): + # N x 768 x 17 x 17 + x = F.avg_pool2d(x, kernel_size=5, stride=3) + # N x 768 x 5 x 5 + x = self.conv0(x) + # N x 128 x 5 x 5 + x = self.conv1(x) + # N x 768 x 1 x 1 + # Adaptive average pooling + x = F.adaptive_avg_pool2d(x, (1, 1)) + # N x 768 x 1 x 1 + x = torch.flatten(x, 1) + # N x 768 + x = self.fc(x) + # N x 1000 + return x + + +class BasicConv2d(nn.Module): + + def __init__(self, in_channels, out_channels, **kwargs): + super(BasicConv2d, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) + self.bn = nn.BatchNorm2d(out_channels, eps=0.001) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return F.relu(x, inplace=True) + + +@BACKBONES.register_module +class Inception4(nn.Module): + """InceptionV4 backbone. + + Args: + num_classes (int): The num_classes of InceptionV4. An extra fc will be used if + """ + def __init__(self, num_classes: int=0, p_dropout=0.2, aux_logits: bool=True): + super(Inception4, self).__init__() + self.aux_logits = aux_logits + # Modules + self.features = nn.Sequential( + BasicConv2d(3, 32, kernel_size=3, stride=2), + BasicConv2d(32, 32, kernel_size=3, stride=1), + BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1), + Mixed_3a(), + Mixed_4a(), + Mixed_5a(), + Inception_A(), + Inception_A(), + Inception_A(), + Inception_A(), + Reduction_A(), # Mixed_6a + Inception_B(), + Inception_B(), + Inception_B(), + Inception_B(), + Inception_B(), + Inception_B(), + Inception_B(), # Mixed_6h 1024 x 17 x 17 + Reduction_B(), # Mixed_7a + Inception_C(), + Inception_C(), + Inception_C() + ) + + if aux_logits: + self.AuxLogits = InceptionAux(1024, num_classes) + + self.dropout = nn.Dropout(p_dropout) + self.last_linear = None + if num_classes > 0: + self.last_linear = nn.Linear(1536, num_classes) + + @property + def fc(self): + return self.last_linear + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m, mode='fan_in', nonlinearity='relu') + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, 1) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def logits(self, features): + # Allows image of any size to be processed + adaptiveAvgPoolWidth = features.shape[2] + x = F.avg_pool2d(features, kernel_size=adaptiveAvgPoolWidth) + x = x.view(x.size(0), -1) # B x 1536 + x = self.fc(x) + # B x num_classes + return x + + def forward(self, input: torch.Tensor): + """_summary_ + + Args: + input (torch.Tensor): A RGB image tensor with shape B x C x H x W + + Returns: + torch.Tensor: A feature tensor or a logit tensor when num_classes is 0 (default) + """ + + if self.training and self.aux_logits: + x = self.features[:-4](input) + aux = self.AuxLogits(x) + x = self.features[-4:](x) + else: + x = self.features(input) + aux = None + + if self.fc is not None: + x = self.logits(x) + + if self.training and self.aux_logits and self.fc is not None: + return [_InceptionOutputs(x, aux)] + return [x] + + + diff --git a/easycv/models/modelzoo.py b/easycv/models/modelzoo.py index d367cc62..f70725c4 100644 --- a/easycv/models/modelzoo.py +++ b/easycv/models/modelzoo.py @@ -79,6 +79,11 @@ 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/easycv/inceptionv3/inception_v3.pth', } +inceptionv4 = { + # Inception v4 ported from http://data.lip6.fr/cadene/pretrainedmodels/inceptionv4-8e4777a0.pth + 'Inception4': '', +} + genet = { 'PlainNetnormal': 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/easycv/genet/GENet_normal.pth', From 03c825d8eb8ec62085707414b103fcb2019f707a Mon Sep 17 00:00:00 2001 From: "gushen.hkw" Date: Fri, 5 Jul 2024 10:52:41 +0800 Subject: [PATCH 02/11] fix for inception DDP training --- .../inception/inceptionv3_b32x8_100e.py | 23 +++++++++++++++---- easycv/models/backbones/inceptionv3.py | 12 ++-------- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/configs/classification/imagenet/inception/inceptionv3_b32x8_100e.py b/configs/classification/imagenet/inception/inceptionv3_b32x8_100e.py index a0cf8d57..666ffdd1 100644 --- a/configs/classification/imagenet/inception/inceptionv3_b32x8_100e.py +++ b/configs/classification/imagenet/inception/inceptionv3_b32x8_100e.py @@ -4,16 +4,31 @@ # model settings model = dict( type='Classification', - backbone=dict(type='Inception3'), - head=dict( + backbone=dict(type='Inception3', num_classes=1000), + head=[dict( type='ClsHead', - with_avg_pool=True, + with_fc=False, in_channels=2048, loss_config=dict( type='CrossEntropyLossWithLabelSmooth', label_smooth=0, ), - num_classes=num_classes)) + num_classes=num_classes, + input_feature_index=[1], + ), + dict( + type='ClsHead', + with_fc=False, + in_channels=768, + loss_config=dict( + type='CrossEntropyLossWithLabelSmooth', + label_smooth=0, + ), + num_classes=num_classes, + input_feature_index=[0], + ) + ] + ) class_list = [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', diff --git a/easycv/models/backbones/inceptionv3.py b/easycv/models/backbones/inceptionv3.py index 69f591bb..55e54dff 100644 --- a/easycv/models/backbones/inceptionv3.py +++ b/easycv/models/backbones/inceptionv3.py @@ -2,9 +2,6 @@ r""" This model is taken from the official PyTorch model zoo. - torchvision.models.inception.py on 31th Aug, 2019 """ - -from collections import namedtuple - import torch import torch.nn as nn import torch.nn.functional as F @@ -16,9 +13,6 @@ __all__ = ['Inception3'] -_InceptionOutputs = namedtuple('InceptionOutputs', ['logits', 'aux_logits']) - - @BACKBONES.register_module class Inception3(nn.Module): @@ -113,6 +107,7 @@ def forward(self, x): # N x 768 x 17 x 17 x = self.Mixed_6e(x) # N x 768 x 17 x 17 + aux = None if self.training and self.aux_logits: aux = self.AuxLogits(x) # N x 768 x 17 x 17 @@ -132,10 +127,7 @@ def forward(self, x): if hasattr(self, 'fc'): x = self.fc(x) - # N x 1000 (num_classes) - if self.training and self.aux_logits and hasattr(self, 'fc'): - return [_InceptionOutputs(x, aux)] - return [x] + return [aux, x] class InceptionA(nn.Module): From 6c4a183f80cd506926b7659b5a09ad12f08768af Mon Sep 17 00:00:00 2001 From: "gushen.hkw" Date: Fri, 5 Jul 2024 11:10:56 +0800 Subject: [PATCH 03/11] add inceptionv4 backbone/training settings --- .../inception/inceptionv4_b32x8_100e.py | 24 +++++-- .../inceptionv4_b32x8_200e_rmsprop.py | 24 +++++-- easycv/models/backbones/inceptionv4.py | 5 +- .../test_models/backbones/test_inceptionv4.py | 63 +++++++++++++++++++ 4 files changed, 102 insertions(+), 14 deletions(-) create mode 100644 tests/test_models/backbones/test_inceptionv4.py diff --git a/configs/classification/imagenet/inception/inceptionv4_b32x8_100e.py b/configs/classification/imagenet/inception/inceptionv4_b32x8_100e.py index 3752844a..d895e524 100644 --- a/configs/classification/imagenet/inception/inceptionv4_b32x8_100e.py +++ b/configs/classification/imagenet/inception/inceptionv4_b32x8_100e.py @@ -4,17 +4,31 @@ # model settings model = dict( type='Classification', - pretrained=False, - backbone=dict(type='Inception4'), - head=dict( + backbone=dict(type='Inception3', num_classes=num_classes), + head=[dict( type='ClsHead', - with_avg_pool=True, + with_fc=False, in_channels=1536, loss_config=dict( type='CrossEntropyLossWithLabelSmooth', label_smooth=0, ), - num_classes=num_classes)) + num_classes=num_classes, + input_feature_index=[1], + ), + dict( + type='ClsHead', + with_fc=False, + in_channels=768, + loss_config=dict( + type='CrossEntropyLossWithLabelSmooth', + label_smooth=0, + ), + num_classes=num_classes, + input_feature_index=[0], + ) + ] + ) class_list = [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', diff --git a/configs/classification/imagenet/inception/inceptionv4_b32x8_200e_rmsprop.py b/configs/classification/imagenet/inception/inceptionv4_b32x8_200e_rmsprop.py index 5b4151e4..ee6c6b2f 100644 --- a/configs/classification/imagenet/inception/inceptionv4_b32x8_200e_rmsprop.py +++ b/configs/classification/imagenet/inception/inceptionv4_b32x8_200e_rmsprop.py @@ -6,17 +6,31 @@ # model settings model = dict( type='Classification', - pretrained=False, - backbone=dict(type='Inception4'), - head=dict( + backbone=dict(type='Inception3', num_classes=num_classes), + head=[dict( type='ClsHead', - with_avg_pool=True, + with_fc=False, in_channels=1536, loss_config=dict( type='CrossEntropyLossWithLabelSmooth', label_smooth=0, ), - num_classes=num_classes)) + num_classes=num_classes, + input_feature_index=[1], + ), + dict( + type='ClsHead', + with_fc=False, + in_channels=768, + loss_config=dict( + type='CrossEntropyLossWithLabelSmooth', + label_smooth=0, + ), + num_classes=num_classes, + input_feature_index=[0], + ) + ] + ) class_list = [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', diff --git a/easycv/models/backbones/inceptionv4.py b/easycv/models/backbones/inceptionv4.py index b0398c61..f79047e0 100644 --- a/easycv/models/backbones/inceptionv4.py +++ b/easycv/models/backbones/inceptionv4.py @@ -10,7 +10,6 @@ from ..registry import BACKBONES -_InceptionOutputs = namedtuple('InceptionOutputs', ['logits', 'aux_logits']) __all__ = ['Inception4'] class BasicConv2d(nn.Module): @@ -372,9 +371,7 @@ def forward(self, input: torch.Tensor): if self.fc is not None: x = self.logits(x) - if self.training and self.aux_logits and self.fc is not None: - return [_InceptionOutputs(x, aux)] - return [x] + return [aux, x] diff --git a/tests/test_models/backbones/test_inceptionv4.py b/tests/test_models/backbones/test_inceptionv4.py new file mode 100644 index 00000000..d9b7625c --- /dev/null +++ b/tests/test_models/backbones/test_inceptionv4.py @@ -0,0 +1,63 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import copy +import random +import unittest + +import numpy as np +import torch + +from easycv.models import modelzoo +from easycv.models.backbones import Inception4 + + +class InceptionV3Test(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + + def test_inceptionv3_withfc(self): + with torch.no_grad(): + # input data + batch_size = random.randint(10, 30) + a = torch.rand(batch_size, 3, 299, 299).to('cuda') + + num_classes = random.randint(10, 1000) + net = Inception4( + aux_logits=True, num_classes=num_classes).to('cuda') + net.init_weights() + net.train() + + self.assertTrue(len(list(net(a)[-1].shape)) == 2) + self.assertTrue(len(list(net(a)[0].shape)) == 2) + self.assertTrue(net(a)[-1].size(1) == num_classes) + self.assertTrue(net(a)[-1].size(0) == batch_size) + self.assertTrue(net(a)[0].size(1) == num_classes) + self.assertTrue(net(a)[0].size(0) == batch_size) + + def test_inceptionv3_withoutfc(self): + with torch.no_grad(): + # input data + batch_size = random.randint(10, 30) + a = torch.rand(batch_size, 3, 299, 299).to('cuda') + + net = Inception4(aux_logits=True, num_classes=0).to('cuda') + net.init_weights() + net.eval() + + self.assertTrue(net(a)[-1].size(1) == 1536) + self.assertTrue(net(a)[-1].size(0) == batch_size) + + def test_inceptionv3_load_modelzoo(self): + with torch.no_grad(): + net = Inception4(aux_logits=True, num_classes=1000).to('cuda') + original_weight = net.features[0].conv.weight + original_weight = copy.deepcopy(original_weight.cpu().data.numpy()) + + net.init_weights() + load_weight = net.features[0].conv.weight.cpu().data.numpy() + + self.assertFalse(np.allclose(original_weight, load_weight)) + + +if __name__ == '__main__': + unittest.main() From 80105f236b57b31b6c0eabbe54f4705ae1b8fb72 Mon Sep 17 00:00:00 2001 From: "gushen.hkw" Date: Fri, 5 Jul 2024 11:23:04 +0800 Subject: [PATCH 04/11] fix config --- .../classification/imagenet/inception/inceptionv4_b32x8_100e.py | 2 +- .../imagenet/inception/inceptionv4_b32x8_200e_rmsprop.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/configs/classification/imagenet/inception/inceptionv4_b32x8_100e.py b/configs/classification/imagenet/inception/inceptionv4_b32x8_100e.py index d895e524..528d804a 100644 --- a/configs/classification/imagenet/inception/inceptionv4_b32x8_100e.py +++ b/configs/classification/imagenet/inception/inceptionv4_b32x8_100e.py @@ -4,7 +4,7 @@ # model settings model = dict( type='Classification', - backbone=dict(type='Inception3', num_classes=num_classes), + backbone=dict(type='Inception4', num_classes=num_classes), head=[dict( type='ClsHead', with_fc=False, diff --git a/configs/classification/imagenet/inception/inceptionv4_b32x8_200e_rmsprop.py b/configs/classification/imagenet/inception/inceptionv4_b32x8_200e_rmsprop.py index ee6c6b2f..ae55310e 100644 --- a/configs/classification/imagenet/inception/inceptionv4_b32x8_200e_rmsprop.py +++ b/configs/classification/imagenet/inception/inceptionv4_b32x8_200e_rmsprop.py @@ -6,7 +6,7 @@ # model settings model = dict( type='Classification', - backbone=dict(type='Inception3', num_classes=num_classes), + backbone=dict(type='Inception4', num_classes=num_classes), head=[dict( type='ClsHead', with_fc=False, From 1cf7cfe186f3505130ba533290928c5a13fe9139 Mon Sep 17 00:00:00 2001 From: "gushen.hkw" Date: Mon, 8 Jul 2024 19:13:02 +0800 Subject: [PATCH 05/11] add converted backbone, top-1 acc 80.08 --- .../inception/inceptionv4_b32x8_100e.py | 48 +++--- .../inceptionv4_b32x8_200e_rmsprop.py | 53 +++--- easycv/models/backbones/inceptionv4.py | 152 ++++++++++-------- easycv/models/modelzoo.py | 3 +- 4 files changed, 135 insertions(+), 121 deletions(-) diff --git a/configs/classification/imagenet/inception/inceptionv4_b32x8_100e.py b/configs/classification/imagenet/inception/inceptionv4_b32x8_100e.py index 528d804a..5aee7685 100644 --- a/configs/classification/imagenet/inception/inceptionv4_b32x8_100e.py +++ b/configs/classification/imagenet/inception/inceptionv4_b32x8_100e.py @@ -5,30 +5,30 @@ model = dict( type='Classification', backbone=dict(type='Inception4', num_classes=num_classes), - head=[dict( - type='ClsHead', - with_fc=False, - in_channels=1536, - loss_config=dict( - type='CrossEntropyLossWithLabelSmooth', - label_smooth=0, + head=[ + dict( + type='ClsHead', + with_fc=False, + in_channels=1536, + loss_config=dict( + type='CrossEntropyLossWithLabelSmooth', + label_smooth=0, + ), + num_classes=num_classes, + input_feature_index=[1], ), - num_classes=num_classes, - input_feature_index=[1], - ), - dict( - type='ClsHead', - with_fc=False, - in_channels=768, - loss_config=dict( - type='CrossEntropyLossWithLabelSmooth', - label_smooth=0, - ), - num_classes=num_classes, - input_feature_index=[0], - ) - ] - ) + dict( + type='ClsHead', + with_fc=False, + in_channels=768, + loss_config=dict( + type='CrossEntropyLossWithLabelSmooth', + label_smooth=0, + ), + num_classes=num_classes, + input_feature_index=[0], + ) + ]) class_list = [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', @@ -141,7 +141,7 @@ image_size1 = int((256 / 224) * image_size2) dataset_type = 'ClsDataset' -img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) +img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) train_pipeline = [ dict(type='RandomResizedCrop', size=image_size2), dict(type='RandomHorizontalFlip'), diff --git a/configs/classification/imagenet/inception/inceptionv4_b32x8_200e_rmsprop.py b/configs/classification/imagenet/inception/inceptionv4_b32x8_200e_rmsprop.py index ae55310e..2fc8b5ba 100644 --- a/configs/classification/imagenet/inception/inceptionv4_b32x8_200e_rmsprop.py +++ b/configs/classification/imagenet/inception/inceptionv4_b32x8_200e_rmsprop.py @@ -7,30 +7,30 @@ model = dict( type='Classification', backbone=dict(type='Inception4', num_classes=num_classes), - head=[dict( - type='ClsHead', - with_fc=False, - in_channels=1536, - loss_config=dict( - type='CrossEntropyLossWithLabelSmooth', - label_smooth=0, + head=[ + dict( + type='ClsHead', + with_fc=False, + in_channels=1536, + loss_config=dict( + type='CrossEntropyLossWithLabelSmooth', + label_smooth=0, + ), + num_classes=num_classes, + input_feature_index=[1], ), - num_classes=num_classes, - input_feature_index=[1], - ), - dict( - type='ClsHead', - with_fc=False, - in_channels=768, - loss_config=dict( - type='CrossEntropyLossWithLabelSmooth', - label_smooth=0, - ), - num_classes=num_classes, - input_feature_index=[0], - ) - ] - ) + dict( + type='ClsHead', + with_fc=False, + in_channels=768, + loss_config=dict( + type='CrossEntropyLossWithLabelSmooth', + label_smooth=0, + ), + num_classes=num_classes, + input_feature_index=[0], + ) + ]) class_list = [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', @@ -143,7 +143,7 @@ image_size1 = int((256 / 224) * image_size2) dataset_type = 'ClsDataset' -img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) +img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) train_pipeline = [ dict(type='RandomResizedCrop', size=image_size2), dict(type='RandomHorizontalFlip'), @@ -190,10 +190,11 @@ ] # optimizer -optimizer = dict(type='RMSprop', lr=0.045, momentum=0.9, weight_decay=0.9, eps=1.0) +optimizer = dict( + type='RMSprop', lr=0.045, momentum=0.9, weight_decay=0.9, eps=1.0) # learning policy -lr_config = dict(policy='exp', gamma=0.96954) # gamma**2 ~ 0.94 +lr_config = dict(policy='exp', gamma=0.96954) # gamma**2 ~ 0.94 checkpoint_config = dict(interval=10) # runtime settings diff --git a/easycv/models/backbones/inceptionv4.py b/easycv/models/backbones/inceptionv4.py index f79047e0..173949f9 100644 --- a/easycv/models/backbones/inceptionv4.py +++ b/easycv/models/backbones/inceptionv4.py @@ -1,28 +1,34 @@ -from __future__ import print_function, division, absolute_import +from __future__ import absolute_import, division, print_function from collections import namedtuple import torch import torch.nn as nn import torch.nn.functional as F +from mmcv.cnn import constant_init, kaiming_init from torch.nn.modules.batchnorm import _BatchNorm -from mmcv.cnn import constant_init, kaiming_init +from ..modelzoo import inceptionv4 as model_urls from ..registry import BACKBONES - __all__ = ['Inception4'] + class BasicConv2d(nn.Module): def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): super(BasicConv2d, self).__init__() - self.conv = nn.Conv2d(in_planes, out_planes, - kernel_size=kernel_size, stride=stride, - padding=padding, bias=False) # verify bias false - self.bn = nn.BatchNorm2d(out_planes, - eps=0.001, # value found in tensorflow - momentum=0.1, # default pytorch value - affine=True) + self.conv = nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=False) # verify bias false + self.bn = nn.BatchNorm2d( + out_planes, + eps=0.001, # value found in tensorflow + momentum=0.1, # default pytorch value + affine=True) self.relu = nn.ReLU(inplace=True) def forward(self, x): @@ -53,15 +59,13 @@ def __init__(self): self.branch0 = nn.Sequential( BasicConv2d(160, 64, kernel_size=1, stride=1), - BasicConv2d(64, 96, kernel_size=3, stride=1) - ) + BasicConv2d(64, 96, kernel_size=3, stride=1)) self.branch1 = nn.Sequential( BasicConv2d(160, 64, kernel_size=1, stride=1), - BasicConv2d(64, 64, kernel_size=(1,7), stride=1, padding=(0,3)), - BasicConv2d(64, 64, kernel_size=(7,1), stride=1, padding=(3,0)), - BasicConv2d(64, 96, kernel_size=(3,3), stride=1) - ) + BasicConv2d(64, 64, kernel_size=(1, 7), stride=1, padding=(0, 3)), + BasicConv2d(64, 64, kernel_size=(7, 1), stride=1, padding=(3, 0)), + BasicConv2d(64, 96, kernel_size=(3, 3), stride=1)) def forward(self, x): x0 = self.branch0(x) @@ -92,19 +96,16 @@ def __init__(self): self.branch1 = nn.Sequential( BasicConv2d(384, 64, kernel_size=1, stride=1), - BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1) - ) + BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1)) self.branch2 = nn.Sequential( BasicConv2d(384, 64, kernel_size=1, stride=1), BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1), - BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1) - ) + BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1)) self.branch3 = nn.Sequential( nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), - BasicConv2d(384, 96, kernel_size=1, stride=1) - ) + BasicConv2d(384, 96, kernel_size=1, stride=1)) def forward(self, x): x0 = self.branch0(x) @@ -124,8 +125,7 @@ def __init__(self): self.branch1 = nn.Sequential( BasicConv2d(384, 192, kernel_size=1, stride=1), BasicConv2d(192, 224, kernel_size=3, stride=1, padding=1), - BasicConv2d(224, 256, kernel_size=3, stride=2) - ) + BasicConv2d(224, 256, kernel_size=3, stride=2)) self.branch2 = nn.MaxPool2d(3, stride=2) @@ -145,22 +145,25 @@ def __init__(self): self.branch1 = nn.Sequential( BasicConv2d(1024, 192, kernel_size=1, stride=1), - BasicConv2d(192, 224, kernel_size=(1,7), stride=1, padding=(0,3)), - BasicConv2d(224, 256, kernel_size=(7,1), stride=1, padding=(3,0)) - ) + BasicConv2d( + 192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)), + BasicConv2d( + 224, 256, kernel_size=(7, 1), stride=1, padding=(3, 0))) self.branch2 = nn.Sequential( BasicConv2d(1024, 192, kernel_size=1, stride=1), - BasicConv2d(192, 192, kernel_size=(7,1), stride=1, padding=(3,0)), - BasicConv2d(192, 224, kernel_size=(1,7), stride=1, padding=(0,3)), - BasicConv2d(224, 224, kernel_size=(7,1), stride=1, padding=(3,0)), - BasicConv2d(224, 256, kernel_size=(1,7), stride=1, padding=(0,3)) - ) + BasicConv2d( + 192, 192, kernel_size=(7, 1), stride=1, padding=(3, 0)), + BasicConv2d( + 192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)), + BasicConv2d( + 224, 224, kernel_size=(7, 1), stride=1, padding=(3, 0)), + BasicConv2d( + 224, 256, kernel_size=(1, 7), stride=1, padding=(0, 3))) self.branch3 = nn.Sequential( nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), - BasicConv2d(1024, 128, kernel_size=1, stride=1) - ) + BasicConv2d(1024, 128, kernel_size=1, stride=1)) def forward(self, x): x0 = self.branch0(x) @@ -178,15 +181,15 @@ def __init__(self): self.branch0 = nn.Sequential( BasicConv2d(1024, 192, kernel_size=1, stride=1), - BasicConv2d(192, 192, kernel_size=3, stride=2) - ) + BasicConv2d(192, 192, kernel_size=3, stride=2)) self.branch1 = nn.Sequential( BasicConv2d(1024, 256, kernel_size=1, stride=1), - BasicConv2d(256, 256, kernel_size=(1,7), stride=1, padding=(0,3)), - BasicConv2d(256, 320, kernel_size=(7,1), stride=1, padding=(3,0)), - BasicConv2d(320, 320, kernel_size=3, stride=2) - ) + BasicConv2d( + 256, 256, kernel_size=(1, 7), stride=1, padding=(0, 3)), + BasicConv2d( + 256, 320, kernel_size=(7, 1), stride=1, padding=(3, 0)), + BasicConv2d(320, 320, kernel_size=3, stride=2)) self.branch2 = nn.MaxPool2d(3, stride=2) @@ -206,19 +209,24 @@ def __init__(self): self.branch0 = BasicConv2d(1536, 256, kernel_size=1, stride=1) self.branch1_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1) - self.branch1_1a = BasicConv2d(384, 256, kernel_size=(1,3), stride=1, padding=(0,1)) - self.branch1_1b = BasicConv2d(384, 256, kernel_size=(3,1), stride=1, padding=(1,0)) + self.branch1_1a = BasicConv2d( + 384, 256, kernel_size=(1, 3), stride=1, padding=(0, 1)) + self.branch1_1b = BasicConv2d( + 384, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)) self.branch2_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1) - self.branch2_1 = BasicConv2d(384, 448, kernel_size=(3,1), stride=1, padding=(1,0)) - self.branch2_2 = BasicConv2d(448, 512, kernel_size=(1,3), stride=1, padding=(0,1)) - self.branch2_3a = BasicConv2d(512, 256, kernel_size=(1,3), stride=1, padding=(0,1)) - self.branch2_3b = BasicConv2d(512, 256, kernel_size=(3,1), stride=1, padding=(1,0)) + self.branch2_1 = BasicConv2d( + 384, 448, kernel_size=(3, 1), stride=1, padding=(1, 0)) + self.branch2_2 = BasicConv2d( + 448, 512, kernel_size=(1, 3), stride=1, padding=(0, 1)) + self.branch2_3a = BasicConv2d( + 512, 256, kernel_size=(1, 3), stride=1, padding=(0, 1)) + self.branch2_3b = BasicConv2d( + 512, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)) self.branch3 = nn.Sequential( nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), - BasicConv2d(1536, 256, kernel_size=1, stride=1) - ) + BasicConv2d(1536, 256, kernel_size=1, stride=1)) def forward(self, x): x0 = self.branch0(x) @@ -240,6 +248,7 @@ def forward(self, x): out = torch.cat((x0, x1, x2, x3), 1) return out + class InceptionAux(nn.Module): def __init__(self, in_channels, num_classes): @@ -268,17 +277,17 @@ def forward(self, x): return x -class BasicConv2d(nn.Module): +# class BasicConv2d(nn.Module): - def __init__(self, in_channels, out_channels, **kwargs): - super(BasicConv2d, self).__init__() - self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) - self.bn = nn.BatchNorm2d(out_channels, eps=0.001) +# def __init__(self, in_channels, out_channels, **kwargs): +# super(BasicConv2d, self).__init__() +# self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) +# self.bn = nn.BatchNorm2d(out_channels, eps=0.001) - def forward(self, x): - x = self.conv(x) - x = self.bn(x) - return F.relu(x, inplace=True) +# def forward(self, x): +# x = self.conv(x) +# x = self.bn(x) +# return F.relu(x, inplace=True) @BACKBONES.register_module @@ -288,7 +297,11 @@ class Inception4(nn.Module): Args: num_classes (int): The num_classes of InceptionV4. An extra fc will be used if """ - def __init__(self, num_classes: int=0, p_dropout=0.2, aux_logits: bool=True): + + def __init__(self, + num_classes: int = 0, + p_dropout=0.2, + aux_logits: bool = True): super(Inception4, self).__init__() self.aux_logits = aux_logits # Modules @@ -303,19 +316,18 @@ def __init__(self, num_classes: int=0, p_dropout=0.2, aux_logits: bool=True): Inception_A(), Inception_A(), Inception_A(), - Reduction_A(), # Mixed_6a + Reduction_A(), # Mixed_6a Inception_B(), Inception_B(), Inception_B(), Inception_B(), Inception_B(), Inception_B(), - Inception_B(), # Mixed_6h 1024 x 17 x 17 - Reduction_B(), # Mixed_7a + Inception_B(), # Mixed_6h 1024 x 17 x 17 + Reduction_B(), # Mixed_7a Inception_C(), Inception_C(), - Inception_C() - ) + Inception_C()) if aux_logits: self.AuxLogits = InceptionAux(1024, num_classes) @@ -325,6 +337,9 @@ def __init__(self, num_classes: int=0, p_dropout=0.2, aux_logits: bool=True): if num_classes > 0: self.last_linear = nn.Linear(1536, num_classes) + self.default_pretrained_model_path = model_urls[ + self.__class__.__name__] + @property def fc(self): return self.last_linear @@ -340,15 +355,15 @@ def init_weights(self): elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) - + def logits(self, features): # Allows image of any size to be processed adaptiveAvgPoolWidth = features.shape[2] x = F.avg_pool2d(features, kernel_size=adaptiveAvgPoolWidth) - x = x.view(x.size(0), -1) # B x 1536 + x = x.view(x.size(0), -1) # B x 1536 x = self.fc(x) # B x num_classes - return x + return x def forward(self, input: torch.Tensor): """_summary_ @@ -359,7 +374,7 @@ def forward(self, input: torch.Tensor): Returns: torch.Tensor: A feature tensor or a logit tensor when num_classes is 0 (default) """ - + if self.training and self.aux_logits: x = self.features[:-4](input) aux = self.AuxLogits(x) @@ -372,6 +387,3 @@ def forward(self, input: torch.Tensor): x = self.logits(x) return [aux, x] - - - diff --git a/easycv/models/modelzoo.py b/easycv/models/modelzoo.py index f70725c4..baa49d37 100644 --- a/easycv/models/modelzoo.py +++ b/easycv/models/modelzoo.py @@ -81,7 +81,8 @@ inceptionv4 = { # Inception v4 ported from http://data.lip6.fr/cadene/pretrainedmodels/inceptionv4-8e4777a0.pth - 'Inception4': '', + 'Inception4': + 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/easycv/inceptionv4/inception_v4.pth', } genet = { From 83ea6b77234ef15a6a4a38b44116127e336e5a4e Mon Sep 17 00:00:00 2001 From: "gushen.hkw" Date: Tue, 9 Jul 2024 10:20:57 +0800 Subject: [PATCH 06/11] fix bug caused by the comment duplicated BasicConv2d --- easycv/models/backbones/inceptionv4.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/easycv/models/backbones/inceptionv4.py b/easycv/models/backbones/inceptionv4.py index 173949f9..917dbf0f 100644 --- a/easycv/models/backbones/inceptionv4.py +++ b/easycv/models/backbones/inceptionv4.py @@ -15,7 +15,12 @@ class BasicConv2d(nn.Module): - def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): + def __init__(self, + in_planes, + out_planes, + kernel_size, + stride=1, + padding=0): super(BasicConv2d, self).__init__() self.conv = nn.Conv2d( in_planes, @@ -357,9 +362,8 @@ def init_weights(self): nn.init.constant_(m.bias, 0) def logits(self, features): - # Allows image of any size to be processed - adaptiveAvgPoolWidth = features.shape[2] - x = F.avg_pool2d(features, kernel_size=adaptiveAvgPoolWidth) + x = F.adaptive_avg_pool2d(features, output_size=(1, 1)) + # x = F.avg_pool2d(features, kernel_size=adaptiveAvgPoolWidth) x = x.view(x.size(0), -1) # B x 1536 x = self.fc(x) # B x num_classes From 1388f0c1fb8912962e307f1e37c5adc815dcc8e7 Mon Sep 17 00:00:00 2001 From: "gushen.hkw" Date: Tue, 9 Jul 2024 12:56:17 +0800 Subject: [PATCH 07/11] fix onnx export for inception3/4, resnext, mobilenetv2 --- .../inception/inceptionv3_b32x8_100e.py | 48 +++---- .../inception/inceptionv4_b32x8_100e.py | 2 + .../inceptionv4_b32x8_200e_rmsprop.py | 2 + .../imagenet/mobilenet/mobilenetv2.py | 5 +- .../resnext/resnext50-32x4d_b32x8_100e_jpg.py | 4 +- easycv/apis/export.py | 69 ++++++---- tests/test_tools/test_export.py | 128 ++++++++++++++++++ tools/export.py | 2 +- 8 files changed, 206 insertions(+), 54 deletions(-) create mode 100644 tests/test_tools/test_export.py diff --git a/configs/classification/imagenet/inception/inceptionv3_b32x8_100e.py b/configs/classification/imagenet/inception/inceptionv3_b32x8_100e.py index 666ffdd1..1d74ec39 100644 --- a/configs/classification/imagenet/inception/inceptionv3_b32x8_100e.py +++ b/configs/classification/imagenet/inception/inceptionv3_b32x8_100e.py @@ -5,30 +5,30 @@ model = dict( type='Classification', backbone=dict(type='Inception3', num_classes=1000), - head=[dict( - type='ClsHead', - with_fc=False, - in_channels=2048, - loss_config=dict( - type='CrossEntropyLossWithLabelSmooth', - label_smooth=0, + head=[ + dict( + type='ClsHead', + with_fc=False, + in_channels=2048, + loss_config=dict( + type='CrossEntropyLossWithLabelSmooth', + label_smooth=0, + ), + num_classes=num_classes, + input_feature_index=[1], ), - num_classes=num_classes, - input_feature_index=[1], - ), - dict( - type='ClsHead', - with_fc=False, - in_channels=768, - loss_config=dict( - type='CrossEntropyLossWithLabelSmooth', - label_smooth=0, - ), - num_classes=num_classes, - input_feature_index=[0], - ) - ] - ) + dict( + type='ClsHead', + with_fc=False, + in_channels=768, + loss_config=dict( + type='CrossEntropyLossWithLabelSmooth', + label_smooth=0, + ), + num_classes=num_classes, + input_feature_index=[0], + ) + ]) class_list = [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', @@ -211,3 +211,5 @@ interval=10, hooks=[dict(type='TextLoggerHook'), dict(type='TensorboardLoggerHook')]) + +export = dict(export_type='raw', export_neck=True) diff --git a/configs/classification/imagenet/inception/inceptionv4_b32x8_100e.py b/configs/classification/imagenet/inception/inceptionv4_b32x8_100e.py index 5aee7685..127f0e74 100644 --- a/configs/classification/imagenet/inception/inceptionv4_b32x8_100e.py +++ b/configs/classification/imagenet/inception/inceptionv4_b32x8_100e.py @@ -211,3 +211,5 @@ interval=10, hooks=[dict(type='TextLoggerHook'), dict(type='TensorboardLoggerHook')]) + +export = dict(export_type='raw', export_neck=True) diff --git a/configs/classification/imagenet/inception/inceptionv4_b32x8_200e_rmsprop.py b/configs/classification/imagenet/inception/inceptionv4_b32x8_200e_rmsprop.py index 2fc8b5ba..2d31d6bb 100644 --- a/configs/classification/imagenet/inception/inceptionv4_b32x8_200e_rmsprop.py +++ b/configs/classification/imagenet/inception/inceptionv4_b32x8_200e_rmsprop.py @@ -214,3 +214,5 @@ interval=10, hooks=[dict(type='TextLoggerHook'), dict(type='TensorboardLoggerHook')]) + +export = dict(export_type='raw', export_neck=True) diff --git a/configs/classification/imagenet/mobilenet/mobilenetv2.py b/configs/classification/imagenet/mobilenet/mobilenetv2.py index 29a663f7..354d966a 100644 --- a/configs/classification/imagenet/mobilenet/mobilenetv2.py +++ b/configs/classification/imagenet/mobilenet/mobilenetv2.py @@ -13,7 +13,8 @@ type='CrossEntropyLossWithLabelSmooth', label_smooth=0, ), - num_classes=num_classes)) + num_classes=num_classes), + pretrained=True) # optimizer optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001) @@ -25,4 +26,4 @@ # runtime settings total_epochs = 100 checkpoint_sync_export = True -export = dict(export_neck=True) +export = dict(export_type='raw', export_neck=True) diff --git a/configs/classification/imagenet/resnext/resnext50-32x4d_b32x8_100e_jpg.py b/configs/classification/imagenet/resnext/resnext50-32x4d_b32x8_100e_jpg.py index c1156bd2..c5cacc07 100644 --- a/configs/classification/imagenet/resnext/resnext50-32x4d_b32x8_100e_jpg.py +++ b/configs/classification/imagenet/resnext/resnext50-32x4d_b32x8_100e_jpg.py @@ -19,7 +19,8 @@ type='CrossEntropyLossWithLabelSmooth', label_smooth=0, ), - num_classes=num_classes)) + num_classes=num_classes), + pretrained=True) # optimizer optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001) @@ -30,3 +31,4 @@ # runtime settings total_epochs = 100 +export = dict(export_type='raw', export_neck=True) diff --git a/easycv/apis/export.py b/easycv/apis/export.py index 85424a62..701d9f23 100644 --- a/easycv/apis/export.py +++ b/easycv/apis/export.py @@ -158,34 +158,49 @@ def _get_blade_model(): def _export_onnx_cls(model, model_config, cfg, filename, meta): + support_backbones = { + 'ResNet': { + 'depth': [50] + }, + 'MobileNetV2': {}, + 'Inception3': {}, + 'Inception4': {}, + 'ResNeXt': { + 'depth': [50] + } + } + if model_config['backbone'].get('type', None) not in support_backbones: + tmp = ' '.join(support_backbones.keys()) + info_str = f'Only support export onnx model for {tmp} now!' + raise ValueError(info_str) + configs = support_backbones[model_config['backbone'].get('type')] + for k, v in configs.items(): + if v[0].__class__(model_config['backbone'].get(k, None)) not in v: + raise ValueError( + f"Unsupport config for {model_config['backbone'].get('type')}") + + # save json config for test_pipline and class + with io.open( + filename + + '.config.json' if filename.endswith('onnx') else filename + + '.onnx.config.json', 'w') as ofile: + json.dump(meta, ofile) - if model_config['backbone'].get( - 'type', None) == 'ResNet' and model_config['backbone'].get( - 'depth', None) == 50: - # save json config for test_pipline and class - with io.open( - filename + - '.config.json' if filename.endswith('onnx') else filename + - '.onnx.config.json', 'w') as ofile: - json.dump(meta, ofile) - - device = 'cuda' if torch.cuda.is_available() else 'cpu' - model.eval() - model.to(device) - img_size = int(cfg.image_size2) - x_input = torch.randn((1, 3, img_size, img_size)).to(device) - torch.onnx.export( - model, - (x_input, 'onnx'), - filename if filename.endswith('onnx') else filename + '.onnx', - export_params=True, - opset_version=12, - do_constant_folding=True, - input_names=['input'], - output_names=['output'], - ) - else: - raise ValueError('Only support export onnx model for ResNet now!') + device = 'cuda' if torch.cuda.is_available() else 'cpu' + model.eval() + model.to(device) + img_size = int(cfg.image_size2) + x_input = torch.randn((1, 3, img_size, img_size)).to(device) + torch.onnx.export( + model, + (x_input, 'onnx'), + filename if filename.endswith('onnx') else filename + '.onnx', + export_params=True, + opset_version=12, + do_constant_folding=True, + input_names=['input'], + output_names=['output'], + ) def _export_cls(model, cfg, filename): diff --git a/tests/test_tools/test_export.py b/tests/test_tools/test_export.py new file mode 100644 index 00000000..8be0b3fe --- /dev/null +++ b/tests/test_tools/test_export.py @@ -0,0 +1,128 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import logging +import os +import sys +import unittest + +import numpy as np +import onnxruntime +import torch + +from easycv.models import build_model +from easycv.utils.checkpoint import load_checkpoint +from easycv.utils.config_tools import mmcv_config_fromfile, rebuild_config +from easycv.utils.test_util import run_in_subprocess + +sys.path.append(os.path.dirname(os.path.realpath(__file__))) +logging.basicConfig(level=logging.INFO) + +BASIC_EXPORT_CONFIGS = { + 'config_file': None, + 'checkpoint': 'dummy', + 'output_filename': 'work_dir/test_out.pth', + 'user_config_params': ['--export.export_type', 'onnx'] +} + + +def build_cmd(export_configs) -> str: + base_cmd = 'python tools/export.py' + base_cmd += f" {export_configs['config_file']}" + base_cmd += f" {export_configs['checkpoint']}" + base_cmd += f" {export_configs['output_filename']}" + user_params = ' '.join(export_configs['user_config_params']) + base_cmd += f' --user_config_params {user_params}' + return base_cmd + + +class ExportTest(unittest.TestCase): + """In this unittest, we test the onnx export functionality of + some classification/detection models. + """ + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + + def tearDown(self): + super().tearDown() + + def run_test(self, CONFIG_FILE, img_size: int = 224, **override_configs): + configs = BASIC_EXPORT_CONFIGS.copy() + configs['config_file'] = CONFIG_FILE + + configs.update(override_configs) + + cmd = build_cmd(configs) + logging.info(f'Export with commands: {cmd}') + run_in_subprocess(cmd) + + cfg = mmcv_config_fromfile(configs['config_file']) + cfg = rebuild_config(cfg, configs['user_config_params']) + + if hasattr(cfg.model, 'pretrained'): + cfg.model.pretrained = False + + torch_model = build_model(cfg.model).eval() + if 'checkpoint' in override_configs: + load_checkpoint( + torch_model, + override_configs['checkpoint'], + strict=False, + logger=logging.getLogger()) + session = onnxruntime.InferenceSession(configs['output_filename'] + + '.onnx') + input_tensor = torch.randn((1, 3, img_size, img_size)) + + torch_output = torch_model(input_tensor, mode='test')['prob'] + + onnx_output = session.run( + [session.get_outputs()[0].name], + {session.get_inputs()[0].name: np.array(input_tensor)}) + if isinstance(onnx_output, list): + onnx_output = onnx_output[0] + + onnx_output = torch.tensor(onnx_output) + + is_same_shape = torch_output.shape == onnx_output.shape + + self.assertTrue( + is_same_shape, + f'The shapes of the two outputs are mismatch, got {torch_output.shape} and {onnx_output.shape}' + ) + is_allclose = torch.allclose(torch_output, onnx_output) + + torch_out_minmax = f'{float(torch_output.min())}~{float(torch_output.max())}' + onnx_out_minmax = f'{float(onnx_output.min())}~{float(onnx_output.max())}' + + info_msg = f'got avg: {float(torch_output.mean())} and {float(onnx_output.mean())},' + info_msg += f' and range: {torch_out_minmax} and {onnx_out_minmax}' + self.assertTrue( + is_allclose, + f'The values between the two outputs are mismatch, {info_msg}') + + def test_inceptionv3(self): + CONFIG_FILE = 'configs/classification/imagenet/inception/inceptionv3_b32x8_100e.py' + self.run_test(CONFIG_FILE, 299) + + def test_inceptionv4(self): + CONFIG_FILE = 'configs/classification/imagenet/inception/inceptionv4_b32x8_100e.py' + self.run_test(CONFIG_FILE, 299) + + def test_resnext50(self): + CONFIG_FILE = 'configs/classification/imagenet/resnext/imagenet_resnext50-32x4d_jpg.py' + self.run_test( + CONFIG_FILE, + checkpoint= + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnext/resnext50-32x4d/epoch_100.pth' + ) + + def test_mobilenetv2(self): + CONFIG_FILE = 'configs/classification/imagenet/mobilenet/mobilenetv2.py' + self.run_test( + CONFIG_FILE, + checkpoint= + 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/easycv/mobilenetv2/mobilenet_v2.pth' + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/tools/export.py b/tools/export.py index 36dbe0d1..505f7fc1 100644 --- a/tools/export.py +++ b/tools/export.py @@ -80,7 +80,7 @@ def main(): cfg = mmcv_config_fromfile(args.config) if args.user_config_params is not None: - assert args.model_type is not None, 'model_type must be setted' + # assert args.model_type is not None, 'model_type must be setted' # rebuild config by user config params cfg = rebuild_config(cfg, args.user_config_params) From 4e2f39382db5166cd436b7880bb56484d4c747d7 Mon Sep 17 00:00:00 2001 From: "gushen.hkw" Date: Wed, 10 Jul 2024 18:30:36 +0800 Subject: [PATCH 08/11] fix format --- easycv/models/backbones/__init__.py | 2 +- easycv/models/backbones/inceptionv3.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/easycv/models/backbones/__init__.py b/easycv/models/backbones/__init__.py index 0db2af72..794f4426 100644 --- a/easycv/models/backbones/__init__.py +++ b/easycv/models/backbones/__init__.py @@ -10,6 +10,7 @@ from .genet import PlainNet from .hrnet import HRNet from .inceptionv3 import Inception3 +from .inceptionv4 import Inception4 from .lighthrnet import LiteHRNet from .mae_vit_transformer import * from .mit import MixVisionTransformer @@ -28,4 +29,3 @@ from .vision_transformer import VisionTransformer from .vitdet import ViTDet from .x3d import X3D -from .inceptionv4 import Inception4 \ No newline at end of file diff --git a/easycv/models/backbones/inceptionv3.py b/easycv/models/backbones/inceptionv3.py index 55e54dff..0439d9fa 100644 --- a/easycv/models/backbones/inceptionv3.py +++ b/easycv/models/backbones/inceptionv3.py @@ -13,6 +13,7 @@ __all__ = ['Inception3'] + @BACKBONES.register_module class Inception3(nn.Module): From d05b7e754554cc3b236222fb3bf4f08be087a0b7 Mon Sep 17 00:00:00 2001 From: "gushen.hkw" Date: Wed, 10 Jul 2024 19:12:20 +0800 Subject: [PATCH 09/11] fix upon comments --- .../inception/inceptionv4_b32x8_100e.py | 184 +----------------- .../inceptionv4_b32x8_200e_rmsprop.py | 174 +---------------- easycv/utils/config_tools.py | 4 + tests/test_tools/test_export.py | 17 +- tools/export.py | 2 +- 5 files changed, 19 insertions(+), 362 deletions(-) diff --git a/configs/classification/imagenet/inception/inceptionv4_b32x8_100e.py b/configs/classification/imagenet/inception/inceptionv4_b32x8_100e.py index 127f0e74..09f00a3a 100644 --- a/configs/classification/imagenet/inception/inceptionv4_b32x8_100e.py +++ b/configs/classification/imagenet/inception/inceptionv4_b32x8_100e.py @@ -1,4 +1,4 @@ -_base_ = 'configs/base.py' +_base_ = 'configs/classification/imagenet/inception/inceptionv3_b32x8_100e.py' num_classes = 1000 # model settings @@ -30,186 +30,4 @@ ) ]) -class_list = [ - '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', - '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', - '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36', '37', - '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', - '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60', '61', - '62', '63', '64', '65', '66', '67', '68', '69', '70', '71', '72', '73', - '74', '75', '76', '77', '78', '79', '80', '81', '82', '83', '84', '85', - '86', '87', '88', '89', '90', '91', '92', '93', '94', '95', '96', '97', - '98', '99', '100', '101', '102', '103', '104', '105', '106', '107', '108', - '109', '110', '111', '112', '113', '114', '115', '116', '117', '118', - '119', '120', '121', '122', '123', '124', '125', '126', '127', '128', - '129', '130', '131', '132', '133', '134', '135', '136', '137', '138', - '139', '140', '141', '142', '143', '144', '145', '146', '147', '148', - '149', '150', '151', '152', '153', '154', '155', '156', '157', '158', - '159', '160', '161', '162', '163', '164', '165', '166', '167', '168', - '169', '170', '171', '172', '173', '174', '175', '176', '177', '178', - '179', '180', '181', '182', '183', '184', '185', '186', '187', '188', - '189', '190', '191', '192', '193', '194', '195', '196', '197', '198', - '199', '200', '201', '202', '203', '204', '205', '206', '207', '208', - '209', '210', '211', '212', '213', '214', '215', '216', '217', '218', - '219', '220', '221', '222', '223', '224', '225', '226', '227', '228', - '229', '230', '231', '232', '233', '234', '235', '236', '237', '238', - '239', '240', '241', '242', '243', '244', '245', '246', '247', '248', - '249', '250', '251', '252', '253', '254', '255', '256', '257', '258', - '259', '260', '261', '262', '263', '264', '265', '266', '267', '268', - '269', '270', '271', '272', '273', '274', '275', '276', '277', '278', - '279', '280', '281', '282', '283', '284', '285', '286', '287', '288', - '289', '290', '291', '292', '293', '294', '295', '296', '297', '298', - '299', '300', '301', '302', '303', '304', '305', '306', '307', '308', - '309', '310', '311', '312', '313', '314', '315', '316', '317', '318', - '319', '320', '321', '322', '323', '324', '325', '326', '327', '328', - '329', '330', '331', '332', '333', '334', '335', '336', '337', '338', - '339', '340', '341', '342', '343', '344', '345', '346', '347', '348', - '349', '350', '351', '352', '353', '354', '355', '356', '357', '358', - '359', '360', '361', '362', '363', '364', '365', '366', '367', '368', - '369', '370', '371', '372', '373', '374', '375', '376', '377', '378', - '379', '380', '381', '382', '383', '384', '385', '386', '387', '388', - '389', '390', '391', '392', '393', '394', '395', '396', '397', '398', - '399', '400', '401', '402', '403', '404', '405', '406', '407', '408', - '409', '410', '411', '412', '413', '414', '415', '416', '417', '418', - '419', '420', '421', '422', '423', '424', '425', '426', '427', '428', - '429', '430', '431', '432', '433', '434', '435', '436', '437', '438', - '439', '440', '441', '442', '443', '444', '445', '446', '447', '448', - '449', '450', '451', '452', '453', '454', '455', '456', '457', '458', - '459', '460', '461', '462', '463', '464', '465', '466', '467', '468', - '469', '470', '471', '472', '473', '474', '475', '476', '477', '478', - '479', '480', '481', '482', '483', '484', '485', '486', '487', '488', - '489', '490', '491', '492', '493', '494', '495', '496', '497', '498', - '499', '500', '501', '502', '503', '504', '505', '506', '507', '508', - '509', '510', '511', '512', '513', '514', '515', '516', '517', '518', - '519', '520', '521', '522', '523', '524', '525', '526', '527', '528', - '529', '530', '531', '532', '533', '534', '535', '536', '537', '538', - '539', '540', '541', '542', '543', '544', '545', '546', '547', '548', - '549', '550', '551', '552', '553', '554', '555', '556', '557', '558', - '559', '560', '561', '562', '563', '564', '565', '566', '567', '568', - '569', '570', '571', '572', '573', '574', '575', '576', '577', '578', - '579', '580', '581', '582', '583', '584', '585', '586', '587', '588', - '589', '590', '591', '592', '593', '594', '595', '596', '597', '598', - '599', '600', '601', '602', '603', '604', '605', '606', '607', '608', - '609', '610', '611', '612', '613', '614', '615', '616', '617', '618', - '619', '620', '621', '622', '623', '624', '625', '626', '627', '628', - '629', '630', '631', '632', '633', '634', '635', '636', '637', '638', - '639', '640', '641', '642', '643', '644', '645', '646', '647', '648', - '649', '650', '651', '652', '653', '654', '655', '656', '657', '658', - '659', '660', '661', '662', '663', '664', '665', '666', '667', '668', - '669', '670', '671', '672', '673', '674', '675', '676', '677', '678', - '679', '680', '681', '682', '683', '684', '685', '686', '687', '688', - '689', '690', '691', '692', '693', '694', '695', '696', '697', '698', - '699', '700', '701', '702', '703', '704', '705', '706', '707', '708', - '709', '710', '711', '712', '713', '714', '715', '716', '717', '718', - '719', '720', '721', '722', '723', '724', '725', '726', '727', '728', - '729', '730', '731', '732', '733', '734', '735', '736', '737', '738', - '739', '740', '741', '742', '743', '744', '745', '746', '747', '748', - '749', '750', '751', '752', '753', '754', '755', '756', '757', '758', - '759', '760', '761', '762', '763', '764', '765', '766', '767', '768', - '769', '770', '771', '772', '773', '774', '775', '776', '777', '778', - '779', '780', '781', '782', '783', '784', '785', '786', '787', '788', - '789', '790', '791', '792', '793', '794', '795', '796', '797', '798', - '799', '800', '801', '802', '803', '804', '805', '806', '807', '808', - '809', '810', '811', '812', '813', '814', '815', '816', '817', '818', - '819', '820', '821', '822', '823', '824', '825', '826', '827', '828', - '829', '830', '831', '832', '833', '834', '835', '836', '837', '838', - '839', '840', '841', '842', '843', '844', '845', '846', '847', '848', - '849', '850', '851', '852', '853', '854', '855', '856', '857', '858', - '859', '860', '861', '862', '863', '864', '865', '866', '867', '868', - '869', '870', '871', '872', '873', '874', '875', '876', '877', '878', - '879', '880', '881', '882', '883', '884', '885', '886', '887', '888', - '889', '890', '891', '892', '893', '894', '895', '896', '897', '898', - '899', '900', '901', '902', '903', '904', '905', '906', '907', '908', - '909', '910', '911', '912', '913', '914', '915', '916', '917', '918', - '919', '920', '921', '922', '923', '924', '925', '926', '927', '928', - '929', '930', '931', '932', '933', '934', '935', '936', '937', '938', - '939', '940', '941', '942', '943', '944', '945', '946', '947', '948', - '949', '950', '951', '952', '953', '954', '955', '956', '957', '958', - '959', '960', '961', '962', '963', '964', '965', '966', '967', '968', - '969', '970', '971', '972', '973', '974', '975', '976', '977', '978', - '979', '980', '981', '982', '983', '984', '985', '986', '987', '988', - '989', '990', '991', '992', '993', '994', '995', '996', '997', '998', '999' -] - -data_source_type = 'ClsSourceImageList' -base_root = 'data/imagenet_raw/' -data_train_list = base_root + 'meta/train_labeled.txt' -data_train_root = base_root + 'train/' -data_test_list = base_root + 'meta/val_labeled.txt' -data_test_root = base_root + 'validation/' -image_size2 = 299 -image_size1 = int((256 / 224) * image_size2) - -dataset_type = 'ClsDataset' img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -train_pipeline = [ - dict(type='RandomResizedCrop', size=image_size2), - dict(type='RandomHorizontalFlip'), - dict(type='ToTensor'), - dict(type='Normalize', **img_norm_cfg), - dict(type='Collect', keys=['img', 'gt_labels']) -] -test_pipeline = [ - dict(type='Resize', size=image_size1), - dict(type='CenterCrop', size=image_size2), - dict(type='ToTensor'), - dict(type='Normalize', **img_norm_cfg), - dict(type='Collect', keys=['img', 'gt_labels']) -] - -data = dict( - imgs_per_gpu=32, # total 256 - workers_per_gpu=4, - train=dict( - type=dataset_type, - data_source=dict( - list_file=data_train_list, - root=data_train_root, - type=data_source_type), - pipeline=train_pipeline), - val=dict( - type=dataset_type, - data_source=dict( - list_file=data_test_list, - root=data_test_root, - type=data_source_type), - pipeline=test_pipeline)) - -eval_config = dict(initial=False, interval=1, gpu_collect=True) -eval_pipelines = [ - dict( - mode='test', - data=data['val'], - dist_eval=True, - evaluators=[ - dict(type='ClsEvaluator', topk=(1, 5), class_list=class_list) - ], - ) -] - -# optimizer -optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001) - -# learning policy -lr_config = dict(policy='step', step=[30, 60, 90]) -checkpoint_config = dict(interval=10) - -# runtime settings -total_epochs = 100 - -predict = dict( - type='ClassificationPredictor', - pipelines=[ - dict(type='Resize', size=image_size1), - dict(type='CenterCrop', size=image_size2), - dict(type='ToTensor'), - dict(type='Normalize', **img_norm_cfg), - dict(type='Collect', keys=['img']) - ]) - -log_config = dict( - interval=10, - hooks=[dict(type='TextLoggerHook'), - dict(type='TensorboardLoggerHook')]) - -export = dict(export_type='raw', export_neck=True) diff --git a/configs/classification/imagenet/inception/inceptionv4_b32x8_200e_rmsprop.py b/configs/classification/imagenet/inception/inceptionv4_b32x8_200e_rmsprop.py index 2d31d6bb..efa4a786 100644 --- a/configs/classification/imagenet/inception/inceptionv4_b32x8_200e_rmsprop.py +++ b/configs/classification/imagenet/inception/inceptionv4_b32x8_200e_rmsprop.py @@ -1,6 +1,6 @@ # A config with the optimization settings from https://arxiv.org/pdf/1602.07261 # May run with 20 GPUs -_base_ = 'configs/base.py' +_base_ = 'configs/classification/imagenet/inception/inceptionv3_b32x8_100e.py' num_classes = 1000 # model settings @@ -32,162 +32,7 @@ ) ]) -class_list = [ - '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', - '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', - '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36', '37', - '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', - '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60', '61', - '62', '63', '64', '65', '66', '67', '68', '69', '70', '71', '72', '73', - '74', '75', '76', '77', '78', '79', '80', '81', '82', '83', '84', '85', - '86', '87', '88', '89', '90', '91', '92', '93', '94', '95', '96', '97', - '98', '99', '100', '101', '102', '103', '104', '105', '106', '107', '108', - '109', '110', '111', '112', '113', '114', '115', '116', '117', '118', - '119', '120', '121', '122', '123', '124', '125', '126', '127', '128', - '129', '130', '131', '132', '133', '134', '135', '136', '137', '138', - '139', '140', '141', '142', '143', '144', '145', '146', '147', '148', - '149', '150', '151', '152', '153', '154', '155', '156', '157', '158', - '159', '160', '161', '162', '163', '164', '165', '166', '167', '168', - '169', '170', '171', '172', '173', '174', '175', '176', '177', '178', - '179', '180', '181', '182', '183', '184', '185', '186', '187', '188', - '189', '190', '191', '192', '193', '194', '195', '196', '197', '198', - '199', '200', '201', '202', '203', '204', '205', '206', '207', '208', - '209', '210', '211', '212', '213', '214', '215', '216', '217', '218', - '219', '220', '221', '222', '223', '224', '225', '226', '227', '228', - '229', '230', '231', '232', '233', '234', '235', '236', '237', '238', - '239', '240', '241', '242', '243', '244', '245', '246', '247', '248', - '249', '250', '251', '252', '253', '254', '255', '256', '257', '258', - '259', '260', '261', '262', '263', '264', '265', '266', '267', '268', - '269', '270', '271', '272', '273', '274', '275', '276', '277', '278', - '279', '280', '281', '282', '283', '284', '285', '286', '287', '288', - '289', '290', '291', '292', '293', '294', '295', '296', '297', '298', - '299', '300', '301', '302', '303', '304', '305', '306', '307', '308', - '309', '310', '311', '312', '313', '314', '315', '316', '317', '318', - '319', '320', '321', '322', '323', '324', '325', '326', '327', '328', - '329', '330', '331', '332', '333', '334', '335', '336', '337', '338', - '339', '340', '341', '342', '343', '344', '345', '346', '347', '348', - '349', '350', '351', '352', '353', '354', '355', '356', '357', '358', - '359', '360', '361', '362', '363', '364', '365', '366', '367', '368', - '369', '370', '371', '372', '373', '374', '375', '376', '377', '378', - '379', '380', '381', '382', '383', '384', '385', '386', '387', '388', - '389', '390', '391', '392', '393', '394', '395', '396', '397', '398', - '399', '400', '401', '402', '403', '404', '405', '406', '407', '408', - '409', '410', '411', '412', '413', '414', '415', '416', '417', '418', - '419', '420', '421', '422', '423', '424', '425', '426', '427', '428', - '429', '430', '431', '432', '433', '434', '435', '436', '437', '438', - '439', '440', '441', '442', '443', '444', '445', '446', '447', '448', - '449', '450', '451', '452', '453', '454', '455', '456', '457', '458', - '459', '460', '461', '462', '463', '464', '465', '466', '467', '468', - '469', '470', '471', '472', '473', '474', '475', '476', '477', '478', - '479', '480', '481', '482', '483', '484', '485', '486', '487', '488', - '489', '490', '491', '492', '493', '494', '495', '496', '497', '498', - '499', '500', '501', '502', '503', '504', '505', '506', '507', '508', - '509', '510', '511', '512', '513', '514', '515', '516', '517', '518', - '519', '520', '521', '522', '523', '524', '525', '526', '527', '528', - '529', '530', '531', '532', '533', '534', '535', '536', '537', '538', - '539', '540', '541', '542', '543', '544', '545', '546', '547', '548', - '549', '550', '551', '552', '553', '554', '555', '556', '557', '558', - '559', '560', '561', '562', '563', '564', '565', '566', '567', '568', - '569', '570', '571', '572', '573', '574', '575', '576', '577', '578', - '579', '580', '581', '582', '583', '584', '585', '586', '587', '588', - '589', '590', '591', '592', '593', '594', '595', '596', '597', '598', - '599', '600', '601', '602', '603', '604', '605', '606', '607', '608', - '609', '610', '611', '612', '613', '614', '615', '616', '617', '618', - '619', '620', '621', '622', '623', '624', '625', '626', '627', '628', - '629', '630', '631', '632', '633', '634', '635', '636', '637', '638', - '639', '640', '641', '642', '643', '644', '645', '646', '647', '648', - '649', '650', '651', '652', '653', '654', '655', '656', '657', '658', - '659', '660', '661', '662', '663', '664', '665', '666', '667', '668', - '669', '670', '671', '672', '673', '674', '675', '676', '677', '678', - '679', '680', '681', '682', '683', '684', '685', '686', '687', '688', - '689', '690', '691', '692', '693', '694', '695', '696', '697', '698', - '699', '700', '701', '702', '703', '704', '705', '706', '707', '708', - '709', '710', '711', '712', '713', '714', '715', '716', '717', '718', - '719', '720', '721', '722', '723', '724', '725', '726', '727', '728', - '729', '730', '731', '732', '733', '734', '735', '736', '737', '738', - '739', '740', '741', '742', '743', '744', '745', '746', '747', '748', - '749', '750', '751', '752', '753', '754', '755', '756', '757', '758', - '759', '760', '761', '762', '763', '764', '765', '766', '767', '768', - '769', '770', '771', '772', '773', '774', '775', '776', '777', '778', - '779', '780', '781', '782', '783', '784', '785', '786', '787', '788', - '789', '790', '791', '792', '793', '794', '795', '796', '797', '798', - '799', '800', '801', '802', '803', '804', '805', '806', '807', '808', - '809', '810', '811', '812', '813', '814', '815', '816', '817', '818', - '819', '820', '821', '822', '823', '824', '825', '826', '827', '828', - '829', '830', '831', '832', '833', '834', '835', '836', '837', '838', - '839', '840', '841', '842', '843', '844', '845', '846', '847', '848', - '849', '850', '851', '852', '853', '854', '855', '856', '857', '858', - '859', '860', '861', '862', '863', '864', '865', '866', '867', '868', - '869', '870', '871', '872', '873', '874', '875', '876', '877', '878', - '879', '880', '881', '882', '883', '884', '885', '886', '887', '888', - '889', '890', '891', '892', '893', '894', '895', '896', '897', '898', - '899', '900', '901', '902', '903', '904', '905', '906', '907', '908', - '909', '910', '911', '912', '913', '914', '915', '916', '917', '918', - '919', '920', '921', '922', '923', '924', '925', '926', '927', '928', - '929', '930', '931', '932', '933', '934', '935', '936', '937', '938', - '939', '940', '941', '942', '943', '944', '945', '946', '947', '948', - '949', '950', '951', '952', '953', '954', '955', '956', '957', '958', - '959', '960', '961', '962', '963', '964', '965', '966', '967', '968', - '969', '970', '971', '972', '973', '974', '975', '976', '977', '978', - '979', '980', '981', '982', '983', '984', '985', '986', '987', '988', - '989', '990', '991', '992', '993', '994', '995', '996', '997', '998', '999' -] - -data_source_type = 'ClsSourceImageList' -base_root = 'data/imagenet_raw/' -data_train_list = base_root + 'meta/train_labeled.txt' -data_train_root = base_root + 'train/' -data_test_list = base_root + 'meta/val_labeled.txt' -data_test_root = base_root + 'validation/' -image_size2 = 299 -image_size1 = int((256 / 224) * image_size2) - -dataset_type = 'ClsDataset' img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -train_pipeline = [ - dict(type='RandomResizedCrop', size=image_size2), - dict(type='RandomHorizontalFlip'), - dict(type='ToTensor'), - dict(type='Normalize', **img_norm_cfg), - dict(type='Collect', keys=['img', 'gt_labels']) -] -test_pipeline = [ - dict(type='Resize', size=image_size1), - dict(type='CenterCrop', size=image_size2), - dict(type='ToTensor'), - dict(type='Normalize', **img_norm_cfg), - dict(type='Collect', keys=['img', 'gt_labels']) -] - -data = dict( - imgs_per_gpu=32, # total 256 - workers_per_gpu=4, - train=dict( - type=dataset_type, - data_source=dict( - list_file=data_train_list, - root=data_train_root, - type=data_source_type), - pipeline=train_pipeline), - val=dict( - type=dataset_type, - data_source=dict( - list_file=data_test_list, - root=data_test_root, - type=data_source_type), - pipeline=test_pipeline)) - -eval_config = dict(initial=False, interval=1, gpu_collect=True) -eval_pipelines = [ - dict( - mode='test', - data=data['val'], - dist_eval=True, - evaluators=[ - dict(type='ClsEvaluator', topk=(1, 5), class_list=class_list) - ], - ) -] # optimizer optimizer = dict( @@ -199,20 +44,3 @@ # runtime settings total_epochs = 200 - -predict = dict( - type='ClassificationPredictor', - pipelines=[ - dict(type='Resize', size=image_size1), - dict(type='CenterCrop', size=image_size2), - dict(type='ToTensor'), - dict(type='Normalize', **img_norm_cfg), - dict(type='Collect', keys=['img']) - ]) - -log_config = dict( - interval=10, - hooks=[dict(type='TextLoggerHook'), - dict(type='TensorboardLoggerHook')]) - -export = dict(export_type='raw', export_neck=True) diff --git a/easycv/utils/config_tools.py b/easycv/utils/config_tools.py index e673f49a..631f50d0 100644 --- a/easycv/utils/config_tools.py +++ b/easycv/utils/config_tools.py @@ -515,6 +515,10 @@ def validate_export_config(cfg): 'configs/classification/imagenet/swint/imagenet_swin_tiny_patch4_window7_224_jpg.py', 'CLASSIFICATION_M0BILENET': 'configs/classification/imagenet/mobilenet/mobilenetv2.py', + 'CLASSIFICATION_INCEPTIONV4': + 'configs/classification/imagenet/inception/inceptionv4_b32x8_100e.py', + 'CLASSIFICATION_INCEPTIONV3': + 'configs/classification/imagenet/inception/inceptionv3_b32x8_100e.py', # metric learning 'METRICLEARNING': diff --git a/tests/test_tools/test_export.py b/tests/test_tools/test_export.py index 8be0b3fe..79558de0 100644 --- a/tests/test_tools/test_export.py +++ b/tests/test_tools/test_export.py @@ -24,11 +24,12 @@ } -def build_cmd(export_configs) -> str: +def build_cmd(export_configs, MODEL_TYPE) -> str: base_cmd = 'python tools/export.py' base_cmd += f" {export_configs['config_file']}" base_cmd += f" {export_configs['checkpoint']}" base_cmd += f" {export_configs['output_filename']}" + base_cmd += f' --model_type {MODEL_TYPE}' user_params = ' '.join(export_configs['user_config_params']) base_cmd += f' --user_config_params {user_params}' return base_cmd @@ -45,13 +46,17 @@ def setUp(self): def tearDown(self): super().tearDown() - def run_test(self, CONFIG_FILE, img_size: int = 224, **override_configs): + def run_test(self, + CONFIG_FILE, + MODEL_TYPE, + img_size: int = 224, + **override_configs): configs = BASIC_EXPORT_CONFIGS.copy() configs['config_file'] = CONFIG_FILE configs.update(override_configs) - cmd = build_cmd(configs) + cmd = build_cmd(configs, MODEL_TYPE) logging.info(f'Export with commands: {cmd}') run_in_subprocess(cmd) @@ -101,16 +106,17 @@ def run_test(self, CONFIG_FILE, img_size: int = 224, **override_configs): def test_inceptionv3(self): CONFIG_FILE = 'configs/classification/imagenet/inception/inceptionv3_b32x8_100e.py' - self.run_test(CONFIG_FILE, 299) + self.run_test(CONFIG_FILE, 'CLASSIFICATION_INCEPTIONV3', 299) def test_inceptionv4(self): CONFIG_FILE = 'configs/classification/imagenet/inception/inceptionv4_b32x8_100e.py' - self.run_test(CONFIG_FILE, 299) + self.run_test(CONFIG_FILE, 'CLASSIFICATION_INCEPTIONV4', 299) def test_resnext50(self): CONFIG_FILE = 'configs/classification/imagenet/resnext/imagenet_resnext50-32x4d_jpg.py' self.run_test( CONFIG_FILE, + 'CLASSIFICATION_RESNEXT', checkpoint= 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnext/resnext50-32x4d/epoch_100.pth' ) @@ -119,6 +125,7 @@ def test_mobilenetv2(self): CONFIG_FILE = 'configs/classification/imagenet/mobilenet/mobilenetv2.py' self.run_test( CONFIG_FILE, + 'CLASSIFICATION_M0BILENET', checkpoint= 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/easycv/mobilenetv2/mobilenet_v2.pth' ) diff --git a/tools/export.py b/tools/export.py index 505f7fc1..36dbe0d1 100644 --- a/tools/export.py +++ b/tools/export.py @@ -80,7 +80,7 @@ def main(): cfg = mmcv_config_fromfile(args.config) if args.user_config_params is not None: - # assert args.model_type is not None, 'model_type must be setted' + assert args.model_type is not None, 'model_type must be setted' # rebuild config by user config params cfg = rebuild_config(cfg, args.user_config_params) From bf3cf8b569f056e00181215110ffeab167cee473 Mon Sep 17 00:00:00 2001 From: "gushen.hkw" Date: Fri, 12 Jul 2024 12:08:26 +0800 Subject: [PATCH 10/11] fix bug when work_dir in test_export does not exist --- tests/test_tools/test_export.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_tools/test_export.py b/tests/test_tools/test_export.py index 79558de0..7b5a96cf 100644 --- a/tests/test_tools/test_export.py +++ b/tests/test_tools/test_export.py @@ -16,10 +16,12 @@ sys.path.append(os.path.dirname(os.path.realpath(__file__))) logging.basicConfig(level=logging.INFO) +WORK_DIRECTORY = 'work_dir3' + BASIC_EXPORT_CONFIGS = { 'config_file': None, 'checkpoint': 'dummy', - 'output_filename': 'work_dir/test_out.pth', + 'output_filename': f'{WORK_DIRECTORY}/test_out.pth', 'user_config_params': ['--export.export_type', 'onnx'] } @@ -42,6 +44,7 @@ class ExportTest(unittest.TestCase): def setUp(self): print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + os.makedirs(WORK_DIRECTORY, exist_ok=True) def tearDown(self): super().tearDown() From a392bb0f1c2d0d4be64bb0bc313b7cd3c4b19b18 Mon Sep 17 00:00:00 2001 From: "gushen.hkw" Date: Thu, 18 Jul 2024 15:49:11 +0800 Subject: [PATCH 11/11] fit jit test --- easycv/models/classification/classification.py | 1 + 1 file changed, 1 insertion(+) diff --git a/easycv/models/classification/classification.py b/easycv/models/classification/classification.py index 479b4d1e..3118c24a 100644 --- a/easycv/models/classification/classification.py +++ b/easycv/models/classification/classification.py @@ -151,6 +151,7 @@ def forward_backbone(self, img: torch.Tensor) -> List[torch.Tensor]: x = self.backbone(img) return x + @torch.jit.unused def forward_onnx(self, img: torch.Tensor) -> Dict[str, torch.Tensor]: """ forward_onnx means generate prob from image only support one neck + one head