基于ResNet50实现垃圾分类
只爱喝水 2024-09-10 12:31:02 阅读 69
一、垃圾分类背景
在现代社会中,垃圾分类已经成为环保的重要措施之一。然而,手动分类垃圾既费时又容易出错。借助深度学习技术,特别是卷积神经网络(CNN),我们可以开发一种自动垃圾分类系统。本文将介绍如何基于ResNet50实现垃圾分类。
二、ResNet50简介
ResNet50是Residual Networks(残差网络)的一种变体,由Kaiming He等人在2015年提出。ResNet50包含50个深度层,通过引入残差模块,有效地解决了深层网络的退化问题。残差模块通过引入短连接(skip connections)使得网络在训练时更容易优化。下图在下文中称为主图
1、ResNet50整体架构图
首先需要声明,这张图的内容是ResNet的Backbone部分(即图中没有ResNet中的全局平均池化层和全连接层),输入<code>INPUT经过ResNet50的5个阶段(Stage 0、Stage 1、……)得到输出OUTPUT
。
下面附上ResNet原文展示的ResNet结构,大家可以结合着看,看不懂也没关系,只看本文也可以无痛理解的。
上图描述了ResNet多个版本的具体结构,本文描述的“ResNet50”中的50指有50个层。和上图一样,本图描述的ResNet也分为5个阶段。
2、ResNet50各个部分的具体架构
1)Stage结构
<code>(3,224,224)指输入INPUT
的通道数(channel)、高(height)和宽(width),即(C,H,W)
。现假设输入的高度和宽度相等,所以用(C,W,W)
表示。
该stage中第1层包括3个先后操作
CONV
CONV
是卷积(Convolution)的缩写,7×7
指卷积核大小,64
指卷积核的数量(即该卷积层输出的通道数),/2
指卷积核的步长为2。
BN
BN
是Batch Normalization的缩写,即常说的BN层。
RELU
RELU
指ReLU激活函数。
该stage中第2层为MAXPOOL
,即最大池化层,其kernel大小为3×3
、步长为2
。
(64,56,56)
是该stage输出的通道数(channel)、高(height)和宽(width),其中64
等于该stage第1层卷积层中卷积核的数量,56
等于224/2/2
(步长为2会使输入尺寸减半)。
总体来讲,在Stage 0中,形状为(3,224,224)
的输入先后经过卷积层、BN层、ReLU激活函数、MaxPooling层得到了形状为(64,56,56)
的输出。
2)BINK1、BINK2的结构
BINK2(主图右侧部分):
BINK2有两个参数:C,W
C:代表输入通道数。W:代表输入尺寸。
BINK2左侧经过三个卷积快(包括BN,RELU),设其输出为F(x),将F和x相加再经过Relu激活函数得到BINK2的输出。至于为什么将F和x相加后再输出后面将会介绍。
BINK1:
BINK1有四个参数:C,W,C1,S。
S:代表卷积层中的步长,当S为1时,输入尺寸和输出尺寸相同,代表没有进行下采样。
C1:代表卷积层输出的特征图数目,即输出通道数。
C:代表输入通道数。C和C1相等说明左侧1×1的卷积层没有减少通道数,后三个stage中C=2*C1说明左侧1×1的卷积层减少了通道数。
W:代表输入尺寸,即长和宽。
BINK1相对于BINK2是输入通道和输出通道不一致的情况,BINK1右侧先经过一个卷积层,改变其输出通道数,设其输出为G(x),G函数起到了和左侧输出通道数匹配的作用,这样将F和G相加再经过Relu激活函数得到BINK1的输出。
3)简要分析
原文可知,ResNet后4个stage中都有BTNK1
和BTNK2
。
4个stage中BTNK2
参数规律相同
4个stage中BTNK2
的参数全都是1个模式和规律,只是输入的形状(C,W,W)
不同。
Stage 1中BTNK1
参数的规律与后3个stage不同
然而,4个stage中BTNK1
的参数的模式并非全都一样。具体来讲,后3个stage中BTNK1
的参数模式一致,Stage 1中BTNK1
的模式与后3个stage的不一样,这表现在以下2个方面:
参数S
:BTNK1
左右两个1×1卷积层是否下采样
Stage 1中的BTNK1
:步长S
为1,没有进行下采样,输入尺寸和输出尺寸相等。
后3个stage的BTNK1
:步长S
为2,进行了下采样,输入尺寸是输出尺寸的2倍。
参数C
和C1
:BTNK1
左侧第一个1×1卷积层是否减少通道数
Stage 1中的BTNK1
:输入通道数C
和左侧1×1卷积层通道数C1
相等(C=C1=64
),即左侧1×1卷积层没有减少通道数。
后3个stage的BTNK1
:输入通道数C
和左侧1×1卷积层通道数C1
不相等(C=2*C1
),左侧1×1卷积层有减少通道数。
为什么Stage 1中BTNK1
参数的规律与后3个stage不同?(个人观点)
关于BTNK1
左右两个1×1卷积层是否下采样
因为Stage 0中刚刚对网络输入进行了卷积和最大池化,还没有进行残差学习,此时直接下采样会损失大量信息;而后3个stage直接进行下采样时,前面的网络已经进行过残差学习了,所以可以直接进行下采样。
关于BTNK1
左侧第一个1×1卷积层是否减少通道数
根据ResNet原文可知,Bottleneck左侧两个1×1卷积层的主要作用分别是减少通道数和恢复通道数,这样就可以使它们中间的3×3卷积层的输入和输出的通道数都较小,因此效率更高。
Stage 1中BTNK1
的输入通道数C
为64,它本来就比较小,因此没有必要通过左侧第一个1×1卷积层减少通道数。
4)残差结构
传统的卷积神经网络(CNN)在训练过程中,当网络深度增加时,梯度消失和退化问题变得更加明显。为了解决这些问题,残差结构引入了短连接,通过直接将输入跳跃连接到输出,形成了所谓的“残差”连接。具体来说,残差结构通过以下公式来表达:
y=F(x,{Wi})+x\mathbf{y} = \mathcal{F}(\mathbf{x}, \{W_i\}) + \mathbf{x}y=F(x,{Wi})+x
其中:
x\mathbf{x}x 是输入。F(x,{Wi})\mathcal{F}(\mathbf{x}, \{W_i\})F(x,{Wi}) 表示学习到的残差函数,即输入经过若干卷积层后的输出。y\mathbf{y}y 是残差结构的最终输出。{Wi}\{W_i\}{Wi} 表示卷积层的权重。
通过这种形式,残差结构能够确保输入信息在每一层都能得到保留和传递,从而缓解了深层网络中的梯度消失问题。
三、ResNet50实现垃圾分类
1、数据集准备
1)首先,我们需要一组垃圾分类的数据集。
常用的数据集有:
Garbage Classification DatasetTrashNet
我们以垃圾分类数据集为例,该数据集包含多个类别的垃圾图像,如纸张、塑料、金属等。
本文采用自定义的158类别数据集。
2)数据集的处理
对于文章的数据,我们定义一个split_data.py文件用于生成训练集,验证集和测试集。
<code>import os
from shutil import copy, rmtree
import random
def mk_file(file_path: str):
if os.path.exists(file_path):
# 如果文件夹存在,则先删除原文件夹在重新创建
rmtree(file_path)
os.makedirs(file_path)
def main():
# 保证随机可复现
random.seed(0)
# 将数据集中10%的数据划分到验证集中
split_rate = 0.1
# 将数据集中20%的数据划分到测试集中
split_rate_2 = 0.2
#因此训练集、验证集、测试集比例为 7:1:2
# 指向你解压后的garbage_photos文件夹
cwd = os.getcwd()
data_root_pro = os.path.join(cwd, "data_set")
data_root = os.path.join(data_root_pro, "garbage_data") # data_root = /data_set/garbage
assert os.path.exists(data_root), "path '{}' does not exist.".format(data_root)
# flower_class = [cla for cla in os.listdir(data_root)
# if os.path.isdir(os.path.join(data_root, cla))]
garbage_class = [cla for cla in os.listdir(data_root)
if os.path.isdir(os.path.join(data_root, cla))]
# 建立保存训练集的文件夹
train_root = os.path.join(data_root, "train_garbage")
mk_file(train_root)
for cla in garbage_class:
# 建立每个类别对应的文件夹
mk_file(os.path.join(train_root, cla))
# 建立保存验证集的文件夹
val_root = os.path.join(data_root, "val_garbage")
mk_file(val_root)
for cla in garbage_class:
# 建立每个类别对应的文件夹
mk_file(os.path.join(val_root, cla))
# 建立保存测试集的文件夹
test_root = os.path.join(data_root, "test_garbage") #对于测试集无需进行分类保存图片
mk_file(test_root)
for cla in garbage_class:
cla_path = os.path.join(data_root, cla) # /data_set/garbage/0/
images = os.listdir(cla_path)
num = len(images)
# 随机采样验证集和测试集的索引
eval_index = random.sample(images, k=int(num*split_rate))
test_index = random.sample(images, k=int(num*split_rate_2))
for index, image in enumerate(images):
if image in eval_index:
# 将分配至验证集中的文件复制到相应目录
image_path = os.path.join(cla_path, image)
new_path = os.path.join(val_root, cla)
copy(image_path, new_path)
elif image in test_index:
# 将分配至测试集中的文件复制到相应目录
image_path = os.path.join(cla_path, image)
#new_path = os.path.join(test_root, cla)
copy(image_path, test_root)
else:
# 将分配至训练集中的文件复制到相应目录
image_path = os.path.join(cla_path, image)
new_path = os.path.join(train_root, cla)
copy(image_path, new_path)
print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="") # processing barcode>
print()
print("processing done!")
if __name__ == '__main__':
main()
运行split_data.py文件之后,目录将生成下图所示三个文件夹,其中train_garbage和val_garbage包含158个子文件夹,每个文件夹存在对应的图像。
2、 模型的构建
<code>import torch.nn as nn
import torch
class BasicBlock(nn.Module):
expansion = 1
# 适用于ResNet18和ResNet34的基本残差块
def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
super(BasicBlock, self).__init__()
# 第一个3x3卷积层
self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channel) # 批量归一化层
self.relu = nn.ReLU() # 激活函数
# 第二个3x3卷积层
self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channel) # 批量归一化层
self.downsample = downsample # 下采样层(用于调整输入尺寸和通道数)
def forward(self, x):
identity = x # 保存输入值
if self.downsample is not None: # 如果需要下采样
identity = self.downsample(x)
out = self.conv1(x) # 第一个卷积层
out = self.bn1(out) # 批量归一化
out = self.relu(out) # 激活
out = self.conv2(out) # 第二个卷积层
out = self.bn2(out) # 批量归一化
out += identity # 残差连接
out = self.relu(out) # 激活
return out
class Bottleneck(nn.Module):
expansion = 4
# 适用于ResNet50、ResNet101和ResNet152的瓶颈残差块
def __init__(self, in_channel, out_channel, stride=1, downsample=None,
groups=1, width_per_group=64):
super(Bottleneck, self).__init__()
width = int(out_channel * (width_per_group / 64.)) * groups # 计算组卷积的宽度
# 第一个1x1卷积层(压缩通道数)
self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,
kernel_size=1, stride=1, bias=False)
self.bn1 = nn.BatchNorm2d(width) # 批量归一化层
# 第二个3x3卷积层
self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,
kernel_size=3, stride=stride, bias=False, padding=1)
self.bn2 = nn.BatchNorm2d(width) # 批量归一化层
# 第三个1x1卷积层(扩展通道数)
self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,
kernel_size=1, stride=1, bias=False)
self.bn3 = nn.BatchNorm2d(out_channel*self.expansion) # 批量归一化层
self.relu = nn.ReLU(inplace=True) # 激活函数
self.downsample = downsample # 下采样层
def forward(self, x):
identity = x # 保存输入值
if self.downsample is not None: # 如果需要下采样
identity = self.downsample(x)
out = self.conv1(x) # 第一个卷积层
out = self.bn1(out) # 批量归一化
out = self.relu(out) # 激活
out = self.conv2(out) # 第二个卷积层
out = self.bn2(out) # 批量归一化
out = self.relu(out) # 激活
out = self.conv3(out) # 第三个卷积层
out = self.bn3(out) # 批量归一化
out += identity # 残差连接
out = self.relu(out) # 激活
return out
class ResNet(nn.Module):
def __init__(self,
block,
blocks_num, # 每个阶段的残差块数量
num_classes=1000,
include_top=True,
groups=1,
width_per_group=64):
super(ResNet, self).__init__()
self.include_top = include_top # 是否包含全连接层
self.in_channel = 64 # 初始通道数
self.groups = groups # 组卷积数量
self.width_per_group = width_per_group # 每组的宽度
# 初始卷积层
self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(self.in_channel) # 批量归一化层
self.relu = nn.ReLU(inplace=True) # 激活函数
# 最大池化层
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# 各个阶段的残差块
self.layer1 = self._make_layer(block, 64, blocks_num[0])
self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
# 全局平均池化层和全连接层
if self.include_top:
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # 自适应平均池化层
self.fc = nn.Linear(512 * block.expansion, num_classes) # 全连接层
# 初始化卷积层权重
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')code>
def _make_layer(self, block, channel, block_num, stride=1):
downsample = None # 对于ResNet18和ResNet34,下采样默认为None,其他层为下面函数
if stride != 1 or self.in_channel != channel * block.expansion: # 如果步长不为1或通道数不匹配
downsample = nn.Sequential(
nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(channel * block.expansion)) # 下采样层
layers = []
layers.append(block(self.in_channel,
channel, # 主分支上第一个卷积核(通道)的个数
downsample=downsample, # 下采样层
stride=stride,
groups=self.groups,
width_per_group=self.width_per_group))
self.in_channel = channel * block.expansion # 更新输入通道数
for _ in range(1, block_num): # 添加剩余的残差块
layers.append(block(self.in_channel,
channel,
groups=self.groups,
width_per_group=self.width_per_group))
return nn.Sequential(*layers) # 将layers列表转化为nn.Sequential
def forward(self, x):
x = self.conv1(x) # 初始卷积层
x = self.bn1(x) # 批量归一化
x = self.relu(x) # 激活
x = self.maxpool(x) # 最大池化层
x = self.layer1(x) # 第一阶段
x = self.layer2(x) # 第二阶段
x = self.layer3(x) # 第三阶段
x = self.layer4(x) # 第四阶段
if self.include_top: # 如果包含全连接层
x = self.avgpool(x) # 全局平均池化层
x = torch.flatten(x, 1) # 展平
x = self.fc(x) # 全连接层
return x
def resnet50(model_name="resnet50", num_classes=1000, init_weights=False, **kwargs):code>
# 创建ResNet-50模型
model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=True)
# 替换最后的全连接层以适应新的类别数
model.fc = nn.Linear(model.fc.in_features, num_classes)
# 初始化权重(如果需要)
if init_weights:
def init_weights(m):
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias,
3、训练模型
# -*- coding: utf-8 -*-
import os
import sys
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm
from model import resnet50 # 导入定义的ResNet50模型
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 设置运行设备
print("using {} device.".format(device))
# 数据预处理
data_transform = {
"train_garbage": transforms.Compose([transforms.RandomResizedCrop(224), # 随机裁剪并调整大小
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.ToTensor(), # 转换为张量
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])]), # 归一化
"val_garbage": transforms.Compose([transforms.Resize(256), # 调整大小
transforms.CenterCrop(224), # 中心裁剪
transforms.ToTensor(), # 转换为张量
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])} # 归一化
data_root = os.path.abspath(os.path.join(os.getcwd(), "./")) # 获取数据根路径
print(data_root)
image_path = os.path.join(data_root, "data_set", "garbage_data") # 垃圾数据集路径
assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train_garbage"),
transform=data_transform["train_garbage"]) # 加载训练数据集
train_num = len(train_dataset)
# 获取类别映射字典并写入json文件
garbage_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in garbage_list.items())
json_str = json.dumps(cla_dict, indent=4)
with open('/home/dell/CV408/hb/data_set/garbage_data/garbage_classification.json', 'w') as json_file:
json_file.write(json_str)
batch_size = 32
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # 设置dataloader的worker数量
print('Using {} dataloader workers every process'.format(nw))
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, shuffle=True,
num_workers=nw) # 加载训练数据集
validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val_garbage"),
transform=data_transform["val_garbage"]) # 加载验证数据集
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
batch_size=batch_size, shuffle=False,
num_workers=nw) # 加载验证数据集
print("using {} images for training, {} images for validation.".format(train_num,
val_num))
# 加载ResNet50模型
model_name = "resnet50"
net = resnet50(model_name=model_name, num_classes=158, init_weights=True)
net.to(device) # 将模型移至指定设备
# 定义损失函数
loss_function = nn.CrossEntropyLoss()
# 定义优化器
params = [p for p in net.parameters() if p.requires_grad]
optimizer = optim.Adam(params, lr=0.0001)
epochs = 30
best_acc = 0.0
save_path = './Test5_resnet{}.pth'.format(model_name) # 保存当前最好的权重路径
train_steps = len(train_loader)
for epoch in range(epochs):
# 训练过程
net.train()
running_loss = 0.0
train_bar = tqdm(train_loader, file=sys.stdout)
for step, data in enumerate(train_bar):
images, labels = data
optimizer.zero_grad()
logits = net(images.to(device))
loss = loss_function(logits, labels.to(device))
loss.backward()
optimizer.step()
# 打印损失
running_loss += loss.item()
train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
epochs,
loss)
# 验证过程
net.eval()
acc = 0.0 # 累积准确数
with torch.no_grad():
val_bar = tqdm(validate_loader, file=sys.stdout)
for val_data in val_bar:
val_images, val_labels = val_data
outputs = net(val_images.to(device))
predict_y = torch.max(outputs, dim=1)[1]
acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,
epochs)
val_accurate = acc / val_num
print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
(epoch + 1, running_loss / train_steps, val_accurate))
if val_accurate > best_acc: # 更新最好权重
best_acc = val_accurate
torch.save(net.state_dict(), save_path)
print('Finished Training')
if __name__ == '__main__':
main()
对模型进行预训练之后,将得到一个最好的权重文件 。
4、测试模型
1)单图片测试
import os
import json
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from model import resnet50 # 导入定义的ResNet50模型
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 设置运行设备
# 定义数据预处理步骤
data_transform = transforms.Compose(
[transforms.Resize(256), # 调整大小到256x256
transforms.CenterCrop(224), # 中心裁剪到224x224
transforms.ToTensor(), # 转换为张量
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) # 归一化
# 加载图像
img_path = "/home/dell/CV408/hb/data_set/garbage_data/test_garbage/8/1171.jpg" # 图像路径
assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
img = Image.open(img_path) # 打开图像
plt.imshow(img) # 显示图像
img = data_transform(img) # 对图像进行预处理
img = torch.unsqueeze(img, dim=0) # 扩展批次维度
# 读取类别字典
json_path = '/home/dell/CV408/hb/data_set/garbage_data/garbage_classification.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
with open(json_path, "r") as f:
class_indict = json.load(f)
# 创建模型
model = resnet50(num_classes=158).to(device) # 初始化ResNet50模型并设置类别数为158
# 加载模型权重
weights_path = "/home/dell/CV408/hb/resnet50.pth" # 模型权重路径
assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
model.load_state_dict(torch.load(weights_path, map_location=device)) # 加载权重
# 创建保存预测结果的目录
save_dir = "/home/dell/CV408/hb/data_set/garbage_data/"
os.makedirs(save_dir, exist_ok=True)
# 预测过程
model.eval() # 设置模型为评估模式
with torch.no_grad(): # 禁用梯度计算
output = torch.squeeze(model(img.to(device))).cpu() # 模型预测并去除批次维度
predict = torch.softmax(output, dim=0) # 应用softmax获取概率
predict_cla = torch.argmax(predict).numpy() # 获取预测类别
# 打印预测结果
print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
predict[predict_cla].numpy())
plt.title(print_res) # 在图像上显示预测结果
for i in range(len(predict)): # 打印所有类别的概率
print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
predict[i].numpy()))
plt.savefig(os.path.join(save_dir, "predicted_image_garbage.png")) # 保存带有预测结果的图像
plt.show() # 显示图像
if __name__ == '__main__':
main() # 运行main函数
测试结果如下图(类别144为:火龙果)
2)批量测试
<code>import os
import json
import matplotlib.pyplot as plt
import torch
from PIL import Image
from torchvision import transforms
from model import resnet50
def main():
# 设置设备为GPU(如果可用)或CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 定义数据预处理操作,包括调整大小、中心裁剪、转换为张量和归一化
data_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 指定需要遍历预测的图像文件夹
imgs_root = "/home/dell/CV408/hb/data_set/garbage_data/test_garbage"
assert os.path.exists(imgs_root), f"file: '{imgs_root}' does not exist."
# 读取指定文件夹下所有jpg图像路径
img_path_list = [os.path.join(imgs_root, i) for i in os.listdir(imgs_root) if i.endswith(".jpg")]
# 读取类别映射文件
json_path = '/home/dell/CV408/hb/data_set/garbage_data/garbage_classification.json'
assert os.path.exists(json_path), f"file: '{json_path}' does not exist."
with open(json_path, "r") as json_file:
class_indict = json.load(json_file)
# 创建ResNet50模型并加载预训练权重
model = resnet50(num_classes=158).to(device)
weights_path = "/home/dell/CV408/hb/Test5_resnetresnet50.pth"
assert os.path.exists(weights_path), f"file: '{weights_path}' does not exist."
model.load_state_dict(torch.load(weights_path, map_location=device))
# 设置模型为评估模式
model.eval()
batch_size = 8 # 每次预测时处理的图像数量
save_dir = "/home/dell/CV408/hb/data_set/garbage_data/predictions/"
os.makedirs(save_dir, exist_ok=True) # 创建保存预测结果的目录
with torch.no_grad():
# 按批次处理图像
for ids in range(0, len(img_path_list) // batch_size):
img_list = []
img_paths = img_path_list[ids * batch_size: (ids + 1) * batch_size]
for img_path in img_paths:
assert os.path.exists(img_path), f"file: '{img_path}' does not exist."
img = Image.open(img_path)
img = data_transform(img)
img_list.append(img)
# 将图像列表打包成一个批次
batch_img = torch.stack(img_list, dim=0)
# 预测类别
output = model(batch_img.to(device)).cpu()
predict = torch.softmax(output, dim=1)
probs, classes = torch.max(predict, dim=1)
# 可视化并保存每张图像的预测结果
for idx, (pro, cla) in enumerate(zip(probs, classes)):
img_path = img_paths[idx]
img = Image.open(img_path)
plt.imshow(img)
plt.title(f"Class: {class_indict[str(cla.numpy())]} Prob: {pro.numpy():.3f}")
save_path = os.path.join(save_dir, f"pred_{os.path.basename(img_path)}")
plt.savefig(save_path)
plt.close()
print(f"image: {img_path} class: {class_indict[str(cla.numpy())]} prob: {pro.numpy():.3f}")
if __name__ == '__main__':
main()
生成一个文件夹保存着测试集图像对应的预测结果。
五、总结
模型的精确率达到了69.7%左右,并且对一些样本少的类别和图像质量较差的预测结果不理想,可能需要对图像的预处理方法进行改进,比如图像增强等等
声明
本文内容仅代表作者观点,或转载于其他网站,本站不以此文作为商业用途
如有涉及侵权,请联系本站进行删除
转载本站原创文章,请注明来源及作者。