【YOLO改进】换遍主干网络之CVPR2024 微软新作StarNet:超强轻量级Backbone(基于MMYOLO)

五山一胖 2024-07-04 15:07:04 阅读 93

StarNet

论文链接:[2403.19967] Rewrite the Stars

github仓库:GitHub - ma-xu/Rewrite-the-Stars: [CVPR 2024] Rewrite the Stars

CVPR2024 Rewrite the Stars论文揭示了<code>star operation(元素乘法)在无需加宽网络下,将输入映射到高维非线性特征空间的能力。基于此提出了StarNet,在紧凑的网络结构和较低的能耗下展示了令人印象深刻的性能和低延迟。

优势 (Advantages)

高维和非线性特征变换 (High-Dimensional and Non-Linear Feature Transformation)

StarNet通过星操作(star operation)实现高维和非线性特征空间的映射,而无需增加计算复杂度。与传统的内核技巧(kernel tricks)类似,星操作能够在低维输入中隐式获得高维特征​ (ar5iv)​。对于YOLO系列网络,这意味着在保持计算效率的同时,能够获得更丰富和表达力更强的特征表示,这对于目标检测任务中的精细特征捕获尤为重要。

高效网络设计 (Efficient Network Design)

StarNet通过星操作实现了高效的特征表示,无需复杂的网络设计和额外的计算开销。其独特的能力在于能够在低维空间中执行计算,但隐式地考虑极高维的特征​ (ar5iv)​。这使得StarNet可以作为YOLO系列网络的主干,提供高效的计算和更好的特征表示,有助于在资源受限的环境中实现更高的检测性能。

多层次隐式特征扩展 (Multi-Layer Implicit Feature Expansion)

通过多层星操作,StarNet能够递归地增加隐式特征维度,接近无限维度。对于具有较大宽度和深度的网络,这种特性可以显著增强特征的表达能力​ (ar5iv)​。对于YOLO系列网络,这意味着可以通过适当的深度和宽度设计,显著提高特征提取的质量,从而提升目标检测的准确性。

解决的问题 (Problems Addressed)

计算复杂度与性能的平衡 (Balance Between Computational Complexity and Performance)

StarNet通过星操作在保持计算复杂度较低的同时,实现了高维特征空间的映射。这解决了传统高效网络设计中计算复杂度与性能之间的权衡问题​ (ar5iv)​。YOLO系列网络需要在实时性和检测精度之间找到平衡,StarNet的高效特性正好契合这一需求。

特征表示的丰富性 (Richness of Feature Representation)

传统卷积网络在特征表示的高维非线性变换上存在一定局限性,而StarNet通过星操作实现了更丰富的特征表示​ (ar5iv)​。在目标检测任务中,特别是对于小目标和复杂场景,丰富的特征表示能够显著提升检测效果,使得YOLO系列网络在这些场景中表现更佳。

简化网络设计 (Simplified Network Design)

StarNet通过星操作提供了一种简化网络设计的方法,无需复杂的特征融合和多分支设计就能实现高效的特征表示​ (ar5iv)​。对于YOLO系列网络,这意味着可以更容易地设计和实现高效的主干网络,降低设计和调试的复杂度。

在MMYOLO中将StarNet替换成yolov5的主干网络

1. 在上文提到的仓库中下载imagenet/starnet.py

2. 修改starnet.py中的forward函数,并且添加out_dices参数使其能够输出不同stage的特征向量

3. 将class StarNet注册并且在__init__()函数中进行修改

4. 修改配置文件,主要是调整YOLOv5 neck和head的输入输出通道数

修改后的starnet.py

"""

Implementation of Prof-of-Concept Network: StarNet.

We make StarNet as simple as possible [to show the key contribution of element-wise multiplication]:

- like NO layer-scale in network design,

- and NO EMA during training,

- which would improve the performance further.

Created by: Xu Ma (Email: ma.xu1@northeastern.edu)

Modified Date: Mar/29/2024

"""

import torch

import torch.nn as nn

from timm.models.layers import DropPath, trunc_normal_

from typing import List, Sequence, Union

# from timm.models.registry import register_model

from mmyolo.registry import MODELS

model_urls = {

"starnet_s1": "https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s1.pth.tar",

"starnet_s2": "https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s2.pth.tar",

"starnet_s3": "https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s3.pth.tar",

"starnet_s4": "https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s4.pth.tar",

}

