失眠网,内容丰富有趣,生活中的好帮手!
失眠网 > keras训练cifar10数据集源代码

keras训练cifar10数据集源代码

时间:2019-09-01 09:54:20

相关推荐

keras训练cifar10数据集源代码

前言

对CIFAR-10 数据集的分类是机器学习中一个公开的基准测试问题,其任务是对一组大小为32x32的RGB图像进行分类,这些图像涵盖了10个类别:

飞机, 汽车, 鸟, 猫, 鹿, 狗, 青蛙, 马, 船以及卡车。

首先来看下cifar10数据集:

这里面一共有五个训练文件,一个测试文件。网上的教程大多都是需要以下五个文件,在这里自己实现了单文件的训练代码。代码需要提前下载好cifar10数据,CIFAR-10 python version版本的哦~

源代码

# -*- coding: utf-8 -*-"""Created on Sun Jun 24 09:43:25 @author: new"""import numpy as npimport osimport sysimport keras.backend as Kfrom six.moves import cPickleimport cv2import numpy as npimport tensorflow as tfimport kerasfrom keras.models import Sequentialfrom keras.layers import Conv2D, MaxPooling2Dfrom keras.layers.core import Dense, Dropout, Activation, Flattenfrom keras.layers.convolutional import Convolution2D, MaxPooling2Dfrom keras.preprocessing.image import ImageDataGeneratorfrom keras.optimizers import SGD, Adadelta, Adagradfrom keras.utils import np_utils, generic_utilsfrom keras.utils import plot_modelimport matplotlib.pyplot as pltimport numpy as npimport scipy.io as sioos.environ["CUDA_VISIBLE_DEVICES"] = "0"gpu_options = tf.GPUOptions(allow_growth=True)sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))#【0】设置超参batch_size = 32num_classes = 10epochs = 5data_augmentation = Truedef load_batch(fpath, label_key='labels'):f = open(fpath, 'rb')if sys.version_info < (3,):d = cPickle.load(f)else:d = cPickle.load(f, encoding='bytes')# decode utf8d_decoded = {}for k, v in d.items():d_decoded[k.decode('utf8')] = vd = d_decodedf.close()data = d['data']labels = d[label_key]data = data.reshape(data.shape[0], 3, 32, 32)return data, labelsdef load_data():dirname = 'C:/Users/new/Desktop/cifar-10-batches-py'num_train_samples = 50000x_train = np.empty((num_train_samples, 3, 32, 32), dtype='uint8')y_train = np.empty((num_train_samples,), dtype='uint8')for i in range(1, 6):fpath = os.path.join(dirname, 'data_batch_' + str(i))(x_train[(i - 1) * 10000: i * 10000, :, :, :],y_train[(i - 1) * 10000: i * 10000]) = load_batch(fpath)fpath = os.path.join(dirname, 'test_batch')x_test, y_test = load_batch(fpath)y_train = np.reshape(y_train, (len(y_train), 1))y_test = np.reshape(y_test, (len(y_test), 1))if K.image_data_format() == 'channels_last':x_train = x_train.transpose(0, 2, 3, 1)x_test = x_test.transpose(0, 2, 3, 1)return (x_train, y_train), (x_test, y_test)(x_train, y_train), (x_test, y_test)=load_data()print('x_train shape:', x_train.shape)print('y_train shape:', y_train.shape)print('x_test shape:', x_test.shape)print('y_test shape:', y_test.shape)plt.figure(1)plt.imshow(x_train[0]) # 显示第一张训练图片plt.figure(2)plt.imshow(x_test[0]) # 显示第一张测试图片# 【3】将标签转化成 one-hot 编码y_train = keras.utils.to_categorical(y_train, num_classes)y_test = keras.utils.to_categorical(y_test, num_classes)# 【4】构建深度CNN序贯模型model = Sequential()model.add(Conv2D(32, (3, 3), padding='same',input_shape=x_train.shape[1:]))model.add(Activation('relu'))model.add(Conv2D(32, (3, 3)))model.add(Activation('relu'))model.add(MaxPooling2D(pool_size=(2, 2)))model.add(Dropout(0.25))model.add(Conv2D(64, (3, 3), padding='same'))model.add(Activation('relu'))model.add(Conv2D(64, (3, 3)))model.add(Activation('relu'))model.add(MaxPooling2D(pool_size=(2, 2)))model.add(Dropout(0.25))model.add(Flatten())model.add(Dense(512))model.add(Activation('relu'))model.add(Dropout(0.5))model.add(Dense(num_classes))model.add(Activation('softmax'))print(model.summary())# 打印模型概况# 【5】编译模型opt = keras.optimizers.rmsprop(lr=0.0001, decay=1e-6)#初始化一个 RMSprop 优化器pile(loss='categorical_crossentropy',optimizer=opt,metrics=['accuracy'])# 【6】数据预处理/增强+模型训练x_train = x_train.astype('float32')x_test = x_test.astype('float32')x_train /= 255x_test /= 255if not data_augmentation:print('Not using data augmentation.')model.fit(x_train, y_train,batch_size=batch_size,epochs=epochs,validation_data=(x_test, y_test),)else:print('Using real-time data augmentation.')# ImageDataGenerator:图片生成器,用以生成一个batch的图像数据,训练时该函数会无限生成数据# 直到达到规定的epoch次数。图片生成(CPU)和训练(GPU)并行执行。datagen = ImageDataGenerator(featurewise_center=False, samplewise_center=False, featurewise_std_normalization=False, samplewise_std_normalization=False, zca_whitening=False, rotation_range=0, # 随机旋转的角度范围width_shift_range=0.1, # 随机水平偏移的幅度范围height_shift_range=0.1, horizontal_flip=True, # 随机水平翻转vertical_flip=False)datagen.fit(x_train) # 计算样本的统计信息,进行数据预处理(如去中心化,标准化)model.fit_generator(datagen.flow(x_train, y_train, # datagen.flow()不断生成一个batch的数据用于模型训练batch_size=batch_size),epochs=epochs,validation_data=(x_test, y_test),workers=4)# 【7】保存模型以及权重save_dir = os.path.join(os.getcwd(), 'saved_models')model_name = 'keras_cifar10_trained_model.h5'if not os.path.isdir(save_dir):os.makedirs(save_dir)model_path = os.path.join(save_dir, model_name)model.save(model_path)print('Saved trained model at %s ' % model_path)# 【8】测试集评估模型scores = model.evaluate(x_test, y_test, verbose=1)print('Test loss:', scores[0])print('Test accuracy:', scores[1])

