python之使用ViT进行图像分类

老歌老听老掉牙 2024-06-16 13:05:03 阅读 52

CIFAR10为数据集,该数据集共有10个分类。整个项目的处理步骤如下。

1)导入需要的库。包括与PyTorch相关的库(torch),与数据处理相关的库(如torchvision)、与张量操作方面的库(如einops)等。

2)对数据进行预处理。使用torchvision导入数据集CIFAR10,然后对数据集进行正则化、剪辑等操作,提升数据质量。

3)生成模型的输入数据。把预处理后的数据向量化,并加上位置嵌入、分类标志等信息,生成模型的输入数据。

4)构建模型。这里主要使用Transformer架构中编码器(Encoder),构建模型。

5)训练模型。定义损失函数,选择优化器,实例化模型,通过多次迭代训练模型。

import torchimport torch.nn.functional as Fimport matplotlib.pyplot as pltimport torchvisionimport torchvision.transforms as transforms​from torch import nnfrom torch import Tensorfrom PIL import Imagefrom torchvision.transforms import Compose, Resize, ToTensorfrom einops import rearrange, reduce, repeatfrom einops.layers.torch import Rearrange, Reduce# 对训练数据实现数据增强方法,以便提升模型的泛化能力.train_transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.RandomResizedCrop((32,32),scale=(0.8,1.0),ratio=(0.9,1.1)), transforms.ToTensor(), transforms.Normalize([0.49139968, 0.48215841, 0.44653091], [0.24703223, 0.24348513, 0.26158784]) ])test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.49139968, 0.48215841, 0.44653091], [0.24703223, 0.24348513, 0.26158784]) ])trainset = torchvision.datasets.CIFAR10(root='../data/', train=True, download=False, transform=train_transform)trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, drop_last=True, pin_memory=True, num_workers=4)testset = torchvision.datasets.CIFAR10(root='../data', train=False,download=False, transform=test_transform)testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, drop_last=False, num_workers=4)​classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')# 随机可视化4张图片NUM_IMAGES = 4CIFAR_images = torch.stack([trainset[idx][0] for idx in range(NUM_IMAGES)], dim=0)img_grid = torchvision.utils.make_grid(CIFAR_images, nrow=4, normalize=True, pad_value=0.9)img_grid = img_grid.permute(1, 2, 0)​plt.figure(figsize=(8,8))plt.title("Image examples of the CIFAR10 dataset")plt.imshow(img_grid)plt.axis('off')plt.show()plt.close()class PatchEmbedding(nn.Module): def __init__(self, in_channels = 3, patch_size = 4, emb_size = 256): self.patch_size = patch_size super().__init__() self.projection = nn.Sequential( # 在s1 x s2切片中分解图像并将其平面化 Rearrange('b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size), nn.Linear(patch_size * patch_size * in_channels, emb_size) ) def forward(self, x): x = self.projection(x) return xclass PatchEmbedding(nn.Module): def __init__(self, in_channels= 3, patch_size= 4, emb_size= 256): self.patch_size = patch_size super().__init__() self.proj = nn.Sequential( # 用卷积层代替线性层->性能提升 nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size), Rearrange('b e (h) (w) -> b (h w) e'), ) self.cls_token = nn.Parameter(torch.randn(1,1, emb_size)) def forward(self, x): b, _, _, _ = x.shape x = self.proj(x) cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b) # 在输入前添加cls标记 x = torch.cat([cls_tokens, x], dim=1) return xclass PatchEmbedding(nn.Module): def __init__(self, in_channels= 3, patch_size= 4, emb_size= 256, img_size= 32): self.patch_size = patch_size super().__init__() self.projection = nn.Sequential( # 用卷积层代替线性层->性能提升 nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size), Rearrange('b e (h) (w) -> b (h w) e'), ) self.cls_token = nn.Parameter(torch.randn(1,1, emb_size)) self.positions = nn.Parameter(torch.randn((img_size // patch_size) **2 + 1, emb_size))​ def forward(self, x): b, _, _, _ = x.shape x = self.projection(x) cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b) # 在输入前添加cls标记 x = torch.cat([cls_tokens, x], dim=1) # 加位置嵌入 x += self.positions return xclass MultiHeadAttention(nn.Module): def __init__(self, emb_size = 256, num_heads = 8, dropout = 0): super().__init__() self.emb_size = emb_size self.num_heads = num_heads # 将查询、键和值融合到一个矩阵中 self.qkv = nn.Linear(emb_size, emb_size * 3) self.att_drop = nn.Dropout(dropout) self.projection = nn.Linear(emb_size, emb_size) def forward(self, x , mask = None): # 分割num_heads中的键、查询和值 qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3) queries, keys, values = qkv[0], qkv[1], qkv[2] # 最后一个轴上求和 energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len if mask is not None: fill_value = torch.finfo(torch.float32).min energy.mask_fill(~mask, fill_value) scaling = self.emb_size ** (1/2) att = F.softmax(energy, dim=-1) / scaling att = self.att_drop(att) # 在第三个轴上求和 out = torch.einsum('bhal, bhlv -> bhav ', att, values) out = rearrange(out, "b h n d -> b n (h d)") out = self.projection(out) return outclass ResidualAdd(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x, **kwargs): res = x x = self.fn(x, **kwargs) x += res return xclass FeedForwardBlock(nn.Sequential): def __init__(self, emb_size=256, expansion= 4, drop_p= 0.): super().__init__( nn.Linear(emb_size, expansion * emb_size), nn.GELU(), nn.Dropout(drop_p), nn.Linear(expansion * emb_size, emb_size), )class TransformerEncoderBlock(nn.Sequential): def __init__(self, emb_size= 256, drop_p = 0., forward_expansion = 4, forward_drop_p = 0., ** kwargs): super().__init__( ResidualAdd(nn.Sequential( nn.LayerNorm(emb_size), MultiHeadAttention(emb_size, **kwargs), nn.Dropout(drop_p) )), ResidualAdd(nn.Sequential( nn.LayerNorm(emb_size), FeedForwardBlock( emb_size, expansion=forward_expansion, drop_p=forward_drop_p), nn.Dropout(drop_p) ) ))class TransformerEncoder(nn.Sequential): def __init__(self, depth: int = 12, **kwargs): super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])class ClassificationHead(nn.Sequential): def __init__(self, emb_size= 256, n_classes = 10): super().__init__( Reduce('b n e -> b e', reduction='mean'), nn.LayerNorm(emb_size), nn.Linear(emb_size, n_classes))class ViT(nn.Sequential): def __init__(self, in_channels = 3, patch_size = 4, emb_size = 256, img_size = 32, depth = 12, n_classes = 10, **kwargs): super().__init__( PatchEmbedding(in_channels, patch_size, emb_size, img_size), TransformerEncoder(depth, emb_size=emb_size, **kwargs), ClassificationHead(emb_size, n_classes) )device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")​vit = ViT()vit=vit.to(device)import torch.optim as optimLR=0.001​criterion = nn.CrossEntropyLoss()optimizer = optim.AdamW(vit.parameters(), lr=0.001)for epoch in range(10): ​ running_loss = 0.0 for i, data in enumerate(trainloader, 0): # 获取训练数据 #print(i) inputs, labels = data inputs, labels = inputs.to(device), labels.to(device) # 权重参数梯度清零 optimizer.zero_grad()​ # 正向及反向传播 outputs = vit(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step()​ # 显示损失值 running_loss += loss.item() if i % 100 == 99: # print every 100 mini-batches print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 100)) running_loss = 0.0​print('Finished Training')​



声明

本文内容仅代表作者观点,或转载于其他网站,本站不以此文作为商业用途
如有涉及侵权,请联系本站进行删除
转载本站原创文章,请注明来源及作者。