使用Vision Transformer来对CIFA-10数据集进行分类

秋灯冷雨 2024-09-01 15:01:02 阅读 54

多的不说,直接放码过来:

vit的主要思想就是将图片切割为多个patch块,大小为patch_size,数量为(size/patch_size)^2

对每个patch展平为一维向量,传入transformer的编码器中得到提取特征后的向量,这样就和nlp里面的任务一样了!

import torch

import torch.nn as nn

import torchvision

from torch.utils.data import DataLoader

from torchvision import transforms

from torch import optim

import timeit

from tqdm import tqdm

class PatchEmbedding(nn.Module):

    def __init__(self, in_channels, patch_size, embed_dim, num_patches, dropout):

        super(PatchEmbedding, self).__init__()

        self.patcher = nn.Sequential(

            nn.Conv2d(in_channels=in_channels, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size),

            nn.Flatten(2)

        )

        self.cls_token = nn.Parameter(torch.randn(size=(1, 1, embed_dim)), requires_grad=True)

        self.postion_embedding = nn.Parameter(torch.randn(size=(1, num_patches + 1, embed_dim)), requires_grad=True)

        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):

        # cls_token = self.cls_token.expand(x.shape[0], -1, -1)

        cls_token = self.cls_token.expand(x.shape[0], 1, -1)  # [batch_size,1,768(embed_dim)]

        x = self.patcher(x).permute(0, 2, 1)  # [batch_size,patches,embed_dim]

        x = torch.cat([cls_token, x], dim=1)  # 拼接分类编码

        x = x + self.postion_embedding  # size=(1, num_patches + 1, embed_dim)

        x = self.dropout(x)

        return x

class Vit(nn.Module):

    def __init__(self, in_channels, patch_size, embed_dim, num_patches, dropout,

                 num_head, activation, num_encoders, num_class):

        super(Vit, self).__init__()

        self.patch_embedding = PatchEmbedding(in_channels, patch_size, embed_dim, num_patches, dropout)

        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_head, dropout=dropout,

                                                   activation=activation, batch_first=True, norm_first=True)

        # 使用多个TransformerEncoderLayer实例化TransformerEncoder

        self.encoder_layers = nn.TransformerEncoder(encoder_layer, num_layers=num_encoders)

        self.MLP = nn.Sequential(

            nn.LayerNorm(normalized_shape=embed_dim),

            nn.Linear(in_features=embed_dim, out_features=num_class)

        )

    def forward(self, x):

        x = self.patch_embedding(x)

        x = self.encoder_layers(x)

        x = self.MLP(x[:, 0, :])  # 取cls_token

        return x

in_channels = 3

img_size = 32

patch_size = 8

embed_dim = patch_size ** 2 * in_channels

num_patches = (img_size // patch_size) ** 2

dropout = 0.01

batch_size = 64

device = "cuda" if torch.cuda.is_available() else "cpu"

epochs = 50

num_head = 8

activation = "gelu"

num_encoders = 10

num_classes = 10

learning_rate = 1e-4

weight_dacay = 1e-4

betas = (0.9, 0.999)

train_transform = torchvision.transforms.Compose([

    # transforms.ToPILImage(),

    transforms.RandomRotation(15),

    transforms.ToTensor(),

    transforms.Normalize(

        mean=0.5, std=0.5

    )

])

test_transform = torchvision.transforms.Compose([

    # transforms.ToPILImage(),

    # transforms.RandomRotation(15),

    transforms.ToTensor(),

    transforms.Normalize(

        mean=0.5, std=0.5

    )

])

train_dataset = torchvision.datasets.CIFAR10(root="../../datas", train=True, transform=train_transform)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = torchvision.datasets.CIFAR10(root="../../datas", train=False, transform=test_transform)

test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

model = Vit(in_channels, patch_size, embed_dim, num_patches, dropout,

            num_head, activation, num_encoders, num_classes).to(device)

criterion = nn.CrossEntropyLoss()

optimzer = optim.Adam(model.parameters(), lr=learning_rate, betas=betas, weight_decay=weight_dacay)

start = timeit.default_timer()

best_acc = 0

print(f"training on : {device}")

for epoch in range(epochs):

    model.train()

    train_labels = []

    train_preds = []

    train_running_loss = 0

    n = 0

    train_n_sum = 0

    train_n_correct = 0

    for idx, (X, y) in enumerate(tqdm(train_dataloader, position=0, leave=True)):

        X = X.to(device)

        y = y.to(device)

        y_pred = model(X)

        y_pred_label = torch.argmax(y_pred, dim=1)

        # print(y.shape, y_pred.shape)

        train_labels.extend(y.cpu().detach())

        train_preds.extend(y_pred.cpu().detach())

        loss = criterion(y_pred, y)

        optimzer.zero_grad()

        loss.backward()

        optimzer.step()

        train_running_loss += loss.item()

        n += 1

        train_n_sum += X.size(0)

        train_n_correct += (y == y_pred_label).sum().item()

    train_loss = train_running_loss / (n + 1)

    train_acc = train_n_correct / train_n_sum

    model.eval()

    val_labels = []

    val_preds = []

    val_running_loss = 0

    test_n_sum = 0

    test_n_correct = 0

    with torch.no_grad():

        n = 0

        for idx, (X, y) in enumerate(tqdm(test_dataloader, position=0, leave=True)):

            X = X.to(device)

            y = y.to(device)

            y_pred = model(X)

            y_pred_label = torch.argmax(y_pred, dim=1)

            # print(y_pred.shape, y.shape)

            val_labels.extend(y.cpu().detach())

            val_preds.extend(y_pred.cpu().detach())

            loss = criterion(y_pred, y)

            val_running_loss += loss.item()

            n += 1

            test_n_sum += X.size(0)

            test_n_correct += (y == y_pred_label).sum().item()

        test_loss = val_running_loss / (n + 1)

        test_acc = test_n_correct / test_n_sum

        if test_acc > best_acc:

            best_acc = test_acc

            map = {

                'state': model.state_dict(),

                "acc": test_acc,

                "loss": test_loss

            }

            print("save model : ", map['acc'])

            torch.save(map, "./checkpoints/vit_model.pth")

    print("-" * 30)

    print(f"train loss epoch : {epoch + 1} : {train_loss:.4f}")

    print(f"test loss epoch : {epoch + 1} : {test_loss:.4f}")

    print(

        f"train acc epoch : {epoch + 1} : {train_acc:.4f}"

    )

    print(

        f"test acc epoch : {epoch + 1} : {test_acc:.4f}"

    )

    print("-" * 30)

stop = timeit.default_timer()

print(f"training time : {stop - start:.2f}")

# patcher = PatchEmbedding(in_channels=in_channels, patch_size=patch_size,

#                          embed_dim=embed_dim, num_patches=num_patches, dropout=dropout)

# for idx, (x, y) in enumerate(dataloader):

#     print("before : ", idx, x.shape, y.shape)

#     x = patcher(x)

#     print("after : ", x.shape)

#     break

附带几个训练效果图:



声明

本文内容仅代表作者观点,或转载于其他网站,本站不以此文作为商业用途
如有涉及侵权,请联系本站进行删除
转载本站原创文章,请注明来源及作者。