class ConvBN(torch.nn.Sequential):

def __init__(self, in_planes, out_planes, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, with_bn=True):

super().__init__()

self.add_module('conv', torch.nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, dilation, groups))

if with_bn:

self.add_module('bn', torch.nn.BatchNorm2d(out_planes))

torch.nn.init.constant_(self.bn.weight, 1)

torch.nn.init.constant_(self.bn.bias, 0)

class Block(nn.Module):

def __init__(self, dim, mlp_ratio=3, drop_path=0.):

super().__init__()

self.dwconv = ConvBN(dim, dim, 7, 1, (7 - 1) // 2, groups=dim, with_bn=True)

self.f1 = ConvBN(dim, mlp_ratio * dim, 1, with_bn=False)

self.f2 = ConvBN(dim, mlp_ratio * dim, 1, with_bn=False)

self.g = ConvBN(mlp_ratio * dim, dim, 1, with_bn=True)

self.dwconv2 = ConvBN(dim, dim, 7, 1, (7 - 1) // 2, groups=dim, with_bn=False)

self.act = nn.ReLU6()

self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

def forward(self, x):

input = x

x = self.dwconv(x)

x1, x2 = self.f1(x), self.f2(x)

x = self.act(x1) * x2

x = self.dwconv2(self.g(x))

x = input + self.drop_path(x)

return x

@MODELS.register_module()

class StarNet(nn.Module):

def __init__(self, base_dim=32, out_indices: Sequence[int] = (0, 1, 2), depths=[3, 3, 12, 5], mlp_ratio=4,

drop_path_rate=0.0, num_classes=1000, **kwargs):

super().__init__()

self.num_classes = num_classes

self.in_channel = 32

self.out_indices = out_indices

self.depths = depths

# stem layer

self.stem = nn.Sequential(ConvBN(3, self.in_channel, kernel_size=3, stride=2, padding=1), nn.ReLU6())

dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth

# build stages

self.stages = nn.ModuleList()

cur = 0

for i_layer in range(len(depths)):

embed_dim = base_dim * 2 ** i_layer

down_sampler = ConvBN(self.in_channel, embed_dim, 3, 2, 1)

self.in_channel = embed_dim

blocks = [Block(self.in_channel, mlp_ratio, dpr[cur + i]) for i in range(depths[i_layer])]

cur += depths[i_layer]

self.stages.append(nn.Sequential(down_sampler, *blocks))

# head

# self.norm = nn.BatchNorm2d(self.in_channel)

# self.avgpool = nn.AdaptiveAvgPool2d(1)

# self.head = nn.Linear(self.in_channel, num_classes)

# self.apply(self._init_weights)

def _init_weights(self, m):

if isinstance(m, nn.Linear or nn.Conv2d):

trunc_normal_(m.weight, std=.02)

if isinstance(m, nn.Linear) and m.bias is not None:

nn.init.constant_(m.bias, 0)

elif isinstance(m, nn.LayerNorm or nn.BatchNorm2d):

nn.init.constant_(m.bias, 0)

nn.init.constant_(m.weight, 1.0)

def forward(self, x):

x = self.stem(x)

##记录stage的输出

outs = []

for i in range(len(self.depths)):

x = self.stages[i](x)

if i in self.out_indices:

outs.append(x)

return tuple(outs)

@MODELS.register_module()

def starnet_s1(pretrained=False, **kwargs):

model = StarNet(24, (0, 1, 2), [2, 2, 8, 3], **kwargs)

if pretrained:

url = model_urls['starnet_s1']

checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")code>

model.load_state_dict(checkpoint["state_dict"])

return model

@MODELS.register_module()

def starnet_s2(pretrained=False, **kwargs):

model = StarNet(32, (0, 1, 2), [1, 2, 6, 2], **kwargs)

if pretrained:

url = model_urls['starnet_s2']

checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")code>

model.load_state_dict(checkpoint["state_dict"])

return model

@MODELS.register_module()

def starnet_s3(pretrained=False, **kwargs):

model = StarNet(32, (0, 1, 2), [2, 2, 8, 4], **kwargs)

if pretrained:

url = model_urls['starnet_s3']

checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")code>

model.load_state_dict(checkpoint["state_dict"])

return model

@MODELS.register_module()

def starnet_s4(pretrained=False, **kwargs):

model = StarNet(32, (0, 1, 2), [3, 3, 12, 5], **kwargs)

if pretrained:

url = model_urls['starnet_s4']

checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")code>

model.load_state_dict(checkpoint["state_dict"])

return model

# very small networks #

@MODELS.register_module()

def starnet_s050(pretrained=False, **kwargs):

return StarNet(16, (0, 1, 2), [1, 1, 3, 1], 3, **kwargs)

@MODELS.register_module()

def starnet_s100(pretrained=False, **kwargs):

return StarNet(20, (0, 1, 2), [1, 2, 4, 1], 4, **kwargs)

@MODELS.register_module()

def starnet_s150(pretrained=False, **kwargs):

return StarNet(24, (0, 1, 2), [1, 2, 4, 2], 3, **kwargs)

if __name__ == '__main__':

model = StarNet()

input_tensor = torch.randn(1, 3, 224, 224)

outputs = model(input_tensor)

修改后的__init__.py

# Copyright (c) OpenMMLab. All rights reserved.

from .base_backbone import BaseBackbone

from .csp_darknet import YOLOv5CSPDarknet, YOLOv8CSPDarknet, YOLOXCSPDarknet

from .csp_resnet import PPYOLOECSPResNet

from .cspnext import CSPNeXt

from .efficient_rep import YOLOv6CSPBep, YOLOv6EfficientRep

from .yolov7_backbone import YOLOv7Backbone

from .starnet import StarNet

__all__ = [

'YOLOv5CSPDarknet', 'BaseBackbone', 'YOLOv6EfficientRep', 'YOLOv6CSPBep',

'YOLOXCSPDarknet', 'CSPNeXt', 'YOLOv7Backbone', 'PPYOLOECSPResNet',

'YOLOv8CSPDarknet','StarNet'

]

修改后的配置文件(以yolov5_s-v61_syncbn_8xb16-300e_coco.py为例子)

_base_ = ['../_base_/default_runtime.py', '../_base_/det_p5_tta.py']

# ========================Frequently modified parameters======================

# -----data related-----

data_root = 'data/coco/' # Root path of data

# Path of train annotation file

train_ann_file = 'annotations/instances_train2017.json'

train_data_prefix = 'train2017/' # Prefix of train image path

# Path of val annotation file

val_ann_file = 'annotations/instances_val2017.json'

val_data_prefix = 'val2017/' # Prefix of val image path

num_classes = 80 # Number of classes for classification

# Batch size of a single GPU during training

train_batch_size_per_gpu = 16

# Worker to pre-fetch data for each single GPU during training

train_num_workers = 8

# persistent_workers must be False if num_workers is 0

persistent_workers = True

# -----model related-----

# Basic size of multi-scale prior box

anchors = [

[(10, 13), (16, 30), (33, 23)], # P3/8

[(30, 61), (62, 45), (59, 119)], # P4/16

[(116, 90), (156, 198), (373, 326)] # P5/32

]

# -----train val related-----

# Base learning rate for optim_wrapper. Corresponding to 8xb16=128 bs

base_lr = 0.01

max_epochs = 300 # Maximum training epochs

model_test_cfg = dict(

# The config of multi-label for multi-class prediction.

multi_label=True,

# The number of boxes before NMS

nms_pre=30000,

score_thr=0.001, # Threshold to filter out boxes.

nms=dict(type='nms', iou_threshold=0.65), # NMS type and thresholdcode>

max_per_img=300) # Max number of detections of each image

# ========================Possible modified parameters========================

# -----data related-----

img_scale = (640, 640) # width, height

# Dataset type, this will be used to define the dataset

dataset_type = 'YOLOv5CocoDataset'

# Batch size of a single GPU during validation

val_batch_size_per_gpu = 1

# Worker to pre-fetch data for each single GPU during validation

val_num_workers = 2

# Config of batch shapes. Only on val.

# It means not used if batch_shapes_cfg is None.

batch_shapes_cfg = dict(

type='BatchShapePolicy',code>

batch_size=val_batch_size_per_gpu,

img_size=img_scale[0],

# The image scale of padding should be divided by pad_size_divisor

size_divisor=32,

# Additional paddings for pixel scale

extra_pad_ratio=0.5)

# -----model related-----

# The scaling factor that controls the depth of the network structure

deepen_factor = 0.33

# The scaling factor that controls the width of the network structure

widen_factor = 0.5

# Strides of multi-scale prior box

strides = [8, 16, 32]

num_det_layers = 3 # The number of model output scales

norm_cfg = dict(type='BN', momentum=0.03, eps=0.001) # Normalization configcode>

# -----train val related-----

affine_scale = 0.5 # YOLOv5RandomAffine scaling ratio

loss_cls_weight = 0.5

loss_bbox_weight = 0.05

loss_obj_weight = 1.0

prior_match_thr = 4. # Priori box matching threshold

# The obj loss weights of the three output layers

obj_level_weights = [4., 1., 0.4]

lr_factor = 0.01 # Learning rate scaling factor

weight_decay = 0.0005

# Save model checkpoint and validation intervals

save_checkpoint_intervals = 10

# The maximum checkpoints to keep.

max_keep_ckpts = 3

# Single-scale training is recommended to

# be turned on, which can speed up training.

env_cfg = dict(cudnn_benchmark=True)

'''

starnet_channel,base_dim,depths,mlp_ratio

s1:24,[48, 96, 192],[2, 2, 8, 3],4

s2:32,[64, 128, 256],[1, 2, 6, 2],4

s3:32,[64, 128, 256],[2, 2, 8, 4],4

s4:32,[64, 128, 256],[3, 3, 12, 5],4

starnet_s050:16,[32,64,128],[1, 1, 3, 1],3

starnet_s0100:20,[40, 80, 120],[1, 2, 4, 1],4

starnet_s150:24,[48, 96, 192],[1, 2, 4, 2],3

'''

starnet_channel=[48, 96, 192]

depths=[1, 2, 6, 2]

# ===============================Unmodified in most cases====================

model = dict(

type='YOLODetector',code>

data_preprocessor=dict(

type='mmdet.DetDataPreprocessor',code>

mean=[0., 0., 0.],

std=[255., 255., 255.],

bgr_to_rgb=True),

backbone=dict(

##s1

type='StarNet',code>

base_dim=24,

out_indices=(0,1,2),

depths=depths,

mlp_ratio=4,

num_classes=num_classes,

# deepen_factor=deepen_factor,

# widen_factor=widen_factor,

# norm_cfg=norm_cfg,

# act_cfg=dict(type='SiLU', inplace=True)code>

),

neck=dict(

type='YOLOv5PAFPN',code>

deepen_factor=deepen_factor,

widen_factor=widen_factor,

in_channels=starnet_channel,

out_channels=starnet_channel,

num_csp_blocks=3,

norm_cfg=norm_cfg,

act_cfg=dict(type='SiLU', inplace=True)),code>

bbox_head=dict(

type='YOLOv5Head',code>

head_module=dict(

type='YOLOv5HeadModule',code>

num_classes=num_classes,

in_channels=starnet_channel,

widen_factor=widen_factor,

featmap_strides=strides,

num_base_priors=3),

prior_generator=dict(

type='mmdet.YOLOAnchorGenerator',code>

base_sizes=anchors,

strides=strides),

# scaled based on number of detection layers

loss_cls=dict(

type='mmdet.CrossEntropyLoss',code>

use_sigmoid=True,

reduction='mean',code>

loss_weight=loss_cls_weight *

(num_classes / 80 * 3 / num_det_layers)),

# 修改此处实现IoU损失函数的替换

loss_bbox=dict(

type='IoULoss',code>

focal=True,

iou_mode='ciou',code>

bbox_format='xywh',code>

eps=1e-7,

reduction='mean',code>

loss_weight=loss_bbox_weight * (3 / num_det_layers),

return_iou=True),

loss_obj=dict(

type='mmdet.CrossEntropyLoss',code>

use_sigmoid=True,

reduction='mean',code>

loss_weight=loss_obj_weight *

((img_scale[0] / 640) ** 2 * 3 / num_det_layers)),

prior_match_thr=prior_match_thr,

obj_level_weights=obj_level_weights),

test_cfg=model_test_cfg)

albu_train_transforms = [

dict(type='Blur', p=0.01),code>

dict(type='MedianBlur', p=0.01),code>

dict(type='ToGray', p=0.01),code>

dict(type='CLAHE', p=0.01)code>

]

pre_transform = [

dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),code>

dict(type='LoadAnnotations', with_bbox=True)code>

]

train_pipeline = [

*pre_transform,

dict(

type='Mosaic',code>

img_scale=img_scale,

pad_val=114.0,

pre_transform=pre_transform),

dict(

type='YOLOv5RandomAffine',code>

max_rotate_degree=0.0,

max_shear_degree=0.0,

scaling_ratio_range=(1 - affine_scale, 1 + affine_scale),

# img_scale is (width, height)

border=(-img_scale[0] // 2, -img_scale[1] // 2),

border_val=(114, 114, 114)),

dict(

type='mmdet.Albu',code>

transforms=albu_train_transforms,

bbox_params=dict(

type='BboxParams',code>

format='pascal_voc',code>

label_fields=['gt_bboxes_labels', 'gt_ignore_flags']),

keymap={

'img': 'image',

'gt_bboxes': 'bboxes'

}),

dict(type='YOLOv5HSVRandomAug'),code>

dict(type='mmdet.RandomFlip', prob=0.5),code>

dict(

type='mmdet.PackDetInputs',code>

meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip',

'flip_direction'))

]

train_dataloader = dict(

batch_size=train_batch_size_per_gpu,

num_workers=train_num_workers,

persistent_workers=persistent_workers,

pin_memory=True,

sampler=dict(type='DefaultSampler', shuffle=True),code>

dataset=dict(

type=dataset_type,

data_root=data_root,

ann_file=train_ann_file,

data_prefix=dict(img=train_data_prefix),

filter_cfg=dict(filter_empty_gt=False, min_size=32),

pipeline=train_pipeline))

test_pipeline = [

dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),code>

dict(type='YOLOv5KeepRatioResize', scale=img_scale),code>

dict(

type='LetterResize',code>

scale=img_scale,

allow_scale_up=False,

pad_val=dict(img=114)),

dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),code>

