失眠网,内容丰富有趣,生活中的好帮手!
失眠网 > 【小白学习PyTorch教程】八 使用图像数据增强手段 提升CIFAR-10 数据集精确度

【小白学习PyTorch教程】八 使用图像数据增强手段 提升CIFAR-10 数据集精确度

时间:2019-09-08 14:36:52

相关推荐

【小白学习PyTorch教程】八 使用图像数据增强手段 提升CIFAR-10 数据集精确度

@Author:Runsen

上次基于CIFAR-10 数据集,使用PyTorch ​​构建图像分类模型的精确度是60%,对于如何提升精确度,方法就是常见的transforms图像数据增强手段。

import torchimport torch.nn as nnimport torch.optim as optimfrom torch.utils.data import DataLoaderimport torchvisionimport torchvision.datasets as datasetsimport torchvision.transforms as transformsimport torchvision.utils as vutilsimport numpy as npimport osimport warningsfrom matplotlib import pyplot as pltwarnings.filterwarnings('ignore')`device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

加载数据集

# number of images in one forward and backward passbatch_size = 128# number of subprocesses used for data loading# Normally do not use it if your os is windowsnum_workers = 2train_dataset = datasets.CIFAR10('./data/CIFAR10/', train = True, download = True, transform = transform_train)train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers)val_dataset = datasets.CIFAR10('./data/CIFAR10', train = True, transform = transform_test)val_loader = DataLoader(val_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers)test_dataset = datasets.CIFAR10('./data/CIFAR10', train = False, transform = transform_test)test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers)# declare classes in CIFAR10classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

之前的transform ’只是进行了缩放和归一,在这里添加RandomCrop和RandomHorizontalFlip

# define a transform to normalize the datatransform_train = pose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(), # converting images to tensortransforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5)) # if the image dataset is black and white image, there can be just one number. ])transform_test = pose([transforms.ToTensor(),transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))])

可视化具体的图像

# function that will be used for visualizing the datadef imshow(img):img = img / 2 + 0.5 # unnormalizeplt.imshow(np.transpose(img, (1, 2, 0))) # convert from Tensor image# obtain one batch of imges from train datasetdataiter = iter(train_loader)images, labels = dataiter.next()images = images.numpy() # convert images to numpy for display# plot the images in one batch with the corresponding labelsfig = plt.figure(figsize = (25, 4))# display imagesfor idx in np.arange(10):ax = fig.add_subplot(1, 10, idx+1, xticks=[], yticks=[])imshow(images[idx])ax.set_title(classes[labels[idx]])

建立常见的CNN模型

# define the CNN architectureclass CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.main = nn.Sequential(# 3x32x32nn.Conv2d(in_channels = 3, out_channels = 32, kernel_size = 3, padding = 1), # 3x32x32 (O = (N+2P-F/S)+1)nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size = 2, stride = 2), # 32x16x16nn.BatchNorm2d(32),nn.Conv2d(32, 64, kernel_size = 3, padding = 1), # 32x16x16nn.ReLU(inplace=True),nn.MaxPool2d(2, 2), # 64x8x8nn.BatchNorm2d(64),nn.Conv2d(64, 128, 3, padding = 1), # 64x8x8nn.ReLU(inplace=True),nn.MaxPool2d(2, 2), # 128x4x4nn.BatchNorm2d(128),)self.fc = nn.Sequential(nn.Linear(128*4*4, 1024),nn.ReLU(inplace=True),nn.Dropout(0.5),nn.Linear(1024, 256),nn.ReLU(inplace=True),nn.Dropout(0.5),nn.Linear(256, 10))def forward(self, x):# Conv and Poolilng layersx = self.main(x)# Flatten before Fully Connected layersx = x.view(-1, 128*4*4) # Fully Connected Layerx = self.fc(x)return xcnn = CNN().to(device)cnn

torch.nn.CrossEntropyLoss对输出概率介于0和1之间的分类模型进行分类。

训练模型

# 超参数:Hyper Parameterslearning_rate = 0.001train_losses = []val_losses = []# Loss function and Optimizercriterion = nn.CrossEntropyLoss()optimizer = optim.Adam(cnn.parameters(), lr = learning_rate)# define train function that trains the model using a CIFAR10 datasetdef train(model, epoch, num_epochs):model.train()total_batch = len(train_dataset) // batch_sizefor i, (images, labels) in enumerate(train_loader):X = images.to(device)Y = labels.to(device)### forward pass and loss calculation# forward passpred = model(X)#c alculation of loss valuecost = criterion(pred, Y)### backward pass and optimization# gradient initializationoptimizer.zero_grad()# backward passcost.backward()# parameter updateoptimizer.step()# training statsif (i+1) % 100 == 0:print('Train, Epoch [%d/%d], lter [%d/%d], Loss: %.4f' % (epoch+1, num_epochs, i+1, total_batch, np.average(train_losses)))train_losses.append(cost.item())n# def the validation function that validates the model using CIFAR10 datasetdef validation(model, epoch, num_epochs):model.eval()total_batch = len(val_dataset) // batch_sizefor i, (images, labels) in enumerate(val_loader):X = images.to(device)Y = labels.to(device)with torch.no_grad():pred = model(X)cost = criterion(pred, Y)if (i+1) % 100 == 0:print("Validation, Epoch [%d/%d], lter [%d/%d], Loss: %.4f"% (epoch+1, num_epochs, i+1, total_batch, np.average(val_losses)))val_losses.append(cost.item())def plot_losses(train_losses, val_losses):plt.figure(figsize=(5, 5))plt.plot(train_losses, label='Train', alpha=0.5)plt.plot(val_losses, label='Validation', alpha=0.5)plt.xlabel('Epochs')plt.ylabel('Losses')plt.legend()plt.grid(b=True)plt.title('CIFAR 10 Train/Val Losses Over Epoch')plt.show()num_epochs = 20for epoch in range(num_epochs):train(cnn, epoch, num_epochs)validation(cnn, epoch, num_epochs)torch.save(cnn.state_dict(), './data/Tutorial_3_CNN_Epoch_{}.pkl'.format(epoch+1))plot_losses(train_losses, val_losses)

测试模型

def test(model):# declare that the model is about to evaluatemodel.eval()correct = 0total = 0with torch.no_grad():for images, labels in test_dataset:images = images.unsqueeze(0).to(device)# forward passoutputs = model(images)_, predicted = torch.max(outputs.data, 1)total += 1correct += (predicted == labels).sum().item()print("Accuracy of Test Images: %f %%" % (100 * float(correct) / total))

经过图像数据增强。模型从60提升到了84。

测试模型在哪些类上表现良好,

class_correct = list(0. for i in range(10))class_total = list(0. for i in range(10))with torch.no_grad():for data in test_loader:images, labels = dataimages = images.to(device)labels = labels.to(device)outputs = cnn(images)_, predicted = torch.max(outputs, 1)c = (predicted == labels).squeeze()for i in range(4):label = labels[i]class_correct[label] += c[i].item()class_total[label] += 1for i in range(10):print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))

如果觉得《【小白学习PyTorch教程】八 使用图像数据增强手段 提升CIFAR-10 数据集精确度》对你有帮助,请点赞、收藏,并留下你的观点哦!

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。