使用Vision Transformer来对CIFA-10数据集进行分类
秋灯冷雨 2024-09-01 15:01:02 阅读 54
多的不说,直接放码过来:
vit的主要思想就是将图片切割为多个patch块,大小为patch_size,数量为(size/patch_size)^2
对每个patch展平为一维向量,传入transformer的编码器中得到提取特征后的向量,这样就和nlp里面的任务一样了!
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms
from torch import optim
import timeit
from tqdm import tqdm
class PatchEmbedding(nn.Module):
def __init__(self, in_channels, patch_size, embed_dim, num_patches, dropout):
super(PatchEmbedding, self).__init__()
self.patcher = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size),
nn.Flatten(2)
)
self.cls_token = nn.Parameter(torch.randn(size=(1, 1, embed_dim)), requires_grad=True)
self.postion_embedding = nn.Parameter(torch.randn(size=(1, num_patches + 1, embed_dim)), requires_grad=True)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x):
# cls_token = self.cls_token.expand(x.shape[0], -1, -1)
cls_token = self.cls_token.expand(x.shape[0], 1, -1) # [batch_size,1,768(embed_dim)]
x = self.patcher(x).permute(0, 2, 1) # [batch_size,patches,embed_dim]
x = torch.cat([cls_token, x], dim=1) # 拼接分类编码
x = x + self.postion_embedding # size=(1, num_patches + 1, embed_dim)
x = self.dropout(x)
return x
class Vit(nn.Module):
def __init__(self, in_channels, patch_size, embed_dim, num_patches, dropout,
num_head, activation, num_encoders, num_class):
super(Vit, self).__init__()
self.patch_embedding = PatchEmbedding(in_channels, patch_size, embed_dim, num_patches, dropout)
encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_head, dropout=dropout,
activation=activation, batch_first=True, norm_first=True)
# 使用多个TransformerEncoderLayer实例化TransformerEncoder
self.encoder_layers = nn.TransformerEncoder(encoder_layer, num_layers=num_encoders)
self.MLP = nn.Sequential(
nn.LayerNorm(normalized_shape=embed_dim),
nn.Linear(in_features=embed_dim, out_features=num_class)
)
def forward(self, x):
x = self.patch_embedding(x)
x = self.encoder_layers(x)
x = self.MLP(x[:, 0, :]) # 取cls_token
return x
in_channels = 3
img_size = 32
patch_size = 8
embed_dim = patch_size ** 2 * in_channels
num_patches = (img_size // patch_size) ** 2
dropout = 0.01
batch_size = 64
device = "cuda" if torch.cuda.is_available() else "cpu"
epochs = 50
num_head = 8
activation = "gelu"
num_encoders = 10
num_classes = 10
learning_rate = 1e-4
weight_dacay = 1e-4
betas = (0.9, 0.999)
train_transform = torchvision.transforms.Compose([
# transforms.ToPILImage(),
transforms.RandomRotation(15),
transforms.ToTensor(),
transforms.Normalize(
mean=0.5, std=0.5
)
])
test_transform = torchvision.transforms.Compose([
# transforms.ToPILImage(),
# transforms.RandomRotation(15),
transforms.ToTensor(),
transforms.Normalize(
mean=0.5, std=0.5
)
])
train_dataset = torchvision.datasets.CIFAR10(root="../../datas", train=True, transform=train_transform)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = torchvision.datasets.CIFAR10(root="../../datas", train=False, transform=test_transform)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
model = Vit(in_channels, patch_size, embed_dim, num_patches, dropout,
num_head, activation, num_encoders, num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimzer = optim.Adam(model.parameters(), lr=learning_rate, betas=betas, weight_decay=weight_dacay)
start = timeit.default_timer()
best_acc = 0
print(f"training on : {device}")
for epoch in range(epochs):
model.train()
train_labels = []
train_preds = []
train_running_loss = 0
n = 0
train_n_sum = 0
train_n_correct = 0
for idx, (X, y) in enumerate(tqdm(train_dataloader, position=0, leave=True)):
X = X.to(device)
y = y.to(device)
y_pred = model(X)
y_pred_label = torch.argmax(y_pred, dim=1)
# print(y.shape, y_pred.shape)
train_labels.extend(y.cpu().detach())
train_preds.extend(y_pred.cpu().detach())
loss = criterion(y_pred, y)
optimzer.zero_grad()
loss.backward()
optimzer.step()
train_running_loss += loss.item()
n += 1
train_n_sum += X.size(0)
train_n_correct += (y == y_pred_label).sum().item()
train_loss = train_running_loss / (n + 1)
train_acc = train_n_correct / train_n_sum
model.eval()
val_labels = []
val_preds = []
val_running_loss = 0
test_n_sum = 0
test_n_correct = 0
with torch.no_grad():
n = 0
for idx, (X, y) in enumerate(tqdm(test_dataloader, position=0, leave=True)):
X = X.to(device)
y = y.to(device)
y_pred = model(X)
y_pred_label = torch.argmax(y_pred, dim=1)
# print(y_pred.shape, y.shape)
val_labels.extend(y.cpu().detach())
val_preds.extend(y_pred.cpu().detach())
loss = criterion(y_pred, y)
val_running_loss += loss.item()
n += 1
test_n_sum += X.size(0)
test_n_correct += (y == y_pred_label).sum().item()
test_loss = val_running_loss / (n + 1)
test_acc = test_n_correct / test_n_sum
if test_acc > best_acc:
best_acc = test_acc
map = {
'state': model.state_dict(),
"acc": test_acc,
"loss": test_loss
}
print("save model : ", map['acc'])
torch.save(map, "./checkpoints/vit_model.pth")
print("-" * 30)
print(f"train loss epoch : {epoch + 1} : {train_loss:.4f}")
print(f"test loss epoch : {epoch + 1} : {test_loss:.4f}")
print(
f"train acc epoch : {epoch + 1} : {train_acc:.4f}"
)
print(
f"test acc epoch : {epoch + 1} : {test_acc:.4f}"
)
print("-" * 30)
stop = timeit.default_timer()
print(f"training time : {stop - start:.2f}")
# patcher = PatchEmbedding(in_channels=in_channels, patch_size=patch_size,
# embed_dim=embed_dim, num_patches=num_patches, dropout=dropout)
# for idx, (x, y) in enumerate(dataloader):
# print("before : ", idx, x.shape, y.shape)
# x = patcher(x)
# print("after : ", x.shape)
# break
附带几个训练效果图:
声明
本文内容仅代表作者观点,或转载于其他网站,本站不以此文作为商业用途
如有涉及侵权,请联系本站进行删除
转载本站原创文章,请注明来源及作者。