dict(

type='mmdet.PackDetInputs',code>

meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',

'scale_factor', 'pad_param'))

]

val_dataloader = dict(

batch_size=val_batch_size_per_gpu,

num_workers=val_num_workers,

persistent_workers=persistent_workers,

pin_memory=True,

drop_last=False,

sampler=dict(type='DefaultSampler', shuffle=False),code>

dataset=dict(

type=dataset_type,

data_root=data_root,

test_mode=True,

data_prefix=dict(img=val_data_prefix),

ann_file=val_ann_file,

pipeline=test_pipeline,

batch_shapes_cfg=batch_shapes_cfg))

test_dataloader = val_dataloader

param_scheduler = None

optim_wrapper = dict(

type='OptimWrapper',code>

optimizer=dict(

type='SGD',code>

lr=base_lr,

momentum=0.937,

weight_decay=weight_decay,

nesterov=True,

batch_size_per_gpu=train_batch_size_per_gpu),

constructor='YOLOv5OptimizerConstructor')code>

default_hooks = dict(

param_scheduler=dict(

type='YOLOv5ParamSchedulerHook',code>

scheduler_type='linear',code>

lr_factor=lr_factor,

max_epochs=max_epochs),

checkpoint=dict(

type='CheckpointHook',code>

interval=save_checkpoint_intervals,

save_best='auto',code>

max_keep_ckpts=max_keep_ckpts))

custom_hooks = [

dict(

type='EMAHook',code>

ema_type='ExpMomentumEMA',code>

momentum=0.0001,

update_buffers=True,

strict_load=False,

priority=49)

]

val_evaluator = dict(

type='mmdet.CocoMetric',code>

proposal_nums=(100, 1, 10),

ann_file=data_root + val_ann_file,

metric='bbox')code>

test_evaluator = val_evaluator

train_cfg = dict(

type='EpochBasedTrainLoop',code>

max_epochs=max_epochs,

val_interval=save_checkpoint_intervals)

val_cfg = dict(type='ValLoop')code>

test_cfg = dict(type='TestLoop')code>



声明

本文内容仅代表作者观点,或转载于其他网站,本站不以此文作为商业用途
如有涉及侵权,请联系本站进行删除
转载本站原创文章,请注明来源及作者。