实验结果

Using TensorFlow backend.x_train shape: (50000, 32, 32, 3)y_train shape: (50000, 1)x_test shape: (10000, 32, 32, 3)y_test shape: (10000, 1)_________________________________________________________________Layer (type) Output Shape Param # =================================================================conv2d_1 (Conv2D) (None, 32, 32, 32) 896 _________________________________________________________________activation_1 (Activation) (None, 32, 32, 32) 0 _________________________________________________________________conv2d_2 (Conv2D) (None, 30, 30, 32) 9248_________________________________________________________________activation_2 (Activation) (None, 30, 30, 32) 0 _________________________________________________________________max_pooling2d_1 (MaxPooling2 (None, 15, 15, 32) 0 _________________________________________________________________dropout_1 (Dropout)(None, 15, 15, 32) 0 _________________________________________________________________conv2d_3 (Conv2D) (None, 15, 15, 64) 18496_________________________________________________________________activation_3 (Activation) (None, 15, 15, 64) 0 _________________________________________________________________conv2d_4 (Conv2D) (None, 13, 13, 64) 36928_________________________________________________________________activation_4 (Activation) (None, 13, 13, 64) 0 _________________________________________________________________max_pooling2d_2 (MaxPooling2 (None, 6, 6, 64)0 _________________________________________________________________dropout_2 (Dropout)(None, 6, 6, 64)0 _________________________________________________________________flatten_1 (Flatten)(None, 2304) 0 _________________________________________________________________dense_1 (Dense) (None, 512)1180160 _________________________________________________________________activation_5 (Activation) (None, 512)0 _________________________________________________________________dropout_3 (Dropout)(None, 512)0 _________________________________________________________________dense_2 (Dense) (None, 10)5130_________________________________________________________________activation_6 (Activation) (None, 10)0 =================================================================Total params: 1,250,858Trainable params: 1,250,858Non-trainable params: 0_________________________________________________________________NoneUsing real-time data augmentation.Epoch 1/51563/1563 [==============================] - 102s 65ms/step - loss: 1.8342 - acc: 0.3229 - val_loss: 1.5518 - val_acc: 0.4325Epoch 2/51563/1563 [==============================] - 120s 77ms/step - loss: 1.5533 - acc: 0.4310 - val_loss: 1.4069 - val_acc: 0.4883Epoch 3/51563/1563 [==============================] - 108s 69ms/step - loss: 1.4322 - acc: 0.4846 - val_loss: 1.2653 - val_acc: 0.5508Epoch 4/51563/1563 [==============================] - 107s 68ms/step - loss: 1.3429 - acc: 0.5180 - val_loss: 1.1613 - val_acc: 0.5869Epoch 5/51563/1563 [==============================] - 107s 69ms/step - loss: 1.2704 - acc: 0.5454 - val_loss: 1.1002 - val_acc: 0.6138Saved trained model at C:\Users\new\Desktop\chapter_2\saved_models\keras_cifar10_trained_model.h5 10000/10000 [==============================] - 6s 611us/stepTest loss: 1.1002309656143188Test accuracy: 0.6138

由于迭代的次数比较少,所以测试集上的准确率不是太高,可以多迭代几次试下哦~~~

加载模型进行预测

model = load_model('C:/Users/new/Desktop/chapter_2/saved_models/keras_cifar10_trained_model.h5') print('test after load: ', model.predict(x_test[0:2]))

测试后的结果:

test after load: [[0.00000000e+00 0.00000000e+00 0.00000000e+00 1.00000000e+000.00000000e+00 2.44679976e-34 0.00000000e+00 0.00000000e+001.05497485e-29 0.00000000e+00][0.00000000e+00 1.46545753e-08 0.00000000e+00 0.00000000e+000.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+001.00000000e+00 0.00000000e+00]]

这里是one-hot向量,最大的那个就是预测出的类别~~~

如果觉得《keras训练cifar10数据集源代码》对你有帮助,请点赞、收藏,并留下你的观点哦!

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。